mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
runable
This commit is contained in:
@@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton):
|
||||
git_env = os.environ.copy()
|
||||
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
||||
command = [self.git_path, *args]
|
||||
command = [item for item in command if item]
|
||||
response = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .auto_class import *
|
||||
from .patcher import patch_context, patch_hub, unpatch_hub
|
||||
from .pipeline_builder import hf_pipeline
|
||||
@@ -26,7 +26,7 @@ def get_all_imported_modules():
|
||||
all_imported_modules = []
|
||||
transformers_include_names = [
|
||||
'Auto', 'T5', 'BitsAndBytes', 'GenerationConfig', 'Quant', 'Awq',
|
||||
'GPTQ', 'BatchFeature', 'Qwen', 'Llama'
|
||||
'GPTQ', 'BatchFeature', 'Qwen', 'Llama', 'Pipeline'
|
||||
]
|
||||
diffusers_include_names = ['Pipeline']
|
||||
if importlib.util.find_spec('transformers') is not None:
|
||||
@@ -144,6 +144,35 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
**kwargs)
|
||||
return kwargs.pop('ori_func')(model_dir, **kwargs)
|
||||
|
||||
def save_pretrained(save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = True,
|
||||
**kwargs):
|
||||
obj = kwargs.pop('obj')
|
||||
push_to_hub = kwargs.pop('push_to_hub', False)
|
||||
|
||||
obj._save_pretrained_origin(obj,
|
||||
save_directory=save_directory,
|
||||
safe_serialization=safe_serialization,
|
||||
push_to_hub=False,
|
||||
**kwargs)
|
||||
|
||||
# Class members may be unpatched, so push_to_hub is done separately here
|
||||
if push_to_hub:
|
||||
from modelscope.hub.push_to_hub import push_to_hub
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
|
||||
token = kwargs.get("token")
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo_name = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
api.create_repo(repo_name, **kwargs)
|
||||
|
||||
push_to_hub(repo_name=repo_name,
|
||||
output_dir=save_directory,
|
||||
commit_message=commit_message,
|
||||
token=token)
|
||||
#return kwargs.pop('ori_func')(obj, save_directory, safe_serialization, **kwargs)
|
||||
|
||||
def get_wrapped_class(
|
||||
module_class: 'PreTrainedModel',
|
||||
ignore_file_pattern: Optional[Union[str, List[str]]] = None,
|
||||
@@ -214,6 +243,56 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
model_dir, *model_args, **kwargs)
|
||||
return module_obj
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
push_to_hub = kwargs.pop('push_to_hub', False)
|
||||
if push_to_hub:
|
||||
import json
|
||||
from modelscope.hub.repository import Repository
|
||||
from modelscope.hub.utils.utils import add_content_to_file
|
||||
from modelscope.hub.push_to_hub import push_to_hub
|
||||
from modelscope.hub.api import HubApi
|
||||
api = HubApi()
|
||||
|
||||
token = kwargs.get("token")
|
||||
api.login(token)
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
repo_name = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
api.create_repo(repo_name)
|
||||
repo = Repository(save_directory, repo_name)
|
||||
default_config = {
|
||||
'framework': 'pytorch',
|
||||
'task': 'text-generation',
|
||||
'allow_remote': True
|
||||
}
|
||||
config_json = kwargs.get('config_json')
|
||||
if not config_json:
|
||||
config_json = {}
|
||||
config = {**default_config, **config_json}
|
||||
add_content_to_file(
|
||||
repo,
|
||||
'configuration.json', [json.dumps(config)],
|
||||
ignore_push_error=True)
|
||||
super().save_pretrained(save_directory=save_directory,
|
||||
safe_serialization=safe_serialization,
|
||||
push_to_hub=False,
|
||||
**kwargs)
|
||||
|
||||
# Class members may be unpatched, so push_to_hub is done separately here
|
||||
if push_to_hub:
|
||||
|
||||
#api.create_repo(repo_name, **kwargs)
|
||||
|
||||
|
||||
push_to_hub(repo_name=repo_name,
|
||||
output_dir=save_directory,
|
||||
commit_message=commit_message,
|
||||
token=token)
|
||||
|
||||
if not hasattr(module_class, 'from_pretrained'):
|
||||
del ClassWrapper.from_pretrained
|
||||
else:
|
||||
@@ -228,6 +307,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
if not hasattr(module_class, 'get_config_dict'):
|
||||
del ClassWrapper.get_config_dict
|
||||
|
||||
if not hasattr(module_class, 'save_pretrained'):
|
||||
del ClassWrapper.save_pretrained
|
||||
|
||||
ClassWrapper.__name__ = module_class.__name__
|
||||
ClassWrapper.__qualname__ = module_class.__qualname__
|
||||
return ClassWrapper
|
||||
@@ -251,17 +333,21 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
has_save_pretrained = hasattr(var, 'save_pretrained')
|
||||
except ImportError:
|
||||
continue
|
||||
|
||||
if wrap:
|
||||
# save_pretrained is not a classmethod and cannot be overridden by replacing the class method. It requires replacing the class object method.
|
||||
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
|
||||
print(f'var wrap: {var}')
|
||||
try:
|
||||
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
|
||||
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type and not has_save_pretrained:
|
||||
all_available_modules.append(var)
|
||||
else:
|
||||
all_available_modules.append(
|
||||
get_wrapped_class(var, **ignore_file_pattern_kwargs))
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
print(f'wrap failed: {e}')
|
||||
all_available_modules.append(var)
|
||||
else:
|
||||
if has_from_pretrained and not hasattr(var,
|
||||
@@ -295,6 +381,14 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
ori_func=var._get_config_dict_origin,
|
||||
**ignore_file_pattern_kwargs)
|
||||
|
||||
if has_save_pretrained and not hasattr(var, '_save_pretrained_origin'):
|
||||
var._save_pretrained_origin = var.save_pretrained
|
||||
var.save_pretrained = partial(
|
||||
save_pretrained,
|
||||
ori_func=var._save_pretrained_origin,
|
||||
obj=var,
|
||||
**ignore_file_pattern_kwargs)
|
||||
|
||||
all_available_modules.append(var)
|
||||
return all_available_modules
|
||||
|
||||
@@ -308,6 +402,7 @@ def _unpatch_pretrained_class(all_imported_modules):
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
has_save_pretrained = hasattr(var, 'save_pretrained')
|
||||
except ImportError:
|
||||
continue
|
||||
if has_from_pretrained and hasattr(var, '_from_pretrained_origin'):
|
||||
@@ -319,6 +414,9 @@ def _unpatch_pretrained_class(all_imported_modules):
|
||||
if has_get_config_dict and hasattr(var, '_get_config_dict_origin'):
|
||||
var.get_config_dict = var._get_config_dict_origin
|
||||
delattr(var, '_get_config_dict_origin')
|
||||
if has_save_pretrained and hasattr(var, '_save_pretrained_origin'):
|
||||
var.save_pretrained = var._save_pretrained_origin
|
||||
delattr(var, '_save_pretrained_origin')
|
||||
|
||||
|
||||
def _patch_hub():
|
||||
|
||||
49
modelscope/utils/hf_util/pipeline_builder.py
Normal file
49
modelscope/utils/hf_util/pipeline_builder.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from typing import Optional, Union
|
||||
import os
|
||||
import torch
|
||||
from modelscope.hub import snapshot_download
|
||||
from transformers import (TFPreTrainedModel, PreTrainedModel, pipeline)
|
||||
from transformers.pipelines import get_task, check_task
|
||||
from transformers import Pipeline as PipelineHF
|
||||
from modelscope.utils.hf_util.patcher import patch_hub, _patch_pretrained_class
|
||||
|
||||
def _get_hf_device(device):
|
||||
if isinstance(device, str):
|
||||
device_name = device.lower()
|
||||
eles = device_name.split(':')
|
||||
if eles[0] == 'gpu':
|
||||
eles = ['cuda'] + eles[1:]
|
||||
device = ''.join(eles)
|
||||
return device
|
||||
|
||||
def _get_hf_pipeline_class(task, model):
|
||||
if not task:
|
||||
task = get_task(model)
|
||||
normalized_task, targeted_task, task_options = check_task(task)
|
||||
pipeline_class = targeted_task["impl"]
|
||||
pipeline_class = _patch_pretrained_class([pipeline_class])[0]
|
||||
return pipeline_class
|
||||
|
||||
def hf_pipeline(
|
||||
task: str = None,
|
||||
model: Optional[Union[str, "PreTrainedModel", "TFPreTrainedModel"]] = None,
|
||||
framework: Optional[str] = None,
|
||||
device: Optional[Union[int, str, "torch.device"]] = None,
|
||||
**kwargs,
|
||||
) -> PipelineHF:
|
||||
if isinstance(model, str):
|
||||
if not os.path.exists(model):
|
||||
model = snapshot_download(model)
|
||||
|
||||
framework = 'pt' if framework == 'pytorch' else framework
|
||||
|
||||
device = _get_hf_device(device)
|
||||
pipeline_class = _get_hf_pipeline_class(task, model)
|
||||
|
||||
return pipeline(task=task,
|
||||
model=model,
|
||||
framework=framework,
|
||||
device=device,
|
||||
pipeline_class=pipeline_class,
|
||||
#pipeline_class=QuestionAnsweringPipeline,
|
||||
**kwargs)
|
||||
Reference in New Issue
Block a user