[to #42322933] fix bug: run failed in tensor_utils.py

1. 修复default data collator的输入类型为tuple时运行会失败的问题
2. 修复default data collator的输入类型为dict时不兼容BatchEncoding的问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9403517

    * fix bug: 1. run failed when datatype is tuple 2. change type checking from dict to Mapping to fit transformers.datasets.BatchEncoding
This commit is contained in:
yuze.zyz
2022-07-16 10:21:43 +08:00
parent 231f400133
commit 8d0d6252ca

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
from collections.abc import Mapping
def torch_nested_numpify(tensors):
@@ -31,7 +32,7 @@ def torch_default_data_collator(features):
# features = [vars(f) for f in features]
first = features[0]
if isinstance(first, dict):
if isinstance(first, Mapping):
batch = {}
# Special handling for labels.
# Ensure that tensor is created with the correct type
@@ -65,9 +66,9 @@ def torch_default_data_collator(features):
batch = []
for idx in range(len(first)):
if isinstance(first[idx], torch.Tensor):
batch.append(torch.stack([f[k] for f in features]))
batch.append(torch.stack([f[idx] for f in features]))
else:
batch.append(torch.tensor([f[k] for f in features]))
batch.append(torch.tensor([f[idx] for f in features]))
else:
if isinstance(first, torch.Tensor):
batch = torch.stack(features)