This commit is contained in:
xingjun.wang
2024-08-05 19:20:25 +08:00
parent 234729b2f8
commit 7151ebc761

View File

@@ -857,6 +857,13 @@ def get_module_with_script(self) -> DatasetModule:
return DatasetModule(module_path, hash, builder_kwargs)
def increase_load_count_ms(name: str, resource_type: str):
"""
Placeholder for increasing the load count of a hf resource.
"""
...
class DatasetsWrapperHF:
@staticmethod
@@ -1330,6 +1337,8 @@ class DatasetsWrapperHF:
@contextlib.contextmanager
def load_dataset_with_ctx(*args, **kwargs):
from datasets.load import increase_load_count as increase_load_count
hf_endpoint_origin = config.HF_ENDPOINT
get_from_cache_origin = file_utils.get_from_cache
@@ -1337,6 +1346,8 @@ def load_dataset_with_ctx(*args, **kwargs):
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
else DownloadManager._download_single
increase_load_count_origin = increase_load_count
dataset_info_origin = HfApi.dataset_info
list_repo_tree_origin = HfApi.list_repo_tree
get_paths_info_origin = HfApi.get_paths_info
@@ -1359,6 +1370,7 @@ def load_dataset_with_ctx(*args, **kwargs):
data_files.resolve_pattern = _resolve_pattern
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
increase_load_count = increase_load_count_ms
try:
dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
@@ -1379,3 +1391,4 @@ def load_dataset_with_ctx(*args, **kwargs):
data_files.resolve_pattern = resolve_pattern_origin
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
increase_load_count = increase_load_count_origin