mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
patch hf hub (#987)
This commit is contained in:
@@ -661,6 +661,26 @@ class HubApi:
|
||||
files.append(file)
|
||||
return files
|
||||
|
||||
def file_exists(
|
||||
self,
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
"""Get if the specified file exists
|
||||
|
||||
Args:
|
||||
repo_id (`str`): The repo id to use
|
||||
filename (`str`): The queried filename
|
||||
revision (`Optional[str]`): The repo revision
|
||||
Returns:
|
||||
The query result in bool value
|
||||
"""
|
||||
files = self.get_model_files(repo_id, revision=revision)
|
||||
files = [file['Name'] for file in files]
|
||||
return filename in files
|
||||
|
||||
def create_dataset(self,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import importlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
|
||||
from transformers import AutoConfig as AutoConfigHF
|
||||
from transformers import AutoImageProcessor as AutoImageProcessorHF
|
||||
@@ -14,10 +18,12 @@ from transformers import AutoTokenizer as AutoTokenizerHF
|
||||
from transformers import BatchFeature as BatchFeatureHF
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
|
||||
from transformers import GenerationConfig as GenerationConfigHF
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
from transformers import (PretrainedConfig, PreTrainedModel,
|
||||
PreTrainedTokenizerBase)
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
|
||||
from .logger import get_logger
|
||||
|
||||
try:
|
||||
from transformers import GPTQConfig as GPTQConfigHF
|
||||
@@ -26,6 +32,8 @@ except ImportError:
|
||||
GPTQConfigHF = None
|
||||
AwqConfigHF = None
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def user_agent(invoked_by=None):
|
||||
if invoked_by is None:
|
||||
@@ -34,6 +42,157 @@ def user_agent(invoked_by=None):
|
||||
return uagent
|
||||
|
||||
|
||||
def _try_login(token: Optional[str] = None):
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
if token is None:
|
||||
token = os.environ.get('MODELSCOPE_API_TOKEN')
|
||||
if token:
|
||||
api.login(token)
|
||||
|
||||
|
||||
def _file_exists(
|
||||
self,
|
||||
repo_id: str,
|
||||
filename: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
):
|
||||
"""Patch huggingface_hub.file_exists"""
|
||||
if repo_type is not None:
|
||||
logger.warning(
|
||||
'The passed in repo_type will not be used in modelscope. Now only model repo can be queried.'
|
||||
)
|
||||
_try_login(token)
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
return api.file_exists(repo_id, filename, revision=revision)
|
||||
|
||||
|
||||
def _file_download(repo_id: str,
|
||||
filename: str,
|
||||
*,
|
||||
subfolder: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
cache_dir: Union[str, Path, None] = None,
|
||||
local_dir: Union[str, Path, None] = None,
|
||||
token: Union[bool, str, None] = None,
|
||||
local_files_only: bool = False,
|
||||
**kwargs):
|
||||
"""Patch huggingface_hub.hf_hub_download"""
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(
|
||||
'The passed in library_name,library_version,user_agent,force_download,proxies'
|
||||
'etag_timeout,headers,endpoint '
|
||||
'will not be used in modelscope.')
|
||||
assert repo_type in (
|
||||
None, 'model',
|
||||
'dataset'), f'repo_type={repo_type} is not supported in ModelScope'
|
||||
if repo_type in (None, 'model'):
|
||||
from modelscope.hub.file_download import model_file_download as file_download
|
||||
else:
|
||||
from modelscope.hub.file_download import dataset_file_download as file_download
|
||||
_try_login(token)
|
||||
return file_download(
|
||||
repo_id,
|
||||
file_path=os.path.join(subfolder, filename) if subfolder else filename,
|
||||
cache_dir=cache_dir,
|
||||
local_dir=local_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision)
|
||||
|
||||
|
||||
def _patch_pretrained_class():
|
||||
|
||||
def get_model_dir(pretrained_model_name_or_path, ignore_file_pattern,
|
||||
**kwargs):
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return model_dir
|
||||
|
||||
def patch_tokenizer_base():
|
||||
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedTokenizerBase.from_pretrained = from_pretrained
|
||||
|
||||
def patch_config_base():
|
||||
""" Monkey patch PretrainedConfig.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PretrainedConfig.from_pretrained.__func__
|
||||
ori_get_config_dict = PretrainedConfig.get_config_dict.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(cls, pretrained_model_name_or_path, **kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||
ignore_file_pattern, **kwargs)
|
||||
return ori_get_config_dict(cls, model_dir, **kwargs)
|
||||
|
||||
PretrainedConfig.get_config_dict = get_config_dict
|
||||
|
||||
def patch_model_base():
|
||||
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
model_dir = get_model_dir(pretrained_model_name_or_path, None,
|
||||
**kwargs)
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedModel.from_pretrained = from_pretrained
|
||||
|
||||
patch_tokenizer_base()
|
||||
patch_config_base()
|
||||
patch_model_base()
|
||||
|
||||
|
||||
def patch_hub():
|
||||
"""Patch hf hub, which to make users can download models from modelscope to speed up.
|
||||
"""
|
||||
import huggingface_hub
|
||||
from huggingface_hub import hf_api
|
||||
from huggingface_hub.hf_api import api
|
||||
|
||||
huggingface_hub.hf_hub_download = _file_download
|
||||
huggingface_hub.file_download.hf_hub_download = _file_download
|
||||
|
||||
hf_api.file_exists = MethodType(_file_exists, api)
|
||||
huggingface_hub.file_exists = hf_api.file_exists
|
||||
huggingface_hub.hf_api.file_exists = hf_api.file_exists
|
||||
|
||||
_patch_pretrained_class()
|
||||
|
||||
|
||||
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
|
||||
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.18.0
|
||||
datasets>=2.18.0,<3.0.0
|
||||
einops
|
||||
oss2
|
||||
Pillow
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.18.0
|
||||
datasets>=2.18.0,<3.0.0
|
||||
einops
|
||||
oss2
|
||||
Pillow
|
||||
|
||||
Reference in New Issue
Block a user