This commit is contained in:
yuze.zyz
2025-01-08 20:50:07 +08:00
parent 640b3bd49b
commit c8f958182d

View File

@@ -1,18 +1,17 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
import inspect
import os
import sys
from functools import partial
from pathlib import Path
import importlib
from types import MethodType
from typing import BinaryIO, Dict, List, Optional, Union
from huggingface_hub.hf_api import CommitInfo, future_compatible
from modelscope import snapshot_download
from modelscope.utils.constant import Invoke
from modelscope.utils.logger import get_logger
from modelscope import snapshot_download
from modelscope.utils.logger import get_logger
logger = get_logger()
@@ -26,13 +25,6 @@ for module in all_modules:
all_imported_modules.append(importlib.import_module(f'transformers.{module}'))
def user_agent(invoked_by=None):
if invoked_by is None:
invoked_by = Invoke.PRETRAINED
uagent = '%s/%s' % (Invoke.KEY, invoked_by)
return uagent
def _patch_pretrained_class():
def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern,
@@ -79,24 +71,51 @@ def _patch_pretrained_class():
else:
ignore_file_pattern_kwargs = {'ignore_file_pattern': ignore_file_pattern}
if name.endswith('HF'):
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
parameters = inspect.signature(var.from_pretrained).parameters
is_peft = 'model' in parameters and 'model_id' in parameters
if has_from_pretrained:
if not is_peft:
var.from_pretrained = partial(patch_pretrained_model_name_or_path,
ori_func=var.from_pretrained,
**ignore_file_pattern_kwargs)
else:
var.from_pretrained = partial(patch_peft_model_id,
ori_func=var.from_pretrained,
**ignore_file_pattern_kwargs)
if has_get_peft_type:
var._get_peft_type = partial(_get_peft_type,
ori_func=var._get_peft_type,
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
parameters = inspect.signature(var.from_pretrained).parameters
is_peft = 'model' in parameters and 'model_id' in parameters
if has_from_pretrained and not hasattr(var, '_from_pretrained_origin'):
var._from_pretrained_origin = var.from_pretrained
if not is_peft:
var.from_pretrained = partial(patch_pretrained_model_name_or_path,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
else:
var.from_pretrained = partial(patch_peft_model_id,
ori_func=var._from_pretrained_origin,
**ignore_file_pattern_kwargs)
delattr(var, '_from_pretrained_origin')
if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
var._get_peft_type_origin = var._get_peft_type
var._get_peft_type = partial(_get_peft_type,
ori_func=var._get_peft_type_origin,
**ignore_file_pattern_kwargs)
delattr(var, '_get_peft_type_origin')
if has_get_config_dict and not hasattr(var, '_get_config_dict_origin'):
var._get_config_dict_origin = var.get_config_dict
var.get_config_dict = partial(patch_pretrained_model_name_or_path,
ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs)
delattr(var, '_get_config_dict_origin')
def _unpatch_pretrained_class():
for var in all_imported_modules:
if var is None:
continue
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
var.from_pretrained = var._from_pretrained_origin
if has_get_peft_type and hasattr(var, '_get_peft_type_origin'):
var._get_peft_type = var._get_peft_type_origin
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
var.get_config_dict = var._get_config_dict_origin
def _patch_hub():
@@ -160,26 +179,11 @@ def _patch_hub():
def _whoami(self, token: Union[bool, str, None] = None) -> Dict:
from modelscope.hub.api import ModelScopeConfig
from modelscope.hub.api import HubApi
api = HubApi()
api.try_login(token)
return {'name': ModelScopeConfig.get_user_info()[0] or 'unknown'}
# Patch hf_hub_download
huggingface_hub.hf_hub_download = _file_download
huggingface_hub.file_download.hf_hub_download = _file_download
# Patch file_exists
hf_api.file_exists = MethodType(_file_exists, api)
huggingface_hub.file_exists = hf_api.file_exists
huggingface_hub.hf_api.file_exists = hf_api.file_exists
# Patch whoami
hf_api.whoami = MethodType(_whoami, api)
huggingface_hub.whoami = hf_api.whoami
huggingface_hub.hf_api.whoami = hf_api.whoami
# Patch repocard.validate
from huggingface_hub import repocard
repocard.RepoCard.validate = lambda *args, **kwargs: None
def create_repo(self,
repo_id: str,
*,
@@ -244,25 +248,109 @@ def _patch_hub():
push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token,
revision, commit_message, commit_description)
# Patch create_repo
from transformers.utils import hub
hf_api.create_repo = MethodType(create_repo, api)
huggingface_hub.create_repo = hf_api.create_repo
huggingface_hub.hf_api.create_repo = hf_api.create_repo
hub.create_repo = create_repo
# Patch repocard.validate
from huggingface_hub import repocard
if not hasattr(repocard.RepoCard, '_validate_origin'):
repocard.RepoCard._validate_origin = repocard.RepoCard.validate
repocard.RepoCard.validate = lambda *args, **kwargs: None
# Patch upload_folder
hf_api.upload_folder = MethodType(upload_folder, api)
huggingface_hub.upload_folder = hf_api.upload_folder
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
if not hasattr(hf_api, '_hf_hub_download_origin'):
# Patch hf_hub_download
hf_api._hf_hub_download_origin = huggingface_hub.file_download.hf_hub_download
huggingface_hub.hf_hub_download = _file_download
huggingface_hub.file_download.hf_hub_download = _file_download
# Patch upload_file
hf_api.upload_file = MethodType(upload_file, api)
huggingface_hub.upload_file = hf_api.upload_file
huggingface_hub.hf_api.upload_file = hf_api.upload_file
repocard.upload_file = hf_api.upload_file
if not hasattr(hf_api, '_file_exists_origin'):
# Patch file_exists
hf_api._file_exists_origin = hf_api.file_exists
hf_api.file_exists = MethodType(_file_exists, api)
huggingface_hub.file_exists = hf_api.file_exists
huggingface_hub.hf_api.file_exists = hf_api.file_exists
if not hasattr(hf_api, '_whoami_origin'):
# Patch whoami
hf_api._whoami_origin = hf_api.whoami
hf_api.whoami = MethodType(_whoami, api)
huggingface_hub.whoami = hf_api.whoami
huggingface_hub.hf_api.whoami = hf_api.whoami
if not hasattr(hf_api, '_create_repo_origin'):
# Patch create_repo
from transformers.utils import hub
hf_api._create_repo_origin = hf_api.create_repo
hf_api.create_repo = MethodType(create_repo, api)
huggingface_hub.create_repo = hf_api.create_repo
huggingface_hub.hf_api.create_repo = hf_api.create_repo
hub.create_repo = hf_api.create_repo
if not hasattr(hf_api, '_upload_folder_origin'):
# Patch upload_folder
hf_api._upload_folder_origin = hf_api.upload_folder
hf_api.upload_folder = MethodType(upload_folder, api)
huggingface_hub.upload_folder = hf_api.upload_folder
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
if not hasattr(hf_api, '_upload_file_origin'):
# Patch upload_file
hf_api._upload_file_origin = hf_api.upload_file
hf_api.upload_file = MethodType(upload_file, api)
huggingface_hub.upload_file = hf_api.upload_file
huggingface_hub.hf_api.upload_file = hf_api.upload_file
repocard.upload_file = hf_api.upload_file
def _unpatch_hub():
import huggingface_hub
from huggingface_hub import hf_api
from huggingface_hub import repocard
if hasattr(repocard.RepoCard, '_validate_origin'):
repocard.RepoCard.validate = repocard.RepoCard._validate_origin
delattr(repocard.RepoCard, '_validate_origin')
if hasattr(hf_api, '_hf_hub_download_origin'):
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
huggingface_hub.hf_hub_download = hf_api._hf_hub_download_origin
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
delattr(hf_api, '_hf_hub_download_origin')
if hasattr(hf_api, '_file_exists_origin'):
hf_api.file_exists = hf_api._file_exists_origin
huggingface_hub.file_exists = hf_api.file_exists
huggingface_hub.hf_api.file_exists = hf_api.file_exists
delattr(hf_api, '_file_exists_origin')
if hasattr(hf_api, '_whoami_origin'):
hf_api.whoami = hf_api._whoami_origin
huggingface_hub.whoami = hf_api.whoami
huggingface_hub.hf_api.whoami = hf_api.whoami
delattr(hf_api, '_whoami_origin')
if hasattr(hf_api, '_create_repo_origin'):
from transformers.utils import hub
hf_api.create_repo = hf_api._create_repo_origin
huggingface_hub.create_repo = hf_api.create_repo
huggingface_hub.hf_api.create_repo = hf_api.create_repo
hub.create_repo = hf_api.create_repo
delattr(hf_api, '_create_repo_origin')
if hasattr(hf_api, '_upload_folder_origin'):
hf_api.upload_folder = hf_api._upload_folder_origin
huggingface_hub.upload_folder = hf_api.upload_folder
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
delattr(hf_api, '_upload_folder_origin')
if hasattr(hf_api, '_upload_file_origin'):
hf_api.upload_file = hf_api._upload_file_origin
huggingface_hub.upload_file = hf_api.upload_file
huggingface_hub.hf_api.upload_file = hf_api.upload_file
repocard.upload_file = hf_api.upload_file
delattr(hf_api, '_upload_file_origin')
def patch_hub():
_patch_hub()
_patch_pretrained_class()
def unpatch_hub():
_unpatch_pretrained_class()