mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Merge commit '2fc6d8e82229e6bd4af9cefc362a0cad95c43227' into release/1.29
This commit is contained in:
@@ -63,6 +63,7 @@ from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
handle_http_response, is_ok,
|
||||
raise_for_http_status, raise_on_error)
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.info import DatasetInfo, ModelInfo
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.utils.aigc import AigcModel
|
||||
from modelscope.hub.utils.utils import (add_content_to_file, get_domain,
|
||||
@@ -83,7 +84,8 @@ from modelscope.utils.file_utils import get_file_hash, get_file_size
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.repo_utils import (DATASET_LFS_SUFFIX,
|
||||
DEFAULT_IGNORE_PATTERNS,
|
||||
MODEL_LFS_SUFFIX, CommitInfo,
|
||||
MODEL_LFS_SUFFIX,
|
||||
CommitHistoryResponse, CommitInfo,
|
||||
CommitOperation, CommitOperationAdd,
|
||||
RepoUtils)
|
||||
from modelscope.utils.thread_utils import thread_executor
|
||||
@@ -434,6 +436,101 @@ class HubApi:
|
||||
else:
|
||||
return prefer_endpoint
|
||||
|
||||
def model_info(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
endpoint: Optional[str] = None) -> ModelInfo:
|
||||
"""Get model information including commit history.
|
||||
|
||||
Args:
|
||||
repo_id (str): The model id in the format of
|
||||
``namespace/model_name``.
|
||||
revision (str, optional): Specific revision of the model.
|
||||
Defaults to ``DEFAULT_MODEL_REVISION``.
|
||||
endpoint (str, optional): Hub endpoint to use. When ``None``,
|
||||
use the endpoint specified when initializing :class:`HubApi`.
|
||||
|
||||
Returns:
|
||||
ModelInfo: The model detailed information returned by
|
||||
ModelScope Hub with commit history.
|
||||
"""
|
||||
owner_or_group, _ = model_id_to_group_owner_name(repo_id)
|
||||
model_data = self.get_model(
|
||||
model_id=repo_id, revision=revision, endpoint=endpoint)
|
||||
commits = self.list_repo_commits(
|
||||
repo_id=repo_id, repo_type=REPO_TYPE_MODEL, revision=revision, endpoint=endpoint)
|
||||
|
||||
# Create ModelInfo from API response data
|
||||
model_info = ModelInfo(**model_data, commits=commits, author=owner_or_group)
|
||||
|
||||
return model_info
|
||||
|
||||
def dataset_info(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
endpoint: Optional[str] = None) -> DatasetInfo:
|
||||
"""Get dataset information including commit history.
|
||||
|
||||
Args:
|
||||
repo_id (str): The dataset id in the format of
|
||||
``namespace/dataset_name``.
|
||||
revision (str, optional): Specific revision of the dataset.
|
||||
Defaults to ``None``.
|
||||
endpoint (str, optional): Hub endpoint to use. When ``None``,
|
||||
use the endpoint specified when initializing :class:`HubApi`.
|
||||
|
||||
Returns:
|
||||
DatasetInfo: The dataset detailed information returned by
|
||||
ModelScope Hub with commit history.
|
||||
"""
|
||||
owner_or_group, _ = model_id_to_group_owner_name(repo_id)
|
||||
dataset_data = self.get_dataset(
|
||||
dataset_id=repo_id, revision=revision, endpoint=endpoint)
|
||||
commits = self.list_repo_commits(
|
||||
repo_id=repo_id, repo_type=REPO_TYPE_DATASET, revision=revision, endpoint=endpoint)
|
||||
|
||||
# Create DatasetInfo from API response data
|
||||
dataset_info = DatasetInfo(**dataset_data, commits=commits, author=owner_or_group)
|
||||
|
||||
return dataset_info
|
||||
|
||||
def repo_info(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
endpoint: Optional[str] = None
|
||||
) -> Union[ModelInfo, DatasetInfo]:
|
||||
"""Get repository information for models or datasets.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository id in the format of
|
||||
``namespace/repo_name``.
|
||||
revision (str, optional): Specific revision of the repository.
|
||||
Currently only effective for model repositories. Defaults to
|
||||
``DEFAULT_MODEL_REVISION``.
|
||||
repo_type (str, optional): Type of the repository. Supported
|
||||
values are ``"model"`` and ``"dataset"``. If not provided,
|
||||
``"model"`` is assumed.
|
||||
endpoint (str, optional): Hub endpoint to use. When ``None``,
|
||||
use the endpoint specified when initializing :class:`HubApi`.
|
||||
|
||||
Returns:
|
||||
Union[ModelInfo, DatasetInfo]: The repository detailed information
|
||||
returned by ModelScope Hub.
|
||||
"""
|
||||
if repo_type is None or repo_type == REPO_TYPE_MODEL:
|
||||
return self.model_info(repo_id=repo_id, revision=revision, endpoint=endpoint)
|
||||
|
||||
if repo_type == REPO_TYPE_DATASET:
|
||||
return self.dataset_info(repo_id=repo_id, revision=revision, endpoint=endpoint)
|
||||
|
||||
raise InvalidParameter(
|
||||
f'Arg repo_type {repo_type} not supported. Please choose from {REPO_TYPE_SUPPORT}.')
|
||||
|
||||
def repo_exists(
|
||||
self,
|
||||
repo_id: str,
|
||||
@@ -1111,6 +1208,68 @@ class HubApi:
|
||||
|
||||
return resp
|
||||
|
||||
def list_repo_commits(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
page_number: int = 1,
|
||||
page_size: int = 50,
|
||||
endpoint: Optional[str] = None):
|
||||
"""
|
||||
Get the commit history for a repository.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repository id, in the format of `namespace/repo_name`.
|
||||
repo_type (Optional[str]): The type of the repository. Supported types are `model` and `dataset`.
|
||||
revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
||||
page_number (int): The page number for pagination. Defaults to 1.
|
||||
page_size (int): The number of commits per page. Defaults to 50.
|
||||
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
||||
|
||||
Returns:
|
||||
CommitHistoryResponse: The commit history response.
|
||||
|
||||
Examples:
|
||||
>>> from modelscope.hub.api import HubApi
|
||||
>>> api = HubApi()
|
||||
>>> commit_history = api.list_repo_commits('meituan/Meeseeks')
|
||||
>>> print(f"Total commits: {commit_history.total_count}")
|
||||
>>> for commit in commit_history.commits:
|
||||
... print(f"{commit.short_id}: {commit.title}")
|
||||
"""
|
||||
from datasets.utils.file_utils import is_relative_path
|
||||
|
||||
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
else:
|
||||
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
||||
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
|
||||
commits_url = f'{endpoint}/api/v1/{repo_type}s/{repo_id}/commits' if repo_type else \
|
||||
f'{endpoint}/api/v1/models/{repo_id}/commits'
|
||||
params = {
|
||||
'Ref': revision or DEFAULT_MODEL_REVISION or DEFAULT_REPOSITORY_REVISION,
|
||||
'PageNumber': page_number,
|
||||
'PageSize': page_size
|
||||
}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
try:
|
||||
r = self.session.get(commits_url, params=params,
|
||||
cookies=cookies, headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
resp = r.json()
|
||||
raise_on_error(resp)
|
||||
|
||||
if resp.get('Code') == HTTPStatus.OK:
|
||||
return CommitHistoryResponse.from_api_response(resp)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise Exception(f'Failed to get repository commits for {repo_id}: {str(e)}')
|
||||
|
||||
def get_dataset_files(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
@@ -1164,6 +1323,39 @@ class HubApi:
|
||||
|
||||
return resp['Data']['Files']
|
||||
|
||||
def get_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
endpoint: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Get the dataset information.
|
||||
|
||||
Args:
|
||||
dataset_id (str): The dataset id.
|
||||
revision (Optional[str]): The revision of the dataset.
|
||||
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information.
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
|
||||
if revision:
|
||||
path = f'{endpoint}/api/v1/datasets/{dataset_id}?Revision={revision}'
|
||||
else:
|
||||
path = f'{endpoint}/api/v1/datasets/{dataset_id}'
|
||||
|
||||
r = self.session.get(
|
||||
path, cookies=cookies, headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(path, resp, r)
|
||||
return resp[API_RESPONSE_FIELD_DATA]
|
||||
|
||||
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str,
|
||||
dataset_id: str, revision: str, endpoint: Optional[str] = None):
|
||||
""" Get the meta file-list of the dataset. """
|
||||
|
||||
259
modelscope/hub/info.py
Normal file
259
modelscope/hub/info.py
Normal file
@@ -0,0 +1,259 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
# yapf: disable
|
||||
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from modelscope.hub.utils.utils import convert_timestamp
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrganizationInfo:
|
||||
"""Organization information for a repository."""
|
||||
id: Optional[int]
|
||||
name: Optional[str]
|
||||
full_name: Optional[str]
|
||||
description: Optional[str]
|
||||
avatar: Optional[str]
|
||||
github_address: Optional[str]
|
||||
type: Optional[int]
|
||||
email: Optional[str]
|
||||
created_time: Optional[datetime.datetime]
|
||||
modified_time: Optional[datetime.datetime]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.pop('Id', None)
|
||||
self.name = kwargs.pop('Name', '')
|
||||
self.full_name = kwargs.pop('FullName', '')
|
||||
self.description = kwargs.pop('Description', '')
|
||||
self.avatar = kwargs.pop('Avatar', '')
|
||||
self.github_address = kwargs.pop('GithubAddress', '')
|
||||
self.type = kwargs.pop('Type', kwargs.pop('type', None))
|
||||
self.email = kwargs.pop('Email', kwargs.pop('email', ''))
|
||||
created_time = kwargs.pop('GmtCreated', kwargs.pop('created_time', None))
|
||||
self.created_time = convert_timestamp(created_time) if created_time else None
|
||||
modified_time = kwargs.pop('GmtModified', kwargs.pop('modified_time', None))
|
||||
self.modified_time = convert_timestamp(modified_time) if modified_time else None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelInfo:
|
||||
"""
|
||||
Contains detailed information about a model on ModelScope Hub. This object is returned by [`model_info`].
|
||||
|
||||
Attributes:
|
||||
id (`int`, *optional*): Model ID.
|
||||
name (`str`, *optional*): Model name.
|
||||
author (`str`, *optional*): Model author.
|
||||
chinese_name (`str`, *optional*): Chinese display name.
|
||||
visibility (`int`, *optional*): Visibility level (1=private, 5=public).
|
||||
is_published (`int`, *optional*): Whether the model is published.
|
||||
is_online (`int`, *optional*): Whether the model is online.
|
||||
already_star (`bool`, *optional*): Whether current user has starred this model.
|
||||
description (`str`, *optional*): Model description.
|
||||
license (`str`, *optional*): Model license.
|
||||
downloads (`int`, *optional*): Number of downloads.
|
||||
likes (`int`, *optional*): Number of likes.
|
||||
created_at (`datetime`, *optional*): Date of creation of the repo on the Hub..
|
||||
last_updated_time (`datetime`, *optional*): Last update timestamp.
|
||||
architectures (`List[str]`, *optional*): Model architectures.
|
||||
model_type (`List[str]`, *optional*): Model types.
|
||||
tasks (`List[Dict[str, Any]]`, *optional*): Supported tasks.
|
||||
readme_content (`str`, *optional*): README content.
|
||||
organization (`OrganizationInfo`, *optional*): Organization information.
|
||||
created_by (`str`, *optional*): Creator username.
|
||||
is_certification (`int`, *optional*): Certification status.
|
||||
approval_mode (`int`, *optional*): Approval mode.
|
||||
card_ready (`int`, *optional*): Whether model card is ready.
|
||||
backend_support (`str`, *optional*): Backend support information.
|
||||
model_infos (`Dict[str, Any]`, *optional*): Detailed model configuration information.
|
||||
tags (`List[str]`, *optional*): Model Tags.
|
||||
is_accessible (`int`, *optional*): Whether accessible.
|
||||
revision (`str`, *optional*): Revision/branch.
|
||||
related_arxiv_id (`List[str]`, *optional*): Related arXiv paper IDs.
|
||||
related_paper (`List[int]`, *optional*): Related papers.
|
||||
sha (`str`, *optional*): Latest commit SHA.
|
||||
last_modified (`datetime`, *optional*): Latest commit date.
|
||||
last_commit (`Dict[str, Any]`, *optional*): Latest commit information.
|
||||
"""
|
||||
|
||||
id: Optional[int]
|
||||
name: Optional[str]
|
||||
author: Optional[str]
|
||||
chinese_name: Optional[str]
|
||||
visibility: Optional[int]
|
||||
is_published: Optional[int]
|
||||
is_online: Optional[int]
|
||||
already_star: Optional[bool]
|
||||
description: Optional[str]
|
||||
license: Optional[str]
|
||||
downloads: Optional[int]
|
||||
likes: Optional[int]
|
||||
created_at: Optional[datetime.datetime]
|
||||
last_updated_time: Optional[datetime.datetime]
|
||||
architectures: Optional[List[str]]
|
||||
model_type: Optional[List[str]]
|
||||
tasks: Optional[List[Dict[str, Any]]]
|
||||
readme_content: Optional[str]
|
||||
organization: Optional[OrganizationInfo]
|
||||
created_by: Optional[str]
|
||||
|
||||
# Certification and approval
|
||||
is_certification: Optional[int]
|
||||
approval_mode: Optional[int]
|
||||
card_ready: Optional[int]
|
||||
|
||||
# Model specific
|
||||
backend_support: Optional[str]
|
||||
model_infos: Optional[Dict[str, Any]]
|
||||
|
||||
# Content and settings
|
||||
tags: Optional[List[str]]
|
||||
|
||||
# Additional flags
|
||||
is_accessible: Optional[int]
|
||||
|
||||
# Revision and version info
|
||||
revision: Optional[str]
|
||||
|
||||
# External references
|
||||
related_arxiv_id: Optional[List[str]]
|
||||
related_paper: Optional[List[int]]
|
||||
|
||||
# latest commit infomation
|
||||
last_commit: Optional[Dict[str, Any]]
|
||||
sha: Optional[str]
|
||||
last_modified: Optional[datetime.datetime]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.pop('Id', None)
|
||||
self.name = kwargs.pop('Name', '')
|
||||
self.chinese_name = kwargs.pop('ChineseName', '')
|
||||
self.visibility = kwargs.pop('Visibility', None)
|
||||
self.is_published = kwargs.pop('IsPublished', None)
|
||||
self.is_online = kwargs.pop('IsOnline', None)
|
||||
self.already_star = kwargs.pop('AlreadyStar', None)
|
||||
self.description = kwargs.pop('Description', '')
|
||||
self.license = kwargs.pop('License', '')
|
||||
self.downloads = kwargs.pop('Downloads', None)
|
||||
self.likes = kwargs.pop('Stars', None) or kwargs.pop('Likes', None)
|
||||
created_time = kwargs.pop('CreatedTime', None)
|
||||
self.created_at = convert_timestamp(created_time) if created_time else None
|
||||
last_updated_time = kwargs.pop('LastUpdatedTime', None)
|
||||
self.last_updated_time = convert_timestamp(last_updated_time) if last_updated_time else None
|
||||
self.architectures = kwargs.pop('Architectures', [])
|
||||
self.model_type = kwargs.pop('ModelType', [])
|
||||
self.tasks = kwargs.pop('Tasks', [])
|
||||
self.readme_content = kwargs.pop('ReadMeContent', '')
|
||||
org_data = kwargs.pop('Organization', None)
|
||||
self.organization = OrganizationInfo(**org_data) if org_data else None
|
||||
self.created_by = kwargs.pop('CreatedBy', None)
|
||||
self.is_certification = kwargs.pop('IsCertification', None)
|
||||
self.approval_mode = kwargs.pop('ApprovalMode', None)
|
||||
self.card_ready = kwargs.pop('CardReady', None)
|
||||
self.backend_support = kwargs.pop('BackendSupport', '{}')
|
||||
self.model_infos = kwargs.pop('ModelInfos', {})
|
||||
self.tags = kwargs.pop('Tags', [])
|
||||
self.is_accessible = kwargs.pop('IsAccessible', None)
|
||||
self.revision = kwargs.pop('Revision', '')
|
||||
self.related_arxiv_id = kwargs.pop('RelatedArxivId', [])
|
||||
self.related_paper = kwargs.pop('RelatedPaper', [])
|
||||
|
||||
commits = kwargs.pop('commits', None)
|
||||
if commits and hasattr(commits, 'commits') and commits.commits:
|
||||
last_commit = commits.commits[0]
|
||||
self.last_commit = last_commit.to_dict() if hasattr(last_commit, 'to_dict') else None
|
||||
self.sha = self.last_commit.get('id') if self.last_commit else None
|
||||
self.last_modified = convert_timestamp(self.last_commit.get('committed_date')) if self.last_commit else None
|
||||
else:
|
||||
self.last_commit = None
|
||||
self.sha = None
|
||||
self.last_modified = None
|
||||
self.author = kwargs.pop('author', '')
|
||||
|
||||
# backward compatibility
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetInfo:
|
||||
"""
|
||||
Contains detailed information about a dataset on ModelScope Hub. This object is returned by [`dataset_info`].
|
||||
|
||||
Attributes:
|
||||
id (`int`, *optional*)): Dataset ID.
|
||||
name (`str`, *optional*)): Dataset name.
|
||||
author (`str`, *optional*): Dataset owner (user or organization).
|
||||
chinese_name (`str`, *optional*): Chinese display name.
|
||||
visibility (`int`, *optional*)): Visibility level (1=private, 3=interal, 5=public).
|
||||
'internal' means visible to logged-in users only.
|
||||
already_star (`bool`, *optional*)): Whether current user has starred this dataset.
|
||||
description (`str`, *optional*): Dataset description.
|
||||
license (`str`, *optional*)): Dataset license.
|
||||
downloads (`int`, *optional*)): Number of downloads.
|
||||
likes (`int`, *optional*)): Number of likes.
|
||||
created_at (`int`, *optional*): Creation timestamp.
|
||||
last_updated_time (`int`, *optional*): Last update timestamp.
|
||||
readme_content (`str`, *optional*): README content.
|
||||
organization (`OrganizationInfo`, *optional*): Organization information.
|
||||
created_by (`str`, *optional*): Creator username.
|
||||
tags (`List[Dict[str, Any]]`): Dataset tags.
|
||||
last_commit (`Dict[str, Any]`, *optional*): Latest commit information.
|
||||
"""
|
||||
|
||||
id: Optional[int]
|
||||
name: Optional[str]
|
||||
author: Optional[str]
|
||||
chinese_name: Optional[str]
|
||||
visibility: Optional[Literal[1, 3, 5]]
|
||||
already_star: Optional[bool]
|
||||
description: Optional[str]
|
||||
license: Optional[str]
|
||||
downloads: Optional[int]
|
||||
likes: Optional[int]
|
||||
created_at: Optional[datetime.datetime]
|
||||
last_updated_time: Optional[datetime.datetime]
|
||||
readme_content: Optional[str]
|
||||
organization: Optional[OrganizationInfo]
|
||||
created_by: Optional[str]
|
||||
tags: Optional[List[Dict[str, Any]]]
|
||||
last_commit: Optional[Dict[str, Any]]
|
||||
sha: Optional[str]
|
||||
last_modified: Optional[datetime.datetime]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.id = kwargs.pop('Id', None)
|
||||
self.name = kwargs.pop('Name', '')
|
||||
self.author = kwargs.pop('author', kwargs.pop('Owner', None) or kwargs.pop('Namespace', None))
|
||||
self.chinese_name = kwargs.pop('ChineseName', '')
|
||||
self.visibility = kwargs.pop('Visibility', None)
|
||||
self.already_star = kwargs.pop('AlreadyStar', None)
|
||||
self.description = kwargs.pop('Description', '')
|
||||
self.likes = kwargs.pop('Likes', None) or kwargs.pop('Stars', None)
|
||||
self.license = kwargs.pop('License', '')
|
||||
self.downloads = kwargs.pop('Downloads', None)
|
||||
created_time = kwargs.pop('GmtCreate', None)
|
||||
self.created_at = convert_timestamp(created_time) if created_time else None
|
||||
last_updated_time = kwargs.pop('LastUpdatedTime', None)
|
||||
self.last_updated_time = convert_timestamp(last_updated_time) if last_updated_time else None
|
||||
self.readme_content = kwargs.pop('ReadMeContent', '')
|
||||
org_data = kwargs.pop('Organization', None)
|
||||
self.organization = OrganizationInfo(**org_data) if org_data else None
|
||||
self.created_by = kwargs.pop('CreatedBy', None)
|
||||
self.tags = kwargs.pop('Tags', [])
|
||||
commits = kwargs.pop('commits', None)
|
||||
|
||||
if commits and hasattr(commits, 'commits') and commits.commits:
|
||||
last_commit = commits.commits[0]
|
||||
self.last_commit = last_commit.to_dict() if hasattr(last_commit, 'to_dict') else None
|
||||
self.sha = self.last_commit.get('id') if self.last_commit else None
|
||||
self.last_modified = convert_timestamp(self.last_commit.get('committed_date')) if self.last_commit else None
|
||||
else:
|
||||
self.last_commit = None
|
||||
self.sha = None
|
||||
self.last_modified = None
|
||||
|
||||
# backward compatibility
|
||||
self.__dict__.update(kwargs)
|
||||
@@ -4,6 +4,7 @@ import hashlib
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
import threading
|
||||
from shutil import move, rmtree
|
||||
from typing import Dict
|
||||
|
||||
@@ -39,6 +40,8 @@ class FileSystemCache(object):
|
||||
cache_root_location (str): The root location to store files.
|
||||
kwargs(dict): The keyword arguments.
|
||||
"""
|
||||
self._cache_lock = threading.RLock()
|
||||
|
||||
os.makedirs(cache_root_location, exist_ok=True)
|
||||
self.cache_root_location = cache_root_location
|
||||
self.load_cache()
|
||||
@@ -55,15 +58,30 @@ class FileSystemCache(object):
|
||||
self.cached_files = pickle.load(f)
|
||||
|
||||
def save_cached_files(self):
|
||||
"""Save cache metadata."""
|
||||
# save new meta to tmp and move to KEY_FILE_NAME
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
# TODO: Sync file write
|
||||
fd, fn = tempfile.mkstemp()
|
||||
with open(fd, 'wb') as f:
|
||||
pickle.dump(self.cached_files, f)
|
||||
move(fn, cache_keys_file_path)
|
||||
"""
|
||||
Save cache metadata in order to verify that the cached content is consistent with the remote content.
|
||||
|
||||
Example of the cached content:
|
||||
[{'Path': 'configuration.json', 'Revision': 'f01dxxx'}, {'Path': 'model.bin', 'Revision': '1159xxx'}, ...]
|
||||
"""
|
||||
with self._cache_lock:
|
||||
cache_keys_file_path = os.path.join(self.cache_root_location,
|
||||
FileSystemCache.KEY_FILE_NAME)
|
||||
fd, temp_filename = tempfile.mkstemp(
|
||||
suffix='.tmp', dir=self.cache_root_location)
|
||||
|
||||
try:
|
||||
with os.fdopen(fd, 'wb') as f:
|
||||
pickle.dump(self.cached_files, f)
|
||||
move(temp_filename, cache_keys_file_path)
|
||||
except Exception:
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
if os.path.exists(temp_filename):
|
||||
os.unlink(temp_filename)
|
||||
raise
|
||||
|
||||
def get_file(self, key):
|
||||
"""Check the key is in the cache, if exist, return the file, otherwise return None.
|
||||
|
||||
@@ -5,6 +5,7 @@ import hashlib
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import zoneinfo
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Generator, List, Optional, Union
|
||||
@@ -299,3 +300,83 @@ def weak_file_lock(lock_file: Union[str, Path],
|
||||
Path(lock_file).unlink()
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def convert_timestamp(time_stamp: Union[int, str, datetime],
|
||||
time_zone: str = 'Asia/Shanghai') -> Optional[datetime]:
|
||||
"""Convert a UNIX/string timestamp to a timezone-aware datetime object.
|
||||
|
||||
Args:
|
||||
time_stamp: UNIX timestamp (int), ISO string, or datetime object
|
||||
time_zone: Target timezone for non-UTC timestamps (default: 'Asia/Shanghai')
|
||||
|
||||
Returns:
|
||||
Timezone-aware datetime object or None if input is None
|
||||
"""
|
||||
if not time_stamp:
|
||||
return None
|
||||
|
||||
# Handle datetime objects first
|
||||
if isinstance(time_stamp, datetime):
|
||||
return time_stamp
|
||||
|
||||
if isinstance(time_stamp, str):
|
||||
try:
|
||||
if time_stamp.endswith('Z'):
|
||||
# Normalize fractional seconds to 6 digits
|
||||
if '.' not in time_stamp:
|
||||
# No fractional seconds (e.g., "2024-11-16T00:27:02Z")
|
||||
time_stamp = time_stamp[:-1] + '.000000Z'
|
||||
else:
|
||||
# Has fractional seconds (e.g., "2022-08-19T07:19:38.123456789Z")
|
||||
base, fraction = time_stamp[:-1].split('.')
|
||||
# Truncate or pad to 6 digits
|
||||
fraction = fraction[:6].ljust(6, '0')
|
||||
time_stamp = f'{base}.{fraction}Z'
|
||||
|
||||
dt = datetime.strptime(time_stamp,
|
||||
'%Y-%m-%dT%H:%M:%S.%fZ').replace(
|
||||
tzinfo=zoneinfo.ZoneInfo('UTC'))
|
||||
if time_zone != 'UTC':
|
||||
dt = dt.astimezone(zoneinfo.ZoneInfo(time_zone))
|
||||
return dt
|
||||
else:
|
||||
# Try parsing common ISO formats
|
||||
formats = [
|
||||
'%Y-%m-%dT%H:%M:%S.%f', # With microseconds
|
||||
'%Y-%m-%dT%H:%M:%S', # Without microseconds
|
||||
'%Y-%m-%d %H:%M:%S.%f', # Space separator with microseconds
|
||||
'%Y-%m-%d %H:%M:%S', # Space separator without microseconds
|
||||
]
|
||||
for fmt in formats:
|
||||
try:
|
||||
return datetime.strptime(
|
||||
time_stamp,
|
||||
fmt).replace(tzinfo=zoneinfo.ZoneInfo(time_zone))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Unsupported timestamp format: '{time_stamp}'")
|
||||
|
||||
except ValueError as e:
|
||||
raise ValueError(
|
||||
f"Cannot parse '{time_stamp}' as a datetime. Expected formats: "
|
||||
f"'YYYY-MM-DDTHH:MM:SS[.ffffff]Z' (UTC) or 'YYYY-MM-DDTHH:MM:SS[.ffffff]' (local)"
|
||||
) from e
|
||||
|
||||
elif isinstance(time_stamp, int):
|
||||
try:
|
||||
# UNIX timestamps are always in UTC, then convert to target timezone
|
||||
return datetime.fromtimestamp(
|
||||
time_stamp, tz=zoneinfo.ZoneInfo('UTC')).astimezone(
|
||||
zoneinfo.ZoneInfo(time_zone))
|
||||
except (ValueError, OSError) as e:
|
||||
raise ValueError(
|
||||
f"Cannot convert '{time_stamp}' to datetime. Ensure it's a valid UNIX timestamp."
|
||||
) from e
|
||||
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported type '{type(time_stamp)}'. Expected int, str, or datetime."
|
||||
)
|
||||
|
||||
@@ -8,12 +8,14 @@ import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import (Any, BinaryIO, Callable, Generator, Iterable, Iterator,
|
||||
List, Literal, Optional, TypeVar, Union)
|
||||
|
||||
from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
|
||||
from modelscope.hub.utils.utils import convert_timestamp
|
||||
from modelscope.utils.file_utils import get_file_hash
|
||||
|
||||
T = TypeVar('T')
|
||||
@@ -257,6 +259,76 @@ class CommitInfo:
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class DetailedCommitInfo:
|
||||
"""Detailed commit information from repository history API."""
|
||||
id: Optional[str]
|
||||
short_id: Optional[str]
|
||||
title: Optional[str]
|
||||
message: Optional[str]
|
||||
author_name: Optional[str]
|
||||
authored_date: Optional[datetime]
|
||||
author_email: Optional[str]
|
||||
committed_date: Optional[datetime]
|
||||
committer_name: Optional[str]
|
||||
committer_email: Optional[str]
|
||||
created_at: Optional[datetime]
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, data: dict) -> 'DetailedCommitInfo':
|
||||
"""Create DetailedCommitInfo from API response data."""
|
||||
return cls(
|
||||
id=data.get('Id', ''),
|
||||
short_id=data.get('ShortId', ''),
|
||||
title=data.get('Title', ''),
|
||||
message=data.get('Message', ''),
|
||||
author_name=data.get('AuthorName', ''),
|
||||
authored_date=convert_timestamp(data.get('AuthoredDate', None)),
|
||||
author_email=data.get('AuthorEmail', ''),
|
||||
committed_date=convert_timestamp(data.get('CommittedDate', None)),
|
||||
committer_name=data.get('CommitterName', ''),
|
||||
committer_email=data.get('CommitterEmail', ''),
|
||||
created_at=convert_timestamp(data.get('CreatedAt', None)),
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert to dictionary."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'short_id': self.short_id,
|
||||
'title': self.title,
|
||||
'message': self.message,
|
||||
'author_name': self.author_name,
|
||||
'authored_date': self.authored_date,
|
||||
'author_email': self.author_email,
|
||||
'committed_date': self.committed_date,
|
||||
'committer_name': self.committer_name,
|
||||
'committer_email': self.committer_email,
|
||||
'created_at': self.created_at,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitHistoryResponse:
|
||||
"""Response from commit history API."""
|
||||
commits: Optional[List[DetailedCommitInfo]]
|
||||
total_count: Optional[int]
|
||||
|
||||
@classmethod
|
||||
def from_api_response(cls, data: dict) -> 'CommitHistoryResponse':
|
||||
"""Create CommitHistoryResponse from API response data."""
|
||||
commits_data = data.get('Data', {}).get('Commit', [])
|
||||
commits = [
|
||||
DetailedCommitInfo.from_api_response(commit)
|
||||
for commit in commits_data
|
||||
]
|
||||
|
||||
return cls(
|
||||
commits=commits,
|
||||
total_count=data.get('TotalCount', 0),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RepoUrl:
|
||||
|
||||
|
||||
474
tests/hub/test_hub_repo_info.py
Normal file
474
tests/hub/test_hub_repo_info.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# yapf: disable
|
||||
|
||||
import datetime
|
||||
import unittest
|
||||
import zoneinfo
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from modelscope.hub.api import DatasetInfo, HubApi, ModelInfo
|
||||
from modelscope.utils.constant import REPO_TYPE_DATASET, REPO_TYPE_MODEL
|
||||
from modelscope.utils.repo_utils import DetailedCommitInfo
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class HubRepoInfoTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.api = HubApi()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@patch.object(HubApi, 'get_model')
|
||||
@patch.object(HubApi, 'list_repo_commits')
|
||||
def test_model_info(self, mock_list_repo_commits, mock_get_model):
|
||||
# Setup mock responses
|
||||
mock_get_model.return_value = {
|
||||
'Id': 123,
|
||||
'Name': 'demo-model',
|
||||
'ChineseName': '测试模型',
|
||||
'Description': 'A test model',
|
||||
'Tasks': [{
|
||||
'Name': 'text-classification',
|
||||
'Description': 'A test task'
|
||||
}],
|
||||
'Tags': ['nlp', 'text']
|
||||
}
|
||||
|
||||
# Mock commit history response
|
||||
commit = DetailedCommitInfo(
|
||||
id='abc123',
|
||||
short_id='abc12',
|
||||
title='Initial commit',
|
||||
message='Initial commit',
|
||||
author_name='Test User',
|
||||
authored_date=None,
|
||||
author_email='test@example.com',
|
||||
committed_date=None,
|
||||
committer_name='Test User',
|
||||
committer_email='test@example.com',
|
||||
created_at=None)
|
||||
commits_response = Mock()
|
||||
commits_response.commits = [commit]
|
||||
mock_list_repo_commits.return_value = commits_response
|
||||
|
||||
# Call the method
|
||||
info = self.api.model_info(
|
||||
repo_id='demo/model', revision='master', endpoint=None)
|
||||
|
||||
# Verify results
|
||||
self.assertEqual(info.id, 123)
|
||||
self.assertEqual(info.name, 'demo-model')
|
||||
self.assertEqual(info.author, 'demo')
|
||||
self.assertEqual(info.chinese_name, '测试模型')
|
||||
self.assertEqual(info.description, 'A test model')
|
||||
self.assertEqual(info.tasks, [{
|
||||
'Name': 'text-classification',
|
||||
'Description': 'A test task'
|
||||
}])
|
||||
self.assertEqual(info.tags, ['nlp', 'text'])
|
||||
self.assertEqual(info.sha, 'abc123')
|
||||
self.assertEqual(info.last_commit, commit.to_dict())
|
||||
|
||||
# Verify correct method calls
|
||||
mock_get_model.assert_called_once_with(
|
||||
model_id='demo/model', revision='master', endpoint=None)
|
||||
mock_list_repo_commits.assert_called_once_with(
|
||||
repo_id='demo/model',
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
revision='master',
|
||||
endpoint=None)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@patch.object(HubApi, 'get_dataset')
|
||||
@patch.object(HubApi, 'list_repo_commits')
|
||||
def test_dataset_info(self, mock_list_repo_commits, mock_get_dataset):
|
||||
# Setup mock responses
|
||||
mock_get_dataset.return_value = {
|
||||
'Id': 456,
|
||||
'Name': 'demo-dataset',
|
||||
'ChineseName': '演示数据集',
|
||||
'Description': 'A test dataset',
|
||||
'Tags': [
|
||||
{
|
||||
'Name': 'nlp',
|
||||
'Color': 'blue'
|
||||
},
|
||||
{
|
||||
'Name': 'text',
|
||||
'Color': 'green'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Mock commit history response
|
||||
commits = [
|
||||
DetailedCommitInfo(
|
||||
id='c1',
|
||||
short_id='c1',
|
||||
title='Update data',
|
||||
message='Update data',
|
||||
author_name='Test User',
|
||||
authored_date=None,
|
||||
author_email='test@example.com',
|
||||
committed_date=None,
|
||||
committer_name='Test User',
|
||||
committer_email='test@example.com',
|
||||
created_at=None),
|
||||
DetailedCommitInfo(
|
||||
id='c2',
|
||||
short_id='c2',
|
||||
title='Initial commit',
|
||||
message='Initial commit',
|
||||
author_name='Test User',
|
||||
authored_date=None,
|
||||
author_email='test@example.com',
|
||||
committed_date=1756284063,
|
||||
committer_name='Test User',
|
||||
committer_email='test@example.com',
|
||||
created_at=None)
|
||||
]
|
||||
commits_response = Mock()
|
||||
commits_response.commits = commits
|
||||
mock_list_repo_commits.return_value = commits_response
|
||||
|
||||
# Call the method
|
||||
info = self.api.dataset_info('demo/dataset')
|
||||
|
||||
# Verify results
|
||||
self.assertEqual(info.id, 456)
|
||||
self.assertEqual(info.name, 'demo-dataset')
|
||||
self.assertEqual(info.author, 'demo')
|
||||
self.assertEqual(info.chinese_name, '演示数据集')
|
||||
self.assertEqual(info.description, 'A test dataset')
|
||||
self.assertEqual(info.tags, [{
|
||||
'Name': 'nlp',
|
||||
'Color': 'blue'
|
||||
}, {
|
||||
'Name': 'text',
|
||||
'Color': 'green'
|
||||
}])
|
||||
self.assertEqual(info.sha, 'c1')
|
||||
self.assertEqual(info.last_commit, commits[0].to_dict())
|
||||
|
||||
# Verify correct method calls
|
||||
mock_get_dataset.assert_called_once_with(
|
||||
dataset_id='demo/dataset', revision=None, endpoint=None)
|
||||
mock_list_repo_commits.assert_called_once_with(
|
||||
repo_id='demo/dataset',
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
revision=None,
|
||||
endpoint=None)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@patch.object(HubApi, 'model_info')
|
||||
@patch.object(HubApi, 'dataset_info')
|
||||
def test_repo_info_model(self, mock_dataset_info, mock_model_info):
|
||||
# Setup mock response
|
||||
model_info = ModelInfo(
|
||||
id=123, name='demo-model', description='A test model')
|
||||
mock_model_info.return_value = model_info
|
||||
|
||||
# Call the method with model type
|
||||
info = self.api.repo_info(
|
||||
repo_id='demo/model', revision='master', endpoint=None)
|
||||
|
||||
# Verify results
|
||||
self.assertEqual(info, model_info)
|
||||
mock_model_info.assert_called_once_with(
|
||||
repo_id='demo/model', revision='master', endpoint=None)
|
||||
mock_dataset_info.assert_not_called()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
@patch.object(HubApi, 'model_info')
|
||||
@patch.object(HubApi, 'dataset_info')
|
||||
def test_repo_info_dataset(self, mock_dataset_info, mock_model_info):
|
||||
# Setup mock response
|
||||
dataset_info = DatasetInfo(
|
||||
id=456, name='demo-dataset', description='A test dataset')
|
||||
mock_dataset_info.return_value = dataset_info
|
||||
|
||||
# Call the method with dataset type
|
||||
info = self.api.repo_info(
|
||||
repo_id='demo/dataset',
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
revision='master',
|
||||
endpoint=None)
|
||||
|
||||
# Verify results
|
||||
self.assertEqual(info, dataset_info)
|
||||
mock_dataset_info.assert_called_once_with(
|
||||
repo_id='demo/dataset', revision='master', endpoint=None)
|
||||
mock_model_info.assert_not_called()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_model_info_class_comprehensive(self):
|
||||
"""Test ModelInfo class initialization and properties."""
|
||||
model_data = {
|
||||
'Id': 123,
|
||||
'Name': 'demo-model',
|
||||
'author': 'demo',
|
||||
'ChineseName': '演示模型',
|
||||
'Description': 'A test model',
|
||||
'Tasks': [{
|
||||
'Name': 'text-classification',
|
||||
'Description': 'A test task'
|
||||
}],
|
||||
'Tags': ['nlp', 'text'],
|
||||
'CreatedTime': '2023-01-01T00:00:00Z',
|
||||
'Visibility': 5,
|
||||
'IsPublished': 1,
|
||||
'IsOnline': 1,
|
||||
'License': 'Apache-2.0',
|
||||
'Downloads': 100,
|
||||
'Stars': 50,
|
||||
'Architectures': ['transformer'],
|
||||
'ModelType': ['nlp']
|
||||
}
|
||||
|
||||
# Create mock commits
|
||||
commit = DetailedCommitInfo(
|
||||
id='abc123',
|
||||
short_id='abc12',
|
||||
title='Initial commit',
|
||||
message='Initial commit',
|
||||
author_name='Test User',
|
||||
authored_date=None,
|
||||
author_email='test@example.com',
|
||||
committed_date=None,
|
||||
committer_name='Test User',
|
||||
committer_email='test@example.com',
|
||||
created_at=None)
|
||||
commits = Mock()
|
||||
commits.commits = [commit]
|
||||
|
||||
# Create ModelInfo instance
|
||||
model_info = ModelInfo(**model_data, commits=commits)
|
||||
|
||||
# Verify properties
|
||||
self.assertEqual(model_info.id, 123)
|
||||
self.assertEqual(model_info.name, 'demo-model')
|
||||
self.assertEqual(model_info.author, 'demo')
|
||||
self.assertEqual(model_info.chinese_name, '演示模型')
|
||||
self.assertEqual(model_info.description, 'A test model')
|
||||
self.assertEqual(model_info.tasks, [{
|
||||
'Name': 'text-classification',
|
||||
'Description': 'A test task'
|
||||
}])
|
||||
self.assertEqual(model_info.tags, ['nlp', 'text'])
|
||||
self.assertEqual(
|
||||
model_info.created_at,
|
||||
datetime.datetime(
|
||||
2023, 1, 1, 8, 0, 0).replace(tzinfo=zoneinfo.ZoneInfo('Asia/Shanghai')))
|
||||
self.assertEqual(model_info.sha, 'abc123')
|
||||
self.assertEqual(model_info.last_commit, commit.to_dict())
|
||||
self.assertEqual(model_info.visibility, 5)
|
||||
self.assertEqual(model_info.is_published, 1)
|
||||
self.assertEqual(model_info.is_online, 1)
|
||||
self.assertEqual(model_info.license, 'Apache-2.0')
|
||||
self.assertEqual(model_info.downloads, 100)
|
||||
self.assertEqual(model_info.likes, 50)
|
||||
self.assertEqual(model_info.architectures, ['transformer'])
|
||||
self.assertEqual(model_info.model_type, ['nlp'])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_dataset_info_class(self):
|
||||
# Test DatasetInfo class initialization and properties
|
||||
dataset_data = {
|
||||
'Id': 456,
|
||||
'Name': 'demo-dataset',
|
||||
'Owner': 'demo',
|
||||
'ChineseName': '演示数据集',
|
||||
'Description': 'A test dataset',
|
||||
'Tags': [
|
||||
{
|
||||
'Name': 'nlp',
|
||||
'Color': 'blue'
|
||||
},
|
||||
{
|
||||
'Name': 'text',
|
||||
'Color': 'green'
|
||||
}
|
||||
],
|
||||
'GmtCreate': 1755752511,
|
||||
'Visibility': 5,
|
||||
'License': 'MIT',
|
||||
'Downloads': 200,
|
||||
'Likes': 75
|
||||
}
|
||||
|
||||
# Create mock commits
|
||||
commit = DetailedCommitInfo(
|
||||
id='c1',
|
||||
short_id='c1',
|
||||
title='Initial commit',
|
||||
message='Initial commit',
|
||||
author_name='Test User',
|
||||
authored_date=None,
|
||||
author_email='test@example.com',
|
||||
committed_date='2024-09-18T06:20:05Z',
|
||||
committer_name='Test User',
|
||||
committer_email='test@example.com',
|
||||
created_at=None)
|
||||
commits = Mock()
|
||||
commits.commits = [commit]
|
||||
|
||||
# Create DatasetInfo instance
|
||||
dataset_info = DatasetInfo(**dataset_data, commits=commits)
|
||||
|
||||
# Verify properties
|
||||
self.assertEqual(dataset_info.id, 456)
|
||||
self.assertEqual(dataset_info.name, 'demo-dataset')
|
||||
self.assertEqual(dataset_info.author, 'demo')
|
||||
self.assertEqual(dataset_info.chinese_name, '演示数据集')
|
||||
self.assertEqual(dataset_info.description, 'A test dataset')
|
||||
self.assertEqual(
|
||||
dataset_info.tags,
|
||||
[
|
||||
{
|
||||
'Name': 'nlp',
|
||||
'Color': 'blue'
|
||||
},
|
||||
{
|
||||
'Name': 'text',
|
||||
'Color': 'green'
|
||||
}
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
dataset_info.created_at,
|
||||
datetime.datetime(
|
||||
2025, 8, 21, 13, 1, 51).replace(tzinfo=zoneinfo.ZoneInfo('Asia/Shanghai')))
|
||||
self.assertEqual(dataset_info.sha, 'c1')
|
||||
self.assertEqual(
|
||||
dataset_info.last_modified,
|
||||
datetime.datetime(
|
||||
2024, 9, 18, 14, 20, 5).replace(tzinfo=zoneinfo.ZoneInfo('Asia/Shanghai')))
|
||||
self.assertEqual(dataset_info.last_commit, commit.to_dict())
|
||||
self.assertEqual(dataset_info.visibility, 5)
|
||||
self.assertEqual(dataset_info.license, 'MIT')
|
||||
self.assertEqual(dataset_info.downloads, 200)
|
||||
self.assertEqual(dataset_info.likes, 75)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_model_info_empty_commits(self):
|
||||
"""Test ModelInfo with empty commits."""
|
||||
model_data = {'Id': 123, 'Name': 'demo-model', 'author': 'demo'}
|
||||
|
||||
# Create ModelInfo with no commits
|
||||
model_info = ModelInfo(**model_data, commits=None)
|
||||
|
||||
# Verify commit-related fields are None
|
||||
self.assertIsNone(model_info.sha)
|
||||
self.assertIsNone(model_info.last_commit)
|
||||
self.assertIsNone(model_info.last_modified)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_dataset_info_empty_commits(self):
|
||||
"""Test DatasetInfo with empty commits."""
|
||||
dataset_data = {'Id': 456, 'Name': 'demo-dataset', 'Owner': 'demo'}
|
||||
|
||||
# Create DatasetInfo with no commits
|
||||
dataset_info = DatasetInfo(**dataset_data, commits=None)
|
||||
|
||||
# Verify commit-related fields are None
|
||||
self.assertIsNone(dataset_info.sha)
|
||||
self.assertIsNone(dataset_info.last_commit)
|
||||
self.assertIsNone(dataset_info.last_modified)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_detailed_commit_info_to_dict(self):
|
||||
"""Test DetailedCommitInfo to_dict method."""
|
||||
commit = DetailedCommitInfo(
|
||||
id='abc123',
|
||||
short_id='abc12',
|
||||
title='Test commit',
|
||||
message='Test commit message',
|
||||
author_name='Test Author',
|
||||
authored_date=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
author_email='test@example.com',
|
||||
committed_date=datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
committer_name='Test Committer',
|
||||
committer_email='committer@example.com',
|
||||
created_at=datetime.datetime(2023, 1, 1, 0, 0, 0))
|
||||
|
||||
result = commit.to_dict()
|
||||
|
||||
expected = {
|
||||
'id': 'abc123',
|
||||
'short_id': 'abc12',
|
||||
'title': 'Test commit',
|
||||
'message': 'Test commit message',
|
||||
'author_name': 'Test Author',
|
||||
'authored_date': datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
'author_email': 'test@example.com',
|
||||
'committed_date': datetime.datetime(2023, 1, 1, 0, 0, 0),
|
||||
'committer_name': 'Test Committer',
|
||||
'committer_email': 'committer@example.com',
|
||||
'created_at': datetime.datetime(2023, 1, 1, 0, 0, 0)
|
||||
}
|
||||
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_real_model_repo_info(self):
|
||||
"""Test getting real model repository information without mocks."""
|
||||
# Use a real model repository
|
||||
model_repo_id = 'black-forest-labs/FLUX.1-Krea-dev'
|
||||
|
||||
# Get repository information
|
||||
info = self.api.repo_info(
|
||||
repo_id=model_repo_id, repo_type=REPO_TYPE_MODEL)
|
||||
|
||||
# Basic validation
|
||||
self.assertIsNotNone(info)
|
||||
self.assertEqual(info.author, 'black-forest-labs')
|
||||
self.assertEqual(info.name, 'FLUX.1-Krea-dev')
|
||||
|
||||
# Check commit information
|
||||
self.assertIsNotNone(info.sha)
|
||||
if hasattr(info, 'last_commit') and info.last_commit:
|
||||
self.assertIn('id', info.last_commit)
|
||||
self.assertIn('title', info.last_commit)
|
||||
|
||||
# Print some information for debugging
|
||||
print(f'\nModel Info for {model_repo_id}:')
|
||||
print(f'ID: {info.id}')
|
||||
print(f'Name: {info.name}')
|
||||
print(f'Author: {info.author}')
|
||||
print(f'SHA: {info.sha}')
|
||||
if hasattr(info, 'last_modified'):
|
||||
print(f'Last Modified: {info.last_modified}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_real_dataset_repo_info(self):
|
||||
"""Test getting real dataset repository information without mocks."""
|
||||
# Use a real dataset repository
|
||||
dataset_repo_id = 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT'
|
||||
|
||||
# Get repository information
|
||||
info = self.api.repo_info(
|
||||
repo_id=dataset_repo_id, repo_type=REPO_TYPE_DATASET)
|
||||
|
||||
# Basic validation
|
||||
self.assertIsNotNone(info)
|
||||
self.assertEqual(info.author, 'swift')
|
||||
self.assertTrue('Chinese-Qwen3' in info.name)
|
||||
|
||||
# Check commit information
|
||||
self.assertIsNotNone(info.sha)
|
||||
if hasattr(info, 'last_commit') and info.last_commit:
|
||||
self.assertIn('id', info.last_commit)
|
||||
self.assertIn('title', info.last_commit)
|
||||
|
||||
# Print some information for debugging
|
||||
print(f'\nDataset Info for {dataset_repo_id}:')
|
||||
print(f'ID: {info.id}')
|
||||
print(f'Name: {info.name}')
|
||||
print(f'Author: {info.author}')
|
||||
print(f'SHA: {info.sha}')
|
||||
if hasattr(info, 'last_modified'):
|
||||
print(f'Last Modified: {info.last_modified}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user