mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Fix dataset infos (#1414)
This commit is contained in:
@@ -374,7 +374,7 @@ class HubApi:
|
||||
by a `/`.
|
||||
repo_type (`str`, *optional*):
|
||||
`None` or `"model"` if getting repository info from a model. Default is `None`.
|
||||
TODO: support dataset and studio
|
||||
TODO: support studio
|
||||
endpoint(`str`):
|
||||
None or specific endpoint to use, when None, use the default endpoint
|
||||
set in HubApi class (self.endpoint)
|
||||
@@ -886,6 +886,9 @@ class HubApi:
|
||||
raise_on_error(d)
|
||||
|
||||
files = []
|
||||
if not d[API_RESPONSE_FIELD_DATA]['Files']:
|
||||
logger.warning(f'No files found in model {model_id} at revision {revision}.')
|
||||
return files
|
||||
for file in d[API_RESPONSE_FIELD_DATA]['Files']:
|
||||
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
|
||||
continue
|
||||
@@ -993,29 +996,6 @@ class HubApi:
|
||||
dataset_type = resp['Data']['Type']
|
||||
return dataset_id, dataset_type
|
||||
|
||||
def get_dataset_infos(self,
|
||||
dataset_hub_id: str,
|
||||
revision: str,
|
||||
files_metadata: bool = False,
|
||||
timeout: float = 100,
|
||||
recursive: str = 'True',
|
||||
endpoint: Optional[str] = None):
|
||||
"""
|
||||
Get dataset infos.
|
||||
"""
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
|
||||
params = {'Revision': revision, 'Root': None, 'Recursive': recursive}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if files_metadata:
|
||||
params['blobs'] = True
|
||||
r = self.session.get(datahub_url, params=params, cookies=cookies, timeout=timeout)
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(datahub_url, resp, r)
|
||||
|
||||
return resp
|
||||
|
||||
def list_repo_tree(self,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
@@ -1025,6 +1005,11 @@ class HubApi:
|
||||
page_number: int = 1,
|
||||
page_size: int = 100,
|
||||
endpoint: Optional[str] = None):
|
||||
"""
|
||||
@deprecated: Use `get_dataset_files` instead.
|
||||
"""
|
||||
warnings.warn('The function `list_repo_tree` is deprecated, use `get_dataset_files` instead.',
|
||||
DeprecationWarning)
|
||||
|
||||
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
||||
dataset_name=dataset_name, namespace=namespace, endpoint=endpoint)
|
||||
@@ -1044,6 +1029,59 @@ class HubApi:
|
||||
|
||||
return resp
|
||||
|
||||
def get_dataset_files(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
revision: str = DEFAULT_REPOSITORY_REVISION,
|
||||
root_path: str = '/',
|
||||
recursive: bool = True,
|
||||
page_number: int = 1,
|
||||
page_size: int = 100,
|
||||
endpoint: Optional[str] = None):
|
||||
"""
|
||||
Get the dataset files.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository id, in the format of `namespace/dataset_name`.
|
||||
revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
||||
root_path (str): The root path to list. Defaults to '/'.
|
||||
recursive (bool): Whether to list recursively. Defaults to True.
|
||||
page_number (int): The page number for pagination. Defaults to 1.
|
||||
page_size (int): The number of items per page. Defaults to 100.
|
||||
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
||||
|
||||
Returns:
|
||||
List: The response containing the dataset repository tree information.
|
||||
e.g. [{'CommitId': None, 'CommitMessage': '...', 'Size': 0, 'Type': 'tree'}, ...]
|
||||
"""
|
||||
from datasets.utils.file_utils import is_relative_path
|
||||
|
||||
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
else:
|
||||
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
||||
|
||||
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
|
||||
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
|
||||
params = {
|
||||
'Revision': revision,
|
||||
'Root': root_path,
|
||||
'Recursive': 'True' if recursive else 'False',
|
||||
'PageNumber': page_number,
|
||||
'PageSize': page_size
|
||||
}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(datahub_url, resp, r)
|
||||
|
||||
return resp['Data']['Files']
|
||||
|
||||
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str,
|
||||
dataset_id: str, revision: str, endpoint: Optional[str] = None):
|
||||
""" Get the meta file-list of the dataset. """
|
||||
@@ -2150,22 +2188,40 @@ class HubApi:
|
||||
recursive=True,
|
||||
endpoint=endpoint
|
||||
)
|
||||
file_list = [f['Path'] for f in files]
|
||||
file_paths = [f['Path'] for f in files]
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
file_paths = []
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
try:
|
||||
dataset_files: List[Dict[str, Any]] = self.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision or DEFAULT_DATASET_REVISION,
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}')
|
||||
break
|
||||
|
||||
# Parse data (Type: 'tree' or 'blob')
|
||||
for file_info_d in dataset_files:
|
||||
if file_info_d['Type'] != 'tree':
|
||||
file_paths.append(file_info_d['Path'])
|
||||
|
||||
if len(dataset_files) < page_size:
|
||||
break
|
||||
|
||||
page_number += 1
|
||||
else:
|
||||
namespace, dataset_name = repo_id.split('/')
|
||||
dataset_hub_id, _ = self.get_dataset_id_and_type(dataset_name, namespace, endpoint=endpoint)
|
||||
dataset_info = self.get_dataset_infos(
|
||||
dataset_hub_id,
|
||||
revision or DEFAULT_DATASET_REVISION,
|
||||
recursive='True',
|
||||
endpoint=endpoint
|
||||
)
|
||||
files = dataset_info.get('Data', {}).get('Files', [])
|
||||
file_list = [f['Path'] for f in files]
|
||||
raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
# Glob pattern matching
|
||||
to_delete = []
|
||||
for path in file_list:
|
||||
for path in file_paths:
|
||||
for delete_pattern in delete_patterns:
|
||||
if fnmatch.fnmatch(path, delete_pattern):
|
||||
to_delete.append(path)
|
||||
@@ -2181,12 +2237,15 @@ class HubApi:
|
||||
'Revision': revision or DEFAULT_MODEL_REVISION,
|
||||
'FilePath': path
|
||||
}
|
||||
else:
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
owner, dataset_name = repo_id.split('/')
|
||||
url = f'{endpoint}/api/v1/datasets/{owner}/{dataset_name}/repo'
|
||||
params = {
|
||||
'FilePath': path
|
||||
}
|
||||
else:
|
||||
raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
r = self.session.delete(url, params=params, cookies=cookies, headers=headers)
|
||||
raise_for_http_status(r)
|
||||
resp = r.json()
|
||||
|
||||
@@ -234,25 +234,22 @@ def _repo_file_download(
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint)
|
||||
if not ('Code' in files_list_tree
|
||||
and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
try:
|
||||
dataset_files = _api.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Get dataset: {repo_id} file list failed, error: {e}')
|
||||
break
|
||||
|
||||
is_exist = False
|
||||
for repo_file in repo_files:
|
||||
for repo_file in dataset_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
@@ -267,7 +264,7 @@ def _repo_file_download(
|
||||
file_to_download_meta = repo_file
|
||||
is_exist = True
|
||||
break
|
||||
if len(repo_files) < page_size or is_exist:
|
||||
if len(dataset_files) < page_size or is_exist:
|
||||
break
|
||||
page_number += 1
|
||||
|
||||
|
||||
@@ -381,8 +381,8 @@ def _snapshot_download(
|
||||
revision_detail = revision or DEFAULT_DATASET_REVISION
|
||||
|
||||
logger.info('Fetching dataset repo file list...')
|
||||
repo_files = fetch_repo_files(_api, name, group_or_owner,
|
||||
revision_detail, endpoint)
|
||||
repo_files = fetch_repo_files(_api, repo_id, revision_detail,
|
||||
endpoint)
|
||||
|
||||
if repo_files is None:
|
||||
logger.error(
|
||||
@@ -415,32 +415,28 @@ def _snapshot_download(
|
||||
return cache_root_path
|
||||
|
||||
|
||||
def fetch_repo_files(_api, name, group_or_owner, revision, endpoint):
|
||||
def fetch_repo_files(_api, repo_id, revision, endpoint):
|
||||
page_number = 1
|
||||
page_size = 150
|
||||
repo_files = []
|
||||
|
||||
while True:
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint)
|
||||
try:
|
||||
dataset_files = _api.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint)
|
||||
except Exception as e:
|
||||
logger.error(f'Error fetching dataset files: {e}')
|
||||
break
|
||||
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
logger.error(f'Get dataset file list failed, request_id: \
|
||||
{files_list_tree["RequestId"]}, message: {files_list_tree["Message"]}'
|
||||
)
|
||||
return None
|
||||
repo_files.extend(dataset_files)
|
||||
|
||||
cur_repo_files = files_list_tree['Data']['Files']
|
||||
repo_files.extend(cur_repo_files)
|
||||
|
||||
if len(cur_repo_files) < page_size:
|
||||
if len(dataset_files) < page_size:
|
||||
break
|
||||
|
||||
page_number += 1
|
||||
|
||||
@@ -158,56 +158,46 @@ def _dataset_info(
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
_api = HubApi()
|
||||
_namespace, _dataset_name = repo_id.split('/')
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
|
||||
dataset_hub_id, dataset_type = _api.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint)
|
||||
# Note: refer to `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
|
||||
repo_info_iter = self.list_repo_tree(
|
||||
repo_id=repo_id,
|
||||
path_in_repo='/',
|
||||
revision=revision,
|
||||
recursive=False,
|
||||
expand=expand,
|
||||
token=token,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
)
|
||||
|
||||
revision: str = revision or DEFAULT_DATASET_REVISION
|
||||
data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id,
|
||||
revision=revision,
|
||||
files_metadata=files_metadata,
|
||||
timeout=timeout,
|
||||
endpoint=endpoint)
|
||||
|
||||
# Parse data
|
||||
data_d: dict = data['Data']
|
||||
data_file_list: list = data_d['Files']
|
||||
# commit_info: dict = data_d['LatestCommitter']
|
||||
|
||||
# Update data # TODO: columns align with HfDatasetInfo
|
||||
data['id'] = repo_id
|
||||
data['private'] = False
|
||||
data['author'] = repo_id.split('/')[0] if repo_id else None
|
||||
data['sha'] = revision
|
||||
data['lastModified'] = None
|
||||
data['gated'] = False
|
||||
data['disabled'] = False
|
||||
data['downloads'] = 0
|
||||
data['likes'] = 0
|
||||
data['tags'] = []
|
||||
data['cardData'] = []
|
||||
data['createdAt'] = None
|
||||
# Update data_info
|
||||
data_info = dict({})
|
||||
data_info['id'] = repo_id
|
||||
data_info['private'] = False
|
||||
data_info['author'] = repo_id.split('/')[0] if repo_id else None
|
||||
data_info['sha'] = revision
|
||||
data_info['lastModified'] = None
|
||||
data_info['gated'] = False
|
||||
data_info['disabled'] = False
|
||||
data_info['downloads'] = 0
|
||||
data_info['likes'] = 0
|
||||
data_info['tags'] = []
|
||||
data_info['cardData'] = []
|
||||
data_info['createdAt'] = None
|
||||
|
||||
# e.g. {'rfilename': 'xxx', 'blobId': 'xxx', 'size': 0, 'lfs': {'size': 0, 'sha256': 'xxx', 'pointerSize': 0}}
|
||||
data['siblings'] = []
|
||||
for file_info_d in data_file_list:
|
||||
file_info = {
|
||||
'rfilename': file_info_d['Path'],
|
||||
'blobId': file_info_d['Id'],
|
||||
'size': file_info_d['Size'],
|
||||
'type': 'directory' if file_info_d['Type'] == 'tree' else 'file',
|
||||
'lfs': {
|
||||
'size': file_info_d['Size'],
|
||||
'sha256': file_info_d['Sha256'],
|
||||
'pointerSize': 0
|
||||
}
|
||||
}
|
||||
data['siblings'].append(file_info)
|
||||
data_siblings = []
|
||||
for info_item in repo_info_iter:
|
||||
if isinstance(info_item, RepoFile):
|
||||
data_siblings.append(
|
||||
dict(
|
||||
rfilename=info_item.rfilename,
|
||||
blobId=info_item.blob_id,
|
||||
size=info_item.size,
|
||||
)
|
||||
)
|
||||
data_info['siblings'] = data_siblings
|
||||
|
||||
return HfDatasetInfo(**data)
|
||||
return HfDatasetInfo(**data_info)
|
||||
|
||||
|
||||
def _list_repo_tree(
|
||||
@@ -225,35 +215,26 @@ def _list_repo_tree(
|
||||
_api = HubApi(timeout=3 * 60, max_retries=3)
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
|
||||
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
||||
_namespace, _dataset_name = repo_id.split('/')
|
||||
elif is_relative_path(repo_id) and repo_id.count('/') == 0:
|
||||
logger.warning(f'Got a relative path: {repo_id} without namespace, '
|
||||
f'Use default namespace: {DEFAULT_DATASET_NAMESPACE}')
|
||||
_namespace, _dataset_name = DEFAULT_DATASET_NAMESPACE, repo_id
|
||||
else:
|
||||
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
||||
|
||||
# List all files in the repo
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
data: dict = _api.list_repo_tree(dataset_name=_dataset_name,
|
||||
namespace=_namespace,
|
||||
revision=revision or DEFAULT_DATASET_REVISION,
|
||||
root_path=path_in_repo or None,
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint
|
||||
)
|
||||
if not ('Code' in data and data['Code'] == 200):
|
||||
logger.error(f'Get dataset: {repo_id} file list failed, message: {data["Message"]}')
|
||||
return None
|
||||
try:
|
||||
dataset_files = _api.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision or DEFAULT_DATASET_REVISION,
|
||||
root_path=path_in_repo or '/',
|
||||
recursive=recursive,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Get dataset: {repo_id} file list failed, message: {e}')
|
||||
break
|
||||
|
||||
# Parse data (Type: 'tree' or 'blob')
|
||||
data_file_list: list = data['Data']['Files']
|
||||
|
||||
for file_info_d in data_file_list:
|
||||
for file_info_d in dataset_files:
|
||||
path_info = {}
|
||||
path_info['type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file'
|
||||
path_info['path'] = file_info_d['Path']
|
||||
@@ -262,7 +243,7 @@ def _list_repo_tree(
|
||||
|
||||
yield RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info)
|
||||
|
||||
if len(data_file_list) < page_size:
|
||||
if len(dataset_files) < page_size:
|
||||
break
|
||||
page_number += 1
|
||||
|
||||
@@ -278,30 +259,17 @@ def _get_paths_info(
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
) -> List[Union[RepoFile, RepoFolder]]:
|
||||
|
||||
_api = HubApi()
|
||||
_namespace, _dataset_name = repo_id.split('/')
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
|
||||
dataset_hub_id, dataset_type = _api.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint)
|
||||
# Refer to func: `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
|
||||
repo_info_iter = self.list_repo_tree(
|
||||
repo_id=repo_id,
|
||||
recursive=False,
|
||||
expand=expand,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
token=token,
|
||||
)
|
||||
|
||||
revision: str = revision or DEFAULT_DATASET_REVISION
|
||||
data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id,
|
||||
revision=revision,
|
||||
files_metadata=False,
|
||||
recursive='False')
|
||||
data_d: dict = data['Data']
|
||||
data_file_list: list = data_d['Files']
|
||||
|
||||
return [
|
||||
RepoFile(path=item_d['Name'],
|
||||
size=item_d['Size'],
|
||||
oid=item_d['Revision'],
|
||||
lfs=None, # TODO: lfs type to be supported
|
||||
last_commit=None, # TODO: lfs type to be supported
|
||||
security=None
|
||||
) for item_d in data_file_list if item_d['Name'] == 'README.md'
|
||||
]
|
||||
return [item_info for item_info in repo_info_iter]
|
||||
|
||||
|
||||
def _download_repo_file(repo_id: str, path_in_repo: str, download_config: DownloadConfig, revision: str):
|
||||
|
||||
Reference in New Issue
Block a user