mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
update pipeline builder
This commit is contained in:
@@ -27,7 +27,8 @@ def get_all_imported_modules():
|
|||||||
transformers_include_names = [
|
transformers_include_names = [
|
||||||
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
|
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
|
||||||
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
|
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
|
||||||
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast'
|
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast',
|
||||||
|
'Pipeline'
|
||||||
]
|
]
|
||||||
peft_include_names = ['.*PeftModel.*', '.*Config']
|
peft_include_names = ['.*PeftModel.*', '.*Config']
|
||||||
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
|
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
|
||||||
@@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
model_dir, *model_args, **kwargs)
|
model_dir, *model_args, **kwargs)
|
||||||
return module_obj
|
return module_obj
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
safe_serialization: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
push_to_hub = kwargs.pop('push_to_hub', False)
|
||||||
|
if push_to_hub:
|
||||||
|
from modelscope.hub.push_to_hub import push_to_hub
|
||||||
|
from modelscope.hub.api import HubApi
|
||||||
|
from modelscope.hub.repository import Repository
|
||||||
|
|
||||||
|
token = kwargs.get('token')
|
||||||
|
commit_message = kwargs.pop('commit_message', None)
|
||||||
|
repo_name = kwargs.pop(
|
||||||
|
'repo_id',
|
||||||
|
save_directory.split(os.path.sep)[-1])
|
||||||
|
|
||||||
|
api = HubApi()
|
||||||
|
api.login(token)
|
||||||
|
api.create_repo(repo_name)
|
||||||
|
# clone the repo
|
||||||
|
Repository(save_directory, repo_name)
|
||||||
|
|
||||||
|
super().save_pretrained(
|
||||||
|
save_directory=save_directory,
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
push_to_hub=False,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# Class members may be unpatched, so push_to_hub is done separately here
|
||||||
|
if push_to_hub:
|
||||||
|
push_to_hub(
|
||||||
|
repo_name=repo_name,
|
||||||
|
output_dir=save_directory,
|
||||||
|
commit_message=commit_message,
|
||||||
|
token=token)
|
||||||
|
|
||||||
if not hasattr(module_class, 'from_pretrained'):
|
if not hasattr(module_class, 'from_pretrained'):
|
||||||
del ClassWrapper.from_pretrained
|
del ClassWrapper.from_pretrained
|
||||||
else:
|
else:
|
||||||
@@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
if not hasattr(module_class, 'get_config_dict'):
|
if not hasattr(module_class, 'get_config_dict'):
|
||||||
del ClassWrapper.get_config_dict
|
del ClassWrapper.get_config_dict
|
||||||
|
|
||||||
|
if not hasattr(module_class, 'save_pretrained'):
|
||||||
|
del ClassWrapper.save_pretrained
|
||||||
|
|
||||||
ClassWrapper.__name__ = module_class.__name__
|
ClassWrapper.__name__ = module_class.__name__
|
||||||
ClassWrapper.__qualname__ = module_class.__qualname__
|
ClassWrapper.__qualname__ = module_class.__qualname__
|
||||||
return ClassWrapper
|
return ClassWrapper
|
||||||
@@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||||
|
has_save_pretrained = hasattr(var, 'save_pretrained')
|
||||||
except: # noqa
|
except: # noqa
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if wrap:
|
# save_pretrained is not a classmethod and cannot be overridden by replacing
|
||||||
|
# the class method. It requires replacing the class object method.
|
||||||
|
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
|
||||||
try:
|
try:
|
||||||
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
|
if (not has_from_pretrained and not has_get_config_dict
|
||||||
|
and not has_get_peft_type and not has_save_pretrained):
|
||||||
all_available_modules.append(var)
|
all_available_modules.append(var)
|
||||||
else:
|
else:
|
||||||
all_available_modules.append(
|
all_available_modules.append(
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Pipeline as PipelineHF
|
||||||
|
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
|
||||||
|
from transformers.pipelines import check_task, get_task
|
||||||
|
|
||||||
from modelscope.hub import snapshot_download
|
from modelscope.hub import snapshot_download
|
||||||
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
|
from modelscope.utils.hf_util.patcher import _patch_pretrained_class
|
||||||
|
|
||||||
@@ -16,7 +21,6 @@ def _get_hf_device(device):
|
|||||||
|
|
||||||
|
|
||||||
def _get_hf_pipeline_class(task, model):
|
def _get_hf_pipeline_class(task, model):
|
||||||
from transformers.pipelines import check_task, get_task
|
|
||||||
if not task:
|
if not task:
|
||||||
task = get_task(model)
|
task = get_task(model)
|
||||||
normalized_task, targeted_task, task_options = check_task(task)
|
normalized_task, targeted_task, task_options = check_task(task)
|
||||||
@@ -31,9 +35,7 @@ def hf_pipeline(
|
|||||||
framework: Optional[str] = None,
|
framework: Optional[str] = None,
|
||||||
device: Optional[Union[int, str, 'torch.device']] = None,
|
device: Optional[Union[int, str, 'torch.device']] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> 'transformers.Pipeline':
|
) -> PipelineHF:
|
||||||
from transformers import pipeline
|
|
||||||
|
|
||||||
if isinstance(model, str):
|
if isinstance(model, str):
|
||||||
if not os.path.exists(model):
|
if not os.path.exists(model):
|
||||||
model = snapshot_download(model)
|
model = snapshot_download(model)
|
||||||
|
|||||||
Reference in New Issue
Block a user