mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[Feat & Fix] Refactor endpoint arg for CLI (#1695)
This commit is contained in:
@@ -6,6 +6,7 @@ from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.constants import (Licenses, ModelVisibility, Visibility,
|
||||
VisibilityMap)
|
||||
from modelscope.hub.utils.aigc import AigcModel
|
||||
from modelscope.hub.utils.utils import resolve_endpoint
|
||||
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -74,13 +75,23 @@ class CreateCMD(CLICommand):
|
||||
help=
|
||||
'Optional, License of the repo. Default to `Apache License 2.0`.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--exist_ok',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help=
|
||||
'If True, do not raise error when repo already exists. Defaults to False.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--endpoint',
|
||||
type=str,
|
||||
default=None,
|
||||
help='Optional, The modelscope server address. Default to None. '
|
||||
'If not provided, the CLI will use the default official ModelScope endpoint (`https://modelscope.cn`). '
|
||||
'`https://modelscope.ai` is also supported.',
|
||||
help=
|
||||
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
|
||||
'modelscope.ai (international site). Full URL like '
|
||||
'https://modelscope.cn is also accepted. Scheme (https://) is '
|
||||
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
|
||||
'then defaults to https://www.modelscope.cn.',
|
||||
)
|
||||
|
||||
# AIGC specific arguments
|
||||
@@ -152,7 +163,8 @@ class CreateCMD(CLICommand):
|
||||
def _create_regular_repo(self):
|
||||
# Check token and login
|
||||
# The cookies will be reused if the user has logged in before.
|
||||
api = HubApi(endpoint=self.args.endpoint)
|
||||
endpoint = resolve_endpoint(self.args.endpoint)
|
||||
api = HubApi(endpoint=endpoint)
|
||||
|
||||
# Create repo
|
||||
api.create_repo(
|
||||
@@ -162,14 +174,15 @@ class CreateCMD(CLICommand):
|
||||
repo_type=self.args.repo_type,
|
||||
chinese_name=self.args.chinese_name,
|
||||
license=self.args.license,
|
||||
exist_ok=True,
|
||||
exist_ok=self.args.exist_ok,
|
||||
create_default_config=True,
|
||||
endpoint=self.args.endpoint,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
def _create_aigc_model(self):
|
||||
"""Execute the command."""
|
||||
api = HubApi(endpoint=self.args.endpoint)
|
||||
endpoint = resolve_endpoint(self.args.endpoint)
|
||||
api = HubApi(endpoint=endpoint)
|
||||
model_id = self.args.repo_id
|
||||
|
||||
if self.args.from_json:
|
||||
|
||||
@@ -10,7 +10,7 @@ from modelscope.hub.file_download import (dataset_file_download,
|
||||
model_file_download)
|
||||
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
|
||||
snapshot_download)
|
||||
from modelscope.hub.utils.utils import convert_patterns
|
||||
from modelscope.hub.utils.utils import convert_patterns, resolve_endpoint
|
||||
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -106,6 +106,18 @@ class DownloadCMD(CLICommand):
|
||||
default=None,
|
||||
help='Glob patterns to exclude from files to download.'
|
||||
'Ignored if file is specified')
|
||||
parser.add_argument(
|
||||
'--endpoint',
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
|
||||
'modelscope.ai (international site). Full URL like '
|
||||
'https://modelscope.cn is also accepted. Scheme (https://) is '
|
||||
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
|
||||
'then defaults to https://www.modelscope.cn. '
|
||||
'When omitted, the CLI auto-detects the correct site '
|
||||
'(cn/intl) for download.')
|
||||
parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
@@ -133,9 +145,13 @@ class DownloadCMD(CLICommand):
|
||||
% self.args.repo_type)
|
||||
if not self.args.model and not self.args.dataset and not self.args.collection:
|
||||
raise Exception('Model, dataset, or collection must be set.')
|
||||
if self.args.endpoint:
|
||||
endpoint = resolve_endpoint(self.args.endpoint)
|
||||
else:
|
||||
endpoint = None
|
||||
cookies = None
|
||||
if self.args.token is not None:
|
||||
api = HubApi()
|
||||
api = HubApi(endpoint=endpoint)
|
||||
cookies = api.get_cookies(access_token=self.args.token)
|
||||
if self.args.model:
|
||||
if len(self.args.files) == 1: # download single file
|
||||
@@ -146,7 +162,8 @@ class DownloadCMD(CLICommand):
|
||||
local_dir=self.args.local_dir,
|
||||
revision=self.args.revision,
|
||||
cookies=cookies,
|
||||
token=self.args.token)
|
||||
token=self.args.token,
|
||||
endpoint=endpoint)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
snapshot_download(
|
||||
@@ -157,7 +174,8 @@ class DownloadCMD(CLICommand):
|
||||
allow_file_pattern=self.args.files,
|
||||
max_workers=self.args.max_workers,
|
||||
cookies=cookies,
|
||||
token=self.args.token)
|
||||
token=self.args.token,
|
||||
endpoint=endpoint)
|
||||
else: # download repo
|
||||
snapshot_download(
|
||||
self.args.model,
|
||||
@@ -168,7 +186,8 @@ class DownloadCMD(CLICommand):
|
||||
ignore_file_pattern=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
cookies=cookies,
|
||||
token=self.args.token)
|
||||
token=self.args.token,
|
||||
endpoint=endpoint)
|
||||
print(f'\nSuccessfully Downloaded from model {self.args.model}.\n')
|
||||
elif self.args.dataset:
|
||||
dataset_revision: str = self.args.revision if self.args.revision else DEFAULT_DATASET_REVISION
|
||||
@@ -180,7 +199,8 @@ class DownloadCMD(CLICommand):
|
||||
local_dir=self.args.local_dir,
|
||||
revision=dataset_revision,
|
||||
cookies=cookies,
|
||||
token=self.args.token)
|
||||
token=self.args.token,
|
||||
endpoint=endpoint)
|
||||
elif len(
|
||||
self.args.files) > 1: # download specified multiple files.
|
||||
dataset_snapshot_download(
|
||||
@@ -191,7 +211,8 @@ class DownloadCMD(CLICommand):
|
||||
allow_file_pattern=self.args.files,
|
||||
max_workers=self.args.max_workers,
|
||||
cookies=cookies,
|
||||
token=self.args.token)
|
||||
token=self.args.token,
|
||||
endpoint=endpoint)
|
||||
else: # download repo
|
||||
dataset_snapshot_download(
|
||||
self.args.dataset,
|
||||
@@ -202,14 +223,16 @@ class DownloadCMD(CLICommand):
|
||||
ignore_file_pattern=convert_patterns(self.args.exclude),
|
||||
max_workers=self.args.max_workers,
|
||||
cookies=cookies,
|
||||
token=self.args.token)
|
||||
token=self.args.token,
|
||||
endpoint=endpoint)
|
||||
print(
|
||||
f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
|
||||
)
|
||||
elif self.args.collection:
|
||||
api = HubApi(token=self.args.token)
|
||||
api = HubApi(endpoint=endpoint, token=self.args.token)
|
||||
local_dir = self.args.local_dir or DEFAULT_SKILLS_DIR
|
||||
data = api.get_collection(self.args.collection, repo_type='skill')
|
||||
data = api.get_collection(
|
||||
self.args.collection, repo_type='skill', endpoint=endpoint)
|
||||
elements = data.get('CollectionElements',
|
||||
{}).get('CollectionElementVoList', [])
|
||||
|
||||
@@ -245,7 +268,9 @@ class DownloadCMD(CLICommand):
|
||||
skill_id = f'{element_path}/{element_name}'
|
||||
try:
|
||||
skill_dir = api.download_skill(
|
||||
skill_id=skill_id, local_dir=local_dir)
|
||||
skill_id=skill_id,
|
||||
local_dir=local_dir,
|
||||
endpoint=endpoint)
|
||||
return (skill_id, skill_dir, None)
|
||||
except Exception as e:
|
||||
return (skill_id, None, str(e))
|
||||
|
||||
@@ -4,6 +4,7 @@ from argparse import ArgumentParser
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.utils.utils import resolve_endpoint
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
@@ -28,8 +29,18 @@ class LoginCMD(CLICommand):
|
||||
type=str,
|
||||
required=True,
|
||||
help='The Access Token for modelscope.')
|
||||
parser.add_argument(
|
||||
'--endpoint',
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
|
||||
'modelscope.ai (international site). Full URL like '
|
||||
'https://modelscope.cn is also accepted. Scheme (https://) is '
|
||||
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
|
||||
'then defaults to https://www.modelscope.cn.')
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
api = HubApi()
|
||||
api = HubApi(endpoint=resolve_endpoint(self.args.endpoint))
|
||||
api.login(self.args.token)
|
||||
|
||||
@@ -4,7 +4,7 @@ from argparse import ArgumentParser, _SubParsersAction
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.utils.utils import convert_patterns, get_endpoint
|
||||
from modelscope.hub.utils.utils import convert_patterns, resolve_endpoint
|
||||
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
|
||||
|
||||
|
||||
@@ -90,8 +90,13 @@ class UploadCMD(CLICommand):
|
||||
parser.add_argument(
|
||||
'--endpoint',
|
||||
type=str,
|
||||
default=get_endpoint(),
|
||||
help='Endpoint for ModelScope service.')
|
||||
default=None,
|
||||
help=
|
||||
'ModelScope server endpoint, e.g. modelscope.cn (Chinese site) or '
|
||||
'modelscope.ai (international site). Full URL like '
|
||||
'https://modelscope.cn is also accepted. Scheme (https://) is '
|
||||
'auto-completed if omitted. Falls back to env MODELSCOPE_DOMAIN, '
|
||||
'then defaults to https://www.modelscope.cn.')
|
||||
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
@@ -135,7 +140,7 @@ class UploadCMD(CLICommand):
|
||||
self.local_path = self.args.local_path
|
||||
self.path_in_repo = self.args.path_in_repo
|
||||
|
||||
api = HubApi(endpoint=self.args.endpoint)
|
||||
api = HubApi(endpoint=resolve_endpoint(self.args.endpoint))
|
||||
|
||||
if os.path.isfile(self.local_path):
|
||||
api.upload_file(
|
||||
|
||||
@@ -3890,7 +3890,8 @@ class HubApi:
|
||||
collection_id: str,
|
||||
repo_type: str = 'skill',
|
||||
page_number: int = 1,
|
||||
page_size: int = 50) -> dict:
|
||||
page_size: int = 50,
|
||||
endpoint: Optional[str] = None) -> dict:
|
||||
"""Get collection details and its elements.
|
||||
|
||||
Args:
|
||||
@@ -3906,12 +3907,14 @@ class HubApi:
|
||||
ValueError: If repo_type is not 'skill'.
|
||||
RequestError: If the API request fails.
|
||||
"""
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
if repo_type != 'skill':
|
||||
raise ValueError(
|
||||
f'repo_type={repo_type} is not supported, '
|
||||
'only "skill" is currently supported.')
|
||||
cookies = self.get_cookies()
|
||||
path = f'{self.endpoint}/api/v1/collections'
|
||||
path = f'{endpoint}/api/v1/collections'
|
||||
params = {
|
||||
'Fid': collection_id,
|
||||
'ElementType': repo_type,
|
||||
@@ -3926,7 +3929,8 @@ class HubApi:
|
||||
return d[API_RESPONSE_FIELD_DATA]
|
||||
|
||||
def download_skill(self, skill_id: str,
|
||||
local_dir: Optional[str] = None) -> str:
|
||||
local_dir: Optional[str] = None,
|
||||
endpoint: Optional[str] = None) -> str:
|
||||
"""Download a single skill archive and extract it.
|
||||
|
||||
Args:
|
||||
@@ -3941,10 +3945,12 @@ class HubApi:
|
||||
ValueError: If skill_id format is invalid.
|
||||
RequestError: If the download request fails.
|
||||
"""
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
element_path, element_name = RepoUtils.validate_repo_id(skill_id)
|
||||
|
||||
cookies = self.get_cookies()
|
||||
url = f'{self.endpoint}/api/v1/skills/{element_path}/{element_name}/archive/zip/master'
|
||||
url = f'{endpoint}/api/v1/skills/{element_path}/{element_name}/archive/zip/master'
|
||||
|
||||
if local_dir is None:
|
||||
local_dir = os.getcwd()
|
||||
|
||||
@@ -53,6 +53,7 @@ def model_file_download(
|
||||
cookies: Optional[CookieJar] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
"""Download from a given URL and cache it if it's not already present in the local cache.
|
||||
|
||||
@@ -72,6 +73,7 @@ def model_file_download(
|
||||
cookies (CookieJar, optional): The cookie of download request.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
token (str, optional): The user token.
|
||||
endpoint (str, optional): The remote endpoint.
|
||||
|
||||
Returns:
|
||||
string: string of local file or if networking is off, last version of
|
||||
@@ -101,7 +103,8 @@ def model_file_download(
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir,
|
||||
token=token)
|
||||
token=token,
|
||||
endpoint=endpoint)
|
||||
|
||||
|
||||
def dataset_file_download(
|
||||
@@ -114,6 +117,7 @@ def dataset_file_download(
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
token: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
@@ -137,6 +141,7 @@ def dataset_file_download(
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
token (str, optional): The user token.
|
||||
endpoint (str, optional): The remote endpoint.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
@@ -162,7 +167,8 @@ def dataset_file_download(
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir,
|
||||
token=token)
|
||||
token=token,
|
||||
endpoint=endpoint)
|
||||
|
||||
|
||||
def _repo_file_download(
|
||||
@@ -178,6 +184,7 @@ def _repo_file_download(
|
||||
local_dir: Optional[str] = None,
|
||||
disable_tqdm: bool = False,
|
||||
token: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
|
||||
if not repo_type:
|
||||
@@ -224,8 +231,9 @@ def _repo_file_download(
|
||||
if cookies is None:
|
||||
cookies = _api.get_cookies()
|
||||
repo_files = []
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=repo_type, token=token)
|
||||
if endpoint is None:
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=repo_type, token=token)
|
||||
file_to_download_meta = None
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
revision = _api.get_valid_revision(
|
||||
|
||||
@@ -56,6 +56,7 @@ def snapshot_download(
|
||||
enable_file_lock: Optional[bool] = None,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
token: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
@@ -156,7 +157,8 @@ def snapshot_download(
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers,
|
||||
progress_callbacks=progress_callbacks,
|
||||
token=token)
|
||||
token=token,
|
||||
endpoint=endpoint)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
@@ -174,6 +176,7 @@ def dataset_snapshot_download(
|
||||
enable_file_lock: Optional[bool] = None,
|
||||
max_workers: int = 8,
|
||||
token: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
@@ -254,7 +257,8 @@ def dataset_snapshot_download(
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers,
|
||||
token=token)
|
||||
token=token,
|
||||
endpoint=endpoint)
|
||||
|
||||
|
||||
def _snapshot_download(
|
||||
@@ -274,6 +278,7 @@ def _snapshot_download(
|
||||
max_workers: int = 8,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
token: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
):
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
@@ -314,8 +319,9 @@ def _snapshot_download(
|
||||
headers['x-aliyun-region-id'] = region_id
|
||||
|
||||
_api = HubApi(token=token)
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=repo_type, token=token)
|
||||
if endpoint is None:
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=repo_type, token=token)
|
||||
if cookies is None:
|
||||
cookies = _api.get_cookies()
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
|
||||
@@ -210,6 +210,41 @@ def get_endpoint(cn_site=True):
|
||||
return MODELSCOPE_URL_SCHEME + get_domain(cn_site)
|
||||
|
||||
|
||||
def resolve_endpoint(cli_endpoint: Optional[str] = None,
|
||||
cn_site: bool = True) -> str:
|
||||
"""Resolve the ModelScope API endpoint with automatic scheme completion.
|
||||
|
||||
Priority (highest to lowest):
|
||||
1. ``cli_endpoint`` (explicit CLI --endpoint argument)
|
||||
2. Environment variable ``MODELSCOPE_DOMAIN``
|
||||
3. Built-in default (https://www.modelscope.cn)
|
||||
|
||||
Scheme auto-completion:
|
||||
If the resolved value does not start with ``http://`` or ``https://``,
|
||||
``https://`` is prepended automatically so that callers may pass bare
|
||||
domain names such as ``modelscope.ai``.
|
||||
|
||||
Args:
|
||||
cli_endpoint: Value from the CLI ``--endpoint`` flag. When *None*,
|
||||
the function falls back to :func:`get_endpoint`.
|
||||
cn_site: Forwarded to :func:`get_endpoint` when *cli_endpoint* is
|
||||
*None*. ``True`` selects the Chinese site, ``False`` the
|
||||
international site.
|
||||
|
||||
Returns:
|
||||
A fully-qualified endpoint URL, e.g. ``https://www.modelscope.cn``.
|
||||
"""
|
||||
if cli_endpoint is None:
|
||||
return get_endpoint(cn_site=cn_site)
|
||||
endpoint = cli_endpoint.strip().rstrip('/')
|
||||
if not endpoint:
|
||||
return get_endpoint(cn_site=cn_site)
|
||||
if not endpoint.startswith('http://') and not endpoint.startswith(
|
||||
'https://'):
|
||||
endpoint = MODELSCOPE_URL_SCHEME + endpoint
|
||||
return endpoint
|
||||
|
||||
|
||||
def compute_hash(file_path):
|
||||
# 16MB buffer for large file hash computation
|
||||
BUFFER_SIZE = 1024 * 1024 * 16
|
||||
|
||||
Reference in New Issue
Block a user