add monkey patch for base tokenizer and base model

This commit is contained in:
wenmeng.zwm
2023-07-19 21:04:36 +08:00
parent a4f46cb379
commit 97d22ade76
2 changed files with 67 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
from transformers import AutoTokenizer as AutoTokenizerHF
from transformers import GenerationConfig as GenerationConfigHF
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from modelscope import snapshot_download
from modelscope.utils.constant import Invoke
@@ -20,16 +21,66 @@ def user_agent(invoked_by=None):
return uagent
class AutoModel(AutoModelHF):
def patch_tokenizer_base():
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedTokenizerBase.from_pretrained = from_pretrained
def patch_model_base():
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
"""
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern)
else:
model_dir = pretrained_model_name_or_path
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
PreTrainedModel.from_pretrained = from_pretrained
patch_tokenizer_base()
patch_model_base()
class AutoModel(AutoModelHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = pretrained_model_name_or_path
@@ -42,11 +93,13 @@ class AutoModelForCausalLM(AutoModelForCausalLMHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = pretrained_model_name_or_path
@@ -59,11 +112,13 @@ class AutoModelForSeq2SeqLM(AutoModelForSeq2SeqLMHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(
pretrained_model_name_or_path,
revision=revision,
ignore_file_pattern=ignore_file_pattern,
user_agent=user_agent())
else:
model_dir = pretrained_model_name_or_path
@@ -76,7 +131,7 @@ class AutoTokenizer(AutoTokenizerHF):
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
**kwargs):
ignore_file_pattern = [r'\w+\.bin']
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download(

View File

@@ -2,6 +2,8 @@
import unittest
from transformers import LlamaForCausalLM, LlamaTokenizer
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, GenerationConfig)
@@ -42,6 +44,14 @@ class HFUtilTest(unittest.TestCase):
revision='v1.0.3')
self.assertEqual(gen_config.assistant_token_id, 196)
def test_transformer_patch(self):
tokenizer = LlamaTokenizer.from_pretrained(
'skyline2006/llama-7b', revision='v1.0.1')
self.assertIsNotNone(tokenizer)
model = LlamaForCausalLM.from_pretrained(
'skyline2006/llama-7b', revision='v1.0.1')
self.assertIsNotNone(model)
if __name__ == '__main__':
unittest.main()