update get cookie

This commit is contained in:
Yunnglin
2025-12-15 20:31:17 +08:00
parent 95aaca3421
commit 5a9e5485f0
9 changed files with 161 additions and 100 deletions

View File

@@ -102,13 +102,15 @@ class HubApi:
def __init__(self, def __init__(self,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
timeout=API_HTTP_CLIENT_TIMEOUT, timeout=API_HTTP_CLIENT_TIMEOUT,
max_retries=API_HTTP_CLIENT_MAX_RETRIES): max_retries=API_HTTP_CLIENT_MAX_RETRIES,
token: Optional[str] = None):
"""The ModelScope HubApi。 """The ModelScope HubApi。
Args: Args:
endpoint (str, optional): The modelscope server http|https address. Defaults to None. endpoint (str, optional): The modelscope server http|https address. Defaults to None.
""" """
self.endpoint = endpoint if endpoint is not None else get_endpoint() self.endpoint = endpoint if endpoint is not None else get_endpoint()
self.token = token
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
self.session = Session() self.session = Session()
retry = Retry( retry = Retry(
@@ -154,12 +156,12 @@ class HubApi:
path='/') path='/')
return jar return jar
def get_cookies(self, access_token, cookies_required: Optional[bool] = False): def get_cookies(self, access_token: Optional[str] = None, cookies_required: Optional[bool] = False):
""" """
Get cookies for authentication from local cache or access_token. Get cookies for authentication from local cache or access_token.
Args: Args:
access_token (str): user access token on ModelScope access_token (Optional[str]): user access token on ModelScope. If not provided, try to get from local cache.
cookies_required (bool): whether to raise error if no cookies found, defaults to `False`. cookies_required (bool): whether to raise error if no cookies found, defaults to `False`.
Returns: Returns:
@@ -168,8 +170,9 @@ class HubApi:
Raises: Raises:
ValueError: If no credentials found and cookies_required is True. ValueError: If no credentials found and cookies_required is True.
""" """
if access_token: token = access_token if access_token is not None else self.token
cookies = self._get_cookies(access_token=access_token) if token:
cookies = self._get_cookies(access_token=token)
else: else:
cookies = ModelScopeConfig.get_cookies() cookies = ModelScopeConfig.get_cookies()
@@ -203,8 +206,7 @@ class HubApi:
Note: Note:
You only have to login once within 30 days. You only have to login once within 30 days.
""" """
if access_token is None: access_token = access_token or self.token or os.environ.get('MODELSCOPE_API_TOKEN')
access_token = os.environ.get('MODELSCOPE_API_TOKEN')
if not access_token: if not access_token:
return None, None return None, None
if not endpoint: if not endpoint:
@@ -423,12 +425,13 @@ class HubApi:
tag_url = f'{endpoint}/models/{model_id}/tags/{tag_name}' tag_url = f'{endpoint}/models/{model_id}/tags/{tag_name}'
return tag_url return tag_url
def delete_model(self, model_id: str, endpoint: Optional[str] = None): def delete_model(self, model_id: str, endpoint: Optional[str] = None, token: Optional[str] = None):
"""Delete model_id from ModelScope. """Delete model_id from ModelScope.
Args: Args:
model_id (str): The model id. model_id (str): The model id.
endpoint: the endpoint to use, default to None to use endpoint specified in the class endpoint: the endpoint to use, default to None to use endpoint specified in the class
token (str, optional): access token for authentication
Raises: Raises:
ValueError: If not login. ValueError: If not login.
@@ -436,7 +439,7 @@ class HubApi:
Note: Note:
model_id = {owner}/{name} model_id = {owner}/{name}
""" """
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
if cookies is None: if cookies is None:
@@ -458,7 +461,8 @@ class HubApi:
self, self,
model_id: str, model_id: str,
revision: Optional[str] = DEFAULT_MODEL_REVISION, revision: Optional[str] = DEFAULT_MODEL_REVISION,
endpoint: Optional[str] = None endpoint: Optional[str] = None,
token: Optional[str] = None,
) -> dict: ) -> dict:
"""Get model information at ModelScope """Get model information at ModelScope
@@ -466,6 +470,7 @@ class HubApi:
model_id (str): The model id. model_id (str): The model id.
revision (str optional): revision of model. revision (str optional): revision of model.
endpoint: the endpoint to use, default to None to use endpoint specified in the class endpoint: the endpoint to use, default to None to use endpoint specified in the class
token (str, optional): access token for authentication
Returns: Returns:
The model detail information. The model detail information.
@@ -476,7 +481,7 @@ class HubApi:
Note: Note:
model_id = {owner}/{name} model_id = {owner}/{name}
""" """
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=False)
owner_or_group, name = model_id_to_group_owner_name(model_id) owner_or_group, name = model_id_to_group_owner_name(model_id)
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
@@ -744,7 +749,8 @@ class HubApi:
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
original_model_id: Optional[str] = None, original_model_id: Optional[str] = None,
ignore_file_pattern: Optional[Union[List[str], str]] = None, ignore_file_pattern: Optional[Union[List[str], str]] = None,
lfs_suffix: Optional[Union[str, List[str]]] = None): lfs_suffix: Optional[Union[str, List[str]]] = None,
token: Optional[str] = None):
warnings.warn( warnings.warn(
'This function is deprecated and will be removed in future versions. ' 'This function is deprecated and will be removed in future versions. '
'Please use git command directly or use HubApi().upload_folder instead', 'Please use git command directly or use HubApi().upload_folder instead',
@@ -810,7 +816,7 @@ class HubApi:
f'No {ModelFile.CONFIGURATION} file found in {model_dir}, creating a default one.') f'No {ModelFile.CONFIGURATION} file found in {model_dir}, creating a default one.')
HubApi._create_default_config(model_dir) HubApi._create_default_config(model_dir)
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
if cookies is None: if cookies is None:
raise NotLoginException('Must login before upload!') raise NotLoginException('Must login before upload!')
files_to_save = os.listdir(model_dir) files_to_save = os.listdir(model_dir)
@@ -881,7 +887,8 @@ class HubApi:
owner_or_group: str, owner_or_group: str,
page_number: Optional[int] = 1, page_number: Optional[int] = 1,
page_size: Optional[int] = 10, page_size: Optional[int] = 10,
endpoint: Optional[str] = None) -> dict: endpoint: Optional[str] = None,
token: Optional[str] = None) -> dict:
"""List models in owner or group. """List models in owner or group.
Args: Args:
@@ -889,6 +896,7 @@ class HubApi:
page_number(int, optional): The page number, default: 1 page_number(int, optional): The page number, default: 1
page_size(int, optional): The page size, default: 10 page_size(int, optional): The page size, default: 10
endpoint: the endpoint to use, default to None to use endpoint specified in the class endpoint: the endpoint to use, default to None to use endpoint specified in the class
token (str, optional): access token for authentication
Raises: Raises:
RequestError: The request error. RequestError: The request error.
@@ -896,7 +904,7 @@ class HubApi:
Returns: Returns:
dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group} dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group}
""" """
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=False)
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
path = f'{endpoint}/api/v1/models/' path = f'{endpoint}/api/v1/models/'
@@ -925,7 +933,7 @@ class HubApi:
sort: Optional[str] = None, sort: Optional[str] = None,
search: Optional[str] = None, search: Optional[str] = None,
endpoint: Optional[str] = None, endpoint: Optional[str] = None,
) -> dict: token: Optional[str] = None) -> dict:
"""List datasets via OpenAPI with pagination, filtering and sorting. """List datasets via OpenAPI with pagination, filtering and sorting.
Args: Args:
@@ -937,6 +945,7 @@ class HubApi:
search (str, optional): Search by substring keywords in the dataset's Chinese name, search (str, optional): Search by substring keywords in the dataset's Chinese name,
English name, and authors (including organizations and individuals). English name, and authors (including organizations and individuals).
endpoint (str, optional): Hub endpoint to use. When None, use the endpoint specified in the class. endpoint (str, optional): Hub endpoint to use. When None, use the endpoint specified in the class.
token (str, optional): Access token for authentication.
Returns: Returns:
dict: The OpenAPI data payload, e.g. dict: The OpenAPI data payload, e.g.
@@ -966,7 +975,7 @@ class HubApi:
if owner_or_group: if owner_or_group:
params['author'] = owner_or_group params['author'] = owner_or_group
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=False)
headers = self.builder_headers(self.headers) headers = self.builder_headers(self.headers)
r = self.session.get( r = self.session.get(
@@ -991,9 +1000,7 @@ class HubApi:
if isinstance(use_cookies, CookieJar): if isinstance(use_cookies, CookieJar):
cookies = use_cookies cookies = use_cookies
elif isinstance(use_cookies, bool): elif isinstance(use_cookies, bool):
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(cookies_required=use_cookies)
if use_cookies and cookies is None:
raise ValueError('Token does not exist, please login first.')
return cookies return cookies
def list_model_revisions( def list_model_revisions(
@@ -1251,6 +1258,7 @@ class HubApi:
filename: str, filename: str,
*, *,
revision: Optional[str] = None, revision: Optional[str] = None,
token: Optional[str] = None,
): ):
"""Get if the specified file exists """Get if the specified file exists
@@ -1259,10 +1267,11 @@ class HubApi:
filename (`str`): The queried filename, if the file exists in a sub folder, filename (`str`): The queried filename, if the file exists in a sub folder,
please pass <sub-folder-name>/<file-name> please pass <sub-folder-name>/<file-name>
revision (`Optional[str]`): The repo revision revision (`Optional[str]`): The repo revision
token (`Optional[str]`): The access token
Returns: Returns:
The query result in bool value The query result in bool value
""" """
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
files = self.get_model_files( files = self.get_model_files(
repo_id, repo_id,
recursive=True, recursive=True,
@@ -1279,14 +1288,13 @@ class HubApi:
license: Optional[str] = Licenses.APACHE_V2, license: Optional[str] = Licenses.APACHE_V2,
visibility: Optional[int] = DatasetVisibility.PUBLIC, visibility: Optional[int] = DatasetVisibility.PUBLIC,
description: Optional[str] = '', description: Optional[str] = '',
endpoint: Optional[str] = None, ) -> str: endpoint: Optional[str] = None,
token: Optional[str] = None) -> str:
if dataset_name is None or namespace is None: if dataset_name is None or namespace is None:
raise InvalidParameter('dataset_name and namespace are required!') raise InvalidParameter('dataset_name and namespace are required!')
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
if cookies is None:
raise ValueError('Token does not exist, please login first.')
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
path = f'{endpoint}/api/v1/datasets' path = f'{endpoint}/api/v1/datasets'
@@ -1312,9 +1320,9 @@ class HubApi:
logger.info(f'Create dataset success: {dataset_repo_url}') logger.info(f'Create dataset success: {dataset_repo_url}')
return dataset_repo_url return dataset_repo_url
def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None): def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None, token: Optional[str] = None):
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
if cookies is None: if cookies is None:
@@ -1327,12 +1335,13 @@ class HubApi:
raise_for_http_status(r) raise_for_http_status(r)
raise_on_error(r.json()) raise_on_error(r.json())
def get_dataset_id_and_type(self, dataset_name: str, namespace: str, endpoint: Optional[str] = None): def get_dataset_id_and_type(self, dataset_name: str, namespace: str, endpoint: Optional[str] = None,
token: Optional[str] = None):
""" Get the dataset id and type. """ """ Get the dataset id and type. """
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}' datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
r = self.session.get(datahub_url, cookies=cookies) r = self.session.get(datahub_url, cookies=cookies)
resp = r.json() resp = r.json()
datahub_raise_on_error(datahub_url, resp, r) datahub_raise_on_error(datahub_url, resp, r)
@@ -1348,7 +1357,8 @@ class HubApi:
recursive: bool = True, recursive: bool = True,
page_number: int = 1, page_number: int = 1,
page_size: int = 100, page_size: int = 100,
endpoint: Optional[str] = None): endpoint: Optional[str] = None,
token: Optional[str] = None):
""" """
@deprecated: Use `get_dataset_files` instead. @deprecated: Use `get_dataset_files` instead.
""" """
@@ -1356,7 +1366,7 @@ class HubApi:
DeprecationWarning) DeprecationWarning)
dataset_hub_id, dataset_type = self.get_dataset_id_and_type( dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
dataset_name=dataset_name, namespace=namespace, endpoint=endpoint) dataset_name=dataset_name, namespace=namespace, endpoint=endpoint, token=token)
recursive = 'True' if recursive else 'False' recursive = 'True' if recursive else 'False'
if not endpoint: if not endpoint:
@@ -1365,7 +1375,7 @@ class HubApi:
params = {'Revision': revision if revision else 'master', params = {'Revision': revision if revision else 'master',
'Root': root_path if root_path else '/', 'Recursive': recursive, 'Root': root_path if root_path else '/', 'Recursive': recursive,
'PageNumber': page_number, 'PageSize': page_size} 'PageNumber': page_number, 'PageSize': page_size}
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
r = self.session.get(datahub_url, params=params, cookies=cookies) r = self.session.get(datahub_url, params=params, cookies=cookies)
resp = r.json() resp = r.json()
@@ -1380,7 +1390,8 @@ class HubApi:
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
page_number: int = 1, page_number: int = 1,
page_size: int = 50, page_size: int = 50,
endpoint: Optional[str] = None): endpoint: Optional[str] = None,
token: Optional[str] = None):
""" """
Get the commit history for a repository. Get the commit history for a repository.
@@ -1391,6 +1402,7 @@ class HubApi:
page_number (int): The page number for pagination. Defaults to 1. page_number (int): The page number for pagination. Defaults to 1.
page_size (int): The number of commits per page. Defaults to 50. page_size (int): The number of commits per page. Defaults to 50.
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class. endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
token (Optional[str]): The access token.
Returns: Returns:
CommitHistoryResponse: The commit history response. CommitHistoryResponse: The commit history response.
@@ -1420,7 +1432,7 @@ class HubApi:
'PageNumber': page_number, 'PageNumber': page_number,
'PageSize': page_size 'PageSize': page_size
} }
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
try: try:
r = self.session.get(commits_url, params=params, r = self.session.get(commits_url, params=params,
@@ -1443,7 +1455,8 @@ class HubApi:
recursive: bool = True, recursive: bool = True,
page_number: int = 1, page_number: int = 1,
page_size: int = 100, page_size: int = 100,
endpoint: Optional[str] = None): endpoint: Optional[str] = None,
token: Optional[str] = None):
""" """
Get the dataset files. Get the dataset files.
@@ -1455,6 +1468,7 @@ class HubApi:
page_number (int): The page number for pagination. Defaults to 1. page_number (int): The page number for pagination. Defaults to 1.
page_size (int): The number of items per page. Defaults to 100. page_size (int): The number of items per page. Defaults to 100.
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class. endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
token (Optional[str]): The access token.
Returns: Returns:
List: The response containing the dataset repository tree information. List: The response containing the dataset repository tree information.
@@ -1468,7 +1482,7 @@ class HubApi:
raise ValueError(f'Invalid repo_id: {repo_id} !') raise ValueError(f'Invalid repo_id: {repo_id} !')
dataset_hub_id, dataset_type = self.get_dataset_id_and_type( dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint) dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
@@ -1480,7 +1494,7 @@ class HubApi:
'PageNumber': page_number, 'PageNumber': page_number,
'PageSize': page_size 'PageSize': page_size
} }
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
r = self.session.get(datahub_url, params=params, cookies=cookies) r = self.session.get(datahub_url, params=params, cookies=cookies)
resp = r.json() resp = r.json()
@@ -1492,7 +1506,8 @@ class HubApi:
self, self,
dataset_id: str, dataset_id: str,
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
endpoint: Optional[str] = None endpoint: Optional[str] = None,
token: Optional[str] = None
): ):
""" """
Get the dataset information. Get the dataset information.
@@ -1501,11 +1516,12 @@ class HubApi:
dataset_id (str): The dataset id. dataset_id (str): The dataset id.
revision (Optional[str]): The revision of the dataset. revision (Optional[str]): The revision of the dataset.
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class. endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
token (Optional[str]): The access token.
Returns: Returns:
dict: The dataset information. dict: The dataset information.
""" """
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
@@ -1522,12 +1538,13 @@ class HubApi:
return resp[API_RESPONSE_FIELD_DATA] return resp[API_RESPONSE_FIELD_DATA]
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, def get_dataset_meta_file_list(self, dataset_name: str, namespace: str,
dataset_id: str, revision: str, endpoint: Optional[str] = None): dataset_id: str, revision: str, endpoint: Optional[str] = None,
token: Optional[str] = None):
""" Get the meta file-list of the dataset. """ """ Get the meta file-list of the dataset. """
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}' datahub_url = f'{endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
r = self.session.get(datahub_url, r = self.session.get(datahub_url,
cookies=cookies, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
@@ -1559,11 +1576,12 @@ class HubApi:
namespace: str, namespace: str,
revision: str, revision: str,
meta_cache_dir: str, dataset_type: int, file_list: list, meta_cache_dir: str, dataset_type: int, file_list: list,
endpoint: Optional[str] = None): endpoint: Optional[str] = None,
token: Optional[str] = None):
local_paths = defaultdict(list) local_paths = defaultdict(list)
dataset_formation = DatasetFormations(dataset_type) dataset_formation = DatasetFormations(dataset_type)
dataset_meta_format = DatasetMetaFormats[dataset_formation] dataset_meta_format = DatasetMetaFormats[dataset_formation]
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
# Dump the data_type as a local file # Dump the data_type as a local file
HubApi.dump_datatype_file(dataset_type=dataset_type, meta_cache_dir=meta_cache_dir) HubApi.dump_datatype_file(dataset_type=dataset_type, meta_cache_dir=meta_cache_dir)
@@ -1591,7 +1609,8 @@ class HubApi:
return local_paths, dataset_formation return local_paths, dataset_formation
@staticmethod @staticmethod
def fetch_meta_files_from_url(url, out_path, chunk_size=1024, mode=DownloadMode.REUSE_DATASET_IF_EXISTS): def fetch_meta_files_from_url(url, out_path, chunk_size=1024, mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
token: Optional[str] = None):
""" """
Fetch the meta-data files from the url, e.g. csv/jsonl files. Fetch the meta-data files from the url, e.g. csv/jsonl files.
""" """
@@ -1605,7 +1624,7 @@ class HubApi:
if os.path.exists(out_path): if os.path.exists(out_path):
logger.info(f'Reusing cached meta-data file: {out_path}') logger.info(f'Reusing cached meta-data file: {out_path}')
return out_path return out_path
cookies = ModelScopeConfig.get_cookies() cookies = HubApi().get_cookies(access_token=token)
# Make the request and get the response content as TextIO # Make the request and get the response content as TextIO
logger.info('Loading meta-data file ...') logger.info('Loading meta-data file ...')
@@ -1697,12 +1716,13 @@ class HubApi:
dataset_name: str, dataset_name: str,
namespace: str, namespace: str,
revision: Optional[str] = DEFAULT_DATASET_REVISION, revision: Optional[str] = DEFAULT_DATASET_REVISION,
endpoint: Optional[str] = None): endpoint: Optional[str] = None,
token: Optional[str] = None):
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \ datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
f'ststoken?Revision={revision}' f'ststoken?Revision={revision}'
return self.datahub_remote_call(datahub_url) return self.datahub_remote_call(datahub_url, token=token)
def get_dataset_access_config_session( def get_dataset_access_config_session(
self, self,
@@ -1710,7 +1730,8 @@ class HubApi:
namespace: str, namespace: str,
check_cookie: bool, check_cookie: bool,
revision: Optional[str] = DEFAULT_DATASET_REVISION, revision: Optional[str] = DEFAULT_DATASET_REVISION,
endpoint: Optional[str] = None): endpoint: Optional[str] = None,
token: Optional[str] = None):
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
@@ -1719,7 +1740,7 @@ class HubApi:
if check_cookie: if check_cookie:
cookies = self._check_cookie(use_cookies=True) cookies = self._check_cookie(use_cookies=True)
else: else:
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
r = self.session.get( r = self.session.get(
url=datahub_url, url=datahub_url,
@@ -1729,7 +1750,7 @@ class HubApi:
raise_on_error(resp) raise_on_error(resp)
return resp['Data'] return resp['Data']
def get_virgo_meta(self, dataset_id: str, version: int = 1) -> dict: def get_virgo_meta(self, dataset_id: str, version: int = 1, token: Optional[str] = None) -> dict:
""" """
Get virgo dataset meta info. Get virgo dataset meta info.
""" """
@@ -1738,7 +1759,7 @@ class HubApi:
raise RuntimeError(f'Virgo endpoint is not set in env: {VirgoDatasetConfig.env_virgo_endpoint}') raise RuntimeError(f'Virgo endpoint is not set in env: {VirgoDatasetConfig.env_virgo_endpoint}')
virgo_dataset_url = f'{virgo_endpoint}/data/set/download' virgo_dataset_url = f'{virgo_endpoint}/data/set/download'
cookies = requests.utils.dict_from_cookiejar(ModelScopeConfig.get_cookies()) cookies = requests.utils.dict_from_cookiejar(self.get_cookies(access_token=token))
dataset_info = dict( dataset_info = dict(
dataSetId=dataset_id, dataSetId=dataset_id,
@@ -1763,11 +1784,12 @@ class HubApi:
namespace: str, namespace: str,
revision: str, revision: str,
zip_file_name: str, zip_file_name: str,
endpoint: Optional[str] = None): endpoint: Optional[str] = None,
token: Optional[str] = None):
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}' datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
r = self.session.get(url=datahub_url, cookies=cookies, r = self.session.get(url=datahub_url, cookies=cookies,
headers=self.builder_headers(self.headers)) headers=self.builder_headers(self.headers))
resp = r.json() resp = r.json()
@@ -1787,13 +1809,14 @@ class HubApi:
return data_sts return data_sts
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit, def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
is_recursive, is_filter_dir, revision, endpoint: Optional[str] = None): is_recursive, is_filter_dir, revision, endpoint: Optional[str] = None,
token: Optional[str] = None):
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \ url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}' f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
resp = self.session.get(url=url, cookies=cookies, timeout=1800) resp = self.session.get(url=url, cookies=cookies, timeout=1800)
resp = resp.json() resp = resp.json()
raise_on_error(resp) raise_on_error(resp)
@@ -1801,14 +1824,15 @@ class HubApi:
return resp return resp
def delete_oss_dataset_object(self, object_name: str, dataset_name: str, def delete_oss_dataset_object(self, object_name: str, dataset_name: str,
namespace: str, revision: str, endpoint: Optional[str] = None) -> str: namespace: str, revision: str, endpoint: Optional[str] = None,
token: Optional[str] = None) -> str:
if not object_name or not dataset_name or not namespace or not revision: if not object_name or not dataset_name or not namespace or not revision:
raise ValueError('Args cannot be empty!') raise ValueError('Args cannot be empty!')
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}' url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}'
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
resp = self.session.delete(url=url, cookies=cookies) resp = self.session.delete(url=url, cookies=cookies)
resp = resp.json() resp = resp.json()
raise_on_error(resp) raise_on_error(resp)
@@ -1816,7 +1840,8 @@ class HubApi:
return resp return resp
def delete_oss_dataset_dir(self, object_name: str, dataset_name: str, def delete_oss_dataset_dir(self, object_name: str, dataset_name: str,
namespace: str, revision: str, endpoint: Optional[str] = None) -> str: namespace: str, revision: str, endpoint: Optional[str] = None,
token: Optional[str] = None) -> str:
if not object_name or not dataset_name or not namespace or not revision: if not object_name or not dataset_name or not namespace or not revision:
raise ValueError('Args cannot be empty!') raise ValueError('Args cannot be empty!')
if not endpoint: if not endpoint:
@@ -1824,15 +1849,15 @@ class HubApi:
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \ url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \
f'&Revision={revision}' f'&Revision={revision}'
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
resp = self.session.delete(url=url, cookies=cookies) resp = self.session.delete(url=url, cookies=cookies)
resp = resp.json() resp = resp.json()
raise_on_error(resp) raise_on_http_response(resp, logger, cookies, url)
resp = resp['Message'] resp = resp['Message']
return resp return resp
def datahub_remote_call(self, url): def datahub_remote_call(self, url, token: Optional[str] = None):
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
r = self.session.get( r = self.session.get(
url, url,
cookies=cookies, cookies=cookies,
@@ -1842,13 +1867,14 @@ class HubApi:
return resp['Data'] return resp['Data']
def dataset_download_statistics(self, dataset_name: str, namespace: str, def dataset_download_statistics(self, dataset_name: str, namespace: str,
use_streaming: bool = False, endpoint: Optional[str] = None) -> None: use_streaming: bool = False, endpoint: Optional[str] = None,
token: Optional[str] = None) -> None:
is_ci_test = os.getenv('CI_TEST') == 'True' is_ci_test = os.getenv('CI_TEST') == 'True'
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
if dataset_name and namespace and not is_ci_test and not use_streaming: if dataset_name and namespace and not is_ci_test and not use_streaming:
try: try:
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token)
# Download count # Download count
download_count_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase' download_count_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
@@ -2219,6 +2245,7 @@ class HubApi:
data=path_or_fileobj, data=path_or_fileobj,
disable_tqdm=disable_tqdm, disable_tqdm=disable_tqdm,
tqdm_desc=tqdm_desc, tqdm_desc=tqdm_desc,
token=token,
) )
# Construct commit info and create commit # Construct commit info and create commit
@@ -2370,6 +2397,7 @@ class HubApi:
data=file_path, data=file_path,
disable_tqdm=file_size <= UPLOAD_BLOB_TQDM_DISABLE_THRESHOLD, disable_tqdm=file_size <= UPLOAD_BLOB_TQDM_DISABLE_THRESHOLD,
tqdm_desc='[Uploading ' + file_path_in_repo + ']', tqdm_desc='[Uploading ' + file_path_in_repo + ']',
token=token,
) )
return { return {
@@ -2446,6 +2474,7 @@ class HubApi:
disable_tqdm: Optional[bool] = False, disable_tqdm: Optional[bool] = False,
tqdm_desc: Optional[str] = '[Uploading]', tqdm_desc: Optional[str] = '[Uploading]',
buffer_size_mb: Optional[int] = 1, buffer_size_mb: Optional[int] = 1,
token: Optional[str] = None,
) -> dict: ) -> dict:
res_d: dict = dict( res_d: dict = dict(
@@ -2460,6 +2489,7 @@ class HubApi:
repo_id=repo_id, repo_id=repo_id,
repo_type=repo_type, repo_type=repo_type,
objects=objects, objects=objects,
token=token,
) )
# upload_object: {'url': 'xxx', 'oid': 'xxx'} # upload_object: {'url': 'xxx', 'oid': 'xxx'}
@@ -2470,7 +2500,7 @@ class HubApi:
res_d['is_uploaded'] = True res_d['is_uploaded'] = True
return res_d return res_d
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
cookies = dict(cookies) if cookies else None cookies = dict(cookies) if cookies else None
if cookies is None: if cookies is None:
raise ValueError('Token does not exist, please login first.') raise ValueError('Token does not exist, please login first.')
@@ -2536,7 +2566,8 @@ class HubApi:
repo_id: str, repo_id: str,
repo_type: str, repo_type: str,
objects: List[Dict[str, Any]], objects: List[Dict[str, Any]],
endpoint: Optional[str] = None endpoint: Optional[str] = None,
token: Optional[str] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
Check the blob has already uploaded. Check the blob has already uploaded.
@@ -2549,6 +2580,7 @@ class HubApi:
oid (str): The sha256 hash value. oid (str): The sha256 hash value.
size (int): The size of the blob. size (int): The size of the blob.
endpoint: the endpoint to use, default to None to use endpoint specified in the class endpoint: the endpoint to use, default to None to use endpoint specified in the class
token (str): The access token.
Returns: Returns:
List[Dict[str, Any]]: The result of the check. List[Dict[str, Any]]: The result of the check.
@@ -2565,9 +2597,7 @@ class HubApi:
'objects': objects, 'objects': objects,
} }
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
if cookies is None:
raise ValueError('Token does not exist, please login first.')
response = requests.post( response = requests.post(
url, url,
headers=self.builder_headers(self.headers), headers=self.builder_headers(self.headers),
@@ -2755,7 +2785,8 @@ class HubApi:
delete_patterns: Union[str, List[str]], delete_patterns: Union[str, List[str]],
*, *,
revision: Optional[str] = DEFAULT_MODEL_REVISION, revision: Optional[str] = DEFAULT_MODEL_REVISION,
endpoint: Optional[str] = None) -> Dict[str, Any]: endpoint: Optional[str] = None,
token: Optional[str] = None) -> Dict[str, Any]:
""" """
Delete files in batch using glob (wildcard) patterns, e.g. '*.py', 'data/*.csv', 'foo*', etc. Delete files in batch using glob (wildcard) patterns, e.g. '*.py', 'data/*.csv', 'foo*', etc.
@@ -2780,6 +2811,7 @@ class HubApi:
delete_patterns (str or List[str]): List of glob patterns, e.g. '*.py', 'data/*.csv', 'foo*' delete_patterns (str or List[str]): List of glob patterns, e.g. '*.py', 'data/*.csv', 'foo*'
revision (str, optional): Branch or tag name revision (str, optional): Branch or tag name
endpoint (str, optional): API endpoint endpoint (str, optional): API endpoint
token (str, optional): Access token
Returns: Returns:
dict: Deletion result dict: Deletion result
""" """
@@ -2790,7 +2822,7 @@ class HubApi:
if isinstance(delete_patterns, str): if isinstance(delete_patterns, str):
delete_patterns = [delete_patterns] delete_patterns = [delete_patterns]
cookies = ModelScopeConfig.get_cookies() cookies = self.get_cookies(access_token=token, cookies_required=True)
if not endpoint: if not endpoint:
endpoint = self.endpoint endpoint = self.endpoint
if cookies is None: if cookies is None:

