mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
wip
This commit is contained in:
@@ -1,18 +1,17 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
import importlib
|
||||
from types import MethodType
|
||||
from typing import BinaryIO, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub.hf_api import CommitInfo, future_compatible
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import Invoke
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -26,13 +25,6 @@ for module in all_modules:
|
||||
all_imported_modules.append(importlib.import_module(f'transformers.{module}'))
|
||||
|
||||
|
||||
def user_agent(invoked_by=None):
|
||||
if invoked_by is None:
|
||||
invoked_by = Invoke.PRETRAINED
|
||||
uagent = '%s/%s' % (Invoke.KEY, invoked_by)
|
||||
return uagent
|
||||
|
||||
|
||||
def _patch_pretrained_class():
|
||||
|
||||
def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern,
|
||||
@@ -79,24 +71,51 @@ def _patch_pretrained_class():
|
||||
else:
|
||||
ignore_file_pattern_kwargs = {'ignore_file_pattern': ignore_file_pattern}
|
||||
|
||||
if name.endswith('HF'):
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
parameters = inspect.signature(var.from_pretrained).parameters
|
||||
is_peft = 'model' in parameters and 'model_id' in parameters
|
||||
if has_from_pretrained:
|
||||
if not is_peft:
|
||||
var.from_pretrained = partial(patch_pretrained_model_name_or_path,
|
||||
ori_func=var.from_pretrained,
|
||||
**ignore_file_pattern_kwargs)
|
||||
else:
|
||||
var.from_pretrained = partial(patch_peft_model_id,
|
||||
ori_func=var.from_pretrained,
|
||||
**ignore_file_pattern_kwargs)
|
||||
if has_get_peft_type:
|
||||
var._get_peft_type = partial(_get_peft_type,
|
||||
ori_func=var._get_peft_type,
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
parameters = inspect.signature(var.from_pretrained).parameters
|
||||
is_peft = 'model' in parameters and 'model_id' in parameters
|
||||
if has_from_pretrained and not hasattr(var, '_from_pretrained_origin'):
|
||||
var._from_pretrained_origin = var.from_pretrained
|
||||
if not is_peft:
|
||||
var.from_pretrained = partial(patch_pretrained_model_name_or_path,
|
||||
ori_func=var._from_pretrained_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
else:
|
||||
var.from_pretrained = partial(patch_peft_model_id,
|
||||
ori_func=var._from_pretrained_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
delattr(var, '_from_pretrained_origin')
|
||||
if has_get_peft_type and not hasattr(var, '_get_peft_type_origin'):
|
||||
var._get_peft_type_origin = var._get_peft_type
|
||||
var._get_peft_type = partial(_get_peft_type,
|
||||
ori_func=var._get_peft_type_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
delattr(var, '_get_peft_type_origin')
|
||||
|
||||
if has_get_config_dict and not hasattr(var, '_get_config_dict_origin'):
|
||||
var._get_config_dict_origin = var.get_config_dict
|
||||
var.get_config_dict = partial(patch_pretrained_model_name_or_path,
|
||||
ori_func=var._get_config_dict_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
delattr(var, '_get_config_dict_origin')
|
||||
|
||||
|
||||
def _unpatch_pretrained_class():
|
||||
for var in all_imported_modules:
|
||||
if var is None:
|
||||
continue
|
||||
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
|
||||
var.from_pretrained = var._from_pretrained_origin
|
||||
if has_get_peft_type and hasattr(var, '_get_peft_type_origin'):
|
||||
var._get_peft_type = var._get_peft_type_origin
|
||||
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
|
||||
var.get_config_dict = var._get_config_dict_origin
|
||||
|
||||
|
||||
def _patch_hub():
|
||||
@@ -160,26 +179,11 @@ def _patch_hub():
|
||||
|
||||
def _whoami(self, token: Union[bool, str, None] = None) -> Dict:
|
||||
from modelscope.hub.api import ModelScopeConfig
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
api.try_login(token)
|
||||
return {'name': ModelScopeConfig.get_user_info()[0] or 'unknown'}
|
||||
|
||||
# Patch hf_hub_download
|
||||
huggingface_hub.hf_hub_download = _file_download
|
||||
huggingface_hub.file_download.hf_hub_download = _file_download
|
||||
|
||||
# Patch file_exists
|
||||
hf_api.file_exists = MethodType(_file_exists, api)
|
||||
huggingface_hub.file_exists = hf_api.file_exists
|
||||
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
||||
|
||||
# Patch whoami
|
||||
hf_api.whoami = MethodType(_whoami, api)
|
||||
huggingface_hub.whoami = hf_api.whoami
|
||||
huggingface_hub.hf_api.whoami = hf_api.whoami
|
||||
|
||||
# Patch repocard.validate
|
||||
from huggingface_hub import repocard
|
||||
repocard.RepoCard.validate = lambda *args, **kwargs: None
|
||||
|
||||
def create_repo(self,
|
||||
repo_id: str,
|
||||
*,
|
||||
@@ -244,25 +248,109 @@ def _patch_hub():
|
||||
push_files_to_hub(path_or_fileobj, path_in_repo, repo_id, token,
|
||||
revision, commit_message, commit_description)
|
||||
|
||||
# Patch create_repo
|
||||
from transformers.utils import hub
|
||||
hf_api.create_repo = MethodType(create_repo, api)
|
||||
huggingface_hub.create_repo = hf_api.create_repo
|
||||
huggingface_hub.hf_api.create_repo = hf_api.create_repo
|
||||
hub.create_repo = create_repo
|
||||
# Patch repocard.validate
|
||||
from huggingface_hub import repocard
|
||||
if not hasattr(repocard.RepoCard, '_validate_origin'):
|
||||
repocard.RepoCard._validate_origin = repocard.RepoCard.validate
|
||||
repocard.RepoCard.validate = lambda *args, **kwargs: None
|
||||
|
||||
# Patch upload_folder
|
||||
hf_api.upload_folder = MethodType(upload_folder, api)
|
||||
huggingface_hub.upload_folder = hf_api.upload_folder
|
||||
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
|
||||
if not hasattr(hf_api, '_hf_hub_download_origin'):
|
||||
# Patch hf_hub_download
|
||||
hf_api._hf_hub_download_origin = huggingface_hub.file_download.hf_hub_download
|
||||
huggingface_hub.hf_hub_download = _file_download
|
||||
huggingface_hub.file_download.hf_hub_download = _file_download
|
||||
|
||||
# Patch upload_file
|
||||
hf_api.upload_file = MethodType(upload_file, api)
|
||||
huggingface_hub.upload_file = hf_api.upload_file
|
||||
huggingface_hub.hf_api.upload_file = hf_api.upload_file
|
||||
repocard.upload_file = hf_api.upload_file
|
||||
if not hasattr(hf_api, '_file_exists_origin'):
|
||||
# Patch file_exists
|
||||
hf_api._file_exists_origin = hf_api.file_exists
|
||||
hf_api.file_exists = MethodType(_file_exists, api)
|
||||
huggingface_hub.file_exists = hf_api.file_exists
|
||||
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
||||
|
||||
if not hasattr(hf_api, '_whoami_origin'):
|
||||
# Patch whoami
|
||||
hf_api._whoami_origin = hf_api.whoami
|
||||
hf_api.whoami = MethodType(_whoami, api)
|
||||
huggingface_hub.whoami = hf_api.whoami
|
||||
huggingface_hub.hf_api.whoami = hf_api.whoami
|
||||
|
||||
if not hasattr(hf_api, '_create_repo_origin'):
|
||||
# Patch create_repo
|
||||
from transformers.utils import hub
|
||||
hf_api._create_repo_origin = hf_api.create_repo
|
||||
hf_api.create_repo = MethodType(create_repo, api)
|
||||
huggingface_hub.create_repo = hf_api.create_repo
|
||||
huggingface_hub.hf_api.create_repo = hf_api.create_repo
|
||||
hub.create_repo = hf_api.create_repo
|
||||
|
||||
if not hasattr(hf_api, '_upload_folder_origin'):
|
||||
# Patch upload_folder
|
||||
hf_api._upload_folder_origin = hf_api.upload_folder
|
||||
hf_api.upload_folder = MethodType(upload_folder, api)
|
||||
huggingface_hub.upload_folder = hf_api.upload_folder
|
||||
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
|
||||
|
||||
if not hasattr(hf_api, '_upload_file_origin'):
|
||||
# Patch upload_file
|
||||
hf_api._upload_file_origin = hf_api.upload_file
|
||||
hf_api.upload_file = MethodType(upload_file, api)
|
||||
huggingface_hub.upload_file = hf_api.upload_file
|
||||
huggingface_hub.hf_api.upload_file = hf_api.upload_file
|
||||
repocard.upload_file = hf_api.upload_file
|
||||
|
||||
|
||||
def _unpatch_hub():
|
||||
import huggingface_hub
|
||||
from huggingface_hub import hf_api
|
||||
|
||||
from huggingface_hub import repocard
|
||||
if hasattr(repocard.RepoCard, '_validate_origin'):
|
||||
repocard.RepoCard.validate = repocard.RepoCard._validate_origin
|
||||
delattr(repocard.RepoCard, '_validate_origin')
|
||||
|
||||
if hasattr(hf_api, '_hf_hub_download_origin'):
|
||||
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
|
||||
huggingface_hub.hf_hub_download = hf_api._hf_hub_download_origin
|
||||
huggingface_hub.file_download.hf_hub_download = hf_api._hf_hub_download_origin
|
||||
delattr(hf_api, '_hf_hub_download_origin')
|
||||
|
||||
if hasattr(hf_api, '_file_exists_origin'):
|
||||
hf_api.file_exists = hf_api._file_exists_origin
|
||||
huggingface_hub.file_exists = hf_api.file_exists
|
||||
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
||||
delattr(hf_api, '_file_exists_origin')
|
||||
|
||||
if hasattr(hf_api, '_whoami_origin'):
|
||||
hf_api.whoami = hf_api._whoami_origin
|
||||
huggingface_hub.whoami = hf_api.whoami
|
||||
huggingface_hub.hf_api.whoami = hf_api.whoami
|
||||
delattr(hf_api, '_whoami_origin')
|
||||
|
||||
if hasattr(hf_api, '_create_repo_origin'):
|
||||
from transformers.utils import hub
|
||||
hf_api.create_repo = hf_api._create_repo_origin
|
||||
huggingface_hub.create_repo = hf_api.create_repo
|
||||
huggingface_hub.hf_api.create_repo = hf_api.create_repo
|
||||
hub.create_repo = hf_api.create_repo
|
||||
delattr(hf_api, '_create_repo_origin')
|
||||
|
||||
if hasattr(hf_api, '_upload_folder_origin'):
|
||||
hf_api.upload_folder = hf_api._upload_folder_origin
|
||||
huggingface_hub.upload_folder = hf_api.upload_folder
|
||||
huggingface_hub.hf_api.upload_folder = hf_api.upload_folder
|
||||
delattr(hf_api, '_upload_folder_origin')
|
||||
|
||||
if hasattr(hf_api, '_upload_file_origin'):
|
||||
hf_api.upload_file = hf_api._upload_file_origin
|
||||
huggingface_hub.upload_file = hf_api.upload_file
|
||||
huggingface_hub.hf_api.upload_file = hf_api.upload_file
|
||||
repocard.upload_file = hf_api.upload_file
|
||||
delattr(hf_api, '_upload_file_origin')
|
||||
|
||||
def patch_hub():
|
||||
_patch_hub()
|
||||
_patch_pretrained_class()
|
||||
|
||||
|
||||
def unpatch_hub():
|
||||
_unpatch_pretrained_class()
|
||||
|
||||
Reference in New Issue
Block a user