Support swift model inference for llm pipeline (#880)

* Support swift model inference for llm pipeline

* fix bug

* Add 'swfit' llm_framework and fix stream bug

* For pass the unittest run action
This commit is contained in:
Firmament-cyou
2024-06-30 19:43:13 +08:00
committed by GitHub
parent f65f45959d
commit 8cdbeec3f6
5 changed files with 184 additions and 27 deletions

View File

@@ -1,7 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
@@ -112,13 +112,14 @@ def pipeline(task: str = None,
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
model = normalize_model_input(
model,
model_revision,
third_party=third_party,
ignore_file_pattern=ignore_file_pattern)
if pipeline_name is None and kwargs.get('llm_first'):
pipeline_name = llm_first_checker(model, model_revision)
pipeline_name = llm_first_checker(model, model_revision, kwargs)
else:
model = normalize_model_input(
model,
model_revision,
third_party=third_party,
ignore_file_pattern=ignore_file_pattern)
pipeline_props = {'type': pipeline_name}
if pipeline_name is None:
# get default pipeline for this task
@@ -132,8 +133,9 @@ def pipeline(task: str = None,
model[0], revision=model_revision)
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model, cfg.get('allow_remote', False))
pipeline_name = llm_first_checker(model, model_revision) \
if kwargs.get('llm_first') else None
pipeline_name = llm_first_checker(
model, model_revision,
kwargs) if kwargs.get('llm_first') else None
if pipeline_name is not None:
pipeline_props = {'type': pipeline_name}
else:
@@ -157,9 +159,6 @@ def pipeline(task: str = None,
pipeline_props['device'] = device
cfg = ConfigDict(pipeline_props)
# support set llm_framework=None
if pipeline_name == 'llm' and kwargs.get('llm_framework', '') == '':
kwargs['llm_framework'] = 'vllm'
clear_llm_info(kwargs)
if kwargs:
cfg.update(kwargs)
@@ -209,14 +208,18 @@ def get_default_pipeline_info(task):
return pipeline_name, default_model
def llm_first_checker(model: Union[str, List[str], Model, List[Model]],
revision: Optional[str]) -> Optional[str]:
def llm_first_checker(model: Union[str, List[str], Model,
List[Model]], revision: Optional[str],
kwargs: Dict[str, Any]) -> Optional[str]:
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
if isinstance(model, list):
model = model[0]
if not isinstance(model, str):
model = model.model_dir
if kwargs.get('llm_framework') == 'swift':
return 'llm'
model_type = ModelTypeHelper.get(
model, revision, with_adapter=True, split='-', use_cache=True)
if LLMAdapterRegistry.contains(model_type):

View File

@@ -1,9 +1,11 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from contextlib import contextmanager
from threading import Lock
from typing import Any, Callable, Dict, Generator, Iterator, List, Tuple, Union
import json
import numpy as np
import torch
from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -18,7 +20,7 @@ from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.util import is_model, is_official_hub_path
from modelscope.utils.config import Config
from modelscope.utils.constant import Frameworks, Invoke, ModelFile, Tasks
from modelscope.utils.device import device_placement
from modelscope.utils.device import create_device, device_placement
from modelscope.utils.logger import get_logger
from modelscope.utils.model_type_helper import ModelTypeHelper
from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
@@ -27,6 +29,8 @@ from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
logger = get_logger()
SWIFT_MODEL_ID_MAPPING = {}
class LLMAdapterRegistry:
@@ -79,6 +83,8 @@ class LLMAdapterRegistry:
class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
def initiate_single_model(self, model):
from swift import Swift
if isinstance(model, str):
logger.info(f'initiate model from {model}')
if self._is_swift_model(model):
@@ -86,7 +92,6 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
logger.warning(
f'Cannot using swift with llm_framework, ignoring {self.llm_framework}.'
)
from swift import Swift
base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
assert base_model is not None, 'Cannot get adapter_cfg.model_id_or_path from configuration.json file.'
@@ -170,6 +175,10 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
self.device_map = 'cuda'
self.torch_dtype = kwargs.pop('torch_dtype', None)
self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
if llm_framework == 'swift':
self._init_swift(kwargs['model'], kwargs.get('device', 'gpu'))
return
with self._temp_configuration_file(kwargs):
super().__init__(*args, **kwargs)
if isinstance(self.model, PreTrainedModel):
@@ -195,6 +204,69 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
self.tokenizer = self._get_tokenizer(
tokenizer_class) if tokenizer is None else tokenizer
def _init_swift(self, model_id, device) -> None:
from swift.llm import prepare_model_template
from swift.llm.utils import MODEL_MAPPING, InferArguments
global SWIFT_MODEL_ID_MAPPING
if not SWIFT_MODEL_ID_MAPPING:
SWIFT_MODEL_ID_MAPPING = {
v['model_id_or_path']: k
for k, v in MODEL_MAPPING.items()
}
def format_messages(messages: Dict[str, List[Dict[str, str]]],
tokenizer: PreTrainedTokenizer,
**kwargs) -> Dict[str, torch.Tensor]:
inputs, _ = self.template.encode(get_example(messages))
inputs.pop('labels', None)
if 'input_ids' in inputs:
input_ids = torch.tensor(inputs['input_ids'])[None]
inputs['input_ids'] = input_ids
token_len = input_ids.shape[1]
if 'inputs_embeds' in inputs:
inputs_embeds = inputs['inputs_embeds'][None]
inputs['inputs_embeds'] = inputs_embeds
token_len = inputs_embeds.shape[1]
inputs['attention_mask'] = torch.ones(token_len)[None]
if 'token_type_ids' in inputs:
inputs['token_type_ids'] = torch.tensor(
inputs['token_type_ids'])[None]
return inputs
def get_example(
messages: Dict[str, List[Dict[str, str]]]) -> Dict[str, str]:
messages = messages['messages']
assert len(messages) > 0, 'messages cannot be empty!'
system = None
if messages[0]['role'] == 'system':
system = messages[0]['content']
messages = messages[1:]
assert len(messages) % 2 == 1, 'Unsupported messages format!'
contents = [message['content'] for message in messages]
prompt = contents[-1]
history = list(zip(contents[::2], contents[1::2]))
return dict(system=system, prompt=prompt, history=history)
assert model_id in SWIFT_MODEL_ID_MAPPING, 'Swift framework does not support current model!'
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
model, template = prepare_model_template(
args, device_map=self.device_map)
self.model = add_stream_generate(model)
template.model = self.model
self.template = template
self.tokenizer = template.tokenizer
self.format_messages = format_messages
self.has_multiple_models = False
self.framework = Frameworks.torch
self.device_name = device
self.device = create_device(device)
self._model_prepare = False
self._model_prepare_lock = Lock()
self._auto_collate = True
self._compile = False
@contextmanager
def _temp_configuration_file(self, kwargs: Dict[str, Any]):
kwargs['model'] = model = self.initiate_single_model(kwargs['model'])
@@ -217,7 +289,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
= isinstance(inputs, dict) and 'messages' in inputs
tokens = self.preprocess(inputs, **preprocess_params)
if self.llm_framework is None:
if self.llm_framework in (None, 'swift'):
# pytorch model
if hasattr(self.model, 'generate'):
outputs = self.model.generate(**tokens, **forward_params)
@@ -313,6 +385,9 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
def postprocess(self, outputs, **kwargs):
is_messages = kwargs.pop('is_messages')
if not isinstance(outputs, str):
shape_type = (torch.Tensor, np.ndarray)
if isinstance(outputs, shape_type) and len(outputs.shape) > 1:
outputs = outputs[0]
response = self.tokenizer.decode(
outputs, skip_special_tokens=True, **kwargs)
else:

View File

@@ -6,6 +6,8 @@ from typing import Any, Dict, Generator, List, Optional, Union
import torch
import torch.distributed as dist
import transformers
from packaging import version
from torch import nn
from transformers import PreTrainedModel
from transformers.generation import GreedySearchDecoderOnlyOutput # noqa
@@ -173,16 +175,24 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
@contextmanager
def _replace_generate(self, model: PreTrainedModel) -> Generator:
greedy_search = model.greedy_search
sample = model.sample
model.greedy_search = types.MethodType(self._greedy_search, model)
model.sample = types.MethodType(self._sample, model)
if version.parse(transformers.__version__) >= version.parse('4.39.0'):
greedy_search_name = '_greedy_search'
sample_name = '_sample'
else:
greedy_search_name = 'greedy_search'
sample_name = 'sample'
origin_greedy_search = getattr(model, greedy_search_name)
origin_sample = getattr(model, sample_name)
setattr(model, greedy_search_name,
types.MethodType(self.stream_greedy_search, model))
setattr(model, sample_name, types.MethodType(self.stream_sample,
model))
yield
model.greedy_search = greedy_search
model.sample = sample
setattr(model, greedy_search_name, origin_greedy_search)
setattr(model, sample_name, origin_sample)
@staticmethod
def _greedy_search(
def stream_greedy_search(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
@@ -356,7 +366,7 @@ class PretrainedModelStreamingOutputMixin(StreamingOutputMixin):
break
@staticmethod
def _sample(
def stream_sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,

View File

@@ -44,7 +44,7 @@ class Human3DRenderTest(unittest.TestCase):
human3d = pipeline(self.task, model=self.model_id)
input = {
'dataset_id': 'damo/3DHuman_synthetic_dataset',
'case_id': '3f2a7538253e42a8',
'case_id': '000039',
'resolution': 1024,
}
output = human3d(input)

View File

@@ -131,7 +131,13 @@ class LLMPipelineTest(unittest.TestCase):
]
}]
}
self.gen_cfg = {'do_sample': True, 'max_length': 512}
self.messages_zh_one_round = {
'messages': [{
'role': 'user',
'content': '你叫什么名字?'
}]
}
self.gen_cfg = {'do_sample': True, 'max_new_tokens': 128}
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm2(self):
@@ -365,6 +371,69 @@ class LLMPipelineTest(unittest.TestCase):
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_baichuan2_with_swift(self):
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-13B-Chat',
llm_framework='swift',
llm_first=True)
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_baichuan2_stream_gemerate(self):
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-13B-Chat',
llm_framework='swift',
llm_first=True)
for stream_output in pipe.stream_generate(self.messages_zh,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_yi_with_swift(self):
pipe = pipeline(
task='chat',
model='01ai/Yi-1.5-6B-Chat',
llm_framework='swift',
llm_first=True)
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_yi_stream_gemerate(self):
pipe = pipeline(
task='chat',
model='01ai/Yi-1.5-6B-Chat',
llm_framework='swift',
llm_first=True)
for stream_output in pipe.stream_generate(self.messages_zh,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_internlm2_with_swift(self):
pipe = pipeline(
task='chat',
model='Shanghai_AI_Laboratory/internlm2-1_8b',
llm_framework='swift',
llm_first=True)
print('messages: ', pipe(self.messages_zh_one_round, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_internlm2_stream_gemerate(self):
pipe = pipeline(
task='chat',
model='Shanghai_AI_Laboratory/internlm2-1_8b',
llm_framework='swift',
llm_first=True)
for stream_output in pipe.stream_generate(self.messages_zh_one_round,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
if __name__ == '__main__':
unittest.main()