mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
Update list datasets to OpenAPI (#1532)
* update list datasets * Update modelscope/hub/constants.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update modelscope/hub/api.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update modelscope/hub/api.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update modelscope/hub/api.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * fix lint * update workflow * fix lint --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
7
.github/workflows/citest.yaml
vendored
7
.github/workflows/citest.yaml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
lfs: 'true'
|
||||
lfs: 'false'
|
||||
submodules: 'true'
|
||||
fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
|
||||
- name: Get changed files
|
||||
@@ -65,8 +65,9 @@ jobs:
|
||||
else
|
||||
echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Checkout LFS objects
|
||||
run: git lfs checkout
|
||||
- name: Fetch LFS objects
|
||||
run: |
|
||||
git lfs pull
|
||||
- name: Run unittest
|
||||
shell: bash
|
||||
run: |
|
||||
|
||||
@@ -54,8 +54,8 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||
UPLOAD_MAX_FILE_SIZE,
|
||||
UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT,
|
||||
UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS,
|
||||
DatasetVisibility, Licenses,
|
||||
ModelVisibility, Visibility,
|
||||
VALID_SORT_KEYS, DatasetVisibility,
|
||||
Licenses, ModelVisibility, Visibility,
|
||||
VisibilityMap)
|
||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
NotLoginException, RequestError,
|
||||
@@ -913,6 +913,75 @@ class HubApi:
|
||||
raise_for_http_status(r)
|
||||
return None
|
||||
|
||||
def list_datasets(self,
|
||||
owner_or_group: str,
|
||||
*,
|
||||
page_number: Optional[int] = 1,
|
||||
page_size: Optional[int] = 10,
|
||||
sort: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""List datasets via OpenAPI with pagination, filtering and sorting.
|
||||
|
||||
Args:
|
||||
owner_or_group (str): Search by dataset authors (including organizations and individuals).
|
||||
page_number (int, optional): The page number. Defaults to 1.
|
||||
page_size (int, optional): The page size. Defaults to 10.
|
||||
sort (str, optional): Sort key. If not provided, the server's default sorting is used.
|
||||
choose from ['default', 'downloads', 'likes', 'last_modified'].
|
||||
search (str, optional): Search by substring keywords in the dataset's Chinese name,
|
||||
English name, and authors (including organizations and individuals).
|
||||
endpoint (str, optional): Hub endpoint to use. When None, use the endpoint specified in the class.
|
||||
|
||||
Returns:
|
||||
dict: The OpenAPI data payload, e.g.
|
||||
{
|
||||
"datasets": [...],
|
||||
"total_count": int,
|
||||
"page_number": int,
|
||||
"page_size": int
|
||||
}
|
||||
"""
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
path = f'{endpoint}/openapi/v1/datasets'
|
||||
|
||||
# Build query params
|
||||
params: Dict[str, Any] = {
|
||||
'page_number': page_number,
|
||||
'page_size': page_size,
|
||||
}
|
||||
if sort:
|
||||
if sort not in VALID_SORT_KEYS:
|
||||
raise InvalidParameter(
|
||||
f'Invalid sort key: {sort}. Supported sort keys: {list(VALID_SORT_KEYS)}')
|
||||
params['sort'] = sort
|
||||
if search:
|
||||
params['search'] = search
|
||||
if owner_or_group:
|
||||
params['author'] = owner_or_group
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
headers = self.builder_headers(self.headers)
|
||||
|
||||
r = self.session.get(
|
||||
path,
|
||||
params=params,
|
||||
cookies=cookies,
|
||||
headers=headers
|
||||
)
|
||||
raise_for_http_status(r)
|
||||
resp = r.json()
|
||||
|
||||
# OpenAPI success schema
|
||||
if resp.get('success') is True and 'data' in resp:
|
||||
return resp['data']
|
||||
else:
|
||||
# Fallback for unexpected schema
|
||||
msg = resp.get('message') or 'Failed to list datasets'
|
||||
raise RequestError(msg)
|
||||
|
||||
def _check_cookie(self, use_cookies: Union[bool, CookieJar] = False) -> CookieJar: # noqa
|
||||
cookies = None
|
||||
if isinstance(use_cookies, CookieJar):
|
||||
@@ -1239,17 +1308,6 @@ class HubApi:
|
||||
logger.info(f'Create dataset success: {dataset_repo_url}')
|
||||
return dataset_repo_url
|
||||
|
||||
def list_datasets(self, endpoint: Optional[str] = None):
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
path = f'{endpoint}/api/v1/datasets'
|
||||
params = {}
|
||||
r = self.session.get(path, params=params,
|
||||
headers=self.builder_headers(self.headers))
|
||||
raise_for_http_status(r)
|
||||
dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
|
||||
return [x['Name'] for x in dataset_list]
|
||||
|
||||
def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None):
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
@@ -1803,6 +1861,7 @@ class HubApi:
|
||||
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
||||
download_uv_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
|
||||
f'{channel}?user={user_name}'
|
||||
|
||||
download_uv_resp = self.session.post(download_uv_url, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
download_uv_resp = download_uv_resp.json()
|
||||
|
||||
@@ -116,3 +116,18 @@ VisibilityMap = {
|
||||
ModelVisibility.INTERNAL: Visibility.INTERNAL,
|
||||
ModelVisibility.PUBLIC: Visibility.PUBLIC
|
||||
}
|
||||
|
||||
|
||||
class SortKey(object):
|
||||
DEFAULT = 'default'
|
||||
DOWNLOADS = 'downloads'
|
||||
LIKES = 'likes'
|
||||
LAST_MODIFIED = 'last_modified'
|
||||
|
||||
|
||||
VALID_SORT_KEYS = {
|
||||
SortKey.DEFAULT,
|
||||
SortKey.DOWNLOADS,
|
||||
SortKey.LIKES,
|
||||
SortKey.LAST_MODIFIED,
|
||||
}
|
||||
|
||||
0
modelscope/msdatasets.lock
Normal file
0
modelscope/msdatasets.lock
Normal file
35
tests/hub/test_hub_list.py
Normal file
35
tests/hub/test_hub_list.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from modelscope import HubApi
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
default_owner = 'modelscope'
|
||||
|
||||
|
||||
class HubListHubTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.api = HubApi()
|
||||
|
||||
def test_list_datasets(self):
|
||||
# Use default args
|
||||
result = self.api.list_datasets(owner_or_group=default_owner)
|
||||
logger.info(f'List datasets result: {result}')
|
||||
|
||||
def test_list_datasets_with_args(self):
|
||||
result = self.api.list_datasets(
|
||||
owner_or_group=default_owner,
|
||||
page_number=1,
|
||||
page_size=2,
|
||||
sort='downloads',
|
||||
search='chinese',
|
||||
)
|
||||
logger.info(f'List datasets with full result: {result}')
|
||||
|
||||
def test_list_models(self):
|
||||
result = self.api.list_models(
|
||||
owner_or_group='Qwen', page_number=1, page_size=2)
|
||||
logger.info(f'List models result: {result}')
|
||||
Reference in New Issue
Block a user