mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-14 15:27:42 +01:00
Support upload file and folder in the hub api (#1152)
* update features * update api * add upload_file and thread_executor * update upload file * update api.py * add cli for uploading * run lint * lint in msdataset * temp * add tqdm_desc in thread_executor * update * refine upload_file and upload_folder * add endpoint for cli * add uploading checker * add path_or_fileobj and path_in_repo check in upload_file func * add size limit to lfs: 1MB by default * update lfs limit size: 10MB * 5MB lfs limit * fix test issue * add pbar for upload_blob; del size_to_chunk_mb; fix allow_patterns and ignore_patterns * fix commit uploaded blobs * add update action for folder * fix issues * add normal files check * update * update * set normal file size limit to 500MB * update tqdm
This commit is contained in:
@@ -24,7 +24,7 @@ options:
|
||||
Get access token: [我的页面](https://modelscope.cn/my/myaccesstoken)获取**SDK 令牌**
|
||||
|
||||
|
||||
## download model
|
||||
## download
|
||||
```bash
|
||||
modelscope download --help
|
||||
|
||||
@@ -36,6 +36,7 @@ modelscope download --help
|
||||
options:
|
||||
-h, --help show this help message and exit
|
||||
--model MODEL The model id to be downloaded.
|
||||
--dataset DATASET The dataset id to be downloaded.
|
||||
--revision REVISION Revision of the model.
|
||||
--cache_dir CACHE_DIR
|
||||
Cache directory to save model.
|
||||
|
||||
@@ -11,6 +11,7 @@ from modelscope.cli.modelcard import ModelCardCMD
|
||||
from modelscope.cli.pipeline import PipelineCMD
|
||||
from modelscope.cli.plugins import PluginsCMD
|
||||
from modelscope.cli.server import ServerCMD
|
||||
from modelscope.cli.upload import UploadCMD
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
@@ -25,6 +26,7 @@ def run_cmd():
|
||||
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
||||
|
||||
DownloadCMD.define_args(subparsers)
|
||||
UploadCMD.define_args(subparsers)
|
||||
ClearCacheCMD.define_args(subparsers)
|
||||
PluginsCMD.define_args(subparsers)
|
||||
PipelineCMD.define_args(subparsers)
|
||||
|
||||
179
modelscope/cli/upload.py
Normal file
179
modelscope/cli/upload.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from argparse import ArgumentParser, _SubParsersAction
|
||||
|
||||
from modelscope.cli.base import CLICommand
|
||||
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def subparser_func(args):
|
||||
""" Function which will be called for a specific sub parser.
|
||||
"""
|
||||
return UploadCMD(args)
|
||||
|
||||
|
||||
class UploadCMD(CLICommand):
|
||||
|
||||
name = 'upload'
|
||||
|
||||
def __init__(self, args: _SubParsersAction):
|
||||
self.args = args
|
||||
|
||||
@staticmethod
|
||||
def define_args(parsers: _SubParsersAction):
|
||||
|
||||
parser: ArgumentParser = parsers.add_parser(UploadCMD.name)
|
||||
|
||||
parser.add_argument(
|
||||
'repo_id',
|
||||
type=str,
|
||||
help='The ID of the repo to upload to (e.g. `username/repo-name`)')
|
||||
parser.add_argument(
|
||||
'local_path',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default=None,
|
||||
help='Optional, '
|
||||
'Local path to the file or folder to upload. Defaults to current directory.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'path_in_repo',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default=None,
|
||||
help='Optional, '
|
||||
'Path of the file or folder in the repo. Defaults to the relative path of the file or folder.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--repo-type',
|
||||
choices=REPO_TYPE_SUPPORT,
|
||||
default=REPO_TYPE_MODEL,
|
||||
help=
|
||||
'Type of the repo to upload to (e.g. `dataset`, `model`). Defaults to be `model`.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--include',
|
||||
nargs='*',
|
||||
type=str,
|
||||
help='Glob patterns to match files to upload.')
|
||||
parser.add_argument(
|
||||
'--exclude',
|
||||
nargs='*',
|
||||
type=str,
|
||||
help='Glob patterns to exclude from files to upload.')
|
||||
parser.add_argument(
|
||||
'--commit-message',
|
||||
type=str,
|
||||
default=None,
|
||||
help='The message of commit. Default to be `None`.')
|
||||
parser.add_argument(
|
||||
'--commit-description',
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
'The description of the generated commit. Default to be `None`.')
|
||||
parser.add_argument(
|
||||
'--token',
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
'A User Access Token generated from https://modelscope.cn/my/myaccesstoken'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max-workers',
|
||||
type=int,
|
||||
default=min(8,
|
||||
os.cpu_count() + 4),
|
||||
help='The number of workers to use for uploading files.')
|
||||
parser.add_argument(
|
||||
'--endpoint',
|
||||
type=str,
|
||||
default='https://www.modelscope.cn',
|
||||
help='Endpoint for Modelscope service.')
|
||||
|
||||
parser.set_defaults(func=subparser_func)
|
||||
|
||||
def execute(self):
|
||||
|
||||
assert self.args.repo_id, '`repo_id` is required'
|
||||
assert self.args.repo_id.count(
|
||||
'/') == 1, 'repo_id should be in format of username/repo-name'
|
||||
repo_name: str = self.args.repo_id.split('/')[-1]
|
||||
self.repo_id = self.args.repo_id
|
||||
|
||||
# Check path_in_repo
|
||||
if self.args.local_path is None and os.path.isfile(repo_name):
|
||||
# Case 1: modelscope upload owner_name/test_repo
|
||||
self.local_path = repo_name
|
||||
self.path_in_repo = repo_name
|
||||
elif self.args.local_path is None and os.path.isdir(repo_name):
|
||||
# Case 2: modelscope upload owner_name/test_repo (run command line in the `repo_name` dir)
|
||||
# => upload all files in current directory to remote root path
|
||||
self.local_path = repo_name
|
||||
self.path_in_repo = '.'
|
||||
elif self.args.local_path is None:
|
||||
# Case 3: user provided only a repo_id that does not match a local file or folder
|
||||
# => the user must explicitly provide a local_path => raise exception
|
||||
raise ValueError(
|
||||
f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly."
|
||||
)
|
||||
elif self.args.path_in_repo is None and os.path.isfile(
|
||||
self.args.local_path):
|
||||
# Case 4: modelscope upload owner_name/test_repo /path/to/your_file.csv
|
||||
# => upload it to remote root path with same name
|
||||
self.local_path = self.args.local_path
|
||||
self.path_in_repo = os.path.basename(self.args.local_path)
|
||||
elif self.args.path_in_repo is None:
|
||||
# Case 5: modelscope upload owner_name/test_repo /path/to/your_folder
|
||||
# => upload all files in current directory to remote root path
|
||||
self.local_path = self.args.local_path
|
||||
self.path_in_repo = ''
|
||||
else:
|
||||
# Finally, if both paths are explicit
|
||||
self.local_path = self.args.local_path
|
||||
self.path_in_repo = self.args.path_in_repo
|
||||
|
||||
# Check token and login
|
||||
# The cookies will be reused if the user has logged in before.
|
||||
api = HubApi(endpoint=self.args.endpoint)
|
||||
|
||||
if self.args.token:
|
||||
api.login(access_token=self.args.token)
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError(
|
||||
'The `token` is not provided! '
|
||||
'You can pass the `--token` argument, '
|
||||
'or use api.login(access_token=`your_sdk_token`). '
|
||||
'Your token is available at https://modelscope.cn/my/myaccesstoken'
|
||||
)
|
||||
|
||||
if os.path.isfile(self.local_path):
|
||||
commit_info = api.upload_file(
|
||||
path_or_fileobj=self.local_path,
|
||||
path_in_repo=self.path_in_repo,
|
||||
repo_id=self.repo_id,
|
||||
repo_type=self.args.repo_type,
|
||||
commit_message=self.args.commit_message,
|
||||
commit_description=self.args.commit_description,
|
||||
)
|
||||
elif os.path.isdir(self.local_path):
|
||||
commit_info = api.upload_folder(
|
||||
repo_id=self.repo_id,
|
||||
folder_path=self.local_path,
|
||||
path_in_repo=self.path_in_repo,
|
||||
commit_message=self.args.commit_message,
|
||||
commit_description=self.args.commit_description,
|
||||
repo_type=self.args.repo_type,
|
||||
allow_patterns=self.args.include,
|
||||
ignore_patterns=self.args.exclude,
|
||||
max_workers=self.args.max_workers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'{self.local_path} is not a valid local path')
|
||||
|
||||
logger.info(f'Upload finished, commit info: {commit_info}')
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import io
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
@@ -13,13 +14,15 @@ from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
from http.cookiejar import CookieJar
|
||||
from os.path import expanduser
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from pathlib import Path
|
||||
from typing import Any, BinaryIO, Dict, Iterable, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import json
|
||||
import requests
|
||||
from requests import Session
|
||||
from requests.adapters import HTTPAdapter, Retry
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||
API_HTTP_CLIENT_TIMEOUT,
|
||||
@@ -29,6 +32,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||
API_RESPONSE_FIELD_MESSAGE,
|
||||
API_RESPONSE_FIELD_USERNAME,
|
||||
DEFAULT_CREDENTIALS_PATH,
|
||||
DEFAULT_MAX_WORKERS,
|
||||
MODELSCOPE_CLOUD_ENVIRONMENT,
|
||||
MODELSCOPE_CLOUD_USERNAME,
|
||||
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
|
||||
@@ -36,25 +40,34 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||
TEMPORARY_FOLDER_NAME, DatasetVisibility,
|
||||
Licenses, ModelVisibility)
|
||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||
NotLoginException, NoValidRevisionError,
|
||||
RequestError, datahub_raise_on_error,
|
||||
NotLoginException, RequestError,
|
||||
datahub_raise_on_error,
|
||||
handle_http_post_error,
|
||||
handle_http_response, is_ok,
|
||||
raise_for_http_status, raise_on_error)
|
||||
from modelscope.hub.git import GitCommandWrapper
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.utils.utils import (get_endpoint, get_readable_folder_size,
|
||||
get_release_datetime,
|
||||
model_id_to_group_owner_name)
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION,
|
||||
DEFAULT_REPOSITORY_REVISION,
|
||||
MASTER_MODEL_BRANCH, META_FILES_FORMAT,
|
||||
REPO_TYPE_MODEL, ConfigFields,
|
||||
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||
REPO_TYPE_SUPPORT, ConfigFields,
|
||||
DatasetFormations, DatasetMetaFormats,
|
||||
DatasetVisibilityMap, DownloadChannel,
|
||||
DownloadMode, Frameworks, ModelFile,
|
||||
Tasks, VirgoDatasetConfig)
|
||||
from modelscope.utils.file_utils import get_file_hash, get_file_size
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .utils.utils import (get_endpoint, get_readable_folder_size,
|
||||
get_release_datetime, model_id_to_group_owner_name)
|
||||
from modelscope.utils.repo_utils import (DATASET_LFS_SUFFIX,
|
||||
DEFAULT_IGNORE_PATTERNS,
|
||||
MODEL_LFS_SUFFIX, CommitInfo,
|
||||
CommitOperation, CommitOperationAdd,
|
||||
RepoUtils)
|
||||
from modelscope.utils.thread_utils import thread_executor
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -93,6 +106,8 @@ class HubApi:
|
||||
getattr(self.session, method),
|
||||
timeout=timeout))
|
||||
|
||||
self.upload_checker = UploadingCheck()
|
||||
|
||||
def login(
|
||||
self,
|
||||
access_token: str,
|
||||
@@ -181,7 +196,7 @@ class HubApi:
|
||||
headers=self.builder_headers(self.headers))
|
||||
handle_http_post_error(r, path, body)
|
||||
raise_on_error(r.json())
|
||||
model_repo_url = f'{get_endpoint()}/{model_id}'
|
||||
model_repo_url = f'{self.endpoint}/{model_id}'
|
||||
return model_repo_url
|
||||
|
||||
def delete_model(self, model_id: str):
|
||||
@@ -1173,6 +1188,572 @@ class HubApi:
|
||||
return f'{self.endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?'
|
||||
# return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
|
||||
|
||||
def create_repo(
|
||||
self,
|
||||
repo_id: str,
|
||||
*,
|
||||
token: Union[str, bool, None] = None,
|
||||
visibility: Optional[str] = 'public',
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
chinese_name: Optional[str] = '',
|
||||
license: Optional[str] = Licenses.APACHE_V2,
|
||||
) -> str:
|
||||
|
||||
# TODO: exist_ok
|
||||
|
||||
if not repo_id:
|
||||
raise ValueError('Repo id cannot be empty!')
|
||||
|
||||
if token:
|
||||
self.login(access_token=token)
|
||||
else:
|
||||
logger.warning('No token provided, will use the cached token.')
|
||||
|
||||
repo_id_list = repo_id.split('/')
|
||||
if len(repo_id_list) != 2:
|
||||
raise ValueError('Invalid repo id, should be in the format of `owner_name/repo_name`')
|
||||
namespace, repo_name = repo_id_list
|
||||
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
visibilities = {k: v for k, v in ModelVisibility.__dict__.items() if not k.startswith('__')}
|
||||
visibility: int = visibilities.get(visibility.upper())
|
||||
if visibility is None:
|
||||
raise ValueError(f'Invalid visibility: {visibility}, '
|
||||
f'supported visibilities: `public`, `private`, `internal`')
|
||||
repo_url: str = self.create_model(
|
||||
model_id=repo_id,
|
||||
visibility=visibility,
|
||||
license=license,
|
||||
chinese_name=chinese_name,
|
||||
)
|
||||
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')}
|
||||
visibility: int = visibilities.get(visibility.upper())
|
||||
if visibility is None:
|
||||
raise ValueError(f'Invalid visibility: {visibility}, '
|
||||
f'supported visibilities: `public`, `private`, `internal`')
|
||||
repo_url: str = self.create_dataset(
|
||||
dataset_name=repo_name,
|
||||
namespace=namespace,
|
||||
chinese_name=chinese_name,
|
||||
license=license,
|
||||
visibility=visibility,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
return repo_url
|
||||
|
||||
def create_commit(
|
||||
self,
|
||||
repo_id: str,
|
||||
operations: Iterable[CommitOperation],
|
||||
*,
|
||||
commit_message: str,
|
||||
commit_description: Optional[str] = None,
|
||||
token: str = None,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
) -> CommitInfo:
|
||||
|
||||
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/commit/{revision}'
|
||||
commit_message = commit_message or f'Commit to {repo_id}'
|
||||
commit_description = commit_description or ''
|
||||
|
||||
if token:
|
||||
self.login(access_token=token)
|
||||
|
||||
# Construct payload
|
||||
payload = self._prepare_commit_payload(
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
# POST
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.builder_headers(self.headers),
|
||||
data=json.dumps(payload),
|
||||
cookies=cookies
|
||||
)
|
||||
|
||||
resp = response.json()
|
||||
|
||||
if not resp['Success']:
|
||||
commit_message = resp['Message']
|
||||
logger.warning(f'{commit_message}')
|
||||
|
||||
return CommitInfo(
|
||||
commit_url=url,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
oid='',
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
self,
|
||||
*,
|
||||
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
|
||||
path_in_repo: str,
|
||||
repo_id: str,
|
||||
token: Union[str, None] = None,
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
commit_message: Optional[str] = None,
|
||||
commit_description: Optional[str] = None,
|
||||
buffer_size_mb: Optional[int] = 1,
|
||||
tqdm_desc: Optional[str] = '[Uploading]',
|
||||
disable_tqdm: Optional[bool] = False,
|
||||
) -> CommitInfo:
|
||||
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
if not path_or_fileobj:
|
||||
raise ValueError('Path or file object cannot be empty!')
|
||||
|
||||
if isinstance(path_or_fileobj, (str, Path)):
|
||||
path_or_fileobj = os.path.abspath(os.path.expanduser(path_or_fileobj))
|
||||
path_in_repo = path_in_repo or os.path.basename(path_or_fileobj)
|
||||
|
||||
else:
|
||||
# If path_or_fileobj is bytes or BinaryIO, then path_in_repo must be provided
|
||||
if not path_in_repo:
|
||||
raise ValueError('Arg `path_in_repo` cannot be empty!')
|
||||
|
||||
# Read file content if path_or_fileobj is a file-like object (BinaryIO)
|
||||
# TODO: to be refined
|
||||
if isinstance(path_or_fileobj, io.BufferedIOBase):
|
||||
path_or_fileobj = path_or_fileobj.read()
|
||||
|
||||
self.upload_checker.check_file(path_or_fileobj)
|
||||
self.upload_checker.check_normal_files(
|
||||
file_path_list=[path_or_fileobj],
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
if token:
|
||||
self.login(access_token=token)
|
||||
|
||||
commit_message = (
|
||||
commit_message if commit_message is not None else f'Upload {path_in_repo} to ModelScope hub'
|
||||
)
|
||||
|
||||
if buffer_size_mb <= 0:
|
||||
raise ValueError('Buffer size: `buffer_size_mb` must be greater than 0')
|
||||
|
||||
hash_info_d: dict = get_file_hash(
|
||||
file_path_or_obj=path_or_fileobj,
|
||||
buffer_size_mb=buffer_size_mb,
|
||||
)
|
||||
file_size: int = hash_info_d['file_size']
|
||||
file_hash: str = hash_info_d['file_hash']
|
||||
|
||||
upload_res: dict = self._upload_blob(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
sha256=file_hash,
|
||||
size=file_size,
|
||||
data=path_or_fileobj,
|
||||
disable_tqdm=disable_tqdm,
|
||||
tqdm_desc=tqdm_desc,
|
||||
)
|
||||
|
||||
# Construct commit info and create commit
|
||||
add_operation: CommitOperationAdd = CommitOperationAdd(
|
||||
path_in_repo=path_in_repo,
|
||||
path_or_fileobj=path_or_fileobj,
|
||||
)
|
||||
add_operation._upload_mode = 'lfs' if self.upload_checker.is_lfs(path_or_fileobj, repo_type) else 'normal'
|
||||
add_operation._is_uploaded = upload_res['is_uploaded']
|
||||
operations = [add_operation]
|
||||
|
||||
commit_info: CommitInfo = self.create_commit(
|
||||
repo_id=repo_id,
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
token=token,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
return commit_info
|
||||
|
||||
def upload_folder(
|
||||
self,
|
||||
*,
|
||||
repo_id: str,
|
||||
folder_path: Union[str, Path],
|
||||
path_in_repo: Optional[str] = '',
|
||||
commit_message: Optional[str] = None,
|
||||
commit_description: Optional[str] = None,
|
||||
token: Union[str, None] = None,
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
) -> CommitInfo:
|
||||
|
||||
if repo_type not in REPO_TYPE_SUPPORT:
|
||||
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
allow_patterns = allow_patterns if allow_patterns else None
|
||||
ignore_patterns = ignore_patterns if ignore_patterns else None
|
||||
|
||||
self.upload_checker.check_folder(folder_path)
|
||||
|
||||
# Ignore .git folder
|
||||
if ignore_patterns is None:
|
||||
ignore_patterns = []
|
||||
elif isinstance(ignore_patterns, str):
|
||||
ignore_patterns = [ignore_patterns]
|
||||
ignore_patterns += DEFAULT_IGNORE_PATTERNS
|
||||
|
||||
if token:
|
||||
self.login(access_token=token)
|
||||
|
||||
commit_message = (
|
||||
commit_message if commit_message is not None else f'Upload folder to {repo_id} on ModelScope hub'
|
||||
)
|
||||
commit_description = commit_description or 'Uploading folder'
|
||||
|
||||
# Get the list of files to upload, e.g. [('data/abc.png', '/path/to/abc.png'), ...]
|
||||
prepared_repo_objects = HubApi._prepare_upload_folder(
|
||||
folder_path=folder_path,
|
||||
path_in_repo=path_in_repo,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
)
|
||||
|
||||
self.upload_checker.check_normal_files(
|
||||
file_path_list = [item for _, item in prepared_repo_objects],
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
@thread_executor(max_workers=max_workers, disable_tqdm=False)
|
||||
def _upload_items(item_pair, **kwargs):
|
||||
file_path_in_repo, file_path = item_pair
|
||||
|
||||
hash_info_d: dict = get_file_hash(
|
||||
file_path_or_obj=file_path,
|
||||
)
|
||||
file_size: int = hash_info_d['file_size']
|
||||
file_hash: str = hash_info_d['file_hash']
|
||||
|
||||
upload_res: dict = self._upload_blob(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
sha256=file_hash,
|
||||
size=file_size,
|
||||
data=file_path,
|
||||
disable_tqdm=False if file_size > 10 * 1024 * 1024 else True,
|
||||
)
|
||||
|
||||
return {
|
||||
'file_path_in_repo': file_path_in_repo,
|
||||
'file_path': file_path,
|
||||
'is_uploaded': upload_res['is_uploaded'],
|
||||
}
|
||||
|
||||
uploaded_items_list = _upload_items(
|
||||
prepared_repo_objects,
|
||||
repo_id=repo_id,
|
||||
token=token,
|
||||
repo_type=repo_type,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
buffer_size_mb=1,
|
||||
tqdm_desc='[Uploading]',
|
||||
disable_tqdm=False,
|
||||
)
|
||||
|
||||
logger.info(f'Uploading folder to {repo_id} finished')
|
||||
|
||||
# Construct commit info and create commit
|
||||
operations = []
|
||||
|
||||
for item_d in uploaded_items_list:
|
||||
prepared_path_in_repo: str = item_d['file_path_in_repo']
|
||||
prepared_file_path: str = item_d['file_path']
|
||||
is_uploaded: bool = item_d['is_uploaded']
|
||||
opt = CommitOperationAdd(
|
||||
path_in_repo=prepared_path_in_repo,
|
||||
path_or_fileobj=prepared_file_path,
|
||||
)
|
||||
|
||||
# check normal or lfs
|
||||
opt._upload_mode = 'lfs' if self.upload_checker.is_lfs(prepared_file_path, repo_type) else 'normal'
|
||||
opt._is_uploaded = is_uploaded
|
||||
operations.append(opt)
|
||||
|
||||
self.create_commit(
|
||||
repo_id=repo_id,
|
||||
operations=operations,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
token=token,
|
||||
repo_type=repo_type,
|
||||
)
|
||||
|
||||
# Construct commit info
|
||||
commit_url = f'{self.endpoint}/api/v1/{repo_type}s/{repo_id}/commit/{DEFAULT_REPOSITORY_REVISION}'
|
||||
return CommitInfo(
|
||||
commit_url=commit_url,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
oid='')
|
||||
|
||||
def _upload_blob(
|
||||
self,
|
||||
*,
|
||||
repo_id: str,
|
||||
repo_type: str,
|
||||
sha256: str,
|
||||
size: int,
|
||||
data: Union[str, Path, bytes, BinaryIO],
|
||||
disable_tqdm: Optional[bool] = False,
|
||||
tqdm_desc: Optional[str] = '[Uploading]',
|
||||
buffer_size_mb: Optional[int] = 1,
|
||||
) -> dict:
|
||||
|
||||
res_d: dict = dict(
|
||||
url=None,
|
||||
is_uploaded=False,
|
||||
status_code=None,
|
||||
status_msg=None,
|
||||
)
|
||||
|
||||
objects = [{'oid': sha256, 'size': size}]
|
||||
upload_objects = self._validate_blob(
|
||||
repo_id=repo_id,
|
||||
repo_type=repo_type,
|
||||
objects=objects,
|
||||
)
|
||||
|
||||
# upload_object: {'url': 'xxx', 'oid': 'xxx'}
|
||||
upload_object = upload_objects[0] if len(upload_objects) == 1 else None
|
||||
|
||||
if upload_object is None:
|
||||
logger.info(f'Blob {sha256} has already uploaded, reuse it.')
|
||||
res_d['is_uploaded'] = True
|
||||
return res_d
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
cookies = dict(cookies) if cookies else None
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
|
||||
self.headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"})
|
||||
headers = self.builder_headers(self.headers)
|
||||
|
||||
def read_in_chunks(file_object, pbar, chunk_size=buffer_size_mb * 1024 * 1024):
|
||||
"""Lazy function (generator) to read a file piece by piece."""
|
||||
while True:
|
||||
ck = file_object.read(chunk_size)
|
||||
if not ck:
|
||||
break
|
||||
pbar.update(len(ck))
|
||||
yield ck
|
||||
|
||||
with tqdm(
|
||||
total=size,
|
||||
unit='B',
|
||||
unit_scale=True,
|
||||
desc=tqdm_desc,
|
||||
disable=disable_tqdm
|
||||
) as pbar:
|
||||
|
||||
if isinstance(data, (str, Path)):
|
||||
with open(data, 'rb') as f:
|
||||
response = requests.put(
|
||||
upload_object['url'],
|
||||
headers=headers,
|
||||
data=read_in_chunks(f, pbar)
|
||||
)
|
||||
|
||||
elif isinstance(data, bytes):
|
||||
response = requests.put(
|
||||
upload_object['url'],
|
||||
headers=headers,
|
||||
data=read_in_chunks(io.BytesIO(data), pbar)
|
||||
)
|
||||
|
||||
elif isinstance(data, io.BufferedIOBase):
|
||||
response = requests.put(
|
||||
upload_object['url'],
|
||||
headers=headers,
|
||||
data=read_in_chunks(data, pbar)
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid data type to upload')
|
||||
|
||||
resp = response.json()
|
||||
raise_on_error(resp)
|
||||
|
||||
res_d['url'] = upload_object['url']
|
||||
res_d['status_code'] = resp['Code']
|
||||
res_d['status_msg'] = resp['Message']
|
||||
|
||||
return res_d
|
||||
|
||||
def _validate_blob(
|
||||
self,
|
||||
*,
|
||||
repo_id: str,
|
||||
repo_type: str,
|
||||
objects: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Check the blob has already uploaded.
|
||||
True -- uploaded; False -- not uploaded.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repo id ModelScope.
|
||||
repo_type (str): The repo type. `dataset`, `model`, etc.
|
||||
objects (List[Dict[str, Any]]): The objects to check.
|
||||
oid (str): The sha256 hash value.
|
||||
size (int): The size of the blob.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]: The result of the check.
|
||||
"""
|
||||
|
||||
# construct URL
|
||||
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/info/lfs/objects/batch'
|
||||
|
||||
# build payload
|
||||
payload = {
|
||||
'operation': 'upload',
|
||||
'objects': objects,
|
||||
}
|
||||
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
if cookies is None:
|
||||
raise ValueError('Token does not exist, please login first.')
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.builder_headers(self.headers),
|
||||
data=json.dumps(payload),
|
||||
cookies=cookies
|
||||
)
|
||||
|
||||
resp = response.json()
|
||||
raise_on_error(resp)
|
||||
|
||||
upload_objects = [] # list of objects to upload, [{'url': 'xxx', 'oid': 'xxx'}, ...]
|
||||
resp_objects = resp['Data']['objects']
|
||||
for obj in resp_objects:
|
||||
upload_objects.append(
|
||||
{'url': obj['actions']['upload']['href'],
|
||||
'oid': obj['oid']}
|
||||
)
|
||||
|
||||
return upload_objects
|
||||
|
||||
@staticmethod
|
||||
def _prepare_upload_folder(
|
||||
folder_path: Union[str, Path],
|
||||
path_in_repo: str,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
) -> List[Union[tuple, list]]:
|
||||
|
||||
folder_path = Path(folder_path).expanduser().resolve()
|
||||
if not folder_path.is_dir():
|
||||
raise ValueError(f"Provided path: '{folder_path}' is not a directory")
|
||||
|
||||
# List files from folder
|
||||
relpath_to_abspath = {
|
||||
path.relative_to(folder_path).as_posix(): path
|
||||
for path in sorted(folder_path.glob('**/*')) # sorted to be deterministic
|
||||
if path.is_file()
|
||||
}
|
||||
|
||||
# Filter files
|
||||
filtered_repo_objects = list(
|
||||
RepoUtils.filter_repo_objects(
|
||||
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||
)
|
||||
)
|
||||
|
||||
prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ''
|
||||
|
||||
prepared_repo_objects = [
|
||||
(prefix + relpath, str(relpath_to_abspath[relpath]))
|
||||
for relpath in filtered_repo_objects
|
||||
]
|
||||
|
||||
return prepared_repo_objects
|
||||
|
||||
@staticmethod
|
||||
def _prepare_commit_payload(
|
||||
operations: Iterable[CommitOperation],
|
||||
commit_message: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare the commit payload to be sent to the ModelScope hub.
|
||||
"""
|
||||
|
||||
payload = {
|
||||
'commit_message': commit_message,
|
||||
'actions': []
|
||||
}
|
||||
|
||||
nb_ignored_files = 0
|
||||
|
||||
# 2. Send operations, one per line
|
||||
for operation in operations:
|
||||
|
||||
# Skip ignored files
|
||||
if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
|
||||
logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
|
||||
nb_ignored_files += 1
|
||||
continue
|
||||
|
||||
# 2.a. Case adding a normal file
|
||||
if isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'normal':
|
||||
|
||||
commit_action = {
|
||||
'action': 'update' if operation._is_uploaded else 'create',
|
||||
'path': operation.path_in_repo,
|
||||
'type': 'normal',
|
||||
'size': operation.upload_info.size,
|
||||
'sha256': '',
|
||||
'content': operation.b64content().decode(),
|
||||
'encoding': 'base64',
|
||||
}
|
||||
payload['actions'].append(commit_action)
|
||||
|
||||
# 2.b. Case adding an LFS file
|
||||
elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'lfs':
|
||||
|
||||
commit_action = {
|
||||
'action': 'update' if operation._is_uploaded else 'create',
|
||||
'path': operation.path_in_repo,
|
||||
'type': 'lfs',
|
||||
'size': operation.upload_info.size,
|
||||
'sha256': operation.upload_info.sha256,
|
||||
'content': '',
|
||||
'encoding': '',
|
||||
}
|
||||
payload['actions'].append(commit_action)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Unknown operation to commit. Operation: {operation}. Upload mode:'
|
||||
f" {getattr(operation, '_upload_mode', None)}"
|
||||
)
|
||||
|
||||
if nb_ignored_files > 0:
|
||||
logger.info(f'Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).')
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class ModelScopeConfig:
|
||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||
@@ -1316,3 +1897,85 @@ class ModelScopeConfig:
|
||||
elif isinstance(user_agent, str):
|
||||
ua += '; ' + user_agent
|
||||
return ua
|
||||
|
||||
|
||||
class UploadingCheck:
|
||||
def __init__(
|
||||
self,
|
||||
max_file_count: int = 100_000,
|
||||
max_file_count_in_dir: int = 10_000,
|
||||
max_file_size: int = 50 * 1024 ** 3,
|
||||
lfs_size_limit: int = 5 * 1024 * 1024,
|
||||
normal_file_size_total_limit: int = 500 * 1024 * 1024,
|
||||
):
|
||||
self.max_file_count = max_file_count
|
||||
self.max_file_count_in_dir = max_file_count_in_dir
|
||||
self.max_file_size = max_file_size
|
||||
self.lfs_size_limit = lfs_size_limit
|
||||
self.normal_file_size_total_limit = normal_file_size_total_limit
|
||||
|
||||
def check_file(self, file_path_or_obj):
|
||||
|
||||
if isinstance(file_path_or_obj, (str, Path)):
|
||||
if not os.path.exists(file_path_or_obj):
|
||||
raise ValueError(f'File {file_path_or_obj} does not exist')
|
||||
|
||||
file_size: int = get_file_size(file_path_or_obj)
|
||||
if file_size > self.max_file_size:
|
||||
raise ValueError(f'File exceeds size limit: {self.max_file_size / (1024 ** 3)} GB')
|
||||
|
||||
def check_folder(self, folder_path: Union[str, Path]):
|
||||
file_count = 0
|
||||
dir_count = 0
|
||||
|
||||
if isinstance(folder_path, str):
|
||||
folder_path = Path(folder_path)
|
||||
|
||||
for item in folder_path.iterdir():
|
||||
if item.is_file():
|
||||
file_count += 1
|
||||
elif item.is_dir():
|
||||
dir_count += 1
|
||||
# Count items in subdirectories recursively
|
||||
sub_file_count, sub_dir_count = self.check_folder(item)
|
||||
if (sub_file_count + sub_dir_count) > self.max_file_count_in_dir:
|
||||
raise ValueError(f'Directory {item} contains {sub_file_count + sub_dir_count} items '
|
||||
f'and exceeds limit: {self.max_file_count_in_dir}')
|
||||
file_count += sub_file_count
|
||||
dir_count += sub_dir_count
|
||||
|
||||
if file_count > self.max_file_count:
|
||||
raise ValueError(f'Total file count {file_count} and exceeds limit: {self.max_file_count}')
|
||||
|
||||
return file_count, dir_count
|
||||
|
||||
def is_lfs(self, file_path_or_obj: Union[str, Path, bytes, BinaryIO], repo_type: str) -> bool:
|
||||
|
||||
hit_lfs_suffix = True
|
||||
|
||||
if isinstance(file_path_or_obj, (str, Path)):
|
||||
file_path_or_obj = Path(file_path_or_obj)
|
||||
if not file_path_or_obj.exists():
|
||||
raise ValueError(f'File {file_path_or_obj} does not exist')
|
||||
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
if file_path_or_obj.suffix not in MODEL_LFS_SUFFIX:
|
||||
hit_lfs_suffix = False
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
if file_path_or_obj.suffix not in DATASET_LFS_SUFFIX:
|
||||
hit_lfs_suffix = False
|
||||
else:
|
||||
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||
|
||||
file_size: int = get_file_size(file_path_or_obj)
|
||||
|
||||
return file_size > self.lfs_size_limit or hit_lfs_suffix
|
||||
|
||||
def check_normal_files(self, file_path_list: List[Union[str, Path]], repo_type: str) -> None:
|
||||
|
||||
normal_file_list = [item for item in file_path_list if not self.is_lfs(item, repo_type)]
|
||||
total_size = sum([get_file_size(item) for item in normal_file_list])
|
||||
|
||||
if total_size > self.normal_file_size_total_limit:
|
||||
raise ValueError(f'Total size of non-lfs files {total_size/(1024 * 1024)}MB '
|
||||
f'and exceeds limit: {self.normal_file_size_total_limit/(1024 * 1024)}MB')
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -6,7 +6,8 @@ from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
|
||||
Sequence, Union)
|
||||
|
||||
import numpy as np
|
||||
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
||||
from datasets import (Dataset, DatasetDict, Features, IterableDataset,
|
||||
IterableDatasetDict)
|
||||
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
|
||||
from datasets.utils.file_utils import is_relative_path
|
||||
|
||||
@@ -163,6 +164,7 @@ class MsDataset:
|
||||
download_mode: Optional[DownloadMode] = DownloadMode.
|
||||
REUSE_DATASET_IF_EXISTS,
|
||||
cache_dir: Optional[str] = MS_DATASETS_CACHE,
|
||||
features: Optional[Features] = None,
|
||||
use_streaming: Optional[bool] = False,
|
||||
stream_batch_size: Optional[int] = 1,
|
||||
custom_cfg: Optional[Config] = Config(),
|
||||
@@ -305,7 +307,7 @@ class MsDataset:
|
||||
data_files=data_files,
|
||||
split=split,
|
||||
cache_dir=cache_dir,
|
||||
features=None,
|
||||
features=features,
|
||||
download_config=None,
|
||||
download_mode=download_mode.value,
|
||||
revision=version,
|
||||
@@ -334,6 +336,9 @@ class MsDataset:
|
||||
return dataset_inst
|
||||
|
||||
elif hub == Hubs.virgo:
|
||||
warnings.warn(
|
||||
'The option `Hubs.virgo` is deprecated, '
|
||||
'will be removed in the future version.', DeprecationWarning)
|
||||
from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
|
||||
from modelscope.utils.constant import VirgoDatasetConfig
|
||||
# Rewrite the namespace, version and cache_dir for virgo dataset.
|
||||
@@ -395,8 +400,10 @@ class MsDataset:
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
'upload is deprecated, please use git command line to upload the dataset.',
|
||||
DeprecationWarning)
|
||||
'The function `upload` is deprecated, '
|
||||
'please use git command '
|
||||
'or modelscope.hub.api.HubApi.upload_folder '
|
||||
'or modelscope.hub.api.HubApi.upload_file.', DeprecationWarning)
|
||||
|
||||
if not object_name:
|
||||
raise ValueError('object_name cannot be empty!')
|
||||
@@ -446,7 +453,7 @@ class MsDataset:
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
'upload is deprecated, please use git command line to upload the dataset.',
|
||||
'The function `clone_meta` is deprecated, please use git command line to clone the repo.',
|
||||
DeprecationWarning)
|
||||
|
||||
_repo = DatasetRepository(
|
||||
@@ -487,6 +494,12 @@ class MsDataset:
|
||||
None
|
||||
|
||||
"""
|
||||
warnings.warn(
|
||||
'The function `upload_meta` is deprecated, '
|
||||
'please use git command '
|
||||
'or CLI `modelscope upload owner_name/repo_name ...`.',
|
||||
DeprecationWarning)
|
||||
|
||||
_repo = DatasetRepository(
|
||||
repo_work_dir=dataset_work_dir,
|
||||
dataset_id='',
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
from pathlib import Path
|
||||
from shutil import Error, copy2, copystat
|
||||
from typing import BinaryIO, Optional, Union
|
||||
|
||||
|
||||
# TODO: remove this api, unify to flattened args
|
||||
@@ -180,3 +182,85 @@ def copytree_py37(src,
|
||||
if errors:
|
||||
raise Error(errors)
|
||||
return dst
|
||||
|
||||
|
||||
def get_file_size(file_path_or_obj: Union[str, Path, bytes, BinaryIO]) -> int:
|
||||
if isinstance(file_path_or_obj, (str, Path)):
|
||||
file_path = Path(file_path_or_obj)
|
||||
return file_path.stat().st_size
|
||||
elif isinstance(file_path_or_obj, bytes):
|
||||
return len(file_path_or_obj)
|
||||
elif isinstance(file_path_or_obj, io.BufferedIOBase):
|
||||
current_position = file_path_or_obj.tell()
|
||||
file_path_or_obj.seek(0, os.SEEK_END)
|
||||
size = file_path_or_obj.tell()
|
||||
file_path_or_obj.seek(current_position)
|
||||
return size
|
||||
else:
|
||||
raise TypeError(
|
||||
'Unsupported type: must be string, Path, bytes, or io.BufferedIOBase'
|
||||
)
|
||||
|
||||
|
||||
def get_file_hash(
|
||||
file_path_or_obj: Union[str, Path, bytes, BinaryIO],
|
||||
buffer_size_mb: Optional[int] = 1,
|
||||
tqdm_desc: Optional[str] = '[Calculating]',
|
||||
disable_tqdm: Optional[bool] = True,
|
||||
) -> dict:
|
||||
from tqdm import tqdm
|
||||
|
||||
file_size = get_file_size(file_path_or_obj)
|
||||
buffer_size = buffer_size_mb * 1024 * 1024
|
||||
file_hash = hashlib.sha256()
|
||||
chunk_hash_list = []
|
||||
|
||||
progress = tqdm(
|
||||
total=file_size,
|
||||
initial=0,
|
||||
unit_scale=True,
|
||||
dynamic_ncols=True,
|
||||
unit='B',
|
||||
desc=tqdm_desc,
|
||||
disable=disable_tqdm,
|
||||
)
|
||||
|
||||
if isinstance(file_path_or_obj, (str, Path)):
|
||||
with open(file_path_or_obj, 'rb') as f:
|
||||
while byte_chunk := f.read(buffer_size):
|
||||
chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
|
||||
file_hash.update(byte_chunk)
|
||||
progress.update(len(byte_chunk))
|
||||
file_hash = file_hash.hexdigest()
|
||||
final_chunk_size = buffer_size
|
||||
|
||||
elif isinstance(file_path_or_obj, bytes):
|
||||
file_hash.update(file_path_or_obj)
|
||||
file_hash = file_hash.hexdigest()
|
||||
chunk_hash_list.append(file_hash)
|
||||
final_chunk_size = len(file_path_or_obj)
|
||||
progress.update(final_chunk_size)
|
||||
|
||||
elif isinstance(file_path_or_obj, io.BufferedIOBase):
|
||||
while byte_chunk := file_path_or_obj.read(buffer_size):
|
||||
chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
|
||||
file_hash.update(byte_chunk)
|
||||
progress.update(len(byte_chunk))
|
||||
file_hash = file_hash.hexdigest()
|
||||
final_chunk_size = buffer_size
|
||||
|
||||
else:
|
||||
progress.close()
|
||||
raise ValueError(
|
||||
'Input must be str, Path, bytes or a io.BufferedIOBase')
|
||||
|
||||
progress.close()
|
||||
|
||||
return {
|
||||
'file_path_or_obj': file_path_or_obj,
|
||||
'file_hash': file_hash,
|
||||
'file_size': file_size,
|
||||
'chunk_size': final_chunk_size,
|
||||
'chunk_nums': len(chunk_hash_list),
|
||||
'chunk_hash_list': chunk_hash_list,
|
||||
}
|
||||
|
||||
479
modelscope/utils/repo_utils.py
Normal file
479
modelscope/utils/repo_utils.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||
import base64
|
||||
import functools
|
||||
import hashlib
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from typing import (BinaryIO, Callable, Generator, Iterable, Iterator, List,
|
||||
Literal, Optional, TypeVar, Union)
|
||||
|
||||
from modelscope.utils.file_utils import get_file_hash
|
||||
|
||||
T = TypeVar('T')
|
||||
# Always ignore `.git` and `.cache/modelscope` folders in commits
|
||||
DEFAULT_IGNORE_PATTERNS = [
|
||||
'.git',
|
||||
'.git/*',
|
||||
'*/.git',
|
||||
'**/.git/**',
|
||||
'.cache/modelscope',
|
||||
'.cache/modelscope/*',
|
||||
'*/.cache/modelscope',
|
||||
'**/.cache/modelscope/**',
|
||||
]
|
||||
# Forbidden to commit these folders
|
||||
FORBIDDEN_FOLDERS = ['.git', '.cache']
|
||||
|
||||
UploadMode = Literal['lfs', 'normal']
|
||||
|
||||
DATASET_LFS_SUFFIX = [
|
||||
'.7z',
|
||||
'.aac',
|
||||
'.arrow',
|
||||
'.audio',
|
||||
'.bmp',
|
||||
'.bin',
|
||||
'.bz2',
|
||||
'.flac',
|
||||
'.ftz',
|
||||
'.gif',
|
||||
'.gz',
|
||||
'.h5',
|
||||
'.jack',
|
||||
'.jpeg',
|
||||
'.jpg',
|
||||
'.jsonl',
|
||||
'.joblib',
|
||||
'.lz4',
|
||||
'.msgpack',
|
||||
'.npy',
|
||||
'.npz',
|
||||
'.ot',
|
||||
'.parquet',
|
||||
'.pb',
|
||||
'.pickle',
|
||||
'.pcm',
|
||||
'.pkl',
|
||||
'.raw',
|
||||
'.rar',
|
||||
'.sam',
|
||||
'.tar',
|
||||
'.tgz',
|
||||
'.wasm',
|
||||
'.wav',
|
||||
'.webm',
|
||||
'.webp',
|
||||
'.zip',
|
||||
'.zst',
|
||||
'.tiff',
|
||||
'.mp3',
|
||||
'.mp4',
|
||||
'.ogg',
|
||||
]
|
||||
|
||||
MODEL_LFS_SUFFIX = [
|
||||
'.7z',
|
||||
'.arrow',
|
||||
'.bin',
|
||||
'.bz2',
|
||||
'.ckpt',
|
||||
'.ftz',
|
||||
'.gz',
|
||||
'.h5',
|
||||
'.joblib',
|
||||
'.mlmodel',
|
||||
'.model',
|
||||
'.msgpack',
|
||||
'.npy',
|
||||
'.npz',
|
||||
'.onnx',
|
||||
'.ot',
|
||||
'.parquet',
|
||||
'.pb',
|
||||
'.pickle',
|
||||
'.pkl',
|
||||
'.pt',
|
||||
'.pth',
|
||||
'.rar',
|
||||
'.safetensors',
|
||||
'.tar',
|
||||
'.tflite',
|
||||
'.tgz',
|
||||
'.wasm',
|
||||
'.xz',
|
||||
'.zip',
|
||||
'.zst',
|
||||
]
|
||||
|
||||
|
||||
class RepoUtils:
|
||||
|
||||
@staticmethod
|
||||
def filter_repo_objects(
|
||||
items: Iterable[T],
|
||||
*,
|
||||
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||
key: Optional[Callable[[T], str]] = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""Filter repo objects based on an allowlist and a denylist.
|
||||
|
||||
Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
|
||||
In the later case, `key` must be provided and specifies a function of one argument
|
||||
that is used to extract a path from each element in iterable.
|
||||
|
||||
Patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
||||
https://docs.python.org/3/library/fnmatch.html for more details.
|
||||
|
||||
Args:
|
||||
items (`Iterable`):
|
||||
List of items to filter.
|
||||
allow_patterns (`str` or `List[str]`, *optional*):
|
||||
Patterns constituting the allowlist. If provided, item paths must match at
|
||||
least one pattern from the allowlist.
|
||||
ignore_patterns (`str` or `List[str]`, *optional*):
|
||||
Patterns constituting the denylist. If provided, item paths must not match
|
||||
any patterns from the denylist.
|
||||
key (`Callable[[T], str]`, *optional*):
|
||||
Single-argument function to extract a path from each item. If not provided,
|
||||
the `items` must already be `str` or `Path`.
|
||||
|
||||
Returns:
|
||||
Filtered list of objects, as a generator.
|
||||
|
||||
Raises:
|
||||
:class:`ValueError`:
|
||||
If `key` is not provided and items are not `str` or `Path`.
|
||||
|
||||
Example usage with paths:
|
||||
```python
|
||||
>>> # Filter only PDFs that are not hidden.
|
||||
>>> list(RepoUtils.filter_repo_objects(
|
||||
... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
|
||||
... allow_patterns=["*.pdf"],
|
||||
... ignore_patterns=[".*"],
|
||||
... ))
|
||||
["aaa.pdf"]
|
||||
```
|
||||
"""
|
||||
|
||||
allow_patterns = allow_patterns if allow_patterns else None
|
||||
ignore_patterns = ignore_patterns if ignore_patterns else None
|
||||
|
||||
if isinstance(allow_patterns, str):
|
||||
allow_patterns = [allow_patterns]
|
||||
|
||||
if isinstance(ignore_patterns, str):
|
||||
ignore_patterns = [ignore_patterns]
|
||||
|
||||
if allow_patterns is not None:
|
||||
allow_patterns = [
|
||||
RepoUtils._add_wildcard_to_directories(p)
|
||||
for p in allow_patterns
|
||||
]
|
||||
if ignore_patterns is not None:
|
||||
ignore_patterns = [
|
||||
RepoUtils._add_wildcard_to_directories(p)
|
||||
for p in ignore_patterns
|
||||
]
|
||||
|
||||
if key is None:
|
||||
|
||||
def _identity(item: T) -> str:
|
||||
if isinstance(item, str):
|
||||
return item
|
||||
if isinstance(item, Path):
|
||||
return str(item)
|
||||
raise ValueError(
|
||||
f'Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.'
|
||||
)
|
||||
|
||||
key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
|
||||
|
||||
for item in items:
|
||||
path = key(item)
|
||||
|
||||
# Skip if there's an allowlist and path doesn't match any
|
||||
if allow_patterns is not None and not any(
|
||||
fnmatch(path, r) for r in allow_patterns):
|
||||
continue
|
||||
|
||||
# Skip if there's a denylist and path matches any
|
||||
if ignore_patterns is not None and any(
|
||||
fnmatch(path, r) for r in ignore_patterns):
|
||||
continue
|
||||
|
||||
yield item
|
||||
|
||||
@staticmethod
|
||||
def _add_wildcard_to_directories(pattern: str) -> str:
|
||||
if pattern[-1] == '/':
|
||||
return pattern + '*'
|
||||
return pattern
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitInfo(str):
|
||||
"""Data structure containing information about a newly created commit.
|
||||
|
||||
Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`],
|
||||
[`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific
|
||||
to `str` is deprecated.
|
||||
|
||||
Attributes:
|
||||
commit_url (`str`):
|
||||
Url where to find the commit.
|
||||
|
||||
commit_message (`str`):
|
||||
The summary (first line) of the commit that has been created.
|
||||
|
||||
commit_description (`str`):
|
||||
Description of the commit that has been created. Can be empty.
|
||||
|
||||
oid (`str`):
|
||||
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
|
||||
|
||||
pr_url (`str`, *optional*):
|
||||
Url to the PR that has been created, if any. Populated when `create_pr=True`
|
||||
is passed.
|
||||
|
||||
pr_revision (`str`, *optional*):
|
||||
Revision of the PR that has been created, if any. Populated when
|
||||
`create_pr=True` is passed. Example: `"refs/pr/1"`.
|
||||
|
||||
pr_num (`int`, *optional*):
|
||||
Number of the PR discussion that has been created, if any. Populated when
|
||||
`create_pr=True` is passed. Can be passed as `discussion_num` in
|
||||
[`get_discussion_details`]. Example: `1`.
|
||||
|
||||
_url (`str`, *optional*):
|
||||
Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by
|
||||
[`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on
|
||||
the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this
|
||||
attribute. Please use `commit_url` instead.
|
||||
"""
|
||||
|
||||
commit_url: str
|
||||
commit_message: str
|
||||
commit_description: str
|
||||
oid: str
|
||||
pr_url: Optional[str] = None
|
||||
|
||||
# Computed from `pr_url` in `__post_init__`
|
||||
pr_revision: Optional[str] = field(init=False)
|
||||
pr_num: Optional[str] = field(init=False)
|
||||
|
||||
# legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.)
|
||||
_url: str = field(
|
||||
repr=False, default=None) # type: ignore # defaults to `commit_url`
|
||||
|
||||
def __new__(cls,
|
||||
*args,
|
||||
commit_url: str,
|
||||
_url: Optional[str] = None,
|
||||
**kwargs):
|
||||
return str.__new__(cls, _url or commit_url)
|
||||
|
||||
def to_dict(cls):
|
||||
return {
|
||||
'commit_url': cls.commit_url,
|
||||
'commit_message': cls.commit_message,
|
||||
'commit_description': cls.commit_description,
|
||||
'oid': cls.oid,
|
||||
'pr_url': cls.pr_url,
|
||||
}
|
||||
|
||||
|
||||
def git_hash(data: bytes) -> str:
|
||||
"""
|
||||
Computes the git-sha1 hash of the given bytes, using the same algorithm as git.
|
||||
|
||||
This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object
|
||||
for more details.
|
||||
|
||||
Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the
|
||||
pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of
|
||||
the LFS file content when we want to compare LFS files.
|
||||
|
||||
Args:
|
||||
data (`bytes`):
|
||||
The data to compute the git-hash for.
|
||||
|
||||
Returns:
|
||||
`str`: the git-hash of `data` as an hexadecimal string.
|
||||
"""
|
||||
_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {}
|
||||
sha1 = functools.partial(hashlib.sha1, **_kwargs)
|
||||
sha = sha1()
|
||||
sha.update(b'blob ')
|
||||
sha.update(str(len(data)).encode())
|
||||
sha.update(b'\0')
|
||||
sha.update(data)
|
||||
return sha.hexdigest()
|
||||
|
||||
|
||||
@dataclass
|
||||
class UploadInfo:
|
||||
"""
|
||||
Dataclass holding required information to determine whether a blob
|
||||
should be uploaded to the hub using the LFS protocol or the regular protocol
|
||||
|
||||
Args:
|
||||
sha256 (`str`):
|
||||
SHA256 hash of the blob
|
||||
size (`int`):
|
||||
Size in bytes of the blob
|
||||
sample (`bytes`):
|
||||
First 512 bytes of the blob
|
||||
"""
|
||||
|
||||
sha256: str
|
||||
size: int
|
||||
sample: bytes
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, path: str):
|
||||
|
||||
file_hash_info: dict = get_file_hash(path)
|
||||
size = file_hash_info['file_size']
|
||||
sha = file_hash_info['file_hash']
|
||||
sample = open(path, 'rb').read(512)
|
||||
|
||||
return cls(sha256=sha, size=size, sample=sample)
|
||||
|
||||
@classmethod
|
||||
def from_bytes(cls, data: bytes):
|
||||
sha = get_file_hash(data)['file_hash']
|
||||
return cls(size=len(data), sample=data[:512], sha256=sha)
|
||||
|
||||
@classmethod
|
||||
def from_fileobj(cls, fileobj: BinaryIO):
|
||||
fileobj_info: dict = get_file_hash(fileobj)
|
||||
sample = fileobj.read(512)
|
||||
return cls(
|
||||
sha256=fileobj_info['file_hash'],
|
||||
size=fileobj_info['file_size'],
|
||||
sample=sample)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommitOperationAdd:
|
||||
"""Data structure containing information about a file to be added to a commit."""
|
||||
|
||||
path_in_repo: str
|
||||
path_or_fileobj: Union[str, Path, bytes, BinaryIO]
|
||||
upload_info: UploadInfo = field(init=False, repr=False)
|
||||
|
||||
# Internal attributes
|
||||
|
||||
# set to "lfs" or "regular" once known
|
||||
_upload_mode: Optional[UploadMode] = field(
|
||||
init=False, repr=False, default=None)
|
||||
|
||||
# set to True if .gitignore rules prevent the file from being uploaded as LFS
|
||||
# (server-side check)
|
||||
_should_ignore: Optional[bool] = field(
|
||||
init=False, repr=False, default=None)
|
||||
|
||||
# set to the remote OID of the file if it has already been uploaded
|
||||
# useful to determine if a commit will be empty or not
|
||||
_remote_oid: Optional[str] = field(init=False, repr=False, default=None)
|
||||
|
||||
# set to True once the file has been uploaded as LFS
|
||||
_is_uploaded: bool = field(init=False, repr=False, default=False)
|
||||
|
||||
# set to True once the file has been committed
|
||||
_is_committed: bool = field(init=False, repr=False, default=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validates `path_or_fileobj` and compute `upload_info`."""
|
||||
|
||||
# Validate `path_or_fileobj` value
|
||||
if isinstance(self.path_or_fileobj, Path):
|
||||
self.path_or_fileobj = str(self.path_or_fileobj)
|
||||
if isinstance(self.path_or_fileobj, str):
|
||||
path_or_fileobj = os.path.normpath(
|
||||
os.path.expanduser(self.path_or_fileobj))
|
||||
if not os.path.isfile(path_or_fileobj):
|
||||
raise ValueError(
|
||||
f"Provided path: '{path_or_fileobj}' is not a file on the local file system"
|
||||
)
|
||||
elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
|
||||
raise ValueError(
|
||||
'path_or_fileobj must be either an instance of str, bytes or'
|
||||
' io.BufferedIOBase. If you passed a file-like object, make sure it is'
|
||||
' in binary mode.')
|
||||
if isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
||||
try:
|
||||
self.path_or_fileobj.tell()
|
||||
self.path_or_fileobj.seek(0, os.SEEK_CUR)
|
||||
except (OSError, AttributeError) as exc:
|
||||
raise ValueError(
|
||||
'path_or_fileobj is a file-like object but does not implement seek() and tell()'
|
||||
) from exc
|
||||
|
||||
# Compute "upload_info" attribute
|
||||
if isinstance(self.path_or_fileobj, str):
|
||||
self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
|
||||
elif isinstance(self.path_or_fileobj, bytes):
|
||||
self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
|
||||
else:
|
||||
self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
|
||||
|
||||
@contextmanager
|
||||
def as_file(self) -> Iterator[BinaryIO]:
|
||||
"""
|
||||
A context manager that yields a file-like object allowing to read the underlying
|
||||
data behind `path_or_fileobj`.
|
||||
"""
|
||||
if isinstance(self.path_or_fileobj, str) or isinstance(
|
||||
self.path_or_fileobj, Path):
|
||||
with open(self.path_or_fileobj, 'rb') as file:
|
||||
yield file
|
||||
elif isinstance(self.path_or_fileobj, bytes):
|
||||
yield io.BytesIO(self.path_or_fileobj)
|
||||
elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
||||
prev_pos = self.path_or_fileobj.tell()
|
||||
yield self.path_or_fileobj
|
||||
self.path_or_fileobj.seek(prev_pos, 0)
|
||||
|
||||
def b64content(self) -> bytes:
|
||||
"""
|
||||
The base64-encoded content of `path_or_fileobj`
|
||||
|
||||
Returns: `bytes`
|
||||
"""
|
||||
with self.as_file() as file:
|
||||
return base64.b64encode(file.read())
|
||||
|
||||
@property
|
||||
def _local_oid(self) -> Optional[str]:
|
||||
"""Return the OID of the local file.
|
||||
|
||||
This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one.
|
||||
If the file did not change, we won't upload it again to prevent empty commits.
|
||||
|
||||
For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref).
|
||||
For regular files, the OID corresponds to the SHA1 of the file content.
|
||||
Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1
|
||||
of the pointer file content (not the actual file content). However, using the SHA256 is enough to detect
|
||||
changes and more convenient client-side.
|
||||
"""
|
||||
if self._upload_mode is None:
|
||||
return None
|
||||
elif self._upload_mode == 'lfs':
|
||||
return self.upload_info.sha256
|
||||
else:
|
||||
# Regular file => compute sha1
|
||||
# => no need to read by chunk since the file is guaranteed to be <=5MB.
|
||||
with self.as_file() as file:
|
||||
return git_hash(file.read())
|
||||
|
||||
|
||||
CommitOperation = Union[CommitOperationAdd, ]
|
||||
@@ -12,13 +12,15 @@ logger = get_logger()
|
||||
|
||||
|
||||
def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
disable_tqdm=False):
|
||||
disable_tqdm: bool = False,
|
||||
tqdm_desc: str = None):
|
||||
"""
|
||||
A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
|
||||
|
||||
Args:
|
||||
max_workers (int): The maximum number of threads to use.
|
||||
disable_tqdm (bool): disable progress bar.
|
||||
tqdm_desc (str): Desc of tqdm.
|
||||
|
||||
Returns:
|
||||
function: A wrapped function that executes with threading and a progress bar.
|
||||
@@ -43,8 +45,11 @@ def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
|
||||
results = []
|
||||
# Create a tqdm progress bar with the total number of items to process
|
||||
with tqdm(
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
initial=0,
|
||||
total=len(iterable),
|
||||
desc=f'Processing {len(iterable)} items',
|
||||
desc=tqdm_desc or f'Processing {len(iterable)} items',
|
||||
disable=disable_tqdm,
|
||||
) as pbar:
|
||||
# Define a wrapper function to update the progress bar
|
||||
|
||||
Reference in New Issue
Block a user