[to #44842128] 修复MsDataset torch场景

1、to_torch_dataset时支持保留原数据类型
2、替换orch.utils.data.IterableDataset为torch.utils.data.Dataset,支持分布式训练和shuffle。后续等streaming数据加载方式支持后再引入IterableDataset
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10214102
This commit is contained in:
feiwu.yfw
2022-09-23 15:27:05 +08:00
committed by Yingda Chen
parent 9f0dde31d8
commit 3b6c3b723c

View File

@@ -42,44 +42,40 @@ def format_list(para) -> List:
return para
class MsIterableDataset(torch.utils.data.IterableDataset):
class MsMapDataset(torch.utils.data.Dataset):
def __init__(self, dataset: Iterable, preprocessor_list, retained_columns,
columns):
super(MsIterableDataset).__init__()
columns, to_tensor):
super(MsDataset).__init__()
self.dataset = dataset
self.preprocessor_list = preprocessor_list
self.to_tensor = to_tensor
self.retained_columns = retained_columns
self.columns = columns
def __len__(self):
return len(self.dataset)
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading
iter_start = 0
iter_end = len(self.dataset)
else: # in a worker process
per_worker = math.ceil(
len(self.dataset) / float(worker_info.num_workers))
worker_id = worker_info.id
iter_start = worker_id * per_worker
iter_end = min(iter_start + per_worker, len(self.dataset))
def type_converter(self, x):
if self.to_tensor:
return torch.tensor(x)
else:
return x
for idx in range(iter_start, iter_end):
item_dict = self.dataset[idx]
res = {
k: torch.tensor(item_dict[k])
for k in self.columns if k in self.retained_columns
}
for preprocessor in self.preprocessor_list:
res.update({
k: torch.tensor(v)
for k, v in preprocessor(item_dict).items()
if k in self.retained_columns
})
yield res
def __getitem__(self, index):
item_dict = self.dataset[index]
res = {
k: self.type_converter(item_dict[k])
for k in self.columns
if (not self.to_tensor) or k in self.retained_columns
}
for preprocessor in self.preprocessor_list:
res.update({
k: self.type_converter(v)
for k, v in preprocessor(item_dict).items()
if (not self.to_tensor) or k in self.retained_columns
})
return res
class MsDataset:
@@ -339,6 +335,7 @@ class MsDataset:
self,
preprocessors: Union[Callable, List[Callable]],
columns: Union[str, List[str]] = None,
to_tensor: bool = True,
):
preprocessor_list = preprocessors if isinstance(
preprocessors, list) else [preprocessors]
@@ -348,28 +345,29 @@ class MsDataset:
columns = [
key for key in self._hf_ds.features.keys() if key in columns
]
sample = next(iter(self._hf_ds))
sample_res = {k: np.array(sample[k]) for k in columns}
for processor in preprocessor_list:
sample_res.update(
{k: np.array(v)
for k, v in processor(sample).items()})
def is_numpy_number(value):
return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
value.dtype, np.floating)
retained_columns = []
for k in sample_res.keys():
if not is_numpy_number(sample_res[k]):
logger.warning(
f'Data of column {k} is non-numeric, will be removed')
continue
retained_columns.append(k)
if to_tensor:
sample = next(iter(self._hf_ds))
return MsIterableDataset(self._hf_ds, preprocessor_list,
retained_columns, columns)
sample_res = {k: np.array(sample[k]) for k in columns}
for processor in preprocessor_list:
sample_res.update(
{k: np.array(v)
for k, v in processor(sample).items()})
def is_numpy_number(value):
return np.issubdtype(value.dtype, np.integer) or np.issubdtype(
value.dtype, np.floating)
for k in sample_res.keys():
if not is_numpy_number(sample_res[k]):
logger.warning(
f'Data of column {k} is non-numeric, will be removed')
continue
retained_columns.append(k)
return MsMapDataset(self._hf_ds, preprocessor_list, retained_columns,
columns, to_tensor)
def to_torch_dataset(
self,
@@ -377,6 +375,7 @@ class MsDataset:
preprocessors: Union[Callable, List[Callable]] = None,
task_name: str = None,
task_data_config: ConfigDict = None,
to_tensor: bool = True,
**format_kwargs,
):
"""Create a torch.utils.data.Dataset from the MS Dataset. The torch.utils.data.Dataset can be passed to
@@ -384,13 +383,14 @@ class MsDataset:
Args:
preprocessors (Callable or List[Callable], default None): (list of) Preprocessor object used to process
every sample of the dataset. The output type of processors is dict, and each numeric field of the dict
every sample of the dataset. The output type of processors is dict, and each (numeric) field of the dict
will be used as a field of torch.utils.data.Dataset.
columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only). If the
preprocessor is None, the arg columns must have at least one column. If the `preprocessors` is not None,
the output fields of processors will also be added.
columns (str or List[str], default None): Dataset column(s) to be loaded (numeric data only if
`to_tensor` is True). If the preprocessor is None, the arg columns must have at least one column.
If the `preprocessors` is not None, the output fields of processors will also be added.
task_name (str, default None): task name, refer to :obj:`Tasks` for more details
task_data_config (ConfigDict, default None): config dict for model object.
to_tensor (bool, default None): whether convert the data types of dataset column(s) to torch.tensor or not.
format_kwargs: A `dict` of arguments to be passed to the `torch.tensor`.
Returns:
@@ -407,7 +407,7 @@ class MsDataset:
return build_task_dataset(task_data_config, task_name)
if preprocessors is not None:
return self.to_torch_dataset_with_processors(
preprocessors, columns=columns)
preprocessors, columns=columns, to_tensor=to_tensor)
else:
self._hf_ds.reset_format()
self._hf_ds.set_format(