mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
761 lines
30 KiB
Python
761 lines
30 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import contextlib
|
|
import importlib
|
|
import inspect
|
|
import os
|
|
import re
|
|
import sys
|
|
from asyncio import Future
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from types import MethodType
|
|
from typing import BinaryIO, Dict, Iterable, List, Optional, Union
|
|
|
|
from modelscope.hub.constants import DEFAULT_MODELSCOPE_DATA_ENDPOINT
|
|
from modelscope.utils.repo_utils import (CommitInfo, CommitOperation,
|
|
CommitOperationAdd)
|
|
|
|
ignore_file_pattern = [
|
|
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5',
|
|
r'\w+\.ckpt'
|
|
]
|
|
|
|
|
|
def get_all_imported_modules():
|
|
"""Find all modules in transformers/peft/diffusers"""
|
|
all_imported_modules = []
|
|
transformers_include_names = [
|
|
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
|
|
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
|
|
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast',
|
|
'Pipeline'
|
|
]
|
|
peft_include_names = ['.*PeftModel.*', '.*Config']
|
|
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
|
|
if importlib.util.find_spec('transformers') is not None:
|
|
import transformers
|
|
lazy_module = sys.modules['transformers']
|
|
_import_structure = lazy_module._import_structure
|
|
for key in _import_structure:
|
|
if 'dummy' in key.lower():
|
|
continue
|
|
values = _import_structure[key]
|
|
for value in values:
|
|
# pretrained
|
|
if any([
|
|
re.fullmatch(name, value)
|
|
for name in transformers_include_names
|
|
]):
|
|
try:
|
|
module = importlib.import_module(
|
|
f'.{key}', transformers.__name__)
|
|
value = getattr(module, value)
|
|
all_imported_modules.append(value)
|
|
except: # noqa
|
|
pass
|
|
|
|
if importlib.util.find_spec('peft') is not None:
|
|
try:
|
|
import peft
|
|
except: # noqa
|
|
pass
|
|
else:
|
|
attributes = dir(peft)
|
|
imports = [
|
|
attr for attr in attributes if not attr.startswith('__')
|
|
]
|
|
all_imported_modules.extend([
|
|
getattr(peft, _import) for _import in imports if any([
|
|
re.fullmatch(name, _import) for name in peft_include_names
|
|
])
|
|
])
|
|
|
|
if importlib.util.find_spec('diffusers') is not None:
|
|
try:
|
|
import diffusers
|
|
except: # noqa
|
|
pass
|
|
else:
|
|
lazy_module = sys.modules['diffusers']
|
|
if hasattr(lazy_module, '_import_structure'):
|
|
_import_structure = lazy_module._import_structure
|
|
for key in _import_structure:
|
|
if 'dummy' in key.lower():
|
|
continue
|
|
values = _import_structure[key]
|
|
for value in values:
|
|
if any([
|
|
re.fullmatch(name, value)
|
|
for name in diffusers_include_names
|
|
]):
|
|
try:
|
|
module = importlib.import_module(
|
|
f'.{key}', diffusers.__name__)
|
|
value = getattr(module, value)
|
|
all_imported_modules.append(value)
|
|
except: # noqa
|
|
pass
|
|
else:
|
|
attributes = dir(lazy_module)
|
|
imports = [
|
|
attr for attr in attributes if not attr.startswith('__')
|
|
]
|
|
all_imported_modules.extend([
|
|
getattr(lazy_module, _import) for _import in imports
|
|
if any([
|
|
re.fullmatch(name, _import)
|
|
for name in diffusers_include_names
|
|
])
|
|
])
|
|
return all_imported_modules
|
|
|
|
|
|
def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|
"""Patch all class to download from modelscope
|
|
|
|
Args:
|
|
wrap: Wrap the class or monkey patch the original class
|
|
|
|
Returns:
|
|
The classes after patched
|
|
"""
|
|
|
|
def get_model_dir(pretrained_model_name_or_path,
|
|
ignore_file_pattern=None,
|
|
allow_file_pattern=None,
|
|
**kwargs):
|
|
from modelscope import snapshot_download
|
|
subfolder = kwargs.pop('subfolder', None)
|
|
file_filter = None
|
|
if subfolder:
|
|
file_filter = f'{subfolder}/*'
|
|
if not os.path.exists(pretrained_model_name_or_path):
|
|
revision = kwargs.pop('revision', None)
|
|
if revision is None or revision == 'main':
|
|
revision = 'master'
|
|
if file_filter is not None:
|
|
allow_file_pattern = file_filter
|
|
model_dir = snapshot_download(
|
|
pretrained_model_name_or_path,
|
|
revision=revision,
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
allow_file_pattern=allow_file_pattern)
|
|
if subfolder:
|
|
model_dir = os.path.join(model_dir, subfolder)
|
|
else:
|
|
model_dir = pretrained_model_name_or_path
|
|
return model_dir
|
|
|
|
def patch_pretrained_model_name_or_path(cls, pretrained_model_name_or_path,
|
|
*model_args, **kwargs):
|
|
"""Patch all from_pretrained"""
|
|
model_dir = get_model_dir(pretrained_model_name_or_path,
|
|
kwargs.pop('ignore_file_pattern', None),
|
|
kwargs.pop('allow_file_pattern', None),
|
|
**kwargs)
|
|
return cls._from_pretrained_origin.__func__(cls, model_dir,
|
|
*model_args, **kwargs)
|
|
|
|
def patch_get_config_dict(cls, pretrained_model_name_or_path, *model_args,
|
|
**kwargs):
|
|
"""Patch all get_config_dict"""
|
|
model_dir = get_model_dir(pretrained_model_name_or_path,
|
|
kwargs.pop('ignore_file_pattern', None),
|
|
kwargs.pop('allow_file_pattern', None),
|
|
**kwargs)
|
|
return cls._get_config_dict_origin.__func__(cls, model_dir,
|
|
*model_args, **kwargs)
|
|
|
|
def patch_peft_model_id(cls, model, model_id, *model_args, **kwargs):
|
|
"""Patch all peft.from_pretrained"""
|
|
model_dir = get_model_dir(model_id,
|
|
kwargs.pop('ignore_file_pattern', None),
|
|
kwargs.pop('allow_file_pattern', None),
|
|
**kwargs)
|
|
return cls._from_pretrained_origin.__func__(cls, model, model_dir,
|
|
*model_args, **kwargs)
|
|
|
|
def patch_get_peft_type(cls, model_id, **kwargs):
|
|
"""Patch all _get_peft_type"""
|
|
model_dir = get_model_dir(model_id,
|
|
kwargs.pop('ignore_file_pattern', None),
|
|
kwargs.pop('allow_file_pattern', None),
|
|
**kwargs)
|
|
return cls._get_peft_type_origin.__func__(cls, model_dir, **kwargs)
|
|
|
|
def get_wrapped_class(
|
|
module_class: 'PreTrainedModel',
|
|
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
|
allow_file_pattern: Optional[Union[str, List[str]]] = None,
|
|
**kwargs):
|
|
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
|
|
Args:
|
|
module_class (`PreTrainedModel`): The actual module class
|
|
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
|
|
Any file pattern to be ignored, like exact file names or file extensions.
|
|
allow_file_pattern (`str` or `List`, *optional*, default to `None`):
|
|
Any file pattern to be included, like exact file names or file extensions.
|
|
Returns:
|
|
The wrapped class
|
|
"""
|
|
|
|
def from_pretrained(model, model_id, *model_args, **kwargs):
|
|
# model is an instance
|
|
model_dir = get_model_dir(
|
|
model_id,
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
allow_file_pattern=allow_file_pattern,
|
|
**kwargs)
|
|
|
|
module_obj = module_class.from_pretrained(model, model_dir,
|
|
*model_args, **kwargs)
|
|
|
|
return module_obj
|
|
|
|
class ClassWrapper(module_class):
|
|
|
|
@classmethod
|
|
def from_pretrained(cls, pretrained_model_name_or_path,
|
|
*model_args, **kwargs):
|
|
model_dir = get_model_dir(
|
|
pretrained_model_name_or_path,
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
allow_file_pattern=allow_file_pattern,
|
|
**kwargs)
|
|
|
|
module_obj = module_class.from_pretrained(
|
|
model_dir, *model_args, **kwargs)
|
|
|
|
if module_class.__name__.startswith('AutoModel'):
|
|
module_obj.model_dir = model_dir
|
|
return module_obj
|
|
|
|
@classmethod
|
|
def _get_peft_type(cls, model_id, **kwargs):
|
|
model_dir = get_model_dir(
|
|
model_id,
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
allow_file_pattern=allow_file_pattern,
|
|
**kwargs)
|
|
module_obj = module_class._get_peft_type(model_dir, **kwargs)
|
|
return module_obj
|
|
|
|
@classmethod
|
|
def get_config_dict(cls, pretrained_model_name_or_path,
|
|
*model_args, **kwargs):
|
|
model_dir = get_model_dir(
|
|
pretrained_model_name_or_path,
|
|
ignore_file_pattern=ignore_file_pattern,
|
|
allow_file_pattern=allow_file_pattern,
|
|
**kwargs)
|
|
|
|
module_obj = module_class.get_config_dict(
|
|
model_dir, *model_args, **kwargs)
|
|
return module_obj
|
|
|
|
def save_pretrained(
|
|
self,
|
|
save_directory: Union[str, os.PathLike],
|
|
safe_serialization: bool = True,
|
|
**kwargs,
|
|
):
|
|
push_to_hub = kwargs.pop('push_to_hub', False)
|
|
if push_to_hub:
|
|
from modelscope.hub.push_to_hub import push_to_hub
|
|
from modelscope.hub.api import HubApi
|
|
from modelscope.hub.repository import Repository
|
|
|
|
token = kwargs.get('token')
|
|
commit_message = kwargs.pop('commit_message', None)
|
|
repo_name = kwargs.pop(
|
|
'repo_id',
|
|
save_directory.split(os.path.sep)[-1])
|
|
|
|
api = HubApi()
|
|
api.login(token)
|
|
api.create_repo(repo_name)
|
|
# clone the repo
|
|
Repository(save_directory, repo_name)
|
|
|
|
super().save_pretrained(
|
|
save_directory=save_directory,
|
|
safe_serialization=safe_serialization,
|
|
push_to_hub=False,
|
|
**kwargs)
|
|
|
|
# Class members may be unpatched, so push_to_hub is done separately here
|
|
if push_to_hub:
|
|
push_to_hub(
|
|
repo_name=repo_name,
|
|
output_dir=save_directory,
|
|
commit_message=commit_message,
|
|
token=token)
|
|
|
|
if not hasattr(module_class, 'from_pretrained'):
|
|
del ClassWrapper.from_pretrained
|
|
else:
|
|
parameters = inspect.signature(var.from_pretrained).parameters
|
|
if 'model' in parameters and 'model_id' in parameters:
|
|
# peft
|
|
ClassWrapper.from_pretrained = from_pretrained
|
|
|
|
if not hasattr(module_class, '_get_peft_type'):
|
|
del ClassWrapper._get_peft_type
|
|
|
|
if not hasattr(module_class, 'get_config_dict'):
|
|
del ClassWrapper.get_config_dict
|
|
|
|
if not hasattr(module_class, 'save_pretrained'):
|
|
del ClassWrapper.save_pretrained
|
|
|
|
ClassWrapper.__name__ = module_class.__name__
|
|
ClassWrapper.__qualname__ = module_class.__qualname__
|
|
return ClassWrapper
|
|
|
|
all_available_modules = []
|
|
for var in all_imported_modules:
|
|
if var is None or not hasattr(var, '__name__'):
|
|
continue
|
|
name = var.__name__
|
|
need_model = 'model' in name.lower() or 'processor' in name.lower(
|
|
) or 'extractor' in name.lower() or 'pipeline' in name.lower()
|
|
if need_model:
|
|
ignore_file_pattern_kwargs = {}
|
|
else:
|
|
ignore_file_pattern_kwargs = {
|
|
'ignore_file_pattern': ignore_file_pattern
|
|
}
|
|
|
|
try:
|
|
# some TFxxx classes has import errors
|
|
has_from_pretrained = hasattr(var, 'from_pretrained')
|
|
has_get_peft_type = hasattr(var, '_get_peft_type')
|
|
has_get_config_dict = hasattr(var, 'get_config_dict')
|
|
has_save_pretrained = hasattr(var, 'save_pretrained')
|
|
except: # noqa
|
|
continue
|
|
|
|
# save_pretrained is not a classmethod and cannot be overridden by replacing
|
|
# the class method. It requires replacing the class object method.
|
|
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
|
|
try:
|
|
if (not has_from_pretrained and not has_get_config_dict
|
|
and not has_get_peft_type and not has_save_pretrained):
|
|
all_available_modules.append(var)
|
|
else:
|
|
all_available_modules.append(
|
|
get_wrapped_class(var, **ignore_file_pattern_kwargs))
|
|
except Exception:
|
|
all_available_modules.append(var)
|
|
else:
|
|
if has_from_pretrained and not hasattr(var,
|
|
'_from_pretrained_origin'):
|
|
parameters = inspect.signature(var.from_pretrained).parameters
|
|
# different argument names
|
|
is_peft = 'model' in parameters and 'model_id' in parameters
|
|
var._from_pretrained_origin = var.from_pretrained
|
|
if not is_peft:
|
|
var.from_pretrained = classmethod(
|
|
partial(patch_pretrained_model_name_or_path,
|
|
**ignore_file_pattern_kwargs))
|
|
else:
|
|
var.from_pretrained = classmethod(
|
|
partial(patch_peft_model_id,
|
|
**ignore_file_pattern_kwargs))
|
|
if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
|
|
var._get_peft_type_origin = var._get_peft_type
|
|
var._get_peft_type = classmethod(
|
|
partial(patch_get_peft_type, **ignore_file_pattern_kwargs))
|
|
|
|
if has_get_config_dict and not hasattr(var,
|
|
'_get_config_dict_origin'):
|
|
var._get_config_dict_origin = var.get_config_dict
|
|
var.get_config_dict = partial(
|
|
patch_pretrained_model_name_or_path,
|
|
ori_func=var._get_config_dict_origin,
|
|
**ignore_file_pattern_kwargs)
|
|
|
|
all_available_modules.append(var)
|
|
return all_available_modules
|
|
|
|
|
|
def _unpatch_pretrained_class(all_imported_modules):
|
|
for var in all_imported_modules:
|
|
if var is None:
|
|
continue
|
|
|
|
try:
|
|
has_from_pretrained = hasattr(var, 'from_pretrained')
|
|
has_get_peft_type = hasattr(var, '_get_peft_type')
|
|
has_get_config_dict = hasattr(var, 'get_config_dict')
|
|
except: # noqa
|
|
continue
|
|
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
|
|
var.from_pretrained = var._from_pretrained_origin
|
|
delattr(var, '_from_pretrained_origin')
|
|
if has_get_peft_type and hasattr(var, '_get_peft_type_origin'):
|
|
var._get_peft_type = var._get_peft_type_origin
|
|
delattr(var, '_get_peft_type_origin')
|
|
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
|
|
var.get_config_dict = var._get_config_dict_origin
|
|
delattr(var, '_get_config_dict_origin')
|
|
|
|
|
|
def _patch_hub():
|
|
import huggingface_hub
|
|
from huggingface_hub import hf_api
|
|
from huggingface_hub.hf_api import api
|
|
from huggingface_hub.hf_api import future_compatible
|
|
from modelscope import get_logger
|
|
logger = get_logger()
|
|
|
|
def _file_exists(
|
|
self,
|
|
repo_id: str,
|
|
filename: str,
|
|
*,
|
|
repo_type: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
token: Union[str, bool, None] = None,
|
|
):
|
|
"""Patch huggingface_hub.file_exists"""
|
|
if repo_type is not None:
|
|
logger.warning(
|
|
'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
|
|
)
|
|
from modelscope.hub.api import HubApi
|
|
api = HubApi()
|
|
api.login(token)
|
|
if revision is None or revision == 'main':
|
|
revision = 'master'
|
|
return api.file_exists(repo_id, filename, revision=revision)
|
|
|
|
def _file_download(repo_id: str,
|
|
filename: str,
|
|
*,
|
|
subfolder: Optional[str] = None,
|
|
repo_type: Optional[str] = None,
|
|
revision: Optional[str] = None,
|
|
cache_dir: Union[str, Path, None] = None,
|
|
local_dir: Union[str, Path, None] = None,
|
|
token: Union[bool, str, None] = None,
|
|
local_files_only: bool = False,
|
|
**kwargs):
|
|
"""Patch huggingface_hub.hf_hub_download"""
|
|
if len(kwargs) > 0:
|
|
logger.warning(
|
|
'The passed in library_name,library_version,user_agent,force_download,proxies'
|
|
'etag_timeout,headers,endpoint '
|
|
'will not be used in modelscope.')
|
|
assert repo_type in (
|
|
None, 'model',
|
|
'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
|
|
if repo_type in (None, 'model'):
|
|
from modelscope.hub.file_download import model_file_download as file_download
|
|
else:
|
|
from modelscope.hub.file_download import dataset_file_download as file_download
|
|
from modelscope import HubApi
|
|
api = HubApi()
|
|
api.login(token)
|
|
if revision is None or revision == 'main':
|
|
revision = 'master'
|
|
return file_download(
|
|
repo_id,
|
|
file_path=os.path.join(subfolder, filename)
|
|
if subfolder else filename,
|
|
cache_dir=cache_dir,
|
|
local_dir=local_dir,
|
|
local_files_only=local_files_only,
|
|
revision=revision)
|
|
|
|
def _whoami(self, token: Union[bool, str, None] = None) -> Dict:
|
|
from modelscope.hub.api import ModelScopeConfig
|
|
from modelscope.hub.api import HubApi
|
|
api = HubApi()
|
|
api.login(token)
|
|
return {'name': ModelScopeConfig.get_user_info()[0] or 'unknown'}
|
|
|
|
def create_repo(self,
|
|
repo_id: str,
|
|
*,
|
|
token: Union[str, bool, None] = None,
|
|
private: bool = False,
|
|
**kwargs) -> 'RepoUrl':
|
|
"""
|
|
Create a new repository on the hub.
|
|
|
|
Args:
|
|
repo_id: The ID of the repository to create.
|
|
token: The authentication token to use.
|
|
private: Whether the repository should be private.
|
|
**kwargs: Additional arguments.
|
|
|
|
Returns:
|
|
RepoUrl: The URL of the created repository.
|
|
"""
|
|
from modelscope.hub.api import HubApi
|
|
api = HubApi()
|
|
visibility = 'private' if private else 'public'
|
|
repo_url = api.create_repo(
|
|
repo_id, token=token, visibility=visibility, **kwargs)
|
|
from modelscope.utils.repo_utils import RepoUrl
|
|
return RepoUrl(url=repo_url, repo_type='model', repo_id=repo_id)
|
|
|
|
@future_compatible
|
|
def upload_folder(
|
|
self,
|
|
*,
|
|
repo_id: str,
|
|
folder_path: Union[str, Path],
|
|
path_in_repo: Optional[str] = None,
|
|
commit_message: Optional[str] = None,
|
|
commit_description: Optional[str] = None,
|
|
token: Union[str, bool, None] = None,
|
|
revision: Optional[str] = 'master',
|
|
ignore_patterns: Optional[Union[List[str], str]] = None,
|
|
**kwargs,
|
|
):
|
|
from modelscope.hub.push_to_hub import _push_files_to_hub
|
|
if revision is None or revision == 'main':
|
|
revision = 'master'
|
|
_push_files_to_hub(
|
|
path_or_fileobj=folder_path,
|
|
path_in_repo=path_in_repo,
|
|
repo_id=repo_id,
|
|
commit_message=commit_message,
|
|
commit_description=commit_description,
|
|
revision=revision,
|
|
token=token)
|
|
from modelscope.utils.repo_utils import CommitInfo
|
|
return CommitInfo(
|
|
commit_url=
|
|
f'{DEFAULT_MODELSCOPE_DATA_ENDPOINT}/models/{repo_id}/files',
|
|
commit_message=commit_message,
|
|
commit_description=commit_description,
|
|
oid=None,
|
|
)
|
|
|
|
from modelscope.utils.constant import DEFAULT_REPOSITORY_REVISION
|
|
|
|
@future_compatible
|
|
def upload_file(
|
|
self,
|
|
*,
|
|
path_or_fileobj: Union[str, Path, bytes, BinaryIO],
|
|
path_in_repo: str,
|
|
repo_id: str,
|
|
token: Union[str, bool, None] = None,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
|
commit_message: Optional[str] = None,
|
|
commit_description: Optional[str] = None,
|
|
**kwargs,
|
|
):
|
|
if revision is None or revision == 'main':
|
|
revision = 'master'
|
|
from modelscope.hub.push_to_hub import _push_files_to_hub
|
|
_push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token,
|
|
revision, commit_message, commit_description)
|
|
|
|
@future_compatible
|
|
def create_commit(
|
|
self,
|
|
repo_id: str,
|
|
operations: Iterable[CommitOperation],
|
|
*,
|
|
commit_message: str,
|
|
commit_description: Optional[str] = None,
|
|
token: Union[str, bool, None] = None,
|
|
repo_type: Optional[str] = None,
|
|
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
|
**kwargs,
|
|
) -> Union[CommitInfo, Future[CommitInfo]]:
|
|
from modelscope.hub.api import HubApi
|
|
api = HubApi()
|
|
if any(['Add' not in op.__class__.__name__ for op in operations]):
|
|
raise ValueError(
|
|
'ModelScope create_commit only support Add operation for now.')
|
|
if revision is None or revision == 'main':
|
|
revision = 'master'
|
|
all_files = [op.path_or_fileobj for op in operations]
|
|
api.upload_folder(
|
|
repo_id=repo_id,
|
|
folder_path=all_files,
|
|
commit_message=commit_message,
|
|
commit_description=commit_description,
|
|
token=token,
|
|
revision=revision,
|
|
repo_type=repo_type or 'model')
|
|
|
|
def load(
|
|
cls,
|
|
repo_id_or_path: Union[str, Path],
|
|
repo_type: Optional[str] = None,
|
|
token: Optional[str] = None,
|
|
ignore_metadata_errors: bool = False,
|
|
):
|
|
from modelscope.hub.api import HubApi
|
|
api = HubApi()
|
|
api.login(token)
|
|
if os.path.exists(repo_id_or_path):
|
|
file_path = repo_id_or_path
|
|
elif repo_type == 'model' or repo_type is None:
|
|
from modelscope import model_file_download
|
|
file_path = model_file_download(repo_id_or_path, 'README.md')
|
|
elif repo_type == 'dataset':
|
|
from modelscope import dataset_file_download
|
|
file_path = dataset_file_download(repo_id_or_path, 'README.md')
|
|
else:
|
|
raise ValueError(
|
|
f'repo_type should be `model` or `dataset`, but now is {repo_type}'
|
|
)
|
|
|
|
with open(file_path, 'r') as f:
|
|
repo_card = cls(
|
|
f.read(), ignore_metadata_errors=ignore_metadata_errors)
|
|
if not hasattr(repo_card.data, 'tags'):
|
|
repo_card.data.tags = []
|
|
return repo_card
|
|
|
|
# Patch repocard.validate
|
|
from huggingface_hub import repocard
|
|
if not hasattr(repocard.RepoCard, '_validate_origin'):
|
|
|
|
def load(*args, **kwargs): # noqa
|
|
from huggingface_hub.errors import EntryNotFoundError
|
|
raise EntryNotFoundError(message='API not supported.')
|
|
|
|
repocard.RepoCard._validate_origin = repocard.RepoCard.validate
|
|
repocard.RepoCard.validate = lambda *args, **kwargs: None
|
|
repocard.RepoCard._load_origin = repocard.RepoCard.load
|
|
repocard.RepoCard.load = MethodType(load, repocard.RepoCard)
|
|
|
|
if not hasattr(hf_api, '_hf_hub_download_origin'):
|
|
# Patch hf_hub_download
|
|
hf_api._hf_hub_download_origin = huggingface_hub.file_download.hf_hub_download
|
|
huggingface_hub.hf_hub_download = _file_download
|
|
huggingface_hub.file_download.hf_hub_download = _file_download
|
|
|
|
if not hasattr(hf_api, '_file_exists_origin'):
|
|
# Patch file_exists
|
|
hf_api._file_exists_origin = hf_api.file_exists
|
|
hf_api.file_exists = MethodType(_file_exists, api)
|
|
huggingface_hub.file_exists = hf_api.file_exists
|
|
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
|
|
|
if not hasattr(hf_api, '_whoami_origin'):
|
|
# Patch whoami
|
|
hf_api._whoami_origin = hf_api.whoami
|
|
hf_api.whoami = MethodType(_whoami, api)
|
|
huggingface_hub.whoami = hf_api.whoami
|
|
huggingface_hub.hf_api.whoami = hf_api.whoami
|
|
|
|
if not hasattr(hf_api, '_create_repo_origin'):
|
|
# Patch create_repo
|
|
from transformers.utils import hub
|
|
hf_api._create_repo_origin = hf_api.create_repo
|
|
hf_api.create_repo = MethodType(create_repo, api)
|
|
huggingface_hub.create_repo = hf_api.create_repo
|
|
huggingface_hub.hf_api.create_repo = hf_api.create_repo
|
|
hub.create_repo = hf_api.create_repo
|
|
|
|
if not hasattr(hf_api, '_upload_folder_origin'):
|
|
# Patch upload_folder
|
|
hf_api._upload_folder_origin = hf_api.upload_folder
|
|
hf_api.upload_folder = MethodType(upload_folder, api)
|
|
huggingface_hub.upload_folder = hf_api.upload_folder
|
|
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
|
|
|
|
if not hasattr(hf_api, '_upload_file_origin'):
|
|
# Patch upload_file
|
|
hf_api._upload_file_origin = hf_api.upload_file
|
|
hf_api.upload_file = MethodType(upload_file, api)
|
|
huggingface_hub.upload_file = hf_api.upload_file
|
|
huggingface_hub.hf_api.upload_file = hf_api.upload_file
|
|
repocard.upload_file = hf_api.upload_file
|
|
|
|
if not hasattr(hf_api, '_create_commit_origin'):
|
|
# Patch upload_file
|
|
hf_api._create_commit_origin = hf_api.create_commit
|
|
hf_api.create_commit = MethodType(create_commit, api)
|
|
huggingface_hub.create_commit = hf_api.create_commit
|
|
huggingface_hub.hf_api.create_commit = hf_api.create_commit
|
|
from transformers.utils import hub
|
|
hub.create_commit = hf_api.create_commit
|
|
|
|
|
|
def _unpatch_hub():
|
|
import huggingface_hub
|
|
from huggingface_hub import hf_api
|
|
|
|
from huggingface_hub import repocard
|
|
if hasattr(repocard.RepoCard, '_validate_origin'):
|
|
repocard.RepoCard.validate = repocard.RepoCard._validate_origin
|
|
delattr(repocard.RepoCard, '_validate_origin')
|
|
if hasattr(repocard.RepoCard, '_load_origin'):
|
|
repocard.RepoCard.load = repocard.RepoCard._load_origin
|
|
delattr(repocard.RepoCard, '_load_origin')
|
|
|
|
if hasattr(hf_api, '_hf_hub_download_origin'):
|
|
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
|
|
huggingface_hub.hf_hub_download = hf_api._hf_hub_download_origin
|
|
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
|
|
delattr(hf_api, '_hf_hub_download_origin')
|
|
|
|
if hasattr(hf_api, '_file_exists_origin'):
|
|
hf_api.file_exists = hf_api._file_exists_origin
|
|
huggingface_hub.file_exists = hf_api.file_exists
|
|
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
|
delattr(hf_api, '_file_exists_origin')
|
|
|
|
if hasattr(hf_api, '_whoami_origin'):
|
|
hf_api.whoami = hf_api._whoami_origin
|
|
huggingface_hub.whoami = hf_api.whoami
|
|
huggingface_hub.hf_api.whoami = hf_api.whoami
|
|
delattr(hf_api, '_whoami_origin')
|
|
|
|
if hasattr(hf_api, '_create_repo_origin'):
|
|
from transformers.utils import hub
|
|
hf_api.create_repo = hf_api._create_repo_origin
|
|
huggingface_hub.create_repo = hf_api.create_repo
|
|
huggingface_hub.hf_api.create_repo = hf_api.create_repo
|
|
hub.create_repo = hf_api.create_repo
|
|
delattr(hf_api, '_create_repo_origin')
|
|
|
|
if hasattr(hf_api, '_upload_folder_origin'):
|
|
hf_api.upload_folder = hf_api._upload_folder_origin
|
|
huggingface_hub.upload_folder = hf_api.upload_folder
|
|
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
|
|
delattr(hf_api, '_upload_folder_origin')
|
|
|
|
if hasattr(hf_api, '_upload_file_origin'):
|
|
hf_api.upload_file = hf_api._upload_file_origin
|
|
huggingface_hub.upload_file = hf_api.upload_file
|
|
huggingface_hub.hf_api.upload_file = hf_api.upload_file
|
|
repocard.upload_file = hf_api.upload_file
|
|
delattr(hf_api, '_upload_file_origin')
|
|
|
|
if hasattr(hf_api, '_create_commit_origin'):
|
|
hf_api.create_commit = hf_api._create_commit_origin
|
|
huggingface_hub.create_commit = hf_api.create_commit
|
|
huggingface_hub.hf_api.create_commit = hf_api.create_commit
|
|
from transformers.utils import hub
|
|
hub.create_commit = hf_api.create_commit
|
|
delattr(hf_api, '_create_commit_origin')
|
|
|
|
|
|
def patch_hub():
|
|
_patch_hub()
|
|
_patch_pretrained_class(get_all_imported_modules())
|
|
|
|
|
|
def unpatch_hub():
|
|
_unpatch_pretrained_class(get_all_imported_modules())
|
|
_unpatch_hub()
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def patch_context():
|
|
patch_hub()
|
|
yield
|
|
unpatch_hub()
|