mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-17 00:37:43 +01:00
fix
This commit is contained in:
@@ -78,9 +78,10 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
from modelscope import snapshot_download
|
from modelscope import snapshot_download
|
||||||
if not os.path.exists(pretrained_model_name_or_path):
|
if not os.path.exists(pretrained_model_name_or_path):
|
||||||
|
revision = kwargs.pop('revision', None)
|
||||||
model_dir = snapshot_download(
|
model_dir = snapshot_download(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
revision=kwargs.pop('revision', None),
|
revision=revision,
|
||||||
ignore_file_pattern=ignore_file_pattern,
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
allow_file_pattern=allow_file_pattern)
|
allow_file_pattern=allow_file_pattern)
|
||||||
else:
|
else:
|
||||||
@@ -88,23 +89,33 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
return model_dir
|
return model_dir
|
||||||
|
|
||||||
ignore_file_pattern = [
|
ignore_file_pattern = [
|
||||||
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5'
|
r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt', r'\w+\.h5',
|
||||||
|
r'\w+\.ckpt'
|
||||||
]
|
]
|
||||||
|
|
||||||
def patch_pretrained_model_name_or_path(pretrained_model_name_or_path,
|
def patch_pretrained_model_name_or_path(pretrained_model_name_or_path,
|
||||||
*model_args, **kwargs):
|
*model_args, **kwargs):
|
||||||
"""Patch all from_pretrained/get_config_dict"""
|
"""Patch all from_pretrained/get_config_dict"""
|
||||||
model_dir = get_model_dir(pretrained_model_name_or_path, **kwargs)
|
model_dir = get_model_dir(pretrained_model_name_or_path,
|
||||||
|
kwargs.pop('ignore_file_pattern', None),
|
||||||
|
kwargs.pop('allow_file_pattern', None),
|
||||||
|
**kwargs)
|
||||||
return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs)
|
return kwargs.pop('ori_func')(model_dir, *model_args, **kwargs)
|
||||||
|
|
||||||
def patch_peft_model_id(model, model_id, *model_args, **kwargs):
|
def patch_peft_model_id(model, model_id, *model_args, **kwargs):
|
||||||
"""Patch all peft.from_pretrained"""
|
"""Patch all peft.from_pretrained"""
|
||||||
model_dir = get_model_dir(model_id, **kwargs)
|
model_dir = get_model_dir(model_id,
|
||||||
|
kwargs.pop('ignore_file_pattern', None),
|
||||||
|
kwargs.pop('allow_file_pattern', None),
|
||||||
|
**kwargs)
|
||||||
return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs)
|
return kwargs.pop('ori_func')(model, model_dir, *model_args, **kwargs)
|
||||||
|
|
||||||
def _get_peft_type(model_id, **kwargs):
|
def _get_peft_type(model_id, **kwargs):
|
||||||
"""Patch all _get_peft_type"""
|
"""Patch all _get_peft_type"""
|
||||||
model_dir = get_model_dir(model_id, **kwargs)
|
model_dir = get_model_dir(model_id,
|
||||||
|
kwargs.pop('ignore_file_pattern', None),
|
||||||
|
kwargs.pop('allow_file_pattern', None),
|
||||||
|
**kwargs)
|
||||||
return kwargs.pop('ori_func')(model_dir, **kwargs)
|
return kwargs.pop('ori_func')(model_dir, **kwargs)
|
||||||
|
|
||||||
def get_wrapped_class(
|
def get_wrapped_class(
|
||||||
@@ -156,15 +167,22 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_peft_type(cls, model_id, **kwargs):
|
def _get_peft_type(cls, model_id, **kwargs):
|
||||||
model_dir = get_model_dir(model_id, **kwargs)
|
model_dir = get_model_dir(
|
||||||
|
model_id,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
**kwargs)
|
||||||
module_obj = module_class._get_peft_type(model_dir, **kwargs)
|
module_obj = module_class._get_peft_type(model_dir, **kwargs)
|
||||||
return module_obj
|
return module_obj
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_dict(cls, pretrained_model_name_or_path,
|
def get_config_dict(cls, pretrained_model_name_or_path,
|
||||||
*model_args, **kwargs):
|
*model_args, **kwargs):
|
||||||
model_dir = get_model_dir(pretrained_model_name_or_path,
|
model_dir = get_model_dir(
|
||||||
**kwargs)
|
pretrained_model_name_or_path,
|
||||||
|
ignore_file_pattern=ignore_file_pattern,
|
||||||
|
allow_file_pattern=allow_file_pattern,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
module_obj = module_class.get_config_dict(
|
module_obj = module_class.get_config_dict(
|
||||||
model_dir, *model_args, **kwargs)
|
model_dir, *model_args, **kwargs)
|
||||||
@@ -194,7 +212,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
continue
|
continue
|
||||||
name = var.__name__
|
name = var.__name__
|
||||||
need_model = 'model' in name.lower() or 'processor' in name.lower(
|
need_model = 'model' in name.lower() or 'processor' in name.lower(
|
||||||
) or 'extractor' in name.lower()
|
) or 'extractor' in name.lower() or 'pipeline' in name.lower()
|
||||||
if need_model:
|
if need_model:
|
||||||
ignore_file_pattern_kwargs = {}
|
ignore_file_pattern_kwargs = {}
|
||||||
else:
|
else:
|
||||||
@@ -216,7 +234,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
all_available_modules.append(var)
|
all_available_modules.append(var)
|
||||||
else:
|
else:
|
||||||
all_available_modules.append(
|
all_available_modules.append(
|
||||||
get_wrapped_class(var, ignore_file_pattern))
|
get_wrapped_class(var, **ignore_file_pattern_kwargs))
|
||||||
except Exception:
|
except Exception:
|
||||||
all_available_modules.append(var)
|
all_available_modules.append(var)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -154,9 +154,10 @@ class HFUtilTest(unittest.TestCase):
|
|||||||
with patch_context():
|
with patch_context():
|
||||||
from transformers import AutoModelForCausalLM
|
from transformers import AutoModelForCausalLM
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
model = AutoModelForCausalLM.from_pretrained('OpenBMB/MiniCPM3-4B')
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
model = PeftModel.from_pretrained(model,
|
'OpenBMB/MiniCPM3-4B', trust_remote_code=True)
|
||||||
'OpenBMB/MiniCPM3-RAG-LoRA')
|
model = PeftModel.from_pretrained(
|
||||||
|
model, 'OpenBMB/MiniCPM3-RAG-LoRA', trust_remote_code=True)
|
||||||
self.assertTrue(model is not None)
|
self.assertTrue(model is not None)
|
||||||
self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin'))
|
self.assertFalse(hasattr(PeftModel, '_from_pretrained_origin'))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user