diff --git a/.github/workflows/citest.yaml b/.github/workflows/citest.yaml index 8060f0bb..b3ae6816 100644 --- a/.github/workflows/citest.yaml +++ b/.github/workflows/citest.yaml @@ -54,7 +54,7 @@ jobs: - name: Checkout uses: actions/checkout@v3 with: - lfs: 'true' + lfs: 'false' submodules: 'true' fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }} - name: Get changed files @@ -65,8 +65,9 @@ jobs: else echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV fi - - name: Checkout LFS objects - run: git lfs checkout + - name: Fetch LFS objects + run: | + git lfs pull - name: Run unittest shell: bash run: | diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index f74bccd6..4cf0c02e 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -54,8 +54,8 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES, UPLOAD_MAX_FILE_SIZE, UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT, UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS, - DatasetVisibility, Licenses, - ModelVisibility, Visibility, + VALID_SORT_KEYS, DatasetVisibility, + Licenses, ModelVisibility, Visibility, VisibilityMap) from modelscope.hub.errors import (InvalidParameter, NotExistError, NotLoginException, RequestError, @@ -913,6 +913,75 @@ class HubApi: raise_for_http_status(r) return None + def list_datasets(self, + owner_or_group: str, + *, + page_number: Optional[int] = 1, + page_size: Optional[int] = 10, + sort: Optional[str] = None, + search: Optional[str] = None, + endpoint: Optional[str] = None, + ) -> dict: + """List datasets via OpenAPI with pagination, filtering and sorting. + + Args: + owner_or_group (str): Search by dataset authors (including organizations and individuals). + page_number (int, optional): The page number. Defaults to 1. + page_size (int, optional): The page size. Defaults to 10. + sort (str, optional): Sort key. If not provided, the server's default sorting is used. + choose from ['default', 'downloads', 'likes', 'last_modified']. + search (str, optional): Search by substring keywords in the dataset's Chinese name, + English name, and authors (including organizations and individuals). + endpoint (str, optional): Hub endpoint to use. When None, use the endpoint specified in the class. + + Returns: + dict: The OpenAPI data payload, e.g. + { + "datasets": [...], + "total_count": int, + "page_number": int, + "page_size": int + } + """ + if not endpoint: + endpoint = self.endpoint + path = f'{endpoint}/openapi/v1/datasets' + + # Build query params + params: Dict[str, Any] = { + 'page_number': page_number, + 'page_size': page_size, + } + if sort: + if sort not in VALID_SORT_KEYS: + raise InvalidParameter( + f'Invalid sort key: {sort}. Supported sort keys: {list(VALID_SORT_KEYS)}') + params['sort'] = sort + if search: + params['search'] = search + if owner_or_group: + params['author'] = owner_or_group + + cookies = ModelScopeConfig.get_cookies() + headers = self.builder_headers(self.headers) + + r = self.session.get( + path, + params=params, + cookies=cookies, + headers=headers + ) + raise_for_http_status(r) + resp = r.json() + + # OpenAPI success schema + if resp.get('success') is True and 'data' in resp: + return resp['data'] + else: + # Fallback for unexpected schema + msg = resp.get('message') or 'Failed to list datasets' + raise RequestError(msg) + def _check_cookie(self, use_cookies: Union[bool, CookieJar] = False) -> CookieJar: # noqa cookies = None if isinstance(use_cookies, CookieJar): @@ -1239,17 +1308,6 @@ class HubApi: logger.info(f'Create dataset success: {dataset_repo_url}') return dataset_repo_url - def list_datasets(self, endpoint: Optional[str] = None): - if not endpoint: - endpoint = self.endpoint - path = f'{endpoint}/api/v1/datasets' - params = {} - r = self.session.get(path, params=params, - headers=self.builder_headers(self.headers)) - raise_for_http_status(r) - dataset_list = r.json()[API_RESPONSE_FIELD_DATA] - return [x['Name'] for x in dataset_list] - def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None): cookies = ModelScopeConfig.get_cookies() @@ -1803,6 +1861,7 @@ class HubApi: user_name = os.environ[MODELSCOPE_CLOUD_USERNAME] download_uv_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \ f'{channel}?user={user_name}' + download_uv_resp = self.session.post(download_uv_url, cookies=cookies, headers=self.builder_headers(self.headers)) download_uv_resp = download_uv_resp.json() diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py index d08dfef7..4cec4282 100644 --- a/modelscope/hub/constants.py +++ b/modelscope/hub/constants.py @@ -116,3 +116,18 @@ VisibilityMap = { ModelVisibility.INTERNAL: Visibility.INTERNAL, ModelVisibility.PUBLIC: Visibility.PUBLIC } + + +class SortKey(object): + DEFAULT = 'default' + DOWNLOADS = 'downloads' + LIKES = 'likes' + LAST_MODIFIED = 'last_modified' + + +VALID_SORT_KEYS = { + SortKey.DEFAULT, + SortKey.DOWNLOADS, + SortKey.LIKES, + SortKey.LAST_MODIFIED, +} diff --git a/modelscope/msdatasets.lock b/modelscope/msdatasets.lock new file mode 100644 index 00000000..e69de29b diff --git a/tests/hub/test_hub_list.py b/tests/hub/test_hub_list.py new file mode 100644 index 00000000..fd6f9d73 --- /dev/null +++ b/tests/hub/test_hub_list.py @@ -0,0 +1,35 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope import HubApi +from modelscope.utils.logger import get_logger + +logger = get_logger() + +default_owner = 'modelscope' + + +class HubListHubTest(unittest.TestCase): + + def setUp(self): + self.api = HubApi() + + def test_list_datasets(self): + # Use default args + result = self.api.list_datasets(owner_or_group=default_owner) + logger.info(f'List datasets result: {result}') + + def test_list_datasets_with_args(self): + result = self.api.list_datasets( + owner_or_group=default_owner, + page_number=1, + page_size=2, + sort='downloads', + search='chinese', + ) + logger.info(f'List datasets with full result: {result}') + + def test_list_models(self): + result = self.api.list_models( + owner_or_group='Qwen', page_number=1, page_size=2) + logger.info(f'List models result: {result}')