From 2d75b3428c3b3ecb1604645ffbec342c4516b6f1 Mon Sep 17 00:00:00 2001 From: "xingjun.wang" Date: Mon, 21 Aug 2023 20:35:12 +0800 Subject: [PATCH] add slicing method for NativeIterableDataset --- modelscope/msdatasets/dataset_cls/dataset.py | 68 ++++++++++++++------ 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/modelscope/msdatasets/dataset_cls/dataset.py b/modelscope/msdatasets/dataset_cls/dataset.py index 48a5ab51..f9ffd9a7 100644 --- a/modelscope/msdatasets/dataset_cls/dataset.py +++ b/modelscope/msdatasets/dataset_cls/dataset.py @@ -1,7 +1,9 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import copy +import math import os +from itertools import islice import datasets import pandas as pd @@ -102,30 +104,60 @@ class NativeIterableDataset(IterableDataset): desc='Overall progress', total=self.n_shards, dynamic_ncols=True): - ret = {} - if isinstance(item, dict): - try: - for k, v in item.items(): - ret[k] = v - if k.endswith(':FILE'): - dl_manager = self._ex_iterable.kwargs.get( - 'dl_manager') - ex_cache_path = dl_manager.download_and_extract(v) - if isinstance(ex_cache_path, str): - ex_cache_path = [ex_cache_path] - ret[k] = ex_cache_path - - except Exception as e: - logger.error(e) - ret = item - else: - ret = item + ret = self._download_item(item) yield ret def __len__(self): return self.n_shards + def __getitem__(self, index): + """ + Returns the item at index `index` in the dataset. Slice indexing is supported. + """ + if isinstance(index, int): + start = index + stop = index + 1 + step = None + else: + start = index.start + stop = index.stop + step = index.step + + if step is not None and step <= 0: + raise ValueError('step must be positive') + + for item in tqdm( + islice( + self.iter(batch_size=1, drop_last_batch=False), start, + stop, step), + desc='Slicing progress', + dynamic_ncols=True): + ret = self._download_item(item) + + yield ret + + def _download_item(self, item): + ret = {} + if isinstance(item, dict): + try: + for k, v in item.items(): + ret[k] = v + if k.endswith(':FILE'): + dl_manager = self._ex_iterable.kwargs.get('dl_manager') + ex_cache_path = dl_manager.download_and_extract(v) + if isinstance(ex_cache_path, str): + ex_cache_path = [ex_cache_path] + ret[k] = ex_cache_path + + except Exception as e: + logger.error(e) + ret = item + else: + ret = item + + return ret + def head(self, n=5): """ Returns the first n rows of the dataset.