mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
@@ -56,7 +56,7 @@ if TYPE_CHECKING:
|
||||
AutoModelForPreTraining, AutoModelForTextEncoding,
|
||||
AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
|
||||
T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel,
|
||||
LlamaPreTrainedModel, LlamaForCausalLM)
|
||||
LlamaPreTrainedModel, LlamaForCausalLM, hf_pipeline)
|
||||
else:
|
||||
print(
|
||||
'transformer is not installed, please install it if you want to use related modules'
|
||||
|
||||
@@ -71,6 +71,10 @@ def check_local_model_is_latest(
|
||||
headers=snapshot_header,
|
||||
use_cookies=cookies,
|
||||
)
|
||||
model_cache = None
|
||||
# download via non-git method
|
||||
if not os.path.exists(os.path.join(model_root_path, '.git')):
|
||||
model_cache = ModelFileSystemCache(model_root_path)
|
||||
for model_file in model_files:
|
||||
if model_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
@@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton):
|
||||
git_env = os.environ.copy()
|
||||
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
||||
command = [self.git_path, *args]
|
||||
command = [item for item in command if item]
|
||||
response = subprocess.run(
|
||||
command,
|
||||
stdout=subprocess.PIPE,
|
||||
|
||||
@@ -10,6 +10,8 @@ from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
|
||||
ThirdParty)
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.import_utils import is_transformers_available
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
register_plugins_repo)
|
||||
from modelscope.utils.registry import Registry, build_from_cfg
|
||||
@@ -17,6 +19,7 @@ from .base import Pipeline
|
||||
from .util import is_official_hub_path
|
||||
|
||||
PIPELINES = Registry('pipelines')
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def normalize_model_input(model,
|
||||
@@ -72,7 +75,7 @@ def pipeline(task: str = None,
|
||||
config_file: str = None,
|
||||
pipeline_name: str = None,
|
||||
framework: str = None,
|
||||
device: str = 'gpu',
|
||||
device: str = None,
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
ignore_file_pattern: List[str] = None,
|
||||
**kwargs) -> Pipeline:
|
||||
@@ -109,6 +112,7 @@ def pipeline(task: str = None,
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
|
||||
pipeline_props = None
|
||||
if pipeline_name is None:
|
||||
# get default pipeline for this task
|
||||
if isinstance(model, str) \
|
||||
@@ -157,8 +161,11 @@ def pipeline(task: str = None,
|
||||
if pipeline_name:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
else:
|
||||
try:
|
||||
check_config(cfg)
|
||||
pipeline_props = cfg.pipeline
|
||||
except AssertionError as e:
|
||||
logger.info(str(e))
|
||||
|
||||
elif model is not None:
|
||||
# get pipeline info from Model object
|
||||
@@ -166,8 +173,12 @@ def pipeline(task: str = None,
|
||||
if not hasattr(first_model, 'pipeline'):
|
||||
# model is instantiated by user, we should parse config again
|
||||
cfg = read_config(first_model.model_dir)
|
||||
try:
|
||||
check_config(cfg)
|
||||
first_model.pipeline = cfg.pipeline
|
||||
except AssertionError as e:
|
||||
logger.info(str(e))
|
||||
if first_model.__dict__.get('pipeline'):
|
||||
pipeline_props = first_model.pipeline
|
||||
else:
|
||||
pipeline_name, default_model_repo = get_default_pipeline_info(task)
|
||||
@@ -176,6 +187,23 @@ def pipeline(task: str = None,
|
||||
else:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
|
||||
if not pipeline_props and is_transformers_available():
|
||||
try:
|
||||
from modelscope.utils.hf_util import hf_pipeline
|
||||
return hf_pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
framework=framework,
|
||||
device=device,
|
||||
**kwargs)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,'
|
||||
' but that also failed.')
|
||||
raise e
|
||||
|
||||
if not device:
|
||||
device = 'gpu'
|
||||
pipeline_props['model'] = model
|
||||
pipeline_props['device'] = device
|
||||
cfg = ConfigDict(pipeline_props)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .auto_class import *
|
||||
from .patcher import patch_context, patch_hub, unpatch_hub
|
||||
from .pipeline_builder import hf_pipeline
|
||||
|
||||
@@ -27,7 +27,8 @@ def get_all_imported_modules():
|
||||
transformers_include_names = [
|
||||
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
|
||||
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
|
||||
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast'
|
||||
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast',
|
||||
'Pipeline'
|
||||
]
|
||||
peft_include_names = ['.*PeftModel.*', '.*Config']
|
||||
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
|
||||
@@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
model_dir, *model_args, **kwargs)
|
||||
return module_obj
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
safe_serialization: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
push_to_hub = kwargs.pop('push_to_hub', False)
|
||||
if push_to_hub:
|
||||
from modelscope.hub.push_to_hub import push_to_hub
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.hub.repository import Repository
|
||||
|
||||
token = kwargs.get('token')
|
||||
commit_message = kwargs.pop('commit_message', None)
|
||||
repo_name = kwargs.pop(
|
||||
'repo_id',
|
||||
save_directory.split(os.path.sep)[-1])
|
||||
|
||||
api = HubApi()
|
||||
api.login(token)
|
||||
api.create_repo(repo_name)
|
||||
# clone the repo
|
||||
Repository(save_directory, repo_name)
|
||||
|
||||
super().save_pretrained(
|
||||
save_directory=save_directory,
|
||||
safe_serialization=safe_serialization,
|
||||
push_to_hub=False,
|
||||
**kwargs)
|
||||
|
||||
# Class members may be unpatched, so push_to_hub is done separately here
|
||||
if push_to_hub:
|
||||
push_to_hub(
|
||||
repo_name=repo_name,
|
||||
output_dir=save_directory,
|
||||
commit_message=commit_message,
|
||||
token=token)
|
||||
|
||||
if not hasattr(module_class, 'from_pretrained'):
|
||||
del ClassWrapper.from_pretrained
|
||||
else:
|
||||
@@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
if not hasattr(module_class, 'get_config_dict'):
|
||||
del ClassWrapper.get_config_dict
|
||||
|
||||
if not hasattr(module_class, 'save_pretrained'):
|
||||
del ClassWrapper.save_pretrained
|
||||
|
||||
ClassWrapper.__name__ = module_class.__name__
|
||||
ClassWrapper.__qualname__ = module_class.__qualname__
|
||||
return ClassWrapper
|
||||
@@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||
has_save_pretrained = hasattr(var, 'save_pretrained')
|
||||
except: # noqa
|
||||
continue
|
||||
|
||||
if wrap:
|
||||
# save_pretrained is not a classmethod and cannot be overridden by replacing
|
||||
# the class method. It requires replacing the class object method.
|
||||
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
|
||||
try:
|
||||
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
|
||||
if (not has_from_pretrained and not has_get_config_dict
|
||||
and not has_get_peft_type and not has_save_pretrained):
|
||||
all_available_modules.append(var)
|
||||
else:
|
||||
all_available_modules.append(
|
||||
|
||||
54
modelscope/utils/hf_util/pipeline_builder.py
Normal file
54
modelscope/utils/hf_util/pipeline_builder.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import Pipeline as PipelineHF
|
||||
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
|
||||
from transformers.pipelines import check_task, get_task
|
||||
|
||||
from modelscope.hub import snapshot_download
|
||||
from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub
|
||||
|
||||
|
||||
def _get_hf_device(device):
|
||||
if isinstance(device, str):
|
||||
device_name = device.lower()
|
||||
eles = device_name.split(':')
|
||||
if eles[0] == 'gpu':
|
||||
eles = ['cuda'] + eles[1:]
|
||||
device = ''.join(eles)
|
||||
return device
|
||||
|
||||
|
||||
def _get_hf_pipeline_class(task, model):
|
||||
if not task:
|
||||
task = get_task(model)
|
||||
normalized_task, targeted_task, task_options = check_task(task)
|
||||
pipeline_class = targeted_task['impl']
|
||||
pipeline_class = _patch_pretrained_class([pipeline_class])[0]
|
||||
return pipeline_class
|
||||
|
||||
|
||||
def hf_pipeline(
|
||||
task: str = None,
|
||||
model: Optional[Union[str, 'PreTrainedModel', 'TFPreTrainedModel']] = None,
|
||||
framework: Optional[str] = None,
|
||||
device: Optional[Union[int, str, 'torch.device']] = None,
|
||||
**kwargs,
|
||||
) -> PipelineHF:
|
||||
if isinstance(model, str):
|
||||
if not os.path.exists(model):
|
||||
model = snapshot_download(model)
|
||||
|
||||
framework = 'pt' if framework == 'pytorch' else framework
|
||||
|
||||
device = _get_hf_device(device)
|
||||
pipeline_class = _get_hf_pipeline_class(task, model)
|
||||
|
||||
return pipeline(
|
||||
task=task,
|
||||
model=model,
|
||||
framework=framework,
|
||||
device=device,
|
||||
pipeline_class=pipeline_class,
|
||||
**kwargs)
|
||||
@@ -40,6 +40,14 @@ class HFUtilTest(unittest.TestCase):
|
||||
with open(self.test_file2, 'w') as f:
|
||||
f.write('{}')
|
||||
|
||||
self.pipeline_qa_context = r"""
|
||||
Extractive Question Answering is the task of extracting an answer from a text given a question. An example
|
||||
of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would
|
||||
like to fine-tune a model on a SQuAD task, you may leverage the
|
||||
examples/pytorch/question-answering/run_squad.py script.
|
||||
"""
|
||||
self.pipeline_qa_question = 'What is a good example of a question answering dataset?'
|
||||
|
||||
def tearDown(self):
|
||||
logger.info('TearDown')
|
||||
shutil.rmtree(self.model_dir, ignore_errors=True)
|
||||
@@ -235,6 +243,59 @@ class HFUtilTest(unittest.TestCase):
|
||||
'Qwen/Qwen1.5-0.5B-Chat', trust_remote_code=True)
|
||||
model.push_to_hub(self.create_model_name)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_model_id(self):
|
||||
from modelscope import pipeline
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
qa = pipeline('question-answering', model=model_id)
|
||||
assert qa(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_auto_model(self):
|
||||
from modelscope import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
model = AutoModelForQuestionAnswering.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
qa = pipeline('question-answering', model=model, tokenizer=tokenizer)
|
||||
assert qa(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_pipeline_save_pretrained(self):
|
||||
from modelscope import pipeline
|
||||
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||
|
||||
pipe_ori = pipeline('question-answering', model=model_id)
|
||||
|
||||
result_ori = pipe_ori(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
# save_pretrained
|
||||
repo_id = self.create_model_name
|
||||
save_dir = './tmp_test_hf_pipeline'
|
||||
try:
|
||||
os.system(f'rm -rf {save_dir}')
|
||||
self.api.delete_model(repo_id)
|
||||
# wait for delete repo
|
||||
import time
|
||||
time.sleep(5)
|
||||
except Exception:
|
||||
# if repo not exists
|
||||
pass
|
||||
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
|
||||
|
||||
# load from saved
|
||||
pipe_new = pipeline('question-answering', model=repo_id)
|
||||
result_new = pipe_new(
|
||||
question=self.pipeline_qa_question,
|
||||
context=self.pipeline_qa_context)
|
||||
|
||||
assert result_new == result_ori
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user