From 62078f1796438e65ce31271e7dfae6fd32b6a741 Mon Sep 17 00:00:00 2001 From: "Xingjun.Wang" Date: Thu, 17 Jul 2025 19:27:39 +0800 Subject: [PATCH] Fix dataset infos (#1414) --- modelscope/hub/api.py | 133 ++++++++++----- modelscope/hub/file_download.py | 35 ++-- modelscope/hub/snapshot_download.py | 38 ++--- .../msdatasets/utils/hf_datasets_util.py | 156 +++++++----------- 4 files changed, 191 insertions(+), 171 deletions(-) diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index ff423838..e4de81b0 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -374,7 +374,7 @@ class HubApi: by a `/`. repo_type (`str`, *optional*): `None` or `"model"` if getting repository info from a model. Default is `None`. - TODO: support dataset and studio + TODO: support studio endpoint(`str`): None or specific endpoint to use, when None, use the default endpoint set in HubApi class (self.endpoint) @@ -886,6 +886,9 @@ class HubApi: raise_on_error(d) files = [] + if not d[API_RESPONSE_FIELD_DATA]['Files']: + logger.warning(f'No files found in model {model_id} at revision {revision}.') + return files for file in d[API_RESPONSE_FIELD_DATA]['Files']: if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes': continue @@ -993,29 +996,6 @@ class HubApi: dataset_type = resp['Data']['Type'] return dataset_id, dataset_type - def get_dataset_infos(self, - dataset_hub_id: str, - revision: str, - files_metadata: bool = False, - timeout: float = 100, - recursive: str = 'True', - endpoint: Optional[str] = None): - """ - Get dataset infos. - """ - if not endpoint: - endpoint = self.endpoint - datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree' - params = {'Revision': revision, 'Root': None, 'Recursive': recursive} - cookies = ModelScopeConfig.get_cookies() - if files_metadata: - params['blobs'] = True - r = self.session.get(datahub_url, params=params, cookies=cookies, timeout=timeout) - resp = r.json() - datahub_raise_on_error(datahub_url, resp, r) - - return resp - def list_repo_tree(self, dataset_name: str, namespace: str, @@ -1025,6 +1005,11 @@ class HubApi: page_number: int = 1, page_size: int = 100, endpoint: Optional[str] = None): + """ + @deprecated: Use `get_dataset_files` instead. + """ + warnings.warn('The function `list_repo_tree` is deprecated, use `get_dataset_files` instead.', + DeprecationWarning) dataset_hub_id, dataset_type = self.get_dataset_id_and_type( dataset_name=dataset_name, namespace=namespace, endpoint=endpoint) @@ -1044,6 +1029,59 @@ class HubApi: return resp + def get_dataset_files(self, + repo_id: str, + *, + revision: str = DEFAULT_REPOSITORY_REVISION, + root_path: str = '/', + recursive: bool = True, + page_number: int = 1, + page_size: int = 100, + endpoint: Optional[str] = None): + """ + Get the dataset files. + + Args: + repo_id (str): The repository id, in the format of `namespace/dataset_name`. + revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`. + root_path (str): The root path to list. Defaults to '/'. + recursive (bool): Whether to list recursively. Defaults to True. + page_number (int): The page number for pagination. Defaults to 1. + page_size (int): The number of items per page. Defaults to 100. + endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class. + + Returns: + List: The response containing the dataset repository tree information. + e.g. [{'CommitId': None, 'CommitMessage': '...', 'Size': 0, 'Type': 'tree'}, ...] + """ + from datasets.utils.file_utils import is_relative_path + + if is_relative_path(repo_id) and repo_id.count('/') == 1: + _owner, _dataset_name = repo_id.split('/') + else: + raise ValueError(f'Invalid repo_id: {repo_id} !') + + dataset_hub_id, dataset_type = self.get_dataset_id_and_type( + dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint) + + if not endpoint: + endpoint = self.endpoint + datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree' + params = { + 'Revision': revision, + 'Root': root_path, + 'Recursive': 'True' if recursive else 'False', + 'PageNumber': page_number, + 'PageSize': page_size + } + cookies = ModelScopeConfig.get_cookies() + + r = self.session.get(datahub_url, params=params, cookies=cookies) + resp = r.json() + datahub_raise_on_error(datahub_url, resp, r) + + return resp['Data']['Files'] + def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, dataset_id: str, revision: str, endpoint: Optional[str] = None): """ Get the meta file-list of the dataset. """ @@ -2150,22 +2188,40 @@ class HubApi: recursive=True, endpoint=endpoint ) - file_list = [f['Path'] for f in files] + file_paths = [f['Path'] for f in files] + elif repo_type == REPO_TYPE_DATASET: + file_paths = [] + page_number = 1 + page_size = 100 + while True: + try: + dataset_files: List[Dict[str, Any]] = self.get_dataset_files( + repo_id=repo_id, + revision=revision or DEFAULT_DATASET_REVISION, + recursive=True, + page_number=page_number, + page_size=page_size, + endpoint=endpoint, + ) + except Exception as e: + logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}') + break + + # Parse data (Type: 'tree' or 'blob') + for file_info_d in dataset_files: + if file_info_d['Type'] != 'tree': + file_paths.append(file_info_d['Path']) + + if len(dataset_files) < page_size: + break + + page_number += 1 else: - namespace, dataset_name = repo_id.split('/') - dataset_hub_id, _ = self.get_dataset_id_and_type(dataset_name, namespace, endpoint=endpoint) - dataset_info = self.get_dataset_infos( - dataset_hub_id, - revision or DEFAULT_DATASET_REVISION, - recursive='True', - endpoint=endpoint - ) - files = dataset_info.get('Data', {}).get('Files', []) - file_list = [f['Path'] for f in files] + raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}') # Glob pattern matching to_delete = [] - for path in file_list: + for path in file_paths: for delete_pattern in delete_patterns: if fnmatch.fnmatch(path, delete_pattern): to_delete.append(path) @@ -2181,12 +2237,15 @@ class HubApi: 'Revision': revision or DEFAULT_MODEL_REVISION, 'FilePath': path } - else: + elif repo_type == REPO_TYPE_DATASET: owner, dataset_name = repo_id.split('/') url = f'{endpoint}/api/v1/datasets/{owner}/{dataset_name}/repo' params = { 'FilePath': path } + else: + raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}') + r = self.session.delete(url, params=params, cookies=cookies, headers=headers) raise_for_http_status(r) resp = r.json() diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 872c7f4c..eeb0d414 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -234,25 +234,22 @@ def _repo_file_download( page_number = 1 page_size = 100 while True: - files_list_tree = _api.list_repo_tree( - dataset_name=name, - namespace=group_or_owner, - revision=revision, - root_path='/', - recursive=True, - page_number=page_number, - page_size=page_size, - endpoint=endpoint) - if not ('Code' in files_list_tree - and files_list_tree['Code'] == 200): - print( - 'Get dataset: %s file list failed, request_id: %s, message: %s' - % (repo_id, files_list_tree['RequestId'], - files_list_tree['Message'])) - return None - repo_files = files_list_tree['Data']['Files'] + try: + dataset_files = _api.get_dataset_files( + repo_id=repo_id, + revision=revision, + root_path='/', + recursive=True, + page_number=page_number, + page_size=page_size, + endpoint=endpoint) + except Exception as e: + logger.error( + f'Get dataset: {repo_id} file list failed, error: {e}') + break + is_exist = False - for repo_file in repo_files: + for repo_file in dataset_files: if repo_file['Type'] == 'tree': continue @@ -267,7 +264,7 @@ def _repo_file_download( file_to_download_meta = repo_file is_exist = True break - if len(repo_files) < page_size or is_exist: + if len(dataset_files) < page_size or is_exist: break page_number += 1 diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index a2e91d7e..f30c9312 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -381,8 +381,8 @@ def _snapshot_download( revision_detail = revision or DEFAULT_DATASET_REVISION logger.info('Fetching dataset repo file list...') - repo_files = fetch_repo_files(_api, name, group_or_owner, - revision_detail, endpoint) + repo_files = fetch_repo_files(_api, repo_id, revision_detail, + endpoint) if repo_files is None: logger.error( @@ -415,32 +415,28 @@ def _snapshot_download( return cache_root_path -def fetch_repo_files(_api, name, group_or_owner, revision, endpoint): +def fetch_repo_files(_api, repo_id, revision, endpoint): page_number = 1 page_size = 150 repo_files = [] while True: - files_list_tree = _api.list_repo_tree( - dataset_name=name, - namespace=group_or_owner, - revision=revision, - root_path='/', - recursive=True, - page_number=page_number, - page_size=page_size, - endpoint=endpoint) + try: + dataset_files = _api.get_dataset_files( + repo_id=repo_id, + revision=revision, + root_path='/', + recursive=True, + page_number=page_number, + page_size=page_size, + endpoint=endpoint) + except Exception as e: + logger.error(f'Error fetching dataset files: {e}') + break - if not ('Code' in files_list_tree and files_list_tree['Code'] == 200): - logger.error(f'Get dataset file list failed, request_id: \ - {files_list_tree["RequestId"]}, message: {files_list_tree["Message"]}' - ) - return None + repo_files.extend(dataset_files) - cur_repo_files = files_list_tree['Data']['Files'] - repo_files.extend(cur_repo_files) - - if len(cur_repo_files) < page_size: + if len(dataset_files) < page_size: break page_number += 1 diff --git a/modelscope/msdatasets/utils/hf_datasets_util.py b/modelscope/msdatasets/utils/hf_datasets_util.py index a76bfbce..224964f4 100644 --- a/modelscope/msdatasets/utils/hf_datasets_util.py +++ b/modelscope/msdatasets/utils/hf_datasets_util.py @@ -158,56 +158,46 @@ def _dataset_info( """ - _api = HubApi() - _namespace, _dataset_name = repo_id.split('/') - endpoint = _api.get_endpoint_for_read( - repo_id=repo_id, repo_type=REPO_TYPE_DATASET) - dataset_hub_id, dataset_type = _api.get_dataset_id_and_type( - dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint) + # Note: refer to `_list_repo_tree()`, for patching `HfApi.list_repo_tree` + repo_info_iter = self.list_repo_tree( + repo_id=repo_id, + path_in_repo='/', + revision=revision, + recursive=False, + expand=expand, + token=token, + repo_type=REPO_TYPE_DATASET, + ) - revision: str = revision or DEFAULT_DATASET_REVISION - data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id, - revision=revision, - files_metadata=files_metadata, - timeout=timeout, - endpoint=endpoint) - - # Parse data - data_d: dict = data['Data'] - data_file_list: list = data_d['Files'] - # commit_info: dict = data_d['LatestCommitter'] - - # Update data # TODO: columns align with HfDatasetInfo - data['id'] = repo_id - data['private'] = False - data['author'] = repo_id.split('/')[0] if repo_id else None - data['sha'] = revision - data['lastModified'] = None - data['gated'] = False - data['disabled'] = False - data['downloads'] = 0 - data['likes'] = 0 - data['tags'] = [] - data['cardData'] = [] - data['createdAt'] = None + # Update data_info + data_info = dict({}) + data_info['id'] = repo_id + data_info['private'] = False + data_info['author'] = repo_id.split('/')[0] if repo_id else None + data_info['sha'] = revision + data_info['lastModified'] = None + data_info['gated'] = False + data_info['disabled'] = False + data_info['downloads'] = 0 + data_info['likes'] = 0 + data_info['tags'] = [] + data_info['cardData'] = [] + data_info['createdAt'] = None # e.g. {'rfilename': 'xxx', 'blobId': 'xxx', 'size': 0, 'lfs': {'size': 0, 'sha256': 'xxx', 'pointerSize': 0}} - data['siblings'] = [] - for file_info_d in data_file_list: - file_info = { - 'rfilename': file_info_d['Path'], - 'blobId': file_info_d['Id'], - 'size': file_info_d['Size'], - 'type': 'directory' if file_info_d['Type'] == 'tree' else 'file', - 'lfs': { - 'size': file_info_d['Size'], - 'sha256': file_info_d['Sha256'], - 'pointerSize': 0 - } - } - data['siblings'].append(file_info) + data_siblings = [] + for info_item in repo_info_iter: + if isinstance(info_item, RepoFile): + data_siblings.append( + dict( + rfilename=info_item.rfilename, + blobId=info_item.blob_id, + size=info_item.size, + ) + ) + data_info['siblings'] = data_siblings - return HfDatasetInfo(**data) + return HfDatasetInfo(**data_info) def _list_repo_tree( @@ -225,35 +215,26 @@ def _list_repo_tree( _api = HubApi(timeout=3 * 60, max_retries=3) endpoint = _api.get_endpoint_for_read( repo_id=repo_id, repo_type=REPO_TYPE_DATASET) - if is_relative_path(repo_id) and repo_id.count('/') == 1: - _namespace, _dataset_name = repo_id.split('/') - elif is_relative_path(repo_id) and repo_id.count('/') == 0: - logger.warning(f'Got a relative path: {repo_id} without namespace, ' - f'Use default namespace: {DEFAULT_DATASET_NAMESPACE}') - _namespace, _dataset_name = DEFAULT_DATASET_NAMESPACE, repo_id - else: - raise ValueError(f'Invalid repo_id: {repo_id} !') + # List all files in the repo page_number = 1 page_size = 100 while True: - data: dict = _api.list_repo_tree(dataset_name=_dataset_name, - namespace=_namespace, - revision=revision or DEFAULT_DATASET_REVISION, - root_path=path_in_repo or None, - recursive=True, - page_number=page_number, - page_size=page_size, - endpoint=endpoint - ) - if not ('Code' in data and data['Code'] == 200): - logger.error(f'Get dataset: {repo_id} file list failed, message: {data["Message"]}') - return None + try: + dataset_files = _api.get_dataset_files( + repo_id=repo_id, + revision=revision or DEFAULT_DATASET_REVISION, + root_path=path_in_repo or '/', + recursive=recursive, + page_number=page_number, + page_size=page_size, + endpoint=endpoint, + ) + except Exception as e: + logger.error(f'Get dataset: {repo_id} file list failed, message: {e}') + break - # Parse data (Type: 'tree' or 'blob') - data_file_list: list = data['Data']['Files'] - - for file_info_d in data_file_list: + for file_info_d in dataset_files: path_info = {} path_info['type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file' path_info['path'] = file_info_d['Path'] @@ -262,7 +243,7 @@ def _list_repo_tree( yield RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info) - if len(data_file_list) < page_size: + if len(dataset_files) < page_size: break page_number += 1 @@ -278,30 +259,17 @@ def _get_paths_info( token: Optional[Union[bool, str]] = None, ) -> List[Union[RepoFile, RepoFolder]]: - _api = HubApi() - _namespace, _dataset_name = repo_id.split('/') - endpoint = _api.get_endpoint_for_read( - repo_id=repo_id, repo_type=REPO_TYPE_DATASET) - dataset_hub_id, dataset_type = _api.get_dataset_id_and_type( - dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint) + # Refer to func: `_list_repo_tree()`, for patching `HfApi.list_repo_tree` + repo_info_iter = self.list_repo_tree( + repo_id=repo_id, + recursive=False, + expand=expand, + revision=revision, + repo_type=repo_type, + token=token, + ) - revision: str = revision or DEFAULT_DATASET_REVISION - data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id, - revision=revision, - files_metadata=False, - recursive='False') - data_d: dict = data['Data'] - data_file_list: list = data_d['Files'] - - return [ - RepoFile(path=item_d['Name'], - size=item_d['Size'], - oid=item_d['Revision'], - lfs=None, # TODO: lfs type to be supported - last_commit=None, # TODO: lfs type to be supported - security=None - ) for item_d in data_file_list if item_d['Name'] == 'README.md' - ] + return [item_info for item_info in repo_info_iter] def _download_repo_file(repo_id: str, path_in_repo: str, download_config: DownloadConfig, revision: str):