Files
modelscope/modelscope/utils/data_utils.py
yuze.zyz ca1321f53f Support trainer prediction and fix some bugs
1. Support trainer prediction
2. Fix bug in text classification metric
3. Move load checkpoint out of checkpointhook
4. Fix bug in train progressing (inner_iter variable not correct)

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11560269
2023-02-10 06:19:37 +00:00

38 lines
1.2 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
from collections.abc import Mapping
import torch
from modelscope.outputs import ModelOutputBase
def to_device(batch, device, non_blocking=False):
"""Put the data to the target cuda device just before the forward function.
Args:
batch: The batch data out of the dataloader.
device: (str | torch.device): The target device for the data.
Returns: The data to the target device.
"""
if isinstance(batch, ModelOutputBase):
for idx in range(len(batch)):
batch[idx] = to_device(batch[idx], device)
return batch
elif isinstance(batch, dict) or isinstance(batch, Mapping):
if hasattr(batch, '__setitem__'):
# Reuse mini-batch to keep attributes for prediction.
for k, v in batch.items():
batch[k] = to_device(v, device)
return batch
else:
return type(batch)(
{k: to_device(v, device)
for k, v in batch.items()})
elif isinstance(batch, (tuple, list)):
return type(batch)(to_device(v, device) for v in batch)
elif isinstance(batch, torch.Tensor):
return batch.to(device, non_blocking=non_blocking)
else:
return batch