mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
add monkey patch for base tokenizer and base model
This commit is contained in:
@@ -8,6 +8,7 @@ from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF
|
||||
from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF
|
||||
from transformers import AutoTokenizer as AutoTokenizerHF
|
||||
from transformers import GenerationConfig as GenerationConfigHF
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizerBase
|
||||
|
||||
from modelscope import snapshot_download
|
||||
from modelscope.utils.constant import Invoke
|
||||
@@ -20,16 +21,66 @@ def user_agent(invoked_by=None):
|
||||
return uagent
|
||||
|
||||
|
||||
class AutoModel(AutoModelHF):
|
||||
def patch_tokenizer_base():
|
||||
""" Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedTokenizerBase.from_pretrained = from_pretrained
|
||||
|
||||
|
||||
def patch_model_base():
|
||||
""" Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub.
|
||||
"""
|
||||
ori_from_pretrained = PreTrainedModel.from_pretrained.__func__
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
return ori_from_pretrained(cls, model_dir, *model_args, **kwargs)
|
||||
|
||||
PreTrainedModel.from_pretrained = from_pretrained
|
||||
|
||||
|
||||
patch_tokenizer_base()
|
||||
patch_model_base()
|
||||
|
||||
|
||||
class AutoModel(AutoModelHF):
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
@@ -42,11 +93,13 @@ class AutoModelForCausalLM(AutoModelForCausalLMHF):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
@@ -59,11 +112,13 @@ class AutoModelForSeq2SeqLM(AutoModelForSeq2SeqLMHF):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_file_pattern,
|
||||
user_agent=user_agent())
|
||||
else:
|
||||
model_dir = pretrained_model_name_or_path
|
||||
@@ -76,7 +131,7 @@ class AutoTokenizer(AutoTokenizerHF):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args,
|
||||
**kwargs):
|
||||
ignore_file_pattern = [r'\w+\.bin']
|
||||
ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors']
|
||||
if not os.path.exists(pretrained_model_name_or_path):
|
||||
revision = kwargs.pop('revision', None)
|
||||
model_dir = snapshot_download(
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import LlamaForCausalLM, LlamaTokenizer
|
||||
|
||||
from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoTokenizer, GenerationConfig)
|
||||
|
||||
@@ -42,6 +44,14 @@ class HFUtilTest(unittest.TestCase):
|
||||
revision='v1.0.3')
|
||||
self.assertEqual(gen_config.assistant_token_id, 196)
|
||||
|
||||
def test_transformer_patch(self):
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
'skyline2006/llama-7b', revision='v1.0.1')
|
||||
self.assertIsNotNone(tokenizer)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
'skyline2006/llama-7b', revision='v1.0.1')
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user