From b4b7e29b286e08495d0b84fab08ac0f43f995fa9 Mon Sep 17 00:00:00 2001 From: "xingjun.wang" Date: Sat, 20 Jul 2024 18:59:19 +0800 Subject: [PATCH] fix streaming for youku-mplug and adopt latest datasets --- modelscope/msdatasets/__init__.py | 2 +- modelscope/msdatasets/dataset_cls/dataset.py | 1 + .../msdatasets/download/download_config.py | 23 ++++++----- .../msdatasets/download/download_manager.py | 38 +++++++++++++++++++ requirements/framework.txt | 2 +- 5 files changed, 52 insertions(+), 14 deletions(-) diff --git a/modelscope/msdatasets/__init__.py b/modelscope/msdatasets/__init__.py index 70200e44..534a0500 100644 --- a/modelscope/msdatasets/__init__.py +++ b/modelscope/msdatasets/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) Alibaba, Inc. and its affiliates. -from .ms_dataset import MsDataset +from modelscope.msdatasets.ms_dataset import MsDataset diff --git a/modelscope/msdatasets/dataset_cls/dataset.py b/modelscope/msdatasets/dataset_cls/dataset.py index f9ffd9a7..9c1c7584 100644 --- a/modelscope/msdatasets/dataset_cls/dataset.py +++ b/modelscope/msdatasets/dataset_cls/dataset.py @@ -149,6 +149,7 @@ class NativeIterableDataset(IterableDataset): if isinstance(ex_cache_path, str): ex_cache_path = [ex_cache_path] ret[k] = ex_cache_path + ret[k.strip(':FILE')] = v except Exception as e: logger.error(e) diff --git a/modelscope/msdatasets/download/download_config.py b/modelscope/msdatasets/download/download_config.py index 11118f85..0fc8b5cf 100644 --- a/modelscope/msdatasets/download/download_config.py +++ b/modelscope/msdatasets/download/download_config.py @@ -6,16 +6,15 @@ from datasets.download.download_config import DownloadConfig class DataDownloadConfig(DownloadConfig): + """ + Extends `DownloadConfig` with additional attributes for data download. + """ - def __init__(self): - self.dataset_name: Optional[str] = None - self.namespace: Optional[str] = None - self.version: Optional[str] = None - self.split: Optional[Union[str, list]] = None - self.data_dir: Optional[str] = None - self.oss_config: Optional[dict] = {} - self.meta_args_map: Optional[dict] = {} - self.num_proc: int = 4 - - def copy(self) -> 'DataDownloadConfig': - return self + dataset_name: Optional[str] = None + namespace: Optional[str] = None + version: Optional[str] = None + split: Optional[Union[str, list]] = None + data_dir: Optional[str] = None + oss_config: Optional[dict] = {} + meta_args_map: Optional[dict] = {} + num_proc: int = 4 diff --git a/modelscope/msdatasets/download/download_manager.py b/modelscope/msdatasets/download/download_manager.py index 4799171a..d241b4fa 100644 --- a/modelscope/msdatasets/download/download_manager.py +++ b/modelscope/msdatasets/download/download_manager.py @@ -36,6 +36,26 @@ class DataDownloadManager(DownloadManager): return cached_path( url_or_filename, download_config=download_config) + def _download_single(self, url_or_filename: str, + download_config: DataDownloadConfig) -> str: + # Note: _download_single is adapted to the datasets>=2.19.0 + + url_or_filename = str(url_or_filename) + + oss_utilities = OssUtilities( + oss_config=download_config.oss_config, + dataset_name=download_config.dataset_name, + namespace=download_config.namespace, + revision=download_config.version) + + if is_relative_path(url_or_filename): + # fetch oss files + return oss_utilities.download( + url_or_filename, download_config=download_config) + else: + return cached_path( + url_or_filename, download_config=download_config) + class DataStreamingDownloadManager(StreamingDownloadManager): """The data streaming download manager.""" @@ -62,3 +82,21 @@ class DataStreamingDownloadManager(StreamingDownloadManager): else: return cached_path( url_or_filename, download_config=self.download_config) + + def _download_single(self, url_or_filename: str) -> str: + # Note: _download_single is adapted to the datasets>=2.19.0 + + url_or_filename = str(url_or_filename) + oss_utilities = OssUtilities( + oss_config=self.download_config.oss_config, + dataset_name=self.download_config.dataset_name, + namespace=self.download_config.namespace, + revision=self.download_config.version) + + if is_relative_path(url_or_filename): + # fetch oss files + return oss_utilities.download( + url_or_filename, download_config=self.download_config) + else: + return cached_path( + url_or_filename, download_config=self.download_config) diff --git a/requirements/framework.txt b/requirements/framework.txt index d6317bf2..21fda818 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=2.16.0,<2.19.0 +datasets>=2.19.0,<2.21.0 einops oss2 python-dateutil>=2.1