Files
modelscope/modelscope/utils/hf_util.py
2023-10-10 10:43:16 +08:00

148 lines
5.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import sys
from transformers import CONFIG_MAPPING
from transformers import AutoConfig as AutoConfigHF
from transformers import AutoModel as AutoModelHF
from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
from transformers import \
AutoModelForSequenceClassification as AutoModelForSequenceClassificationHF
from transformers import \
AutoModelForTokenClassification as AutoModelForTokenClassificationHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import BitsAndBytesConfig as BitsAndBytesConfigHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import (PretrainedConfig, PreTrainedModel,
PreTrainedTokenizerBase)
from transformers.models.auto.tokenization_auto import (
TOKENIZER_MAPPING_NAMES, get_tokenizer_config)
from modelscope import snapshot_download
from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke
try:
from transformers import GPTQConfig as GPTQConfigHF
except ImportError:
GPTQConfigHF = None
def user_agent(invoked_by=None):
if invoked_by is None:
invoked_by = Invoke.PRETRAINED
uagent = '%s/%s' % (Invoke.KEY, invoked_by)
return uagent
def patch_tokenizer_base():
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedTokenizerBase.from_pretrained = from_pretrained
def patch_model_base():
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedModel.from_pretrained = from_pretrained
patch_tokenizer_base()
patch_model_base()
def get_wrapped_class(module_class, ignore_file_pattern=[], **kwargs):
"""Get a custom wrapper class for auto classes to download the models from the ModelScope hub
Args:
module_class: The actual module class
ignore_file_pattern (`str` or `List`, *optional*, default to `None`):
Any file pattern to be ignored in downloading, like exact file names or file extensions.
Returns:
The wrapper
"""
default_ignore_file_pattern = ignore_file_pattern
class ClassWrapper(module_class):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = kwargs.pop('ignore_file_pattern',
default_ignore_file_pattern)
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', DEFAULT_MODEL_REVISION)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = pretrained_model_name_or_path
module_obj = module_class.from_pretrained(model_dir, *model_args,
**kwargs)
if module_class.__name__.startswith('AutoModel'):
module_obj.model_dir = model_dir
return module_obj
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
AutoModel = get_wrapped_class(
AutoModelHF, ignore_file_pattern=[r'\w+\.safetensors'])
AutoModelForCausalLM = get_wrapped_class(
AutoModelForCausalLMHF, ignore_file_pattern=[r'\w+\.safetensors'])
AutoModelForSeq2SeqLM = get_wrapped_class(
AutoModelForSeq2SeqLMHF, ignore_file_pattern=[r'\w+\.safetensors'])
AutoModelForSequenceClassification = get_wrapped_class(
AutoModelForSequenceClassificationHF,
ignore_file_pattern=[r'\w+\.safetensors'])
AutoModelForTokenClassification = get_wrapped_class(
AutoModelForTokenClassificationHF,
ignore_file_pattern=[r'\w+\.safetensors'])
AutoTokenizer = get_wrapped_class(
AutoTokenizerHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
AutoConfig = get_wrapped_class(
AutoConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
GenerationConfig = get_wrapped_class(
GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors'])
GPTQConfig = GPTQConfigHF
BitsAndBytesConfig = BitsAndBytesConfigHF