add slicing method for NativeIterableDataset (#490)

This commit is contained in:
Xingjun.Wang
2023-08-22 09:29:51 +08:00
committed by GitHub
parent 040698e201
commit 82902b8758

View File

@@ -1,7 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
import math
import os
from itertools import islice
import datasets
import pandas as pd
@@ -102,30 +104,60 @@ class NativeIterableDataset(IterableDataset):
desc='Overall progress',
total=self.n_shards,
dynamic_ncols=True):
ret = {}
if isinstance(item, dict):
try:
for k, v in item.items():
ret[k] = v
if k.endswith(':FILE'):
dl_manager = self._ex_iterable.kwargs.get(
'dl_manager')
ex_cache_path = dl_manager.download_and_extract(v)
if isinstance(ex_cache_path, str):
ex_cache_path = [ex_cache_path]
ret[k] = ex_cache_path
except Exception as e:
logger.error(e)
ret = item
else:
ret = item
ret = self._download_item(item)
yield ret
def __len__(self):
return self.n_shards
def __getitem__(self, index):
"""
Returns the item at index `index` in the dataset. Slice indexing is supported.
"""
if isinstance(index, int):
start = index
stop = index + 1
step = None
else:
start = index.start
stop = index.stop
step = index.step
if step is not None and step <= 0:
raise ValueError('step must be positive')
for item in tqdm(
islice(
self.iter(batch_size=1, drop_last_batch=False), start,
stop, step),
desc='Slicing progress',
dynamic_ncols=True):
ret = self._download_item(item)
yield ret
def _download_item(self, item):
ret = {}
if isinstance(item, dict):
try:
for k, v in item.items():
ret[k] = v
if k.endswith(':FILE'):
dl_manager = self._ex_iterable.kwargs.get('dl_manager')
ex_cache_path = dl_manager.download_and_extract(v)
if isinstance(ex_cache_path, str):
ex_cache_path = [ex_cache_path]
ret[k] = ex_cache_path
except Exception as e:
logger.error(e)
ret = item
else:
ret = item
return ret
def head(self, n=5):
"""
Returns the first n rows of the dataset.