mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
format llm pipeline
This commit is contained in:
@@ -7,7 +7,7 @@ from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.config import ConfigDict, check_config
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
|
||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
|
||||
ThirdParty)
|
||||
from modelscope.utils.hub import read_config
|
||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||
@@ -108,11 +108,21 @@ def pipeline(task: str = None,
|
||||
"""
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
|
||||
prefer_llm_pipeline = kwargs.get('llm_first')
|
||||
if task is not None and task.lower() in [
|
||||
Tasks.text_generation, Tasks.text2text_generation, Tasks.chat
|
||||
]:
|
||||
# if not specified, prefer llm pipeline for aforementioned tasks
|
||||
if prefer_llm_pipeline is None:
|
||||
prefer_llm_pipeline = True
|
||||
# for llm pipeline, if llm_framework is not specified, default to swift instead
|
||||
# TODO: port the swift infer based on transformer into ModelScope
|
||||
if prefer_llm_pipeline and kwargs.get('llm_framework') is None:
|
||||
kwargs['llm_framework'] = 'swift'
|
||||
third_party = kwargs.get(ThirdParty.KEY)
|
||||
if third_party is not None:
|
||||
kwargs.pop(ThirdParty.KEY)
|
||||
if pipeline_name is None and kwargs.get('llm_first'):
|
||||
if pipeline_name is None and prefer_llm_pipeline:
|
||||
pipeline_name = llm_first_checker(model, model_revision, kwargs)
|
||||
else:
|
||||
model = normalize_model_input(
|
||||
@@ -135,7 +145,7 @@ def pipeline(task: str = None,
|
||||
register_modelhub_repo(model, cfg.get('allow_remote', False))
|
||||
pipeline_name = llm_first_checker(
|
||||
model, model_revision,
|
||||
kwargs) if kwargs.get('llm_first') else None
|
||||
kwargs) if prefer_llm_pipeline else None
|
||||
if pipeline_name is not None:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
else:
|
||||
|
||||
@@ -90,7 +90,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
if self._is_swift_model(model):
|
||||
if self.llm_framework is not None:
|
||||
logger.warning(
|
||||
f'Cannot using swift with llm_framework, ignoring {self.llm_framework}.'
|
||||
f'Cannot swift with llm_framework, ignoring {self.llm_framework}.'
|
||||
)
|
||||
|
||||
base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
|
||||
@@ -223,9 +223,13 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
for k, v in MODEL_MAPPING.items()
|
||||
}
|
||||
|
||||
def format_messages(messages: Dict[str, List[Dict[str, str]]],
|
||||
def format_messages(messages: Union[List, Dict[str, List[Dict[str,
|
||||
str]]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
# for compatibility, also support input list, but we shall wrap it into Dict
|
||||
if isinstance(messages, list):
|
||||
messages = {'messages': messages}
|
||||
inputs, _ = self.template.encode(get_example(messages))
|
||||
inputs.pop('labels', None)
|
||||
if 'input_ids' in inputs:
|
||||
@@ -256,7 +260,8 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
history = list(zip(contents[::2], contents[1::2]))
|
||||
return dict(system=system, prompt=prompt, history=history)
|
||||
|
||||
assert model_id in SWIFT_MODEL_ID_MAPPING, 'Swift framework does not support current model!'
|
||||
assert model_id in SWIFT_MODEL_ID_MAPPING,\
|
||||
f'Invalid model id {model_id} or Swift framework does not this model.'
|
||||
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
|
||||
model, template = prepare_model_template(
|
||||
args, device_map=self.device_map)
|
||||
|
||||
Reference in New Issue
Block a user