mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
[to #44842128] 修复MsDataset torch场景
1、to_torch_dataset时支持保留原数据类型
2、替换orch.utils.data.IterableDataset为torch.utils.data.Dataset,支持分布式训练和shuffle。后续等streaming数据加载方式支持后再引入IterableDataset
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10214102
This commit is contained in:
@@ -42,44 +42,40 @@ def format_list(para) -> List:
|
||||
return para
|
||||
|
||||
|
||||
class MsIterableDataset(torch.utils.data.IterableDataset):
|
||||
class MsMapDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(self, dataset: Iterable, preprocessor_list, retained_columns,
|
||||
columns):
|
||||
super(MsIterableDataset).__init__()
|
||||
columns, to_tensor):
|
||||
super(MsDataset).__init__()
|
||||
self.dataset = dataset
|
||||
self.preprocessor_list = preprocessor_list
|
||||
self.to_tensor = to_tensor
|
||||
self.retained_columns = retained_columns
|
||||
self.columns = columns
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __iter__(self):
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None: # single-process data loading
|
||||
iter_start = 0
|
||||
iter_end = len(self.dataset)
|
||||
else: # in a worker process
|
||||
per_worker = math.ceil(
|
||||
len(self.dataset) / float(worker_info.num_workers))
|
||||
worker_id = worker_info.id
|
||||
iter_start = worker_id * per_worker
|
||||
iter_end = min(iter_start + per_worker, len(self.dataset))
|
||||
def type_converter(self, x):
|
||||
if self.to_tensor:
|
||||
return torch.tensor(x)
|
||||
else:
|
||||
return x
|
||||
|
||||
for idx in range(iter_start, iter_end):
|
||||
item_dict = self.dataset[idx]
|
||||
res = {
|
||||
k: torch.tensor(item_dict[k])
|
||||
for k in self.columns if k in self.retained_columns
|
||||
}
|
||||
for preprocessor in self.preprocessor_list:
|
||||
res.update({
|
||||
k: torch.tensor(v)
|
||||
for k, v in preprocessor(item_dict).items()
|
||||
if k in self.retained_columns
|
||||
})
|
||||
yield res
|
||||
def __getitem__(self, index):
|
||||
item_dict = self.dataset[index]
|
||||
res = {
|
||||
k: self.type_converter(item_dict[k])
|
||||
for k in self.columns
|
||||
if (not self.to_tensor) or k in self.retained_columns
|
||||
}
|
||||
for preprocessor in self.preprocessor_list:
|
||||
res.update({
|
||||
k: self.type_converter(v)
|
||||
for k, v in preprocessor(item_dict).items()
|
||||
if (not self.to_tensor) or k in self.retained_columns
|
||||
})
|
||||
return res
|
||||
|
||||
|
||||
class MsDataset:
|
||||
@@ -339,6 +335,7 @@ class MsDataset:
|
||||
self,
|
||||
preprocessors: Union[Callable, List[Callable]],
|
||||
columns: Union[str, List[str]] = None,
|
||||
to_tensor: bool = True,
|
||||
):
|
||||
preprocessor_list = preprocessors if isinstance(
|
||||
preprocessors, list) else [preprocessors]
|
||||
@@ -348,28 +345,29 @@ class MsDataset:
|
||||
columns = [
|
||||
key for key in self._hf_ds.features.keys() if key in columns
|
||||
]
|
||||
sample = next(iter(self._hf_ds))
|
||||
|
||||
sample_res = {k: np.array(sample[k]) for k in columns}
|
||||
for processor in preprocessor_list:
|
||||
sample_res.update(
|
||||
{k: np.array(v)
|
||||
for k, v in processor(sample).items()})
|
||||
|
||||
def is_numpy_number(value):
|
||||
return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
|
||||
value.dtype, np.floating)
|
||||
|
||||
retained_columns = []
|
||||
for k in sample_res.keys():
|
||||
if not is_numpy_number(sample_res[k]):
|
||||
logger.warning(
|
||||
f'Data of column {k} is non-numeric, will be removed')
|
||||
continue
|
||||
retained_columns.append(k)
|
||||
if to_tensor:
|
||||
sample = next(iter(self._hf_ds))
|
||||
|
||||
return MsIterableDataset(self._hf_ds, preprocessor_list,
|
||||
retained_columns, columns)
|
||||
sample_res = {k: np.array(sample[k]) for k in columns}
|
||||
for processor in preprocessor_list:
|
||||
sample_res.update(
|
||||
{k: np.array(v)
|
||||
for k, v in processor(sample).items()})
|
||||
|
||||
def is_numpy_number(value):
|
||||
return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
|
||||
value.dtype, np.floating)
|
||||
|
||||
for k in sample_res.keys():
|
||||
if not is_numpy_number(sample_res[k]):
|
||||
logger.warning(
|
||||
f'Data of column {k} is non-numeric, will be removed')
|
||||
continue
|
||||
retained_columns.append(k)
|
||||
|
||||
return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns,
|
||||
columns, to_tensor)
|
||||
|
||||
def to_torch_dataset(
|
||||
self,
|
||||
@@ -377,6 +375,7 @@ class MsDataset:
|
||||
preprocessors: Union[Callable, List[Callable]] = None,
|
||||
task_name: str = None,
|
||||
task_data_config: ConfigDict = None,
|
||||
to_tensor: bool = True,
|
||||
**format_kwargs,
|
||||
):
|
||||
"""Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to
|
||||
@@ -384,13 +383,14 @@ class MsDataset:
|
||||
|
||||
Args:
|
||||
preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
|
||||
every sample of the dataset. The output type of processors is dict, and each numeric field of the dict
|
||||
every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict
|
||||
will be used as a field of torch.utils.data.Dataset.
|
||||
columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only). If the
|
||||
preprocessor is None, the arg columns must have at least one column. If the `preprocessors` is not None,
|
||||
the output fields of processors will also be added.
|
||||
columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if
|
||||
`to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column.
|
||||
If the `preprocessors` is not None, the output fields of processors will also be added.
|
||||
task_name (str, default None): task name, refer to :obj:`Tasks` for more details
|
||||
task_data_config (ConfigDict, default None): config dict for model object.
|
||||
to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not.
|
||||
format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`.
|
||||
|
||||
Returns:
|
||||
@@ -407,7 +407,7 @@ class MsDataset:
|
||||
return build_task_dataset(task_data_config, task_name)
|
||||
if preprocessors is not None:
|
||||
return self.to_torch_dataset_with_processors(
|
||||
preprocessors, columns=columns)
|
||||
preprocessors, columns=columns, to_tensor=to_tensor)
|
||||
else:
|
||||
self._hf_ds.reset_format()
|
||||
self._hf_ds.set_format(
|
||||
|
||||
Reference in New Issue
Block a user