From 97d22ade76d15771564944210fcd8e117a83db4e Mon Sep 17 00:00:00 2001 From: "wenmeng.zwm" Date: Wed, 19 Jul 2023 21:04:36 +0800 Subject: [PATCH] add monkey patch for base tokenizer and base model --- modelscope/utils/hf_util.py | 59 +++++++++++++++++++++++++++++++++++-- tests/utils/test_hf_util.py | 10 +++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index 1893daf1..8e0a9ef1 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -8,6 +8,7 @@ from transformers import AutoModelForCausalLM as AutoModelForCausalLMHF from transformers import AutoModelForSeq2SeqLM as AutoModelForSeq2SeqLMHF from transformers import AutoTokenizer as AutoTokenizerHF from transformers import GenerationConfig as GenerationConfigHF +from transformers import PreTrainedModel, PreTrainedTokenizerBase from modelscope import snapshot_download from modelscope.utils.constant import Invoke @@ -20,16 +21,66 @@ def user_agent(invoked_by=None): return uagent -class AutoModel(AutoModelHF): +def patch_tokenizer_base(): + """ Monkey patch PreTrainedTokenizerBase.from_pretrained to adapt to modelscope hub. + """ + ori_from_pretrained = PreTrainedTokenizerBase.from_pretrained.__func__ @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors'] if not os.path.exists(pretrained_model_name_or_path): revision = kwargs.pop('revision', None) model_dir = snapshot_download( pretrained_model_name_or_path, revision=revision, + ignore_file_pattern=ignore_file_pattern) + else: + model_dir = pretrained_model_name_or_path + return ori_from_pretrained(cls, model_dir, *model_args, **kwargs) + + PreTrainedTokenizerBase.from_pretrained = from_pretrained + + +def patch_model_base(): + """ Monkey patch PreTrainedModel.from_pretrained to adapt to modelscope hub. + """ + ori_from_pretrained = PreTrainedModel.from_pretrained.__func__ + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + ignore_file_pattern = [r'\w+\.safetensors'] + if not os.path.exists(pretrained_model_name_or_path): + revision = kwargs.pop('revision', None) + model_dir = snapshot_download( + pretrained_model_name_or_path, + revision=revision, + ignore_file_pattern=ignore_file_pattern) + else: + model_dir = pretrained_model_name_or_path + return ori_from_pretrained(cls, model_dir, *model_args, **kwargs) + + PreTrainedModel.from_pretrained = from_pretrained + + +patch_tokenizer_base() +patch_model_base() + + +class AutoModel(AutoModelHF): + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, + **kwargs): + ignore_file_pattern = [r'\w+\.safetensors'] + if not os.path.exists(pretrained_model_name_or_path): + revision = kwargs.pop('revision', None) + model_dir = snapshot_download( + pretrained_model_name_or_path, + revision=revision, + ignore_file_pattern=ignore_file_pattern, user_agent=user_agent()) else: model_dir = pretrained_model_name_or_path @@ -42,11 +93,13 @@ class AutoModelForCausalLM(AutoModelForCausalLMHF): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + ignore_file_pattern = [r'\w+\.safetensors'] if not os.path.exists(pretrained_model_name_or_path): revision = kwargs.pop('revision', None) model_dir = snapshot_download( pretrained_model_name_or_path, revision=revision, + ignore_file_pattern=ignore_file_pattern, user_agent=user_agent()) else: model_dir = pretrained_model_name_or_path @@ -59,11 +112,13 @@ class AutoModelForSeq2SeqLM(AutoModelForSeq2SeqLMHF): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + ignore_file_pattern = [r'\w+\.safetensors'] if not os.path.exists(pretrained_model_name_or_path): revision = kwargs.pop('revision', None) model_dir = snapshot_download( pretrained_model_name_or_path, revision=revision, + ignore_file_pattern=ignore_file_pattern, user_agent=user_agent()) else: model_dir = pretrained_model_name_or_path @@ -76,7 +131,7 @@ class AutoTokenizer(AutoTokenizerHF): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - ignore_file_pattern = [r'\w+\.bin'] + ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors'] if not os.path.exists(pretrained_model_name_or_path): revision = kwargs.pop('revision', None) model_dir = snapshot_download( diff --git a/tests/utils/test_hf_util.py b/tests/utils/test_hf_util.py index 50ecbaa5..b0279c6a 100644 --- a/tests/utils/test_hf_util.py +++ b/tests/utils/test_hf_util.py @@ -2,6 +2,8 @@ import unittest +from transformers import LlamaForCausalLM, LlamaTokenizer + from modelscope import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig) @@ -42,6 +44,14 @@ class HFUtilTest(unittest.TestCase): revision='v1.0.3') self.assertEqual(gen_config.assistant_token_id, 196) + def test_transformer_patch(self): + tokenizer = LlamaTokenizer.from_pretrained( + 'skyline2006/llama-7b', revision='v1.0.1') + self.assertIsNotNone(tokenizer) + model = LlamaForCausalLM.from_pretrained( + 'skyline2006/llama-7b', revision='v1.0.1') + self.assertIsNotNone(model) + if __name__ == '__main__': unittest.main()