diff --git a/modelscope/utils/hf_util.py b/modelscope/utils/hf_util.py index 01aeebef..f5fb8d33 100644 --- a/modelscope/utils/hf_util.py +++ b/modelscope/utils/hf_util.py @@ -127,7 +127,9 @@ def _patch_pretrained_class(): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors'] + ignore_file_pattern = [ + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt' + ] model_dir = get_model_dir(pretrained_model_name_or_path, ignore_file_pattern, **kwargs) return ori_from_pretrained(cls, model_dir, *model_args, **kwargs) @@ -143,14 +145,18 @@ def _patch_pretrained_class(): @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors'] + ignore_file_pattern = [ + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt' + ] model_dir = get_model_dir(pretrained_model_name_or_path, ignore_file_pattern, **kwargs) return ori_from_pretrained(cls, model_dir, *model_args, **kwargs) @classmethod def get_config_dict(cls, pretrained_model_name_or_path, **kwargs): - ignore_file_pattern = [r'\w+\.bin', r'\w+\.safetensors'] + ignore_file_pattern = [ + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt' + ] model_dir = get_model_dir(pretrained_model_name_or_path, ignore_file_pattern, **kwargs) return ori_get_config_dict(cls, model_dir, **kwargs) @@ -242,11 +248,20 @@ AutoModelForTokenClassification = get_wrapped_class( AutoModelForTokenClassificationHF) AutoTokenizer = get_wrapped_class( - AutoTokenizerHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors']) + AutoTokenizerHF, + ignore_file_pattern=[ + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt' + ]) AutoConfig = get_wrapped_class( - AutoConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors']) + AutoConfigHF, + ignore_file_pattern=[ + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt' + ]) GenerationConfig = get_wrapped_class( - GenerationConfigHF, ignore_file_pattern=[r'\w+\.bin', r'\w+\.safetensors']) + GenerationConfigHF, + ignore_file_pattern=[ + r'\w+\.bin', r'\w+\.safetensors', r'\w+\.pth', r'\w+\.pt' + ]) GPTQConfig = GPTQConfigHF AwqConfig = AwqConfigHF BitsAndBytesConfig = BitsAndBytesConfigHF