mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
[to #42322933] fix bug: run failed in tensor_utils.py
1. 修复default data collator的输入类型为tuple时运行会失败的问题
2. 修复default data collator的输入类型为dict时不兼容BatchEncoding的问题
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9403517
* fix bug: 1. run failed when datatype is tuple 2. change type checking from dict to Mapping to fit transformers.datasets.BatchEncoding
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from huggingface/transformers.
|
||||
from collections.abc import Mapping
|
||||
|
||||
|
||||
def torch_nested_numpify(tensors):
|
||||
@@ -31,7 +32,7 @@ def torch_default_data_collator(features):
|
||||
# features = [vars(f) for f in features]
|
||||
first = features[0]
|
||||
|
||||
if isinstance(first, dict):
|
||||
if isinstance(first, Mapping):
|
||||
batch = {}
|
||||
# Special handling for labels.
|
||||
# Ensure that tensor is created with the correct type
|
||||
@@ -65,9 +66,9 @@ def torch_default_data_collator(features):
|
||||
batch = []
|
||||
for idx in range(len(first)):
|
||||
if isinstance(first[idx], torch.Tensor):
|
||||
batch.append(torch.stack([f[k] for f in features]))
|
||||
batch.append(torch.stack([f[idx] for f in features]))
|
||||
else:
|
||||
batch.append(torch.tensor([f[k] for f in features]))
|
||||
batch.append(torch.tensor([f[idx] for f in features]))
|
||||
else:
|
||||
if isinstance(first, torch.Tensor):
|
||||
batch = torch.stack(features)
|
||||
|
||||
Reference in New Issue
Block a user