patch hf hub (#987)

This commit is contained in:
tastelikefeet
2024-09-14 12:24:19 +08:00
committed by GitHub
parent 51b33cecef
commit 4c518db424
4 changed files with 182 additions and 3 deletions

View File

@@ -661,6 +661,26 @@ class HubApi:
files.append(file)
return files
def file_exists(
self,
repo_id: str,
filename: str,
*,
revision: Optional[str] = None,
):
"""Get if the specified file exists
Args:
repo_id (`str`): The repo id to use
filename (`str`): The queried filename
revision (`Optional[str]`): The repo revision
Returns:
The query result in bool value
"""
files = self.get_model_files(repo_id, revision=revision)
files = [file['Name'] for file in files]
return filename in files
def create_dataset(self,
dataset_name: str,
namespace: str,

View File

@@ -1,5 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib
import os
from pathlib import Path
from types import MethodType
from typing import Dict, Literal, Optional, Union
from transformers import AutoConfig as AutoConfigHF
from transformers import AutoImageProcessor as AutoImageProcessorHF
@@ -14,10 +18,12 @@ from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import BatchFeature as BatchFeatureHF
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
from .logger import get_logger
try:
from transformers import GPTQConfig as GPTQConfigHF
@@ -26,6 +32,8 @@ except ImportError:
GPTQConfigHF = None
AwqConfigHF = None
logger = get_logger()
def user_agent(invoked_by=None):
if invoked_by is None:
@@ -34,6 +42,157 @@ def user_agent(invoked_by=None):
return uagent
def _try_login(token: Optional[str] = None):
from modelscope.hub.api import HubApi
api = HubApi()
if token is None:
token = os.environ.get('MODELSCOPE_API_TOKEN')
if token:
api.login(token)
def _file_exists(
self,
repo_id: str,
filename: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
token: Union[str, bool, None] = None,
):
"""Patch huggingface_hub.file_exists"""
if repo_type is not None:
logger.warning(
'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
)
_try_login(token)
from modelscope.hub.api import HubApi
api = HubApi()
return api.file_exists(repo_id, filename, revision=revision)
def _file_download(repo_id: str,
filename: str,
*,
subfolder: Optional[str] = None,
repo_type: Optional[str] = None,
revision: Optional[str] = None,
cache_dir: Union[str, Path, None] = None,
local_dir: Union[str, Path, None] = None,
token: Union[bool, str, None] = None,
local_files_only: bool = False,
**kwargs):
"""Patch huggingface_hub.hf_hub_download"""
if len(kwargs) > 0:
logger.warning(
'The passed in library_name,library_version,user_agent,force_download,proxies'
'etag_timeout,headers,endpoint '
'will not be used in modelscope.')
assert repo_type in (
None, 'model',
'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
if repo_type in (None, 'model'):
from modelscope.hub.file_download import model_file_download as file_download
else:
from modelscope.hub.file_download import dataset_file_download as file_download
_try_login(token)
return file_download(
repo_id,
file_path=os.path.join(subfolder, filename) if subfolder else filename,
cache_dir=cache_dir,
local_dir=local_dir,
local_files_only=local_files_only,
revision=revision)
def _patch_pretrained_class():
def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern,
**kwargs):
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return model_dir
def patch_tokenizer_base():
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedTokenizerBase.from_pretrained = from_pretrained
def patch_config_base():
""" Monkey patch PretrainedConfig.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PretrainedConfig.from_pretrained.__func__
ori_get_config_dict = PretrainedConfig.get_config_dict.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
@classmethod
def get_config_dict(cls, pretrained_model_name_or_path, **kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
model_dir = get_model_dir(pretrained_model_name_or_path,
ignore_file_pattern, **kwargs)
return ori_get_config_dict(cls, model_dir, **kwargs)
PretrainedConfig.get_config_dict = get_config_dict
def patch_model_base():
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
model_dir = get_model_dir(pretrained_model_name_or_path, None,
**kwargs)
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedModel.from_pretrained = from_pretrained
patch_tokenizer_base()
patch_config_base()
patch_model_base()
def patch_hub():
"""Patch hf hub, which to make users can download models from modelscope to speed up.
"""
import huggingface_hub
from huggingface_hub import hf_api
from huggingface_hub.hf_api import api
huggingface_hub.hf_hub_download = _file_download
huggingface_hub.file_download.hf_hub_download = _file_download
hf_api.file_exists = MethodType(_file_exists, api)
huggingface_hub.file_exists = hf_api.file_exists
huggingface_hub.hf_api.file_exists = hf_api.file_exists
_patch_pretrained_class()
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=2.18.0
datasets>=2.18.0,<3.0.0
einops
oss2
Pillow

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=2.18.0
datasets>=2.18.0,<3.0.0
einops
oss2
Pillow