mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
@@ -56,7 +56,7 @@ if TYPE_CHECKING:
|
|||||||
AutoModelForPreTraining, AutoModelForTextEncoding,
|
AutoModelForPreTraining, AutoModelForTextEncoding,
|
||||||
AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
|
AutoImageProcessor, BatchFeature, Qwen2VLForConditionalGeneration,
|
||||||
T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel,
|
T5EncoderModel, Qwen2_5_VLForConditionalGeneration, LlamaModel,
|
||||||
LlamaPreTrainedModel, LlamaForCausalLM)
|
LlamaPreTrainedModel, LlamaForCausalLM, hf_pipeline)
|
||||||
else:
|
else:
|
||||||
print(
|
print(
|
||||||
'transformer is not installed, please install it if you want to use related modules'
|
'transformer is not installed, please install it if you want to use related modules'
|
||||||
|
|||||||
@@ -71,6 +71,10 @@ def check_local_model_is_latest(
|
|||||||
headers=snapshot_header,
|
headers=snapshot_header,
|
||||||
use_cookies=cookies,
|
use_cookies=cookies,
|
||||||
)
|
)
|
||||||
|
model_cache = None
|
||||||
|
# download via non-git method
|
||||||
|
if not os.path.exists(os.path.join(model_root_path, '.git')):
|
||||||
|
model_cache = ModelFileSystemCache(model_root_path)
|
||||||
for model_file in model_files:
|
for model_file in model_files:
|
||||||
if model_file['Type'] == 'tree':
|
if model_file['Type'] == 'tree':
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class GitCommandWrapper(metaclass=Singleton):
|
|||||||
git_env = os.environ.copy()
|
git_env = os.environ.copy()
|
||||||
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
git_env['GIT_TERMINAL_PROMPT'] = '0'
|
||||||
command = [self.git_path, *args]
|
command = [self.git_path, *args]
|
||||||
|
command = [item for item in command if item]
|
||||||
response = subprocess.run(
|
response = subprocess.run(
|
||||||
command,
|
command,
|
||||||
stdout=subprocess.PIPE,
|
stdout=subprocess.PIPE,
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ from modelscope.utils.config import ConfigDict, check_config
|
|||||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
|
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
|
||||||
ThirdParty)
|
ThirdParty)
|
||||||
from modelscope.utils.hub import read_config
|
from modelscope.utils.hub import read_config
|
||||||
|
from modelscope.utils.import_utils import is_transformers_available
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||||
register_plugins_repo)
|
register_plugins_repo)
|
||||||
from modelscope.utils.registry import Registry, build_from_cfg
|
from modelscope.utils.registry import Registry, build_from_cfg
|
||||||
@@ -17,6 +19,7 @@ from .base import Pipeline
|
|||||||
from .util import is_official_hub_path
|
from .util import is_official_hub_path
|
||||||
|
|
||||||
PIPELINES = Registry('pipelines')
|
PIPELINES = Registry('pipelines')
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
def normalize_model_input(model,
|
def normalize_model_input(model,
|
||||||
@@ -72,7 +75,7 @@ def pipeline(task: str = None,
|
|||||||
config_file: str = None,
|
config_file: str = None,
|
||||||
pipeline_name: str = None,
|
pipeline_name: str = None,
|
||||||
framework: str = None,
|
framework: str = None,
|
||||||
device: str = 'gpu',
|
device: str = None,
|
||||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||||
ignore_file_pattern: List[str] = None,
|
ignore_file_pattern: List[str] = None,
|
||||||
**kwargs) -> Pipeline:
|
**kwargs) -> Pipeline:
|
||||||
@@ -109,6 +112,7 @@ def pipeline(task: str = None,
|
|||||||
if task is None and pipeline_name is None:
|
if task is None and pipeline_name is None:
|
||||||
raise ValueError('task or pipeline_name is required')
|
raise ValueError('task or pipeline_name is required')
|
||||||
|
|
||||||
|
pipeline_props = None
|
||||||
if pipeline_name is None:
|
if pipeline_name is None:
|
||||||
# get default pipeline for this task
|
# get default pipeline for this task
|
||||||
if isinstance(model, str) \
|
if isinstance(model, str) \
|
||||||
@@ -157,8 +161,11 @@ def pipeline(task: str = None,
|
|||||||
if pipeline_name:
|
if pipeline_name:
|
||||||
pipeline_props = {'type': pipeline_name}
|
pipeline_props = {'type': pipeline_name}
|
||||||
else:
|
else:
|
||||||
check_config(cfg)
|
try:
|
||||||
pipeline_props = cfg.pipeline
|
check_config(cfg)
|
||||||
|
pipeline_props = cfg.pipeline
|
||||||
|
except AssertionError as e:
|
||||||
|
logger.info(str(e))
|
||||||
|
|
||||||
elif model is not None:
|
elif model is not None:
|
||||||
# get pipeline info from Model object
|
# get pipeline info from Model object
|
||||||
@@ -166,9 +173,13 @@ def pipeline(task: str = None,
|
|||||||
if not hasattr(first_model, 'pipeline'):
|
if not hasattr(first_model, 'pipeline'):
|
||||||
# model is instantiated by user, we should parse config again
|
# model is instantiated by user, we should parse config again
|
||||||
cfg = read_config(first_model.model_dir)
|
cfg = read_config(first_model.model_dir)
|
||||||
check_config(cfg)
|
try:
|
||||||
first_model.pipeline = cfg.pipeline
|
check_config(cfg)
|
||||||
pipeline_props = first_model.pipeline
|
first_model.pipeline = cfg.pipeline
|
||||||
|
except AssertionError as e:
|
||||||
|
logger.info(str(e))
|
||||||
|
if first_model.__dict__.get('pipeline'):
|
||||||
|
pipeline_props = first_model.pipeline
|
||||||
else:
|
else:
|
||||||
pipeline_name, default_model_repo = get_default_pipeline_info(task)
|
pipeline_name, default_model_repo = get_default_pipeline_info(task)
|
||||||
model = normalize_model_input(default_model_repo, model_revision)
|
model = normalize_model_input(default_model_repo, model_revision)
|
||||||
@@ -176,6 +187,23 @@ def pipeline(task: str = None,
|
|||||||
else:
|
else:
|
||||||
pipeline_props = {'type': pipeline_name}
|
pipeline_props = {'type': pipeline_name}
|
||||||
|
|
||||||
|
if not pipeline_props and is_transformers_available():
|
||||||
|
try:
|
||||||
|
from modelscope.utils.hf_util import hf_pipeline
|
||||||
|
return hf_pipeline(
|
||||||
|
task=task,
|
||||||
|
model=model,
|
||||||
|
framework=framework,
|
||||||
|
device=device,
|
||||||
|
**kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,'
|
||||||
|
' but that also failed.')
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if not device:
|
||||||
|
device = 'gpu'
|
||||||
pipeline_props['model'] = model
|
pipeline_props['model'] = model
|
||||||
pipeline_props['device'] = device
|
pipeline_props['device'] = device
|
||||||
cfg = ConfigDict(pipeline_props)
|
cfg = ConfigDict(pipeline_props)
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .auto_class import *
|
from .auto_class import *
|
||||||
from .patcher import patch_context, patch_hub, unpatch_hub
|
from .patcher import patch_context, patch_hub, unpatch_hub
|
||||||
|
from .pipeline_builder import hf_pipeline
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ def get_all_imported_modules():
|
|||||||
transformers_include_names = [
|
transformers_include_names = [
|
||||||
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
|
'Auto.*', 'T5.*', 'BitsAndBytesConfig', 'GenerationConfig', 'Awq.*',
|
||||||
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
|
'GPTQ.*', 'BatchFeature', 'Qwen.*', 'Llama.*', 'PretrainedConfig',
|
||||||
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast'
|
'PreTrainedTokenizer', 'PreTrainedModel', 'PreTrainedTokenizerFast',
|
||||||
|
'Pipeline'
|
||||||
]
|
]
|
||||||
peft_include_names = ['.*PeftModel.*', '.*Config']
|
peft_include_names = ['.*PeftModel.*', '.*Config']
|
||||||
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
|
diffusers_include_names = ['^(?!TF|Flax).*Pipeline$']
|
||||||
@@ -252,6 +253,44 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
model_dir, *model_args, **kwargs)
|
model_dir, *model_args, **kwargs)
|
||||||
return module_obj
|
return module_obj
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
safe_serialization: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
push_to_hub = kwargs.pop('push_to_hub', False)
|
||||||
|
if push_to_hub:
|
||||||
|
from modelscope.hub.push_to_hub import push_to_hub
|
||||||
|
from modelscope.hub.api import HubApi
|
||||||
|
from modelscope.hub.repository import Repository
|
||||||
|
|
||||||
|
token = kwargs.get('token')
|
||||||
|
commit_message = kwargs.pop('commit_message', None)
|
||||||
|
repo_name = kwargs.pop(
|
||||||
|
'repo_id',
|
||||||
|
save_directory.split(os.path.sep)[-1])
|
||||||
|
|
||||||
|
api = HubApi()
|
||||||
|
api.login(token)
|
||||||
|
api.create_repo(repo_name)
|
||||||
|
# clone the repo
|
||||||
|
Repository(save_directory, repo_name)
|
||||||
|
|
||||||
|
super().save_pretrained(
|
||||||
|
save_directory=save_directory,
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
push_to_hub=False,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
# Class members may be unpatched, so push_to_hub is done separately here
|
||||||
|
if push_to_hub:
|
||||||
|
push_to_hub(
|
||||||
|
repo_name=repo_name,
|
||||||
|
output_dir=save_directory,
|
||||||
|
commit_message=commit_message,
|
||||||
|
token=token)
|
||||||
|
|
||||||
if not hasattr(module_class, 'from_pretrained'):
|
if not hasattr(module_class, 'from_pretrained'):
|
||||||
del ClassWrapper.from_pretrained
|
del ClassWrapper.from_pretrained
|
||||||
else:
|
else:
|
||||||
@@ -266,6 +305,9 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
if not hasattr(module_class, 'get_config_dict'):
|
if not hasattr(module_class, 'get_config_dict'):
|
||||||
del ClassWrapper.get_config_dict
|
del ClassWrapper.get_config_dict
|
||||||
|
|
||||||
|
if not hasattr(module_class, 'save_pretrained'):
|
||||||
|
del ClassWrapper.save_pretrained
|
||||||
|
|
||||||
ClassWrapper.__name__ = module_class.__name__
|
ClassWrapper.__name__ = module_class.__name__
|
||||||
ClassWrapper.__qualname__ = module_class.__qualname__
|
ClassWrapper.__qualname__ = module_class.__qualname__
|
||||||
return ClassWrapper
|
return ClassWrapper
|
||||||
@@ -289,12 +331,16 @@ def _patch_pretrained_class(all_imported_modules, wrap=False):
|
|||||||
has_from_pretrained = hasattr(var, 'from_pretrained')
|
has_from_pretrained = hasattr(var, 'from_pretrained')
|
||||||
has_get_peft_type = hasattr(var, '_get_peft_type')
|
has_get_peft_type = hasattr(var, '_get_peft_type')
|
||||||
has_get_config_dict = hasattr(var, 'get_config_dict')
|
has_get_config_dict = hasattr(var, 'get_config_dict')
|
||||||
|
has_save_pretrained = hasattr(var, 'save_pretrained')
|
||||||
except: # noqa
|
except: # noqa
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if wrap:
|
# save_pretrained is not a classmethod and cannot be overridden by replacing
|
||||||
|
# the class method. It requires replacing the class object method.
|
||||||
|
if wrap or ('pipeline' in name.lower() and has_save_pretrained):
|
||||||
try:
|
try:
|
||||||
if not has_from_pretrained and not has_get_config_dict and not has_get_peft_type:
|
if (not has_from_pretrained and not has_get_config_dict
|
||||||
|
and not has_get_peft_type and not has_save_pretrained):
|
||||||
all_available_modules.append(var)
|
all_available_modules.append(var)
|
||||||
else:
|
else:
|
||||||
all_available_modules.append(
|
all_available_modules.append(
|
||||||
|
|||||||
54
modelscope/utils/hf_util/pipeline_builder.py
Normal file
54
modelscope/utils/hf_util/pipeline_builder.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import Pipeline as PipelineHF
|
||||||
|
from transformers import PreTrainedModel, TFPreTrainedModel, pipeline
|
||||||
|
from transformers.pipelines import check_task, get_task
|
||||||
|
|
||||||
|
from modelscope.hub import snapshot_download
|
||||||
|
from modelscope.utils.hf_util.patcher import _patch_pretrained_class, patch_hub
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hf_device(device):
|
||||||
|
if isinstance(device, str):
|
||||||
|
device_name = device.lower()
|
||||||
|
eles = device_name.split(':')
|
||||||
|
if eles[0] == 'gpu':
|
||||||
|
eles = ['cuda'] + eles[1:]
|
||||||
|
device = ''.join(eles)
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def _get_hf_pipeline_class(task, model):
|
||||||
|
if not task:
|
||||||
|
task = get_task(model)
|
||||||
|
normalized_task, targeted_task, task_options = check_task(task)
|
||||||
|
pipeline_class = targeted_task['impl']
|
||||||
|
pipeline_class = _patch_pretrained_class([pipeline_class])[0]
|
||||||
|
return pipeline_class
|
||||||
|
|
||||||
|
|
||||||
|
def hf_pipeline(
|
||||||
|
task: str = None,
|
||||||
|
model: Optional[Union[str, 'PreTrainedModel', 'TFPreTrainedModel']] = None,
|
||||||
|
framework: Optional[str] = None,
|
||||||
|
device: Optional[Union[int, str, 'torch.device']] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> PipelineHF:
|
||||||
|
if isinstance(model, str):
|
||||||
|
if not os.path.exists(model):
|
||||||
|
model = snapshot_download(model)
|
||||||
|
|
||||||
|
framework = 'pt' if framework == 'pytorch' else framework
|
||||||
|
|
||||||
|
device = _get_hf_device(device)
|
||||||
|
pipeline_class = _get_hf_pipeline_class(task, model)
|
||||||
|
|
||||||
|
return pipeline(
|
||||||
|
task=task,
|
||||||
|
model=model,
|
||||||
|
framework=framework,
|
||||||
|
device=device,
|
||||||
|
pipeline_class=pipeline_class,
|
||||||
|
**kwargs)
|
||||||
@@ -40,6 +40,14 @@ class HFUtilTest(unittest.TestCase):
|
|||||||
with open(self.test_file2, 'w') as f:
|
with open(self.test_file2, 'w') as f:
|
||||||
f.write('{}')
|
f.write('{}')
|
||||||
|
|
||||||
|
self.pipeline_qa_context = r"""
|
||||||
|
Extractive Question Answering is the task of extracting an answer from a text given a question. An example
|
||||||
|
of a question answering dataset is the SQuAD dataset, which is entirely based on that task. If you would
|
||||||
|
like to fine-tune a model on a SQuAD task, you may leverage the
|
||||||
|
examples/pytorch/question-answering/run_squad.py script.
|
||||||
|
"""
|
||||||
|
self.pipeline_qa_question = 'What is a good example of a question answering dataset?'
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
logger.info('TearDown')
|
logger.info('TearDown')
|
||||||
shutil.rmtree(self.model_dir, ignore_errors=True)
|
shutil.rmtree(self.model_dir, ignore_errors=True)
|
||||||
@@ -235,6 +243,59 @@ class HFUtilTest(unittest.TestCase):
|
|||||||
'Qwen/Qwen1.5-0.5B-Chat', trust_remote_code=True)
|
'Qwen/Qwen1.5-0.5B-Chat', trust_remote_code=True)
|
||||||
model.push_to_hub(self.create_model_name)
|
model.push_to_hub(self.create_model_name)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
|
def test_pipeline_model_id(self):
|
||||||
|
from modelscope import pipeline
|
||||||
|
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||||
|
qa = pipeline('question-answering', model=model_id)
|
||||||
|
assert qa(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
|
def test_pipeline_auto_model(self):
|
||||||
|
from modelscope import pipeline, AutoModelForQuestionAnswering, AutoTokenizer
|
||||||
|
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||||
|
model = AutoModelForQuestionAnswering.from_pretrained(model_id)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
qa = pipeline('question-answering', model=model, tokenizer=tokenizer)
|
||||||
|
assert qa(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
|
def test_pipeline_save_pretrained(self):
|
||||||
|
from modelscope import pipeline
|
||||||
|
model_id = 'damotestx/distilbert-base-cased-distilled-squad'
|
||||||
|
|
||||||
|
pipe_ori = pipeline('question-answering', model=model_id)
|
||||||
|
|
||||||
|
result_ori = pipe_ori(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
|
# save_pretrained
|
||||||
|
repo_id = self.create_model_name
|
||||||
|
save_dir = './tmp_test_hf_pipeline'
|
||||||
|
try:
|
||||||
|
os.system(f'rm -rf {save_dir}')
|
||||||
|
self.api.delete_model(repo_id)
|
||||||
|
# wait for delete repo
|
||||||
|
import time
|
||||||
|
time.sleep(5)
|
||||||
|
except Exception:
|
||||||
|
# if repo not exists
|
||||||
|
pass
|
||||||
|
pipe_ori.save_pretrained(save_dir, push_to_hub=True, repo_id=repo_id)
|
||||||
|
|
||||||
|
# load from saved
|
||||||
|
pipe_new = pipeline('question-answering', model=repo_id)
|
||||||
|
result_new = pipe_new(
|
||||||
|
question=self.pipeline_qa_question,
|
||||||
|
context=self.pipeline_qa_context)
|
||||||
|
|
||||||
|
assert result_new == result_ori
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user