Dataset refactor (#807)

* add main entry in ms_dataset

* update func get_data_patterns import

* modify return_config_only

* modify return_config_only to dataset_info_only

* udpate version for test

* del get_logger(__name__)

* fix py script loading

* fix loading py and without py

* add subset support

* add hf_datasets_util; refine list_repo_tree_ms; fix private datasets loading issue

* update version to rc5

* fix and support preview for dataset_info_only mode

* fix urlencode

* update to rc7

* loading of dataset_infos.json is deprecated; 2. add some ut

* update version

* add escapechar for read_csv and to_csv

* add params: Source=SDK

* add create_dataset func

* overwrite _get_paths_info

* update & version

* update list_repo_tree name

* add get_module_with_script, fix download imports

* fix py script loading issue in dataset_module_factory

* fix create dataset

* update log info in api
This commit is contained in:
Xingjun.Wang
2024-03-22 17:30:34 +08:00
committed by GitHub
parent 9d2c2708ff
commit 1a66f069c4
10 changed files with 1873 additions and 25 deletions

View File

@@ -15,7 +15,9 @@ from http import HTTPStatus
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from os.path import expanduser from os.path import expanduser
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from urllib.parse import urlencode
import json
import pandas as pd import pandas as pd
import requests import requests
from requests import Session from requests import Session
@@ -31,7 +33,8 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
MODELSCOPE_CLOUD_ENVIRONMENT, MODELSCOPE_CLOUD_ENVIRONMENT,
MODELSCOPE_CLOUD_USERNAME, MODELSCOPE_CLOUD_USERNAME,
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS, MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
REQUESTS_API_HTTP_METHOD, Licenses, REQUESTS_API_HTTP_METHOD,
DatasetVisibility, Licenses,
ModelVisibility) ModelVisibility)
from modelscope.hub.errors import (InvalidParameter, NotExistError, from modelscope.hub.errors import (InvalidParameter, NotExistError,
NotLoginException, NoValidRevisionError, NotLoginException, NoValidRevisionError,
@@ -647,6 +650,44 @@ class HubApi:
files.append(file) files.append(file)
return files return files
def create_dataset(self,
dataset_name: str,
namespace: str,
chinese_name: Optional[str] = '',
license: Optional[str] = Licenses.APACHE_V2,
visibility: Optional[int] = DatasetVisibility.PUBLIC,
description: Optional[str] = '') -> str:
if dataset_name is None or namespace is None:
raise InvalidParameter('dataset_name and namespace are required!')
cookies = ModelScopeConfig.get_cookies()
if cookies is None:
raise ValueError('Token does not exist, please login first.')
path = f'{self.endpoint}/api/v1/datasets'
files = {
'Name': (None, dataset_name),
'ChineseName': (None, chinese_name),
'Owner': (None, namespace),
'License': (None, license),
'Visibility': (None, visibility),
'Description': (None, description)
}
r = self.session.post(
path,
files=files,
cookies=cookies,
headers=self.builder_headers(self.headers),
)
handle_http_post_error(r, path, files)
raise_on_error(r.json())
dataset_repo_url = f'{self.endpoint}/datasets/{namespace}/{dataset_name}'
logger.info(f'Create dataset success: {dataset_repo_url}')
return dataset_repo_url
def list_datasets(self): def list_datasets(self):
path = f'{self.endpoint}/api/v1/datasets' path = f'{self.endpoint}/api/v1/datasets'
params = {} params = {}
@@ -667,6 +708,47 @@ class HubApi:
dataset_type = resp['Data']['Type'] dataset_type = resp['Data']['Type']
return dataset_id, dataset_type return dataset_id, dataset_type
def get_dataset_infos(self,
dataset_hub_id: str,
revision: str,
files_metadata: bool = False,
timeout: float = 100,
recursive: str = 'True'):
"""
Get dataset infos.
"""
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
params = {'Revision': revision, 'Root': None, 'Recursive': recursive}
cookies = ModelScopeConfig.get_cookies()
if files_metadata:
params['blobs'] = True
r = self.session.get(datahub_url, params=params, cookies=cookies, timeout=timeout)
resp = r.json()
datahub_raise_on_error(datahub_url, resp, r)
return resp
def list_repo_tree(self,
dataset_name: str,
namespace: str,
revision: str,
root_path: str,
recursive: bool = True):
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
dataset_name=dataset_name, namespace=namespace)
recursive = 'True' if recursive else 'False'
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
params = {'Revision': revision, 'Root': root_path, 'Recursive': recursive}
cookies = ModelScopeConfig.get_cookies()
r = self.session.get(datahub_url, params=params, cookies=cookies)
resp = r.json()
datahub_raise_on_error(datahub_url, resp, r)
return resp
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, dataset_id: str, revision: str): def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, dataset_id: str, revision: str):
""" Get the meta file-list of the dataset. """ """ Get the meta file-list of the dataset. """
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
@@ -735,7 +817,6 @@ class HubApi:
Fetch the meta-data files from the url, e.g. csv/jsonl files. Fetch the meta-data files from the url, e.g. csv/jsonl files.
""" """
import hashlib import hashlib
import json
from tqdm import tqdm from tqdm import tqdm
out_path = os.path.join(out_path, hashlib.md5(url.encode(encoding='UTF-8')).hexdigest()) out_path = os.path.join(out_path, hashlib.md5(url.encode(encoding='UTF-8')).hexdigest())
if mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(out_path): if mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(out_path):
@@ -774,7 +855,7 @@ class HubApi:
else: else:
with_header = False with_header = False
chunk_df = pd.DataFrame(chunk) chunk_df = pd.DataFrame(chunk)
chunk_df.to_csv(f, index=False, header=with_header) chunk_df.to_csv(f, index=False, header=with_header, escapechar='\\')
iter_num += 1 iter_num += 1
else: else:
# csv or others # csv or others
@@ -789,11 +870,28 @@ class HubApi:
file_name: str, file_name: str,
dataset_name: str, dataset_name: str,
namespace: str, namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION): revision: Optional[str] = DEFAULT_DATASET_REVISION,
if file_name and os.path.splitext(file_name)[-1] in META_FILES_FORMAT: extension_filter: Optional[bool] = True):
file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
f'Revision={revision}&FilePath={file_name}' if not file_name or not dataset_name or not namespace:
return file_name raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
# Note: make sure the FilePath is the last parameter in the url
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name}
params: str = urlencode(params)
file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
return file_url
# if extension_filter:
# if os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
# file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?'\
# f'Revision={revision}&FilePath={file_name}'
# else:
# file_url = file_name
# return file_url
# else:
# return file_url
def get_dataset_access_config( def get_dataset_access_config(
self, self,
@@ -931,7 +1029,7 @@ class HubApi:
datahub_raise_on_error(url, resp, r) datahub_raise_on_error(url, resp, r)
return resp['Data'] return resp['Data']
def dataset_download_statistics(self, dataset_name: str, namespace: str, use_streaming: bool) -> None: def dataset_download_statistics(self, dataset_name: str, namespace: str, use_streaming: bool = False) -> None:
is_ci_test = os.getenv('CI_TEST') == 'True' is_ci_test = os.getenv('CI_TEST') == 'True'
if dataset_name and namespace and not is_ci_test and not use_streaming: if dataset_name and namespace and not is_ci_test and not use_streaming:
try: try:
@@ -964,6 +1062,10 @@ class HubApi:
return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex), return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
**headers} **headers}
def get_file_base_path(self, namespace: str, dataset_name: str) -> str:
return f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?'
# return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
class ModelScopeConfig: class ModelScopeConfig:
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH) path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)

