support endpoint fallback

This commit is contained in:
Yingda Chen
2025-03-04 12:34:11 +08:00
parent 352336a8a2
commit a4c06da3c2
5 changed files with 267 additions and 110 deletions

View File

@@ -34,10 +34,12 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
API_RESPONSE_FIELD_USERNAME, API_RESPONSE_FIELD_USERNAME,
DEFAULT_CREDENTIALS_PATH, DEFAULT_CREDENTIALS_PATH,
DEFAULT_MAX_WORKERS, DEFAULT_MAX_WORKERS,
DEFAULT_MODELSCOPE_DOMAIN,
MODELSCOPE_CLOUD_ENVIRONMENT, MODELSCOPE_CLOUD_ENVIRONMENT,
MODELSCOPE_CLOUD_USERNAME, MODELSCOPE_CLOUD_USERNAME,
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS, MODELSCOPE_DOMAIN,
MODELSCOPE_PREFER_INTL,
MODELSCOPE_REQUEST_ID,
MODELSCOPE_URL_SCHEME, ONE_YEAR_SECONDS,
REQUESTS_API_HTTP_METHOD, REQUESTS_API_HTTP_METHOD,
TEMPORARY_FOLDER_NAME, DatasetVisibility, TEMPORARY_FOLDER_NAME, DatasetVisibility,
Licenses, ModelVisibility, Visibility, Licenses, ModelVisibility, Visibility,
@@ -50,9 +52,9 @@ from modelscope.hub.errors import (InvalidParameter, NotExistError,
raise_for_http_status, raise_on_error) raise_for_http_status, raise_on_error)
from modelscope.hub.git import GitCommandWrapper from modelscope.hub.git import GitCommandWrapper
from modelscope.hub.repository import Repository from modelscope.hub.repository import Repository
from modelscope.hub.utils.utils import (add_content_to_file, get_endpoint, from modelscope.hub.utils.utils import (add_content_to_file, get_domain,
get_readable_folder_size, get_endpoint, get_readable_folder_size,
get_release_datetime, get_release_datetime, is_env_true,
model_id_to_group_owner_name) model_id_to_group_owner_name)
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION, DEFAULT_MODEL_REVISION,
@@ -118,14 +120,14 @@ class HubApi:
jar = RequestsCookieJar() jar = RequestsCookieJar()
jar.set('m_session_id', jar.set('m_session_id',
access_token, access_token,
domain=os.getenv('MODELSCOPE_DOMAIN', domain=get_domain(),
DEFAULT_MODELSCOPE_DOMAIN),
path='/') path='/')
return jar return jar
def login( def login(
self, self,
access_token: Optional[str] = None access_token: Optional[str] = None,
endpoint: Optional[str] = None
): ):
"""Login with your SDK access token, which can be obtained from """Login with your SDK access token, which can be obtained from
https://www.modelscope.cn user center. https://www.modelscope.cn user center.
@@ -133,6 +135,7 @@ class HubApi:
Args: Args:
access_token (str): user access token on modelscope, set this argument or set `MODELSCOPE_API_TOKEN`. access_token (str): user access token on modelscope, set this argument or set `MODELSCOPE_API_TOKEN`.
If neither of the tokens exist, login will directly return. If neither of the tokens exist, login will directly return.
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Returns: Returns:
cookies: to authenticate yourself to ModelScope open-api cookies: to authenticate yourself to ModelScope open-api
@@ -145,7 +148,9 @@ class HubApi:
access_token = os.environ.get('MODELSCOPE_API_TOKEN') access_token = os.environ.get('MODELSCOPE_API_TOKEN')
if not access_token: if not access_token:
return None, None return None, None
path = f'{self.endpoint}/api/v1/login' if not endpoint:
endpoint = self.endpoint
path = f'{endpoint}/api/v1/login'
r = self.session.post( r = self.session.post(
path, path,
json={'AccessToken': access_token}, json={'AccessToken': access_token},
@@ -172,7 +177,8 @@ class HubApi:
visibility: Optional[int] = ModelVisibility.PUBLIC, visibility: Optional[int] = ModelVisibility.PUBLIC,
license: Optional[str] = Licenses.APACHE_V2, license: Optional[str] = Licenses.APACHE_V2,
chinese_name: Optional[str] = None, chinese_name: Optional[str] = None,
original_model_id: Optional[str] = '') -> str: original_model_id: Optional[str] = '',
endpoint: Optional[str] = None) -> str:
"""Create model repo at ModelScope Hub. """Create model repo at ModelScope Hub.
Args: Args:
@@ -181,6 +187,7 @@ class HubApi:
license (str, optional): license of the model, default none. license (str, optional): license of the model, default none.
chinese_name (str, optional): chinese name of the model. chinese_name (str, optional): chinese name of the model.
original_model_id (str, optional): the base model id which this model is trained from original_model_id (str, optional): the base model id which this model is trained from
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Returns: Returns:
Name of the model created Name of the model created
@@ -197,8 +204,9 @@ class HubApi:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
if cookies is None: if cookies is None:
raise ValueError('Token does not exist, please login first.') raise ValueError('Token does not exist, please login first.')
if not endpoint:
path = f'{self.endpoint}/api/v1/models' endpoint = self.endpoint
path = f'{endpoint}/api/v1/models'
owner_or_group, name = model_id_to_group_owner_name(model_id) owner_or_group, name = model_id_to_group_owner_name(model_id)
body = { body = {
'Path': owner_or_group, 'Path': owner_or_group,
@@ -216,14 +224,15 @@ class HubApi:
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
handle_http_post_error(r, path, body) handle_http_post_error(r, path, body)
raise_on_error(r.json()) raise_on_error(r.json())
model_repo_url = f'{self.endpoint}/{model_id}' model_repo_url = f'{endpoint}/{model_id}'
return model_repo_url return model_repo_url
def delete_model(self, model_id: str): def delete_model(self, model_id: str, endpoint: Optional[str] = None):
"""Delete model_id from ModelScope. """Delete model_id from ModelScope.
Args: Args:
model_id (str): The model id. model_id (str): The model id.
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Raises: Raises:
ValueError: If not login. ValueError: If not login.
@@ -232,9 +241,11 @@ class HubApi:
model_id = {owner}/{name} model_id = {owner}/{name}
""" """
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
if not endpoint:
endpoint = self.endpoint
if cookies is None: if cookies is None:
raise ValueError('Token does not exist, please login first.') raise ValueError('Token does not exist, please login first.')
path = f'{self.endpoint}/api/v1/models/{model_id}' path = f'{endpoint}/api/v1/models/{model_id}'
r = self.session.delete(path, r = self.session.delete(path,
cookies=cookies, cookies=cookies,
@@ -242,19 +253,23 @@ class HubApi:
raise_for_http_status(r) raise_for_http_status(r)
raise_on_error(r.json()) raise_on_error(r.json())
def get_model_url(self, model_id: str): def get_model_url(self, model_id: str, endpoint: Optional[str] = None):
return f'{self.endpoint}/api/v1/models/{model_id}.git' if not endpoint:
endpoint = self.endpoint
return f'{endpoint}/api/v1/models/{model_id}.git'
def get_model( def get_model(
self, self,
model_id: str, model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION, revision: Optional[str] = DEFAULT_MODEL_REVISION,
endpoint: Optional[str] = None
) -> str: ) -> str:
"""Get model information at ModelScope """Get model information at ModelScope
Args: Args:
model_id (str): The model id. model_id (str): The model id.
revision (str optional): revision of model. revision (str optional): revision of model.
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Returns: Returns:
The model detail information. The model detail information.
@@ -267,10 +282,13 @@ class HubApi:
""" """
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
owner_or_group, name = model_id_to_group_owner_name(model_id) owner_or_group, name = model_id_to_group_owner_name(model_id)
if not endpoint:
endpoint = self.endpoint
if revision: if revision:
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}' path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}'
else: else:
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}' path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}'
r = self.session.get(path, cookies=cookies, r = self.session.get(path, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
@@ -283,11 +301,47 @@ class HubApi:
else: else:
raise_for_http_status(r) raise_for_http_status(r)
def get_endpoint_for_read(self,
repo_id: str,
*,
repo_type: Optional[str] = None) -> str:
"""Get proper endpoint for read operation (such as download, list etc.)
1. If user has set MODELSCOPE_DOMAIN, construct endpoint with user-specified domain.
If the repo does not exist on that endpoint, throw 404 error, otherwise return the endpoint.
2. If domain is not set, check existence of repo in cn-site and intl-site respectively.
Checking order is determined by MODELSCOPE_PREFER_INTL.
a. if MODELSCOPE_PREFER_INTL is not set ,check cn-site first before intl-site
b. otherwise check intl-site before cn-site
return the endpoint with which the given repo_id exists.
if neither exists, throw 404 error
"""
if MODELSCOPE_DOMAIN in os.environ:
endpoint = MODELSCOPE_URL_SCHEME + os.getenv(MODELSCOPE_DOMAIN)
if not self.repo_exists(repo_id=repo_id, repo_type=repo_type, endpoint=endpoint):
raise NotExistError(f'Repo {repo_id} not exists on {endpoint}')
else:
return endpoint
check_cn_first = not is_env_true(MODELSCOPE_PREFER_INTL)
prefer_endpoint = get_endpoint(cn_site=check_cn_first)
if not self.repo_exists(
repo_id, repo_type=repo_type, endpoint=prefer_endpoint):
logger.warning(f'Repo {repo_id} not exists on {prefer_endpoint}, will try on alternative endpoint.')
alternative_endpoint = get_endpoint(cn_site=(not check_cn_first))
if not self.repo_exists(
repo_id, repo_type=repo_type, endpoint=alternative_endpoint):
raise NotExistError(f'Repo {repo_id} not exists on either {prefer_endpoint} or {alternative_endpoint}')
else:
return alternative_endpoint
else:
return prefer_endpoint
def repo_exists( def repo_exists(
self, self,
repo_id: str, repo_id: str,
*, *,
repo_type: Optional[str] = None, repo_type: Optional[str] = None,
endpoint: Optional[str] = None,
) -> bool: ) -> bool:
""" """
Checks if a repository exists on ModelScope Checks if a repository exists on ModelScope
@@ -299,10 +353,14 @@ class HubApi:
repo_type (`str`, *optional*): repo_type (`str`, *optional*):
`None` or `"model"` if getting repository info from a model. Default is `None`. `None` or `"model"` if getting repository info from a model. Default is `None`.
TODO: support dataset and studio TODO: support dataset and studio
endpoint(`str`):
None or specific endpoint to use, when None, use the default endpoint
set in HubApi class (self.endpoint)
Returns: Returns:
True if the repository exists, False otherwise. True if the repository exists, False otherwise.
""" """
if endpoint is None:
endpoint = self.endpoint
if (repo_type is not None) and repo_type.lower() != REPO_TYPE_MODEL: if (repo_type is not None) and repo_type.lower() != REPO_TYPE_MODEL:
raise Exception('Not support repo-type: %s' % repo_type) raise Exception('Not support repo-type: %s' % repo_type)
if (repo_id is None) or repo_id.count('/') != 1: if (repo_id is None) or repo_id.count('/') != 1:
@@ -310,7 +368,7 @@ class HubApi:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
owner_or_group, name = model_id_to_group_owner_name(repo_id) owner_or_group, name = model_id_to_group_owner_name(repo_id)
path = f'{self.endpoint}/api/v1/models/{owner_or_group}/{name}' path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}'
r = self.session.get(path, cookies=cookies, r = self.session.get(path, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
@@ -476,13 +534,15 @@ class HubApi:
def list_models(self, def list_models(self,
owner_or_group: str, owner_or_group: str,
page_number: Optional[int] = 1, page_number: Optional[int] = 1,
page_size: Optional[int] = 10) -> dict: page_size: Optional[int] = 10,
endpoint: Optional[str] = None) -> dict:
"""List models in owner or group. """List models in owner or group.
Args: Args:
owner_or_group(str): owner or group. owner_or_group(str): owner or group.
page_number(int, optional): The page number, default: 1 page_number(int, optional): The page number, default: 1
page_size(int, optional): The page size, default: 10 page_size(int, optional): The page size, default: 10
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Raises: Raises:
RequestError: The request error. RequestError: The request error.
@@ -491,7 +551,9 @@ class HubApi:
dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group} dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group}
""" """
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
path = f'{self.endpoint}/api/v1/models/' if not endpoint:
endpoint = self.endpoint
path = f'{endpoint}/api/v1/models/'
r = self.session.put( r = self.session.put(
path, path,
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' % data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
@@ -547,7 +609,8 @@ class HubApi:
self, self,
model_id: str, model_id: str,
cutoff_timestamp: Optional[int] = None, cutoff_timestamp: Optional[int] = None,
use_cookies: Union[bool, CookieJar] = False) -> List[str]: use_cookies: Union[bool, CookieJar] = False,
endpoint: Optional[str] = None) -> List[str]:
"""Get model branch and tags. """Get model branch and tags.
Args: Args:
@@ -556,6 +619,7 @@ class HubApi:
The timestamp is represented by the seconds elapsed from the epoch time. The timestamp is represented by the seconds elapsed from the epoch time.
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
will load cookie from local. Defaults to False. will load cookie from local. Defaults to False.
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Returns: Returns:
Tuple[List[str], List[str]]: Return list of branch name and tags Tuple[List[str], List[str]]: Return list of branch name and tags
@@ -563,7 +627,9 @@ class HubApi:
cookies = self._check_cookie(use_cookies) cookies = self._check_cookie(use_cookies)
if cutoff_timestamp is None: if cutoff_timestamp is None:
cutoff_timestamp = get_release_datetime() cutoff_timestamp = get_release_datetime()
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp if not endpoint:
endpoint = self.endpoint
path = f'{endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
r = self.session.get(path, cookies=cookies, r = self.session.get(path, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, model_id) handle_http_response(r, logger, cookies, model_id)
@@ -582,14 +648,17 @@ class HubApi:
def get_valid_revision_detail(self, def get_valid_revision_detail(self,
model_id: str, model_id: str,
revision=None, revision=None,
cookies: Optional[CookieJar] = None): cookies: Optional[CookieJar] = None,
endpoint: Optional[str] = None):
if not endpoint:
endpoint = self.endpoint
release_timestamp = get_release_datetime() release_timestamp = get_release_datetime()
current_timestamp = int(round(datetime.datetime.now().timestamp())) current_timestamp = int(round(datetime.datetime.now().timestamp()))
# for active development in library codes (non-release-branches), release_timestamp # for active development in library codes (non-release-branches), release_timestamp
# is set to be a far-away-time-in-the-future, to ensure that we shall # is set to be a far-away-time-in-the-future, to ensure that we shall
# get the master-HEAD version from model repo by default (when no revision is provided) # get the master-HEAD version from model repo by default (when no revision is provided)
all_branches_detail, all_tags_detail = self.get_model_branches_and_tags_details( all_branches_detail, all_tags_detail = self.get_model_branches_and_tags_details(
model_id, use_cookies=False if cookies is None else cookies) model_id, use_cookies=False if cookies is None else cookies, endpoint=endpoint)
all_branches = [x['Revision'] for x in all_branches_detail] if all_branches_detail else [] all_branches = [x['Revision'] for x in all_branches_detail] if all_branches_detail else []
all_tags = [x['Revision'] for x in all_tags_detail] if all_tags_detail else [] all_tags = [x['Revision'] for x in all_tags_detail] if all_tags_detail else []
if release_timestamp > current_timestamp + ONE_YEAR_SECONDS: if release_timestamp > current_timestamp + ONE_YEAR_SECONDS:
@@ -658,6 +727,7 @@ class HubApi:
self, self,
model_id: str, model_id: str,
use_cookies: Union[bool, CookieJar] = False, use_cookies: Union[bool, CookieJar] = False,
endpoint: Optional[str] = None
) -> Tuple[List[str], List[str]]: ) -> Tuple[List[str], List[str]]:
"""Get model branch and tags. """Get model branch and tags.
@@ -665,13 +735,15 @@ class HubApi:
model_id (str): The model id model_id (str): The model id
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
will load cookie from local. Defaults to False. will load cookie from local. Defaults to False.
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Returns: Returns:
Tuple[List[str], List[str]]: Return list of branch name and tags Tuple[List[str], List[str]]: Return list of branch name and tags
""" """
cookies = self._check_cookie(use_cookies) cookies = self._check_cookie(use_cookies)
if not endpoint:
path = f'{self.endpoint}/api/v1/models/{model_id}/revisions' endpoint = self.endpoint
path = f'{endpoint}/api/v1/models/{model_id}/revisions'
r = self.session.get(path, cookies=cookies, r = self.session.get(path, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
handle_http_response(r, logger, cookies, model_id) handle_http_response(r, logger, cookies, model_id)
@@ -709,7 +781,8 @@ class HubApi:
root: Optional[str] = None, root: Optional[str] = None,
recursive: Optional[str] = False, recursive: Optional[str] = False,
use_cookies: Union[bool, CookieJar] = False, use_cookies: Union[bool, CookieJar] = False,
headers: Optional[dict] = {}) -> List[dict]: headers: Optional[dict] = {},
endpoint: Optional[str] = None) -> List[dict]:
"""List the models files. """List the models files.
Args: Args:
@@ -720,16 +793,19 @@ class HubApi:
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True, use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
will load cookie from local. Defaults to False. will load cookie from local. Defaults to False.
headers: request headers headers: request headers
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Returns: Returns:
List[dict]: Model file list. List[dict]: Model file list.
""" """
if not endpoint:
endpoint = self.endpoint
if revision: if revision:
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % ( path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % (
self.endpoint, model_id, revision, recursive) endpoint, model_id, revision, recursive)
else: else:
path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % ( path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % (
self.endpoint, model_id, recursive) endpoint, model_id, recursive)
cookies = self._check_cookie(use_cookies) cookies = self._check_cookie(use_cookies)
if root is not None: if root is not None:
path = path + f'&Root={root}' path = path + f'&Root={root}'
@@ -777,7 +853,8 @@ class HubApi:
chinese_name: Optional[str] = '', chinese_name: Optional[str] = '',
license: Optional[str] = Licenses.APACHE_V2, license: Optional[str] = Licenses.APACHE_V2,
visibility: Optional[int] = DatasetVisibility.PUBLIC, visibility: Optional[int] = DatasetVisibility.PUBLIC,
description: Optional[str] = '') -> str: description: Optional[str] = '',
endpoint: Optional[str] = None, ) -> str:
if dataset_name is None or namespace is None: if dataset_name is None or namespace is None:
raise InvalidParameter('dataset_name and namespace are required!') raise InvalidParameter('dataset_name and namespace are required!')
@@ -785,8 +862,9 @@ class HubApi:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
if cookies is None: if cookies is None:
raise ValueError('Token does not exist, please login first.') raise ValueError('Token does not exist, please login first.')
if not endpoint:
path = f'{self.endpoint}/api/v1/datasets' endpoint = self.endpoint
path = f'{endpoint}/api/v1/datasets'
files = { files = {
'Name': (None, dataset_name), 'Name': (None, dataset_name),
'ChineseName': (None, chinese_name), 'ChineseName': (None, chinese_name),
@@ -805,12 +883,14 @@ class HubApi:
handle_http_post_error(r, path, files) handle_http_post_error(r, path, files)
raise_on_error(r.json()) raise_on_error(r.json())
dataset_repo_url = f'{self.endpoint}/datasets/{namespace}/{dataset_name}' dataset_repo_url = f'{endpoint}/datasets/{namespace}/{dataset_name}'
logger.info(f'Create dataset success: {dataset_repo_url}') logger.info(f'Create dataset success: {dataset_repo_url}')
return dataset_repo_url return dataset_repo_url
def list_datasets(self): def list_datasets(self, endpoint: Optional[str] = None):
path = f'{self.endpoint}/api/v1/datasets' if not endpoint:
endpoint = self.endpoint
path = f'{endpoint}/api/v1/datasets'
params = {} params = {}
r = self.session.get(path, params=params, r = self.session.get(path, params=params,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
@@ -818,9 +898,11 @@ class HubApi:
dataset_list = r.json()[API_RESPONSE_FIELD_DATA] dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
return [x['Name'] for x in dataset_list] return [x['Name'] for x in dataset_list]
def get_dataset_id_and_type(self, dataset_name: str, namespace: str): def get_dataset_id_and_type(self, dataset_name: str, namespace: str, endpoint: Optional[str] = None):
""" Get the dataset id and type. """ """ Get the dataset id and type. """
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' if not endpoint:
endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
r = self.session.get(datahub_url, cookies=cookies) r = self.session.get(datahub_url, cookies=cookies)
resp = r.json() resp = r.json()
@@ -834,11 +916,14 @@ class HubApi:
revision: str, revision: str,
files_metadata: bool = False, files_metadata: bool = False,
timeout: float = 100, timeout: float = 100,
recursive: str = 'True'): recursive: str = 'True',
endpoint: Optional[str] = None):
""" """
Get dataset infos. Get dataset infos.
""" """
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree' if not endpoint:
endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
params = {'Revision': revision, 'Root': None, 'Recursive': recursive} params = {'Revision': revision, 'Root': None, 'Recursive': recursive}
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
if files_metadata: if files_metadata:
@@ -856,13 +941,16 @@ class HubApi:
root_path: str, root_path: str,
recursive: bool = True, recursive: bool = True,
page_number: int = 1, page_number: int = 1,
page_size: int = 100): page_size: int = 100,
endpoint: Optional[str] = None):
dataset_hub_id, dataset_type = self.get_dataset_id_and_type( dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
dataset_name=dataset_name, namespace=namespace) dataset_name=dataset_name, namespace=namespace)
recursive = 'True' if recursive else 'False' recursive = 'True' if recursive else 'False'
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree' if not endpoint:
endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
params = {'Revision': revision if revision else 'master', params = {'Revision': revision if revision else 'master',
'Root': root_path if root_path else '/', 'Recursive': recursive, 'Root': root_path if root_path else '/', 'Recursive': recursive,
'PageNumber': page_number, 'PageSize': page_size} 'PageNumber': page_number, 'PageSize': page_size}
@@ -874,9 +962,12 @@ class HubApi:
return resp return resp
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, dataset_id: str, revision: str): def get_dataset_meta_file_list(self, dataset_name: str, namespace: str,
dataset_id: str, revision: str, endpoint: Optional[str] = None):
""" Get the meta file-list of the dataset. """ """ Get the meta file-list of the dataset. """
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' if not endpoint:
endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
r = self.session.get(datahub_url, r = self.session.get(datahub_url,
cookies=cookies, cookies=cookies,
@@ -908,7 +999,8 @@ class HubApi:
def get_dataset_meta_files_local_paths(self, dataset_name: str, def get_dataset_meta_files_local_paths(self, dataset_name: str,
namespace: str, namespace: str,
revision: str, revision: str,
meta_cache_dir: str, dataset_type: int, file_list: list): meta_cache_dir: str, dataset_type: int, file_list: list,
endpoint: Optional[str] = None):
local_paths = defaultdict(list) local_paths = defaultdict(list)
dataset_formation = DatasetFormations(dataset_type) dataset_formation = DatasetFormations(dataset_type)
dataset_meta_format = DatasetMetaFormats[dataset_formation] dataset_meta_format = DatasetMetaFormats[dataset_formation]
@@ -916,12 +1008,13 @@ class HubApi:
# Dump the data_type as a local file # Dump the data_type as a local file
HubApi.dump_datatype_file(dataset_type=dataset_type, meta_cache_dir=meta_cache_dir) HubApi.dump_datatype_file(dataset_type=dataset_type, meta_cache_dir=meta_cache_dir)
if not endpoint:
endpoint = self.endpoint
for file_info in file_list: for file_info in file_list:
file_path = file_info['Path'] file_path = file_info['Path']
extension = os.path.splitext(file_path)[-1] extension = os.path.splitext(file_path)[-1]
if extension in dataset_meta_format: if extension in dataset_meta_format:
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
f'Revision={revision}&FilePath={file_path}' f'Revision={revision}&FilePath={file_path}'
r = self.session.get(datahub_url, cookies=cookies) r = self.session.get(datahub_url, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -1001,7 +1094,8 @@ class HubApi:
namespace: str, namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION, revision: Optional[str] = DEFAULT_DATASET_REVISION,
view: Optional[bool] = False, view: Optional[bool] = False,
extension_filter: Optional[bool] = True): extension_filter: Optional[bool] = True,
endpoint: Optional[str] = None):
if not file_name or not dataset_name or not namespace: if not file_name or not dataset_name or not namespace:
raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!') raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
@@ -1009,7 +1103,9 @@ class HubApi:
# Note: make sure the FilePath is the last parameter in the url # Note: make sure the FilePath is the last parameter in the url
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view} params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view}
params: str = urlencode(params) params: str = urlencode(params)
file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}' if not endpoint:
endpoint = self.endpoint
file_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
return file_url return file_url
@@ -1028,9 +1124,12 @@ class HubApi:
file_name: str, file_name: str,
dataset_name: str, dataset_name: str,
namespace: str, namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION): revision: Optional[str] = DEFAULT_DATASET_REVISION,
endpoint: Optional[str] = None):
if not endpoint:
endpoint = self.endpoint
if file_name and os.path.splitext(file_name)[-1] in META_FILES_FORMAT: if file_name and os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \ file_name = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
f'Revision={revision}&FilePath={file_name}' f'Revision={revision}&FilePath={file_name}'
return file_name return file_name
@@ -1038,8 +1137,11 @@ class HubApi:
self, self,
dataset_name: str, dataset_name: str,
namespace: str, namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION): revision: Optional[str] = DEFAULT_DATASET_REVISION,
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ endpoint: Optional[str] = None):
if not endpoint:
endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
f'ststoken?Revision={revision}' f'ststoken?Revision={revision}'
return self.datahub_remote_call(datahub_url) return self.datahub_remote_call(datahub_url)
@@ -1048,9 +1150,12 @@ class HubApi:
dataset_name: str, dataset_name: str,
namespace: str, namespace: str,
check_cookie: bool, check_cookie: bool,
revision: Optional[str] = DEFAULT_DATASET_REVISION): revision: Optional[str] = DEFAULT_DATASET_REVISION,
endpoint: Optional[str] = None):
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ if not endpoint:
endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
f'ststoken?Revision={revision}' f'ststoken?Revision={revision}'
if check_cookie: if check_cookie:
cookies = self._check_cookie(use_cookies=True) cookies = self._check_cookie(use_cookies=True)
@@ -1098,8 +1203,11 @@ class HubApi:
dataset_name: str, dataset_name: str,
namespace: str, namespace: str,
revision: str, revision: str,
zip_file_name: str): zip_file_name: str,
datahub_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}' endpoint: Optional[str] = None):
if not endpoint:
endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
r = self.session.get(url=datahub_url, cookies=cookies, r = self.session.get(url=datahub_url, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
@@ -1120,8 +1228,10 @@ class HubApi:
return data_sts return data_sts
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
is_recursive, is_filter_dir, revision): is_recursive, is_filter_dir, revision, endpoint: Optional[str] = None):
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ if not endpoint:
endpoint = self.endpoint
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
@@ -1132,11 +1242,12 @@ class HubApi:
return resp return resp
def delete_oss_dataset_object(self, object_name: str, dataset_name: str, def delete_oss_dataset_object(self, object_name: str, dataset_name: str,
namespace: str, revision: str) -> str: namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
if not object_name or not dataset_name or not namespace or not revision: if not object_name or not dataset_name or not namespace or not revision:
raise ValueError('Args cannot be empty!') raise ValueError('Args cannot be empty!')
if not endpoint:
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}' endpoint = self.endpoint
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}'
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
resp = self.session.delete(url=url, cookies=cookies) resp = self.session.delete(url=url, cookies=cookies)
@@ -1146,11 +1257,12 @@ class HubApi:
return resp return resp
def delete_oss_dataset_dir(self, object_name: str, dataset_name: str, def delete_oss_dataset_dir(self, object_name: str, dataset_name: str,
namespace: str, revision: str) -> str: namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
if not object_name or not dataset_name or not namespace or not revision: if not object_name or not dataset_name or not namespace or not revision:
raise ValueError('Args cannot be empty!') raise ValueError('Args cannot be empty!')
if not endpoint:
url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \ endpoint = self.endpoint
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \
f'&Revision={revision}' f'&Revision={revision}'
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
@@ -1170,14 +1282,17 @@ class HubApi:
datahub_raise_on_error(url, resp, r) datahub_raise_on_error(url, resp, r)
return resp['Data'] return resp['Data']
def dataset_download_statistics(self, dataset_name: str, namespace: str, use_streaming: bool = False) -> None: def dataset_download_statistics(self, dataset_name: str, namespace: str,
use_streaming: bool = False, endpoint: Optional[str] = None) -> None:
is_ci_test = os.getenv('CI_TEST') == 'True' is_ci_test = os.getenv('CI_TEST') == 'True'
if not endpoint:
endpoint = self.endpoint
if dataset_name and namespace and not is_ci_test and not use_streaming: if dataset_name and namespace and not is_ci_test and not use_streaming:
try: try:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
# Download count # Download count
download_count_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' download_count_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
download_count_resp = self.session.post(download_count_url, cookies=cookies, download_count_resp = self.session.post(download_count_url, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
raise_for_http_status(download_count_resp) raise_for_http_status(download_count_resp)
@@ -1189,7 +1304,7 @@ class HubApi:
channel = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT] channel = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
if MODELSCOPE_CLOUD_USERNAME in os.environ: if MODELSCOPE_CLOUD_USERNAME in os.environ:
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME] user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
download_uv_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \ download_uv_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
f'{channel}?user={user_name}' f'{channel}?user={user_name}'
download_uv_resp = self.session.post(download_uv_url, cookies=cookies, download_uv_resp = self.session.post(download_uv_url, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
@@ -1203,9 +1318,11 @@ class HubApi:
return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex), return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
**headers} **headers}
def get_file_base_path(self, repo_id: str) -> str: def get_file_base_path(self, repo_id: str, endpoint: Optional[str] = None) -> str:
_namespace, _dataset_name = repo_id.split('/') _namespace, _dataset_name = repo_id.split('/')
return f'{self.endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?' if not endpoint:
endpoint = self.endpoint
return f'{endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?'
# return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath=' # return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
def create_repo( def create_repo(
@@ -1217,13 +1334,15 @@ class HubApi:
repo_type: Optional[str] = REPO_TYPE_MODEL, repo_type: Optional[str] = REPO_TYPE_MODEL,
chinese_name: Optional[str] = '', chinese_name: Optional[str] = '',
license: Optional[str] = Licenses.APACHE_V2, license: Optional[str] = Licenses.APACHE_V2,
endpoint: Optional[str] = None,
**kwargs, **kwargs,
) -> str: ) -> str:
# TODO: exist_ok # TODO: exist_ok
if not repo_id: if not repo_id:
raise ValueError('Repo id cannot be empty!') raise ValueError('Repo id cannot be empty!')
if not endpoint:
endpoint = self.endpoint
self.login(access_token=token) self.login(access_token=token)
repo_id_list = repo_id.split('/') repo_id_list = repo_id.split('/')
@@ -1261,7 +1380,7 @@ class HubApi:
'configuration.json', [json.dumps(config)], 'configuration.json', [json.dumps(config)],
ignore_push_error=True) ignore_push_error=True)
else: else:
repo_url = f'{self.endpoint}/{repo_id}' repo_url = f'{endpoint}/{repo_id}'
elif repo_type == REPO_TYPE_DATASET: elif repo_type == REPO_TYPE_DATASET:
visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')} visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')}
@@ -1278,7 +1397,7 @@ class HubApi:
visibility=visibility, visibility=visibility,
) )
else: else:
repo_url = f'{self.endpoint}/datasets/{namespace}/{repo_name}' repo_url = f'{endpoint}/datasets/{namespace}/{repo_name}'
else: else:
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}') raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
@@ -1295,9 +1414,12 @@ class HubApi:
token: str = None, token: str = None,
repo_type: Optional[str] = None, repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
endpoint: Optional[str] = None
) -> CommitInfo: ) -> CommitInfo:
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/commit/{revision}' if not endpoint:
endpoint = self.endpoint
url = f'{endpoint}/api/v1/repos/{repo_type}s/{repo_id}/commit/{revision}'
commit_message = commit_message or f'Commit to {repo_id}' commit_message = commit_message or f'Commit to {repo_id}'
commit_description = commit_description or '' commit_description = commit_description or ''
@@ -1640,6 +1762,7 @@ class HubApi:
repo_id: str, repo_id: str,
repo_type: str, repo_type: str,
objects: List[Dict[str, Any]], objects: List[Dict[str, Any]],
endpoint: Optional[str] = None
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Check the blob has already uploaded. Check the blob has already uploaded.
@@ -1651,13 +1774,16 @@ class HubApi:
objects (List[Dict[str, Any]]): The objects to check. objects (List[Dict[str, Any]]): The objects to check.
oid (str): The sha256 hash value. oid (str): The sha256 hash value.
size (int): The size of the blob. size (int): The size of the blob.
endpoint: the endpoint to use, default to None to use endpoint specified in the class
Returns: Returns:
List[Dict[str, Any]]: The result of the check. List[Dict[str, Any]]: The result of the check.
""" """
# construct URL # construct URL
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/info/lfs/objects/batch' if not endpoint:
endpoint = self.endpoint
url = f'{endpoint}/api/v1/repos/{repo_type}s/{repo_id}/info/lfs/objects/batch'
# build payload # build payload
payload = { payload = {

View File

@@ -4,7 +4,9 @@ from pathlib import Path
MODELSCOPE_URL_SCHEME = 'https://' MODELSCOPE_URL_SCHEME = 'https://'
DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn' DEFAULT_MODELSCOPE_DOMAIN = 'www.modelscope.cn'
DEFAULT_MODELSCOPE_INTL_DOMAIN = 'www.modelscope.ai'
DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN DEFAULT_MODELSCOPE_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_DOMAIN
DEFAULT_MODELSCOPE_INTL_DATA_ENDPOINT = MODELSCOPE_URL_SCHEME + DEFAULT_MODELSCOPE_INTL_DOMAIN
MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB = int( MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB = int(
os.environ.get('MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB', 500)) os.environ.get('MODELSCOPE_PARALLEL_DOWNLOAD_THRESHOLD_MB', 500))
MODELSCOPE_DOWNLOAD_PARALLELS = int( MODELSCOPE_DOWNLOAD_PARALLELS = int(
@@ -28,6 +30,8 @@ API_RESPONSE_FIELD_MESSAGE = 'Message'
MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT' MODELSCOPE_CLOUD_ENVIRONMENT = 'MODELSCOPE_ENVIRONMENT'
MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME' MODELSCOPE_CLOUD_USERNAME = 'MODELSCOPE_USERNAME'
MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG' MODELSCOPE_SDK_DEBUG = 'MODELSCOPE_SDK_DEBUG'
MODELSCOPE_PREFER_INTL = 'MODELSCOPE_PREFER_INTL'
MODELSCOPE_DOMAIN = 'MODELSCOPE_DOMAIN'
MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION = 'MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION' MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION = 'MODELSCOPE_ENABLE_DEFAULT_HASH_VALIDATION'
ONE_YEAR_SECONDS = 24 * 365 * 60 * 60 ONE_YEAR_SECONDS = 24 * 365 * 60 * 60
MODELSCOPE_REQUEST_ID = 'X-Request-ID' MODELSCOPE_REQUEST_ID = 'X-Request-ID'

View File

@@ -199,17 +199,19 @@ def _repo_file_download(
if cookies is None: if cookies is None:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
repo_files = [] repo_files = []
endpoint = _api.get_endpoint_for_read(repo_id=repo_id, repo_type=repo_type)
file_to_download_meta = None file_to_download_meta = None
if repo_type == REPO_TYPE_MODEL: if repo_type == REPO_TYPE_MODEL:
revision = _api.get_valid_revision( revision = _api.get_valid_revision(
repo_id, revision=revision, cookies=cookies) repo_id, revision=revision, cookies=cookies, endpoint=endpoint)
# we need to confirm the version is up-to-date # we need to confirm the version is up-to-date
# we need to get the file list to check if the latest version is cached, if so return, otherwise download # we need to get the file list to check if the latest version is cached, if so return, otherwise download
repo_files = _api.get_model_files( repo_files = _api.get_model_files(
model_id=repo_id, model_id=repo_id,
revision=revision, revision=revision,
recursive=True, recursive=True,
use_cookies=False if cookies is None else cookies) use_cookies=False if cookies is None else cookies,
endpoint=endpoint)
for repo_file in repo_files: for repo_file in repo_files:
if repo_file['Type'] == 'tree': if repo_file['Type'] == 'tree':
continue continue
@@ -238,7 +240,8 @@ def _repo_file_download(
root_path='/', root_path='/',
recursive=True, recursive=True,
page_number=page_number, page_number=page_number,
page_size=page_size) page_size=page_size,
endpoint=endpoint)
if not ('Code' in files_list_tree if not ('Code' in files_list_tree
and files_list_tree['Code'] == 200): and files_list_tree['Code'] == 200):
print( print(
@@ -273,13 +276,15 @@ def _repo_file_download(
# we need to download again # we need to download again
if repo_type == REPO_TYPE_MODEL: if repo_type == REPO_TYPE_MODEL:
url_to_download = get_file_download_url(repo_id, file_path, revision) url_to_download = get_file_download_url(repo_id, file_path, revision,
endpoint)
elif repo_type == REPO_TYPE_DATASET: elif repo_type == REPO_TYPE_DATASET:
url_to_download = _api.get_dataset_file_url( url_to_download = _api.get_dataset_file_url(
file_name=file_to_download_meta['Path'], file_name=file_to_download_meta['Path'],
dataset_name=name, dataset_name=name,
namespace=group_or_owner, namespace=group_or_owner,
revision=revision) revision=revision,
endpoint=endpoint)
else: else:
raise ValueError(f'Invalid repo type {repo_type}') raise ValueError(f'Invalid repo type {repo_type}')
@@ -354,7 +359,8 @@ def create_temporary_directory_and_cache(model_id: str,
return temporary_cache_dir, cache return temporary_cache_dir, cache
def get_file_download_url(model_id: str, file_path: str, revision: str): def get_file_download_url(model_id: str, file_path: str, revision: str,
endpoint: str):
"""Format file download url according to `model_id`, `revision` and `file_path`. """Format file download url according to `model_id`, `revision` and `file_path`.
e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`, e.g., Given `model_id=john/bert`, `revision=master`, `file_path=README.md`,
the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md the resulted download url is: https://modelscope.cn/api/v1/models/john/bert/repo?Revision=master&FilePath=README.md
@@ -363,6 +369,7 @@ def get_file_download_url(model_id: str, file_path: str, revision: str):
model_id (str): The model_id. model_id (str): The model_id.
file_path (str): File path file_path (str): File path
revision (str): File revision. revision (str): File revision.
endpoint (str): The remote endpoint
Returns: Returns:
str: The file url. str: The file url.
@@ -370,8 +377,10 @@ def get_file_download_url(model_id: str, file_path: str, revision: str):
file_path = urllib.parse.quote_plus(file_path) file_path = urllib.parse.quote_plus(file_path)
revision = urllib.parse.quote_plus(revision) revision = urllib.parse.quote_plus(revision)
download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}' download_url_template = '{endpoint}/api/v1/models/{model_id}/repo?Revision={revision}&FilePath={file_path}'
if not endpoint:
endpoint = get_endpoint()
return download_url_template.format( return download_url_template.format(
endpoint=get_endpoint(), endpoint=endpoint,
model_id=model_id, model_id=model_id,
revision=revision, revision=revision,
file_path=file_path, file_path=file_path,

View File

@@ -243,6 +243,8 @@ def _snapshot_download(
# To count the download statistics, to add the snapshot-identifier as a header. # To count the download statistics, to add the snapshot-identifier as a header.
headers['snapshot-identifier'] = str(uuid.uuid4()) headers['snapshot-identifier'] = str(uuid.uuid4())
_api = HubApi() _api = HubApi()
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=repo_type)
if cookies is None: if cookies is None:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
if repo_type == REPO_TYPE_MODEL: if repo_type == REPO_TYPE_MODEL:
@@ -255,7 +257,7 @@ def _snapshot_download(
*repo_id.split('/')) *repo_id.split('/'))
print(f'Downloading Model to directory: {directory}') print(f'Downloading Model to directory: {directory}')
revision_detail = _api.get_valid_revision_detail( revision_detail = _api.get_valid_revision_detail(
repo_id, revision=revision, cookies=cookies) repo_id, revision=revision, cookies=cookies, endpoint=endpoint)
revision = revision_detail['Revision'] revision = revision_detail['Revision']
snapshot_header = headers if 'CI_TEST' in os.environ else { snapshot_header = headers if 'CI_TEST' in os.environ else {
@@ -274,7 +276,7 @@ def _snapshot_download(
recursive=True, recursive=True,
use_cookies=False if cookies is None else cookies, use_cookies=False if cookies is None else cookies,
headers=snapshot_header, headers=snapshot_header,
) endpoint=endpoint)
_download_file_lists( _download_file_lists(
repo_files, repo_files,
cache, cache,
@@ -324,7 +326,7 @@ def _snapshot_download(
logger.info('Fetching dataset repo file list...') logger.info('Fetching dataset repo file list...')
repo_files = fetch_repo_files(_api, name, group_or_owner, repo_files = fetch_repo_files(_api, name, group_or_owner,
revision_detail) revision_detail, endpoint)
if repo_files is None: if repo_files is None:
logger.error( logger.error(
@@ -354,7 +356,7 @@ def _snapshot_download(
return cache_root_path return cache_root_path
def fetch_repo_files(_api, name, group_or_owner, revision): def fetch_repo_files(_api, name, group_or_owner, revision, endpoint):
page_number = 1 page_number = 1
page_size = 150 page_size = 150
repo_files = [] repo_files = []
@@ -367,7 +369,8 @@ def fetch_repo_files(_api, name, group_or_owner, revision):
root_path='/', root_path='/',
recursive=True, recursive=True,
page_number=page_number, page_number=page_number,
page_size=page_size) page_size=page_size,
endpoint=endpoint)
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200): if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
logger.error(f'Get dataset file list failed, request_id: \ logger.error(f'Get dataset file list failed, request_id: \
@@ -416,22 +419,23 @@ def _get_valid_regex_pattern(patterns: List[str]):
def _download_file_lists( def _download_file_lists(
repo_files: List[str], repo_files: List[str],
cache: ModelFileSystemCache, cache: ModelFileSystemCache,
temporary_cache_dir: str, temporary_cache_dir: str,
repo_id: str, repo_id: str,
api: HubApi, api: HubApi,
name: str, name: str,
group_or_owner: str, group_or_owner: str,
headers, headers,
repo_type: Optional[str] = None, repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_MODEL_REVISION, revision: Optional[str] = DEFAULT_MODEL_REVISION,
cookies: Optional[CookieJar] = None, cookies: Optional[CookieJar] = None,
ignore_file_pattern: Optional[Union[str, List[str]]] = None, ignore_file_pattern: Optional[Union[str, List[str]]] = None,
allow_file_pattern: Optional[Union[str, List[str]]] = None, allow_file_pattern: Optional[Union[str, List[str]]] = None,
allow_patterns: Optional[Union[List[str], str]] = None, allow_patterns: Optional[Union[List[str], str]] = None,
ignore_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8): max_workers: int = 8,
):
ignore_patterns = _normalize_patterns(ignore_patterns) ignore_patterns = _normalize_patterns(ignore_patterns)
allow_patterns = _normalize_patterns(allow_patterns) allow_patterns = _normalize_patterns(allow_patterns)
ignore_file_pattern = _normalize_patterns(ignore_file_pattern) ignore_file_pattern = _normalize_patterns(ignore_file_pattern)

View File

@@ -8,7 +8,9 @@ from typing import List, Optional, Union
from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN, from modelscope.hub.constants import (DEFAULT_MODELSCOPE_DOMAIN,
DEFAULT_MODELSCOPE_GROUP, DEFAULT_MODELSCOPE_GROUP,
MODEL_ID_SEPARATOR, MODELSCOPE_SDK_DEBUG, DEFAULT_MODELSCOPE_INTL_DOMAIN,
MODEL_ID_SEPARATOR, MODELSCOPE_DOMAIN,
MODELSCOPE_SDK_DEBUG,
MODELSCOPE_URL_SCHEME) MODELSCOPE_URL_SCHEME)
from modelscope.hub.errors import FileIntegrityError from modelscope.hub.errors import FileIntegrityError
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
@@ -26,6 +28,20 @@ def model_id_to_group_owner_name(model_id):
return group_or_owner, name return group_or_owner, name
def is_env_true(var_name):
value = os.environ.get(var_name, '').strip().lower()
return value == 'true'
def get_domain(cn_site=True):
if MODELSCOPE_DOMAIN in os.environ:
return os.getenv(MODELSCOPE_DOMAIN)
if cn_site:
return DEFAULT_MODELSCOPE_DOMAIN
else:
return DEFAULT_MODELSCOPE_INTL_DOMAIN
def convert_patterns(raw_input: Union[str, List[str]]): def convert_patterns(raw_input: Union[str, List[str]]):
output = None output = None
if isinstance(raw_input, str): if isinstance(raw_input, str):
@@ -105,10 +121,8 @@ def get_release_datetime():
return rt return rt
def get_endpoint(): def get_endpoint(cn_site=True):
modelscope_domain = os.getenv('MODELSCOPE_DOMAIN', return MODELSCOPE_URL_SCHEME + get_domain(cn_site)
DEFAULT_MODELSCOPE_DOMAIN)
return MODELSCOPE_URL_SCHEME + modelscope_domain
def compute_hash(file_path): def compute_hash(file_path):