mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
Support swift model inference for llm pipeline (#880)
* Support swift model inference for llm pipeline * fix bug * Add 'swfit' llm_framework and fix stream bug * For pass the unittest run action
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
|
||||
@@ -112,13 +112,14 @@ def pipeline(task: str = None,
|
||||
third_party = kwargs.get(ThirdParty.KEY)
|
||||
if third_party is not None:
|
||||
kwargs.pop(ThirdParty.KEY)
|
||||
model = normalize_model_input(
|
||||
model,
|
||||
model_revision,
|
||||
third_party=third_party,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
if pipeline_name is None and kwargs.get('llm_first'):
|
||||
pipeline_name = llm_first_checker(model, model_revision)
|
||||
pipeline_name = llm_first_checker(model, model_revision, kwargs)
|
||||
else:
|
||||
model = normalize_model_input(
|
||||
model,
|
||||
model_revision,
|
||||
third_party=third_party,
|
||||
ignore_file_pattern=ignore_file_pattern)
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
if pipeline_name is None:
|
||||
# get default pipeline for this task
|
||||
@@ -132,8 +133,9 @@ def pipeline(task: str = None,
|
||||
model[0], revision=model_revision)
|
||||
register_plugins_repo(cfg.safe_get('plugins'))
|
||||
register_modelhub_repo(model, cfg.get('allow_remote', False))
|
||||
pipeline_name = llm_first_checker(model, model_revision) \
|
||||
if kwargs.get('llm_first') else None
|
||||
pipeline_name = llm_first_checker(
|
||||
model, model_revision,
|
||||
kwargs) if kwargs.get('llm_first') else None
|
||||
if pipeline_name is not None:
|
||||
pipeline_props = {'type': pipeline_name}
|
||||
else:
|
||||
@@ -157,9 +159,6 @@ def pipeline(task: str = None,
|
||||
pipeline_props['device'] = device
|
||||
cfg = ConfigDict(pipeline_props)
|
||||
|
||||
# support set llm_framework=None
|
||||
if pipeline_name == 'llm' and kwargs.get('llm_framework', '') == '':
|
||||
kwargs['llm_framework'] = 'vllm'
|
||||
clear_llm_info(kwargs)
|
||||
if kwargs:
|
||||
cfg.update(kwargs)
|
||||
@@ -209,14 +208,18 @@ def get_default_pipeline_info(task):
|
||||
return pipeline_name, default_model
|
||||
|
||||
|
||||
def llm_first_checker(model: Union[str, List[str], Model, List[Model]],
|
||||
revision: Optional[str]) -> Optional[str]:
|
||||
def llm_first_checker(model: Union[str, List[str], Model,
|
||||
List[Model]], revision: Optional[str],
|
||||
kwargs: Dict[str, Any]) -> Optional[str]:
|
||||
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
|
||||
|
||||
if isinstance(model, list):
|
||||
model = model[0]
|
||||
if not isinstance(model, str):
|
||||
model = model.model_dir
|
||||
|
||||
if kwargs.get('llm_framework') == 'swift':
|
||||
return 'llm'
|
||||
model_type = ModelTypeHelper.get(
|
||||
model, revision, with_adapter=True, split='-', use_cache=True)
|
||||
if LLMAdapterRegistry.contains(model_type):
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from threading import Lock
|
||||
from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
@@ -18,7 +20,7 @@ from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.util import is_model, is_official_hub_path
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Frameworks, Invoke, ModelFile, Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.device import create_device, device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.model_type_helper import ModelTypeHelper
|
||||
from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
|
||||
@@ -27,6 +29,8 @@ from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
SWIFT_MODEL_ID_MAPPING = {}
|
||||
|
||||
|
||||
class LLMAdapterRegistry:
|
||||
|
||||
@@ -79,6 +83,8 @@ class LLMAdapterRegistry:
|
||||
class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
|
||||
def initiate_single_model(self, model):
|
||||
from swift import Swift
|
||||
|
||||
if isinstance(model, str):
|
||||
logger.info(f'initiate model from {model}')
|
||||
if self._is_swift_model(model):
|
||||
@@ -86,7 +92,6 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
logger.warning(
|
||||
f'Cannot using swift with llm_framework, ignoring {self.llm_framework}.'
|
||||
)
|
||||
from swift import Swift
|
||||
|
||||
base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
|
||||
assert base_model is not None, 'Cannot get adapter_cfg.model_id_or_path from configuration.json file.'
|
||||
@@ -170,6 +175,10 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
self.device_map = 'cuda'
|
||||
self.torch_dtype = kwargs.pop('torch_dtype', None)
|
||||
self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
||||
|
||||
if llm_framework == 'swift':
|
||||
self._init_swift(kwargs['model'], kwargs.get('device', 'gpu'))
|
||||
return
|
||||
with self._temp_configuration_file(kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
@@ -195,6 +204,69 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
self.tokenizer = self._get_tokenizer(
|
||||
tokenizer_class) if tokenizer is None else tokenizer
|
||||
|
||||
def _init_swift(self, model_id, device) -> None:
|
||||
from swift.llm import prepare_model_template
|
||||
from swift.llm.utils import MODEL_MAPPING, InferArguments
|
||||
|
||||
global SWIFT_MODEL_ID_MAPPING
|
||||
if not SWIFT_MODEL_ID_MAPPING:
|
||||
SWIFT_MODEL_ID_MAPPING = {
|
||||
v['model_id_or_path']: k
|
||||
for k, v in MODEL_MAPPING.items()
|
||||
}
|
||||
|
||||
def format_messages(messages: Dict[str, List[Dict[str, str]]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
**kwargs) -> Dict[str, torch.Tensor]:
|
||||
inputs, _ = self.template.encode(get_example(messages))
|
||||
inputs.pop('labels', None)
|
||||
if 'input_ids' in inputs:
|
||||
input_ids = torch.tensor(inputs['input_ids'])[None]
|
||||
inputs['input_ids'] = input_ids
|
||||
token_len = input_ids.shape[1]
|
||||
if 'inputs_embeds' in inputs:
|
||||
inputs_embeds = inputs['inputs_embeds'][None]
|
||||
inputs['inputs_embeds'] = inputs_embeds
|
||||
token_len = inputs_embeds.shape[1]
|
||||
inputs['attention_mask'] = torch.ones(token_len)[None]
|
||||
if 'token_type_ids' in inputs:
|
||||
inputs['token_type_ids'] = torch.tensor(
|
||||
inputs['token_type_ids'])[None]
|
||||
return inputs
|
||||
|
||||
def get_example(
|
||||
messages: Dict[str, List[Dict[str, str]]]) -> Dict[str, str]:
|
||||
messages = messages['messages']
|
||||
assert len(messages) > 0, 'messages cannot be empty!'
|
||||
system = None
|
||||
if messages[0]['role'] == 'system':
|
||||
system = messages[0]['content']
|
||||
messages = messages[1:]
|
||||
assert len(messages) % 2 == 1, 'Unsupported messages format!'
|
||||
contents = [message['content'] for message in messages]
|
||||
prompt = contents[-1]
|
||||
history = list(zip(contents[::2], contents[1::2]))
|
||||
return dict(system=system, prompt=prompt, history=history)
|
||||
|
||||
assert model_id in SWIFT_MODEL_ID_MAPPING, 'Swift framework does not support current model!'
|
||||
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
|
||||
model, template = prepare_model_template(
|
||||
args, device_map=self.device_map)
|
||||
self.model = add_stream_generate(model)
|
||||
template.model = self.model
|
||||
self.template = template
|
||||
self.tokenizer = template.tokenizer
|
||||
self.format_messages = format_messages
|
||||
|
||||
self.has_multiple_models = False
|
||||
self.framework = Frameworks.torch
|
||||
self.device_name = device
|
||||
self.device = create_device(device)
|
||||
self._model_prepare = False
|
||||
self._model_prepare_lock = Lock()
|
||||
self._auto_collate = True
|
||||
self._compile = False
|
||||
|
||||
@contextmanager
|
||||
def _temp_configuration_file(self, kwargs: Dict[str, Any]):
|
||||
kwargs['model'] = model = self.initiate_single_model(kwargs['model'])
|
||||
@@ -217,7 +289,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
= isinstance(inputs, dict) and 'messages' in inputs
|
||||
tokens = self.preprocess(inputs, **preprocess_params)
|
||||
|
||||
if self.llm_framework is None:
|
||||
if self.llm_framework in (None, 'swift'):
|
||||
# pytorch model
|
||||
if hasattr(self.model, 'generate'):
|
||||
outputs = self.model.generate(**tokens, **forward_params)
|
||||
@@ -313,6 +385,9 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
||||
def postprocess(self, outputs, **kwargs):
|
||||
is_messages = kwargs.pop('is_messages')
|
||||
if not isinstance(outputs, str):
|
||||
shape_type = (torch.Tensor, np.ndarray)
|
||||
if isinstance(outputs, shape_type) and len(outputs.shape) > 1:
|
||||
outputs = outputs[0]
|
||||
response = self.tokenizer.decode(
|
||||
outputs, skip_special_tokens=True, **kwargs)
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,8 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.generation import GreedySearchDecoderOnlyOutput # noqa
|
||||
@@ -173,16 +175,24 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
|
||||
|
||||
@contextmanager
|
||||
def _replace_generate(self, model: PreTrainedModel) -> Generator:
|
||||
greedy_search = model.greedy_search
|
||||
sample = model.sample
|
||||
model.greedy_search = types.MethodType(self._greedy_search, model)
|
||||
model.sample = types.MethodType(self._sample, model)
|
||||
if version.parse(transformers.__version__) >= version.parse('4.39.0'):
|
||||
greedy_search_name = '_greedy_search'
|
||||
sample_name = '_sample'
|
||||
else:
|
||||
greedy_search_name = 'greedy_search'
|
||||
sample_name = 'sample'
|
||||
origin_greedy_search = getattr(model, greedy_search_name)
|
||||
origin_sample = getattr(model, sample_name)
|
||||
setattr(model, greedy_search_name,
|
||||
types.MethodType(self.stream_greedy_search, model))
|
||||
setattr(model, sample_name, types.MethodType(self.stream_sample,
|
||||
model))
|
||||
yield
|
||||
model.greedy_search = greedy_search
|
||||
model.sample = sample
|
||||
setattr(model, greedy_search_name, origin_greedy_search)
|
||||
setattr(model, sample_name, origin_sample)
|
||||
|
||||
@staticmethod
|
||||
def _greedy_search(
|
||||
def stream_greedy_search(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
@@ -356,7 +366,7 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def _sample(
|
||||
def stream_sample(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
|
||||
@@ -44,7 +44,7 @@ class Human3DRenderTest(unittest.TestCase):
|
||||
human3d = pipeline(self.task, model=self.model_id)
|
||||
input = {
|
||||
'dataset_id': 'damo/3DHuman_synthetic_dataset',
|
||||
'case_id': '3f2a7538253e42a8',
|
||||
'case_id': '000039',
|
||||
'resolution': 1024,
|
||||
}
|
||||
output = human3d(input)
|
||||
|
||||
@@ -131,7 +131,13 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
]
|
||||
}]
|
||||
}
|
||||
self.gen_cfg = {'do_sample': True, 'max_length': 512}
|
||||
self.messages_zh_one_round = {
|
||||
'messages': [{
|
||||
'role': 'user',
|
||||
'content': '你叫什么名字?'
|
||||
}]
|
||||
}
|
||||
self.gen_cfg = {'do_sample': True, 'max_new_tokens': 128}
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm2(self):
|
||||
@@ -365,6 +371,69 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_baichuan2_with_swift(self):
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-13B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_baichuan2_stream_gemerate(self):
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-13B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_yi_with_swift(self):
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='01ai/Yi-1.5-6B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_yi_stream_gemerate(self):
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='01ai/Yi-1.5-6B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_internlm2_with_swift(self):
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
print('messages: ', pipe(self.messages_zh_one_round, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_internlm2_stream_gemerate(self):
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh_one_round,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user