View File

@@ -33,6 +33,7 @@ def get_model_id_from_cache(model_root_path: str, ) -> str:
def check_local_model_is_latest( def check_local_model_is_latest(
model_root_path: str, model_root_path: str,
user_agent: Optional[Union[Dict, str]] = None, user_agent: Optional[Union[Dict, str]] = None,
token: Optional[str] = None,
): ):
"""Check local model repo is latest. """Check local model repo is latest.
Check local model repo is same as hub latest version. Check local model repo is same as hub latest version.
@@ -45,7 +46,8 @@ def check_local_model_is_latest(
'user-agent': 'user-agent':
ModelScopeConfig.get_user_agent(user_agent=user_agent, ) ModelScopeConfig.get_user_agent(user_agent=user_agent, )
} }
cookies = ModelScopeConfig.get_cookies() _api = HubApi(timeout=20, token=token)
cookies = _api.get_cookies()
snapshot_header = headers if 'CI_TEST' in os.environ else { snapshot_header = headers if 'CI_TEST' in os.environ else {
**headers, **headers,
@@ -53,7 +55,6 @@ def check_local_model_is_latest(
'Snapshot': 'True' 'Snapshot': 'True'
} }
} }
_api = HubApi(timeout=20)
try: try:
_, revisions = _api.get_model_branches_and_tags( _, revisions = _api.get_model_branches_and_tags(
model_id=model_id, use_cookies=cookies) model_id=model_id, use_cookies=cookies)

