mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
add slicing method for NativeIterableDataset (#490)
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from itertools import islice
|
||||
|
||||
import datasets
|
||||
import pandas as pd
|
||||
@@ -102,30 +104,60 @@ class NativeIterableDataset(IterableDataset):
|
||||
desc='Overall progress',
|
||||
total=self.n_shards,
|
||||
dynamic_ncols=True):
|
||||
ret = {}
|
||||
if isinstance(item, dict):
|
||||
try:
|
||||
for k, v in item.items():
|
||||
ret[k] = v
|
||||
if k.endswith(':FILE'):
|
||||
dl_manager = self._ex_iterable.kwargs.get(
|
||||
'dl_manager')
|
||||
ex_cache_path = dl_manager.download_and_extract(v)
|
||||
if isinstance(ex_cache_path, str):
|
||||
ex_cache_path = [ex_cache_path]
|
||||
ret[k] = ex_cache_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
ret = item
|
||||
else:
|
||||
ret = item
|
||||
ret = self._download_item(item)
|
||||
|
||||
yield ret
|
||||
|
||||
def __len__(self):
|
||||
return self.n_shards
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Returns the item at index `index` in the dataset. Slice indexing is supported.
|
||||
"""
|
||||
if isinstance(index, int):
|
||||
start = index
|
||||
stop = index + 1
|
||||
step = None
|
||||
else:
|
||||
start = index.start
|
||||
stop = index.stop
|
||||
step = index.step
|
||||
|
||||
if step is not None and step <= 0:
|
||||
raise ValueError('step must be positive')
|
||||
|
||||
for item in tqdm(
|
||||
islice(
|
||||
self.iter(batch_size=1, drop_last_batch=False), start,
|
||||
stop, step),
|
||||
desc='Slicing progress',
|
||||
dynamic_ncols=True):
|
||||
ret = self._download_item(item)
|
||||
|
||||
yield ret
|
||||
|
||||
def _download_item(self, item):
|
||||
ret = {}
|
||||
if isinstance(item, dict):
|
||||
try:
|
||||
for k, v in item.items():
|
||||
ret[k] = v
|
||||
if k.endswith(':FILE'):
|
||||
dl_manager = self._ex_iterable.kwargs.get('dl_manager')
|
||||
ex_cache_path = dl_manager.download_and_extract(v)
|
||||
if isinstance(ex_cache_path, str):
|
||||
ex_cache_path = [ex_cache_path]
|
||||
ret[k] = ex_cache_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
ret = item
|
||||
else:
|
||||
ret = item
|
||||
|
||||
return ret
|
||||
|
||||
def head(self, n=5):
|
||||
"""
|
||||
Returns the first n rows of the dataset.
|
||||
|
||||
Reference in New Issue
Block a user