diff --git a/modelscope/utils/tensor_utils.py b/modelscope/utils/tensor_utils.py index dc38f9bf..a80ca6cd 100644 --- a/modelscope/utils/tensor_utils.py +++ b/modelscope/utils/tensor_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/transformers. +from collections.abc import Mapping def torch_nested_numpify(tensors): @@ -31,7 +32,7 @@ def torch_default_data_collator(features): # features = [vars(f) for f in features] first = features[0] - if isinstance(first, dict): + if isinstance(first, Mapping): batch = {} # Special handling for labels. # Ensure that tensor is created with the correct type @@ -65,9 +66,9 @@ def torch_default_data_collator(features): batch = [] for idx in range(len(first)): if isinstance(first[idx], torch.Tensor): - batch.append(torch.stack([f[k] for f in features])) + batch.append(torch.stack([f[idx] for f in features])) else: - batch.append(torch.tensor([f[k] for f in features])) + batch.append(torch.tensor([f[idx] for f in features])) else: if isinstance(first, torch.Tensor): batch = torch.stack(features)