[Feat & Fix] Refactor endpoint arg for CLI (#1695)

This commit is contained in:
Xingjun.Wang
2026-04-27 17:28:37 +08:00
committed by GitHub
parent b3599deb03
commit 16a79bc80b
8 changed files with 144 additions and 35 deletions

View File

@@ -6,6 +6,7 @@ from modelscope.hub.api import HubApi
from modelscope.hub.constants import (Licenses, ModelVisibility, Visibility,
VisibilityMap)
from modelscope.hub.utils.aigc import AigcModel
from modelscope.hub.utils.utils import resolve_endpoint
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
from modelscope.utils.logger import get_logger
@@ -74,13 +75,23 @@ class CreateCMD(CLICommand):
help=
'Optional, License of the repo. Default to `Apache License 2.0`.',
)
parser.add_argument(
'--exist_ok',
action='store_true',
default=False,
help=
'If True, do not raise error when repo already exists. Defaults to False.',
)
parser.add_argument(
'--endpoint',
type=str,
default=None,
help='Optional, The modelscope server address. Default to None. '
'If not provided, the CLI will use the default official ModelScope endpoint (`https://modelscope.cn`). '
'`https://modelscope.ai` is also supported.',
help=
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
'modelscope.ai (international site). Full URL like '
'https://modelscope.cn is also accepted. Scheme (https://) is '
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
'then defaults to https://www.modelscope.cn.',
)
# AIGC specific arguments
@@ -152,7 +163,8 @@ class CreateCMD(CLICommand):
def _create_regular_repo(self):
# Check token and login
# The cookies will be reused if the user has logged in before.
api = HubApi(endpoint=self.args.endpoint)
endpoint = resolve_endpoint(self.args.endpoint)
api = HubApi(endpoint=endpoint)
# Create repo
api.create_repo(
@@ -162,14 +174,15 @@ class CreateCMD(CLICommand):
repo_type=self.args.repo_type,
chinese_name=self.args.chinese_name,
license=self.args.license,
exist_ok=True,
exist_ok=self.args.exist_ok,
create_default_config=True,
endpoint=self.args.endpoint,
endpoint=endpoint,
)
def _create_aigc_model(self):
"""Execute the command."""
api = HubApi(endpoint=self.args.endpoint)
endpoint = resolve_endpoint(self.args.endpoint)
api = HubApi(endpoint=endpoint)
model_id = self.args.repo_id
if self.args.from_json:

View File

@@ -10,7 +10,7 @@ from modelscope.hub.file_download import (dataset_file_download,
model_file_download)
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
snapshot_download)
from modelscope.hub.utils.utils import convert_patterns
from modelscope.hub.utils.utils import convert_patterns, resolve_endpoint
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
from modelscope.utils.logger import get_logger
@@ -106,6 +106,18 @@ class DownloadCMD(CLICommand):
default=None,
help='Glob patterns to exclude from files to download.'
'Ignored if file is specified')
parser.add_argument(
'--endpoint',
type=str,
default=None,
help=
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
'modelscope.ai (international site). Full URL like '
'https://modelscope.cn is also accepted. Scheme (https://) is '
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
'then defaults to https://www.modelscope.cn. '
'When omitted, the CLI auto-detects the correct site '
'(cn/intl) for download.')
parser.add_argument(
'--max-workers',
type=int,
@@ -133,9 +145,13 @@ class DownloadCMD(CLICommand):
% self.args.repo_type)
if not self.args.model and not self.args.dataset and not self.args.collection:
raise Exception('Model, dataset, or collection must be set.')
if self.args.endpoint:
endpoint = resolve_endpoint(self.args.endpoint)
else:
endpoint = None
cookies = None
if self.args.token is not None:
api = HubApi()
api = HubApi(endpoint=endpoint)
cookies = api.get_cookies(access_token=self.args.token)
if self.args.model:
if len(self.args.files) == 1: # download single file
@@ -146,7 +162,8 @@ class DownloadCMD(CLICommand):
local_dir=self.args.local_dir,
revision=self.args.revision,
cookies=cookies,
token=self.args.token)
token=self.args.token,
endpoint=endpoint)
elif len(
self.args.files) > 1: # download specified multiple files.
snapshot_download(
@@ -157,7 +174,8 @@ class DownloadCMD(CLICommand):
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
cookies=cookies,
token=self.args.token)
token=self.args.token,
endpoint=endpoint)
else: # download repo
snapshot_download(
self.args.model,
@@ -168,7 +186,8 @@ class DownloadCMD(CLICommand):
ignore_file_pattern=convert_patterns(self.args.exclude),
max_workers=self.args.max_workers,
cookies=cookies,
token=self.args.token)
token=self.args.token,
endpoint=endpoint)
print(f'\nSuccessfully Downloaded from model {self.args.model}.\n')
elif self.args.dataset:
dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
@@ -180,7 +199,8 @@ class DownloadCMD(CLICommand):
local_dir=self.args.local_dir,
revision=dataset_revision,
cookies=cookies,
token=self.args.token)
token=self.args.token,
endpoint=endpoint)
elif len(
self.args.files) > 1: # download specified multiple files.
dataset_snapshot_download(
@@ -191,7 +211,8 @@ class DownloadCMD(CLICommand):
allow_file_pattern=self.args.files,
max_workers=self.args.max_workers,
cookies=cookies,
token=self.args.token)
token=self.args.token,
endpoint=endpoint)
else: # download repo
dataset_snapshot_download(
self.args.dataset,
@@ -202,14 +223,16 @@ class DownloadCMD(CLICommand):
ignore_file_pattern=convert_patterns(self.args.exclude),
max_workers=self.args.max_workers,
cookies=cookies,
token=self.args.token)
token=self.args.token,
endpoint=endpoint)
print(
f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
)
elif self.args.collection:
api = HubApi(token=self.args.token)
api = HubApi(endpoint=endpoint, token=self.args.token)
local_dir = self.args.local_dir or DEFAULT_SKILLS_DIR
data = api.get_collection(self.args.collection, repo_type='skill')
data = api.get_collection(
self.args.collection, repo_type='skill', endpoint=endpoint)
elements = data.get('CollectionElements',
{}).get('CollectionElementVoList', [])
@@ -245,7 +268,9 @@ class DownloadCMD(CLICommand):
skill_id = f'{element_path}/{element_name}'
try:
skill_dir = api.download_skill(
skill_id=skill_id, local_dir=local_dir)
skill_id=skill_id,
local_dir=local_dir,
endpoint=endpoint)
return (skill_id, skill_dir, None)
except Exception as e:
return (skill_id, None, str(e))

