diff --git a/modelscope/utils/hf_util/patcher.py b/modelscope/utils/hf_util/patcher.py index cc874683..a51f8911 100644 --- a/modelscope/utils/hf_util/patcher.py +++ b/modelscope/utils/hf_util/patcher.py @@ -78,9 +78,10 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): **kwargs): from modelscope import snapshot_download if not os.path.exists(pretrained_model_name_or_path): + revision = kwargs.pop('revision', None) model_dir = snapshot_download( pretrained_model_name_or_path, - revision=kwargs.pop('revision', None), + revision=revision, ignore_file_pattern=ignore_file_pattern, allow_file_pattern=allow_file_pattern) else: @@ -88,23 +89,33 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): return model_dir ignore_file_pattern = [ - r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5', + r'\w+\.ckpt' ] def patch_pretrained_model_name_or_path(pretrained_model_name_or_path, *model_args, **kwargs): """Patch all from_pretrained/get_config_dict""" - model_dir = get_model_dir(pretrained_model_name_or_path, **kwargs) + model_dir = get_model_dir(pretrained_model_name_or_path, + kwargs.pop('ignore_file_pattern', None), + kwargs.pop('allow_file_pattern', None), + **kwargs) return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs) def patch_peft_model_id(model, model_id, *model_args, **kwargs): """Patch all peft.from_pretrained""" - model_dir = get_model_dir(model_id, **kwargs) + model_dir = get_model_dir(model_id, + kwargs.pop('ignore_file_pattern', None), + kwargs.pop('allow_file_pattern', None), + **kwargs) return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs) def _get_peft_type(model_id, **kwargs): """Patch all _get_peft_type""" - model_dir = get_model_dir(model_id, **kwargs) + model_dir = get_model_dir(model_id, + kwargs.pop('ignore_file_pattern', None), + kwargs.pop('allow_file_pattern', None), + **kwargs) return kwargs.pop('ori_func')(model_dir, **kwargs) def get_wrapped_class( @@ -156,15 +167,22 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): @classmethod def _get_peft_type(cls, model_id, **kwargs): - model_dir = get_model_dir(model_id, **kwargs) + model_dir = get_model_dir( + model_id, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + **kwargs) module_obj = module_class._get_peft_type(model_dir, **kwargs) return module_obj @classmethod def get_config_dict(cls, pretrained_model_name_or_path, *model_args, **kwargs): - model_dir = get_model_dir(pretrained_model_name_or_path, - **kwargs) + model_dir = get_model_dir( + pretrained_model_name_or_path, + ignore_file_pattern=ignore_file_pattern, + allow_file_pattern=allow_file_pattern, + **kwargs) module_obj = module_class.get_config_dict( model_dir, *model_args, **kwargs) @@ -194,7 +212,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): continue name = var.__name__ need_model = 'model' in name.lower() or 'processor' in name.lower( - ) or 'extractor' in name.lower() + ) or 'extractor' in name.lower() or 'pipeline' in name.lower() if need_model: ignore_file_pattern_kwargs = {} else: @@ -216,7 +234,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False): all_available_modules.append(var) else: all_available_modules.append( - get_wrapped_class(var, ignore_file_pattern)) + get_wrapped_class(var, **ignore_file_pattern_kwargs)) except Exception: all_available_modules.append(var) else: diff --git a/tests/utils/test_hf_util.py b/tests/utils/test_hf_util.py index 6b5d39ed..87650c5c 100644 --- a/tests/utils/test_hf_util.py +++ b/tests/utils/test_hf_util.py @@ -154,9 +154,10 @@ class HFUtilTest(unittest.TestCase): with patch_context(): from transformers import AutoModelForCausalLM from peft import PeftModel - model = AutoModelForCausalLM.from_pretrained('OpenBMB/MiniCPM3-4B') - model = PeftModel.from_pretrained(model, - 'OpenBMB/MiniCPM3-RAG-LoRA') + model = AutoModelForCausalLM.from_pretrained( + 'OpenBMB/MiniCPM3-4B', trust_remote_code=True) + model = PeftModel.from_pretrained( + model, 'OpenBMB/MiniCPM3-RAG-LoRA', trust_remote_code=True) self.assertTrue(model is not None) self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin'))