update pipeline builder

This commit is contained in:
xingjun.wxj
2025-03-13 15:07:26 +08:00
2 changed files with 55 additions and 7 deletions

View File

@@ -27,7 +27,8 @@ def get_all_imported_modules():
transformers_include_names = [ transformers_include_names = [
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*', 'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig', 'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast' 'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast',
'Pipeline'
] ]
peft_include_names = ['.*PeftModel.*', '.*Config'] peft_include_names = ['.*PeftModel.*', '.*Config']
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$'] diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
@@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
model_dir, *model_args, **kwargs) model_dir, *model_args, **kwargs)
return module_obj return module_obj
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
**kwargs,
):
push_to_hub = kwargs.pop('push_to_hub', False)
if push_to_hub:
from modelscope.hub.push_to_hub import push_to_hub
from modelscope.hub.api import HubApi
from modelscope.hub.repository import Repository
token = kwargs.get('token')
commit_message = kwargs.pop('commit_message', None)
repo_name = kwargs.pop(
'repo_id',
save_directory.split(os.path.sep)[-1])
api = HubApi()
api.login(token)
api.create_repo(repo_name)
# clone the repo
Repository(save_directory, repo_name)
super().save_pretrained(
save_directory=save_directory,
safe_serialization=safe_serialization,
push_to_hub=False,
**kwargs)
# Class members may be unpatched, so push_to_hub is done separately here
if push_to_hub:
push_to_hub(
repo_name=repo_name,
output_dir=save_directory,
commit_message=commit_message,
token=token)
if not hasattr(module_class, 'from_pretrained'): if not hasattr(module_class, 'from_pretrained'):
del ClassWrapper.from_pretrained del ClassWrapper.from_pretrained
else: else:
@@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
if not hasattr(module_class, 'get_config_dict'): if not hasattr(module_class, 'get_config_dict'):
del ClassWrapper.get_config_dict del ClassWrapper.get_config_dict
if not hasattr(module_class, 'save_pretrained'):
del ClassWrapper.save_pretrained
ClassWrapper.__name__ = module_class.__name__ ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__ ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper return ClassWrapper
@@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
has_from_pretrained = hasattr(var, 'from_pretrained') has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type') has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict') has_get_config_dict = hasattr(var, 'get_config_dict')
has_save_pretrained = hasattr(var, 'save_pretrained')
except: # noqa except: # noqa
continue continue
if wrap: # save_pretrained is not a classmethod and cannot be overridden by replacing
# the class method. It requires replacing the class object method.
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
try: try:
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type: if (not has_from_pretrained and not has_get_config_dict
and not has_get_peft_type and not has_save_pretrained):
all_available_modules.append(var) all_available_modules.append(var)
else: else:
all_available_modules.append( all_available_modules.append(

View File

@@ -1,6 +1,11 @@
import os import os
from typing import Optional, Union from typing import Optional, Union
import torch
from transformers import Pipeline as PipelineHF
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
from transformers.pipelines import check_task, get_task
from modelscope.hub import snapshot_download from modelscope.hub import snapshot_download
from modelscope.utils.hf_util.patcher import _patch_pretrained_class from modelscope.utils.hf_util.patcher import _patch_pretrained_class
@@ -16,7 +21,6 @@ def _get_hf_device(device):
def _get_hf_pipeline_class(task, model): def _get_hf_pipeline_class(task, model):
from transformers.pipelines import check_task, get_task
if not task: if not task:
task = get_task(model) task = get_task(model)
normalized_task, targeted_task, task_options = check_task(task) normalized_task, targeted_task, task_options = check_task(task)
@@ -31,9 +35,7 @@ def hf_pipeline(
framework: Optional[str] = None, framework: Optional[str] = None,
device: Optional[Union[int, str, 'torch.device']] = None, device: Optional[Union[int, str, 'torch.device']] = None,
**kwargs, **kwargs,
) -> 'transformers.Pipeline': ) -> PipelineHF:
from transformers import pipeline
if isinstance(model, str): if isinstance(model, str):
if not os.path.exists(model): if not os.path.exists(model):
model = snapshot_download(model) model = snapshot_download(model)