mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
2996 lines
124 KiB
Python
2996 lines
124 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
# yapf: disable
|
|
|
|
import datetime
|
|
import fnmatch
|
|
import functools
|
|
import io
|
|
import os
|
|
import pickle
|
|
import platform
|
|
import re
|
|
import shutil
|
|
import tempfile
|
|
import time
|
|
import uuid
|
|
import warnings
|
|
from collections import defaultdict
|
|
from http import HTTPStatus
|
|
from http.cookiejar import CookieJar
|
|
from os.path import expanduser
|
|
from pathlib import Path
|
|
from typing import Any, BinaryIO, Dict, Iterable, List, Optional, Tuple, Union
|
|
from urllib.parse import urlencode
|
|
|
|
import json
|
|
import requests
|
|
from requests import Session
|
|
from requests.adapters import HTTPAdapter, Retry
|
|
from requests.exceptions import HTTPError
|
|
from tqdm.auto import tqdm
|
|
|
|
from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
|
API_HTTP_CLIENT_TIMEOUT,
|
|
API_RESPONSE_FIELD_DATA,
|
|
API_RESPONSE_FIELD_EMAIL,
|
|
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN,
|
|
API_RESPONSE_FIELD_MESSAGE,
|
|
API_RESPONSE_FIELD_USERNAME,
|
|
DEFAULT_MAX_WORKERS,
|
|
MODELSCOPE_CLOUD_ENVIRONMENT,
|
|
MODELSCOPE_CLOUD_USERNAME,
|
|
MODELSCOPE_CREDENTIALS_PATH,
|
|
MODELSCOPE_DOMAIN,
|
|
MODELSCOPE_PREFER_AI_SITE,
|
|
MODELSCOPE_REQUEST_ID,
|
|
MODELSCOPE_URL_SCHEME, ONE_YEAR_SECONDS,
|
|
REQUESTS_API_HTTP_METHOD,
|
|
TEMPORARY_FOLDER_NAME,
|
|
UPLOAD_BLOB_TQDM_DISABLE_THRESHOLD,
|
|
UPLOAD_COMMIT_BATCH_SIZE,
|
|
UPLOAD_MAX_FILE_COUNT,
|
|
UPLOAD_MAX_FILE_COUNT_IN_DIR,
|
|
UPLOAD_MAX_FILE_SIZE,
|
|
UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT,
|
|
UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS,
|
|
DatasetVisibility, Licenses,
|
|
ModelVisibility, Visibility,
|
|
VisibilityMap)
|
|
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
|
NotLoginException, RequestError,
|
|
datahub_raise_on_error,
|
|
handle_http_post_error,
|
|
handle_http_response, is_ok,
|
|
raise_for_http_status, raise_on_error)
|
|
from modelscope.hub.git import GitCommandWrapper
|
|
from modelscope.hub.info import DatasetInfo, ModelInfo
|
|
from modelscope.hub.repository import Repository
|
|
from modelscope.hub.utils.aigc import AigcModel
|
|
from modelscope.hub.utils.utils import (add_content_to_file, get_domain,
|
|
get_endpoint, get_readable_folder_size,
|
|
get_release_datetime, is_env_true,
|
|
model_id_to_group_owner_name)
|
|
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
|
DEFAULT_MODEL_REVISION,
|
|
DEFAULT_REPOSITORY_REVISION,
|
|
MASTER_MODEL_BRANCH, META_FILES_FORMAT,
|
|
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
|
REPO_TYPE_SUPPORT, ConfigFields,
|
|
DatasetFormations, DatasetMetaFormats,
|
|
DownloadChannel, DownloadMode,
|
|
Frameworks, ModelFile, Tasks,
|
|
VirgoDatasetConfig)
|
|
from modelscope.utils.file_utils import get_file_hash, get_file_size
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.repo_utils import (DATASET_LFS_SUFFIX,
|
|
DEFAULT_IGNORE_PATTERNS,
|
|
MODEL_LFS_SUFFIX,
|
|
CommitHistoryResponse, CommitInfo,
|
|
CommitOperation, CommitOperationAdd,
|
|
RepoUtils)
|
|
from modelscope.utils.thread_utils import thread_executor
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class HubApi:
|
|
"""Model hub api interface.
|
|
"""
|
|
|
|
def __init__(self,
|
|
endpoint: Optional[str] = None,
|
|
timeout=API_HTTP_CLIENT_TIMEOUT,
|
|
max_retries=API_HTTP_CLIENT_MAX_RETRIES):
|
|
"""The ModelScope HubApi。
|
|
|
|
Args:
|
|
endpoint (str, optional): The modelscope server http|https address. Defaults to None.
|
|
"""
|
|
self.endpoint = endpoint if endpoint is not None else get_endpoint()
|
|
self.headers = {'user-agent': ModelScopeConfig.get_user_agent()}
|
|
self.session = Session()
|
|
retry = Retry(
|
|
total=max_retries,
|
|
read=2,
|
|
connect=2,
|
|
backoff_factor=1,
|
|
status_forcelist=(500, 502, 503, 504),
|
|
respect_retry_after_header=False,
|
|
)
|
|
adapter = HTTPAdapter(max_retries=retry)
|
|
self.session.mount('http://', adapter)
|
|
self.session.mount('https://', adapter)
|
|
# set http timeout
|
|
for method in REQUESTS_API_HTTP_METHOD:
|
|
setattr(
|
|
self.session, method,
|
|
functools.partial(
|
|
getattr(self.session, method),
|
|
timeout=timeout))
|
|
|
|
self.upload_checker = UploadingCheck()
|
|
|
|
@staticmethod
|
|
def _get_cookies(access_token: str):
|
|
"""
|
|
Get jar cookies for authentication from access_token.
|
|
|
|
Args:
|
|
access_token (str): user access token on ModelScope.
|
|
|
|
Returns:
|
|
jar (CookieJar): cookies for authentication.
|
|
"""
|
|
from requests.cookies import RequestsCookieJar
|
|
jar = RequestsCookieJar()
|
|
jar.set('m_session_id',
|
|
access_token,
|
|
domain=get_domain(),
|
|
path='/')
|
|
return jar
|
|
|
|
def get_cookies(self, access_token, cookies_required: Optional[bool] = False):
|
|
"""
|
|
Get cookies for authentication from local cache or access_token.
|
|
|
|
Args:
|
|
access_token (str): user access token on ModelScope
|
|
cookies_required (bool): whether to raise error if no cookies found, defaults to `False`.
|
|
|
|
Returns:
|
|
cookies (CookieJar): cookies for authentication.
|
|
|
|
Raises:
|
|
ValueError: If no credentials found and cookies_required is True.
|
|
"""
|
|
if access_token:
|
|
cookies = self._get_cookies(access_token=access_token)
|
|
else:
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
if cookies is None and cookies_required:
|
|
raise ValueError(
|
|
'No credentials found.'
|
|
'You can pass the `--token` argument, '
|
|
'or use HubApi().login(access_token=`your_sdk_token`). '
|
|
'Your token is available at https://modelscope.cn/my/myaccesstoken'
|
|
)
|
|
|
|
return cookies
|
|
|
|
def login(
|
|
self,
|
|
access_token: Optional[str] = None,
|
|
endpoint: Optional[str] = None
|
|
):
|
|
"""Login with your SDK access token, which can be obtained from
|
|
https://www.modelscope.cn user center.
|
|
|
|
Args:
|
|
access_token (str): user access token on modelscope, set this argument or set `MODELSCOPE_API_TOKEN`.
|
|
If neither of the tokens exist, login will directly return.
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Returns:
|
|
cookies: to authenticate yourself to ModelScope open-api
|
|
git_token: token to access your git repository.
|
|
|
|
Note:
|
|
You only have to login once within 30 days.
|
|
"""
|
|
if access_token is None:
|
|
access_token = os.environ.get('MODELSCOPE_API_TOKEN')
|
|
if not access_token:
|
|
return None, None
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
path = f'{endpoint}/api/v1/login'
|
|
r = self.session.post(
|
|
path,
|
|
json={'AccessToken': access_token},
|
|
headers=self.builder_headers(self.headers))
|
|
raise_for_http_status(r)
|
|
d = r.json()
|
|
raise_on_error(d)
|
|
|
|
token = d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_GIT_ACCESS_TOKEN]
|
|
cookies = r.cookies
|
|
|
|
# save token and cookie
|
|
ModelScopeConfig.save_token(token)
|
|
ModelScopeConfig.save_cookies(cookies)
|
|
ModelScopeConfig.save_user_info(
|
|
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_USERNAME],
|
|
d[API_RESPONSE_FIELD_DATA][API_RESPONSE_FIELD_EMAIL])
|
|
|
|
return d[API_RESPONSE_FIELD_DATA][
|
|
API_RESPONSE_FIELD_GIT_ACCESS_TOKEN], cookies
|
|
|
|
def create_model(self,
|
|
model_id: str,
|
|
visibility: Optional[int] = ModelVisibility.PUBLIC,
|
|
license: Optional[str] = Licenses.APACHE_V2,
|
|
chinese_name: Optional[str] = None,
|
|
original_model_id: Optional[str] = '',
|
|
endpoint: Optional[str] = None,
|
|
token: Optional[str] = None,
|
|
aigc_model: Optional['AigcModel'] = None) -> str:
|
|
"""Create model repo at ModelScope Hub.
|
|
|
|
Args:
|
|
model_id (str): The model id in format {owner}/{name}
|
|
visibility (int, optional): visibility of the model(1-private, 5-public), default 5.
|
|
license (str, optional): license of the model, default apache-2.0.
|
|
chinese_name (str, optional): chinese name of the model.
|
|
original_model_id (str, optional): the base model id which this model is trained from
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
token (str, optional): access token for authentication
|
|
aigc_model (AigcModel, optional): AigcModel instance for AIGC model creation.
|
|
If provided, will create an AIGC model with automatic file upload.
|
|
Refer to modelscope.hub.utils.aigc.AigcModel for details.
|
|
|
|
Returns:
|
|
str: URL of the created model repository
|
|
|
|
Raises:
|
|
InvalidParameter: If model_id is invalid or required AIGC parameters are missing.
|
|
ValueError: If not login.
|
|
|
|
Note:
|
|
model_id = {owner}/{name}
|
|
"""
|
|
if model_id is None:
|
|
raise InvalidParameter('model_id is required!')
|
|
# Get cookies for authentication.
|
|
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
|
|
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
|
|
|
# Base body configuration
|
|
body = {
|
|
'Path': owner_or_group,
|
|
'Name': name,
|
|
'ChineseName': chinese_name,
|
|
'Visibility': visibility,
|
|
'License': license,
|
|
'OriginalModelId': original_model_id,
|
|
'TrainId': os.environ.get('MODELSCOPE_TRAIN_ID', '')
|
|
}
|
|
|
|
# Set path based on model type
|
|
if aigc_model is not None:
|
|
# Use AIGC model endpoint
|
|
path = f'{endpoint}/api/v1/models/aigc'
|
|
# Best-effort pre-upload weights so server recognizes sha256 (use existing cookies)
|
|
aigc_model.preupload_weights(cookies=cookies, headers=self.builder_headers(self.headers))
|
|
|
|
# Add AIGC-specific fields to body
|
|
body.update({
|
|
'TagShowName': aigc_model.revision,
|
|
'CoverImages': aigc_model.cover_images,
|
|
'AigcType': aigc_model.aigc_type,
|
|
'TagDescription': aigc_model.description,
|
|
'VisionFoundation': aigc_model.base_model_type,
|
|
'BaseModel': aigc_model.base_model_id or original_model_id,
|
|
'WeightsName': aigc_model.weight_filename,
|
|
'WeightsSha256': aigc_model.weight_sha256,
|
|
'WeightsSize': aigc_model.weight_size,
|
|
'ModelPath': aigc_model.model_path
|
|
})
|
|
|
|
else:
|
|
# Use regular model endpoint
|
|
path = f'{endpoint}/api/v1/models'
|
|
|
|
r = self.session.post(
|
|
path,
|
|
json=body,
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
handle_http_post_error(r, path, body)
|
|
raise_on_error(r.json())
|
|
model_repo_url = f'{endpoint}/models/{model_id}'
|
|
|
|
# TODO: to be aligned with the new api
|
|
# Upload model files for AIGC models
|
|
# if aigc_model is not None:
|
|
# aigc_model.upload_to_repo(self, model_id, token)
|
|
|
|
return model_repo_url
|
|
|
|
def delete_model(self, model_id: str, endpoint: Optional[str] = None):
|
|
"""Delete model_id from ModelScope.
|
|
|
|
Args:
|
|
model_id (str): The model id.
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Raises:
|
|
ValueError: If not login.
|
|
|
|
Note:
|
|
model_id = {owner}/{name}
|
|
"""
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
if cookies is None:
|
|
raise ValueError('Token does not exist, please login first.')
|
|
path = f'{endpoint}/api/v1/models/{model_id}'
|
|
|
|
r = self.session.delete(path,
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
raise_for_http_status(r)
|
|
raise_on_error(r.json())
|
|
|
|
def get_model_url(self, model_id: str, endpoint: Optional[str] = None):
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
return f'{endpoint}/api/v1/models/{model_id}.git'
|
|
|
|
def get_model(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
endpoint: Optional[str] = None
|
|
) -> str:
|
|
"""Get model information at ModelScope
|
|
|
|
Args:
|
|
model_id (str): The model id.
|
|
revision (str optional): revision of model.
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Returns:
|
|
The model detail information.
|
|
|
|
Raises:
|
|
NotExistError: If the model is not exist, will throw NotExistError
|
|
|
|
Note:
|
|
model_id = {owner}/{name}
|
|
"""
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
owner_or_group, name = model_id_to_group_owner_name(model_id)
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
|
|
if revision:
|
|
path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}?Revision={revision}'
|
|
else:
|
|
path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}'
|
|
|
|
r = self.session.get(path, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
handle_http_response(r, logger, cookies, model_id)
|
|
if r.status_code == HTTPStatus.OK:
|
|
if is_ok(r.json()):
|
|
return r.json()[API_RESPONSE_FIELD_DATA]
|
|
else:
|
|
raise NotExistError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
|
else:
|
|
raise_for_http_status(r)
|
|
|
|
def get_endpoint_for_read(self,
|
|
repo_id: str,
|
|
*,
|
|
repo_type: Optional[str] = None) -> str:
|
|
"""Get proper endpoint for read operation (such as download, list etc.)
|
|
1. If user has set MODELSCOPE_DOMAIN, construct endpoint with user-specified domain.
|
|
If the repo does not exist on that endpoint, throw 404 error, otherwise return the endpoint.
|
|
2. If domain is not set, check existence of repo in cn-site and ai-site (intl version) respectively.
|
|
Checking order is determined by MODELSCOPE_PREFER_AI_SITE.
|
|
a. if MODELSCOPE_PREFER_AI_SITE is not set ,check cn-site first before ai-site (intl version)
|
|
b. otherwise check ai-site before cn-site
|
|
return the endpoint with which the given repo_id exists.
|
|
if neither exists, throw 404 error
|
|
"""
|
|
s = os.environ.get(MODELSCOPE_DOMAIN)
|
|
if s is not None and s.strip() != '':
|
|
endpoint = MODELSCOPE_URL_SCHEME + s
|
|
try:
|
|
self.repo_exists(repo_id=repo_id, repo_type=repo_type, endpoint=endpoint, re_raise=True)
|
|
except Exception:
|
|
logger.error(f'Repo {repo_id} does not exist on {endpoint}.')
|
|
raise
|
|
return endpoint
|
|
|
|
check_cn_first = not is_env_true(MODELSCOPE_PREFER_AI_SITE)
|
|
prefer_endpoint = get_endpoint(cn_site=check_cn_first)
|
|
if not self.repo_exists(
|
|
repo_id, repo_type=repo_type, endpoint=prefer_endpoint):
|
|
alternative_endpoint = get_endpoint(cn_site=(not check_cn_first))
|
|
logger.warning(f'Repo {repo_id} not exists on {prefer_endpoint}, '
|
|
f'will try on alternative endpoint {alternative_endpoint}.')
|
|
try:
|
|
self.repo_exists(
|
|
repo_id, repo_type=repo_type, endpoint=alternative_endpoint, re_raise=True)
|
|
except Exception:
|
|
logger.error(f'Repo {repo_id} not exists on either {prefer_endpoint} or {alternative_endpoint}')
|
|
raise
|
|
else:
|
|
return alternative_endpoint
|
|
else:
|
|
return prefer_endpoint
|
|
|
|
def model_info(self,
|
|
repo_id: str,
|
|
*,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
endpoint: Optional[str] = None) -> ModelInfo:
|
|
"""Get model information including commit history.
|
|
|
|
Args:
|
|
repo_id (str): The model id in the format of
|
|
``namespace/model_name``.
|
|
revision (str, optional): Specific revision of the model.
|
|
Defaults to ``DEFAULT_MODEL_REVISION``.
|
|
endpoint (str, optional): Hub endpoint to use. When ``None``,
|
|
use the endpoint specified when initializing :class:`HubApi`.
|
|
|
|
Returns:
|
|
ModelInfo: The model detailed information returned by
|
|
ModelScope Hub with commit history.
|
|
"""
|
|
owner_or_group, _ = model_id_to_group_owner_name(repo_id)
|
|
model_data = self.get_model(
|
|
model_id=repo_id, revision=revision, endpoint=endpoint)
|
|
commits = self.list_repo_commits(
|
|
repo_id=repo_id, repo_type=REPO_TYPE_MODEL, revision=revision, endpoint=endpoint)
|
|
|
|
# Create ModelInfo from API response data
|
|
model_info = ModelInfo(**model_data, commits=commits, author=owner_or_group)
|
|
|
|
return model_info
|
|
|
|
def dataset_info(self,
|
|
repo_id: str,
|
|
*,
|
|
revision: Optional[str] = None,
|
|
endpoint: Optional[str] = None) -> DatasetInfo:
|
|
"""Get dataset information including commit history.
|
|
|
|
Args:
|
|
repo_id (str): The dataset id in the format of
|
|
``namespace/dataset_name``.
|
|
revision (str, optional): Specific revision of the dataset.
|
|
Defaults to ``None``.
|
|
endpoint (str, optional): Hub endpoint to use. When ``None``,
|
|
use the endpoint specified when initializing :class:`HubApi`.
|
|
|
|
Returns:
|
|
DatasetInfo: The dataset detailed information returned by
|
|
ModelScope Hub with commit history.
|
|
"""
|
|
owner_or_group, _ = model_id_to_group_owner_name(repo_id)
|
|
dataset_data = self.get_dataset(
|
|
dataset_id=repo_id, revision=revision, endpoint=endpoint)
|
|
commits = self.list_repo_commits(
|
|
repo_id=repo_id, repo_type=REPO_TYPE_DATASET, revision=revision, endpoint=endpoint)
|
|
|
|
# Create DatasetInfo from API response data
|
|
dataset_info = DatasetInfo(**dataset_data, commits=commits, author=owner_or_group)
|
|
|
|
return dataset_info
|
|
|
|
def repo_info(
|
|
self,
|
|
repo_id: str,
|
|
*,
|
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
endpoint: Optional[str] = None
|
|
) -> Union[ModelInfo, DatasetInfo]:
|
|
"""Get repository information for models or datasets.
|
|
|
|
Args:
|
|
repo_id (str): The repository id in the format of
|
|
``namespace/repo_name``.
|
|
revision (str, optional): Specific revision of the repository.
|
|
Currently only effective for model repositories. Defaults to
|
|
``DEFAULT_MODEL_REVISION``.
|
|
repo_type (str, optional): Type of the repository. Supported
|
|
values are ``"model"`` and ``"dataset"``. If not provided,
|
|
``"model"`` is assumed.
|
|
endpoint (str, optional): Hub endpoint to use. When ``None``,
|
|
use the endpoint specified when initializing :class:`HubApi`.
|
|
|
|
Returns:
|
|
Union[ModelInfo, DatasetInfo]: The repository detailed information
|
|
returned by ModelScope Hub.
|
|
"""
|
|
if repo_type is None or repo_type == REPO_TYPE_MODEL:
|
|
return self.model_info(repo_id=repo_id, revision=revision, endpoint=endpoint)
|
|
|
|
if repo_type == REPO_TYPE_DATASET:
|
|
return self.dataset_info(repo_id=repo_id, revision=revision, endpoint=endpoint)
|
|
|
|
raise InvalidParameter(
|
|
f'Arg repo_type {repo_type} not supported. Please choose from {REPO_TYPE_SUPPORT}.')
|
|
|
|
def repo_exists(
|
|
self,
|
|
repo_id: str,
|
|
*,
|
|
repo_type: Optional[str] = None,
|
|
endpoint: Optional[str] = None,
|
|
re_raise: Optional[bool] = False,
|
|
token: Optional[str] = None
|
|
) -> bool:
|
|
"""
|
|
Checks if a repository exists on ModelScope
|
|
|
|
Args:
|
|
repo_id (`str`):
|
|
A namespace (user or an organization) and a repo name separated
|
|
by a `/`.
|
|
repo_type (`str`, *optional*):
|
|
`None` or `"model"` if getting repository info from a model. Default is `None`.
|
|
TODO: support studio
|
|
endpoint(`str`):
|
|
None or specific endpoint to use, when None, use the default endpoint
|
|
set in HubApi class (self.endpoint)
|
|
re_raise(`bool`):
|
|
raise exception when error
|
|
token (`str`, *optional*): access token to use for checking existence.
|
|
Returns:
|
|
True if the repository exists, False otherwise.
|
|
"""
|
|
if endpoint is None:
|
|
endpoint = self.endpoint
|
|
if (repo_type is not None) and repo_type.lower() not in REPO_TYPE_SUPPORT:
|
|
raise Exception('Not support repo-type: %s' % repo_type)
|
|
if (repo_id is None) or repo_id.count('/') != 1:
|
|
raise Exception('Invalid repo_id: %s, must be of format namespace/name' % repo_type)
|
|
|
|
cookies = self.get_cookies(access_token=token, cookies_required=False)
|
|
owner_or_group, name = model_id_to_group_owner_name(repo_id)
|
|
if (repo_type is not None) and repo_type.lower() == REPO_TYPE_DATASET:
|
|
path = f'{endpoint}/api/v1/datasets/{owner_or_group}/{name}'
|
|
else:
|
|
path = f'{endpoint}/api/v1/models/{owner_or_group}/{name}'
|
|
|
|
r = self.session.get(path, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
code = handle_http_response(r, logger, cookies, repo_id, False)
|
|
if code == 200:
|
|
return True
|
|
elif code == 404:
|
|
if re_raise:
|
|
raise HTTPError(r)
|
|
else:
|
|
return False
|
|
else:
|
|
logger.warn(f'Check repo_exists return status code {code}.')
|
|
raise Exception(
|
|
'Failed to check existence of repo: %s, make sure you have access authorization.'
|
|
% repo_type)
|
|
|
|
def delete_repo(self, repo_id: str, repo_type: str, endpoint: Optional[str] = None):
|
|
"""
|
|
Delete a repository from ModelScope.
|
|
|
|
Args:
|
|
repo_id (`str`):
|
|
A namespace (user or an organization) and a repo name separated
|
|
by a `/`.
|
|
repo_type (`str`):
|
|
The type of the repository. Supported types are `model` and `dataset`.
|
|
endpoint(`str`):
|
|
The endpoint to use. If not provided, the default endpoint is `https://www.modelscope.cn`
|
|
Could be set to `https://ai.modelscope.ai` for international version.
|
|
"""
|
|
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
|
|
if repo_type == REPO_TYPE_DATASET:
|
|
self.delete_dataset(repo_id, endpoint)
|
|
elif repo_type == REPO_TYPE_MODEL:
|
|
self.delete_model(repo_id, endpoint)
|
|
else:
|
|
raise Exception(f'Arg repo_type {repo_type} not supported.')
|
|
|
|
logger.info(f'Repo {repo_id} deleted successfully.')
|
|
|
|
@staticmethod
|
|
def _create_default_config(model_dir):
|
|
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
|
cfg = {
|
|
ConfigFields.framework: Frameworks.torch,
|
|
ConfigFields.task: Tasks.other,
|
|
}
|
|
with open(cfg_file, 'w') as file:
|
|
json.dump(cfg, file)
|
|
|
|
def push_model(self,
|
|
model_id: str,
|
|
model_dir: str,
|
|
visibility: Optional[int] = ModelVisibility.PUBLIC,
|
|
license: Optional[str] = Licenses.APACHE_V2,
|
|
chinese_name: Optional[str] = None,
|
|
commit_message: Optional[str] = 'upload model',
|
|
tag: Optional[str] = None,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
|
original_model_id: Optional[str] = None,
|
|
ignore_file_pattern: Optional[Union[List[str], str]] = None,
|
|
lfs_suffix: Optional[Union[str, List[str]]] = None):
|
|
warnings.warn(
|
|
'This function is deprecated and will be removed in future versions. '
|
|
'Please use git command directly or use HubApi().upload_folder instead',
|
|
DeprecationWarning,
|
|
stacklevel=2
|
|
)
|
|
"""Upload model from a given directory to given repository. A valid model directory
|
|
must contain a configuration.json file.
|
|
|
|
This function upload the files in given directory to given repository. If the
|
|
given repository is not exists in remote, it will automatically create it with
|
|
given visibility, license and chinese_name parameters. If the revision is also
|
|
not exists in remote repository, it will create a new branch for it.
|
|
|
|
This function must be called before calling HubApi's login with a valid token
|
|
which can be obtained from ModelScope's website.
|
|
|
|
If any error, please upload via git commands.
|
|
|
|
Args:
|
|
model_id (str):
|
|
The model id to be uploaded, caller must have write permission for it.
|
|
model_dir(str):
|
|
The Absolute Path of the finetune result.
|
|
visibility(int, optional):
|
|
Visibility of the new created model(1-private, 5-public). If the model is
|
|
not exists in ModelScope, this function will create a new model with this
|
|
visibility and this parameter is required. You can ignore this parameter
|
|
if you make sure the model's existence.
|
|
license(`str`, defaults to `None`):
|
|
License of the new created model(see License). If the model is not exists
|
|
in ModelScope, this function will create a new model with this license
|
|
and this parameter is required. You can ignore this parameter if you
|
|
make sure the model's existence.
|
|
chinese_name(`str`, *optional*, defaults to `None`):
|
|
chinese name of the new created model.
|
|
commit_message(`str`, *optional*, defaults to `None`):
|
|
commit message of the push request.
|
|
tag(`str`, *optional*, defaults to `None`):
|
|
The tag on this commit
|
|
revision (`str`, *optional*, default to DEFAULT_MODEL_REVISION):
|
|
which branch to push. If the branch is not exists, It will create a new
|
|
branch and push to it.
|
|
original_model_id (str, optional): The base model id which this model is trained from
|
|
ignore_file_pattern (`Union[List[str], str]`, optional): The file pattern to ignore uploading
|
|
lfs_suffix (`List[str]`, optional): File types to use LFS to manage. examples: '*.safetensors'.
|
|
|
|
Raises:
|
|
InvalidParameter: Parameter invalid.
|
|
NotLoginException: Not login
|
|
ValueError: No configuration.json
|
|
Exception: Create failed.
|
|
"""
|
|
if model_id is None:
|
|
raise InvalidParameter('model_id cannot be empty!')
|
|
if model_dir is None:
|
|
raise InvalidParameter('model_dir cannot be empty!')
|
|
if not os.path.exists(model_dir) or os.path.isfile(model_dir):
|
|
raise InvalidParameter('model_dir must be a valid directory.')
|
|
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
|
if not os.path.exists(cfg_file):
|
|
logger.warning(
|
|
f'No {ModelFile.CONFIGURATION} file found in {model_dir}, creating a default one.')
|
|
HubApi._create_default_config(model_dir)
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if cookies is None:
|
|
raise NotLoginException('Must login before upload!')
|
|
files_to_save = os.listdir(model_dir)
|
|
folder_size = get_readable_folder_size(model_dir)
|
|
if ignore_file_pattern is None:
|
|
ignore_file_pattern = []
|
|
if isinstance(ignore_file_pattern, str):
|
|
ignore_file_pattern = [ignore_file_pattern]
|
|
if visibility is None or license is None:
|
|
raise InvalidParameter('Visibility and License cannot be empty for new model.')
|
|
if not self.repo_exists(model_id):
|
|
logger.info('Creating new model [%s]' % model_id)
|
|
self.create_model(
|
|
model_id=model_id,
|
|
visibility=visibility,
|
|
license=license,
|
|
chinese_name=chinese_name,
|
|
original_model_id=original_model_id)
|
|
tmp_dir = os.path.join(model_dir, TEMPORARY_FOLDER_NAME) # make temporary folder
|
|
git_wrapper = GitCommandWrapper()
|
|
logger.info(f'Pushing folder {model_dir} as model {model_id}.')
|
|
logger.info(f'Total folder size {folder_size}, this may take a while depending on actual pushing size...')
|
|
try:
|
|
repo = Repository(model_dir=tmp_dir, clone_from=model_id)
|
|
branches = git_wrapper.get_remote_branches(tmp_dir)
|
|
if revision not in branches:
|
|
logger.info('Creating new branch %s' % revision)
|
|
git_wrapper.new_branch(tmp_dir, revision)
|
|
git_wrapper.checkout(tmp_dir, revision)
|
|
files_in_repo = os.listdir(tmp_dir)
|
|
for f in files_in_repo:
|
|
if f[0] != '.':
|
|
src = os.path.join(tmp_dir, f)
|
|
if os.path.isfile(src):
|
|
os.remove(src)
|
|
else:
|
|
shutil.rmtree(src, ignore_errors=True)
|
|
for f in files_to_save:
|
|
if f[0] != '.':
|
|
if any([re.search(pattern, f) is not None for pattern in ignore_file_pattern]):
|
|
continue
|
|
src = os.path.join(model_dir, f)
|
|
if os.path.isdir(src):
|
|
shutil.copytree(src, os.path.join(tmp_dir, f))
|
|
else:
|
|
shutil.copy(src, tmp_dir)
|
|
if not commit_message:
|
|
date = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
|
|
commit_message = '[automsg] push model %s to hub at %s' % (
|
|
model_id, date)
|
|
if lfs_suffix is not None:
|
|
lfs_suffix_list = [lfs_suffix] if isinstance(lfs_suffix, str) else lfs_suffix
|
|
for suffix in lfs_suffix_list:
|
|
repo.add_lfs_type(suffix)
|
|
repo.push(
|
|
commit_message=commit_message,
|
|
local_branch=revision,
|
|
remote_branch=revision)
|
|
if tag is not None:
|
|
repo.tag_and_push(tag, tag)
|
|
logger.info(f'Successfully push folder {model_dir} to remote repo [{model_id}].')
|
|
except Exception:
|
|
raise
|
|
finally:
|
|
shutil.rmtree(tmp_dir, ignore_errors=True)
|
|
|
|
def list_models(self,
|
|
owner_or_group: str,
|
|
page_number: Optional[int] = 1,
|
|
page_size: Optional[int] = 10,
|
|
endpoint: Optional[str] = None) -> dict:
|
|
"""List models in owner or group.
|
|
|
|
Args:
|
|
owner_or_group(str): owner or group.
|
|
page_number(int, optional): The page number, default: 1
|
|
page_size(int, optional): The page size, default: 10
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Raises:
|
|
RequestError: The request error.
|
|
|
|
Returns:
|
|
dict: {"models": "list of models", "TotalCount": total_number_of_models_in_owner_or_group}
|
|
"""
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
path = f'{endpoint}/api/v1/models/'
|
|
r = self.session.put(
|
|
path,
|
|
data='{"Path":"%s", "PageNumber":%s, "PageSize": %s}' %
|
|
(owner_or_group, page_number, page_size),
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
handle_http_response(r, logger, cookies, owner_or_group)
|
|
if r.status_code == HTTPStatus.OK:
|
|
if is_ok(r.json()):
|
|
data = r.json()[API_RESPONSE_FIELD_DATA]
|
|
return data
|
|
else:
|
|
raise RequestError(r.json()[API_RESPONSE_FIELD_MESSAGE])
|
|
else:
|
|
raise_for_http_status(r)
|
|
return None
|
|
|
|
def _check_cookie(self, use_cookies: Union[bool, CookieJar] = False) -> CookieJar: # noqa
|
|
cookies = None
|
|
if isinstance(use_cookies, CookieJar):
|
|
cookies = use_cookies
|
|
elif use_cookies:
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if cookies is None:
|
|
raise ValueError('Token does not exist, please login first.')
|
|
return cookies
|
|
|
|
def list_model_revisions(
|
|
self,
|
|
model_id: str,
|
|
cutoff_timestamp: Optional[int] = None,
|
|
use_cookies: Union[bool, CookieJar] = False) -> List[str]:
|
|
"""Get model branch and tags.
|
|
|
|
Args:
|
|
model_id (str): The model id
|
|
cutoff_timestamp (int): Tags created before the cutoff will be included.
|
|
The timestamp is represented by the seconds elapsed from the epoch time.
|
|
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
|
will load cookie from local. Defaults to False.
|
|
|
|
Returns:
|
|
Tuple[List[str], List[str]]: Return list of branch name and tags
|
|
"""
|
|
tags_details = self.list_model_revisions_detail(model_id=model_id,
|
|
cutoff_timestamp=cutoff_timestamp,
|
|
use_cookies=use_cookies)
|
|
tags = [x['Revision'] for x in tags_details
|
|
] if tags_details else []
|
|
return tags
|
|
|
|
def list_model_revisions_detail(
|
|
self,
|
|
model_id: str,
|
|
cutoff_timestamp: Optional[int] = None,
|
|
use_cookies: Union[bool, CookieJar] = False,
|
|
endpoint: Optional[str] = None) -> List[str]:
|
|
"""Get model branch and tags.
|
|
|
|
Args:
|
|
model_id (str): The model id
|
|
cutoff_timestamp (int): Tags created before the cutoff will be included.
|
|
The timestamp is represented by the seconds elapsed from the epoch time.
|
|
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
|
will load cookie from local. Defaults to False.
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Returns:
|
|
Tuple[List[str], List[str]]: Return list of branch name and tags
|
|
"""
|
|
cookies = self._check_cookie(use_cookies)
|
|
if cutoff_timestamp is None:
|
|
cutoff_timestamp = get_release_datetime()
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
path = f'{endpoint}/api/v1/models/{model_id}/revisions?EndTime=%s' % cutoff_timestamp
|
|
r = self.session.get(path, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
handle_http_response(r, logger, cookies, model_id)
|
|
d = r.json()
|
|
raise_on_error(d)
|
|
info = d[API_RESPONSE_FIELD_DATA]
|
|
# tags returned from backend are guaranteed to be ordered by create-time
|
|
return info['RevisionMap']['Tags']
|
|
|
|
def get_branch_tag_detail(self, details, name):
|
|
for item in details:
|
|
if item['Revision'] == name:
|
|
return item
|
|
return None
|
|
|
|
def get_valid_revision_detail(self,
|
|
model_id: str,
|
|
revision=None,
|
|
cookies: Optional[CookieJar] = None,
|
|
endpoint: Optional[str] = None):
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
release_timestamp = get_release_datetime()
|
|
current_timestamp = int(round(datetime.datetime.now().timestamp()))
|
|
# for active development in library codes (non-release-branches), release_timestamp
|
|
# is set to be a far-away-time-in-the-future, to ensure that we shall
|
|
# get the master-HEAD version from model repo by default (when no revision is provided)
|
|
all_branches_detail, all_tags_detail = self.get_model_branches_and_tags_details(
|
|
model_id, use_cookies=False if cookies is None else cookies, endpoint=endpoint)
|
|
all_branches = [x['Revision'] for x in all_branches_detail] if all_branches_detail else []
|
|
all_tags = [x['Revision'] for x in all_tags_detail] if all_tags_detail else []
|
|
if release_timestamp > current_timestamp + ONE_YEAR_SECONDS:
|
|
if revision is None:
|
|
revision = MASTER_MODEL_BRANCH
|
|
logger.info(
|
|
'Model revision not specified, using default [%s] version.'
|
|
% revision)
|
|
if revision not in all_branches and revision not in all_tags:
|
|
raise NotExistError('The model: %s has no revision : %s .' % (model_id, revision))
|
|
|
|
revision_detail = self.get_branch_tag_detail(all_tags_detail, revision)
|
|
if revision_detail is None:
|
|
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
|
|
logger.debug('Development mode use revision: %s' % revision)
|
|
else:
|
|
if revision is not None and revision in all_branches:
|
|
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
|
|
return revision_detail
|
|
|
|
if len(all_tags_detail) == 0: # use no revision use master as default.
|
|
if revision is None or revision == MASTER_MODEL_BRANCH:
|
|
revision = MASTER_MODEL_BRANCH
|
|
else:
|
|
raise NotExistError('The model: %s has no revision: %s !' % (model_id, revision))
|
|
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
|
|
else:
|
|
if revision is None: # user not specified revision, use latest revision before release time
|
|
revisions_detail = [x for x in
|
|
all_tags_detail if
|
|
x['CreatedAt'] <= release_timestamp] if all_tags_detail else [] # noqa E501
|
|
if len(revisions_detail) > 0:
|
|
revision = revisions_detail[0]['Revision'] # use latest revision before release time.
|
|
revision_detail = revisions_detail[0]
|
|
else:
|
|
revision = MASTER_MODEL_BRANCH
|
|
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
|
|
vl = '[%s]' % ','.join(all_tags)
|
|
logger.warning('Model revision should be specified from revisions: %s' % (vl))
|
|
logger.warning('Model revision not specified, use revision: %s' % revision)
|
|
else:
|
|
# use user-specified revision
|
|
if revision not in all_tags:
|
|
if revision == MASTER_MODEL_BRANCH:
|
|
logger.warning('Using the master branch is fragile, please use it with caution!')
|
|
revision_detail = self.get_branch_tag_detail(all_branches_detail, revision)
|
|
else:
|
|
vl = '[%s]' % ','.join(all_tags)
|
|
raise NotExistError('The model: %s has no revision: %s valid are: %s!' %
|
|
(model_id, revision, vl))
|
|
else:
|
|
revision_detail = self.get_branch_tag_detail(all_tags_detail, revision)
|
|
logger.info('Use user-specified model revision: %s' % revision)
|
|
return revision_detail
|
|
|
|
def get_valid_revision(self,
|
|
model_id: str,
|
|
revision=None,
|
|
cookies: Optional[CookieJar] = None,
|
|
endpoint: Optional[str] = None):
|
|
return self.get_valid_revision_detail(model_id=model_id,
|
|
revision=revision,
|
|
cookies=cookies,
|
|
endpoint=endpoint)['Revision']
|
|
|
|
def get_model_branches_and_tags_details(
|
|
self,
|
|
model_id: str,
|
|
use_cookies: Union[bool, CookieJar] = False,
|
|
endpoint: Optional[str] = None
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Get model branch and tags.
|
|
|
|
Args:
|
|
model_id (str): The model id
|
|
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
|
will load cookie from local. Defaults to False.
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Returns:
|
|
Tuple[List[str], List[str]]: Return list of branch name and tags
|
|
"""
|
|
cookies = self._check_cookie(use_cookies)
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
path = f'{endpoint}/api/v1/models/{model_id}/revisions'
|
|
r = self.session.get(path, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
handle_http_response(r, logger, cookies, model_id)
|
|
d = r.json()
|
|
raise_on_error(d)
|
|
info = d[API_RESPONSE_FIELD_DATA]
|
|
return info['RevisionMap']['Branches'], info['RevisionMap']['Tags']
|
|
|
|
def get_model_branches_and_tags(
|
|
self,
|
|
model_id: str,
|
|
use_cookies: Union[bool, CookieJar] = False,
|
|
) -> Tuple[List[str], List[str]]:
|
|
"""Get model branch and tags.
|
|
|
|
Args:
|
|
model_id (str): The model id
|
|
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
|
will load cookie from local. Defaults to False.
|
|
|
|
Returns:
|
|
Tuple[List[str], List[str]]: Return list of branch name and tags
|
|
"""
|
|
branches_detail, tags_detail = self.get_model_branches_and_tags_details(model_id=model_id,
|
|
use_cookies=use_cookies)
|
|
branches = [x['Revision'] for x in branches_detail
|
|
] if branches_detail else []
|
|
tags = [x['Revision'] for x in tags_detail
|
|
] if tags_detail else []
|
|
return branches, tags
|
|
|
|
def get_model_files(self,
|
|
model_id: str,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
root: Optional[str] = None,
|
|
recursive: Optional[bool] = False,
|
|
use_cookies: Union[bool, CookieJar] = False,
|
|
headers: Optional[dict] = {},
|
|
endpoint: Optional[str] = None) -> List[dict]:
|
|
"""List the models files.
|
|
|
|
Args:
|
|
model_id (str): The model id
|
|
revision (Optional[str], optional): The branch or tag name.
|
|
root (Optional[str], optional): The root path. Defaults to None.
|
|
recursive (Optional[bool], optional): Is recursive list files. Defaults to False.
|
|
use_cookies (Union[bool, CookieJar], optional): If is cookieJar, we will use this cookie, if True,
|
|
will load cookie from local. Defaults to False.
|
|
headers: request headers
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Returns:
|
|
List[dict]: Model file list.
|
|
"""
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
if revision:
|
|
path = '%s/api/v1/models/%s/repo/files?Revision=%s&Recursive=%s' % (
|
|
endpoint, model_id, revision, recursive)
|
|
else:
|
|
path = '%s/api/v1/models/%s/repo/files?Recursive=%s' % (
|
|
endpoint, model_id, recursive)
|
|
cookies = self._check_cookie(use_cookies)
|
|
if root is not None:
|
|
path = path + f'&Root={root}'
|
|
headers = self.headers if headers is None else headers
|
|
headers['X-Request-ID'] = str(uuid.uuid4().hex)
|
|
r = self.session.get(
|
|
path, cookies=cookies, headers=headers)
|
|
|
|
handle_http_response(r, logger, cookies, model_id)
|
|
d = r.json()
|
|
raise_on_error(d)
|
|
|
|
files = []
|
|
if not d[API_RESPONSE_FIELD_DATA]['Files']:
|
|
logger.warning(f'No files found in model {model_id} at revision {revision}.')
|
|
return files
|
|
for file in d[API_RESPONSE_FIELD_DATA]['Files']:
|
|
if file['Name'] == '.gitignore' or file['Name'] == '.gitattributes':
|
|
continue
|
|
|
|
files.append(file)
|
|
return files
|
|
|
|
def file_exists(
|
|
self,
|
|
repo_id: str,
|
|
filename: str,
|
|
*,
|
|
revision: Optional[str] = None,
|
|
):
|
|
"""Get if the specified file exists
|
|
|
|
Args:
|
|
repo_id (`str`): The repo id to use
|
|
filename (`str`): The queried filename, if the file exists in a sub folder,
|
|
please pass <sub-folder-name>/<file-name>
|
|
revision (`Optional[str]`): The repo revision
|
|
Returns:
|
|
The query result in bool value
|
|
"""
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
files = self.get_model_files(
|
|
repo_id,
|
|
recursive=True,
|
|
revision=revision,
|
|
use_cookies=False if cookies is None else cookies,
|
|
)
|
|
files = [file['Path'] for file in files]
|
|
return filename in files
|
|
|
|
def create_dataset(self,
|
|
dataset_name: str,
|
|
namespace: str,
|
|
chinese_name: Optional[str] = '',
|
|
license: Optional[str] = Licenses.APACHE_V2,
|
|
visibility: Optional[int] = DatasetVisibility.PUBLIC,
|
|
description: Optional[str] = '',
|
|
endpoint: Optional[str] = None, ) -> str:
|
|
|
|
if dataset_name is None or namespace is None:
|
|
raise InvalidParameter('dataset_name and namespace are required!')
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if cookies is None:
|
|
raise ValueError('Token does not exist, please login first.')
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
path = f'{endpoint}/api/v1/datasets'
|
|
files = {
|
|
'Name': (None, dataset_name),
|
|
'ChineseName': (None, chinese_name),
|
|
'Owner': (None, namespace),
|
|
'License': (None, license),
|
|
'Visibility': (None, visibility),
|
|
'Description': (None, description)
|
|
}
|
|
|
|
r = self.session.post(
|
|
path,
|
|
files=files,
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers),
|
|
)
|
|
|
|
handle_http_post_error(r, path, files)
|
|
raise_on_error(r.json())
|
|
dataset_repo_url = f'{endpoint}/datasets/{namespace}/{dataset_name}'
|
|
logger.info(f'Create dataset success: {dataset_repo_url}')
|
|
return dataset_repo_url
|
|
|
|
def list_datasets(self, endpoint: Optional[str] = None):
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
path = f'{endpoint}/api/v1/datasets'
|
|
params = {}
|
|
r = self.session.get(path, params=params,
|
|
headers=self.builder_headers(self.headers))
|
|
raise_for_http_status(r)
|
|
dataset_list = r.json()[API_RESPONSE_FIELD_DATA]
|
|
return [x['Name'] for x in dataset_list]
|
|
|
|
def delete_dataset(self, dataset_id: str, endpoint: Optional[str] = None):
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
if cookies is None:
|
|
raise ValueError('Token does not exist, please login first.')
|
|
|
|
path = f'{endpoint}/api/v1/datasets/{dataset_id}'
|
|
r = self.session.delete(path,
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
raise_for_http_status(r)
|
|
raise_on_error(r.json())
|
|
|
|
def get_dataset_id_and_type(self, dataset_name: str, namespace: str, endpoint: Optional[str] = None):
|
|
""" Get the dataset id and type. """
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
r = self.session.get(datahub_url, cookies=cookies)
|
|
resp = r.json()
|
|
datahub_raise_on_error(datahub_url, resp, r)
|
|
dataset_id = resp['Data']['Id']
|
|
dataset_type = resp['Data']['Type']
|
|
return dataset_id, dataset_type
|
|
|
|
def list_repo_tree(self,
|
|
dataset_name: str,
|
|
namespace: str,
|
|
revision: str,
|
|
root_path: str,
|
|
recursive: bool = True,
|
|
page_number: int = 1,
|
|
page_size: int = 100,
|
|
endpoint: Optional[str] = None):
|
|
"""
|
|
@deprecated: Use `get_dataset_files` instead.
|
|
"""
|
|
warnings.warn('The function `list_repo_tree` is deprecated, use `get_dataset_files` instead.',
|
|
DeprecationWarning)
|
|
|
|
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
|
dataset_name=dataset_name, namespace=namespace, endpoint=endpoint)
|
|
|
|
recursive = 'True' if recursive else 'False'
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
|
|
params = {'Revision': revision if revision else 'master',
|
|
'Root': root_path if root_path else '/', 'Recursive': recursive,
|
|
'PageNumber': page_number, 'PageSize': page_size}
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
|
resp = r.json()
|
|
datahub_raise_on_error(datahub_url, resp, r)
|
|
|
|
return resp
|
|
|
|
def list_repo_commits(self,
|
|
repo_id: str,
|
|
*,
|
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
|
page_number: int = 1,
|
|
page_size: int = 50,
|
|
endpoint: Optional[str] = None):
|
|
"""
|
|
Get the commit history for a repository.
|
|
|
|
Args:
|
|
repo_id (str): The repository id, in the format of `namespace/repo_name`.
|
|
repo_type (Optional[str]): The type of the repository. Supported types are `model` and `dataset`.
|
|
revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
|
page_number (int): The page number for pagination. Defaults to 1.
|
|
page_size (int): The number of commits per page. Defaults to 50.
|
|
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
|
|
|
Returns:
|
|
CommitHistoryResponse: The commit history response.
|
|
|
|
Examples:
|
|
>>> from modelscope.hub.api import HubApi
|
|
>>> api = HubApi()
|
|
>>> commit_history = api.list_repo_commits('meituan/Meeseeks')
|
|
>>> print(f"Total commits: {commit_history.total_count}")
|
|
>>> for commit in commit_history.commits:
|
|
... print(f"{commit.short_id}: {commit.title}")
|
|
"""
|
|
from datasets.utils.file_utils import is_relative_path
|
|
|
|
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
|
_owner, _dataset_name = repo_id.split('/')
|
|
else:
|
|
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
|
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
|
|
commits_url = f'{endpoint}/api/v1/{repo_type}s/{repo_id}/commits' if repo_type else \
|
|
f'{endpoint}/api/v1/models/{repo_id}/commits'
|
|
params = {
|
|
'Ref': revision or DEFAULT_MODEL_REVISION or DEFAULT_REPOSITORY_REVISION,
|
|
'PageNumber': page_number,
|
|
'PageSize': page_size
|
|
}
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
try:
|
|
r = self.session.get(commits_url, params=params,
|
|
cookies=cookies, headers=self.builder_headers(self.headers))
|
|
raise_for_http_status(r)
|
|
resp = r.json()
|
|
raise_on_error(resp)
|
|
|
|
if resp.get('Code') == HTTPStatus.OK:
|
|
return CommitHistoryResponse.from_api_response(resp)
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
raise Exception(f'Failed to get repository commits for {repo_id}: {str(e)}')
|
|
|
|
def get_dataset_files(self,
|
|
repo_id: str,
|
|
*,
|
|
revision: str = DEFAULT_REPOSITORY_REVISION,
|
|
root_path: str = '/',
|
|
recursive: bool = True,
|
|
page_number: int = 1,
|
|
page_size: int = 100,
|
|
endpoint: Optional[str] = None):
|
|
"""
|
|
Get the dataset files.
|
|
|
|
Args:
|
|
repo_id (str): The repository id, in the format of `namespace/dataset_name`.
|
|
revision (str): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
|
root_path (str): The root path to list. Defaults to '/'.
|
|
recursive (bool): Whether to list recursively. Defaults to True.
|
|
page_number (int): The page number for pagination. Defaults to 1.
|
|
page_size (int): The number of items per page. Defaults to 100.
|
|
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
|
|
|
Returns:
|
|
List: The response containing the dataset repository tree information.
|
|
e.g. [{'CommitId': None, 'CommitMessage': '...', 'Size': 0, 'Type': 'tree'}, ...]
|
|
"""
|
|
from datasets.utils.file_utils import is_relative_path
|
|
|
|
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
|
_owner, _dataset_name = repo_id.split('/')
|
|
else:
|
|
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
|
|
|
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
|
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
|
|
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
|
|
params = {
|
|
'Revision': revision,
|
|
'Root': root_path,
|
|
'Recursive': 'True' if recursive else 'False',
|
|
'PageNumber': page_number,
|
|
'PageSize': page_size
|
|
}
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
r = self.session.get(datahub_url, params=params, cookies=cookies)
|
|
resp = r.json()
|
|
datahub_raise_on_error(datahub_url, resp, r)
|
|
|
|
return resp['Data']['Files']
|
|
|
|
def get_dataset(
|
|
self,
|
|
dataset_id: str,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
|
endpoint: Optional[str] = None
|
|
):
|
|
"""
|
|
Get the dataset information.
|
|
|
|
Args:
|
|
dataset_id (str): The dataset id.
|
|
revision (Optional[str]): The revision of the dataset.
|
|
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
|
|
|
Returns:
|
|
dict: The dataset information.
|
|
"""
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
|
|
if revision:
|
|
path = f'{endpoint}/api/v1/datasets/{dataset_id}?Revision={revision}'
|
|
else:
|
|
path = f'{endpoint}/api/v1/datasets/{dataset_id}'
|
|
|
|
r = self.session.get(
|
|
path, cookies=cookies, headers=self.builder_headers(self.headers))
|
|
raise_for_http_status(r)
|
|
resp = r.json()
|
|
datahub_raise_on_error(path, resp, r)
|
|
return resp[API_RESPONSE_FIELD_DATA]
|
|
|
|
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str,
|
|
dataset_id: str, revision: str, endpoint: Optional[str] = None):
|
|
""" Get the meta file-list of the dataset. """
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
r = self.session.get(datahub_url,
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
resp = r.json()
|
|
datahub_raise_on_error(datahub_url, resp, r)
|
|
file_list = resp['Data']
|
|
if file_list is None:
|
|
raise NotExistError(
|
|
f'The modelscope dataset [dataset_name = {dataset_name}, namespace = {namespace}, '
|
|
f'version = {revision}] dose not exist')
|
|
|
|
file_list = file_list['Files']
|
|
return file_list
|
|
|
|
@staticmethod
|
|
def dump_datatype_file(dataset_type: int, meta_cache_dir: str):
|
|
"""
|
|
Dump the data_type as a local file, in order to get the dataset
|
|
formation without calling the datahub.
|
|
More details, please refer to the class
|
|
`modelscope.utils.constant.DatasetFormations`.
|
|
"""
|
|
dataset_type_file_path = os.path.join(meta_cache_dir,
|
|
f'{str(dataset_type)}{DatasetFormations.formation_mark_ext.value}')
|
|
with open(dataset_type_file_path, 'w') as fp:
|
|
fp.write('*** Automatically-generated file, do not modify ***')
|
|
|
|
def get_dataset_meta_files_local_paths(self, dataset_name: str,
|
|
namespace: str,
|
|
revision: str,
|
|
meta_cache_dir: str, dataset_type: int, file_list: list,
|
|
endpoint: Optional[str] = None):
|
|
local_paths = defaultdict(list)
|
|
dataset_formation = DatasetFormations(dataset_type)
|
|
dataset_meta_format = DatasetMetaFormats[dataset_formation]
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
# Dump the data_type as a local file
|
|
HubApi.dump_datatype_file(dataset_type=dataset_type, meta_cache_dir=meta_cache_dir)
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
for file_info in file_list:
|
|
file_path = file_info['Path']
|
|
extension = os.path.splitext(file_path)[-1]
|
|
if extension in dataset_meta_format:
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
|
|
f'Revision={revision}&FilePath={file_path}'
|
|
r = self.session.get(datahub_url, cookies=cookies)
|
|
raise_for_http_status(r)
|
|
local_path = os.path.join(meta_cache_dir, file_path)
|
|
if os.path.exists(local_path):
|
|
logger.warning(
|
|
f"Reusing dataset {dataset_name}'s python file ({local_path})"
|
|
)
|
|
local_paths[extension].append(local_path)
|
|
continue
|
|
with open(local_path, 'wb') as f:
|
|
f.write(r.content)
|
|
local_paths[extension].append(local_path)
|
|
|
|
return local_paths, dataset_formation
|
|
|
|
@staticmethod
|
|
def fetch_meta_files_from_url(url, out_path, chunk_size=1024, mode=DownloadMode.REUSE_DATASET_IF_EXISTS):
|
|
"""
|
|
Fetch the meta-data files from the url, e.g. csv/jsonl files.
|
|
"""
|
|
import hashlib
|
|
from tqdm.auto import tqdm
|
|
import pandas as pd
|
|
|
|
out_path = os.path.join(out_path, hashlib.md5(url.encode(encoding='UTF-8')).hexdigest())
|
|
if mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(out_path):
|
|
os.remove(out_path)
|
|
if os.path.exists(out_path):
|
|
logger.info(f'Reusing cached meta-data file: {out_path}')
|
|
return out_path
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
# Make the request and get the response content as TextIO
|
|
logger.info('Loading meta-data file ...')
|
|
response = requests.get(url, cookies=cookies, stream=True)
|
|
total_size = int(response.headers.get('content-length', 0))
|
|
progress = tqdm(total=total_size, dynamic_ncols=True)
|
|
|
|
def get_chunk(resp):
|
|
chunk_data = []
|
|
for data in resp.iter_lines():
|
|
data = data.decode('utf-8')
|
|
chunk_data.append(data)
|
|
if len(chunk_data) >= chunk_size:
|
|
yield chunk_data
|
|
chunk_data = []
|
|
yield chunk_data
|
|
|
|
iter_num = 0
|
|
with open(out_path, 'a') as f:
|
|
for chunk in get_chunk(response):
|
|
progress.update(len(chunk))
|
|
if url.endswith('jsonl'):
|
|
chunk = [json.loads(line) for line in chunk if line.strip()]
|
|
if len(chunk) == 0:
|
|
continue
|
|
if iter_num == 0:
|
|
with_header = True
|
|
else:
|
|
with_header = False
|
|
chunk_df = pd.DataFrame(chunk)
|
|
chunk_df.to_csv(f, index=False, header=with_header, escapechar='\\')
|
|
iter_num += 1
|
|
else:
|
|
# csv or others
|
|
for line in chunk:
|
|
f.write(line + '\n')
|
|
progress.close()
|
|
|
|
return out_path
|
|
|
|
def get_dataset_file_url(
|
|
self,
|
|
file_name: str,
|
|
dataset_name: str,
|
|
namespace: str,
|
|
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
|
view: Optional[bool] = False,
|
|
extension_filter: Optional[bool] = True,
|
|
endpoint: Optional[str] = None):
|
|
|
|
if not file_name or not dataset_name or not namespace:
|
|
raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
|
|
|
|
# Note: make sure the FilePath is the last parameter in the url
|
|
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name, 'View': view}
|
|
params: str = urlencode(params)
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
file_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
|
|
|
|
return file_url
|
|
|
|
# if extension_filter:
|
|
# if os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
|
|
# file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?'\
|
|
# f'Revision={revision}&FilePath={file_name}'
|
|
# else:
|
|
# file_url = file_name
|
|
# return file_url
|
|
# else:
|
|
# return file_url
|
|
|
|
def get_dataset_file_url_origin(
|
|
self,
|
|
file_name: str,
|
|
dataset_name: str,
|
|
namespace: str,
|
|
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
|
endpoint: Optional[str] = None):
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
if file_name and os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
|
|
file_name = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
|
|
f'Revision={revision}&FilePath={file_name}'
|
|
return file_name
|
|
|
|
def get_dataset_access_config(
|
|
self,
|
|
dataset_name: str,
|
|
namespace: str,
|
|
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
|
endpoint: Optional[str] = None):
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
|
|
f'ststoken?Revision={revision}'
|
|
return self.datahub_remote_call(datahub_url)
|
|
|
|
def get_dataset_access_config_session(
|
|
self,
|
|
dataset_name: str,
|
|
namespace: str,
|
|
check_cookie: bool,
|
|
revision: Optional[str] = DEFAULT_DATASET_REVISION,
|
|
endpoint: Optional[str] = None):
|
|
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/' \
|
|
f'ststoken?Revision={revision}'
|
|
if check_cookie:
|
|
cookies = self._check_cookie(use_cookies=True)
|
|
else:
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
r = self.session.get(
|
|
url=datahub_url,
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
resp = r.json()
|
|
raise_on_error(resp)
|
|
return resp['Data']
|
|
|
|
def get_virgo_meta(self, dataset_id: str, version: int = 1) -> dict:
|
|
"""
|
|
Get virgo dataset meta info.
|
|
"""
|
|
virgo_endpoint = os.environ.get(VirgoDatasetConfig.env_virgo_endpoint, '')
|
|
if not virgo_endpoint:
|
|
raise RuntimeError(f'Virgo endpoint is not set in env: {VirgoDatasetConfig.env_virgo_endpoint}')
|
|
|
|
virgo_dataset_url = f'{virgo_endpoint}/data/set/download'
|
|
cookies = requests.utils.dict_from_cookiejar(ModelScopeConfig.get_cookies())
|
|
|
|
dataset_info = dict(
|
|
dataSetId=dataset_id,
|
|
dataSetVersion=version
|
|
)
|
|
data = dict(
|
|
data=dataset_info,
|
|
)
|
|
r = self.session.post(url=virgo_dataset_url,
|
|
json=data,
|
|
cookies=cookies,
|
|
headers=self.builder_headers(self.headers),
|
|
timeout=900)
|
|
resp = r.json()
|
|
if resp['code'] != 0:
|
|
raise RuntimeError(f'Failed to get virgo dataset: {resp}')
|
|
|
|
return resp['data']
|
|
|
|
def get_dataset_access_config_for_unzipped(self,
|
|
dataset_name: str,
|
|
namespace: str,
|
|
revision: str,
|
|
zip_file_name: str,
|
|
endpoint: Optional[str] = None):
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
r = self.session.get(url=datahub_url, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
resp = r.json()
|
|
# get visibility of the dataset
|
|
raise_on_error(resp)
|
|
data = resp['Data']
|
|
visibility = VisibilityMap.get(data['Visibility'])
|
|
|
|
datahub_sts_url = f'{datahub_url}/ststoken?Revision={revision}'
|
|
r_sts = self.session.get(url=datahub_sts_url, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
resp_sts = r_sts.json()
|
|
raise_on_error(resp_sts)
|
|
data_sts = resp_sts['Data']
|
|
file_dir = visibility + '-unzipped' + '/' + namespace + '_' + dataset_name + '_' + zip_file_name
|
|
data_sts['Dir'] = file_dir
|
|
return data_sts
|
|
|
|
def list_oss_dataset_objects(self, dataset_name, namespace, max_limit,
|
|
is_recursive, is_filter_dir, revision, endpoint: Optional[str] = None):
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/tree/?' \
|
|
f'MaxLimit={max_limit}&Revision={revision}&Recursive={is_recursive}&FilterDir={is_filter_dir}'
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
resp = self.session.get(url=url, cookies=cookies, timeout=1800)
|
|
resp = resp.json()
|
|
raise_on_error(resp)
|
|
resp = resp['Data']
|
|
return resp
|
|
|
|
def delete_oss_dataset_object(self, object_name: str, dataset_name: str,
|
|
namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
|
|
if not object_name or not dataset_name or not namespace or not revision:
|
|
raise ValueError('Args cannot be empty!')
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss?Path={object_name}&Revision={revision}'
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
resp = self.session.delete(url=url, cookies=cookies)
|
|
resp = resp.json()
|
|
raise_on_error(resp)
|
|
resp = resp['Message']
|
|
return resp
|
|
|
|
def delete_oss_dataset_dir(self, object_name: str, dataset_name: str,
|
|
namespace: str, revision: str, endpoint: Optional[str] = None) -> str:
|
|
if not object_name or not dataset_name or not namespace or not revision:
|
|
raise ValueError('Args cannot be empty!')
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/oss/prefix?Prefix={object_name}/' \
|
|
f'&Revision={revision}'
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
resp = self.session.delete(url=url, cookies=cookies)
|
|
resp = resp.json()
|
|
raise_on_error(resp)
|
|
resp = resp['Message']
|
|
return resp
|
|
|
|
def datahub_remote_call(self, url):
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
r = self.session.get(
|
|
url,
|
|
cookies=cookies,
|
|
headers={'user-agent': ModelScopeConfig.get_user_agent()})
|
|
resp = r.json()
|
|
datahub_raise_on_error(url, resp, r)
|
|
return resp['Data']
|
|
|
|
def dataset_download_statistics(self, dataset_name: str, namespace: str,
|
|
use_streaming: bool = False, endpoint: Optional[str] = None) -> None:
|
|
is_ci_test = os.getenv('CI_TEST') == 'True'
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
if dataset_name and namespace and not is_ci_test and not use_streaming:
|
|
try:
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
|
|
# Download count
|
|
download_count_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/increase'
|
|
download_count_resp = self.session.post(download_count_url, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
raise_for_http_status(download_count_resp)
|
|
|
|
# Download uv
|
|
channel = DownloadChannel.LOCAL.value
|
|
user_name = ''
|
|
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
|
channel = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
|
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
|
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
|
download_uv_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/download/uv/' \
|
|
f'{channel}?user={user_name}'
|
|
download_uv_resp = self.session.post(download_uv_url, cookies=cookies,
|
|
headers=self.builder_headers(self.headers))
|
|
download_uv_resp = download_uv_resp.json()
|
|
raise_on_error(download_uv_resp)
|
|
|
|
except Exception as e:
|
|
logger.error(e)
|
|
|
|
def builder_headers(self, headers):
|
|
return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
|
|
**headers}
|
|
|
|
def get_file_base_path(self, repo_id: str, endpoint: Optional[str] = None) -> str:
|
|
_namespace, _dataset_name = repo_id.split('/')
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
return f'{endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?'
|
|
# return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
|
|
|
|
def create_repo(
|
|
self,
|
|
repo_id: str,
|
|
*,
|
|
token: Union[str, bool, None] = None,
|
|
visibility: Optional[str] = Visibility.PUBLIC,
|
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
|
chinese_name: Optional[str] = None,
|
|
license: Optional[str] = Licenses.APACHE_V2,
|
|
endpoint: Optional[str] = None,
|
|
exist_ok: Optional[bool] = False,
|
|
create_default_config: Optional[bool] = True,
|
|
aigc_model: Optional[AigcModel] = None,
|
|
**kwargs,
|
|
) -> str:
|
|
"""
|
|
Create a repository on the ModelScope Hub.
|
|
|
|
Args:
|
|
repo_id (str): The repo id in the format of `owner_name/repo_name`.
|
|
token (Union[str, bool, None]): The access token.
|
|
visibility (Optional[str]): The visibility of the repo,
|
|
could be `public`, `private`, `internal`, default to `public`.
|
|
repo_type (Optional[str]): The repo type, default to `model`.
|
|
chinese_name (Optional[str]): The Chinese name of the repo.
|
|
license (Optional[str]): The license of the repo, default to `apache-2.0`.
|
|
endpoint (Optional[str]): The endpoint to use.
|
|
In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai'
|
|
exist_ok (Optional[bool]): If the repo exists, whether to return the repo url directly.
|
|
create_default_config (Optional[bool]): If True, create a default configuration file in the model repo.
|
|
**kwargs: The additional arguments.
|
|
|
|
Returns:
|
|
str: The repo url.
|
|
"""
|
|
|
|
if not repo_id:
|
|
raise ValueError('Repo id cannot be empty!')
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
|
|
self.login(access_token=token, endpoint=endpoint)
|
|
|
|
repo_exists: bool = self.repo_exists(repo_id, repo_type=repo_type, endpoint=endpoint, token=token)
|
|
if repo_exists:
|
|
if exist_ok:
|
|
repo_url: str = f'{endpoint}/{repo_type}s/{repo_id}'
|
|
logger.warning(f'Repo {repo_id} already exists, got repo url: {repo_url}')
|
|
return repo_url
|
|
else:
|
|
raise ValueError(f'Repo {repo_id} already exists!')
|
|
|
|
repo_id_list = repo_id.split('/')
|
|
if len(repo_id_list) != 2:
|
|
raise ValueError('Invalid repo id, should be in the format of `owner_name/repo_name`')
|
|
namespace, repo_name = repo_id_list
|
|
|
|
if repo_type == REPO_TYPE_MODEL:
|
|
visibilities = {k: v for k, v in ModelVisibility.__dict__.items() if not k.startswith('__')}
|
|
visibility: int = visibilities.get(visibility.upper())
|
|
if visibility is None:
|
|
raise ValueError(f'Invalid visibility: {visibility}, '
|
|
f'supported visibilities: `public`, `private`, `internal`')
|
|
repo_url: str = self.create_model(
|
|
model_id=repo_id,
|
|
visibility=visibility,
|
|
license=license,
|
|
chinese_name=chinese_name,
|
|
aigc_model=aigc_model
|
|
)
|
|
if create_default_config:
|
|
with tempfile.TemporaryDirectory() as temp_cache_dir:
|
|
from modelscope.hub.repository import Repository
|
|
repo = Repository(temp_cache_dir, repo_id)
|
|
default_config = {
|
|
'framework': 'pytorch',
|
|
'task': 'text-generation',
|
|
'allow_remote': True
|
|
}
|
|
config_json = kwargs.get('config_json')
|
|
if not config_json:
|
|
config_json = {}
|
|
config = {**default_config, **config_json}
|
|
add_content_to_file(
|
|
repo,
|
|
'configuration.json', [json.dumps(config)],
|
|
ignore_push_error=True)
|
|
print(f'New model created successfully at {repo_url}.', flush=True)
|
|
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')}
|
|
visibility: int = visibilities.get(visibility.upper())
|
|
if visibility is None:
|
|
raise ValueError(f'Invalid visibility: {visibility}, '
|
|
f'supported visibilities: `public`, `private`, `internal`')
|
|
repo_url: str = self.create_dataset(
|
|
dataset_name=repo_name,
|
|
namespace=namespace,
|
|
chinese_name=chinese_name,
|
|
license=license,
|
|
visibility=visibility,
|
|
)
|
|
print(f'New dataset created successfully at {repo_url}.', flush=True)
|
|
|
|
else:
|
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
|
|
|
return repo_url
|
|
|
|
def create_commit(
|
|
self,
|
|
repo_id: str,
|
|
operations: Iterable[CommitOperation],
|
|
*,
|
|
commit_message: str,
|
|
commit_description: Optional[str] = None,
|
|
token: str = None,
|
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
|
endpoint: Optional[str] = None,
|
|
max_retries: int = 3,
|
|
timeout: int = 180,
|
|
) -> CommitInfo:
|
|
"""
|
|
Create a commit on the ModelScope Hub with retry mechanism.
|
|
|
|
Args:
|
|
repo_id (str): The repo id in the format of `owner_name/repo_name`.
|
|
operations (Iterable[CommitOperation]): The commit operations.
|
|
commit_message (str): The commit message.
|
|
commit_description (Optional[str]): The commit description.
|
|
token (str): The access token. If None, will use the cookies from the local cache.
|
|
See `https://modelscope.cn/my/myaccesstoken` to get your token.
|
|
repo_type (Optional[str]): The repo type, should be `model` or `dataset`. Defaults to `model`.
|
|
revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
|
endpoint (Optional[str]): The endpoint to use.
|
|
In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai'
|
|
max_retries (int): Number of max retry attempts (default: 3).
|
|
timeout (int): Timeout for each request in seconds (default: 180).
|
|
|
|
Returns:
|
|
CommitInfo: The commit info.
|
|
|
|
Raises:
|
|
requests.exceptions.RequestException: If all retry attempts fail.
|
|
"""
|
|
if not repo_id:
|
|
raise ValueError('Repo id cannot be empty!')
|
|
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
|
|
if repo_type not in REPO_TYPE_SUPPORT:
|
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
|
|
|
url = f'{endpoint}/api/v1/repos/{repo_type}s/{repo_id}/commit/{revision}'
|
|
commit_message = commit_message or f'Commit to {repo_id}'
|
|
commit_description = commit_description or ''
|
|
|
|
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
|
|
|
# Construct payload
|
|
payload = self._prepare_commit_payload(
|
|
operations=operations,
|
|
commit_message=commit_message,
|
|
)
|
|
|
|
# POST with retry mechanism
|
|
last_exception = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
if attempt > 0:
|
|
logger.info(f'Attempt {attempt + 1} to create commit for {repo_id}...')
|
|
response = requests.post(
|
|
url,
|
|
headers=self.builder_headers(self.headers),
|
|
data=json.dumps(payload),
|
|
cookies=cookies,
|
|
timeout=timeout,
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
try:
|
|
error_detail = response.json()
|
|
except json.JSONDecodeError:
|
|
error_detail = response.text
|
|
|
|
error_msg = (
|
|
f'HTTP {response.status_code} error from {url}: '
|
|
f'{error_detail}'
|
|
)
|
|
|
|
# If server error (5xx), we can retry, otherwise (4xx) raise immediately
|
|
if 500 <= response.status_code < 600:
|
|
logger.warning(
|
|
f'Server error on attempt {attempt + 1}: {error_msg}'
|
|
)
|
|
else:
|
|
raise ValueError(f'Client request failed: {error_msg}')
|
|
else:
|
|
resp = response.json()
|
|
|
|
oid = resp.get('Data', {}).get('oid', '')
|
|
logger.info(f'Commit succeeded: {url}')
|
|
return CommitInfo(
|
|
commit_url=url,
|
|
commit_message=commit_message,
|
|
commit_description=commit_description,
|
|
oid=oid,
|
|
)
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
last_exception = e
|
|
logger.warning(f'Request failed on attempt {attempt + 1}: {str(e)}')
|
|
|
|
except Exception as e:
|
|
last_exception = e
|
|
logger.error(f'Unexpected error on attempt {attempt + 1}: {str(e)}')
|
|
if attempt == max_retries - 1:
|
|
raise
|
|
|
|
if attempt < max_retries - 1:
|
|
time.sleep(1)
|
|
|
|
# All retries exhausted
|
|
raise requests.exceptions.RequestException(
|
|
f'Failed to create commit after {max_retries} attempts. Last error: {last_exception}'
|
|
)
|
|
|
|
def upload_file(
|
|
self,
|
|
*,
|
|
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
|
|
path_in_repo: str,
|
|
repo_id: str,
|
|
token: Union[str, None] = None,
|
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
|
commit_message: Optional[str] = None,
|
|
commit_description: Optional[str] = None,
|
|
buffer_size_mb: Optional[int] = 1,
|
|
tqdm_desc: Optional[str] = '[Uploading]',
|
|
disable_tqdm: Optional[bool] = False,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION
|
|
) -> CommitInfo:
|
|
"""
|
|
Upload a file to the ModelScope Hub.
|
|
|
|
Args:
|
|
path_or_fileobj (Union[str, Path, bytes, BinaryIO]):
|
|
The local file path or file-like object (BinaryIO) or bytes to upload.
|
|
path_in_repo (str): The path in the repo to upload to.
|
|
repo_id (str): The repo id in the format of `owner_name/repo_name`.
|
|
token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
|
|
See `https://modelscope.cn/my/myaccesstoken` to get your token.
|
|
repo_type (Optional[str]): The repo type, default to `model`.
|
|
commit_message (Optional[str]): The commit message.
|
|
commit_description (Optional[str]): The commit description.
|
|
buffer_size_mb (Optional[int]): The buffer size in MB for reading the file. Default to 1MB.
|
|
tqdm_desc (Optional[str]): The description for the tqdm progress bar. Default to '[Uploading]'.
|
|
disable_tqdm (Optional[bool]): Whether to disable the tqdm progress bar. Default to False.
|
|
revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
|
|
|
Returns:
|
|
CommitInfo: The commit info.
|
|
|
|
Examples:
|
|
>>> from modelscope.hub.api import HubApi
|
|
>>> api = HubApi()
|
|
>>> commit_info = api.upload_file(
|
|
... path_or_fileobj='/path/to/your/file.txt',
|
|
... path_in_repo='optional/path/in/repo/file.txt',
|
|
... repo_id='your-namespace/your-repo-name',
|
|
... commit_message='Upload file.txt to ModelScope hub'
|
|
... )
|
|
>>> print(commit_info)
|
|
"""
|
|
|
|
if repo_type not in REPO_TYPE_SUPPORT:
|
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
|
|
|
if not path_or_fileobj:
|
|
raise ValueError('Path or file object cannot be empty!')
|
|
|
|
# Check authentication first
|
|
self.get_cookies(access_token=token, cookies_required=True)
|
|
|
|
if isinstance(path_or_fileobj, (str, Path)):
|
|
path_or_fileobj = os.path.abspath(os.path.expanduser(path_or_fileobj))
|
|
path_in_repo = path_in_repo or os.path.basename(path_or_fileobj)
|
|
else:
|
|
# If path_or_fileobj is bytes or BinaryIO, then path_in_repo must be provided
|
|
if not path_in_repo:
|
|
raise ValueError('Arg `path_in_repo` cannot be empty!')
|
|
|
|
# Read file content if path_or_fileobj is a file-like object (BinaryIO)
|
|
# TODO: to be refined
|
|
if isinstance(path_or_fileobj, io.BufferedIOBase):
|
|
path_or_fileobj = path_or_fileobj.read()
|
|
|
|
self.upload_checker.check_file(path_or_fileobj)
|
|
self.upload_checker.check_normal_files(
|
|
file_path_list=[path_or_fileobj],
|
|
repo_type=repo_type,
|
|
)
|
|
|
|
commit_message = (
|
|
commit_message if commit_message is not None else f'Upload {path_in_repo} to ModelScope hub'
|
|
)
|
|
|
|
if buffer_size_mb <= 0:
|
|
raise ValueError('Buffer size: `buffer_size_mb` must be greater than 0')
|
|
|
|
hash_info_d: dict = get_file_hash(
|
|
file_path_or_obj=path_or_fileobj,
|
|
buffer_size_mb=buffer_size_mb,
|
|
)
|
|
file_size: int = hash_info_d['file_size']
|
|
file_hash: str = hash_info_d['file_hash']
|
|
|
|
self.create_repo(repo_id=repo_id,
|
|
token=token,
|
|
repo_type=repo_type,
|
|
endpoint=self.endpoint,
|
|
exist_ok=True,
|
|
create_default_config=False)
|
|
|
|
upload_res: dict = self._upload_blob(
|
|
repo_id=repo_id,
|
|
repo_type=repo_type,
|
|
sha256=file_hash,
|
|
size=file_size,
|
|
data=path_or_fileobj,
|
|
disable_tqdm=disable_tqdm,
|
|
tqdm_desc=tqdm_desc,
|
|
)
|
|
|
|
# Construct commit info and create commit
|
|
add_operation: CommitOperationAdd = CommitOperationAdd(
|
|
path_in_repo=path_in_repo,
|
|
path_or_fileobj=path_or_fileobj,
|
|
file_hash_info=hash_info_d,
|
|
)
|
|
add_operation._upload_mode = 'lfs' if self.upload_checker.is_lfs(path_or_fileobj, repo_type) else 'normal'
|
|
add_operation._is_uploaded = upload_res['is_uploaded']
|
|
operations = [add_operation]
|
|
|
|
print(f'Committing file to {repo_id} ...', flush=True)
|
|
commit_info: CommitInfo = self.create_commit(
|
|
repo_id=repo_id,
|
|
operations=operations,
|
|
commit_message=commit_message,
|
|
commit_description=commit_description,
|
|
token=token,
|
|
repo_type=repo_type,
|
|
revision=revision,
|
|
)
|
|
|
|
return commit_info
|
|
|
|
def upload_folder(
|
|
self,
|
|
*,
|
|
repo_id: str,
|
|
folder_path: Union[str, Path, List[str], List[Path]],
|
|
path_in_repo: Optional[str] = '',
|
|
commit_message: Optional[str] = None,
|
|
commit_description: Optional[str] = None,
|
|
token: Union[str, None] = None,
|
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
max_workers: int = DEFAULT_MAX_WORKERS,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
|
) -> Union[CommitInfo, List[CommitInfo]]:
|
|
"""
|
|
Upload a folder to the ModelScope Hub.
|
|
|
|
Args:
|
|
repo_id (str): The repo id in the format of `owner_name/repo_name`.
|
|
folder_path (Union[str, Path, List[str], List[Path]]): The folder path or list of file paths to upload.
|
|
path_in_repo (Optional[str]): The path in the repo to upload to.
|
|
commit_message (Optional[str]): The commit message.
|
|
commit_description (Optional[str]): The commit description.
|
|
token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
|
|
See `https://modelscope.cn/my/myaccesstoken` to get your token.
|
|
repo_type (Optional[str]): The repo type, default to `model`.
|
|
allow_patterns (Optional[Union[List[str], str]]): The patterns to allow.
|
|
ignore_patterns (Optional[Union[List[str], str]]): The patterns to ignore.
|
|
max_workers (int): The maximum number of workers to use for uploading files concurrently.
|
|
Defaults to `DEFAULT_MAX_WORKERS`.
|
|
revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
|
|
|
Returns:
|
|
Union[CommitInfo, List[CommitInfo]]:
|
|
The commit info or list of commit infos if multiple batches are committed.
|
|
|
|
Examples:
|
|
>>> from modelscope.hub.api import HubApi
|
|
>>> api = HubApi()
|
|
>>> commit_info = api.upload_folder(
|
|
... repo_id='your-namespace/your-repo-name',
|
|
... folder_path='/path/to/your/folder',
|
|
... path_in_repo='optional/path/in/repo',
|
|
... commit_message='Upload my folder',
|
|
... token='your-access-token'
|
|
... )
|
|
>>> print(commit_info.commit_url)
|
|
"""
|
|
if not repo_id:
|
|
raise ValueError('The arg `repo_id` cannot be empty!')
|
|
|
|
if folder_path is None:
|
|
raise ValueError('The arg `folder_path` cannot be None!')
|
|
|
|
if repo_type not in REPO_TYPE_SUPPORT:
|
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
|
|
|
# Check authentication first
|
|
self.get_cookies(access_token=token, cookies_required=True)
|
|
|
|
allow_patterns = allow_patterns if allow_patterns else None
|
|
ignore_patterns = ignore_patterns if ignore_patterns else None
|
|
|
|
# Ignore .git folder
|
|
if ignore_patterns is None:
|
|
ignore_patterns = []
|
|
elif isinstance(ignore_patterns, str):
|
|
ignore_patterns = [ignore_patterns]
|
|
ignore_patterns += DEFAULT_IGNORE_PATTERNS
|
|
|
|
commit_message = (
|
|
commit_message if commit_message is not None else f'Upload to {repo_id} on ModelScope hub'
|
|
)
|
|
commit_description = commit_description or 'Uploading files'
|
|
|
|
# Get the list of files to upload, e.g. [('data/abc.png', '/path/to/abc.png'), ...]
|
|
logger.info('Preparing files to upload ...')
|
|
prepared_repo_objects = self._prepare_upload_folder(
|
|
folder_path_or_files=folder_path,
|
|
path_in_repo=path_in_repo,
|
|
allow_patterns=allow_patterns,
|
|
ignore_patterns=ignore_patterns,
|
|
)
|
|
if len(prepared_repo_objects) == 0:
|
|
raise ValueError(f'No files to upload in the folder: {folder_path} !')
|
|
|
|
logger.info(f'Checking {len(prepared_repo_objects)} files to upload ...')
|
|
self.upload_checker.check_normal_files(
|
|
file_path_list=[item for _, item in prepared_repo_objects],
|
|
repo_type=repo_type,
|
|
)
|
|
|
|
self.create_repo(repo_id=repo_id,
|
|
token=token,
|
|
repo_type=repo_type,
|
|
endpoint=self.endpoint,
|
|
exist_ok=True,
|
|
create_default_config=False)
|
|
|
|
@thread_executor(max_workers=max_workers, disable_tqdm=False)
|
|
def _upload_items(item_pair, **kwargs):
|
|
file_path_in_repo, file_path = item_pair
|
|
|
|
hash_info_d: dict = get_file_hash(
|
|
file_path_or_obj=file_path,
|
|
)
|
|
file_size: int = hash_info_d['file_size']
|
|
file_hash: str = hash_info_d['file_hash']
|
|
|
|
upload_res: dict = self._upload_blob(
|
|
repo_id=repo_id,
|
|
repo_type=repo_type,
|
|
sha256=file_hash,
|
|
size=file_size,
|
|
data=file_path,
|
|
disable_tqdm=file_size <= UPLOAD_BLOB_TQDM_DISABLE_THRESHOLD,
|
|
tqdm_desc='[Uploading ' + file_path_in_repo + ']',
|
|
)
|
|
|
|
return {
|
|
'file_path_in_repo': file_path_in_repo,
|
|
'file_path': file_path,
|
|
'is_uploaded': upload_res['is_uploaded'],
|
|
'file_hash_info': hash_info_d,
|
|
}
|
|
|
|
uploaded_items_list = _upload_items(
|
|
prepared_repo_objects,
|
|
repo_id=repo_id,
|
|
token=token,
|
|
repo_type=repo_type,
|
|
commit_message=commit_message,
|
|
commit_description=commit_description,
|
|
buffer_size_mb=1,
|
|
disable_tqdm=False,
|
|
)
|
|
|
|
# Construct commit info and create commit
|
|
operations = []
|
|
|
|
for item_d in uploaded_items_list:
|
|
prepared_path_in_repo: str = item_d['file_path_in_repo']
|
|
prepared_file_path: str = item_d['file_path']
|
|
is_uploaded: bool = item_d['is_uploaded']
|
|
file_hash_info: dict = item_d['file_hash_info']
|
|
opt = CommitOperationAdd(
|
|
path_in_repo=prepared_path_in_repo,
|
|
path_or_fileobj=prepared_file_path,
|
|
file_hash_info=file_hash_info,
|
|
)
|
|
|
|
# check normal or lfs
|
|
opt._upload_mode = 'lfs' if self.upload_checker.is_lfs(prepared_file_path, repo_type) else 'normal'
|
|
opt._is_uploaded = is_uploaded
|
|
operations.append(opt)
|
|
|
|
if len(operations) == 0:
|
|
raise ValueError(f'No files to upload in the folder: {folder_path} !')
|
|
|
|
# Commit the operations in batches
|
|
commit_batch_size: int = UPLOAD_COMMIT_BATCH_SIZE if UPLOAD_COMMIT_BATCH_SIZE > 0 else len(operations)
|
|
num_batches = (len(operations) - 1) // commit_batch_size + 1
|
|
print(f'Committing {len(operations)} files in {num_batches} batch(es) of size {commit_batch_size}.',
|
|
flush=True)
|
|
commit_infos: List[CommitInfo] = []
|
|
for i in tqdm(range(num_batches), desc='[Committing batches] ', total=num_batches):
|
|
batch_operations = operations[i * commit_batch_size: (i + 1) * commit_batch_size]
|
|
batch_commit_message = f'{commit_message} (batch {i + 1}/{num_batches})'
|
|
|
|
commit_info: CommitInfo = self.create_commit(
|
|
repo_id=repo_id,
|
|
operations=batch_operations,
|
|
commit_message=batch_commit_message,
|
|
commit_description=commit_description,
|
|
token=token,
|
|
repo_type=repo_type,
|
|
revision=revision,
|
|
)
|
|
commit_infos.append(commit_info)
|
|
|
|
return commit_infos[0] if len(commit_infos) == 1 else commit_infos
|
|
|
|
def _upload_blob(
|
|
self,
|
|
*,
|
|
repo_id: str,
|
|
repo_type: str,
|
|
sha256: str,
|
|
size: int,
|
|
data: Union[str, Path, bytes, BinaryIO],
|
|
disable_tqdm: Optional[bool] = False,
|
|
tqdm_desc: Optional[str] = '[Uploading]',
|
|
buffer_size_mb: Optional[int] = 1,
|
|
) -> dict:
|
|
|
|
res_d: dict = dict(
|
|
url=None,
|
|
is_uploaded=False,
|
|
status_code=None,
|
|
status_msg=None,
|
|
)
|
|
|
|
objects = [{'oid': sha256, 'size': size}]
|
|
upload_objects = self._validate_blob(
|
|
repo_id=repo_id,
|
|
repo_type=repo_type,
|
|
objects=objects,
|
|
)
|
|
|
|
# upload_object: {'url': 'xxx', 'oid': 'xxx'}
|
|
upload_object = upload_objects[0] if len(upload_objects) == 1 else None
|
|
|
|
if upload_object is None:
|
|
logger.debug(f'Blob {sha256[:8]} has already uploaded, reuse it.')
|
|
res_d['is_uploaded'] = True
|
|
return res_d
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
cookies = dict(cookies) if cookies else None
|
|
if cookies is None:
|
|
raise ValueError('Token does not exist, please login first.')
|
|
|
|
self.headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"})
|
|
headers = self.builder_headers(self.headers)
|
|
|
|
def read_in_chunks(file_object, pbar, chunk_size=buffer_size_mb * 1024 * 1024):
|
|
"""Lazy function (generator) to read a file piece by piece."""
|
|
while True:
|
|
ck = file_object.read(chunk_size)
|
|
if not ck:
|
|
break
|
|
pbar.update(len(ck))
|
|
yield ck
|
|
|
|
with tqdm(
|
|
total=size,
|
|
unit='B',
|
|
unit_scale=True,
|
|
desc=tqdm_desc,
|
|
disable=disable_tqdm
|
|
) as pbar:
|
|
|
|
if isinstance(data, (str, Path)):
|
|
with open(data, 'rb') as f:
|
|
response = requests.put(
|
|
upload_object['url'],
|
|
headers=headers,
|
|
data=read_in_chunks(f, pbar)
|
|
)
|
|
|
|
elif isinstance(data, bytes):
|
|
response = requests.put(
|
|
upload_object['url'],
|
|
headers=headers,
|
|
data=read_in_chunks(io.BytesIO(data), pbar)
|
|
)
|
|
|
|
elif isinstance(data, io.BufferedIOBase):
|
|
response = requests.put(
|
|
upload_object['url'],
|
|
headers=headers,
|
|
data=read_in_chunks(data, pbar)
|
|
)
|
|
|
|
else:
|
|
raise ValueError('Invalid data type to upload')
|
|
|
|
raise_for_http_status(rsp=response)
|
|
resp = response.json()
|
|
raise_on_error(rsp=resp)
|
|
|
|
res_d['url'] = upload_object['url']
|
|
res_d['status_code'] = resp['Code']
|
|
res_d['status_msg'] = resp['Message']
|
|
|
|
return res_d
|
|
|
|
def _validate_blob(
|
|
self,
|
|
*,
|
|
repo_id: str,
|
|
repo_type: str,
|
|
objects: List[Dict[str, Any]],
|
|
endpoint: Optional[str] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Check the blob has already uploaded.
|
|
True -- uploaded; False -- not uploaded.
|
|
|
|
Args:
|
|
repo_id (str): The repo id ModelScope.
|
|
repo_type (str): The repo type. `dataset`, `model`, etc.
|
|
objects (List[Dict[str, Any]]): The objects to check.
|
|
oid (str): The sha256 hash value.
|
|
size (int): The size of the blob.
|
|
endpoint: the endpoint to use, default to None to use endpoint specified in the class
|
|
|
|
Returns:
|
|
List[Dict[str, Any]]: The result of the check.
|
|
"""
|
|
|
|
# construct URL
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
url = f'{endpoint}/api/v1/repos/{repo_type}s/{repo_id}/info/lfs/objects/batch'
|
|
|
|
# build payload
|
|
payload = {
|
|
'operation': 'upload',
|
|
'objects': objects,
|
|
}
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if cookies is None:
|
|
raise ValueError('Token does not exist, please login first.')
|
|
response = requests.post(
|
|
url,
|
|
headers=self.builder_headers(self.headers),
|
|
data=json.dumps(payload),
|
|
cookies=cookies
|
|
)
|
|
|
|
raise_for_http_status(rsp=response)
|
|
resp = response.json()
|
|
raise_on_error(rsp=resp)
|
|
|
|
upload_objects = [] # list of objects to upload, [{'url': 'xxx', 'oid': 'xxx'}, ...]
|
|
resp_objects = resp['Data']['objects']
|
|
for obj in resp_objects:
|
|
upload_objects.append(
|
|
{'url': obj['actions']['upload']['href'],
|
|
'oid': obj['oid']}
|
|
)
|
|
|
|
return upload_objects
|
|
|
|
def _prepare_upload_folder(
|
|
self,
|
|
folder_path_or_files: Union[str, Path, List[str], List[Path]],
|
|
path_in_repo: str,
|
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
) -> List[Union[tuple, list]]:
|
|
folder_path = None
|
|
files_path = None
|
|
if isinstance(folder_path_or_files, list):
|
|
if os.path.isfile(folder_path_or_files[0]):
|
|
files_path = folder_path_or_files
|
|
else:
|
|
raise ValueError('Uploading multiple folders is not supported now.')
|
|
else:
|
|
if os.path.isfile(folder_path_or_files):
|
|
files_path = [folder_path_or_files]
|
|
else:
|
|
folder_path = folder_path_or_files
|
|
|
|
if files_path is None:
|
|
self.upload_checker.check_folder(folder_path)
|
|
folder_path = Path(folder_path).expanduser().resolve()
|
|
if not folder_path.is_dir():
|
|
raise ValueError(f"Provided path: '{folder_path}' is not a directory")
|
|
|
|
# List files from folder
|
|
relpath_to_abspath = {
|
|
path.relative_to(folder_path).as_posix(): path
|
|
for path in sorted(folder_path.glob('**/*')) # sorted to be deterministic
|
|
if path.is_file()
|
|
}
|
|
else:
|
|
relpath_to_abspath = {}
|
|
for path in files_path:
|
|
if os.path.isfile(path):
|
|
self.upload_checker.check_file(path)
|
|
relpath_to_abspath[os.path.basename(path)] = path
|
|
|
|
# Filter files
|
|
filtered_repo_objects = list(
|
|
RepoUtils.filter_repo_objects(
|
|
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
|
)
|
|
)
|
|
|
|
prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ''
|
|
|
|
prepared_repo_objects = [
|
|
(prefix + relpath, str(relpath_to_abspath[relpath]))
|
|
for relpath in filtered_repo_objects
|
|
]
|
|
|
|
logger.info(f'Prepared {len(prepared_repo_objects)} files for upload.')
|
|
|
|
return prepared_repo_objects
|
|
|
|
@staticmethod
|
|
def _prepare_commit_payload(
|
|
operations: Iterable[CommitOperation],
|
|
commit_message: str,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Prepare the commit payload to be sent to the ModelScope hub.
|
|
"""
|
|
|
|
payload = {
|
|
'commit_message': commit_message,
|
|
'actions': []
|
|
}
|
|
|
|
nb_ignored_files = 0
|
|
|
|
# 2. Send operations, one per line
|
|
for operation in operations:
|
|
|
|
# Skip ignored files
|
|
if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
|
|
logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
|
|
nb_ignored_files += 1
|
|
continue
|
|
|
|
# 2.a. Case adding a normal file
|
|
if isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'normal':
|
|
|
|
commit_action = {
|
|
'action': 'update' if operation._is_uploaded else 'create',
|
|
'path': operation.path_in_repo,
|
|
'type': 'normal',
|
|
'size': operation.upload_info.size,
|
|
'sha256': '',
|
|
'content': operation.b64content().decode(),
|
|
'encoding': 'base64',
|
|
}
|
|
payload['actions'].append(commit_action)
|
|
|
|
# 2.b. Case adding an LFS file
|
|
elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'lfs':
|
|
|
|
commit_action = {
|
|
'action': 'update' if operation._is_uploaded else 'create',
|
|
'path': operation.path_in_repo,
|
|
'type': 'lfs',
|
|
'size': operation.upload_info.size,
|
|
'sha256': operation.upload_info.sha256,
|
|
'content': '',
|
|
'encoding': '',
|
|
}
|
|
payload['actions'].append(commit_action)
|
|
|
|
else:
|
|
raise ValueError(
|
|
f'Unknown operation to commit. Operation: {operation}. Upload mode:'
|
|
f" {getattr(operation, '_upload_mode', None)}"
|
|
)
|
|
|
|
if nb_ignored_files > 0:
|
|
logger.info(f'Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).')
|
|
|
|
return payload
|
|
|
|
def _get_internal_acceleration_domain(self, internal_timeout: float = 0.2):
|
|
"""
|
|
Get the internal acceleration domain.
|
|
|
|
Args:
|
|
internal_timeout (float): The timeout for the request. Default to 0.2s
|
|
|
|
Returns:
|
|
str: The internal acceleration domain. e.g. `cn-hangzhou`, `cn-zhangjiakou`
|
|
"""
|
|
|
|
def send_request(url: str, timeout: float):
|
|
try:
|
|
response = requests.get(url, timeout=timeout)
|
|
response.raise_for_status()
|
|
except requests.exceptions.RequestException:
|
|
response = None
|
|
|
|
return response
|
|
|
|
internal_url = f'{self.endpoint}/api/v1/repos/internalAccelerationInfo'
|
|
|
|
# Get internal url and region for acceleration
|
|
internal_info_response = send_request(url=internal_url, timeout=internal_timeout)
|
|
region_id: str = ''
|
|
if internal_info_response is not None:
|
|
internal_info_response = internal_info_response.json()
|
|
if 'Data' in internal_info_response:
|
|
query_addr = internal_info_response['Data']['InternalRegionQueryAddress']
|
|
else:
|
|
query_addr: str = ''
|
|
|
|
if query_addr:
|
|
domain_response = send_request(query_addr, timeout=internal_timeout)
|
|
if domain_response is not None:
|
|
region_id = domain_response.text.strip()
|
|
|
|
return region_id
|
|
|
|
def delete_files(self,
|
|
repo_id: str,
|
|
repo_type: str,
|
|
delete_patterns: Union[str, List[str]],
|
|
*,
|
|
revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
|
endpoint: Optional[str] = None) -> Dict[str, Any]:
|
|
"""
|
|
Delete files in batch using glob (wildcard) patterns, e.g. '*.py', 'data/*.csv', 'foo*', etc.
|
|
|
|
Example:
|
|
# Delete all Python and Markdown files in a model repo
|
|
api.delete_files(
|
|
repo_id='your_username/your_model',
|
|
repo_type=REPO_TYPE_MODEL,
|
|
delete_patterns=['*.py', '*.md']
|
|
)
|
|
|
|
# Delete all CSV files in the data/ directory of a dataset repo
|
|
api.delete_files(
|
|
repo_id='your_username/your_dataset',
|
|
repo_type=REPO_TYPE_DATASET,
|
|
delete_patterns='data/*.csv'
|
|
)
|
|
|
|
Args:
|
|
repo_id (str): 'owner/repo_name' or 'owner/dataset_name', e.g. 'Koko/my_model'
|
|
repo_type (str): REPO_TYPE_MODEL or REPO_TYPE_DATASET
|
|
delete_patterns (str or List[str]): List of glob patterns, e.g. '*.py', 'data/*.csv', 'foo*'
|
|
revision (str, optional): Branch or tag name
|
|
endpoint (str, optional): API endpoint
|
|
Returns:
|
|
dict: Deletion result
|
|
"""
|
|
if repo_type not in REPO_TYPE_SUPPORT:
|
|
raise ValueError(f'Unsupported repo_type: {repo_type}')
|
|
if not delete_patterns:
|
|
raise ValueError('delete_patterns cannot be empty')
|
|
if isinstance(delete_patterns, str):
|
|
delete_patterns = [delete_patterns]
|
|
|
|
cookies = ModelScopeConfig.get_cookies()
|
|
if not endpoint:
|
|
endpoint = self.endpoint
|
|
if cookies is None:
|
|
raise ValueError('Token does not exist, please login first.')
|
|
headers = self.builder_headers(self.headers)
|
|
|
|
# List all files in the repo
|
|
if repo_type == REPO_TYPE_MODEL:
|
|
files = self.get_model_files(
|
|
repo_id,
|
|
revision=revision or DEFAULT_MODEL_REVISION,
|
|
recursive=True,
|
|
endpoint=endpoint,
|
|
use_cookies=cookies,
|
|
)
|
|
file_paths = [f['Path'] for f in files]
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
file_paths = []
|
|
page_number = 1
|
|
page_size = 100
|
|
while True:
|
|
try:
|
|
dataset_files: List[Dict[str, Any]] = self.get_dataset_files(
|
|
repo_id=repo_id,
|
|
revision=revision or DEFAULT_DATASET_REVISION,
|
|
recursive=True,
|
|
page_number=page_number,
|
|
page_size=page_size,
|
|
endpoint=endpoint,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}')
|
|
break
|
|
|
|
# Parse data (Type: 'tree' or 'blob')
|
|
for file_info_d in dataset_files:
|
|
if file_info_d['Type'] != 'tree':
|
|
file_paths.append(file_info_d['Path'])
|
|
|
|
if len(dataset_files) < page_size:
|
|
break
|
|
|
|
page_number += 1
|
|
else:
|
|
raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
|
|
|
# Glob pattern matching
|
|
to_delete = []
|
|
for path in file_paths:
|
|
for delete_pattern in delete_patterns:
|
|
if fnmatch.fnmatch(path, delete_pattern):
|
|
to_delete.append(path)
|
|
break
|
|
|
|
deleted_files, failed_files = [], []
|
|
for path in to_delete:
|
|
try:
|
|
if repo_type == REPO_TYPE_MODEL:
|
|
owner, repo_name = repo_id.split('/')
|
|
url = f'{endpoint}/api/v1/models/{owner}/{repo_name}/file'
|
|
params = {
|
|
'Revision': revision or DEFAULT_MODEL_REVISION,
|
|
'FilePath': path
|
|
}
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
owner, dataset_name = repo_id.split('/')
|
|
url = f'{endpoint}/api/v1/datasets/{owner}/{dataset_name}/repo'
|
|
params = {
|
|
'FilePath': path
|
|
}
|
|
else:
|
|
raise ValueError(f'Unsupported repo_type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
|
|
|
r = self.session.delete(url, params=params, cookies=cookies, headers=headers)
|
|
raise_for_http_status(r)
|
|
resp = r.json()
|
|
raise_on_error(resp)
|
|
deleted_files.append(path)
|
|
except Exception as e:
|
|
failed_files.append(path)
|
|
logger.error(f'Failed to delete {path}: {str(e)}')
|
|
|
|
return {
|
|
'deleted_files': deleted_files,
|
|
'failed_files': failed_files,
|
|
'total_files': len(to_delete)
|
|
}
|
|
|
|
|
|
class ModelScopeConfig:
|
|
path_credential = expanduser(MODELSCOPE_CREDENTIALS_PATH)
|
|
COOKIES_FILE_NAME = 'cookies'
|
|
GIT_TOKEN_FILE_NAME = 'git_token'
|
|
USER_INFO_FILE_NAME = 'user'
|
|
USER_SESSION_ID_FILE_NAME = 'session'
|
|
cookie_expired_warning = False
|
|
|
|
@staticmethod
|
|
def make_sure_credential_path_exist():
|
|
os.makedirs(ModelScopeConfig.path_credential, exist_ok=True)
|
|
|
|
@staticmethod
|
|
def save_cookies(cookies: CookieJar):
|
|
ModelScopeConfig.make_sure_credential_path_exist()
|
|
with open(
|
|
os.path.join(ModelScopeConfig.path_credential,
|
|
ModelScopeConfig.COOKIES_FILE_NAME), 'wb+') as f:
|
|
pickle.dump(cookies, f)
|
|
|
|
@staticmethod
|
|
def get_cookies():
|
|
cookies_path = os.path.join(ModelScopeConfig.path_credential,
|
|
ModelScopeConfig.COOKIES_FILE_NAME)
|
|
if os.path.exists(cookies_path):
|
|
with open(cookies_path, 'rb') as f:
|
|
cookies = pickle.load(f)
|
|
for cookie in cookies:
|
|
if cookie.name == 'm_session_id' and cookie.is_expired() and \
|
|
not ModelScopeConfig.cookie_expired_warning:
|
|
ModelScopeConfig.cookie_expired_warning = True
|
|
logger.info('Not logged-in, you can login for uploading'
|
|
'or accessing controlled entities.')
|
|
return None
|
|
return cookies
|
|
return None
|
|
|
|
@staticmethod
|
|
def get_user_session_id():
|
|
session_path = os.path.join(ModelScopeConfig.path_credential,
|
|
ModelScopeConfig.USER_SESSION_ID_FILE_NAME)
|
|
session_id = ''
|
|
if os.path.exists(session_path):
|
|
with open(session_path, 'rb') as f:
|
|
session_id = str(f.readline().strip(), encoding='utf-8')
|
|
return session_id
|
|
if session_id == '' or len(session_id) != 32:
|
|
session_id = str(uuid.uuid4().hex)
|
|
ModelScopeConfig.make_sure_credential_path_exist()
|
|
with open(session_path, 'w+') as wf:
|
|
wf.write(session_id)
|
|
|
|
return session_id
|
|
|
|
@staticmethod
|
|
def save_token(token: str):
|
|
ModelScopeConfig.make_sure_credential_path_exist()
|
|
with open(
|
|
os.path.join(ModelScopeConfig.path_credential,
|
|
ModelScopeConfig.GIT_TOKEN_FILE_NAME), 'w+') as f:
|
|
f.write(token)
|
|
|
|
@staticmethod
|
|
def save_user_info(user_name: str, user_email: str):
|
|
ModelScopeConfig.make_sure_credential_path_exist()
|
|
with open(
|
|
os.path.join(ModelScopeConfig.path_credential,
|
|
ModelScopeConfig.USER_INFO_FILE_NAME), 'w+') as f:
|
|
f.write('%s:%s' % (user_name, user_email))
|
|
|
|
@staticmethod
|
|
def get_user_info() -> Tuple[str, str]:
|
|
try:
|
|
with open(
|
|
os.path.join(ModelScopeConfig.path_credential,
|
|
ModelScopeConfig.USER_INFO_FILE_NAME),
|
|
'r',
|
|
encoding='utf-8') as f:
|
|
info = f.read()
|
|
return info.split(':')[0], info.split(':')[1]
|
|
except FileNotFoundError:
|
|
pass
|
|
return None, None
|
|
|
|
@staticmethod
|
|
def get_token() -> Optional[str]:
|
|
"""
|
|
Get token or None if not existent.
|
|
|
|
Returns:
|
|
`str` or `None`: The token, `None` if it doesn't exist.
|
|
|
|
"""
|
|
token = None
|
|
try:
|
|
with open(
|
|
os.path.join(ModelScopeConfig.path_credential,
|
|
ModelScopeConfig.GIT_TOKEN_FILE_NAME),
|
|
'r',
|
|
encoding='utf-8') as f:
|
|
token = f.read()
|
|
except FileNotFoundError:
|
|
pass
|
|
return token
|
|
|
|
@staticmethod
|
|
def get_user_agent(user_agent: Union[Dict, str, None] = None, ) -> str:
|
|
"""Formats a user-agent string with basic info about a request.
|
|
|
|
Args:
|
|
user_agent (`str`, `dict`, *optional*):
|
|
The user agent info in the form of a dictionary or a single string.
|
|
|
|
Returns:
|
|
The formatted user-agent string.
|
|
"""
|
|
|
|
# include some more telemetrics when executing in dedicated
|
|
# cloud containers
|
|
env = 'custom'
|
|
if MODELSCOPE_CLOUD_ENVIRONMENT in os.environ:
|
|
env = os.environ[MODELSCOPE_CLOUD_ENVIRONMENT]
|
|
user_name = 'unknown'
|
|
if MODELSCOPE_CLOUD_USERNAME in os.environ:
|
|
user_name = os.environ[MODELSCOPE_CLOUD_USERNAME]
|
|
|
|
from modelscope import __version__
|
|
ua = 'modelscope/%s; python/%s; session_id/%s; platform/%s; processor/%s; env/%s; user/%s' % (
|
|
__version__,
|
|
platform.python_version(),
|
|
ModelScopeConfig.get_user_session_id(),
|
|
platform.platform(),
|
|
platform.processor(),
|
|
env,
|
|
user_name,
|
|
)
|
|
if isinstance(user_agent, dict):
|
|
ua += '; ' + '; '.join(f'{k}/{v}' for k, v in user_agent.items())
|
|
elif isinstance(user_agent, str):
|
|
ua += '; ' + user_agent
|
|
return ua
|
|
|
|
|
|
class UploadingCheck:
|
|
"""
|
|
Check the files and folders to be uploaded.
|
|
|
|
Args:
|
|
max_file_count (int): The maximum number of files to be uploaded. Default to `UPLOAD_MAX_FILE_COUNT`.
|
|
max_file_count_in_dir (int): The maximum number of files in a directory.
|
|
Default to `UPLOAD_MAX_FILE_COUNT_IN_DIR`.
|
|
max_file_size (int): The maximum size of a single file in bytes. Default to `UPLOAD_MAX_FILE_SIZE`.
|
|
size_threshold_to_enforce_lfs (int): The size threshold to enforce LFS in bytes.
|
|
Files larger than this size will be enforced to be uploaded via LFS.
|
|
Default to `UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS`.
|
|
normal_file_size_total_limit (int): The total size limit of normal files in bytes.
|
|
Default to `UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT`.
|
|
|
|
Examples:
|
|
>>> from modelscope.hub.api import UploadingCheck
|
|
>>> upload_checker = UploadingCheck()
|
|
>>> upload_checker.check_file('/path/to/your/file.txt')
|
|
>>> upload_checker.check_folder('/path/to/your/folder')
|
|
>>> is_lfs = upload_checker.is_lfs('/path/to/your/file.txt', repo_type='model')
|
|
>>> print(f'Is LFS: {is_lfs}')
|
|
"""
|
|
def __init__(
|
|
self,
|
|
max_file_count: int = UPLOAD_MAX_FILE_COUNT,
|
|
max_file_count_in_dir: int = UPLOAD_MAX_FILE_COUNT_IN_DIR,
|
|
max_file_size: int = UPLOAD_MAX_FILE_SIZE,
|
|
size_threshold_to_enforce_lfs: int = UPLOAD_SIZE_THRESHOLD_TO_ENFORCE_LFS,
|
|
normal_file_size_total_limit: int = UPLOAD_NORMAL_FILE_SIZE_TOTAL_LIMIT,
|
|
):
|
|
self.max_file_count = max_file_count
|
|
self.max_file_count_in_dir = max_file_count_in_dir
|
|
self.max_file_size = max_file_size
|
|
self.size_threshold_to_enforce_lfs = size_threshold_to_enforce_lfs
|
|
self.normal_file_size_total_limit = normal_file_size_total_limit
|
|
|
|
def check_file(self, file_path_or_obj) -> None:
|
|
"""
|
|
Check a single file to be uploaded.
|
|
|
|
Args:
|
|
file_path_or_obj (Union[str, Path, bytes, BinaryIO]): The file path or file-like object to be checked.
|
|
|
|
Raises:
|
|
ValueError: If the file does not exist or exceeds the size limit.
|
|
"""
|
|
if isinstance(file_path_or_obj, (str, Path)):
|
|
if not os.path.exists(file_path_or_obj):
|
|
raise ValueError(f'File {file_path_or_obj} does not exist')
|
|
|
|
file_size: int = get_file_size(file_path_or_obj)
|
|
if file_size > self.max_file_size:
|
|
logger.warning(f'File exceeds size limit: {self.max_file_size / (1024 ** 3)} GB, '
|
|
f'got {round(file_size / (1024 ** 3), 4)} GB')
|
|
|
|
def check_folder(self, folder_path: Union[str, Path]):
|
|
"""
|
|
Check a folder to be uploaded.
|
|
|
|
Args:
|
|
folder_path (Union[str, Path]): The folder path to be checked.
|
|
|
|
Raises:
|
|
ValueError: If the folder does not exist or exceeds the file count limit.
|
|
"""
|
|
file_count = 0
|
|
dir_count = 0
|
|
|
|
if isinstance(folder_path, str):
|
|
folder_path = Path(folder_path)
|
|
|
|
for item in folder_path.iterdir():
|
|
if item.is_file():
|
|
file_count += 1
|
|
item_size: int = get_file_size(item)
|
|
if item_size > self.max_file_size:
|
|
logger.warning(f'File {item} exceeds size limit: {self.max_file_size / (1024 ** 3)} GB',
|
|
f'got {round(item_size / (1024 ** 3), 4)} GB')
|
|
elif item.is_dir():
|
|
dir_count += 1
|
|
# Count items in subdirectories recursively
|
|
sub_file_count, sub_dir_count = self.check_folder(item)
|
|
if (sub_file_count + sub_dir_count) > self.max_file_count_in_dir:
|
|
raise ValueError(f'Directory {item} contains {sub_file_count + sub_dir_count} items '
|
|
f'and exceeds limit: {self.max_file_count_in_dir}')
|
|
file_count += sub_file_count
|
|
dir_count += sub_dir_count
|
|
|
|
if file_count > self.max_file_count:
|
|
raise ValueError(f'Total file count {file_count} and exceeds limit: {self.max_file_count}')
|
|
|
|
return file_count, dir_count
|
|
|
|
def is_lfs(self, file_path_or_obj: Union[str, Path, bytes, BinaryIO], repo_type: str) -> bool:
|
|
"""
|
|
Check if a file should be uploaded via LFS.
|
|
|
|
Args:
|
|
file_path_or_obj (Union[str, Path, bytes, BinaryIO]): The file path or file-like object to be checked.
|
|
repo_type (str): The repo type, either `model` or `dataset`.
|
|
|
|
Returns:
|
|
bool: True if the file should be uploaded via LFS, False otherwise.
|
|
"""
|
|
hit_lfs_suffix = True
|
|
|
|
if isinstance(file_path_or_obj, (str, Path)):
|
|
file_path_or_obj = Path(file_path_or_obj)
|
|
if not file_path_or_obj.exists():
|
|
raise ValueError(f'File {file_path_or_obj} does not exist')
|
|
|
|
if repo_type == REPO_TYPE_MODEL:
|
|
if file_path_or_obj.suffix not in MODEL_LFS_SUFFIX:
|
|
hit_lfs_suffix = False
|
|
elif repo_type == REPO_TYPE_DATASET:
|
|
if file_path_or_obj.suffix not in DATASET_LFS_SUFFIX:
|
|
hit_lfs_suffix = False
|
|
else:
|
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
|
|
|
file_size: int = get_file_size(file_path_or_obj)
|
|
|
|
return file_size > self.size_threshold_to_enforce_lfs or hit_lfs_suffix
|
|
|
|
def check_normal_files(self, file_path_list: List[Union[str, Path]], repo_type: str) -> None:
|
|
"""
|
|
Check a list of normal files to be uploaded.
|
|
|
|
Args:
|
|
file_path_list (List[Union[str, Path]]): The list of file paths to be checked.
|
|
repo_type (str): The repo type, either `model` or `dataset`.
|
|
|
|
Raises:
|
|
ValueError: If the total size of normal files exceeds the limit.
|
|
|
|
Returns: None
|
|
"""
|
|
normal_file_list = [item for item in file_path_list if not self.is_lfs(item, repo_type)]
|
|
total_size = sum([get_file_size(item) for item in normal_file_list])
|
|
|
|
if total_size > self.normal_file_size_total_limit:
|
|
raise ValueError(f'Total size of non-lfs files {total_size / (1024 * 1024)}MB '
|
|
f'and exceeds limit: {self.normal_file_size_total_limit / (1024 * 1024)}MB')
|