Update list datasets to OpenAPI (#1532)

* update list datasets

* Update modelscope/hub/constants.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update modelscope/hub/api.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update modelscope/hub/api.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* Update modelscope/hub/api.py

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

* fix lint

* update workflow

* fix lint

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Yunlin Mao
2025-11-05 17:04:29 +08:00
committed by GitHub
parent c1fc7bf6c2
commit 723599ac48
5 changed files with 126 additions and 16 deletions

View File

@@ -54,7 +54,7 @@ jobs:
- name: Checkout
uses: actions/checkout@v3
with:
lfs: 'true'
lfs: 'false'
submodules: 'true'
fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
- name: Get changed files
@@ -65,8 +65,9 @@ jobs:
else
echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
fi
- name: Checkout LFS objects
run: git lfs checkout
- name: Fetch LFS objects
run: |
git lfs pull
- name: Run unittest
shell: bash
run: |

View File

@@ -54,8 +54,8 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
UPLOAD_MAX_FILE_SIZE,
UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT,
UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS,
DatasetVisibility, Licenses,
ModelVisibility, Visibility,
VALID_SORT_KEYS, DatasetVisibility,
Licenses, ModelVisibility, Visibility,
VisibilityMap)
from modelscope.hub.errors import (InvalidParameter, NotExistError,
NotLoginException, RequestError,
@@ -913,6 +913,75 @@ class HubApi:
raise_for_http_status(r)
return None
def list_datasets(self,
owner_or_group: str,
*,
page_number: Optional[int] = 1,
page_size: Optional[int] = 10,
sort: Optional[str] = None,
search: Optional[str] = None,
endpoint: Optional[str] = None,
) -> dict:
"""List datasets via OpenAPI with pagination, filtering and sorting.
Args:
owner_or_group (str): Search by dataset authors (including organizations and individuals).
page_number (int, optional): The page number. Defaults to 1.
page_size (int, optional): The page size. Defaults to 10.
sort (str, optional): Sort key. If not provided, the server's default sorting is used.
choose from ['default', 'downloads', 'likes', 'last_modified'].
search (str, optional): Search by substring keywords in the dataset's Chinese name,
English name, and authors (including organizations and individuals).
endpoint (str, optional): Hub endpoint to use. When None, use the endpoint specified in the class.
Returns:
dict: The OpenAPI data payload, e.g.
{
"datasets": [...],
"total_count": int,
"page_number": int,
"page_size": int
}
"""
if not endpoint:
endpoint = self.endpoint
path = f'{endpoint}/openapi/v1/datasets'
# Build query params
params: Dict[str, Any] = {
'page_number': page_number,
'page_size': page_size,
}
if sort:
if sort not in VALID_SORT_KEYS:
raise InvalidParameter(
f'Invalid sort key: {sort}. Supported sort keys: {list(VALID_SORT_KEYS)}')
params['sort'] = sort
if search:
params['search'] = search
if owner_or_group:
params['author'] = owner_or_group
cookies = ModelScopeConfig.get_cookies()
headers = self.builder_headers(self.headers)
r = self.session.get(
path,
params=params,
cookies=cookies,
headers=headers
)
raise_for_http_status(r)
resp = r.json()
# OpenAPI success schema
if resp.get('success') is True and 'data' in resp:
return resp['data']
else:
# Fallback for unexpected schema
msg = resp.get('message') or 'Failed to list datasets'
raise RequestError(msg)
def _check_cookie(self, use_cookies: Union[bool, CookieJar] = False) -> CookieJar: # noqa
cookies = None
if isinstance(use_cookies, CookieJar):
@@ -1239,17 +1308,6 @@ class HubApi:
logger.info(f'Create dataset success: {dataset_repo_url}')
return dataset_repo_url
def list_datasets(self, endpoint: Optional[str] = None):
if not endpoint:
endpoint = self.endpoint
path = f'{endpoint}/api/v1/datasets'
params = {}
r = self.session.get(path, params=params,
headers=self.builder_headers(self.headers))
raise_for_http_status(r)
dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
return [x['Name'] for x in dataset_list]
def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None):
cookies = ModelScopeConfig.get_cookies()
@@ -1803,6 +1861,7 @@ class HubApi:
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
download_uv_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
f'{channel}?user={user_name}'
download_uv_resp = self.session.post(download_uv_url, cookies=cookies,
headers=self.builder_headers(self.headers))
download_uv_resp = download_uv_resp.json()

View File

@@ -116,3 +116,18 @@ VisibilityMap = {
ModelVisibility.INTERNAL: Visibility.INTERNAL,
ModelVisibility.PUBLIC: Visibility.PUBLIC
}
class SortKey(object):
DEFAULT = 'default'
DOWNLOADS = 'downloads'
LIKES = 'likes'
LAST_MODIFIED = 'last_modified'
VALID_SORT_KEYS = {
SortKey.DEFAULT,
SortKey.DOWNLOADS,
SortKey.LIKES,
SortKey.LAST_MODIFIED,
}

View File

View File

@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope import HubApi
from modelscope.utils.logger import get_logger
logger = get_logger()
default_owner = 'modelscope'
class HubListHubTest(unittest.TestCase):
def setUp(self):
self.api = HubApi()
def test_list_datasets(self):
# Use default args
result = self.api.list_datasets(owner_or_group=default_owner)
logger.info(f'List datasets result: {result}')
def test_list_datasets_with_args(self):
result = self.api.list_datasets(
owner_or_group=default_owner,
page_number=1,
page_size=2,
sort='downloads',
search='chinese',
)
logger.info(f'List datasets with full result: {result}')
def test_list_models(self):
result = self.api.list_models(
owner_or_group='Qwen', page_number=1, page_size=2)
logger.info(f'List models result: {result}')