From 8d0d6252ca0c3c36606dee148bb3f5a3d76594b1 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Sat, 16 Jul 2022 10:21:43 +0800 Subject: [PATCH] [to #42322933] fix bug: run failed in tensor_utils.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 修复default data collator的输入类型为tuple时运行会失败的问题 2. 修复default data collator的输入类型为dict时不兼容BatchEncoding的问题 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9403517 * fix bug: 1. run failed when datatype is tuple 2. change type checking from dict to Mapping to fit transformers.datasets.BatchEncoding --- modelscope/utils/tensor_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/modelscope/utils/tensor_utils.py b/modelscope/utils/tensor_utils.py index dc38f9bf..a80ca6cd 100644 --- a/modelscope/utils/tensor_utils.py +++ b/modelscope/utils/tensor_utils.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. # Part of the implementation is borrowed from huggingface/transformers. +from collections.abc import Mapping def torch_nested_numpify(tensors): @@ -31,7 +32,7 @@ def torch_default_data_collator(features): # features = [vars(f) for f in features] first = features[0] - if isinstance(first, dict): + if isinstance(first, Mapping): batch = {} # Special handling for labels. # Ensure that tensor is created with the correct type @@ -65,9 +66,9 @@ def torch_default_data_collator(features): batch = [] for idx in range(len(first)): if isinstance(first[idx], torch.Tensor): - batch.append(torch.stack([f[k] for f in features])) + batch.append(torch.stack([f[idx] for f in features])) else: - batch.append(torch.tensor([f[k] for f in features])) + batch.append(torch.tensor([f[idx] for f in features])) else: if isinstance(first, torch.Tensor): batch = torch.stack(features)