View File

@@ -7,7 +7,7 @@ import json
import requests import requests
from attrs import asdict, define, field, validators from attrs import asdict, define, field, validators
from modelscope.hub.api import ModelScopeConfig from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA, from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
API_RESPONSE_FIELD_MESSAGE) API_RESPONSE_FIELD_MESSAGE)
from modelscope.hub.errors import (NotLoginException, NotSupportError, from modelscope.hub.errors import (NotLoginException, NotSupportError,
@@ -188,13 +188,11 @@ class ServiceDeployer(object):
"""Facilitate model deployment on to supported service provider(s). """Facilitate model deployment on to supported service provider(s).
""" """
def __init__(self, endpoint=None): def __init__(self, endpoint=None, token: Optional[str] = None):
self.endpoint = endpoint if endpoint is not None else get_endpoint() self.endpoint = endpoint if endpoint is not None else get_endpoint()
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()} self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
self.cookies = ModelScopeConfig.get_cookies() self.cookies = HubApi().get_cookies(
if self.cookies is None: access_token=token, cookies_required=True)
raise NotLoginException(
'Token does not exist, please login with HubApi first.')
# deploy_model # deploy_model
def create(self, model_id: str, revision: str, instance_name: str, def create(self, model_id: str, revision: str, instance_name: str,

View File

@@ -49,6 +49,7 @@ def model_file_download(
local_files_only: Optional[bool] = False, local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None, cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None, local_dir: Optional[str] = None,
token: Optional[str] = None,
) -> Optional[str]: # pragma: no cover ) -> Optional[str]: # pragma: no cover
"""Download from a given URL and cache it if it's not already present in the local cache. """Download from a given URL and cache it if it's not already present in the local cache.
@@ -67,6 +68,7 @@ def model_file_download(
local cached file if it exists. if `False`, download the file anyway even it exists. local cached file if it exists. if `False`, download the file anyway even it exists.
cookies (CookieJar, optional): The cookie of download request. cookies (CookieJar, optional): The cookie of download request.
local_dir (str, optional): Specific local directory path to which the file will be downloaded. local_dir (str, optional): Specific local directory path to which the file will be downloaded.
token (str, optional): The user token.
Returns: Returns:
string: string of local file or if networking is off, last version of string: string of local file or if networking is off, last version of
@@ -95,7 +97,8 @@ def model_file_download(
user_agent=user_agent, user_agent=user_agent,
local_files_only=local_files_only, local_files_only=local_files_only,
cookies=cookies, cookies=cookies,
local_dir=local_dir) local_dir=local_dir,
token=token)
def dataset_file_download( def dataset_file_download(
@@ -107,6 +110,7 @@ def dataset_file_download(
user_agent: Optional[Union[Dict, str]] = None, user_agent: Optional[Union[Dict, str]] = None,
local_files_only: Optional[bool] = False, local_files_only: Optional[bool] = False,
cookies: Optional[CookieJar] = None, cookies: Optional[CookieJar] = None,
token: Optional[str] = None,
) -> str: ) -> str:
"""Download raw files of a dataset. """Download raw files of a dataset.
Downloads all files at the specified revision. This Downloads all files at the specified revision. This
@@ -129,6 +133,7 @@ def dataset_file_download(
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
local cached file if it exists. local cached file if it exists.
cookies (CookieJar, optional): The cookie of the request, default None. cookies (CookieJar, optional): The cookie of the request, default None.
token (str, optional): The user token.
Raises: Raises:
ValueError: the value details. ValueError: the value details.
@@ -153,7 +158,8 @@ def dataset_file_download(
user_agent=user_agent, user_agent=user_agent,
local_files_only=local_files_only, local_files_only=local_files_only,
cookies=cookies, cookies=cookies,
local_dir=local_dir) local_dir=local_dir,
token=token)
def _repo_file_download( def _repo_file_download(
@@ -168,6 +174,7 @@ def _repo_file_download(
cookies: Optional[CookieJar] = None, cookies: Optional[CookieJar] = None,
local_dir: Optional[str] = None, local_dir: Optional[str] = None,
disable_tqdm: bool = False, disable_tqdm: bool = False,
token: Optional[str] = None,
) -> Optional[str]: # pragma: no cover ) -> Optional[str]: # pragma: no cover
if not repo_type: if not repo_type:
@@ -194,7 +201,7 @@ def _repo_file_download(
' traffic has been disabled. To enable look-ups and downloads' ' traffic has been disabled. To enable look-ups and downloads'
" online, set 'local_files_only' to False.") " online, set 'local_files_only' to False.")
_api = HubApi() _api = HubApi(token=token)
headers = { headers = {
'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ), 'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
@@ -212,7 +219,7 @@ def _repo_file_download(
headers['x-aliyun-region-id'] = region_id headers['x-aliyun-region-id'] = region_id
if cookies is None: if cookies is None:
cookies = ModelScopeConfig.get_cookies() cookies = _api.get_cookies()
repo_files = [] repo_files = []
endpoint = _api.get_endpoint_for_read(repo_id=repo_id, repo_type=repo_type) endpoint = _api.get_endpoint_for_read(repo_id=repo_id, repo_type=repo_type)
file_to_download_meta = None file_to_download_meta = None

View File

@@ -48,6 +48,7 @@ def snapshot_download(
repo_type: Optional[str] = REPO_TYPE_MODEL, repo_type: Optional[str] = REPO_TYPE_MODEL,
enable_file_lock: Optional[bool] = None, enable_file_lock: Optional[bool] = None,
progress_callbacks: List[Type[ProgressCallback]] = None, progress_callbacks: List[Type[ProgressCallback]] = None,
token: Optional[str] = None,
) -> str: ) -> str:
"""Download all files of a repo. """Download all files of a repo.
Downloads a whole snapshot of a repo's files at the specified revision. This Downloads a whole snapshot of a repo's files at the specified revision. This
@@ -88,6 +89,7 @@ def snapshot_download(
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`. change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`): progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
progress callbacks to track the download progress. progress callbacks to track the download progress.
token (str, optional): The user token.
Raises: Raises:
ValueError: the value details. ValueError: the value details.
@@ -146,7 +148,8 @@ def snapshot_download(
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
max_workers=max_workers, max_workers=max_workers,
progress_callbacks=progress_callbacks) progress_callbacks=progress_callbacks,
token=token)
def dataset_snapshot_download( def dataset_snapshot_download(
@@ -163,6 +166,7 @@ def dataset_snapshot_download(
ignore_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None,
enable_file_lock: Optional[bool] = None, enable_file_lock: Optional[bool] = None,
max_workers: int = 8, max_workers: int = 8,
token: Optional[str] = None,
) -> str: ) -> str:
"""Download raw files of a dataset. """Download raw files of a dataset.
Downloads all files at the specified revision. This Downloads all files at the specified revision. This
@@ -199,6 +203,7 @@ def dataset_snapshot_download(
If you find something wrong with file lock and have a problem modifying your code, If you find something wrong with file lock and have a problem modifying your code,
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`. change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
max_workers (`int`): The maximum number of workers to download files, default 8. max_workers (`int`): The maximum number of workers to download files, default 8.
token (str, optional): The user token.
Raises: Raises:
ValueError: the value details. ValueError: the value details.
@@ -241,7 +246,8 @@ def dataset_snapshot_download(
local_dir=local_dir, local_dir=local_dir,
ignore_patterns=ignore_patterns, ignore_patterns=ignore_patterns,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
max_workers=max_workers) max_workers=max_workers,
token=token)
def _snapshot_download( def _snapshot_download(
@@ -260,6 +266,7 @@ def _snapshot_download(
ignore_patterns: Optional[Union[List[str], str]] = None, ignore_patterns: Optional[Union[List[str], str]] = None,
max_workers: int = 8, max_workers: int = 8,
progress_callbacks: List[Type[ProgressCallback]] = None, progress_callbacks: List[Type[ProgressCallback]] = None,
token: Optional[str] = None,
): ):
if not repo_type: if not repo_type:
repo_type = REPO_TYPE_MODEL repo_type = REPO_TYPE_MODEL
@@ -299,11 +306,11 @@ def _snapshot_download(
) )
headers['x-aliyun-region-id'] = region_id headers['x-aliyun-region-id'] = region_id
_api = HubApi() _api = HubApi(token=token)
endpoint = _api.get_endpoint_for_read( endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=repo_type) repo_id=repo_id, repo_type=repo_type)
if cookies is None: if cookies is None:
cookies = ModelScopeConfig.get_cookies() cookies = _api.get_cookies()
if repo_type == REPO_TYPE_MODEL: if repo_type == REPO_TYPE_MODEL:
if local_dir: if local_dir:
directory = os.path.abspath(local_dir) directory = os.path.abspath(local_dir)

View File

@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates. # Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Mapping, Sequence, Union from typing import Mapping, Optional, Sequence, Union
from modelscope.msdatasets.auth.auth_config import BaseAuthConfig from modelscope.msdatasets.auth.auth_config import BaseAuthConfig
from modelscope.msdatasets.download.download_config import DataDownloadConfig from modelscope.msdatasets.download.download_config import DataDownloadConfig
@@ -11,14 +11,24 @@ from modelscope.utils.constant import DownloadMode, Hubs
class DatasetContextConfig: class DatasetContextConfig:
"""Context configuration of dataset.""" """Context configuration of dataset."""
def __init__(self, dataset_name: Union[str, list], namespace: str, def __init__(self,
version: str, subset_name: str, split: Union[str, list], dataset_name: Union[str, list],
target: str, hub: Hubs, data_dir: str, namespace: str,
version: str,
subset_name: str,
split: Union[str, list],
target: str,
hub: Hubs,
data_dir: str,
data_files: Union[str, Sequence[str], data_files: Union[str, Sequence[str],
Mapping[str, Union[str, Sequence[str]]]], Mapping[str, Union[str, Sequence[str]]]],
download_mode: DownloadMode, cache_root_dir: str, download_mode: DownloadMode,
use_streaming: bool, stream_batch_size: int, cache_root_dir: str,
trust_remote_code: bool, **kwargs): use_streaming: bool,
stream_batch_size: int,
trust_remote_code: bool,
token: Optional[str] = None,
**kwargs):
self._download_config = None self._download_config = None
self._data_meta_config = None self._data_meta_config = None
@@ -32,6 +42,7 @@ class DatasetContextConfig:
# General arguments for dataset # General arguments for dataset
self.hub = hub self.hub = hub
self.token = token
self.download_mode = download_mode self.download_mode = download_mode
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.namespace = namespace self.namespace = namespace

View File

@@ -8,7 +8,7 @@ from datasets import (Dataset, DatasetBuilder, DatasetDict, IterableDataset,
IterableDatasetDict) IterableDatasetDict)
from datasets import load_dataset as hf_load_dataset from datasets import load_dataset as hf_load_dataset
from modelscope.hub.api import ModelScopeConfig from modelscope.hub.api import HubApi, ModelScopeConfig
from modelscope.msdatasets.auth.auth_config import OssAuthConfig from modelscope.msdatasets.auth.auth_config import OssAuthConfig
from modelscope.msdatasets.context.dataset_context_config import \ from modelscope.msdatasets.context.dataset_context_config import \
DatasetContextConfig DatasetContextConfig
@@ -86,7 +86,8 @@ class OssDownloader(BaseDownloader):
def _authorize(self) -> None: def _authorize(self) -> None:
""" Authorization of target dataset. """ Authorization of target dataset.
Get credentials from cache and send to the modelscope-hub in the future. """ Get credentials from cache and send to the modelscope-hub in the future. """
cookies = ModelScopeConfig.get_cookies() cookies = HubApi().get_cookies(
access_token=self.dataset_context_config.token)
git_token = ModelScopeConfig.get_token() git_token = ModelScopeConfig.get_token()
user_info = ModelScopeConfig.get_user_info() user_info = ModelScopeConfig.get_user_info()
@@ -178,7 +179,8 @@ class VirgoDownloader(BaseDownloader):
"""Authorization of virgo dataset.""" """Authorization of virgo dataset."""
from modelscope.msdatasets.auth.auth_config import VirgoAuthConfig from modelscope.msdatasets.auth.auth_config import VirgoAuthConfig
cookies = ModelScopeConfig.get_cookies() cookies = HubApi().get_cookies(
access_token=self.dataset_context_config.token)
user_info = ModelScopeConfig.get_user_info() user_info = ModelScopeConfig.get_user_info()
if not self.dataset_context_config.auth_config: if not self.dataset_context_config.auth_config:

View File

@@ -97,7 +97,8 @@ class RemoteDataLoaderManager(DataLoaderManager):
def __init__(self, dataset_context_config: DatasetContextConfig): def __init__(self, dataset_context_config: DatasetContextConfig):
super().__init__(dataset_context_config=dataset_context_config) super().__init__(dataset_context_config=dataset_context_config)
self.api = HubApi()
self.api = HubApi(token=dataset_context_config.token)
def load_dataset(self, data_loader_type: enum.Enum): def load_dataset(self, data_loader_type: enum.Enum):
# Get args from context # Get args from context
@@ -112,6 +113,7 @@ class RemoteDataLoaderManager(DataLoaderManager):
use_streaming = self.dataset_context_config.use_streaming use_streaming = self.dataset_context_config.use_streaming
input_config_kwargs = self.dataset_context_config.config_kwargs input_config_kwargs = self.dataset_context_config.config_kwargs
trust_remote_code = self.dataset_context_config.trust_remote_code trust_remote_code = self.dataset_context_config.trust_remote_code
token = self.dataset_context_config.token
# To use the huggingface data loader # To use the huggingface data loader
if data_loader_type == RemoteDataLoaderType.HF_DATA_LOADER: if data_loader_type == RemoteDataLoaderType.HF_DATA_LOADER:
@@ -129,6 +131,7 @@ class RemoteDataLoaderManager(DataLoaderManager):
download_mode=download_mode_val, download_mode=download_mode_val,
streaming=use_streaming, streaming=use_streaming,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
token=token,
**input_config_kwargs) **input_config_kwargs)
# download statistics # download statistics
self.api.dataset_download_statistics( self.api.dataset_download_statistics(

View File

@@ -24,7 +24,7 @@ from filelock import FileLock
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
from modelscope.hub.api import ModelScopeConfig from modelscope.hub.api import HubApi
from modelscope import __version__ from modelscope import __version__
@@ -219,7 +219,7 @@ def get_from_cache_ms(
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
connected = True connected = True
try: try:
cookies = ModelScopeConfig.get_cookies() cookies = HubApi().get_cookies(access_token=token)
response = http_head_ms( response = http_head_ms(
url, url,
allow_redirects=True, allow_redirects=True,