mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 00:07:42 +01:00
Support upload file and folder in the hub api (#1152)
* update features * update api * add upload_file and thread_executor * update upload file * update api.py * add cli for uploading * run lint * lint in msdataset * temp * add tqdm_desc in thread_executor * update * refine upload_file and upload_folder * add endpoint for cli * add uploading checker * add path_or_fileobj and path_in_repo check in upload_file func * add size limit to lfs: 1MB by default * update lfs limit size: 10MB * 5MB lfs limit * fix test issue * add pbar for upload_blob; del size_to_chunk_mb; fix allow_patterns and ignore_patterns * fix commit uploaded blobs * add update action for folder * fix issues * add normal files check * update * update * set normal file size limit to 500MB * update tqdm
This commit is contained in:
@@ -24,7 +24,7 @@ options:
|
|||||||
Get access token: [我的页面](https://modelscope.cn/my/myaccesstoken)获取**SDK 令牌**
|
Get access token: [我的页面](https://modelscope.cn/my/myaccesstoken)获取**SDK 令牌**
|
||||||
|
|
||||||
|
|
||||||
## download model
|
## download
|
||||||
```bash
|
```bash
|
||||||
modelscope download --help
|
modelscope download --help
|
||||||
|
|
||||||
@@ -36,6 +36,7 @@ modelscope download --help
|
|||||||
options:
|
options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--model MODEL The model id to be downloaded.
|
--model MODEL The model id to be downloaded.
|
||||||
|
--dataset DATASET The dataset id to be downloaded.
|
||||||
--revision REVISION Revision of the model.
|
--revision REVISION Revision of the model.
|
||||||
--cache_dir CACHE_DIR
|
--cache_dir CACHE_DIR
|
||||||
Cache directory to save model.
|
Cache directory to save model.
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from modelscope.cli.modelcard import ModelCardCMD
|
|||||||
from modelscope.cli.pipeline import PipelineCMD
|
from modelscope.cli.pipeline import PipelineCMD
|
||||||
from modelscope.cli.plugins import PluginsCMD
|
from modelscope.cli.plugins import PluginsCMD
|
||||||
from modelscope.cli.server import ServerCMD
|
from modelscope.cli.server import ServerCMD
|
||||||
|
from modelscope.cli.upload import UploadCMD
|
||||||
from modelscope.hub.api import HubApi
|
from modelscope.hub.api import HubApi
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
@@ -25,6 +26,7 @@ def run_cmd():
|
|||||||
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
subparsers = parser.add_subparsers(help='modelscope commands helpers')
|
||||||
|
|
||||||
DownloadCMD.define_args(subparsers)
|
DownloadCMD.define_args(subparsers)
|
||||||
|
UploadCMD.define_args(subparsers)
|
||||||
ClearCacheCMD.define_args(subparsers)
|
ClearCacheCMD.define_args(subparsers)
|
||||||
PluginsCMD.define_args(subparsers)
|
PluginsCMD.define_args(subparsers)
|
||||||
PipelineCMD.define_args(subparsers)
|
PipelineCMD.define_args(subparsers)
|
||||||
|
|||||||
179
modelscope/cli/upload.py
Normal file
179
modelscope/cli/upload.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
|
from argparse import ArgumentParser, _SubParsersAction
|
||||||
|
|
||||||
|
from modelscope.cli.base import CLICommand
|
||||||
|
from modelscope.hub.api import HubApi, ModelScopeConfig
|
||||||
|
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def subparser_func(args):
|
||||||
|
""" Function which will be called for a specific sub parser.
|
||||||
|
"""
|
||||||
|
return UploadCMD(args)
|
||||||
|
|
||||||
|
|
||||||
|
class UploadCMD(CLICommand):
|
||||||
|
|
||||||
|
name = 'upload'
|
||||||
|
|
||||||
|
def __init__(self, args: _SubParsersAction):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def define_args(parsers: _SubParsersAction):
|
||||||
|
|
||||||
|
parser: ArgumentParser = parsers.add_parser(UploadCMD.name)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
'repo_id',
|
||||||
|
type=str,
|
||||||
|
help='The ID of the repo to upload to (e.g. `username/repo-name`)')
|
||||||
|
parser.add_argument(
|
||||||
|
'local_path',
|
||||||
|
type=str,
|
||||||
|
nargs='?',
|
||||||
|
default=None,
|
||||||
|
help='Optional, '
|
||||||
|
'Local path to the file or folder to upload. Defaults to current directory.'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'path_in_repo',
|
||||||
|
type=str,
|
||||||
|
nargs='?',
|
||||||
|
default=None,
|
||||||
|
help='Optional, '
|
||||||
|
'Path of the file or folder in the repo. Defaults to the relative path of the file or folder.'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--repo-type',
|
||||||
|
choices=REPO_TYPE_SUPPORT,
|
||||||
|
default=REPO_TYPE_MODEL,
|
||||||
|
help=
|
||||||
|
'Type of the repo to upload to (e.g. `dataset`, `model`). Defaults to be `model`.',
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--include',
|
||||||
|
nargs='*',
|
||||||
|
type=str,
|
||||||
|
help='Glob patterns to match files to upload.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--exclude',
|
||||||
|
nargs='*',
|
||||||
|
type=str,
|
||||||
|
help='Glob patterns to exclude from files to upload.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--commit-message',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help='The message of commit. Default to be `None`.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--commit-description',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
'The description of the generated commit. Default to be `None`.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--token',
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=
|
||||||
|
'A User Access Token generated from https://modelscope.cn/my/myaccesstoken'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--max-workers',
|
||||||
|
type=int,
|
||||||
|
default=min(8,
|
||||||
|
os.cpu_count() + 4),
|
||||||
|
help='The number of workers to use for uploading files.')
|
||||||
|
parser.add_argument(
|
||||||
|
'--endpoint',
|
||||||
|
type=str,
|
||||||
|
default='https://www.modelscope.cn',
|
||||||
|
help='Endpoint for Modelscope service.')
|
||||||
|
|
||||||
|
parser.set_defaults(func=subparser_func)
|
||||||
|
|
||||||
|
def execute(self):
|
||||||
|
|
||||||
|
assert self.args.repo_id, '`repo_id` is required'
|
||||||
|
assert self.args.repo_id.count(
|
||||||
|
'/') == 1, 'repo_id should be in format of username/repo-name'
|
||||||
|
repo_name: str = self.args.repo_id.split('/')[-1]
|
||||||
|
self.repo_id = self.args.repo_id
|
||||||
|
|
||||||
|
# Check path_in_repo
|
||||||
|
if self.args.local_path is None and os.path.isfile(repo_name):
|
||||||
|
# Case 1: modelscope upload owner_name/test_repo
|
||||||
|
self.local_path = repo_name
|
||||||
|
self.path_in_repo = repo_name
|
||||||
|
elif self.args.local_path is None and os.path.isdir(repo_name):
|
||||||
|
# Case 2: modelscope upload owner_name/test_repo (run command line in the `repo_name` dir)
|
||||||
|
# => upload all files in current directory to remote root path
|
||||||
|
self.local_path = repo_name
|
||||||
|
self.path_in_repo = '.'
|
||||||
|
elif self.args.local_path is None:
|
||||||
|
# Case 3: user provided only a repo_id that does not match a local file or folder
|
||||||
|
# => the user must explicitly provide a local_path => raise exception
|
||||||
|
raise ValueError(
|
||||||
|
f"'{repo_name}' is not a local file or folder. Please set `local_path` explicitly."
|
||||||
|
)
|
||||||
|
elif self.args.path_in_repo is None and os.path.isfile(
|
||||||
|
self.args.local_path):
|
||||||
|
# Case 4: modelscope upload owner_name/test_repo /path/to/your_file.csv
|
||||||
|
# => upload it to remote root path with same name
|
||||||
|
self.local_path = self.args.local_path
|
||||||
|
self.path_in_repo = os.path.basename(self.args.local_path)
|
||||||
|
elif self.args.path_in_repo is None:
|
||||||
|
# Case 5: modelscope upload owner_name/test_repo /path/to/your_folder
|
||||||
|
# => upload all files in current directory to remote root path
|
||||||
|
self.local_path = self.args.local_path
|
||||||
|
self.path_in_repo = ''
|
||||||
|
else:
|
||||||
|
# Finally, if both paths are explicit
|
||||||
|
self.local_path = self.args.local_path
|
||||||
|
self.path_in_repo = self.args.path_in_repo
|
||||||
|
|
||||||
|
# Check token and login
|
||||||
|
# The cookies will be reused if the user has logged in before.
|
||||||
|
api = HubApi(endpoint=self.args.endpoint)
|
||||||
|
|
||||||
|
if self.args.token:
|
||||||
|
api.login(access_token=self.args.token)
|
||||||
|
cookies = ModelScopeConfig.get_cookies()
|
||||||
|
if cookies is None:
|
||||||
|
raise ValueError(
|
||||||
|
'The `token` is not provided! '
|
||||||
|
'You can pass the `--token` argument, '
|
||||||
|
'or use api.login(access_token=`your_sdk_token`). '
|
||||||
|
'Your token is available at https://modelscope.cn/my/myaccesstoken'
|
||||||
|
)
|
||||||
|
|
||||||
|
if os.path.isfile(self.local_path):
|
||||||
|
commit_info = api.upload_file(
|
||||||
|
path_or_fileobj=self.local_path,
|
||||||
|
path_in_repo=self.path_in_repo,
|
||||||
|
repo_id=self.repo_id,
|
||||||
|
repo_type=self.args.repo_type,
|
||||||
|
commit_message=self.args.commit_message,
|
||||||
|
commit_description=self.args.commit_description,
|
||||||
|
)
|
||||||
|
elif os.path.isdir(self.local_path):
|
||||||
|
commit_info = api.upload_folder(
|
||||||
|
repo_id=self.repo_id,
|
||||||
|
folder_path=self.local_path,
|
||||||
|
path_in_repo=self.path_in_repo,
|
||||||
|
commit_message=self.args.commit_message,
|
||||||
|
commit_description=self.args.commit_description,
|
||||||
|
repo_type=self.args.repo_type,
|
||||||
|
allow_patterns=self.args.include,
|
||||||
|
ignore_patterns=self.args.exclude,
|
||||||
|
max_workers=self.args.max_workers,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'{self.local_path} is not a valid local path')
|
||||||
|
|
||||||
|
logger.info(f'Upload finished, commit info: {commit_info}')
|
||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import platform
|
import platform
|
||||||
@@ -13,13 +14,15 @@ from collections import defaultdict
|
|||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from http.cookiejar import CookieJar
|
from http.cookiejar import CookieJar
|
||||||
from os.path import expanduser
|
from os.path import expanduser
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from pathlib import Path
|
||||||
|
from typing import Any, BinaryIO, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import requests
|
import requests
|
||||||
from requests import Session
|
from requests import Session
|
||||||
from requests.adapters import HTTPAdapter, Retry
|
from requests.adapters import HTTPAdapter, Retry
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
||||||
API_HTTP_CLIENT_TIMEOUT,
|
API_HTTP_CLIENT_TIMEOUT,
|
||||||
@@ -29,6 +32,7 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
|||||||
API_RESPONSE_FIELD_MESSAGE,
|
API_RESPONSE_FIELD_MESSAGE,
|
||||||
API_RESPONSE_FIELD_USERNAME,
|
API_RESPONSE_FIELD_USERNAME,
|
||||||
DEFAULT_CREDENTIALS_PATH,
|
DEFAULT_CREDENTIALS_PATH,
|
||||||
|
DEFAULT_MAX_WORKERS,
|
||||||
MODELSCOPE_CLOUD_ENVIRONMENT,
|
MODELSCOPE_CLOUD_ENVIRONMENT,
|
||||||
MODELSCOPE_CLOUD_USERNAME,
|
MODELSCOPE_CLOUD_USERNAME,
|
||||||
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
|
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
|
||||||
@@ -36,25 +40,34 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_MAX_RETRIES,
|
|||||||
TEMPORARY_FOLDER_NAME, DatasetVisibility,
|
TEMPORARY_FOLDER_NAME, DatasetVisibility,
|
||||||
Licenses, ModelVisibility)
|
Licenses, ModelVisibility)
|
||||||
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
from modelscope.hub.errors import (InvalidParameter, NotExistError,
|
||||||
NotLoginException, NoValidRevisionError,
|
NotLoginException, RequestError,
|
||||||
RequestError, datahub_raise_on_error,
|
datahub_raise_on_error,
|
||||||
handle_http_post_error,
|
handle_http_post_error,
|
||||||
handle_http_response, is_ok,
|
handle_http_response, is_ok,
|
||||||
raise_for_http_status, raise_on_error)
|
raise_for_http_status, raise_on_error)
|
||||||
from modelscope.hub.git import GitCommandWrapper
|
from modelscope.hub.git import GitCommandWrapper
|
||||||
from modelscope.hub.repository import Repository
|
from modelscope.hub.repository import Repository
|
||||||
|
from modelscope.hub.utils.utils import (get_endpoint, get_readable_folder_size,
|
||||||
|
get_release_datetime,
|
||||||
|
model_id_to_group_owner_name)
|
||||||
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
|
||||||
DEFAULT_MODEL_REVISION,
|
DEFAULT_MODEL_REVISION,
|
||||||
DEFAULT_REPOSITORY_REVISION,
|
DEFAULT_REPOSITORY_REVISION,
|
||||||
MASTER_MODEL_BRANCH, META_FILES_FORMAT,
|
MASTER_MODEL_BRANCH, META_FILES_FORMAT,
|
||||||
REPO_TYPE_MODEL, ConfigFields,
|
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
|
||||||
|
REPO_TYPE_SUPPORT, ConfigFields,
|
||||||
DatasetFormations, DatasetMetaFormats,
|
DatasetFormations, DatasetMetaFormats,
|
||||||
DatasetVisibilityMap, DownloadChannel,
|
DatasetVisibilityMap, DownloadChannel,
|
||||||
DownloadMode, Frameworks, ModelFile,
|
DownloadMode, Frameworks, ModelFile,
|
||||||
Tasks, VirgoDatasetConfig)
|
Tasks, VirgoDatasetConfig)
|
||||||
|
from modelscope.utils.file_utils import get_file_hash, get_file_size
|
||||||
from modelscope.utils.logger import get_logger
|
from modelscope.utils.logger import get_logger
|
||||||
from .utils.utils import (get_endpoint, get_readable_folder_size,
|
from modelscope.utils.repo_utils import (DATASET_LFS_SUFFIX,
|
||||||
get_release_datetime, model_id_to_group_owner_name)
|
DEFAULT_IGNORE_PATTERNS,
|
||||||
|
MODEL_LFS_SUFFIX, CommitInfo,
|
||||||
|
CommitOperation, CommitOperationAdd,
|
||||||
|
RepoUtils)
|
||||||
|
from modelscope.utils.thread_utils import thread_executor
|
||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
@@ -93,6 +106,8 @@ class HubApi:
|
|||||||
getattr(self.session, method),
|
getattr(self.session, method),
|
||||||
timeout=timeout))
|
timeout=timeout))
|
||||||
|
|
||||||
|
self.upload_checker = UploadingCheck()
|
||||||
|
|
||||||
def login(
|
def login(
|
||||||
self,
|
self,
|
||||||
access_token: str,
|
access_token: str,
|
||||||
@@ -181,7 +196,7 @@ class HubApi:
|
|||||||
headers=self.builder_headers(self.headers))
|
headers=self.builder_headers(self.headers))
|
||||||
handle_http_post_error(r, path, body)
|
handle_http_post_error(r, path, body)
|
||||||
raise_on_error(r.json())
|
raise_on_error(r.json())
|
||||||
model_repo_url = f'{get_endpoint()}/{model_id}'
|
model_repo_url = f'{self.endpoint}/{model_id}'
|
||||||
return model_repo_url
|
return model_repo_url
|
||||||
|
|
||||||
def delete_model(self, model_id: str):
|
def delete_model(self, model_id: str):
|
||||||
@@ -1173,6 +1188,572 @@ class HubApi:
|
|||||||
return f'{self.endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?'
|
return f'{self.endpoint}/api/v1/datasets/{_namespace}/{_dataset_name}/repo?'
|
||||||
# return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
|
# return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
|
||||||
|
|
||||||
|
def create_repo(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
*,
|
||||||
|
token: Union[str, bool, None] = None,
|
||||||
|
visibility: Optional[str] = 'public',
|
||||||
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||||
|
chinese_name: Optional[str] = '',
|
||||||
|
license: Optional[str] = Licenses.APACHE_V2,
|
||||||
|
) -> str:
|
||||||
|
|
||||||
|
# TODO: exist_ok
|
||||||
|
|
||||||
|
if not repo_id:
|
||||||
|
raise ValueError('Repo id cannot be empty!')
|
||||||
|
|
||||||
|
if token:
|
||||||
|
self.login(access_token=token)
|
||||||
|
else:
|
||||||
|
logger.warning('No token provided, will use the cached token.')
|
||||||
|
|
||||||
|
repo_id_list = repo_id.split('/')
|
||||||
|
if len(repo_id_list) != 2:
|
||||||
|
raise ValueError('Invalid repo id, should be in the format of `owner_name/repo_name`')
|
||||||
|
namespace, repo_name = repo_id_list
|
||||||
|
|
||||||
|
if repo_type == REPO_TYPE_MODEL:
|
||||||
|
visibilities = {k: v for k, v in ModelVisibility.__dict__.items() if not k.startswith('__')}
|
||||||
|
visibility: int = visibilities.get(visibility.upper())
|
||||||
|
if visibility is None:
|
||||||
|
raise ValueError(f'Invalid visibility: {visibility}, '
|
||||||
|
f'supported visibilities: `public`, `private`, `internal`')
|
||||||
|
repo_url: str = self.create_model(
|
||||||
|
model_id=repo_id,
|
||||||
|
visibility=visibility,
|
||||||
|
license=license,
|
||||||
|
chinese_name=chinese_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif repo_type == REPO_TYPE_DATASET:
|
||||||
|
visibilities = {k: v for k, v in DatasetVisibility.__dict__.items() if not k.startswith('__')}
|
||||||
|
visibility: int = visibilities.get(visibility.upper())
|
||||||
|
if visibility is None:
|
||||||
|
raise ValueError(f'Invalid visibility: {visibility}, '
|
||||||
|
f'supported visibilities: `public`, `private`, `internal`')
|
||||||
|
repo_url: str = self.create_dataset(
|
||||||
|
dataset_name=repo_name,
|
||||||
|
namespace=namespace,
|
||||||
|
chinese_name=chinese_name,
|
||||||
|
license=license,
|
||||||
|
visibility=visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||||
|
|
||||||
|
return repo_url
|
||||||
|
|
||||||
|
def create_commit(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
operations: Iterable[CommitOperation],
|
||||||
|
*,
|
||||||
|
commit_message: str,
|
||||||
|
commit_description: Optional[str] = None,
|
||||||
|
token: str = None,
|
||||||
|
repo_type: Optional[str] = None,
|
||||||
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||||
|
) -> CommitInfo:
|
||||||
|
|
||||||
|
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/commit/{revision}'
|
||||||
|
commit_message = commit_message or f'Commit to {repo_id}'
|
||||||
|
commit_description = commit_description or ''
|
||||||
|
|
||||||
|
if token:
|
||||||
|
self.login(access_token=token)
|
||||||
|
|
||||||
|
# Construct payload
|
||||||
|
payload = self._prepare_commit_payload(
|
||||||
|
operations=operations,
|
||||||
|
commit_message=commit_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
# POST
|
||||||
|
cookies = ModelScopeConfig.get_cookies()
|
||||||
|
if cookies is None:
|
||||||
|
raise ValueError('Token does not exist, please login first.')
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=self.builder_headers(self.headers),
|
||||||
|
data=json.dumps(payload),
|
||||||
|
cookies=cookies
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = response.json()
|
||||||
|
|
||||||
|
if not resp['Success']:
|
||||||
|
commit_message = resp['Message']
|
||||||
|
logger.warning(f'{commit_message}')
|
||||||
|
|
||||||
|
return CommitInfo(
|
||||||
|
commit_url=url,
|
||||||
|
commit_message=commit_message,
|
||||||
|
commit_description=commit_description,
|
||||||
|
oid='',
|
||||||
|
)
|
||||||
|
|
||||||
|
def upload_file(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
|
||||||
|
path_in_repo: str,
|
||||||
|
repo_id: str,
|
||||||
|
token: Union[str, None] = None,
|
||||||
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||||
|
commit_message: Optional[str] = None,
|
||||||
|
commit_description: Optional[str] = None,
|
||||||
|
buffer_size_mb: Optional[int] = 1,
|
||||||
|
tqdm_desc: Optional[str] = '[Uploading]',
|
||||||
|
disable_tqdm: Optional[bool] = False,
|
||||||
|
) -> CommitInfo:
|
||||||
|
|
||||||
|
if repo_type not in REPO_TYPE_SUPPORT:
|
||||||
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||||
|
|
||||||
|
if not path_or_fileobj:
|
||||||
|
raise ValueError('Path or file object cannot be empty!')
|
||||||
|
|
||||||
|
if isinstance(path_or_fileobj, (str, Path)):
|
||||||
|
path_or_fileobj = os.path.abspath(os.path.expanduser(path_or_fileobj))
|
||||||
|
path_in_repo = path_in_repo or os.path.basename(path_or_fileobj)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# If path_or_fileobj is bytes or BinaryIO, then path_in_repo must be provided
|
||||||
|
if not path_in_repo:
|
||||||
|
raise ValueError('Arg `path_in_repo` cannot be empty!')
|
||||||
|
|
||||||
|
# Read file content if path_or_fileobj is a file-like object (BinaryIO)
|
||||||
|
# TODO: to be refined
|
||||||
|
if isinstance(path_or_fileobj, io.BufferedIOBase):
|
||||||
|
path_or_fileobj = path_or_fileobj.read()
|
||||||
|
|
||||||
|
self.upload_checker.check_file(path_or_fileobj)
|
||||||
|
self.upload_checker.check_normal_files(
|
||||||
|
file_path_list=[path_or_fileobj],
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if token:
|
||||||
|
self.login(access_token=token)
|
||||||
|
|
||||||
|
commit_message = (
|
||||||
|
commit_message if commit_message is not None else f'Upload {path_in_repo} to ModelScope hub'
|
||||||
|
)
|
||||||
|
|
||||||
|
if buffer_size_mb <= 0:
|
||||||
|
raise ValueError('Buffer size: `buffer_size_mb` must be greater than 0')
|
||||||
|
|
||||||
|
hash_info_d: dict = get_file_hash(
|
||||||
|
file_path_or_obj=path_or_fileobj,
|
||||||
|
buffer_size_mb=buffer_size_mb,
|
||||||
|
)
|
||||||
|
file_size: int = hash_info_d['file_size']
|
||||||
|
file_hash: str = hash_info_d['file_hash']
|
||||||
|
|
||||||
|
upload_res: dict = self._upload_blob(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
sha256=file_hash,
|
||||||
|
size=file_size,
|
||||||
|
data=path_or_fileobj,
|
||||||
|
disable_tqdm=disable_tqdm,
|
||||||
|
tqdm_desc=tqdm_desc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct commit info and create commit
|
||||||
|
add_operation: CommitOperationAdd = CommitOperationAdd(
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
path_or_fileobj=path_or_fileobj,
|
||||||
|
)
|
||||||
|
add_operation._upload_mode = 'lfs' if self.upload_checker.is_lfs(path_or_fileobj, repo_type) else 'normal'
|
||||||
|
add_operation._is_uploaded = upload_res['is_uploaded']
|
||||||
|
operations = [add_operation]
|
||||||
|
|
||||||
|
commit_info: CommitInfo = self.create_commit(
|
||||||
|
repo_id=repo_id,
|
||||||
|
operations=operations,
|
||||||
|
commit_message=commit_message,
|
||||||
|
commit_description=commit_description,
|
||||||
|
token=token,
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
return commit_info
|
||||||
|
|
||||||
|
def upload_folder(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repo_id: str,
|
||||||
|
folder_path: Union[str, Path],
|
||||||
|
path_in_repo: Optional[str] = '',
|
||||||
|
commit_message: Optional[str] = None,
|
||||||
|
commit_description: Optional[str] = None,
|
||||||
|
token: Union[str, None] = None,
|
||||||
|
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||||
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||||
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||||
|
max_workers: int = DEFAULT_MAX_WORKERS,
|
||||||
|
) -> CommitInfo:
|
||||||
|
|
||||||
|
if repo_type not in REPO_TYPE_SUPPORT:
|
||||||
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||||
|
|
||||||
|
allow_patterns = allow_patterns if allow_patterns else None
|
||||||
|
ignore_patterns = ignore_patterns if ignore_patterns else None
|
||||||
|
|
||||||
|
self.upload_checker.check_folder(folder_path)
|
||||||
|
|
||||||
|
# Ignore .git folder
|
||||||
|
if ignore_patterns is None:
|
||||||
|
ignore_patterns = []
|
||||||
|
elif isinstance(ignore_patterns, str):
|
||||||
|
ignore_patterns = [ignore_patterns]
|
||||||
|
ignore_patterns += DEFAULT_IGNORE_PATTERNS
|
||||||
|
|
||||||
|
if token:
|
||||||
|
self.login(access_token=token)
|
||||||
|
|
||||||
|
commit_message = (
|
||||||
|
commit_message if commit_message is not None else f'Upload folder to {repo_id} on ModelScope hub'
|
||||||
|
)
|
||||||
|
commit_description = commit_description or 'Uploading folder'
|
||||||
|
|
||||||
|
# Get the list of files to upload, e.g. [('data/abc.png', '/path/to/abc.png'), ...]
|
||||||
|
prepared_repo_objects = HubApi._prepare_upload_folder(
|
||||||
|
folder_path=folder_path,
|
||||||
|
path_in_repo=path_in_repo,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
ignore_patterns=ignore_patterns,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.upload_checker.check_normal_files(
|
||||||
|
file_path_list = [item for _, item in prepared_repo_objects],
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
@thread_executor(max_workers=max_workers, disable_tqdm=False)
|
||||||
|
def _upload_items(item_pair, **kwargs):
|
||||||
|
file_path_in_repo, file_path = item_pair
|
||||||
|
|
||||||
|
hash_info_d: dict = get_file_hash(
|
||||||
|
file_path_or_obj=file_path,
|
||||||
|
)
|
||||||
|
file_size: int = hash_info_d['file_size']
|
||||||
|
file_hash: str = hash_info_d['file_hash']
|
||||||
|
|
||||||
|
upload_res: dict = self._upload_blob(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
sha256=file_hash,
|
||||||
|
size=file_size,
|
||||||
|
data=file_path,
|
||||||
|
disable_tqdm=False if file_size > 10 * 1024 * 1024 else True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'file_path_in_repo': file_path_in_repo,
|
||||||
|
'file_path': file_path,
|
||||||
|
'is_uploaded': upload_res['is_uploaded'],
|
||||||
|
}
|
||||||
|
|
||||||
|
uploaded_items_list = _upload_items(
|
||||||
|
prepared_repo_objects,
|
||||||
|
repo_id=repo_id,
|
||||||
|
token=token,
|
||||||
|
repo_type=repo_type,
|
||||||
|
commit_message=commit_message,
|
||||||
|
commit_description=commit_description,
|
||||||
|
buffer_size_mb=1,
|
||||||
|
tqdm_desc='[Uploading]',
|
||||||
|
disable_tqdm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f'Uploading folder to {repo_id} finished')
|
||||||
|
|
||||||
|
# Construct commit info and create commit
|
||||||
|
operations = []
|
||||||
|
|
||||||
|
for item_d in uploaded_items_list:
|
||||||
|
prepared_path_in_repo: str = item_d['file_path_in_repo']
|
||||||
|
prepared_file_path: str = item_d['file_path']
|
||||||
|
is_uploaded: bool = item_d['is_uploaded']
|
||||||
|
opt = CommitOperationAdd(
|
||||||
|
path_in_repo=prepared_path_in_repo,
|
||||||
|
path_or_fileobj=prepared_file_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
# check normal or lfs
|
||||||
|
opt._upload_mode = 'lfs' if self.upload_checker.is_lfs(prepared_file_path, repo_type) else 'normal'
|
||||||
|
opt._is_uploaded = is_uploaded
|
||||||
|
operations.append(opt)
|
||||||
|
|
||||||
|
self.create_commit(
|
||||||
|
repo_id=repo_id,
|
||||||
|
operations=operations,
|
||||||
|
commit_message=commit_message,
|
||||||
|
commit_description=commit_description,
|
||||||
|
token=token,
|
||||||
|
repo_type=repo_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construct commit info
|
||||||
|
commit_url = f'{self.endpoint}/api/v1/{repo_type}s/{repo_id}/commit/{DEFAULT_REPOSITORY_REVISION}'
|
||||||
|
return CommitInfo(
|
||||||
|
commit_url=commit_url,
|
||||||
|
commit_message=commit_message,
|
||||||
|
commit_description=commit_description,
|
||||||
|
oid='')
|
||||||
|
|
||||||
|
def _upload_blob(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repo_id: str,
|
||||||
|
repo_type: str,
|
||||||
|
sha256: str,
|
||||||
|
size: int,
|
||||||
|
data: Union[str, Path, bytes, BinaryIO],
|
||||||
|
disable_tqdm: Optional[bool] = False,
|
||||||
|
tqdm_desc: Optional[str] = '[Uploading]',
|
||||||
|
buffer_size_mb: Optional[int] = 1,
|
||||||
|
) -> dict:
|
||||||
|
|
||||||
|
res_d: dict = dict(
|
||||||
|
url=None,
|
||||||
|
is_uploaded=False,
|
||||||
|
status_code=None,
|
||||||
|
status_msg=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
objects = [{'oid': sha256, 'size': size}]
|
||||||
|
upload_objects = self._validate_blob(
|
||||||
|
repo_id=repo_id,
|
||||||
|
repo_type=repo_type,
|
||||||
|
objects=objects,
|
||||||
|
)
|
||||||
|
|
||||||
|
# upload_object: {'url': 'xxx', 'oid': 'xxx'}
|
||||||
|
upload_object = upload_objects[0] if len(upload_objects) == 1 else None
|
||||||
|
|
||||||
|
if upload_object is None:
|
||||||
|
logger.info(f'Blob {sha256} has already uploaded, reuse it.')
|
||||||
|
res_d['is_uploaded'] = True
|
||||||
|
return res_d
|
||||||
|
|
||||||
|
cookies = ModelScopeConfig.get_cookies()
|
||||||
|
cookies = dict(cookies) if cookies else None
|
||||||
|
if cookies is None:
|
||||||
|
raise ValueError('Token does not exist, please login first.')
|
||||||
|
|
||||||
|
self.headers.update({'Cookie': f"m_session_id={cookies['m_session_id']}"})
|
||||||
|
headers = self.builder_headers(self.headers)
|
||||||
|
|
||||||
|
def read_in_chunks(file_object, pbar, chunk_size=buffer_size_mb * 1024 * 1024):
|
||||||
|
"""Lazy function (generator) to read a file piece by piece."""
|
||||||
|
while True:
|
||||||
|
ck = file_object.read(chunk_size)
|
||||||
|
if not ck:
|
||||||
|
break
|
||||||
|
pbar.update(len(ck))
|
||||||
|
yield ck
|
||||||
|
|
||||||
|
with tqdm(
|
||||||
|
total=size,
|
||||||
|
unit='B',
|
||||||
|
unit_scale=True,
|
||||||
|
desc=tqdm_desc,
|
||||||
|
disable=disable_tqdm
|
||||||
|
) as pbar:
|
||||||
|
|
||||||
|
if isinstance(data, (str, Path)):
|
||||||
|
with open(data, 'rb') as f:
|
||||||
|
response = requests.put(
|
||||||
|
upload_object['url'],
|
||||||
|
headers=headers,
|
||||||
|
data=read_in_chunks(f, pbar)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(data, bytes):
|
||||||
|
response = requests.put(
|
||||||
|
upload_object['url'],
|
||||||
|
headers=headers,
|
||||||
|
data=read_in_chunks(io.BytesIO(data), pbar)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(data, io.BufferedIOBase):
|
||||||
|
response = requests.put(
|
||||||
|
upload_object['url'],
|
||||||
|
headers=headers,
|
||||||
|
data=read_in_chunks(data, pbar)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError('Invalid data type to upload')
|
||||||
|
|
||||||
|
resp = response.json()
|
||||||
|
raise_on_error(resp)
|
||||||
|
|
||||||
|
res_d['url'] = upload_object['url']
|
||||||
|
res_d['status_code'] = resp['Code']
|
||||||
|
res_d['status_msg'] = resp['Message']
|
||||||
|
|
||||||
|
return res_d
|
||||||
|
|
||||||
|
def _validate_blob(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
repo_id: str,
|
||||||
|
repo_type: str,
|
||||||
|
objects: List[Dict[str, Any]],
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Check the blob has already uploaded.
|
||||||
|
True -- uploaded; False -- not uploaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repo_id (str): The repo id ModelScope.
|
||||||
|
repo_type (str): The repo type. `dataset`, `model`, etc.
|
||||||
|
objects (List[Dict[str, Any]]): The objects to check.
|
||||||
|
oid (str): The sha256 hash value.
|
||||||
|
size (int): The size of the blob.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: The result of the check.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# construct URL
|
||||||
|
url = f'{self.endpoint}/api/v1/repos/{repo_type}s/{repo_id}/info/lfs/objects/batch'
|
||||||
|
|
||||||
|
# build payload
|
||||||
|
payload = {
|
||||||
|
'operation': 'upload',
|
||||||
|
'objects': objects,
|
||||||
|
}
|
||||||
|
|
||||||
|
cookies = ModelScopeConfig.get_cookies()
|
||||||
|
if cookies is None:
|
||||||
|
raise ValueError('Token does not exist, please login first.')
|
||||||
|
response = requests.post(
|
||||||
|
url,
|
||||||
|
headers=self.builder_headers(self.headers),
|
||||||
|
data=json.dumps(payload),
|
||||||
|
cookies=cookies
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = response.json()
|
||||||
|
raise_on_error(resp)
|
||||||
|
|
||||||
|
upload_objects = [] # list of objects to upload, [{'url': 'xxx', 'oid': 'xxx'}, ...]
|
||||||
|
resp_objects = resp['Data']['objects']
|
||||||
|
for obj in resp_objects:
|
||||||
|
upload_objects.append(
|
||||||
|
{'url': obj['actions']['upload']['href'],
|
||||||
|
'oid': obj['oid']}
|
||||||
|
)
|
||||||
|
|
||||||
|
return upload_objects
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_upload_folder(
|
||||||
|
folder_path: Union[str, Path],
|
||||||
|
path_in_repo: str,
|
||||||
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||||
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||||
|
) -> List[Union[tuple, list]]:
|
||||||
|
|
||||||
|
folder_path = Path(folder_path).expanduser().resolve()
|
||||||
|
if not folder_path.is_dir():
|
||||||
|
raise ValueError(f"Provided path: '{folder_path}' is not a directory")
|
||||||
|
|
||||||
|
# List files from folder
|
||||||
|
relpath_to_abspath = {
|
||||||
|
path.relative_to(folder_path).as_posix(): path
|
||||||
|
for path in sorted(folder_path.glob('**/*')) # sorted to be deterministic
|
||||||
|
if path.is_file()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter files
|
||||||
|
filtered_repo_objects = list(
|
||||||
|
RepoUtils.filter_repo_objects(
|
||||||
|
relpath_to_abspath.keys(), allow_patterns=allow_patterns, ignore_patterns=ignore_patterns
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
prefix = f"{path_in_repo.strip('/')}/" if path_in_repo else ''
|
||||||
|
|
||||||
|
prepared_repo_objects = [
|
||||||
|
(prefix + relpath, str(relpath_to_abspath[relpath]))
|
||||||
|
for relpath in filtered_repo_objects
|
||||||
|
]
|
||||||
|
|
||||||
|
return prepared_repo_objects
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_commit_payload(
|
||||||
|
operations: Iterable[CommitOperation],
|
||||||
|
commit_message: str,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Prepare the commit payload to be sent to the ModelScope hub.
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
'commit_message': commit_message,
|
||||||
|
'actions': []
|
||||||
|
}
|
||||||
|
|
||||||
|
nb_ignored_files = 0
|
||||||
|
|
||||||
|
# 2. Send operations, one per line
|
||||||
|
for operation in operations:
|
||||||
|
|
||||||
|
# Skip ignored files
|
||||||
|
if isinstance(operation, CommitOperationAdd) and operation._should_ignore:
|
||||||
|
logger.debug(f"Skipping file '{operation.path_in_repo}' in commit (ignored by gitignore file).")
|
||||||
|
nb_ignored_files += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 2.a. Case adding a normal file
|
||||||
|
if isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'normal':
|
||||||
|
|
||||||
|
commit_action = {
|
||||||
|
'action': 'update' if operation._is_uploaded else 'create',
|
||||||
|
'path': operation.path_in_repo,
|
||||||
|
'type': 'normal',
|
||||||
|
'size': operation.upload_info.size,
|
||||||
|
'sha256': '',
|
||||||
|
'content': operation.b64content().decode(),
|
||||||
|
'encoding': 'base64',
|
||||||
|
}
|
||||||
|
payload['actions'].append(commit_action)
|
||||||
|
|
||||||
|
# 2.b. Case adding an LFS file
|
||||||
|
elif isinstance(operation, CommitOperationAdd) and operation._upload_mode == 'lfs':
|
||||||
|
|
||||||
|
commit_action = {
|
||||||
|
'action': 'update' if operation._is_uploaded else 'create',
|
||||||
|
'path': operation.path_in_repo,
|
||||||
|
'type': 'lfs',
|
||||||
|
'size': operation.upload_info.size,
|
||||||
|
'sha256': operation.upload_info.sha256,
|
||||||
|
'content': '',
|
||||||
|
'encoding': '',
|
||||||
|
}
|
||||||
|
payload['actions'].append(commit_action)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Unknown operation to commit. Operation: {operation}. Upload mode:'
|
||||||
|
f" {getattr(operation, '_upload_mode', None)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if nb_ignored_files > 0:
|
||||||
|
logger.info(f'Skipped {nb_ignored_files} file(s) in commit (ignored by gitignore file).')
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
class ModelScopeConfig:
|
class ModelScopeConfig:
|
||||||
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
|
||||||
@@ -1316,3 +1897,85 @@ class ModelScopeConfig:
|
|||||||
elif isinstance(user_agent, str):
|
elif isinstance(user_agent, str):
|
||||||
ua += '; ' + user_agent
|
ua += '; ' + user_agent
|
||||||
return ua
|
return ua
|
||||||
|
|
||||||
|
|
||||||
|
class UploadingCheck:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_file_count: int = 100_000,
|
||||||
|
max_file_count_in_dir: int = 10_000,
|
||||||
|
max_file_size: int = 50 * 1024 ** 3,
|
||||||
|
lfs_size_limit: int = 5 * 1024 * 1024,
|
||||||
|
normal_file_size_total_limit: int = 500 * 1024 * 1024,
|
||||||
|
):
|
||||||
|
self.max_file_count = max_file_count
|
||||||
|
self.max_file_count_in_dir = max_file_count_in_dir
|
||||||
|
self.max_file_size = max_file_size
|
||||||
|
self.lfs_size_limit = lfs_size_limit
|
||||||
|
self.normal_file_size_total_limit = normal_file_size_total_limit
|
||||||
|
|
||||||
|
def check_file(self, file_path_or_obj):
|
||||||
|
|
||||||
|
if isinstance(file_path_or_obj, (str, Path)):
|
||||||
|
if not os.path.exists(file_path_or_obj):
|
||||||
|
raise ValueError(f'File {file_path_or_obj} does not exist')
|
||||||
|
|
||||||
|
file_size: int = get_file_size(file_path_or_obj)
|
||||||
|
if file_size > self.max_file_size:
|
||||||
|
raise ValueError(f'File exceeds size limit: {self.max_file_size / (1024 ** 3)} GB')
|
||||||
|
|
||||||
|
def check_folder(self, folder_path: Union[str, Path]):
|
||||||
|
file_count = 0
|
||||||
|
dir_count = 0
|
||||||
|
|
||||||
|
if isinstance(folder_path, str):
|
||||||
|
folder_path = Path(folder_path)
|
||||||
|
|
||||||
|
for item in folder_path.iterdir():
|
||||||
|
if item.is_file():
|
||||||
|
file_count += 1
|
||||||
|
elif item.is_dir():
|
||||||
|
dir_count += 1
|
||||||
|
# Count items in subdirectories recursively
|
||||||
|
sub_file_count, sub_dir_count = self.check_folder(item)
|
||||||
|
if (sub_file_count + sub_dir_count) > self.max_file_count_in_dir:
|
||||||
|
raise ValueError(f'Directory {item} contains {sub_file_count + sub_dir_count} items '
|
||||||
|
f'and exceeds limit: {self.max_file_count_in_dir}')
|
||||||
|
file_count += sub_file_count
|
||||||
|
dir_count += sub_dir_count
|
||||||
|
|
||||||
|
if file_count > self.max_file_count:
|
||||||
|
raise ValueError(f'Total file count {file_count} and exceeds limit: {self.max_file_count}')
|
||||||
|
|
||||||
|
return file_count, dir_count
|
||||||
|
|
||||||
|
def is_lfs(self, file_path_or_obj: Union[str, Path, bytes, BinaryIO], repo_type: str) -> bool:
|
||||||
|
|
||||||
|
hit_lfs_suffix = True
|
||||||
|
|
||||||
|
if isinstance(file_path_or_obj, (str, Path)):
|
||||||
|
file_path_or_obj = Path(file_path_or_obj)
|
||||||
|
if not file_path_or_obj.exists():
|
||||||
|
raise ValueError(f'File {file_path_or_obj} does not exist')
|
||||||
|
|
||||||
|
if repo_type == REPO_TYPE_MODEL:
|
||||||
|
if file_path_or_obj.suffix not in MODEL_LFS_SUFFIX:
|
||||||
|
hit_lfs_suffix = False
|
||||||
|
elif repo_type == REPO_TYPE_DATASET:
|
||||||
|
if file_path_or_obj.suffix not in DATASET_LFS_SUFFIX:
|
||||||
|
hit_lfs_suffix = False
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
|
||||||
|
|
||||||
|
file_size: int = get_file_size(file_path_or_obj)
|
||||||
|
|
||||||
|
return file_size > self.lfs_size_limit or hit_lfs_suffix
|
||||||
|
|
||||||
|
def check_normal_files(self, file_path_list: List[Union[str, Path]], repo_type: str) -> None:
|
||||||
|
|
||||||
|
normal_file_list = [item for item in file_path_list if not self.is_lfs(item, repo_type)]
|
||||||
|
total_size = sum([get_file_size(item) for item in normal_file_list])
|
||||||
|
|
||||||
|
if total_size > self.normal_file_size_total_limit:
|
||||||
|
raise ValueError(f'Total size of non-lfs files {total_size/(1024 * 1024)}MB '
|
||||||
|
f'and exceeds limit: {self.normal_file_size_total_limit/(1024 * 1024)}MB')
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from typing import (Any, Callable, Dict, Iterable, List, Mapping, Optional,
|
|||||||
Sequence, Union)
|
Sequence, Union)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
|
from datasets import (Dataset, DatasetDict, Features, IterableDataset,
|
||||||
|
IterableDatasetDict)
|
||||||
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
|
from datasets.packaged_modules import _PACKAGED_DATASETS_MODULES
|
||||||
from datasets.utils.file_utils import is_relative_path
|
from datasets.utils.file_utils import is_relative_path
|
||||||
|
|
||||||
@@ -163,6 +164,7 @@ class MsDataset:
|
|||||||
download_mode: Optional[DownloadMode] = DownloadMode.
|
download_mode: Optional[DownloadMode] = DownloadMode.
|
||||||
REUSE_DATASET_IF_EXISTS,
|
REUSE_DATASET_IF_EXISTS,
|
||||||
cache_dir: Optional[str] = MS_DATASETS_CACHE,
|
cache_dir: Optional[str] = MS_DATASETS_CACHE,
|
||||||
|
features: Optional[Features] = None,
|
||||||
use_streaming: Optional[bool] = False,
|
use_streaming: Optional[bool] = False,
|
||||||
stream_batch_size: Optional[int] = 1,
|
stream_batch_size: Optional[int] = 1,
|
||||||
custom_cfg: Optional[Config] = Config(),
|
custom_cfg: Optional[Config] = Config(),
|
||||||
@@ -305,7 +307,7 @@ class MsDataset:
|
|||||||
data_files=data_files,
|
data_files=data_files,
|
||||||
split=split,
|
split=split,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
features=None,
|
features=features,
|
||||||
download_config=None,
|
download_config=None,
|
||||||
download_mode=download_mode.value,
|
download_mode=download_mode.value,
|
||||||
revision=version,
|
revision=version,
|
||||||
@@ -334,6 +336,9 @@ class MsDataset:
|
|||||||
return dataset_inst
|
return dataset_inst
|
||||||
|
|
||||||
elif hub == Hubs.virgo:
|
elif hub == Hubs.virgo:
|
||||||
|
warnings.warn(
|
||||||
|
'The option `Hubs.virgo` is deprecated, '
|
||||||
|
'will be removed in the future version.', DeprecationWarning)
|
||||||
from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
|
from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
|
||||||
from modelscope.utils.constant import VirgoDatasetConfig
|
from modelscope.utils.constant import VirgoDatasetConfig
|
||||||
# Rewrite the namespace, version and cache_dir for virgo dataset.
|
# Rewrite the namespace, version and cache_dir for virgo dataset.
|
||||||
@@ -395,8 +400,10 @@ class MsDataset:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'upload is deprecated, please use git command line to upload the dataset.',
|
'The function `upload` is deprecated, '
|
||||||
DeprecationWarning)
|
'please use git command '
|
||||||
|
'or modelscope.hub.api.HubApi.upload_folder '
|
||||||
|
'or modelscope.hub.api.HubApi.upload_file.', DeprecationWarning)
|
||||||
|
|
||||||
if not object_name:
|
if not object_name:
|
||||||
raise ValueError('object_name cannot be empty!')
|
raise ValueError('object_name cannot be empty!')
|
||||||
@@ -446,7 +453,7 @@ class MsDataset:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
'upload is deprecated, please use git command line to upload the dataset.',
|
'The function `clone_meta` is deprecated, please use git command line to clone the repo.',
|
||||||
DeprecationWarning)
|
DeprecationWarning)
|
||||||
|
|
||||||
_repo = DatasetRepository(
|
_repo = DatasetRepository(
|
||||||
@@ -487,6 +494,12 @@ class MsDataset:
|
|||||||
None
|
None
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
'The function `upload_meta` is deprecated, '
|
||||||
|
'please use git command '
|
||||||
|
'or CLI `modelscope upload owner_name/repo_name ...`.',
|
||||||
|
DeprecationWarning)
|
||||||
|
|
||||||
_repo = DatasetRepository(
|
_repo = DatasetRepository(
|
||||||
repo_work_dir=dataset_work_dir,
|
repo_work_dir=dataset_work_dir,
|
||||||
dataset_id='',
|
dataset_id='',
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import hashlib
|
||||||
import inspect
|
import inspect
|
||||||
|
import io
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import Error, copy2, copystat
|
from shutil import Error, copy2, copystat
|
||||||
|
from typing import BinaryIO, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
# TODO: remove this api, unify to flattened args
|
# TODO: remove this api, unify to flattened args
|
||||||
@@ -180,3 +182,85 @@ def copytree_py37(src,
|
|||||||
if errors:
|
if errors:
|
||||||
raise Error(errors)
|
raise Error(errors)
|
||||||
return dst
|
return dst
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_size(file_path_or_obj: Union[str, Path, bytes, BinaryIO]) -> int:
|
||||||
|
if isinstance(file_path_or_obj, (str, Path)):
|
||||||
|
file_path = Path(file_path_or_obj)
|
||||||
|
return file_path.stat().st_size
|
||||||
|
elif isinstance(file_path_or_obj, bytes):
|
||||||
|
return len(file_path_or_obj)
|
||||||
|
elif isinstance(file_path_or_obj, io.BufferedIOBase):
|
||||||
|
current_position = file_path_or_obj.tell()
|
||||||
|
file_path_or_obj.seek(0, os.SEEK_END)
|
||||||
|
size = file_path_or_obj.tell()
|
||||||
|
file_path_or_obj.seek(current_position)
|
||||||
|
return size
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'Unsupported type: must be string, Path, bytes, or io.BufferedIOBase'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_hash(
|
||||||
|
file_path_or_obj: Union[str, Path, bytes, BinaryIO],
|
||||||
|
buffer_size_mb: Optional[int] = 1,
|
||||||
|
tqdm_desc: Optional[str] = '[Calculating]',
|
||||||
|
disable_tqdm: Optional[bool] = True,
|
||||||
|
) -> dict:
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
file_size = get_file_size(file_path_or_obj)
|
||||||
|
buffer_size = buffer_size_mb * 1024 * 1024
|
||||||
|
file_hash = hashlib.sha256()
|
||||||
|
chunk_hash_list = []
|
||||||
|
|
||||||
|
progress = tqdm(
|
||||||
|
total=file_size,
|
||||||
|
initial=0,
|
||||||
|
unit_scale=True,
|
||||||
|
dynamic_ncols=True,
|
||||||
|
unit='B',
|
||||||
|
desc=tqdm_desc,
|
||||||
|
disable=disable_tqdm,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(file_path_or_obj, (str, Path)):
|
||||||
|
with open(file_path_or_obj, 'rb') as f:
|
||||||
|
while byte_chunk := f.read(buffer_size):
|
||||||
|
chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
|
||||||
|
file_hash.update(byte_chunk)
|
||||||
|
progress.update(len(byte_chunk))
|
||||||
|
file_hash = file_hash.hexdigest()
|
||||||
|
final_chunk_size = buffer_size
|
||||||
|
|
||||||
|
elif isinstance(file_path_or_obj, bytes):
|
||||||
|
file_hash.update(file_path_or_obj)
|
||||||
|
file_hash = file_hash.hexdigest()
|
||||||
|
chunk_hash_list.append(file_hash)
|
||||||
|
final_chunk_size = len(file_path_or_obj)
|
||||||
|
progress.update(final_chunk_size)
|
||||||
|
|
||||||
|
elif isinstance(file_path_or_obj, io.BufferedIOBase):
|
||||||
|
while byte_chunk := file_path_or_obj.read(buffer_size):
|
||||||
|
chunk_hash_list.append(hashlib.sha256(byte_chunk).hexdigest())
|
||||||
|
file_hash.update(byte_chunk)
|
||||||
|
progress.update(len(byte_chunk))
|
||||||
|
file_hash = file_hash.hexdigest()
|
||||||
|
final_chunk_size = buffer_size
|
||||||
|
|
||||||
|
else:
|
||||||
|
progress.close()
|
||||||
|
raise ValueError(
|
||||||
|
'Input must be str, Path, bytes or a io.BufferedIOBase')
|
||||||
|
|
||||||
|
progress.close()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'file_path_or_obj': file_path_or_obj,
|
||||||
|
'file_hash': file_hash,
|
||||||
|
'file_size': file_size,
|
||||||
|
'chunk_size': final_chunk_size,
|
||||||
|
'chunk_nums': len(chunk_hash_list),
|
||||||
|
'chunk_hash_list': chunk_hash_list,
|
||||||
|
}
|
||||||
|
|||||||
479
modelscope/utils/repo_utils.py
Normal file
479
modelscope/utils/repo_utils.py
Normal file
@@ -0,0 +1,479 @@
|
|||||||
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
# Copyright 2022-present, the HuggingFace Inc. team.
|
||||||
|
import base64
|
||||||
|
import functools
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from fnmatch import fnmatch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import (BinaryIO, Callable, Generator, Iterable, Iterator, List,
|
||||||
|
Literal, Optional, TypeVar, Union)
|
||||||
|
|
||||||
|
from modelscope.utils.file_utils import get_file_hash
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
# Always ignore `.git` and `.cache/modelscope` folders in commits
|
||||||
|
DEFAULT_IGNORE_PATTERNS = [
|
||||||
|
'.git',
|
||||||
|
'.git/*',
|
||||||
|
'*/.git',
|
||||||
|
'**/.git/**',
|
||||||
|
'.cache/modelscope',
|
||||||
|
'.cache/modelscope/*',
|
||||||
|
'*/.cache/modelscope',
|
||||||
|
'**/.cache/modelscope/**',
|
||||||
|
]
|
||||||
|
# Forbidden to commit these folders
|
||||||
|
FORBIDDEN_FOLDERS = ['.git', '.cache']
|
||||||
|
|
||||||
|
UploadMode = Literal['lfs', 'normal']
|
||||||
|
|
||||||
|
DATASET_LFS_SUFFIX = [
|
||||||
|
'.7z',
|
||||||
|
'.aac',
|
||||||
|
'.arrow',
|
||||||
|
'.audio',
|
||||||
|
'.bmp',
|
||||||
|
'.bin',
|
||||||
|
'.bz2',
|
||||||
|
'.flac',
|
||||||
|
'.ftz',
|
||||||
|
'.gif',
|
||||||
|
'.gz',
|
||||||
|
'.h5',
|
||||||
|
'.jack',
|
||||||
|
'.jpeg',
|
||||||
|
'.jpg',
|
||||||
|
'.jsonl',
|
||||||
|
'.joblib',
|
||||||
|
'.lz4',
|
||||||
|
'.msgpack',
|
||||||
|
'.npy',
|
||||||
|
'.npz',
|
||||||
|
'.ot',
|
||||||
|
'.parquet',
|
||||||
|
'.pb',
|
||||||
|
'.pickle',
|
||||||
|
'.pcm',
|
||||||
|
'.pkl',
|
||||||
|
'.raw',
|
||||||
|
'.rar',
|
||||||
|
'.sam',
|
||||||
|
'.tar',
|
||||||
|
'.tgz',
|
||||||
|
'.wasm',
|
||||||
|
'.wav',
|
||||||
|
'.webm',
|
||||||
|
'.webp',
|
||||||
|
'.zip',
|
||||||
|
'.zst',
|
||||||
|
'.tiff',
|
||||||
|
'.mp3',
|
||||||
|
'.mp4',
|
||||||
|
'.ogg',
|
||||||
|
]
|
||||||
|
|
||||||
|
MODEL_LFS_SUFFIX = [
|
||||||
|
'.7z',
|
||||||
|
'.arrow',
|
||||||
|
'.bin',
|
||||||
|
'.bz2',
|
||||||
|
'.ckpt',
|
||||||
|
'.ftz',
|
||||||
|
'.gz',
|
||||||
|
'.h5',
|
||||||
|
'.joblib',
|
||||||
|
'.mlmodel',
|
||||||
|
'.model',
|
||||||
|
'.msgpack',
|
||||||
|
'.npy',
|
||||||
|
'.npz',
|
||||||
|
'.onnx',
|
||||||
|
'.ot',
|
||||||
|
'.parquet',
|
||||||
|
'.pb',
|
||||||
|
'.pickle',
|
||||||
|
'.pkl',
|
||||||
|
'.pt',
|
||||||
|
'.pth',
|
||||||
|
'.rar',
|
||||||
|
'.safetensors',
|
||||||
|
'.tar',
|
||||||
|
'.tflite',
|
||||||
|
'.tgz',
|
||||||
|
'.wasm',
|
||||||
|
'.xz',
|
||||||
|
'.zip',
|
||||||
|
'.zst',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RepoUtils:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def filter_repo_objects(
|
||||||
|
items: Iterable[T],
|
||||||
|
*,
|
||||||
|
allow_patterns: Optional[Union[List[str], str]] = None,
|
||||||
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
||||||
|
key: Optional[Callable[[T], str]] = None,
|
||||||
|
) -> Generator[T, None, None]:
|
||||||
|
"""Filter repo objects based on an allowlist and a denylist.
|
||||||
|
|
||||||
|
Input must be a list of paths (`str` or `Path`) or a list of arbitrary objects.
|
||||||
|
In the later case, `key` must be provided and specifies a function of one argument
|
||||||
|
that is used to extract a path from each element in iterable.
|
||||||
|
|
||||||
|
Patterns are Unix shell-style wildcards which are NOT regular expressions. See
|
||||||
|
https://docs.python.org/3/library/fnmatch.html for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items (`Iterable`):
|
||||||
|
List of items to filter.
|
||||||
|
allow_patterns (`str` or `List[str]`, *optional*):
|
||||||
|
Patterns constituting the allowlist. If provided, item paths must match at
|
||||||
|
least one pattern from the allowlist.
|
||||||
|
ignore_patterns (`str` or `List[str]`, *optional*):
|
||||||
|
Patterns constituting the denylist. If provided, item paths must not match
|
||||||
|
any patterns from the denylist.
|
||||||
|
key (`Callable[[T], str]`, *optional*):
|
||||||
|
Single-argument function to extract a path from each item. If not provided,
|
||||||
|
the `items` must already be `str` or `Path`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Filtered list of objects, as a generator.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
:class:`ValueError`:
|
||||||
|
If `key` is not provided and items are not `str` or `Path`.
|
||||||
|
|
||||||
|
Example usage with paths:
|
||||||
|
```python
|
||||||
|
>>> # Filter only PDFs that are not hidden.
|
||||||
|
>>> list(RepoUtils.filter_repo_objects(
|
||||||
|
... ["aaa.PDF", "bbb.jpg", ".ccc.pdf", ".ddd.png"],
|
||||||
|
... allow_patterns=["*.pdf"],
|
||||||
|
... ignore_patterns=[".*"],
|
||||||
|
... ))
|
||||||
|
["aaa.pdf"]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
allow_patterns = allow_patterns if allow_patterns else None
|
||||||
|
ignore_patterns = ignore_patterns if ignore_patterns else None
|
||||||
|
|
||||||
|
if isinstance(allow_patterns, str):
|
||||||
|
allow_patterns = [allow_patterns]
|
||||||
|
|
||||||
|
if isinstance(ignore_patterns, str):
|
||||||
|
ignore_patterns = [ignore_patterns]
|
||||||
|
|
||||||
|
if allow_patterns is not None:
|
||||||
|
allow_patterns = [
|
||||||
|
RepoUtils._add_wildcard_to_directories(p)
|
||||||
|
for p in allow_patterns
|
||||||
|
]
|
||||||
|
if ignore_patterns is not None:
|
||||||
|
ignore_patterns = [
|
||||||
|
RepoUtils._add_wildcard_to_directories(p)
|
||||||
|
for p in ignore_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
if key is None:
|
||||||
|
|
||||||
|
def _identity(item: T) -> str:
|
||||||
|
if isinstance(item, str):
|
||||||
|
return item
|
||||||
|
if isinstance(item, Path):
|
||||||
|
return str(item)
|
||||||
|
raise ValueError(
|
||||||
|
f'Please provide `key` argument in `filter_repo_objects`: `{item}` is not a string.'
|
||||||
|
)
|
||||||
|
|
||||||
|
key = _identity # Items must be `str` or `Path`, otherwise raise ValueError
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
path = key(item)
|
||||||
|
|
||||||
|
# Skip if there's an allowlist and path doesn't match any
|
||||||
|
if allow_patterns is not None and not any(
|
||||||
|
fnmatch(path, r) for r in allow_patterns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Skip if there's a denylist and path matches any
|
||||||
|
if ignore_patterns is not None and any(
|
||||||
|
fnmatch(path, r) for r in ignore_patterns):
|
||||||
|
continue
|
||||||
|
|
||||||
|
yield item
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_wildcard_to_directories(pattern: str) -> str:
|
||||||
|
if pattern[-1] == '/':
|
||||||
|
return pattern + '*'
|
||||||
|
return pattern
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommitInfo(str):
|
||||||
|
"""Data structure containing information about a newly created commit.
|
||||||
|
|
||||||
|
Returned by any method that creates a commit on the Hub: [`create_commit`], [`upload_file`], [`upload_folder`],
|
||||||
|
[`delete_file`], [`delete_folder`]. It inherits from `str` for backward compatibility but using methods specific
|
||||||
|
to `str` is deprecated.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
commit_url (`str`):
|
||||||
|
Url where to find the commit.
|
||||||
|
|
||||||
|
commit_message (`str`):
|
||||||
|
The summary (first line) of the commit that has been created.
|
||||||
|
|
||||||
|
commit_description (`str`):
|
||||||
|
Description of the commit that has been created. Can be empty.
|
||||||
|
|
||||||
|
oid (`str`):
|
||||||
|
Commit hash id. Example: `"91c54ad1727ee830252e457677f467be0bfd8a57"`.
|
||||||
|
|
||||||
|
pr_url (`str`, *optional*):
|
||||||
|
Url to the PR that has been created, if any. Populated when `create_pr=True`
|
||||||
|
is passed.
|
||||||
|
|
||||||
|
pr_revision (`str`, *optional*):
|
||||||
|
Revision of the PR that has been created, if any. Populated when
|
||||||
|
`create_pr=True` is passed. Example: `"refs/pr/1"`.
|
||||||
|
|
||||||
|
pr_num (`int`, *optional*):
|
||||||
|
Number of the PR discussion that has been created, if any. Populated when
|
||||||
|
`create_pr=True` is passed. Can be passed as `discussion_num` in
|
||||||
|
[`get_discussion_details`]. Example: `1`.
|
||||||
|
|
||||||
|
_url (`str`, *optional*):
|
||||||
|
Legacy url for `str` compatibility. Can be the url to the uploaded file on the Hub (if returned by
|
||||||
|
[`upload_file`]), to the uploaded folder on the Hub (if returned by [`upload_folder`]) or to the commit on
|
||||||
|
the Hub (if returned by [`create_commit`]). Defaults to `commit_url`. It is deprecated to use this
|
||||||
|
attribute. Please use `commit_url` instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
commit_url: str
|
||||||
|
commit_message: str
|
||||||
|
commit_description: str
|
||||||
|
oid: str
|
||||||
|
pr_url: Optional[str] = None
|
||||||
|
|
||||||
|
# Computed from `pr_url` in `__post_init__`
|
||||||
|
pr_revision: Optional[str] = field(init=False)
|
||||||
|
pr_num: Optional[str] = field(init=False)
|
||||||
|
|
||||||
|
# legacy url for `str` compatibility (ex: url to uploaded file, url to uploaded folder, url to PR, etc.)
|
||||||
|
_url: str = field(
|
||||||
|
repr=False, default=None) # type: ignore # defaults to `commit_url`
|
||||||
|
|
||||||
|
def __new__(cls,
|
||||||
|
*args,
|
||||||
|
commit_url: str,
|
||||||
|
_url: Optional[str] = None,
|
||||||
|
**kwargs):
|
||||||
|
return str.__new__(cls, _url or commit_url)
|
||||||
|
|
||||||
|
def to_dict(cls):
|
||||||
|
return {
|
||||||
|
'commit_url': cls.commit_url,
|
||||||
|
'commit_message': cls.commit_message,
|
||||||
|
'commit_description': cls.commit_description,
|
||||||
|
'oid': cls.oid,
|
||||||
|
'pr_url': cls.pr_url,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def git_hash(data: bytes) -> str:
|
||||||
|
"""
|
||||||
|
Computes the git-sha1 hash of the given bytes, using the same algorithm as git.
|
||||||
|
|
||||||
|
This is equivalent to running `git hash-object`. See https://git-scm.com/docs/git-hash-object
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
Note: this method is valid for regular files. For LFS files, the proper git hash is supposed to be computed on the
|
||||||
|
pointer file content, not the actual file content. However, for simplicity, we directly compare the sha256 of
|
||||||
|
the LFS file content when we want to compare LFS files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (`bytes`):
|
||||||
|
The data to compute the git-hash for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`str`: the git-hash of `data` as an hexadecimal string.
|
||||||
|
"""
|
||||||
|
_kwargs = {'usedforsecurity': False} if sys.version_info >= (3, 9) else {}
|
||||||
|
sha1 = functools.partial(hashlib.sha1, **_kwargs)
|
||||||
|
sha = sha1()
|
||||||
|
sha.update(b'blob ')
|
||||||
|
sha.update(str(len(data)).encode())
|
||||||
|
sha.update(b'\0')
|
||||||
|
sha.update(data)
|
||||||
|
return sha.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UploadInfo:
|
||||||
|
"""
|
||||||
|
Dataclass holding required information to determine whether a blob
|
||||||
|
should be uploaded to the hub using the LFS protocol or the regular protocol
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sha256 (`str`):
|
||||||
|
SHA256 hash of the blob
|
||||||
|
size (`int`):
|
||||||
|
Size in bytes of the blob
|
||||||
|
sample (`bytes`):
|
||||||
|
First 512 bytes of the blob
|
||||||
|
"""
|
||||||
|
|
||||||
|
sha256: str
|
||||||
|
size: int
|
||||||
|
sample: bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_path(cls, path: str):
|
||||||
|
|
||||||
|
file_hash_info: dict = get_file_hash(path)
|
||||||
|
size = file_hash_info['file_size']
|
||||||
|
sha = file_hash_info['file_hash']
|
||||||
|
sample = open(path, 'rb').read(512)
|
||||||
|
|
||||||
|
return cls(sha256=sha, size=size, sample=sample)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_bytes(cls, data: bytes):
|
||||||
|
sha = get_file_hash(data)['file_hash']
|
||||||
|
return cls(size=len(data), sample=data[:512], sha256=sha)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_fileobj(cls, fileobj: BinaryIO):
|
||||||
|
fileobj_info: dict = get_file_hash(fileobj)
|
||||||
|
sample = fileobj.read(512)
|
||||||
|
return cls(
|
||||||
|
sha256=fileobj_info['file_hash'],
|
||||||
|
size=fileobj_info['file_size'],
|
||||||
|
sample=sample)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommitOperationAdd:
|
||||||
|
"""Data structure containing information about a file to be added to a commit."""
|
||||||
|
|
||||||
|
path_in_repo: str
|
||||||
|
path_or_fileobj: Union[str, Path, bytes, BinaryIO]
|
||||||
|
upload_info: UploadInfo = field(init=False, repr=False)
|
||||||
|
|
||||||
|
# Internal attributes
|
||||||
|
|
||||||
|
# set to "lfs" or "regular" once known
|
||||||
|
_upload_mode: Optional[UploadMode] = field(
|
||||||
|
init=False, repr=False, default=None)
|
||||||
|
|
||||||
|
# set to True if .gitignore rules prevent the file from being uploaded as LFS
|
||||||
|
# (server-side check)
|
||||||
|
_should_ignore: Optional[bool] = field(
|
||||||
|
init=False, repr=False, default=None)
|
||||||
|
|
||||||
|
# set to the remote OID of the file if it has already been uploaded
|
||||||
|
# useful to determine if a commit will be empty or not
|
||||||
|
_remote_oid: Optional[str] = field(init=False, repr=False, default=None)
|
||||||
|
|
||||||
|
# set to True once the file has been uploaded as LFS
|
||||||
|
_is_uploaded: bool = field(init=False, repr=False, default=False)
|
||||||
|
|
||||||
|
# set to True once the file has been committed
|
||||||
|
_is_committed: bool = field(init=False, repr=False, default=False)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
"""Validates `path_or_fileobj` and compute `upload_info`."""
|
||||||
|
|
||||||
|
# Validate `path_or_fileobj` value
|
||||||
|
if isinstance(self.path_or_fileobj, Path):
|
||||||
|
self.path_or_fileobj = str(self.path_or_fileobj)
|
||||||
|
if isinstance(self.path_or_fileobj, str):
|
||||||
|
path_or_fileobj = os.path.normpath(
|
||||||
|
os.path.expanduser(self.path_or_fileobj))
|
||||||
|
if not os.path.isfile(path_or_fileobj):
|
||||||
|
raise ValueError(
|
||||||
|
f"Provided path: '{path_or_fileobj}' is not a file on the local file system"
|
||||||
|
)
|
||||||
|
elif not isinstance(self.path_or_fileobj, (io.BufferedIOBase, bytes)):
|
||||||
|
raise ValueError(
|
||||||
|
'path_or_fileobj must be either an instance of str, bytes or'
|
||||||
|
' io.BufferedIOBase. If you passed a file-like object, make sure it is'
|
||||||
|
' in binary mode.')
|
||||||
|
if isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
||||||
|
try:
|
||||||
|
self.path_or_fileobj.tell()
|
||||||
|
self.path_or_fileobj.seek(0, os.SEEK_CUR)
|
||||||
|
except (OSError, AttributeError) as exc:
|
||||||
|
raise ValueError(
|
||||||
|
'path_or_fileobj is a file-like object but does not implement seek() and tell()'
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Compute "upload_info" attribute
|
||||||
|
if isinstance(self.path_or_fileobj, str):
|
||||||
|
self.upload_info = UploadInfo.from_path(self.path_or_fileobj)
|
||||||
|
elif isinstance(self.path_or_fileobj, bytes):
|
||||||
|
self.upload_info = UploadInfo.from_bytes(self.path_or_fileobj)
|
||||||
|
else:
|
||||||
|
self.upload_info = UploadInfo.from_fileobj(self.path_or_fileobj)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def as_file(self) -> Iterator[BinaryIO]:
|
||||||
|
"""
|
||||||
|
A context manager that yields a file-like object allowing to read the underlying
|
||||||
|
data behind `path_or_fileobj`.
|
||||||
|
"""
|
||||||
|
if isinstance(self.path_or_fileobj, str) or isinstance(
|
||||||
|
self.path_or_fileobj, Path):
|
||||||
|
with open(self.path_or_fileobj, 'rb') as file:
|
||||||
|
yield file
|
||||||
|
elif isinstance(self.path_or_fileobj, bytes):
|
||||||
|
yield io.BytesIO(self.path_or_fileobj)
|
||||||
|
elif isinstance(self.path_or_fileobj, io.BufferedIOBase):
|
||||||
|
prev_pos = self.path_or_fileobj.tell()
|
||||||
|
yield self.path_or_fileobj
|
||||||
|
self.path_or_fileobj.seek(prev_pos, 0)
|
||||||
|
|
||||||
|
def b64content(self) -> bytes:
|
||||||
|
"""
|
||||||
|
The base64-encoded content of `path_or_fileobj`
|
||||||
|
|
||||||
|
Returns: `bytes`
|
||||||
|
"""
|
||||||
|
with self.as_file() as file:
|
||||||
|
return base64.b64encode(file.read())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _local_oid(self) -> Optional[str]:
|
||||||
|
"""Return the OID of the local file.
|
||||||
|
|
||||||
|
This OID is then compared to `self._remote_oid` to check if the file has changed compared to the remote one.
|
||||||
|
If the file did not change, we won't upload it again to prevent empty commits.
|
||||||
|
|
||||||
|
For LFS files, the OID corresponds to the SHA256 of the file content (used a LFS ref).
|
||||||
|
For regular files, the OID corresponds to the SHA1 of the file content.
|
||||||
|
Note: this is slightly different to git OID computation since the oid of an LFS file is usually the git-SHA1
|
||||||
|
of the pointer file content (not the actual file content). However, using the SHA256 is enough to detect
|
||||||
|
changes and more convenient client-side.
|
||||||
|
"""
|
||||||
|
if self._upload_mode is None:
|
||||||
|
return None
|
||||||
|
elif self._upload_mode == 'lfs':
|
||||||
|
return self.upload_info.sha256
|
||||||
|
else:
|
||||||
|
# Regular file => compute sha1
|
||||||
|
# => no need to read by chunk since the file is guaranteed to be <=5MB.
|
||||||
|
with self.as_file() as file:
|
||||||
|
return git_hash(file.read())
|
||||||
|
|
||||||
|
|
||||||
|
CommitOperation = Union[CommitOperationAdd, ]
|
||||||
@@ -12,13 +12,15 @@ logger = get_logger()
|
|||||||
|
|
||||||
|
|
||||||
def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
|
def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
|
||||||
disable_tqdm=False):
|
disable_tqdm: bool = False,
|
||||||
|
tqdm_desc: str = None):
|
||||||
"""
|
"""
|
||||||
A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
|
A decorator to execute a function in a threaded manner using ThreadPoolExecutor.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_workers (int): The maximum number of threads to use.
|
max_workers (int): The maximum number of threads to use.
|
||||||
disable_tqdm (bool): disable progress bar.
|
disable_tqdm (bool): disable progress bar.
|
||||||
|
tqdm_desc (str): Desc of tqdm.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
function: A wrapped function that executes with threading and a progress bar.
|
function: A wrapped function that executes with threading and a progress bar.
|
||||||
@@ -43,8 +45,11 @@ def thread_executor(max_workers: int = DEFAULT_MAX_WORKERS,
|
|||||||
results = []
|
results = []
|
||||||
# Create a tqdm progress bar with the total number of items to process
|
# Create a tqdm progress bar with the total number of items to process
|
||||||
with tqdm(
|
with tqdm(
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1024,
|
||||||
|
initial=0,
|
||||||
total=len(iterable),
|
total=len(iterable),
|
||||||
desc=f'Processing {len(iterable)} items',
|
desc=tqdm_desc or f'Processing {len(iterable)} items',
|
||||||
disable=disable_tqdm,
|
disable=disable_tqdm,
|
||||||
) as pbar:
|
) as pbar:
|
||||||
# Define a wrapper function to update the progress bar
|
# Define a wrapper function to update the progress bar
|
||||||
|
|||||||
Reference in New Issue
Block a user