diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index af3539a4..96c576de 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -63,6 +63,7 @@ from modelscope.hub.errors import (InvalidParameter, NotExistError, handle_http_response, is_ok, raise_for_http_status, raise_on_error) from modelscope.hub.git import GitCommandWrapper +from modelscope.hub.info import DatasetInfo, ModelInfo from modelscope.hub.repository import Repository from modelscope.hub.utils.aigc import AigcModel from modelscope.hub.utils.utils import (add_content_to_file, get_domain, @@ -83,7 +84,8 @@ from modelscope.utils.file_utils import get_file_hash, get_file_size from modelscope.utils.logger import get_logger from modelscope.utils.repo_utils import (DATASET_LFS_SUFFIX, DEFAULT_IGNORE_PATTERNS, - MODEL_LFS_SUFFIX, CommitInfo, + MODEL_LFS_SUFFIX, + CommitHistoryResponse, CommitInfo, CommitOperation, CommitOperationAdd, RepoUtils) from modelscope.utils.thread_utils import thread_executor @@ -434,6 +436,101 @@ class HubApi: else: return prefer_endpoint + def model_info(self, + repo_id: str, + *, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + endpoint: Optional[str] = None) -> ModelInfo: + """Get model information including commit history. + + Args: + repo_id (str): The model id in the format of + ``namespace/model_name``. + revision (str, optional): Specific revision of the model. + Defaults to ``DEFAULT_MODEL_REVISION``. + endpoint (str, optional): Hub endpoint to use. When ``None``, + use the endpoint specified when initializing :class:`HubApi`. + + Returns: + ModelInfo: The model detailed information returned by + ModelScope Hub with commit history. + """ + owner_or_group, _ = model_id_to_group_owner_name(repo_id) + model_data = self.get_model( + model_id=repo_id, revision=revision, endpoint=endpoint) + commits = self.list_repo_commits( + repo_id=repo_id, repo_type=REPO_TYPE_MODEL, revision=revision, endpoint=endpoint) + + # Create ModelInfo from API response data + model_info = ModelInfo(**model_data, commits=commits, author=owner_or_group) + + return model_info + + def dataset_info(self, + repo_id: str, + *, + revision: Optional[str] = None, + endpoint: Optional[str] = None) -> DatasetInfo: + """Get dataset information including commit history. + + Args: + repo_id (str): The dataset id in the format of + ``namespace/dataset_name``. + revision (str, optional): Specific revision of the dataset. + Defaults to ``None``. + endpoint (str, optional): Hub endpoint to use. When ``None``, + use the endpoint specified when initializing :class:`HubApi`. + + Returns: + DatasetInfo: The dataset detailed information returned by + ModelScope Hub with commit history. + """ + owner_or_group, _ = model_id_to_group_owner_name(repo_id) + dataset_data = self.get_dataset( + dataset_id=repo_id, revision=revision, endpoint=endpoint) + commits = self.list_repo_commits( + repo_id=repo_id, repo_type=REPO_TYPE_DATASET, revision=revision, endpoint=endpoint) + + # Create DatasetInfo from API response data + dataset_info = DatasetInfo(**dataset_data, commits=commits, author=owner_or_group) + + return dataset_info + + def repo_info( + self, + repo_id: str, + *, + repo_type: Optional[str] = REPO_TYPE_MODEL, + revision: Optional[str] = DEFAULT_MODEL_REVISION, + endpoint: Optional[str] = None + ) -> Union[ModelInfo, DatasetInfo]: + """Get repository information for models or datasets. + + Args: + repo_id (str): The repository id in the format of + ``namespace/repo_name``. + revision (str, optional): Specific revision of the repository. + Currently only effective for model repositories. Defaults to + ``DEFAULT_MODEL_REVISION``. + repo_type (str, optional): Type of the repository. Supported + values are ``"model"`` and ``"dataset"``. If not provided, + ``"model"`` is assumed. + endpoint (str, optional): Hub endpoint to use. When ``None``, + use the endpoint specified when initializing :class:`HubApi`. + + Returns: + Union[ModelInfo, DatasetInfo]: The repository detailed information + returned by ModelScope Hub. + """ + if repo_type is None or repo_type == REPO_TYPE_MODEL: + return self.model_info(repo_id=repo_id, revision=revision, endpoint=endpoint) + + if repo_type == REPO_TYPE_DATASET: + return self.dataset_info(repo_id=repo_id, revision=revision, endpoint=endpoint) + + raise InvalidParameter( + f'Arg repo_type {repo_type} not supported. Please choose from {REPO_TYPE_SUPPORT}.') + def repo_exists( self, repo_id: str, @@ -1111,6 +1208,68 @@ class HubApi: return resp + def list_repo_commits(self, + repo_id: str, + *, + repo_type: Optional[str] = REPO_TYPE_MODEL, + revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, + page_number: int = 1, + page_size: int = 50, + endpoint: Optional[str] = None): + """ + Get the commit history for a repository. + + Args: + repo_id (str): The repository id, in the format of `namespace/repo_name`. + repo_type (Optional[str]): The type of the repository. Supported types are `model` and `dataset`. + revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`. + page_number (int): The page number for pagination. Defaults to 1. + page_size (int): The number of commits per page. Defaults to 50. + endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class. + + Returns: + CommitHistoryResponse: The commit history response. + + Examples: + >>> from modelscope.hub.api import HubApi + >>> api = HubApi() + >>> commit_history = api.list_repo_commits('meituan/Meeseeks') + >>> print(f"Total commits: {commit_history.total_count}") + >>> for commit in commit_history.commits: + ... print(f"{commit.short_id}: {commit.title}") + """ + from datasets.utils.file_utils import is_relative_path + + if is_relative_path(repo_id) and repo_id.count('/') == 1: + _owner, _dataset_name = repo_id.split('/') + else: + raise ValueError(f'Invalid repo_id: {repo_id} !') + + if not endpoint: + endpoint = self.endpoint + + commits_url = f'{endpoint}/api/v1/{repo_type}s/{repo_id}/commits' if repo_type else \ + f'{endpoint}/api/v1/models/{repo_id}/commits' + params = { + 'Ref': revision or DEFAULT_MODEL_REVISION or DEFAULT_REPOSITORY_REVISION, + 'PageNumber': page_number, + 'PageSize': page_size + } + cookies = ModelScopeConfig.get_cookies() + + try: + r = self.session.get(commits_url, params=params, + cookies=cookies, headers=self.builder_headers(self.headers)) + raise_for_http_status(r) + resp = r.json() + raise_on_error(resp) + + if resp.get('Code') == HTTPStatus.OK: + return CommitHistoryResponse.from_api_response(resp) + + except requests.exceptions.RequestException as e: + raise Exception(f'Failed to get repository commits for {repo_id}: {str(e)}') + def get_dataset_files(self, repo_id: str, *, @@ -1164,6 +1323,39 @@ class HubApi: return resp['Data']['Files'] + def get_dataset( + self, + dataset_id: str, + revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, + endpoint: Optional[str] = None + ): + """ + Get the dataset information. + + Args: + dataset_id (str): The dataset id. + revision (Optional[str]): The revision of the dataset. + endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class. + + Returns: + dict: The dataset information. + """ + cookies = ModelScopeConfig.get_cookies() + if not endpoint: + endpoint = self.endpoint + + if revision: + path = f'{endpoint}/api/v1/datasets/{dataset_id}?Revision={revision}' + else: + path = f'{endpoint}/api/v1/datasets/{dataset_id}' + + r = self.session.get( + path, cookies=cookies, headers=self.builder_headers(self.headers)) + raise_for_http_status(r) + resp = r.json() + datahub_raise_on_error(path, resp, r) + return resp[API_RESPONSE_FIELD_DATA] + def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, dataset_id: str, revision: str, endpoint: Optional[str] = None): """ Get the meta file-list of the dataset. """ diff --git a/modelscope/hub/info.py b/modelscope/hub/info.py new file mode 100644 index 00000000..0dd95e98 --- /dev/null +++ b/modelscope/hub/info.py @@ -0,0 +1,259 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Copyright 2022-present, the HuggingFace Inc. team. +# yapf: disable + +import datetime +from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional + +from modelscope.hub.utils.utils import convert_timestamp + + +@dataclass +class OrganizationInfo: + """Organization information for a repository.""" + id: Optional[int] + name: Optional[str] + full_name: Optional[str] + description: Optional[str] + avatar: Optional[str] + github_address: Optional[str] + type: Optional[int] + email: Optional[str] + created_time: Optional[datetime.datetime] + modified_time: Optional[datetime.datetime] + + def __init__(self, **kwargs): + self.id = kwargs.pop('Id', None) + self.name = kwargs.pop('Name', '') + self.full_name = kwargs.pop('FullName', '') + self.description = kwargs.pop('Description', '') + self.avatar = kwargs.pop('Avatar', '') + self.github_address = kwargs.pop('GithubAddress', '') + self.type = kwargs.pop('Type', kwargs.pop('type', None)) + self.email = kwargs.pop('Email', kwargs.pop('email', '')) + created_time = kwargs.pop('GmtCreated', kwargs.pop('created_time', None)) + self.created_time = convert_timestamp(created_time) if created_time else None + modified_time = kwargs.pop('GmtModified', kwargs.pop('modified_time', None)) + self.modified_time = convert_timestamp(modified_time) if modified_time else None + + +@dataclass +class ModelInfo: + """ + Contains detailed information about a model on ModelScope Hub. This object is returned by [`model_info`]. + + Attributes: + id (`int`, *optional*): Model ID. + name (`str`, *optional*): Model name. + author (`str`, *optional*): Model author. + chinese_name (`str`, *optional*): Chinese display name. + visibility (`int`, *optional*): Visibility level (1=private, 5=public). + is_published (`int`, *optional*): Whether the model is published. + is_online (`int`, *optional*): Whether the model is online. + already_star (`bool`, *optional*): Whether current user has starred this model. + description (`str`, *optional*): Model description. + license (`str`, *optional*): Model license. + downloads (`int`, *optional*): Number of downloads. + likes (`int`, *optional*): Number of likes. + created_at (`datetime`, *optional*): Date of creation of the repo on the Hub.. + last_updated_time (`datetime`, *optional*): Last update timestamp. + architectures (`List[str]`, *optional*): Model architectures. + model_type (`List[str]`, *optional*): Model types. + tasks (`List[Dict[str, Any]]`, *optional*): Supported tasks. + readme_content (`str`, *optional*): README content. + organization (`OrganizationInfo`, *optional*): Organization information. + created_by (`str`, *optional*): Creator username. + is_certification (`int`, *optional*): Certification status. + approval_mode (`int`, *optional*): Approval mode. + card_ready (`int`, *optional*): Whether model card is ready. + backend_support (`str`, *optional*): Backend support information. + model_infos (`Dict[str, Any]`, *optional*): Detailed model configuration information. + tags (`List[str]`, *optional*): Model Tags. + is_accessible (`int`, *optional*): Whether accessible. + revision (`str`, *optional*): Revision/branch. + related_arxiv_id (`List[str]`, *optional*): Related arXiv paper IDs. + related_paper (`List[int]`, *optional*): Related papers. + sha (`str`, *optional*): Latest commit SHA. + last_modified (`datetime`, *optional*): Latest commit date. + last_commit (`Dict[str, Any]`, *optional*): Latest commit information. + """ + + id: Optional[int] + name: Optional[str] + author: Optional[str] + chinese_name: Optional[str] + visibility: Optional[int] + is_published: Optional[int] + is_online: Optional[int] + already_star: Optional[bool] + description: Optional[str] + license: Optional[str] + downloads: Optional[int] + likes: Optional[int] + created_at: Optional[datetime.datetime] + last_updated_time: Optional[datetime.datetime] + architectures: Optional[List[str]] + model_type: Optional[List[str]] + tasks: Optional[List[Dict[str, Any]]] + readme_content: Optional[str] + organization: Optional[OrganizationInfo] + created_by: Optional[str] + + # Certification and approval + is_certification: Optional[int] + approval_mode: Optional[int] + card_ready: Optional[int] + + # Model specific + backend_support: Optional[str] + model_infos: Optional[Dict[str, Any]] + + # Content and settings + tags: Optional[List[str]] + + # Additional flags + is_accessible: Optional[int] + + # Revision and version info + revision: Optional[str] + + # External references + related_arxiv_id: Optional[List[str]] + related_paper: Optional[List[int]] + + # latest commit infomation + last_commit: Optional[Dict[str, Any]] + sha: Optional[str] + last_modified: Optional[datetime.datetime] + + def __init__(self, **kwargs): + self.id = kwargs.pop('Id', None) + self.name = kwargs.pop('Name', '') + self.chinese_name = kwargs.pop('ChineseName', '') + self.visibility = kwargs.pop('Visibility', None) + self.is_published = kwargs.pop('IsPublished', None) + self.is_online = kwargs.pop('IsOnline', None) + self.already_star = kwargs.pop('AlreadyStar', None) + self.description = kwargs.pop('Description', '') + self.license = kwargs.pop('License', '') + self.downloads = kwargs.pop('Downloads', None) + self.likes = kwargs.pop('Stars', None) or kwargs.pop('Likes', None) + created_time = kwargs.pop('CreatedTime', None) + self.created_at = convert_timestamp(created_time) if created_time else None + last_updated_time = kwargs.pop('LastUpdatedTime', None) + self.last_updated_time = convert_timestamp(last_updated_time) if last_updated_time else None + self.architectures = kwargs.pop('Architectures', []) + self.model_type = kwargs.pop('ModelType', []) + self.tasks = kwargs.pop('Tasks', []) + self.readme_content = kwargs.pop('ReadMeContent', '') + org_data = kwargs.pop('Organization', None) + self.organization = OrganizationInfo(**org_data) if org_data else None + self.created_by = kwargs.pop('CreatedBy', None) + self.is_certification = kwargs.pop('IsCertification', None) + self.approval_mode = kwargs.pop('ApprovalMode', None) + self.card_ready = kwargs.pop('CardReady', None) + self.backend_support = kwargs.pop('BackendSupport', '{}') + self.model_infos = kwargs.pop('ModelInfos', {}) + self.tags = kwargs.pop('Tags', []) + self.is_accessible = kwargs.pop('IsAccessible', None) + self.revision = kwargs.pop('Revision', '') + self.related_arxiv_id = kwargs.pop('RelatedArxivId', []) + self.related_paper = kwargs.pop('RelatedPaper', []) + + commits = kwargs.pop('commits', None) + if commits and hasattr(commits, 'commits') and commits.commits: + last_commit = commits.commits[0] + self.last_commit = last_commit.to_dict() if hasattr(last_commit, 'to_dict') else None + self.sha = self.last_commit.get('id') if self.last_commit else None + self.last_modified = convert_timestamp(self.last_commit.get('committed_date')) if self.last_commit else None + else: + self.last_commit = None + self.sha = None + self.last_modified = None + self.author = kwargs.pop('author', '') + + # backward compatibility + self.__dict__.update(kwargs) + + +@dataclass +class DatasetInfo: + """ + Contains detailed information about a dataset on ModelScope Hub. This object is returned by [`dataset_info`]. + + Attributes: + id (`int`, *optional*)): Dataset ID. + name (`str`, *optional*)): Dataset name. + author (`str`, *optional*): Dataset owner (user or organization). + chinese_name (`str`, *optional*): Chinese display name. + visibility (`int`, *optional*)): Visibility level (1=private, 3=interal, 5=public). + 'internal' means visible to logged-in users only. + already_star (`bool`, *optional*)): Whether current user has starred this dataset. + description (`str`, *optional*): Dataset description. + license (`str`, *optional*)): Dataset license. + downloads (`int`, *optional*)): Number of downloads. + likes (`int`, *optional*)): Number of likes. + created_at (`int`, *optional*): Creation timestamp. + last_updated_time (`int`, *optional*): Last update timestamp. + readme_content (`str`, *optional*): README content. + organization (`OrganizationInfo`, *optional*): Organization information. + created_by (`str`, *optional*): Creator username. + tags (`List[Dict[str, Any]]`): Dataset tags. + last_commit (`Dict[str, Any]`, *optional*): Latest commit information. + """ + + id: Optional[int] + name: Optional[str] + author: Optional[str] + chinese_name: Optional[str] + visibility: Optional[Literal[1, 3, 5]] + already_star: Optional[bool] + description: Optional[str] + license: Optional[str] + downloads: Optional[int] + likes: Optional[int] + created_at: Optional[datetime.datetime] + last_updated_time: Optional[datetime.datetime] + readme_content: Optional[str] + organization: Optional[OrganizationInfo] + created_by: Optional[str] + tags: Optional[List[Dict[str, Any]]] + last_commit: Optional[Dict[str, Any]] + sha: Optional[str] + last_modified: Optional[datetime.datetime] + + def __init__(self, **kwargs): + self.id = kwargs.pop('Id', None) + self.name = kwargs.pop('Name', '') + self.author = kwargs.pop('author', kwargs.pop('Owner', None) or kwargs.pop('Namespace', None)) + self.chinese_name = kwargs.pop('ChineseName', '') + self.visibility = kwargs.pop('Visibility', None) + self.already_star = kwargs.pop('AlreadyStar', None) + self.description = kwargs.pop('Description', '') + self.likes = kwargs.pop('Likes', None) or kwargs.pop('Stars', None) + self.license = kwargs.pop('License', '') + self.downloads = kwargs.pop('Downloads', None) + created_time = kwargs.pop('GmtCreate', None) + self.created_at = convert_timestamp(created_time) if created_time else None + last_updated_time = kwargs.pop('LastUpdatedTime', None) + self.last_updated_time = convert_timestamp(last_updated_time) if last_updated_time else None + self.readme_content = kwargs.pop('ReadMeContent', '') + org_data = kwargs.pop('Organization', None) + self.organization = OrganizationInfo(**org_data) if org_data else None + self.created_by = kwargs.pop('CreatedBy', None) + self.tags = kwargs.pop('Tags', []) + commits = kwargs.pop('commits', None) + + if commits and hasattr(commits, 'commits') and commits.commits: + last_commit = commits.commits[0] + self.last_commit = last_commit.to_dict() if hasattr(last_commit, 'to_dict') else None + self.sha = self.last_commit.get('id') if self.last_commit else None + self.last_modified = convert_timestamp(self.last_commit.get('committed_date')) if self.last_commit else None + else: + self.last_commit = None + self.sha = None + self.last_modified = None + + # backward compatibility + self.__dict__.update(kwargs) diff --git a/modelscope/hub/utils/utils.py b/modelscope/hub/utils/utils.py index 28bcdbf2..10658b0e 100644 --- a/modelscope/hub/utils/utils.py +++ b/modelscope/hub/utils/utils.py @@ -5,6 +5,7 @@ import hashlib import os import sys import time +import zoneinfo from datetime import datetime from pathlib import Path from typing import Generator, List, Optional, Union @@ -299,3 +300,83 @@ def weak_file_lock(lock_file: Union[str, Path], Path(lock_file).unlink() except OSError: pass + + +def convert_timestamp(time_stamp: Union[int, str, datetime], + time_zone: str = 'Asia/Shanghai') -> Optional[datetime]: + """Convert a UNIX/string timestamp to a timezone-aware datetime object. + + Args: + time_stamp: UNIX timestamp (int), ISO string, or datetime object + time_zone: Target timezone for non-UTC timestamps (default: 'Asia/Shanghai') + + Returns: + Timezone-aware datetime object or None if input is None + """ + if not time_stamp: + return None + + # Handle datetime objects first + if isinstance(time_stamp, datetime): + return time_stamp + + if isinstance(time_stamp, str): + try: + if time_stamp.endswith('Z'): + # Normalize fractional seconds to 6 digits + if '.' not in time_stamp: + # No fractional seconds (e.g., "2024-11-16T00:27:02Z") + time_stamp = time_stamp[:-1] + '.000000Z' + else: + # Has fractional seconds (e.g., "2022-08-19T07:19:38.123456789Z") + base, fraction = time_stamp[:-1].split('.') + # Truncate or pad to 6 digits + fraction = fraction[:6].ljust(6, '0') + time_stamp = f'{base}.{fraction}Z' + + dt = datetime.strptime(time_stamp, + '%Y-%m-%dT%H:%M:%S.%fZ').replace( + tzinfo=zoneinfo.ZoneInfo('UTC')) + if time_zone != 'UTC': + dt = dt.astimezone(zoneinfo.ZoneInfo(time_zone)) + return dt + else: + # Try parsing common ISO formats + formats = [ + '%Y-%m-%dT%H:%M:%S.%f', # With microseconds + '%Y-%m-%dT%H:%M:%S', # Without microseconds + '%Y-%m-%d %H:%M:%S.%f', # Space separator with microseconds + '%Y-%m-%d %H:%M:%S', # Space separator without microseconds + ] + for fmt in formats: + try: + return datetime.strptime( + time_stamp, + fmt).replace(tzinfo=zoneinfo.ZoneInfo(time_zone)) + except ValueError: + continue + + raise ValueError( + f"Unsupported timestamp format: '{time_stamp}'") + + except ValueError as e: + raise ValueError( + f"Cannot parse '{time_stamp}' as a datetime. Expected formats: " + f"'YYYY-MM-DDTHH:MM:SS[.ffffff]Z' (UTC) or 'YYYY-MM-DDTHH:MM:SS[.ffffff]' (local)" + ) from e + + elif isinstance(time_stamp, int): + try: + # UNIX timestamps are always in UTC, then convert to target timezone + return datetime.fromtimestamp( + time_stamp, tz=zoneinfo.ZoneInfo('UTC')).astimezone( + zoneinfo.ZoneInfo(time_zone)) + except (ValueError, OSError) as e: + raise ValueError( + f"Cannot convert '{time_stamp}' to datetime. Ensure it's a valid UNIX timestamp." + ) from e + + else: + raise TypeError( + f"Unsupported type '{type(time_stamp)}'. Expected int, str, or datetime." + ) diff --git a/modelscope/utils/repo_utils.py b/modelscope/utils/repo_utils.py index 446f3857..842cdf60 100644 --- a/modelscope/utils/repo_utils.py +++ b/modelscope/utils/repo_utils.py @@ -8,12 +8,14 @@ import os import sys from contextlib import contextmanager from dataclasses import dataclass, field +from datetime import datetime from fnmatch import fnmatch from pathlib import Path from typing import (Any, BinaryIO, Callable, Generator, Iterable, Iterator, List, Literal, Optional, TypeVar, Union) from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT +from modelscope.hub.utils.utils import convert_timestamp from modelscope.utils.file_utils import get_file_hash T = TypeVar('T') @@ -257,6 +259,76 @@ class CommitInfo: } +@dataclass +class DetailedCommitInfo: + """Detailed commit information from repository history API.""" + id: Optional[str] + short_id: Optional[str] + title: Optional[str] + message: Optional[str] + author_name: Optional[str] + authored_date: Optional[datetime] + author_email: Optional[str] + committed_date: Optional[datetime] + committer_name: Optional[str] + committer_email: Optional[str] + created_at: Optional[datetime] + + @classmethod + def from_api_response(cls, data: dict) -> 'DetailedCommitInfo': + """Create DetailedCommitInfo from API response data.""" + return cls( + id=data.get('Id', ''), + short_id=data.get('ShortId', ''), + title=data.get('Title', ''), + message=data.get('Message', ''), + author_name=data.get('AuthorName', ''), + authored_date=convert_timestamp(data.get('AuthoredDate', None)), + author_email=data.get('AuthorEmail', ''), + committed_date=convert_timestamp(data.get('CommittedDate', None)), + committer_name=data.get('CommitterName', ''), + committer_email=data.get('CommitterEmail', ''), + created_at=convert_timestamp(data.get('CreatedAt', None)), + ) + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return { + 'id': self.id, + 'short_id': self.short_id, + 'title': self.title, + 'message': self.message, + 'author_name': self.author_name, + 'authored_date': self.authored_date, + 'author_email': self.author_email, + 'committed_date': self.committed_date, + 'committer_name': self.committer_name, + 'committer_email': self.committer_email, + 'created_at': self.created_at, + } + + +@dataclass +class CommitHistoryResponse: + """Response from commit history API.""" + commits: Optional[List[DetailedCommitInfo]] + total_count: Optional[int] + + @classmethod + def from_api_response(cls, data: dict) -> 'CommitHistoryResponse': + """Create CommitHistoryResponse from API response data.""" + commits_data = data.get('Data', {}).get('Commit', []) + commits = [ + DetailedCommitInfo.from_api_response(commit) + for commit in commits_data + ] + + return cls( + commits=commits, + total_count=data.get('TotalCount', 0), + ) + + @dataclass class RepoUrl: diff --git a/tests/hub/test_hub_repo_info.py b/tests/hub/test_hub_repo_info.py new file mode 100644 index 00000000..0b021931 --- /dev/null +++ b/tests/hub/test_hub_repo_info.py @@ -0,0 +1,474 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# yapf: disable + +import datetime +import unittest +import zoneinfo +from unittest.mock import Mock, patch + +from modelscope.hub.api import DatasetInfo, HubApi, ModelInfo +from modelscope.utils.constant import REPO_TYPE_DATASET, REPO_TYPE_MODEL +from modelscope.utils.repo_utils import DetailedCommitInfo +from modelscope.utils.test_utils import test_level + + +class HubRepoInfoTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @patch.object(HubApi, 'get_model') + @patch.object(HubApi, 'list_repo_commits') + def test_model_info(self, mock_list_repo_commits, mock_get_model): + # Setup mock responses + mock_get_model.return_value = { + 'Id': 123, + 'Name': 'demo-model', + 'ChineseName': '测试模型', + 'Description': 'A test model', + 'Tasks': [{ + 'Name': 'text-classification', + 'Description': 'A test task' + }], + 'Tags': ['nlp', 'text'] + } + + # Mock commit history response + commit = DetailedCommitInfo( + id='abc123', + short_id='abc12', + title='Initial commit', + message='Initial commit', + author_name='Test User', + authored_date=None, + author_email='test@example.com', + committed_date=None, + committer_name='Test User', + committer_email='test@example.com', + created_at=None) + commits_response = Mock() + commits_response.commits = [commit] + mock_list_repo_commits.return_value = commits_response + + # Call the method + info = self.api.model_info( + repo_id='demo/model', revision='master', endpoint=None) + + # Verify results + self.assertEqual(info.id, 123) + self.assertEqual(info.name, 'demo-model') + self.assertEqual(info.author, 'demo') + self.assertEqual(info.chinese_name, '测试模型') + self.assertEqual(info.description, 'A test model') + self.assertEqual(info.tasks, [{ + 'Name': 'text-classification', + 'Description': 'A test task' + }]) + self.assertEqual(info.tags, ['nlp', 'text']) + self.assertEqual(info.sha, 'abc123') + self.assertEqual(info.last_commit, commit.to_dict()) + + # Verify correct method calls + mock_get_model.assert_called_once_with( + model_id='demo/model', revision='master', endpoint=None) + mock_list_repo_commits.assert_called_once_with( + repo_id='demo/model', + repo_type=REPO_TYPE_MODEL, + revision='master', + endpoint=None) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @patch.object(HubApi, 'get_dataset') + @patch.object(HubApi, 'list_repo_commits') + def test_dataset_info(self, mock_list_repo_commits, mock_get_dataset): + # Setup mock responses + mock_get_dataset.return_value = { + 'Id': 456, + 'Name': 'demo-dataset', + 'ChineseName': '演示数据集', + 'Description': 'A test dataset', + 'Tags': [ + { + 'Name': 'nlp', + 'Color': 'blue' + }, + { + 'Name': 'text', + 'Color': 'green' + } + ] + } + + # Mock commit history response + commits = [ + DetailedCommitInfo( + id='c1', + short_id='c1', + title='Update data', + message='Update data', + author_name='Test User', + authored_date=None, + author_email='test@example.com', + committed_date=None, + committer_name='Test User', + committer_email='test@example.com', + created_at=None), + DetailedCommitInfo( + id='c2', + short_id='c2', + title='Initial commit', + message='Initial commit', + author_name='Test User', + authored_date=None, + author_email='test@example.com', + committed_date=1756284063, + committer_name='Test User', + committer_email='test@example.com', + created_at=None) + ] + commits_response = Mock() + commits_response.commits = commits + mock_list_repo_commits.return_value = commits_response + + # Call the method + info = self.api.dataset_info('demo/dataset') + + # Verify results + self.assertEqual(info.id, 456) + self.assertEqual(info.name, 'demo-dataset') + self.assertEqual(info.author, 'demo') + self.assertEqual(info.chinese_name, '演示数据集') + self.assertEqual(info.description, 'A test dataset') + self.assertEqual(info.tags, [{ + 'Name': 'nlp', + 'Color': 'blue' + }, { + 'Name': 'text', + 'Color': 'green' + }]) + self.assertEqual(info.sha, 'c1') + self.assertEqual(info.last_commit, commits[0].to_dict()) + + # Verify correct method calls + mock_get_dataset.assert_called_once_with( + dataset_id='demo/dataset', revision=None, endpoint=None) + mock_list_repo_commits.assert_called_once_with( + repo_id='demo/dataset', + repo_type=REPO_TYPE_DATASET, + revision=None, + endpoint=None) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @patch.object(HubApi, 'model_info') + @patch.object(HubApi, 'dataset_info') + def test_repo_info_model(self, mock_dataset_info, mock_model_info): + # Setup mock response + model_info = ModelInfo( + id=123, name='demo-model', description='A test model') + mock_model_info.return_value = model_info + + # Call the method with model type + info = self.api.repo_info( + repo_id='demo/model', revision='master', endpoint=None) + + # Verify results + self.assertEqual(info, model_info) + mock_model_info.assert_called_once_with( + repo_id='demo/model', revision='master', endpoint=None) + mock_dataset_info.assert_not_called() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + @patch.object(HubApi, 'model_info') + @patch.object(HubApi, 'dataset_info') + def test_repo_info_dataset(self, mock_dataset_info, mock_model_info): + # Setup mock response + dataset_info = DatasetInfo( + id=456, name='demo-dataset', description='A test dataset') + mock_dataset_info.return_value = dataset_info + + # Call the method with dataset type + info = self.api.repo_info( + repo_id='demo/dataset', + repo_type=REPO_TYPE_DATASET, + revision='master', + endpoint=None) + + # Verify results + self.assertEqual(info, dataset_info) + mock_dataset_info.assert_called_once_with( + repo_id='demo/dataset', revision='master', endpoint=None) + mock_model_info.assert_not_called() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_model_info_class_comprehensive(self): + """Test ModelInfo class initialization and properties.""" + model_data = { + 'Id': 123, + 'Name': 'demo-model', + 'author': 'demo', + 'ChineseName': '演示模型', + 'Description': 'A test model', + 'Tasks': [{ + 'Name': 'text-classification', + 'Description': 'A test task' + }], + 'Tags': ['nlp', 'text'], + 'CreatedTime': '2023-01-01T00:00:00Z', + 'Visibility': 5, + 'IsPublished': 1, + 'IsOnline': 1, + 'License': 'Apache-2.0', + 'Downloads': 100, + 'Stars': 50, + 'Architectures': ['transformer'], + 'ModelType': ['nlp'] + } + + # Create mock commits + commit = DetailedCommitInfo( + id='abc123', + short_id='abc12', + title='Initial commit', + message='Initial commit', + author_name='Test User', + authored_date=None, + author_email='test@example.com', + committed_date=None, + committer_name='Test User', + committer_email='test@example.com', + created_at=None) + commits = Mock() + commits.commits = [commit] + + # Create ModelInfo instance + model_info = ModelInfo(**model_data, commits=commits) + + # Verify properties + self.assertEqual(model_info.id, 123) + self.assertEqual(model_info.name, 'demo-model') + self.assertEqual(model_info.author, 'demo') + self.assertEqual(model_info.chinese_name, '演示模型') + self.assertEqual(model_info.description, 'A test model') + self.assertEqual(model_info.tasks, [{ + 'Name': 'text-classification', + 'Description': 'A test task' + }]) + self.assertEqual(model_info.tags, ['nlp', 'text']) + self.assertEqual( + model_info.created_at, + datetime.datetime( + 2023, 1, 1, 8, 0, 0).replace(tzinfo=zoneinfo.ZoneInfo('Asia/Shanghai'))) + self.assertEqual(model_info.sha, 'abc123') + self.assertEqual(model_info.last_commit, commit.to_dict()) + self.assertEqual(model_info.visibility, 5) + self.assertEqual(model_info.is_published, 1) + self.assertEqual(model_info.is_online, 1) + self.assertEqual(model_info.license, 'Apache-2.0') + self.assertEqual(model_info.downloads, 100) + self.assertEqual(model_info.likes, 50) + self.assertEqual(model_info.architectures, ['transformer']) + self.assertEqual(model_info.model_type, ['nlp']) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_dataset_info_class(self): + # Test DatasetInfo class initialization and properties + dataset_data = { + 'Id': 456, + 'Name': 'demo-dataset', + 'Owner': 'demo', + 'ChineseName': '演示数据集', + 'Description': 'A test dataset', + 'Tags': [ + { + 'Name': 'nlp', + 'Color': 'blue' + }, + { + 'Name': 'text', + 'Color': 'green' + } + ], + 'GmtCreate': 1755752511, + 'Visibility': 5, + 'License': 'MIT', + 'Downloads': 200, + 'Likes': 75 + } + + # Create mock commits + commit = DetailedCommitInfo( + id='c1', + short_id='c1', + title='Initial commit', + message='Initial commit', + author_name='Test User', + authored_date=None, + author_email='test@example.com', + committed_date='2024-09-18T06:20:05Z', + committer_name='Test User', + committer_email='test@example.com', + created_at=None) + commits = Mock() + commits.commits = [commit] + + # Create DatasetInfo instance + dataset_info = DatasetInfo(**dataset_data, commits=commits) + + # Verify properties + self.assertEqual(dataset_info.id, 456) + self.assertEqual(dataset_info.name, 'demo-dataset') + self.assertEqual(dataset_info.author, 'demo') + self.assertEqual(dataset_info.chinese_name, '演示数据集') + self.assertEqual(dataset_info.description, 'A test dataset') + self.assertEqual( + dataset_info.tags, + [ + { + 'Name': 'nlp', + 'Color': 'blue' + }, + { + 'Name': 'text', + 'Color': 'green' + } + ] + ) + self.assertEqual( + dataset_info.created_at, + datetime.datetime( + 2025, 8, 21, 13, 1, 51).replace(tzinfo=zoneinfo.ZoneInfo('Asia/Shanghai'))) + self.assertEqual(dataset_info.sha, 'c1') + self.assertEqual( + dataset_info.last_modified, + datetime.datetime( + 2024, 9, 18, 14, 20, 5).replace(tzinfo=zoneinfo.ZoneInfo('Asia/Shanghai'))) + self.assertEqual(dataset_info.last_commit, commit.to_dict()) + self.assertEqual(dataset_info.visibility, 5) + self.assertEqual(dataset_info.license, 'MIT') + self.assertEqual(dataset_info.downloads, 200) + self.assertEqual(dataset_info.likes, 75) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_model_info_empty_commits(self): + """Test ModelInfo with empty commits.""" + model_data = {'Id': 123, 'Name': 'demo-model', 'author': 'demo'} + + # Create ModelInfo with no commits + model_info = ModelInfo(**model_data, commits=None) + + # Verify commit-related fields are None + self.assertIsNone(model_info.sha) + self.assertIsNone(model_info.last_commit) + self.assertIsNone(model_info.last_modified) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_dataset_info_empty_commits(self): + """Test DatasetInfo with empty commits.""" + dataset_data = {'Id': 456, 'Name': 'demo-dataset', 'Owner': 'demo'} + + # Create DatasetInfo with no commits + dataset_info = DatasetInfo(**dataset_data, commits=None) + + # Verify commit-related fields are None + self.assertIsNone(dataset_info.sha) + self.assertIsNone(dataset_info.last_commit) + self.assertIsNone(dataset_info.last_modified) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_detailed_commit_info_to_dict(self): + """Test DetailedCommitInfo to_dict method.""" + commit = DetailedCommitInfo( + id='abc123', + short_id='abc12', + title='Test commit', + message='Test commit message', + author_name='Test Author', + authored_date=datetime.datetime(2023, 1, 1, 0, 0, 0), + author_email='test@example.com', + committed_date=datetime.datetime(2023, 1, 1, 0, 0, 0), + committer_name='Test Committer', + committer_email='committer@example.com', + created_at=datetime.datetime(2023, 1, 1, 0, 0, 0)) + + result = commit.to_dict() + + expected = { + 'id': 'abc123', + 'short_id': 'abc12', + 'title': 'Test commit', + 'message': 'Test commit message', + 'author_name': 'Test Author', + 'authored_date': datetime.datetime(2023, 1, 1, 0, 0, 0), + 'author_email': 'test@example.com', + 'committed_date': datetime.datetime(2023, 1, 1, 0, 0, 0), + 'committer_name': 'Test Committer', + 'committer_email': 'committer@example.com', + 'created_at': datetime.datetime(2023, 1, 1, 0, 0, 0) + } + + self.assertEqual(result, expected) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_real_model_repo_info(self): + """Test getting real model repository information without mocks.""" + # Use a real model repository + model_repo_id = 'black-forest-labs/FLUX.1-Krea-dev' + + # Get repository information + info = self.api.repo_info( + repo_id=model_repo_id, repo_type=REPO_TYPE_MODEL) + + # Basic validation + self.assertIsNotNone(info) + self.assertEqual(info.author, 'black-forest-labs') + self.assertEqual(info.name, 'FLUX.1-Krea-dev') + + # Check commit information + self.assertIsNotNone(info.sha) + if hasattr(info, 'last_commit') and info.last_commit: + self.assertIn('id', info.last_commit) + self.assertIn('title', info.last_commit) + + # Print some information for debugging + print(f'\nModel Info for {model_repo_id}:') + print(f'ID: {info.id}') + print(f'Name: {info.name}') + print(f'Author: {info.author}') + print(f'SHA: {info.sha}') + if hasattr(info, 'last_modified'): + print(f'Last Modified: {info.last_modified}') + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_real_dataset_repo_info(self): + """Test getting real dataset repository information without mocks.""" + # Use a real dataset repository + dataset_repo_id = 'swift/Chinese-Qwen3-235B-2507-Distill-data-110k-SFT' + + # Get repository information + info = self.api.repo_info( + repo_id=dataset_repo_id, repo_type=REPO_TYPE_DATASET) + + # Basic validation + self.assertIsNotNone(info) + self.assertEqual(info.author, 'swift') + self.assertTrue('Chinese-Qwen3' in info.name) + + # Check commit information + self.assertIsNotNone(info.sha) + if hasattr(info, 'last_commit') and info.last_commit: + self.assertIn('id', info.last_commit) + self.assertIn('title', info.last_commit) + + # Print some information for debugging + print(f'\nDataset Info for {dataset_repo_id}:') + print(f'ID: {info.id}') + print(f'Name: {info.name}') + print(f'Author: {info.author}') + print(f'SHA: {info.sha}') + if hasattr(info, 'last_modified'): + print(f'Last Modified: {info.last_modified}') + + +if __name__ == '__main__': + unittest.main()