View File

@@ -47,3 +47,9 @@ class ModelVisibility(object):
PRIVATE = 1 PRIVATE = 1
INTERNAL = 3 INTERNAL = 3
PUBLIC = 5 PUBLIC = 5
class DatasetVisibility(object):
PRIVATE = 1
INTERNAL = 3
PUBLIC = 5

View File

@@ -92,6 +92,10 @@ class DataMetaManager(object):
data_meta_config.meta_cache_dir = meta_cache_dir data_meta_config.meta_cache_dir = meta_cache_dir
data_meta_config.dataset_scripts = dataset_scripts data_meta_config.dataset_scripts = dataset_scripts
data_meta_config.dataset_formation = dataset_formation data_meta_config.dataset_formation = dataset_formation
if '.py' in dataset_scripts:
tmp_py_scripts = dataset_scripts['.py']
if len(tmp_py_scripts) > 0:
data_meta_config.dataset_py_script = tmp_py_scripts[0]
# Set dataset_context_config # Set dataset_context_config
self.dataset_context_config.data_meta_config = data_meta_config self.dataset_context_config.data_meta_config = data_meta_config

View File

@@ -13,7 +13,6 @@ from datasets.utils.file_utils import is_relative_path
from modelscope.hub.repository import DatasetRepository from modelscope.hub.repository import DatasetRepository
from modelscope.msdatasets.context.dataset_context_config import \ from modelscope.msdatasets.context.dataset_context_config import \
DatasetContextConfig DatasetContextConfig
from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
from modelscope.msdatasets.data_loader.data_loader_manager import ( from modelscope.msdatasets.data_loader.data_loader_manager import (
LocalDataLoaderManager, LocalDataLoaderType, RemoteDataLoaderManager, LocalDataLoaderManager, LocalDataLoaderType, RemoteDataLoaderManager,
RemoteDataLoaderType) RemoteDataLoaderType)
@@ -22,14 +21,16 @@ from modelscope.msdatasets.dataset_cls import (ExternalDataset,
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \ from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
build_custom_dataset build_custom_dataset
from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager
from modelscope.msdatasets.utils.hf_datasets_util import \
load_dataset as hf_load_dataset_wrapper
from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager
from modelscope.preprocessors import build_preprocessor from modelscope.preprocessors import build_preprocessor
from modelscope.utils.config import Config, ConfigDict from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE, from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
DEFAULT_DATASET_REVISION, ConfigFields, DEFAULT_DATASET_REVISION, ConfigFields,
DownloadMode, Hubs, ModeKeys, Tasks, DatasetFormations, DownloadMode, Hubs,
UploadMode, VirgoDatasetConfig) ModeKeys, Tasks, UploadMode)
from modelscope.utils.import_utils import is_tf_available, is_torch_available from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
@@ -167,6 +168,7 @@ class MsDataset:
stream_batch_size: Optional[int] = 1, stream_batch_size: Optional[int] = 1,
custom_cfg: Optional[Config] = Config(), custom_cfg: Optional[Config] = Config(),
token: Optional[str] = None, token: Optional[str] = None,
dataset_info_only: Optional[bool] = False,
**config_kwargs, **config_kwargs,
) -> Union[dict, 'MsDataset', NativeIterableDataset]: ) -> Union[dict, 'MsDataset', NativeIterableDataset]:
"""Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset. """Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
@@ -196,6 +198,7 @@ class MsDataset:
custom_cfg (str, Optional): Model configuration, this can be used for custom datasets. custom_cfg (str, Optional): Model configuration, this can be used for custom datasets.
see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3 see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3
token (str, Optional): SDK token of ModelScope. token (str, Optional): SDK token of ModelScope.
dataset_info_only (bool, Optional): If set to True, only return the dataset config and info (dict).
**config_kwargs (additional keyword arguments): Keyword arguments to be passed **config_kwargs (additional keyword arguments): Keyword arguments to be passed
Returns: Returns:
@@ -279,19 +282,51 @@ class MsDataset:
return dataset_inst return dataset_inst
# Load from the modelscope hub # Load from the modelscope hub
elif hub == Hubs.modelscope: elif hub == Hubs.modelscope:
remote_dataloader_manager = RemoteDataLoaderManager(
dataset_context_config) # Get dataset type from ModelScope Hub; dataset_type->4: General Dataset
dataset_inst = remote_dataloader_manager.load_dataset( from modelscope.hub.api import HubApi
RemoteDataLoaderType.MS_DATA_LOADER) _api = HubApi()
dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target) dataset_id_on_hub, dataset_type = _api.get_dataset_id_and_type(
if isinstance(dataset_inst, MsDataset): dataset_name=dataset_name, namespace=namespace)
dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config
if custom_cfg: logger.info(f'dataset_type: {dataset_type}')
dataset_inst.to_custom_dataset(
custom_cfg=custom_cfg, **config_kwargs) # Load from the ModelScope Hub for type=4 (general)
dataset_inst.is_custom = True if str(dataset_type) == str(DatasetFormations.general.value):
return dataset_inst return hf_load_dataset_wrapper(
path=namespace + '/' + dataset_name,
name=subset_name,
data_dir=data_dir,
data_files=data_files,
split=split,
cache_dir=cache_dir,
features=None,
download_config=None,
download_mode=download_mode.value,
revision=version,
token=token,
streaming=use_streaming,
dataset_info_only=dataset_info_only,
**config_kwargs)
else:
remote_dataloader_manager = RemoteDataLoaderManager(
dataset_context_config)
dataset_inst = remote_dataloader_manager.load_dataset(
RemoteDataLoaderType.MS_DATA_LOADER)
dataset_inst = MsDataset.to_ms_dataset(
dataset_inst, target=target)
if isinstance(dataset_inst, MsDataset):
dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config
if custom_cfg:
dataset_inst.to_custom_dataset(
custom_cfg=custom_cfg, **config_kwargs)
dataset_inst.is_custom = True
return dataset_inst
elif hub == Hubs.virgo: elif hub == Hubs.virgo:
from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
from modelscope.utils.constant import VirgoDatasetConfig
# Rewrite the namespace, version and cache_dir for virgo dataset. # Rewrite the namespace, version and cache_dir for virgo dataset.
if namespace == DEFAULT_DATASET_NAMESPACE: if namespace == DEFAULT_DATASET_NAMESPACE:
dataset_context_config.namespace = VirgoDatasetConfig.default_virgo_namespace dataset_context_config.namespace = VirgoDatasetConfig.default_virgo_namespace
@@ -323,6 +358,10 @@ class MsDataset:
chunksize: Optional[int] = 1, chunksize: Optional[int] = 1,
filter_hidden_files: Optional[bool] = True, filter_hidden_files: Optional[bool] = True,
upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None: upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None:
r"""
@deprecated
This method is deprecated and may be removed in future releases, please use git command line instead.
"""
"""Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first. """Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first.
Args: Args:
@@ -346,6 +385,10 @@ class MsDataset:
None None
""" """
warnings.warn(
'upload is deprecated, please use git command line to upload the dataset.',
DeprecationWarning)
if not object_name: if not object_name:
raise ValueError('object_name cannot be empty!') raise ValueError('object_name cannot be empty!')
@@ -393,6 +436,10 @@ class MsDataset:
None None
""" """
warnings.warn(
'upload is deprecated, please use git command line to upload the dataset.',
DeprecationWarning)
_repo = DatasetRepository( _repo = DatasetRepository(
repo_work_dir=dataset_work_dir, repo_work_dir=dataset_work_dir,
dataset_id=dataset_id, dataset_id=dataset_id,

View File

@@ -212,7 +212,10 @@ def get_dataset_files(subset_split_into: dict,
csv_delimiter = context_config.config_kwargs.get('delimiter', ',') csv_delimiter = context_config.config_kwargs.get('delimiter', ',')
csv_df = pd.read_csv( csv_df = pd.read_csv(
meta_csv_file_path, iterator=False, delimiter=csv_delimiter) meta_csv_file_path,
iterator=False,
delimiter=csv_delimiter,
escapechar='\\')
target_col = csv_df.columns[csv_df.columns.str.contains( target_col = csv_df.columns[csv_df.columns.str.contains(
':FILE')].to_list() ':FILE')].to_list()
if len(target_col) == 0: if len(target_col) == 0:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,237 @@
# noqa: isort:skip_file, yapf: disable
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
import json
import os
import re
import shutil
import warnings
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from urllib.parse import urljoin, urlparse
import requests
from datasets import config
from datasets.utils.file_utils import hash_url_to_filename, get_authentication_headers_for_url, ftp_head, fsspec_head, \
http_head, _raise_if_offline_mode_is_enabled, ftp_get, fsspec_get, http_get
from filelock import FileLock
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.logger import get_logger
from modelscope.hub.api import HubApi, ModelScopeConfig
logger = get_logger()
def get_from_cache_ms(
url,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=100,
resume_download=False,
user_agent=None,
local_files_only=False,
use_etag=True,
max_retries=0,
token=None,
use_auth_token='deprecated',
ignore_url_params=False,
storage_options=None,
download_desc=None,
) -> str:
"""
Given a URL, look for the corresponding file in the local cache.
If it's not there, download it. Then return the path to the cached file.
Return:
Local path (string)
Raises:
FileNotFoundError: in case of non-recoverable file
(non-existent or no cache on disk)
ConnectionError: in case of unreachable url
and no cache on disk
"""
if use_auth_token != 'deprecated':
warnings.warn(
"'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
f"You can remove this warning by passing 'token={use_auth_token}' instead.",
FutureWarning,
)
token = use_auth_token
if cache_dir is None:
cache_dir = MS_DATASETS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
if ignore_url_params:
# strip all query parameters and #fragments from the URL
cached_url = urljoin(url, urlparse(url).path)
else:
cached_url = url # additional parameters may be added to the given URL
connected = False
response = None
cookies = None
etag = None
head_error = None
scheme = None
# Try a first time to file the file on the local file system without eTag (None)
# if we don't ask for 'force_download' then we spare a request
filename = hash_url_to_filename(cached_url, etag=None)
cache_path = os.path.join(cache_dir, filename)
if os.path.exists(cache_path) and not force_download and not use_etag:
return cache_path
# Prepare headers for authentication
headers = get_authentication_headers_for_url(url, token=token)
if user_agent is not None:
headers['user-agent'] = user_agent
# We don't have the file locally or we need an eTag
if not local_files_only:
scheme = urlparse(url).scheme
if scheme == 'ftp':
connected = ftp_head(url)
elif scheme not in ('http', 'https'):
response = fsspec_head(url, storage_options=storage_options)
# s3fs uses "ETag", gcsfs uses "etag"
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
connected = True
try:
cookies = ModelScopeConfig.get_cookies()
response = http_head(
url,
allow_redirects=True,
proxies=proxies,
timeout=etag_timeout,
max_retries=max_retries,
headers=headers,
cookies=cookies,
)
if response.status_code == 200: # ok
etag = response.headers.get('ETag') if use_etag else None
for k, v in response.cookies.items():
# In some edge cases, we need to get a confirmation token
if k.startswith('download_warning') and 'drive.google.com' in url:
url += '&confirm=' + v
cookies = response.cookies
connected = True
# Fix Google Drive URL to avoid Virus scan warning
if 'drive.google.com' in url and 'confirm=' not in url:
url += '&confirm=t'
# In some edge cases, head request returns 400 but the connection is actually ok
elif (
(response.status_code == 400 and 'firebasestorage.googleapis.com' in url)
or (response.status_code == 405 and 'drive.google.com' in url)
or (
response.status_code == 403
and (
re.match(r'^https?://github.com/.*?/.*?/releases/download/.*?/.*?$', url)
or re.match(r'^https://.*?s3.*?amazonaws.com/.*?$', response.url)
)
)
or (response.status_code == 403 and 'ndownloader.figstatic.com' in url)
):
connected = True
logger.info(f"Couldn't get ETag version for url {url}")
elif response.status_code == 401 and config.HF_ENDPOINT in url and token is None:
raise ConnectionError(
f'Unauthorized for URL {url}. '
f'Please use the parameter `token=True` after logging in with `huggingface-cli login`'
)
except (OSError, requests.exceptions.Timeout) as e:
# not connected
head_error = e
pass
# connected == False = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
# try to get the last downloaded one
if not connected:
if os.path.exists(cache_path) and not force_download:
return cache_path
if local_files_only:
raise FileNotFoundError(
f'Cannot find the requested files in the cached path at {cache_path} and outgoing traffic has been'
" disabled. To enable file online look-ups, set 'local_files_only' to False."
)
elif response is not None and response.status_code == 404:
raise FileNotFoundError(f"Couldn't find file at {url}")
_raise_if_offline_mode_is_enabled(f'Tried to reach {url}')
if head_error is not None:
raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
elif response is not None:
raise ConnectionError(f"Couldn't reach {url} (error {response.status_code})")
else:
raise ConnectionError(f"Couldn't reach {url}")
# Try a second time
filename = hash_url_to_filename(cached_url, etag)
cache_path = os.path.join(cache_dir, filename)
if os.path.exists(cache_path) and not force_download:
return cache_path
# From now on, connected is True.
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + '.lock'
with FileLock(lock_path):
# Retry in case previously locked processes just enter after the precedent process releases the lock
if os.path.exists(cache_path) and not force_download:
return cache_path
incomplete_path = cache_path + '.incomplete'
@contextmanager
def temp_file_manager(mode='w+b'):
with open(incomplete_path, mode) as f:
yield f
resume_size = 0
if resume_download:
temp_file_manager = partial(temp_file_manager, mode='a+b')
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
# Download to temporary file, then copy to cache path once finished.
# Otherwise, you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info(f'Downloading to {temp_file.name}')
# GET file object
if scheme == 'ftp':
ftp_get(url, temp_file)
elif scheme not in ('http', 'https'):
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
else:
http_get(
url,
temp_file=temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
)
logger.info(f'storing {url} in cache at {cache_path}')
shutil.move(temp_file.name, cache_path)
umask = os.umask(0o666)
os.umask(umask)
os.chmod(cache_path, 0o666 & ~umask)
logger.info(f'creating metadata file for {cache_path}')
meta = {'url': url, 'etag': etag}
meta_path = cache_path + '.json'
with open(meta_path, 'w', encoding='utf-8') as meta_file:
json.dump(meta, meta_file)
return cache_path

View File

@@ -393,9 +393,14 @@ class DatasetFormations(enum.Enum):
# formation that is compatible with official huggingface dataset, which # formation that is compatible with official huggingface dataset, which
# organizes whole dataset into one single (zip) file. # organizes whole dataset into one single (zip) file.
hf_compatible = 1 hf_compatible = 1
# native modelscope formation that supports, among other things, # native modelscope formation that supports, among other things,
# multiple files in a dataset # multiple files in a dataset
native = 2 native = 2
# general formation for datasets
general = 4
# for local meta cache mark # for local meta cache mark
formation_mark_ext = '.formation_mark' formation_mark_ext = '.formation_mark'
@@ -403,6 +408,7 @@ class DatasetFormations(enum.Enum):
DatasetMetaFormats = { DatasetMetaFormats = {
DatasetFormations.native: ['.json'], DatasetFormations.native: ['.json'],
DatasetFormations.hf_compatible: ['.py'], DatasetFormations.hf_compatible: ['.py'],
DatasetFormations.general: ['.py'],
} }

View File

@@ -4,6 +4,7 @@ datasets>=2.14.5
einops einops
filelock>=3.3.0 filelock>=3.3.0
gast>=0.2.2 gast>=0.2.2
huggingface_hub
numpy numpy
oss2 oss2
pandas pandas

View File

@@ -0,0 +1,103 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope import MsDataset
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level
logger = get_logger()
# Note: MODELSCOPE_DOMAIN is set to 'test.modelscope.cn' in the environment variable
# TODO: ONLY FOR TEST ENVIRONMENT, to be replaced by the online domain
TEST_INNER_LEVEL = 1
class GeneralMsDatasetTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_return_dataset_info_only(self):
ds = MsDataset.load(
'wangxingjun778test/aya_dataset_mini', dataset_info_only=True)
print(f'>>output of test_return_dataset_info_only:\n {ds}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_fashion_mnist(self):
# inner means the dataset is on the test.modelscope.cn environment
ds = MsDataset.load(
'xxxxtest0004/ms_test_0308_py',
subset_name='fashion_mnist',
split='train')
print(f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_clue(self):
ds = MsDataset.load(
'wangxingjun778test/clue', subset_name='afqmc', split='train')
print(f'>>output of test_inner_clue:\n {next(iter(ds))}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_cats_and_dogs_mini(self):
ds = MsDataset.load(
'wangxingjun778test/cats_and_dogs_mini', split='train')
print(f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_aya_dataset_mini(self):
# Dataset Format:
# data/train-xxx-of-xxx.parquet; data/test-xxx-of-xxx.parquet
# demographics/train-xxx-of-xxx.parquet
ds = MsDataset.load(
'wangxingjun778test/aya_dataset_mini', split='train')
print(f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
ds = MsDataset.load(
'wangxingjun778test/aya_dataset_mini', subset_name='demographics')
assert next(iter(ds['train']))
print(
f">>output of test_inner_aya_dataset_mini:\n {next(iter(ds['train']))}"
)
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_no_standard_imgs(self):
infos = MsDataset.load(
'xxxxtest0004/png_jpg_txt_test', dataset_info_only=True)
assert infos['default']
ds = MsDataset.load('xxxxtest0004/png_jpg_txt_test', split='train')
print(f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
assert next(iter(ds))
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_hf_pictures(self):
ds = MsDataset.load('xxxxtest0004/hf_Pictures')
print(ds)
assert next(iter(ds))
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
def test_inner_speech_yinpin(self):
ds = MsDataset.load('xxxxtest0004/hf_lj_speech_yinpin_test')
print(ds)
assert next(iter(ds))
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_yuancheng_picture(self):
ds = MsDataset.load(
'xxxxtest0004/yuancheng_picture',
subset_name='remote_images',
split='train')
print(next(iter(ds)))
assert next(iter(ds))
if __name__ == '__main__':
unittest.main()