View File

@@ -4,6 +4,7 @@ from argparse import ArgumentParser
from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
from modelscope.hub.utils.utils import resolve_endpoint
def subparser_func(args):
@@ -28,8 +29,18 @@ class LoginCMD(CLICommand):
type=str,
required=True,
help='The Access Token for modelscope.')
parser.add_argument(
'--endpoint',
type=str,
default=None,
help=
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
'modelscope.ai (international site). Full URL like '
'https://modelscope.cn is also accepted. Scheme (https://) is '
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
'then defaults to https://www.modelscope.cn.')
parser.set_defaults(func=subparser_func)
def execute(self):
api = HubApi()
api = HubApi(endpoint=resolve_endpoint(self.args.endpoint))
api.login(self.args.token)

View File

@@ -4,7 +4,7 @@ from argparse import ArgumentParser, _SubParsersAction
from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
from modelscope.hub.utils.utils import convert_patterns, get_endpoint
from modelscope.hub.utils.utils import convert_patterns, resolve_endpoint
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
@@ -90,8 +90,13 @@ class UploadCMD(CLICommand):
parser.add_argument(
'--endpoint',
type=str,
default=get_endpoint(),
help='Endpoint for ModelScope service.')
default=None,
help=
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
'modelscope.ai (international site). Full URL like '
'https://modelscope.cn is also accepted. Scheme (https://) is '
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
'then defaults to https://www.modelscope.cn.')
parser.set_defaults(func=subparser_func)
@@ -135,7 +140,7 @@ class UploadCMD(CLICommand):
self.local_path = self.args.local_path
self.path_in_repo = self.args.path_in_repo
api = HubApi(endpoint=self.args.endpoint)
api = HubApi(endpoint=resolve_endpoint(self.args.endpoint))
if os.path.isfile(self.local_path):
api.upload_file(

View File

@@ -3890,7 +3890,8 @@ class HubApi:
collection_id: str,
repo_type: str = 'skill',
page_number: int = 1,
page_size: int = 50) -> dict:
page_size: int = 50,
endpoint: Optional[str] = None) -> dict:
"""Get collection details and its elements.
Args:
@@ -3906,12 +3907,14 @@ class HubApi:
ValueError: If repo_type is not 'skill'.
RequestError: If the API request fails.
"""
if not endpoint:
endpoint = self.endpoint
if repo_type != 'skill':
raise ValueError(
f'repo_type={repo_type} is not supported, '
'only "skill" is currently supported.')
cookies = self.get_cookies()
path = f'{self.endpoint}/api/v1/collections'
path = f'{endpoint}/api/v1/collections'
params = {
'Fid': collection_id,
'ElementType': repo_type,
@@ -3926,7 +3929,8 @@ class HubApi:
return d[API_RESPONSE_FIELD_DATA]
def download_skill(self, skill_id: str,
local_dir: Optional[str] = None) -> str:
local_dir: Optional[str] = None,
endpoint: Optional[str] = None) -> str:
"""Download a single skill archive and extract it.
Args:
@@ -3941,10 +3945,12 @@ class HubApi:
ValueError: If skill_id format is invalid.
RequestError: If the download request fails.
"""
if not endpoint:
endpoint = self.endpoint
element_path, element_name = RepoUtils.validate_repo_id(skill_id)
cookies = self.get_cookies()
url = f'{self.endpoint}/api/v1/skills/{element_path}/{element_name}/archive/zip/master'
url = f'{endpoint}/api/v1/skills/{element_path}/{element_name}/archive/zip/master'
if local_dir is None:
local_dir = os.getcwd()

View File

