mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Feat: update hubapi get cookie (#1573)
This commit is contained in:
@@ -102,13 +102,15 @@ class HubApi:
|
||||
def __init__(self,
|
||||
endpoint: Optional[str] = None,
|
||||
timeout=API_HTTP_CLIENT_TIMEOUT,
|
||||
max_retries=API_HTTP_CLIENT_MAX_RETRIES):
|
||||
max_retries=API_HTTP_CLIENT_MAX_RETRIES,
|
||||
token: Optional[str] = None):
|
||||
"""The ModelScope HubApi。
|
||||
|
||||
Args:
|
||||
endpoint (str, optional): The modelscope server http|https address. Defaults to None.
|
||||
"""
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
self.token = token
|
||||
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
|
||||
self.session = Session()
|
||||
retry = Retry(
|
||||
@@ -154,12 +156,12 @@ class HubApi:
|
||||
path='/')
|
||||
return jar
|
||||
|
||||
def get_cookies(self, access_token, cookies_required: Optional[bool] = False):
|
||||
def get_cookies(self, access_token: Optional[str] = None, cookies_required: Optional[bool] = False):
|
||||
"""
|
||||
Get cookies for authentication from local cache or access_token.
|
||||
|
||||
Args:
|
||||
access_token (str): user access token on ModelScope
|
||||
access_token (Optional[str]): user access token on ModelScope. If not provided, try to get from local cache.
|
||||
cookies_required (bool): whether to raise error if no cookies found, defaults to `False`.
|
||||
|
||||
Returns:
|
||||
@@ -168,8 +170,9 @@ class HubApi:
|
||||
Raises:
|
||||
ValueError: If no credentials found and cookies_required is True.
|
||||
"""
|
||||
if access_token:
|
||||
cookies = self._get_cookies(access_token=access_token)
|
||||
token = access_token or self.token or os.environ.get('MODELSCOPE_API_TOKEN')
|
||||
if token:
|
||||
cookies = self._get_cookies(access_token=token)
|
||||
else:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
|
||||
@@ -203,8 +206,7 @@ class HubApi:
|
||||
Note:
|
||||
You only have to login once within 30 days.
|
||||
"""
|
||||
if access_token is None:
|
||||
access_token = os.environ.get('MODELSCOPE_API_TOKEN')
|
||||
access_token = access_token or self.token or os.environ.get('MODELSCOPE_API_TOKEN')
|
||||
if not access_token:
|
||||
return None, None
|
||||
if not endpoint:
|
||||
@@ -423,12 +425,13 @@ class HubApi:
|
||||
tag_url = f'{endpoint}/models/{model_id}/tags/{tag_name}'
|
||||
return tag_url
|
||||
|
||||
def delete_model(self, model_id: str, endpoint: Optional[str] = None):
|
||||
def delete_model(self, model_id: str, endpoint: Optional[str] = None, token: Optional[str] = None):
|
||||
"""Delete model_id from ModelScope.
|
||||
|
||||
Args:
|
||||
model_id (str): The model id.
|
||||
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
||||
token (str, optional): access token for authentication
|
||||
|
||||
Raises:
|
||||
ValueError: If not login.
|
||||
@@ -436,7 +439,7 @@ class HubApi:
|
||||
Note:
|
||||
model_id = {owner}/{name}
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
if cookies is None:
|
||||
@@ -458,7 +461,8 @@ class HubApi:
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
endpoint: Optional[str] = None
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Get model information at ModelScope
|
||||
|
||||
@@ -466,6 +470,7 @@ class HubApi:
|
||||
model_id (str): The model id.
|
||||
revision (str optional): revision of model.
|
||||
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
||||
token (str, optional): access token for authentication
|
||||
|
||||
Returns:
|
||||
The model detail information.
|
||||
@@ -476,7 +481,7 @@ class HubApi:
|
||||
Note:
|
||||
model_id = {owner}/{name}
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=False)
|
||||
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
@@ -696,7 +701,12 @@ class HubApi:
|
||||
'Failed to check existence of repo: %s, make sure you have access authorization.'
|
||||
% repo_type)
|
||||
|
||||
def delete_repo(self, repo_id: str, repo_type: str, endpoint: Optional[str] = None):
|
||||
def delete_repo(self,
|
||||
repo_id: str,
|
||||
repo_type: str,
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Delete a repository from ModelScope.
|
||||
|
||||
@@ -709,15 +719,23 @@ class HubApi:
|
||||
endpoint(`str`):
|
||||
The endpoint to use. If not provided, the default endpoint is `https://www.modelscope.cn`
|
||||
Could be set to `https://ai.modelscope.ai` for international version.
|
||||
token (str): Access token of the ModelScope.
|
||||
"""
|
||||
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
|
||||
if repo_type == REPO_TYPE_DATASET:
|
||||
self.delete_dataset(repo_id, endpoint)
|
||||
self.delete_dataset(
|
||||
dataset_id=repo_id,
|
||||
endpoint=endpoint,
|
||||
token=token
|
||||
)
|
||||
elif repo_type == REPO_TYPE_MODEL:
|
||||
self.delete_model(repo_id, endpoint)
|
||||
self.delete_model(
|
||||
model_id=repo_id,
|
||||
endpoint=endpoint,
|
||||
token=token)
|
||||
else:
|
||||
raise Exception(f'Arg repo_type {repo_type} not supported.')
|
||||
|
||||
@@ -744,7 +762,8 @@ class HubApi:
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
original_model_id: Optional[str] = None,
|
||||
ignore_file_pattern: Optional[Union[List[str], str]] = None,
|
||||
lfs_suffix: Optional[Union[str, List[str]]] = None):
|
||||
lfs_suffix: Optional[Union[str, List[str]]] = None,
|
||||
token: Optional[str] = None):
|
||||
warnings.warn(
|
||||
'This function is deprecated and will be removed in future versions. '
|
||||
'Please use git command directly or use HubApi().upload_folder instead',
|
||||
@@ -810,7 +829,7 @@ class HubApi:
|
||||
f'No {ModelFile.CONFIGURATION} file found in {model_dir}, creating a default one.')
|
||||
HubApi._create_default_config(model_dir)
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
if cookies is None:
|
||||
raise NotLoginException('Must login before upload!')
|
||||
files_to_save = os.listdir(model_dir)
|
||||
@@ -881,7 +900,8 @@ class HubApi:
|
||||
owner_or_group: str,
|
||||
page_number: Optional[int] = 1,
|
||||
page_size: Optional[int] = 10,
|
||||
endpoint: Optional[str] = None) -> dict:
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None) -> dict:
|
||||
"""List models in owner or group.
|
||||
|
||||
Args:
|
||||
@@ -889,6 +909,7 @@ class HubApi:
|
||||
page_number(int, optional): The page number, default: 1
|
||||
page_size(int, optional): The page size, default: 10
|
||||
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
||||
token (str, optional): access token for authentication
|
||||
|
||||
Raises:
|
||||
RequestError: The request error.
|
||||
@@ -896,7 +917,7 @@ class HubApi:
|
||||
Returns:
|
||||
dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group}
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=False)
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
path = f'{endpoint}/api/v1/models/'
|
||||
@@ -925,7 +946,7 @@ class HubApi:
|
||||
sort: Optional[str] = None,
|
||||
search: Optional[str] = None,
|
||||
endpoint: Optional[str] = None,
|
||||
) -> dict:
|
||||
token: Optional[str] = None) -> dict:
|
||||
"""List datasets via OpenAPI with pagination, filtering and sorting.
|
||||
|
||||
Args:
|
||||
@@ -937,6 +958,7 @@ class HubApi:
|
||||
search (str, optional): Search by substring keywords in the dataset's Chinese name,
|
||||
English name, and authors (including organizations and individuals).
|
||||
endpoint (str, optional): Hub endpoint to use. When None, use the endpoint specified in the class.
|
||||
token (str, optional): Access token for authentication.
|
||||
|
||||
Returns:
|
||||
dict: The OpenAPI data payload, e.g.
|
||||
@@ -966,7 +988,7 @@ class HubApi:
|
||||
if owner_or_group:
|
||||
params['author'] = owner_or_group
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=False)
|
||||
headers = self.builder_headers(self.headers)
|
||||
|
||||
r = self.session.get(
|
||||
@@ -991,9 +1013,7 @@ class HubApi:
|
||||
if isinstance(use_cookies, CookieJar):
|
||||
cookies = use_cookies
|
||||
elif isinstance(use_cookies, bool):
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if use_cookies and cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
cookies = self.get_cookies(cookies_required=use_cookies)
|
||||
return cookies
|
||||
|
||||
def list_model_revisions(
|
||||
@@ -1251,6 +1271,7 @@ class HubApi:
|
||||
filename: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
):
|
||||
"""Get if the specified file exists
|
||||
|
||||
@@ -1259,10 +1280,11 @@ class HubApi:
|
||||
filename (`str`): The queried filename, if the file exists in a sub folder,
|
||||
please pass <sub-folder-name>/<file-name>
|
||||
revision (`Optional[str]`): The repo revision
|
||||
token (`Optional[str]`): The access token
|
||||
Returns:
|
||||
The query result in bool value
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
files = self.get_model_files(
|
||||
repo_id,
|
||||
recursive=True,
|
||||
@@ -1279,14 +1301,29 @@ class HubApi:
|
||||
license: Optional[str] = Licenses.APACHE_V2,
|
||||
visibility: Optional[int] = DatasetVisibility.PUBLIC,
|
||||
description: Optional[str] = '',
|
||||
endpoint: Optional[str] = None, ) -> str:
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None) -> str:
|
||||
"""
|
||||
Create a dataset in ModelScope.
|
||||
|
||||
Args:
|
||||
dataset_name (str): The name of the dataset.
|
||||
namespace (str): The namespace (user or organization) for the dataset.
|
||||
chinese_name (str, optional): The Chinese name of the dataset. Defaults to ''.
|
||||
license (str, optional): The license of the dataset. Defaults to Licenses.APACHE_V2.
|
||||
visibility (int, optional): The visibility of the dataset. Defaults to DatasetVisibility.PUBLIC.
|
||||
description (str, optional): The description of the dataset. Defaults to ''.
|
||||
endpoint (str, optional): The endpoint to use. If not provided, the default endpoint is used.
|
||||
token (str, optional): The access token for authentication.
|
||||
|
||||
Returns:
|
||||
str: The URL of the created dataset repository.
|
||||
"""
|
||||
|
||||
if dataset_name is None or namespace is None:
|
||||
raise InvalidParameter('dataset_name and namespace are required!')
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
path = f'{endpoint}/api/v1/datasets'
|
||||
@@ -1310,11 +1347,25 @@ class HubApi:
|
||||
raise_on_error(r.json())
|
||||
dataset_repo_url = f'{endpoint}/datasets/{namespace}/{dataset_name}'
|
||||
logger.info(f'Create dataset success: {dataset_repo_url}')
|
||||
|
||||
return dataset_repo_url
|
||||
|
||||
def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None):
|
||||
def delete_dataset(self,
|
||||
dataset_id: str,
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
"""
|
||||
Delete a dataset from ModelScope.
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
Args:
|
||||
dataset_id (str): The dataset id to delete.
|
||||
endpoint (str, optional): The endpoint to use. If not provided, the default endpoint is used.
|
||||
token (str, optional): The access token for authentication.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
if cookies is None:
|
||||
@@ -1327,12 +1378,16 @@ class HubApi:
|
||||
raise_for_http_status(r)
|
||||
raise_on_error(r.json())
|
||||
|
||||
def get_dataset_id_and_type(self, dataset_name: str, namespace: str, endpoint: Optional[str] = None):
|
||||
def get_dataset_id_and_type(self,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
""" Get the dataset id and type. """
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
r = self.session.get(datahub_url, cookies=cookies)
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(datahub_url, resp, r)
|
||||
@@ -1348,7 +1403,8 @@ class HubApi:
|
||||
recursive: bool = True,
|
||||
page_number: int = 1,
|
||||
page_size: int = 100,
|
||||
endpoint: Optional[str] = None):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
"""
|
||||
@deprecated: Use `get_dataset_files` instead.
|
||||
"""
|
||||
@@ -1356,7 +1412,7 @@ class HubApi:
|
||||
DeprecationWarning)
|
||||
|
||||
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
||||
dataset_name=dataset_name, namespace=namespace, endpoint=endpoint)
|
||||
dataset_name=dataset_name, namespace=namespace, endpoint=endpoint, token=token)
|
||||
|
||||
recursive = 'True' if recursive else 'False'
|
||||
if not endpoint:
|
||||
@@ -1365,7 +1421,7 @@ class HubApi:
|
||||
params = {'Revision': revision if revision else 'master',
|
||||
'Root': root_path if root_path else '/', 'Recursive': recursive,
|
||||
'PageNumber': page_number, 'PageSize': page_size}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
|
||||
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
||||
resp = r.json()
|
||||
@@ -1380,7 +1436,8 @@ class HubApi:
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
page_number: int = 1,
|
||||
page_size: int = 50,
|
||||
endpoint: Optional[str] = None):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
"""
|
||||
Get the commit history for a repository.
|
||||
|
||||
@@ -1391,6 +1448,7 @@ class HubApi:
|
||||
page_number (int): The page number for pagination. Defaults to 1.
|
||||
page_size (int): The number of commits per page. Defaults to 50.
|
||||
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
||||
token (Optional[str]): The access token.
|
||||
|
||||
Returns:
|
||||
CommitHistoryResponse: The commit history response.
|
||||
@@ -1420,7 +1478,7 @@ class HubApi:
|
||||
'PageNumber': page_number,
|
||||
'PageSize': page_size
|
||||
}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
|
||||
try:
|
||||
r = self.session.get(commits_url, params=params,
|
||||
@@ -1443,7 +1501,8 @@ class HubApi:
|
||||
recursive: bool = True,
|
||||
page_number: int = 1,
|
||||
page_size: int = 100,
|
||||
endpoint: Optional[str] = None):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
"""
|
||||
Get the dataset files.
|
||||
|
||||
@@ -1455,6 +1514,7 @@ class HubApi:
|
||||
page_number (int): The page number for pagination. Defaults to 1.
|
||||
page_size (int): The number of items per page. Defaults to 100.
|
||||
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
||||
token (Optional[str]): The access token.
|
||||
|
||||
Returns:
|
||||
List: The response containing the dataset repository tree information.
|
||||
@@ -1468,7 +1528,7 @@ class HubApi:
|
||||
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
||||
|
||||
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
|
||||
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
@@ -1480,7 +1540,7 @@ class HubApi:
|
||||
'PageNumber': page_number,
|
||||
'PageSize': page_size
|
||||
}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
|
||||
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
||||
resp = r.json()
|
||||
@@ -1492,7 +1552,8 @@ class HubApi:
|
||||
self,
|
||||
dataset_id: str,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
endpoint: Optional[str] = None
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Get the dataset information.
|
||||
@@ -1501,11 +1562,12 @@ class HubApi:
|
||||
dataset_id (str): The dataset id.
|
||||
revision (Optional[str]): The revision of the dataset.
|
||||
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
||||
token (Optional[str]): The access token.
|
||||
|
||||
Returns:
|
||||
dict: The dataset information.
|
||||
"""
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
|
||||
@@ -1522,12 +1584,13 @@ class HubApi:
|
||||
return resp[API_RESPONSE_FIELD_DATA]
|
||||
|
||||
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str,
|
||||
dataset_id: str, revision: str, endpoint: Optional[str] = None):
|
||||
dataset_id: str, revision: str, endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
""" Get the meta file-list of the dataset. """
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
r = self.session.get(datahub_url,
|
||||
cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
@@ -1559,11 +1622,12 @@ class HubApi:
|
||||
namespace: str,
|
||||
revision: str,
|
||||
meta_cache_dir: str, dataset_type: int, file_list: list,
|
||||
endpoint: Optional[str] = None):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
local_paths = defaultdict(list)
|
||||
dataset_formation = DatasetFormations(dataset_type)
|
||||
dataset_meta_format = DatasetMetaFormats[dataset_formation]
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
|
||||
# Dump the data_type as a local file
|
||||
HubApi.dump_datatype_file(dataset_type=dataset_type, meta_cache_dir=meta_cache_dir)
|
||||
@@ -1591,7 +1655,8 @@ class HubApi:
|
||||
return local_paths, dataset_formation
|
||||
|
||||
@staticmethod
|
||||
def fetch_meta_files_from_url(url, out_path, chunk_size=1024, mode=DownloadMode.REUSE_DATASET_IF_EXISTS):
|
||||
def fetch_meta_files_from_url(url, out_path, chunk_size=1024, mode=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
||||
token: Optional[str] = None):
|
||||
"""
|
||||
Fetch the meta-data files from the url, e.g. csv/jsonl files.
|
||||
"""
|
||||
@@ -1605,7 +1670,7 @@ class HubApi:
|
||||
if os.path.exists(out_path):
|
||||
logger.info(f'Reusing cached meta-data file: {out_path}')
|
||||
return out_path
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = HubApi().get_cookies(access_token=token)
|
||||
|
||||
# Make the request and get the response content as TextIO
|
||||
logger.info('Loading meta-data file ...')
|
||||
@@ -1697,12 +1762,13 @@ class HubApi:
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
endpoint: Optional[str] = None):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
|
||||
f'ststoken?Revision={revision}'
|
||||
return self.datahub_remote_call(datahub_url)
|
||||
return self.datahub_remote_call(datahub_url, token=token)
|
||||
|
||||
def get_dataset_access_config_session(
|
||||
self,
|
||||
@@ -1710,7 +1776,8 @@ class HubApi:
|
||||
namespace: str,
|
||||
check_cookie: bool,
|
||||
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
||||
endpoint: Optional[str] = None):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
@@ -1719,7 +1786,7 @@ class HubApi:
|
||||
if check_cookie:
|
||||
cookies = self._check_cookie(use_cookies=True)
|
||||
else:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
|
||||
r = self.session.get(
|
||||
url=datahub_url,
|
||||
@@ -1729,7 +1796,7 @@ class HubApi:
|
||||
raise_on_error(resp)
|
||||
return resp['Data']
|
||||
|
||||
def get_virgo_meta(self, dataset_id: str, version: int = 1) -> dict:
|
||||
def get_virgo_meta(self, dataset_id: str, version: int = 1, token: Optional[str] = None) -> dict:
|
||||
"""
|
||||
Get virgo dataset meta info.
|
||||
"""
|
||||
@@ -1738,7 +1805,7 @@ class HubApi:
|
||||
raise RuntimeError(f'Virgo endpoint is not set in env: {VirgoDatasetConfig.env_virgo_endpoint}')
|
||||
|
||||
virgo_dataset_url = f'{virgo_endpoint}/data/set/download'
|
||||
cookies = requests.utils.dict_from_cookiejar(ModelScopeConfig.get_cookies())
|
||||
cookies = requests.utils.dict_from_cookiejar(self.get_cookies(access_token=token))
|
||||
|
||||
dataset_info = dict(
|
||||
dataSetId=dataset_id,
|
||||
@@ -1763,11 +1830,12 @@ class HubApi:
|
||||
namespace: str,
|
||||
revision: str,
|
||||
zip_file_name: str,
|
||||
endpoint: Optional[str] = None):
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
r = self.session.get(url=datahub_url, cookies=cookies,
|
||||
headers=self.builder_headers(self.headers))
|
||||
resp = r.json()
|
||||
@@ -1787,13 +1855,14 @@ class HubApi:
|
||||
return data_sts
|
||||
|
||||
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
|
||||
is_recursive, is_filter_dir, revision, endpoint: Optional[str] = None):
|
||||
is_recursive, is_filter_dir, revision, endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
|
||||
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
resp = self.session.get(url=url, cookies=cookies, timeout=1800)
|
||||
resp = resp.json()
|
||||
raise_on_error(resp)
|
||||
@@ -1801,14 +1870,15 @@ class HubApi:
|
||||
return resp
|
||||
|
||||
def delete_oss_dataset_object(self, object_name: str, dataset_name: str,
|
||||
namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
|
||||
namespace: str, revision: str, endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None) -> str:
|
||||
if not object_name or not dataset_name or not namespace or not revision:
|
||||
raise ValueError('Args cannot be empty!')
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
resp = self.session.delete(url=url, cookies=cookies)
|
||||
resp = resp.json()
|
||||
raise_on_error(resp)
|
||||
@@ -1816,7 +1886,8 @@ class HubApi:
|
||||
return resp
|
||||
|
||||
def delete_oss_dataset_dir(self, object_name: str, dataset_name: str,
|
||||
namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
|
||||
namespace: str, revision: str, endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None) -> str:
|
||||
if not object_name or not dataset_name or not namespace or not revision:
|
||||
raise ValueError('Args cannot be empty!')
|
||||
if not endpoint:
|
||||
@@ -1824,15 +1895,15 @@ class HubApi:
|
||||
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \
|
||||
f'&Revision={revision}'
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
resp = self.session.delete(url=url, cookies=cookies)
|
||||
resp = resp.json()
|
||||
raise_on_error(resp)
|
||||
resp = resp['Message']
|
||||
return resp
|
||||
|
||||
def datahub_remote_call(self, url):
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
def datahub_remote_call(self, url, token: Optional[str] = None):
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
r = self.session.get(
|
||||
url,
|
||||
cookies=cookies,
|
||||
@@ -1842,13 +1913,14 @@ class HubApi:
|
||||
return resp['Data']
|
||||
|
||||
def dataset_download_statistics(self, dataset_name: str, namespace: str,
|
||||
use_streaming: bool = False, endpoint: Optional[str] = None) -> None:
|
||||
use_streaming: bool = False, endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None) -> None:
|
||||
is_ci_test = os.getenv('CI_TEST') == 'True'
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
if dataset_name and namespace and not is_ci_test and not use_streaming:
|
||||
try:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
|
||||
# Download count
|
||||
download_count_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
|
||||
@@ -1926,8 +1998,6 @@ class HubApi:
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
|
||||
self.login(access_token=token, endpoint=endpoint)
|
||||
|
||||
repo_exists: bool = self.repo_exists(repo_id, repo_type=repo_type, endpoint=endpoint, token=token)
|
||||
if repo_exists:
|
||||
if exist_ok:
|
||||
@@ -1953,7 +2023,8 @@ class HubApi:
|
||||
visibility=visibility,
|
||||
license=license,
|
||||
chinese_name=chinese_name,
|
||||
aigc_model=aigc_model
|
||||
aigc_model=aigc_model,
|
||||
token=token,
|
||||
)
|
||||
if create_default_config:
|
||||
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
||||
@@ -1986,6 +2057,7 @@ class HubApi:
|
||||
chinese_name=chinese_name,
|
||||
license=license,
|
||||
visibility=visibility,
|
||||
token=token,
|
||||
)
|
||||
print(f'New dataset created successfully at {repo_url}.', flush=True)
|
||||
|
||||
@@ -2161,7 +2233,6 @@ class HubApi:
|
||||
... )
|
||||
>>> print(commit_info)
|
||||
"""
|
||||
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
@@ -2219,6 +2290,7 @@ class HubApi:
|
||||
data=path_or_fileobj,
|
||||
disable_tqdm=disable_tqdm,
|
||||
tqdm_desc=tqdm_desc,
|
||||
token=token,
|
||||
)
|
||||
|
||||
# Construct commit info and create commit
|
||||
@@ -2370,6 +2442,7 @@ class HubApi:
|
||||
data=file_path,
|
||||
disable_tqdm=file_size <= UPLOAD_BLOB_TQDM_DISABLE_THRESHOLD,
|
||||
tqdm_desc='[Uploading ' + file_path_in_repo + ']',
|
||||
token=token,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -2446,6 +2519,7 @@ class HubApi:
|
||||
disable_tqdm: Optional[bool] = False,
|
||||
tqdm_desc: Optional[str] = '[Uploading]',
|
||||
buffer_size_mb: Optional[int] = 1,
|
||||
token: Optional[str] = None,
|
||||
) -> dict:
|
||||
|
||||
res_d: dict = dict(
|
||||
@@ -2460,6 +2534,7 @@ class HubApi:
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
objects=objects,
|
||||
token=token,
|
||||
)
|
||||
|
||||
# upload_object: {'url': 'xxx', 'oid': 'xxx'}
|
||||
@@ -2470,7 +2545,7 @@ class HubApi:
|
||||
res_d['is_uploaded'] = True
|
||||
return res_d
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
cookies = dict(cookies) if cookies else None
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
@@ -2536,7 +2611,8 @@ class HubApi:
|
||||
repo_id: str,
|
||||
repo_type: str,
|
||||
objects: List[Dict[str, Any]],
|
||||
endpoint: Optional[str] = None
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Check the blob has already uploaded.
|
||||
@@ -2549,6 +2625,7 @@ class HubApi:
|
||||
oid (str): The sha256 hash value.
|
||||
size (int): The size of the blob.
|
||||
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
||||
token (str): The access token.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The result of the check.
|
||||
@@ -2565,9 +2642,7 @@ class HubApi:
|
||||
'objects': objects,
|
||||
}
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.builder_headers(self.headers),
|
||||
@@ -2755,7 +2830,8 @@ class HubApi:
|
||||
delete_patterns: Union[str, List[str]],
|
||||
*,
|
||||
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
endpoint: Optional[str] = None) -> Dict[str, Any]:
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete files in batch using glob (wildcard) patterns, e.g. '*.py', 'data/*.csv', 'foo*', etc.
|
||||
|
||||
@@ -2780,6 +2856,7 @@ class HubApi:
|
||||
delete_patterns (str or List[str]): List of glob patterns, e.g. '*.py', 'data/*.csv', 'foo*'
|
||||
revision (str, optional): Branch or tag name
|
||||
endpoint (str, optional): API endpoint
|
||||
token (str, optional): Access token
|
||||
Returns:
|
||||
dict: Deletion result
|
||||
"""
|
||||
@@ -2790,7 +2867,7 @@ class HubApi:
|
||||
if isinstance(delete_patterns, str):
|
||||
delete_patterns = [delete_patterns]
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
if cookies is None:
|
||||
|
||||
@@ -33,6 +33,7 @@ def get_model_id_from_cache(model_root_path: str, ) -> str:
|
||||
def check_local_model_is_latest(
|
||||
model_root_path: str,
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
token: Optional[str] = None,
|
||||
):
|
||||
"""Check local model repo is latest.
|
||||
Check local model repo is same as hub latest version.
|
||||
@@ -45,7 +46,8 @@ def check_local_model_is_latest(
|
||||
'user-agent':
|
||||
ModelScopeConfig.get_user_agent(user_agent=user_agent, )
|
||||
}
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
_api = HubApi(timeout=20, token=token)
|
||||
cookies = _api.get_cookies()
|
||||
|
||||
snapshot_header = headers if 'CI_TEST' in os.environ else {
|
||||
**headers,
|
||||
@@ -53,7 +55,6 @@ def check_local_model_is_latest(
|
||||
'Snapshot': 'True'
|
||||
}
|
||||
}
|
||||
_api = HubApi(timeout=20)
|
||||
try:
|
||||
_, revisions = _api.get_model_branches_and_tags(
|
||||
model_id=model_id, use_cookies=cookies)
|
||||
|
||||
@@ -7,7 +7,7 @@ import json
|
||||
import requests
|
||||
from attrs import asdict, define, field, validators
|
||||
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.hub.constants import (API_RESPONSE_FIELD_DATA,
|
||||
API_RESPONSE_FIELD_MESSAGE)
|
||||
from modelscope.hub.errors import (NotLoginException, NotSupportError,
|
||||
@@ -188,13 +188,11 @@ class ServiceDeployer(object):
|
||||
"""Facilitate model deployment on to supported service provider(s).
|
||||
"""
|
||||
|
||||
def __init__(self, endpoint=None):
|
||||
def __init__(self, endpoint=None, token: Optional[str] = None):
|
||||
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
||||
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
|
||||
self.cookies = ModelScopeConfig.get_cookies()
|
||||
if self.cookies is None:
|
||||
raise NotLoginException(
|
||||
'Token does not exist, please login with HubApi first.')
|
||||
self.cookies = HubApi().get_cookies(
|
||||
access_token=token, cookies_required=True)
|
||||
|
||||
# deploy_model
|
||||
def create(self, model_id: str, revision: str, instance_name: str,
|
||||
|
||||
@@ -49,6 +49,7 @@ def model_file_download(
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
"""Download from a given URL and cache it if it's not already present in the local cache.
|
||||
|
||||
@@ -67,6 +68,7 @@ def model_file_download(
|
||||
local cached file if it exists. if `False`, download the file anyway even it exists.
|
||||
cookies (CookieJar, optional): The cookie of download request.
|
||||
local_dir (str, optional): Specific local directory path to which the file will be downloaded.
|
||||
token (str, optional): The user token.
|
||||
|
||||
Returns:
|
||||
string: string of local file or if networking is off, last version of
|
||||
@@ -95,7 +97,8 @@ def model_file_download(
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir)
|
||||
local_dir=local_dir,
|
||||
token=token)
|
||||
|
||||
|
||||
def dataset_file_download(
|
||||
@@ -107,6 +110,7 @@ def dataset_file_download(
|
||||
user_agent: Optional[Union[Dict, str]] = None,
|
||||
local_files_only: Optional[bool] = False,
|
||||
cookies: Optional[CookieJar] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
@@ -129,6 +133,7 @@ def dataset_file_download(
|
||||
local_files_only (bool, optional): If `True`, avoid downloading the file and return the path to the
|
||||
local cached file if it exists.
|
||||
cookies (CookieJar, optional): The cookie of the request, default None.
|
||||
token (str, optional): The user token.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
@@ -153,7 +158,8 @@ def dataset_file_download(
|
||||
user_agent=user_agent,
|
||||
local_files_only=local_files_only,
|
||||
cookies=cookies,
|
||||
local_dir=local_dir)
|
||||
local_dir=local_dir,
|
||||
token=token)
|
||||
|
||||
|
||||
def _repo_file_download(
|
||||
@@ -168,6 +174,7 @@ def _repo_file_download(
|
||||
cookies: Optional[CookieJar] = None,
|
||||
local_dir: Optional[str] = None,
|
||||
disable_tqdm: bool = False,
|
||||
token: Optional[str] = None,
|
||||
) -> Optional[str]: # pragma: no cover
|
||||
|
||||
if not repo_type:
|
||||
@@ -194,7 +201,7 @@ def _repo_file_download(
|
||||
' traffic has been disabled. To enable look-ups and downloads'
|
||||
" online, set 'local_files_only' to False.")
|
||||
|
||||
_api = HubApi()
|
||||
_api = HubApi(token=token)
|
||||
|
||||
headers = {
|
||||
'user-agent': ModelScopeConfig.get_user_agent(user_agent=user_agent, ),
|
||||
@@ -212,7 +219,7 @@ def _repo_file_download(
|
||||
headers['x-aliyun-region-id'] = region_id
|
||||
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = _api.get_cookies()
|
||||
repo_files = []
|
||||
endpoint = _api.get_endpoint_for_read(repo_id=repo_id, repo_type=repo_type)
|
||||
file_to_download_meta = None
|
||||
|
||||
@@ -48,6 +48,7 @@ def snapshot_download(
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
enable_file_lock: Optional[bool] = None,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download all files of a repo.
|
||||
Downloads a whole snapshot of a repo's files at the specified revision. This
|
||||
@@ -88,6 +89,7 @@ def snapshot_download(
|
||||
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
|
||||
progress_callbacks (`List[Type[ProgressCallback]]`, **optional**, default to `None`):
|
||||
progress callbacks to track the download progress.
|
||||
token (str, optional): The user token.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
@@ -146,7 +148,8 @@ def snapshot_download(
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers,
|
||||
progress_callbacks=progress_callbacks)
|
||||
progress_callbacks=progress_callbacks,
|
||||
token=token)
|
||||
|
||||
|
||||
def dataset_snapshot_download(
|
||||
@@ -163,6 +166,7 @@ def dataset_snapshot_download(
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
enable_file_lock: Optional[bool] = None,
|
||||
max_workers: int = 8,
|
||||
token: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Download raw files of a dataset.
|
||||
Downloads all files at the specified revision. This
|
||||
@@ -199,6 +203,7 @@ def dataset_snapshot_download(
|
||||
If you find something wrong with file lock and have a problem modifying your code,
|
||||
change `MODELSCOPE_HUB_FILE_LOCK` env to `false`.
|
||||
max_workers (`int`): The maximum number of workers to download files, default 8.
|
||||
token (str, optional): The user token.
|
||||
Raises:
|
||||
ValueError: the value details.
|
||||
|
||||
@@ -241,7 +246,8 @@ def dataset_snapshot_download(
|
||||
local_dir=local_dir,
|
||||
ignore_patterns=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
max_workers=max_workers)
|
||||
max_workers=max_workers,
|
||||
token=token)
|
||||
|
||||
|
||||
def _snapshot_download(
|
||||
@@ -260,6 +266,7 @@ def _snapshot_download(
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_workers: int = 8,
|
||||
progress_callbacks: List[Type[ProgressCallback]] = None,
|
||||
token: Optional[str] = None,
|
||||
):
|
||||
if not repo_type:
|
||||
repo_type = REPO_TYPE_MODEL
|
||||
@@ -299,11 +306,11 @@ def _snapshot_download(
|
||||
)
|
||||
headers['x-aliyun-region-id'] = region_id
|
||||
|
||||
_api = HubApi()
|
||||
_api = HubApi(token=token)
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=repo_type)
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = _api.get_cookies()
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
if local_dir:
|
||||
directory = os.path.abspath(local_dir)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import Mapping, Sequence, Union
|
||||
from typing import Mapping, Optional, Sequence, Union
|
||||
|
||||
from modelscope.msdatasets.auth.auth_config import BaseAuthConfig
|
||||
from modelscope.msdatasets.download.download_config import DataDownloadConfig
|
||||
@@ -11,14 +11,24 @@ from modelscope.utils.constant import DownloadMode, Hubs
|
||||
class DatasetContextConfig:
|
||||
"""Context configuration of dataset."""
|
||||
|
||||
def __init__(self, dataset_name: Union[str, list], namespace: str,
|
||||
version: str, subset_name: str, split: Union[str, list],
|
||||
target: str, hub: Hubs, data_dir: str,
|
||||
def __init__(self,
|
||||
dataset_name: Union[str, list],
|
||||
namespace: str,
|
||||
version: str,
|
||||
subset_name: str,
|
||||
split: Union[str, list],
|
||||
target: str,
|
||||
hub: Hubs,
|
||||
data_dir: str,
|
||||
data_files: Union[str, Sequence[str],
|
||||
Mapping[str, Union[str, Sequence[str]]]],
|
||||
download_mode: DownloadMode, cache_root_dir: str,
|
||||
use_streaming: bool, stream_batch_size: int,
|
||||
trust_remote_code: bool, **kwargs):
|
||||
download_mode: DownloadMode,
|
||||
cache_root_dir: str,
|
||||
use_streaming: bool,
|
||||
stream_batch_size: int,
|
||||
trust_remote_code: bool,
|
||||
token: Optional[str] = None,
|
||||
**kwargs):
|
||||
|
||||
self._download_config = None
|
||||
self._data_meta_config = None
|
||||
@@ -32,6 +42,7 @@ class DatasetContextConfig:
|
||||
|
||||
# General arguments for dataset
|
||||
self.hub = hub
|
||||
self.token = token
|
||||
self.download_mode = download_mode
|
||||
self.dataset_name = dataset_name
|
||||
self.namespace = namespace
|
||||
|
||||
@@ -8,7 +8,7 @@ from datasets import (Dataset, DatasetBuilder, DatasetDict, IterableDataset,
|
||||
IterableDatasetDict)
|
||||
from datasets import load_dataset as hf_load_dataset
|
||||
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.msdatasets.auth.auth_config import OssAuthConfig
|
||||
from modelscope.msdatasets.context.dataset_context_config import \
|
||||
DatasetContextConfig
|
||||
@@ -86,7 +86,8 @@ class OssDownloader(BaseDownloader):
|
||||
def _authorize(self) -> None:
|
||||
""" Authorization of target dataset.
|
||||
Get credentials from cache and send to the modelscope-hub in the future. """
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = HubApi().get_cookies(
|
||||
access_token=self.dataset_context_config.token)
|
||||
git_token = ModelScopeConfig.get_token()
|
||||
user_info = ModelScopeConfig.get_user_info()
|
||||
|
||||
@@ -178,7 +179,8 @@ class VirgoDownloader(BaseDownloader):
|
||||
"""Authorization of virgo dataset."""
|
||||
from modelscope.msdatasets.auth.auth_config import VirgoAuthConfig
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = HubApi().get_cookies(
|
||||
access_token=self.dataset_context_config.token)
|
||||
user_info = ModelScopeConfig.get_user_info()
|
||||
|
||||
if not self.dataset_context_config.auth_config:
|
||||
|
||||
@@ -97,7 +97,8 @@ class RemoteDataLoaderManager(DataLoaderManager):
|
||||
|
||||
def __init__(self, dataset_context_config: DatasetContextConfig):
|
||||
super().__init__(dataset_context_config=dataset_context_config)
|
||||
self.api = HubApi()
|
||||
|
||||
self.api = HubApi(token=dataset_context_config.token)
|
||||
|
||||
def load_dataset(self, data_loader_type: enum.Enum):
|
||||
# Get args from context
|
||||
@@ -112,6 +113,7 @@ class RemoteDataLoaderManager(DataLoaderManager):
|
||||
use_streaming = self.dataset_context_config.use_streaming
|
||||
input_config_kwargs = self.dataset_context_config.config_kwargs
|
||||
trust_remote_code = self.dataset_context_config.trust_remote_code
|
||||
token = self.dataset_context_config.token
|
||||
|
||||
# To use the huggingface data loader
|
||||
if data_loader_type == RemoteDataLoaderType.HF_DATA_LOADER:
|
||||
@@ -129,6 +131,7 @@ class RemoteDataLoaderManager(DataLoaderManager):
|
||||
download_mode=download_mode_val,
|
||||
streaming=use_streaming,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token=token,
|
||||
**input_config_kwargs)
|
||||
# download statistics
|
||||
self.api.dataset_download_statistics(
|
||||
|
||||
@@ -24,7 +24,7 @@ from filelock import FileLock
|
||||
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
from modelscope.hub.api import HubApi
|
||||
|
||||
from modelscope import __version__
|
||||
|
||||
@@ -219,7 +219,7 @@ def get_from_cache_ms(
|
||||
etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
|
||||
connected = True
|
||||
try:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = HubApi().get_cookies(access_token=token)
|
||||
response = http_head_ms(
|
||||
url,
|
||||
allow_redirects=True,
|
||||
|
||||
Reference in New Issue
Block a user