This commit is contained in:
suluyan
2025-02-11 10:25:41 +08:00
parent d46dcbf2a3
commit 462eaab3cf
4 changed files with 153 additions and 4 deletions

View File

@@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton):
git_env = os.environ.copy()
git_env['GIT_TERMINAL_PROMPT'] = '0'
command = [self.git_path, *args]
command = [item for item in command if item]
response = subprocess.run(
command,
stdout=subprocess.PIPE,

View File

@@ -1,2 +1,3 @@
from .auto_class import *
from .patcher import patch_context, patch_hub, unpatch_hub
from .pipeline_builder import hf_pipeline

View File

@@ -26,7 +26,7 @@ def get_all_imported_modules():
all_imported_modules = []
transformers_include_names = [
'Auto', 'T5', 'BitsAndBytes', 'GenerationConfig', 'Quant', 'Awq',
'GPTQ', 'BatchFeature', 'Qwen', 'Llama'
'GPTQ', 'BatchFeature', 'Qwen', 'Llama', 'Pipeline'
]
diffusers_include_names = ['Pipeline']
if importlib.util.find_spec('transformers') is not None:
@@ -144,6 +144,35 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
**kwargs)
return kwargs.pop('ori_func')(model_dir, **kwargs)
def save_pretrained(save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
**kwargs):
obj = kwargs.pop('obj')
push_to_hub = kwargs.pop('push_to_hub', False)
obj._save_pretrained_origin(obj,
save_directory=save_directory,
safe_serialization=safe_serialization,
push_to_hub=False,
**kwargs)
# Class members may be unpatched, so push_to_hub is done separately here
if push_to_hub:
from modelscope.hub.push_to_hub import push_to_hub
from modelscope.hub.api import HubApi
api = HubApi()
token = kwargs.get("token")
commit_message = kwargs.pop("commit_message", None)
repo_name = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
api.create_repo(repo_name, **kwargs)
push_to_hub(repo_name=repo_name,
output_dir=save_directory,
commit_message=commit_message,
token=token)
#return kwargs.pop('ori_func')(obj, save_directory, safe_serialization, **kwargs)
def get_wrapped_class(
module_class: 'PreTrainedModel',
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
@@ -214,6 +243,56 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
model_dir, *model_args, **kwargs)
return module_obj
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = True,
**kwargs,
):
push_to_hub = kwargs.pop('push_to_hub', False)
if push_to_hub:
import json
from modelscope.hub.repository import Repository
from modelscope.hub.utils.utils import add_content_to_file
from modelscope.hub.push_to_hub import push_to_hub
from modelscope.hub.api import HubApi
api = HubApi()
token = kwargs.get("token")
api.login(token)
commit_message = kwargs.pop("commit_message", None)
repo_name = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
api.create_repo(repo_name)
repo = Repository(save_directory, repo_name)
default_config = {
'framework': 'pytorch',
'task': 'text-generation',
'allow_remote': True
}
config_json = kwargs.get('config_json')
if not config_json:
config_json = {}
config = {**default_config, **config_json}
add_content_to_file(
repo,
'configuration.json', [json.dumps(config)],
ignore_push_error=True)
super().save_pretrained(save_directory=save_directory,
safe_serialization=safe_serialization,
push_to_hub=False,
**kwargs)
# Class members may be unpatched, so push_to_hub is done separately here
if push_to_hub:
#api.create_repo(repo_name, **kwargs)
push_to_hub(repo_name=repo_name,
output_dir=save_directory,
commit_message=commit_message,
token=token)
if not hasattr(module_class, 'from_pretrained'):
del ClassWrapper.from_pretrained
else:
@@ -228,6 +307,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
if not hasattr(module_class, 'get_config_dict'):
del ClassWrapper.get_config_dict
if not hasattr(module_class, 'save_pretrained'):
del ClassWrapper.save_pretrained
ClassWrapper.__name__ = module_class.__name__
ClassWrapper.__qualname__ = module_class.__qualname__
return ClassWrapper
@@ -251,17 +333,21 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
has_save_pretrained = hasattr(var, 'save_pretrained')
except ImportError:
continue
if wrap:
# save_pretrained is not a classmethod and cannot be overridden by replacing the class method. It requires replacing the class object method.
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
print(f'var wrap: {var}')
try:
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type and not has_save_pretrained:
all_available_modules.append(var)
else:
all_available_modules.append(
get_wrapped_class(var, **ignore_file_pattern_kwargs))
except Exception:
except Exception as e:
print(f'wrap failed: {e}')
all_available_modules.append(var)
else:
if has_from_pretrained and not hasattr(var,
@@ -295,6 +381,14 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
ori_func=var._get_config_dict_origin,
**ignore_file_pattern_kwargs)
if has_save_pretrained and not hasattr(var, '_save_pretrained_origin'):
var._save_pretrained_origin = var.save_pretrained
var.save_pretrained = partial(
save_pretrained,
ori_func=var._save_pretrained_origin,
obj=var,
**ignore_file_pattern_kwargs)
all_available_modules.append(var)
return all_available_modules
@@ -308,6 +402,7 @@ def _unpatch_pretrained_class(all_imported_modules):
has_from_pretrained = hasattr(var, 'from_pretrained')
has_get_peft_type = hasattr(var, '_get_peft_type')
has_get_config_dict = hasattr(var, 'get_config_dict')
has_save_pretrained = hasattr(var, 'save_pretrained')
except ImportError:
continue
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
@@ -319,6 +414,9 @@ def _unpatch_pretrained_class(all_imported_modules):
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
var.get_config_dict = var._get_config_dict_origin
delattr(var, '_get_config_dict_origin')
if has_save_pretrained and hasattr(var, '_save_pretrained_origin'):
var.save_pretrained = var._save_pretrained_origin
delattr(var, '_save_pretrained_origin')
def _patch_hub():

View File

@@ -0,0 +1,49 @@
from typing import Optional, Union
import os
import torch
from modelscope.hub import snapshot_download
from transformers import (TFPreTrainedModel, PreTrainedModel, pipeline)
from transformers.pipelines import get_task, check_task
from transformers import Pipeline as PipelineHF
from modelscope.utils.hf_util.patcher import patch_hub, _patch_pretrained_class
def _get_hf_device(device):
if isinstance(device, str):
device_name = device.lower()
eles = device_name.split(':')
if eles[0] == 'gpu':
eles = ['cuda'] + eles[1:]
device = ''.join(eles)
return device
def _get_hf_pipeline_class(task, model):
if not task:
task = get_task(model)
normalized_task, targeted_task, task_options = check_task(task)
pipeline_class = targeted_task["impl"]
pipeline_class = _patch_pretrained_class([pipeline_class])[0]
return pipeline_class
def hf_pipeline(
task: str = None,
model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
framework: Optional[str] = None,
device: Optional[Union[int, str, "torch.device"]] = None,
**kwargs,
) -> PipelineHF:
if isinstance(model, str):
if not os.path.exists(model):
model = snapshot_download(model)
framework = 'pt' if framework == 'pytorch' else framework
device = _get_hf_device(device)
pipeline_class = _get_hf_pipeline_class(task, model)
return pipeline(task=task,
model=model,
framework=framework,
device=device,
pipeline_class=pipeline_class,
#pipeline_class=QuestionAnsweringPipeline,
**kwargs)