@@ -53,6 +53,7 @@ def model_file_download(
cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None,
token: Optional[str] = None,
endpoint: Optional[str] = None,
) -> Optional[str]: # pragma: no cover
"""Download from a given URL and cache it if it's not already present in the local cache.
@@ -72,6 +73,7 @@ def model_file_download(
cookies (CookieJar, optional): The cookie of download request.
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
token (str, optional): The user token.
endpoint (str, optional): The remote endpoint.
Returns:
string: string of local file or if networking is off, last version of
@@ -101,7 +103,8 @@ def model_file_download(
local_files_only=local_files_only,
cookies=cookies,
local_dir=local_dir,
token=token)
token=token,
endpoint=endpoint)
def dataset_file_download(
@@ -114,6 +117,7 @@ def dataset_file_download(
local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None,
token: Optional[str] = None,
endpoint: Optional[str] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
@@ -137,6 +141,7 @@ def dataset_file_download(
local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None.
token (str, optional): The user token.
endpoint (str, optional): The remote endpoint.
Raises:
ValueError: the value details.
@@ -162,7 +167,8 @@ def dataset_file_download(
local_files_only=local_files_only,
cookies=cookies,
local_dir=local_dir,
token=token)
token=token,
endpoint=endpoint)
def _repo_file_download(
@@ -178,6 +184,7 @@ def _repo_file_download(
local_dir: Optional[str] = None,
disable_tqdm: bool = False,
token: Optional[str] = None,
endpoint: Optional[str] = None,
) -> Optional[str]: # pragma: no cover
if not repo_type:
@@ -224,8 +231,9 @@ def _repo_file_download(
if cookies is None:
cookies = _api.get_cookies()
repo_files = []
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=repo_type, token=token)
if endpoint is None:
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=repo_type, token=token)
file_to_download_meta = None
if repo_type == REPO_TYPE_MODEL:
revision = _api.get_valid_revision(

View File

@@ -56,6 +56,7 @@ def snapshot_download(
enable_file_lock: Optional[bool] = None,
progress_callbacks: List[Type[ProgressCallback]] = None,
token: Optional[str] = None,
endpoint: Optional[str] = None,
) -> str:
"""Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -156,7 +157,8 @@ def snapshot_download(
allow_patterns=allow_patterns,
max_workers=max_workers,
progress_callbacks=progress_callbacks,
token=token)
token=token,
endpoint=endpoint)
def dataset_snapshot_download(
@@ -174,6 +176,7 @@ def dataset_snapshot_download(
enable_file_lock: Optional[bool] = None,
max_workers: int = 8,
token: Optional[str] = None,
endpoint: Optional[str] = None,
) -> str:
"""Download raw files of a dataset.
Downloads all files at the specified revision. This
@@ -254,7 +257,8 @@ def dataset_snapshot_download(
ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns,
max_workers=max_workers,
token=token)
token=token,
endpoint=endpoint)
def _snapshot_download(
@@ -274,6 +278,7 @@ def _snapshot_download(
max_workers: int = 8,
progress_callbacks: List[Type[ProgressCallback]] = None,
token: Optional[str] = None,
endpoint: Optional[str] = None,
):
if not repo_type:
repo_type = REPO_TYPE_MODEL
@@ -314,8 +319,9 @@ def _snapshot_download(
headers['x-aliyun-region-id'] = region_id
_api = HubApi(token=token)
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=repo_type, token=token)
if endpoint is None:
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=repo_type, token=token)
if cookies is None:
cookies = _api.get_cookies()
if repo_type == REPO_TYPE_MODEL:

View File

@@ -210,6 +210,41 @@ def get_endpoint(cn_site=True):
return MODELSCOPE_URL_SCHEME + get_domain(cn_site)
def resolve_endpoint(cli_endpoint: Optional[str] = None,
cn_site: bool = True) -> str:
"""Resolve the ModelScope API endpoint with automatic scheme completion.
Priority (highest to lowest):
1. ``cli_endpoint`` (explicit CLI --endpoint argument)
2. Environment variable ``MODELSCOPE_DOMAIN``
3. Built-in default (https://www.modelscope.cn)
Scheme auto-completion:
If the resolved value does not start with ``http://`` or ``https://``,
``https://`` is prepended automatically so that callers may pass bare
domain names such as ``modelscope.ai``.
Args:
cli_endpoint: Value from the CLI ``--endpoint`` flag. When *None*,
the function falls back to :func:`get_endpoint`.
cn_site: Forwarded to :func:`get_endpoint` when *cli_endpoint* is
*None*. ``True`` selects the Chinese site, ``False`` the
international site.
Returns:
A fully-qualified endpoint URL, e.g. ``https://www.modelscope.cn``.
"""
if cli_endpoint is None:
return get_endpoint(cn_site=cn_site)
endpoint = cli_endpoint.strip().rstrip('/')
if not endpoint:
return get_endpoint(cn_site=cn_site)
if not endpoint.startswith('http://') and not endpoint.startswith(
'https://'):
endpoint = MODELSCOPE_URL_SCHEME + endpoint
return endpoint
def compute_hash(file_path):
# 16MB buffer for large file hash computation
BUFFER_SIZE = 1024 * 1024 * 16