mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Merge pull request #834 from modelscope/fix/datasets_hf_utils
Fix datasets hf utils
This commit is contained in:
@@ -21,8 +21,7 @@ from modelscope.msdatasets.dataset_cls import (ExternalDataset,
|
||||
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
|
||||
build_custom_dataset
|
||||
from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager
|
||||
from modelscope.msdatasets.utils.hf_datasets_util import \
|
||||
load_dataset as hf_load_dataset_wrapper
|
||||
from modelscope.msdatasets.utils.hf_datasets_util import load_dataset_with_ctx
|
||||
from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager
|
||||
from modelscope.preprocessors import build_preprocessor
|
||||
from modelscope.utils.config import Config, ConfigDict
|
||||
@@ -293,21 +292,25 @@ class MsDataset:
|
||||
|
||||
# Load from the ModelScope Hub for type=4 (general)
|
||||
if str(dataset_type) == str(DatasetFormations.general.value):
|
||||
return hf_load_dataset_wrapper(
|
||||
path=namespace + '/' + dataset_name,
|
||||
name=subset_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=split,
|
||||
cache_dir=cache_dir,
|
||||
features=None,
|
||||
download_config=None,
|
||||
download_mode=download_mode.value,
|
||||
revision=version,
|
||||
token=token,
|
||||
streaming=use_streaming,
|
||||
dataset_info_only=dataset_info_only,
|
||||
**config_kwargs)
|
||||
|
||||
with load_dataset_with_ctx(
|
||||
path=namespace + '/' + dataset_name,
|
||||
name=subset_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=split,
|
||||
cache_dir=cache_dir,
|
||||
features=None,
|
||||
download_config=None,
|
||||
download_mode=download_mode.value,
|
||||
revision=version,
|
||||
token=token,
|
||||
streaming=use_streaming,
|
||||
dataset_info_only=dataset_info_only,
|
||||
**config_kwargs) as dataset_res:
|
||||
|
||||
return dataset_res
|
||||
|
||||
else:
|
||||
|
||||
remote_dataloader_manager = RemoteDataLoaderManager(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
|
||||
import importlib
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
@@ -52,7 +53,7 @@ from datasets.utils.track import tracked_str
|
||||
from fsspec import filesystem
|
||||
from fsspec.core import _un_chain
|
||||
from fsspec.utils import stringify_path
|
||||
from huggingface_hub import (DatasetCard, DatasetCardData, HfFileSystem)
|
||||
from huggingface_hub import (DatasetCard, DatasetCardData)
|
||||
from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
|
||||
from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
|
||||
from packaging import version
|
||||
@@ -66,14 +67,8 @@ from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
config.HF_ENDPOINT = get_endpoint()
|
||||
|
||||
|
||||
file_utils.get_from_cache = get_from_cache_ms
|
||||
|
||||
|
||||
def _download(self, url_or_filename: str,
|
||||
download_config: DownloadConfig) -> str:
|
||||
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
|
||||
url_or_filename = str(url_or_filename)
|
||||
# for temp val
|
||||
revision = None
|
||||
@@ -94,9 +89,6 @@ def _download(self, url_or_filename: str,
|
||||
return out
|
||||
|
||||
|
||||
DownloadManager._download = _download
|
||||
|
||||
|
||||
def _dataset_info(
|
||||
self,
|
||||
repo_id: str,
|
||||
@@ -193,9 +185,6 @@ def _dataset_info(
|
||||
return HfDatasetInfo(**data)
|
||||
|
||||
|
||||
HfApi.dataset_info = _dataset_info
|
||||
|
||||
|
||||
def _list_repo_tree(
|
||||
self,
|
||||
repo_id: str,
|
||||
@@ -244,9 +233,6 @@ def _list_repo_tree(
|
||||
**path_info)
|
||||
|
||||
|
||||
HfApi.list_repo_tree = _list_repo_tree
|
||||
|
||||
|
||||
def _get_paths_info(
|
||||
self,
|
||||
repo_id: str,
|
||||
@@ -282,9 +268,6 @@ def _get_paths_info(
|
||||
]
|
||||
|
||||
|
||||
HfApi.get_paths_info = _get_paths_info
|
||||
|
||||
|
||||
def get_fs_token_paths(
|
||||
urlpath,
|
||||
storage_options=None,
|
||||
@@ -420,9 +403,6 @@ def _resolve_pattern(
|
||||
return out
|
||||
|
||||
|
||||
data_files.resolve_pattern = _resolve_pattern
|
||||
|
||||
|
||||
def _get_data_patterns(
|
||||
base_path: str,
|
||||
download_config: Optional[DownloadConfig] = None) -> Dict[str,
|
||||
@@ -668,9 +648,6 @@ def get_module_without_script(self) -> DatasetModule:
|
||||
)
|
||||
|
||||
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
|
||||
|
||||
|
||||
def _download_additional_modules(
|
||||
name: str,
|
||||
dataset_name: str,
|
||||
@@ -863,9 +840,6 @@ def get_module_with_script(self) -> DatasetModule:
|
||||
return DatasetModule(module_path, hash, builder_kwargs)
|
||||
|
||||
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
|
||||
|
||||
|
||||
class DatasetsWrapperHF:
|
||||
|
||||
@staticmethod
|
||||
@@ -1336,4 +1310,40 @@ class DatasetsWrapperHF:
|
||||
f'any data file in the same directory.')
|
||||
|
||||
|
||||
load_dataset = DatasetsWrapperHF.load_dataset
|
||||
@contextlib.contextmanager
|
||||
def load_dataset_with_ctx(*args, **kwargs):
|
||||
hf_endpoint_origin = config.HF_ENDPOINT
|
||||
get_from_cache_origin = file_utils.get_from_cache
|
||||
_download_origin = DownloadManager._download
|
||||
dataset_info_origin = HfApi.dataset_info
|
||||
list_repo_tree_origin = HfApi.list_repo_tree
|
||||
get_paths_info_origin = HfApi.get_paths_info
|
||||
resolve_pattern_origin = data_files.resolve_pattern
|
||||
get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module
|
||||
get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module
|
||||
|
||||
config.HF_ENDPOINT = get_endpoint()
|
||||
file_utils.get_from_cache = get_from_cache_ms
|
||||
DownloadManager._download = _download_ms
|
||||
HfApi.dataset_info = _dataset_info
|
||||
HfApi.list_repo_tree = _list_repo_tree
|
||||
HfApi.get_paths_info = _get_paths_info
|
||||
data_files.resolve_pattern = _resolve_pattern
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
|
||||
|
||||
try:
|
||||
dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
|
||||
yield dataset_res
|
||||
finally:
|
||||
config.HF_ENDPOINT = hf_endpoint_origin
|
||||
file_utils.get_from_cache = get_from_cache_origin
|
||||
DownloadManager._download = _download_origin
|
||||
HfApi.dataset_info = dataset_info_origin
|
||||
HfApi.list_repo_tree = list_repo_tree_origin
|
||||
HfApi.get_paths_info = get_paths_info_origin
|
||||
data_files.resolve_pattern = resolve_pattern_origin
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
|
||||
|
||||
logger.info('Context manager of ms-dataset exited.')
|
||||
|
||||
@@ -20,7 +20,7 @@ from filelock import FileLock
|
||||
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.14.5
|
||||
datasets>=2.16.0,<2.19.0
|
||||
einops
|
||||
filelock>=3.3.0
|
||||
gast>=0.2.2
|
||||
|
||||
Reference in New Issue
Block a user