Merge remote-tracking branch 'origin' into fix/trust_remote_code_

This commit is contained in:
suluyan
2025-03-13 15:36:19 +08:00
2 changed files with 4 additions and 11 deletions

View File

@@ -345,7 +345,7 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
else:
all_available_modules.append(
get_wrapped_class(var, **ignore_file_pattern_kwargs))
except Exception:
except: # noqa
all_available_modules.append(var)
else:
if has_from_pretrained and not hasattr(var,
@@ -370,10 +370,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
if has_get_config_dict and not hasattr(var,
'_get_config_dict_origin'):
var._get_config_dict_origin = var.get_config_dict
var.get_config_dict = partial(
patch_pretrained_model_name_or_path,
ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs)
var.get_config_dict = classmethod(
partial(patch_get_config_dict,
**ignore_file_pattern_kwargs))
all_available_modules.append(var)
return all_available_modules
@@ -619,11 +618,6 @@ def _patch_hub():
# Patch repocard.validate
from huggingface_hub import repocard
if not hasattr(repocard.RepoCard, '_validate_origin'):
def load(*args, **kwargs): # noqa
from huggingface_hub.errors import EntryNotFoundError
raise EntryNotFoundError(message='API not supported.')
repocard.RepoCard._validate_origin = repocard.RepoCard.validate
repocard.RepoCard.validate = lambda *args, **kwargs: None
repocard.RepoCard._load_origin = repocard.RepoCard.load

View File

@@ -33,7 +33,6 @@ def hf_pipeline(
**kwargs,
) -> 'transformers.Pipeline':
from transformers import pipeline
if isinstance(model, str):
if not os.path.exists(model):
model = snapshot_download(model)