This commit is contained in:
tastelikefeet
2025-01-26 18:31:41 +08:00
parent 22b7b25f44
commit d0cccf64af
2 changed files with 32 additions and 13 deletions

View File

@@ -78,9 +78,10 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
**kwargs): **kwargs):
from modelscope import snapshot_download from modelscope import snapshot_download
if not os.path.exists(pretrained_model_name_or_path): if not os.path.exists(pretrained_model_name_or_path):
revision = kwargs.pop('revision', None)
model_dir = snapshot_download( model_dir = snapshot_download(
pretrained_model_name_or_path, pretrained_model_name_or_path,
revision=kwargs.pop('revision', None), revision=revision,
ignore_file_pattern=ignore_file_pattern, ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern) allow_file_pattern=allow_file_pattern)
else: else:
@@ -88,23 +89,33 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
return model_dir return model_dir
ignore_file_pattern = [ ignore_file_pattern = [
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5' r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5',
r'\w+\.ckpt'
] ]
def patch_pretrained_model_name_or_path(pretrained_model_name_or_path, def patch_pretrained_model_name_or_path(pretrained_model_name_or_path,
*model_args, **kwargs): *model_args, **kwargs):
"""Patch all from_pretrained/get_config_dict""" """Patch all from_pretrained/get_config_dict"""
model_dir = get_model_dir(pretrained_model_name_or_path, **kwargs) model_dir = get_model_dir(pretrained_model_name_or_path,
kwargs.pop('ignore_file_pattern', None),
kwargs.pop('allow_file_pattern', None),
**kwargs)
return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs) return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs)
def patch_peft_model_id(model, model_id, *model_args, **kwargs): def patch_peft_model_id(model, model_id, *model_args, **kwargs):
"""Patch all peft.from_pretrained""" """Patch all peft.from_pretrained"""
model_dir = get_model_dir(model_id, **kwargs) model_dir = get_model_dir(model_id,
kwargs.pop('ignore_file_pattern', None),
kwargs.pop('allow_file_pattern', None),
**kwargs)
return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs) return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs)
def _get_peft_type(model_id, **kwargs): def _get_peft_type(model_id, **kwargs):
"""Patch all _get_peft_type""" """Patch all _get_peft_type"""
model_dir = get_model_dir(model_id, **kwargs) model_dir = get_model_dir(model_id,
kwargs.pop('ignore_file_pattern', None),
kwargs.pop('allow_file_pattern', None),
**kwargs)
return kwargs.pop('ori_func')(model_dir, **kwargs) return kwargs.pop('ori_func')(model_dir, **kwargs)
def get_wrapped_class( def get_wrapped_class(
@@ -156,15 +167,22 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
@classmethod @classmethod
def _get_peft_type(cls, model_id, **kwargs): def _get_peft_type(cls, model_id, **kwargs):
model_dir = get_model_dir(model_id, **kwargs) model_dir = get_model_dir(
model_id,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
**kwargs)
module_obj = module_class._get_peft_type(model_dir, **kwargs) module_obj = module_class._get_peft_type(model_dir, **kwargs)
return module_obj return module_obj
@classmethod @classmethod
def get_config_dict(cls, pretrained_model_name_or_path, def get_config_dict(cls, pretrained_model_name_or_path,
*model_args, **kwargs): *model_args, **kwargs):
model_dir = get_model_dir(pretrained_model_name_or_path, model_dir = get_model_dir(
**kwargs) pretrained_model_name_or_path,
ignore_file_pattern=ignore_file_pattern,
allow_file_pattern=allow_file_pattern,
**kwargs)
module_obj = module_class.get_config_dict( module_obj = module_class.get_config_dict(
model_dir, *model_args, **kwargs) model_dir, *model_args, **kwargs)
@@ -194,7 +212,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
continue continue
name = var.__name__ name = var.__name__
need_model = 'model' in name.lower() or 'processor' in name.lower( need_model = 'model' in name.lower() or 'processor' in name.lower(
) or 'extractor' in name.lower() ) or 'extractor' in name.lower() or 'pipeline' in name.lower()
if need_model: if need_model:
ignore_file_pattern_kwargs = {} ignore_file_pattern_kwargs = {}
else: else:
@@ -216,7 +234,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
all_available_modules.append(var) all_available_modules.append(var)
else: else:
all_available_modules.append( all_available_modules.append(
get_wrapped_class(var, ignore_file_pattern)) get_wrapped_class(var, **ignore_file_pattern_kwargs))
except Exception: except Exception:
all_available_modules.append(var) all_available_modules.append(var)
else: else:

View File

@@ -154,9 +154,10 @@ class HFUtilTest(unittest.TestCase):
with patch_context(): with patch_context():
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from peft import PeftModel from peft import PeftModel
model = AutoModelForCausalLM.from_pretrained('OpenBMB/MiniCPM3-4B') model = AutoModelForCausalLM.from_pretrained(
model = PeftModel.from_pretrained(model, 'OpenBMB/MiniCPM3-4B', trust_remote_code=True)
'OpenBMB/MiniCPM3-RAG-LoRA') model = PeftModel.from_pretrained(
model, 'OpenBMB/MiniCPM3-RAG-LoRA', trust_remote_code=True)
self.assertTrue(model is not None) self.assertTrue(model is not None)
self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin')) self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin'))