Format llm pipeline (#1094)

* format llm pipeline

Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
Yingda Chen
2024-11-22 20:04:59 +08:00
committed by GitHub
parent ddc5fab311
commit 2b1c839918
13 changed files with 135 additions and 104 deletions

View File

@@ -205,6 +205,13 @@ class Pipeline(ABC):
kwargs['preprocess_params'] = preprocess_params
kwargs['forward_params'] = forward_params
kwargs['postprocess_params'] = postprocess_params
# for LLMPipeline, we shall support treating list of roles as a
# one single 'messages' input
if 'LLMPipeline' in type(self).__name__ and isinstance(input, list):
input = {'messages': input}
kwargs['is_message'] = True
if isinstance(input, list):
if batch_size is None:
output = []

View File

@@ -7,7 +7,7 @@ from modelscope.hub.snapshot_download import snapshot_download
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
from modelscope.models.base import Model
from modelscope.utils.config import ConfigDict, check_config
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
ThirdParty)
from modelscope.utils.hub import read_config
from modelscope.utils.plugins import (register_modelhub_repo,
@@ -108,12 +108,23 @@ def pipeline(task: str = None,
"""
if task is None and pipeline_name is None:
raise ValueError('task or pipeline_name is required')
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
if task is not None and task.lower() in [
Tasks.text_generation, Tasks.chat
]:
# if not specified, prefer llm pipeline for aforementioned tasks
if prefer_llm_pipeline is None:
prefer_llm_pipeline = True
# for llm pipeline, if llm_framework is not specified, default to swift instead
# TODO: port the swift infer based on transformer into ModelScope
if prefer_llm_pipeline and kwargs.get('llm_framework') is None:
kwargs['llm_framework'] = 'swift'
third_party = kwargs.get(ThirdParty.KEY)
if third_party is not None:
kwargs.pop(ThirdParty.KEY)
if pipeline_name is None and kwargs.get('llm_first'):
pipeline_name = llm_first_checker(model, model_revision, kwargs)
if pipeline_name is None and prefer_llm_pipeline:
pipeline_name = external_engine_for_llm_checker(
model, model_revision, kwargs)
else:
model = normalize_model_input(
model,
@@ -133,9 +144,9 @@ def pipeline(task: str = None,
model[0], revision=model_revision)
register_plugins_repo(cfg.safe_get('plugins'))
register_modelhub_repo(model, cfg.get('allow_remote', False))
pipeline_name = llm_first_checker(
pipeline_name = external_engine_for_llm_checker(
model, model_revision,
kwargs) if kwargs.get('llm_first') else None
kwargs) if prefer_llm_pipeline else None
if pipeline_name is not None:
pipeline_props = {'type': pipeline_name}
else:
@@ -208,9 +219,10 @@ def get_default_pipeline_info(task):
return pipeline_name, default_model
def llm_first_checker(model: Union[str, List[str], Model,
List[Model]], revision: Optional[str],
kwargs: Dict[str, Any]) -> Optional[str]:
def external_engine_for_llm_checker(model: Union[str, List[str], Model,
List[Model]],
revision: Optional[str],
kwargs: Dict[str, Any]) -> Optional[str]:
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
if isinstance(model, list):
@@ -229,5 +241,5 @@ def llm_first_checker(model: Union[str, List[str], Model,
def clear_llm_info(kwargs: Dict):
from modelscope.utils.model_type_helper import ModelTypeHelper
kwargs.pop('llm_first', None)
kwargs.pop('external_engine_for_llm', None)
ModelTypeHelper.clear_cache()

View File

@@ -30,6 +30,7 @@ from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
logger = get_logger()
SWIFT_MODEL_ID_MAPPING = {}
SWIFT_FRAMEWORK = 'swift'
class LLMAdapterRegistry:
@@ -90,7 +91,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
if self._is_swift_model(model):
if self.llm_framework is not None:
logger.warning(
f'Cannot using swift with llm_framework, ignoring {self.llm_framework}.'
f'Cannot use swift with llm_framework, ignoring {self.llm_framework}.'
)
base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
@@ -155,7 +156,8 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
return False
self.cfg = Config.from_file(cfg_file)
return self.cfg.safe_get('adapter_cfg.tuner_backend') == 'swift'
return self.cfg.safe_get(
'adapter_cfg.tuner_backend') == SWIFT_FRAMEWORK
def _wrap_infer_framework(self, model_dir, framework='vllm'):
from modelscope.pipelines.accelerate.base import InferFramework
@@ -184,7 +186,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
self.torch_dtype = kwargs.pop('torch_dtype', None)
self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
if llm_framework == 'swift':
if llm_framework == SWIFT_FRAMEWORK:
self._init_swift(kwargs['model'], kwargs.get('device', 'gpu'))
return
with self._temp_configuration_file(kwargs):
@@ -254,9 +256,13 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
contents = [message['content'] for message in messages]
prompt = contents[-1]
history = list(zip(contents[::2], contents[1::2]))
return dict(system=system, prompt=prompt, history=history)
if self.llm_framework == SWIFT_FRAMEWORK:
return dict(system=system, query=prompt, history=history)
else:
return dict(system=system, prompt=prompt, history=history)
assert model_id in SWIFT_MODEL_ID_MAPPING, 'Swift framework does not support current model!'
assert model_id in SWIFT_MODEL_ID_MAPPING,\
f'Invalid model id {model_id} or Swift framework does not support this model.'
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
model, template = prepare_model_template(
args, device_map=self.device_map)
@@ -297,7 +303,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
= isinstance(inputs, dict) and 'messages' in inputs
tokens = self.preprocess(inputs, **preprocess_params)
if self.llm_framework in (None, 'swift'):
if self.llm_framework in (None, SWIFT_FRAMEWORK):
# pytorch model
if hasattr(self.model, 'generate'):
outputs = self.model.generate(**tokens, **forward_params)
@@ -310,7 +316,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
tokens = [list(tokens['inputs'].flatten().numpy())]
outputs = self.model(tokens, **forward_params)[0]
if self.llm_framework is None:
if self.llm_framework in (None, SWIFT_FRAMEWORK):
# pytorch model
outputs = outputs.tolist()[0][len(tokens['inputs'][0]):]
response = self.postprocess(outputs, **postprocess_params)

View File

@@ -10,7 +10,7 @@ def add_server_args(parser: argparse.ArgumentParser):
parser.add_argument('--port', type=int, default=8000, help='Server port')
parser.add_argument('--debug', default='debug', help='Set debug level.')
parser.add_argument(
'--llm_first',
'--external_engine_for_llm',
type=bool,
default=True,
help='Use LLMPipeline first for llm models.')

View File

@@ -14,9 +14,9 @@ logger = get_logger()
def _startup_model(app: FastAPI) -> None:
logger.info('download model and create pipeline')
app.state.pipeline = create_pipeline(app.state.args.model_id,
app.state.args.revision,
app.state.args.llm_first)
app.state.pipeline = create_pipeline(
app.state.args.model_id, app.state.args.revision,
app.state.args.external_engine_for_llm)
info = {}
info['task_name'] = app.state.pipeline.group_key
info['schema'] = get_task_schemas(app.state.pipeline.group_key)

View File

@@ -43,7 +43,7 @@ class DeployChecker:
task=task,
model=model_id,
model_revision=model_revision,
llm_first=True)
external_engine_for_llm=True)
pipeline_info = get_pipeline_information_by_pipeline(ppl)
# call pipeline

View File

@@ -67,7 +67,9 @@ Todo:
"""
def create_pipeline(model_id: str, revision: str, llm_first: bool = True):
def create_pipeline(model_id: str,
revision: str,
external_engine_for_llm: bool = True):
model_configuration_file = model_file_download(
model_id=model_id,
file_path=ModelFile.CONFIGURATION,
@@ -77,7 +79,7 @@ def create_pipeline(model_id: str, revision: str, llm_first: bool = True):
task=cfg.task,
model=model_id,
model_revision=revision,
llm_first=llm_first)
external_engine_for_llm=external_engine_for_llm)
def get_class_user_attributes(cls):

View File

@@ -39,7 +39,7 @@ class ModelJsonTest:
task=task,
model=model_id,
model_revision=model_revision,
llm_first=True)
external_engine_for_llm=True)
pipeline_info = get_pipeline_information_by_pipeline(ppl)
# call pipeline

View File

@@ -17,26 +17,38 @@ class TextGPT3GenerationTest(unittest.TestCase):
self.model_dir_13B = snapshot_download(self.model_id_13B)
self.input = '好的'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skip('deprecated, skipped')
def test_gpt3_1_3B(self):
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
pipe = pipeline(
Tasks.text_generation,
model=self.model_id_1_3B,
external_engine_for_llm=False)
print(pipe(self.input))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skip('deprecated, skipped')
def test_gpt3_1_3B_with_streaming(self):
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
pipe = pipeline(
Tasks.text_generation,
model=self.model_id_1_3B,
external_engine_for_llm=False)
for output in pipe.stream_generate(self.input, max_length=64):
print(output, end='\r')
print()
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skip('deprecated, skipped')
def test_gpt3_2_7B(self):
pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B)
pipe = pipeline(
Tasks.text_generation,
model=self.model_id_2_7B,
external_engine_for_llm=False)
print(pipe(self.input))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
@unittest.skip('deprecated, skipped')
def test_gpt3_1_3B_with_args(self):
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
pipe = pipeline(
Tasks.text_generation,
model=self.model_id_1_3B,
external_engine_for_llm=False)
print(pipe(self.input, top_p=0.9, temperature=0.9, max_length=32))
@unittest.skip('distributed gpt3 13B, skipped')
@@ -62,7 +74,10 @@ class TextGPT3GenerationTest(unittest.TestCase):
|_ mp_rank_06_model_states.pt
|_ mp_rank_07_model_states.pt
"""
pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B)
pipe = pipeline(
Tasks.text_generation,
model=self.model_dir_13B,
external_engine_for_llm=False)
print(pipe(self.input))

View File

@@ -24,6 +24,7 @@ class Llama2TextGenerationPipelineTest(unittest.TestCase):
input,
init_kwargs={},
run_kwargs={}):
init_kwargs['external_engine_for_llm'] = False
pipeline_ins = pipeline(task=Tasks.chat, model=model_id, **init_kwargs)
pipeline_ins._model_prepare = True
result = pipeline_ins(input, **run_kwargs)
@@ -36,6 +37,7 @@ class Llama2TextGenerationPipelineTest(unittest.TestCase):
self.llama2_model_id_7B_chat_ms,
self.llama2_input_chat_ch,
init_kwargs={
'external_engine_for_llm': False,
'device_map': 'auto',
'torch_dtype': torch.float16,
'model_revision': 'v1.0.5',

View File

@@ -141,29 +141,25 @@ class LLMPipelineTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm2(self):
pipe = pipeline(
task='chat', model='ZhipuAI/chatglm2-6b', llm_first=True)
pipe = pipeline(task='chat', model='ZhipuAI/chatglm2-6b')
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm2int4(self):
pipe = pipeline(
task='chat', model='ZhipuAI/chatglm2-6b-int4', llm_first=True)
pipe = pipeline(task='chat', model='ZhipuAI/chatglm2-6b-int4')
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm232k(self):
pipe = pipeline(
task='chat', model='ZhipuAI/chatglm2-6b-32k', llm_first=True)
pipe = pipeline(task='chat', model='ZhipuAI/chatglm2-6b-32k')
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm3(self):
pipe = pipeline(
task='chat', model='ZhipuAI/chatglm3-6b', llm_first=True)
pipe = pipeline(task='chat', model='ZhipuAI/chatglm3-6b')
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -174,8 +170,7 @@ class LLMPipelineTest(unittest.TestCase):
model='modelscope/Llama-2-7b-ms',
torch_dtype=torch.float16,
device_map='auto',
ignore_file_pattern=[r'.+\.bin$'],
llm_first=True)
ignore_file_pattern=[r'.+\.bin$'])
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
@@ -187,8 +182,7 @@ class LLMPipelineTest(unittest.TestCase):
revision='v1.0.2',
torch_dtype=torch.float16,
device_map='auto',
ignore_file_pattern=[r'.+\.bin$'],
llm_first=True)
ignore_file_pattern=[r'.+\.bin$'])
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
@@ -199,8 +193,7 @@ class LLMPipelineTest(unittest.TestCase):
model='AI-ModelScope/CodeLlama-7b-Instruct-hf',
torch_dtype=torch.float16,
device_map='auto',
ignore_file_pattern=[r'.+\.bin$'],
llm_first=True)
ignore_file_pattern=[r'.+\.bin$'])
print('messages: ', pipe(self.messages_code, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_code, **self.gen_cfg))
@@ -210,8 +203,7 @@ class LLMPipelineTest(unittest.TestCase):
task='chat',
model='baichuan-inc/baichuan-7B',
device_map='auto',
torch_dtype=torch.float16,
llm_first=True)
torch_dtype=torch.float16)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -221,8 +213,7 @@ class LLMPipelineTest(unittest.TestCase):
task='chat',
model='baichuan-inc/Baichuan-13B-Base',
device_map='auto',
torch_dtype=torch.float16,
llm_first=True)
torch_dtype=torch.float16)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -232,8 +223,7 @@ class LLMPipelineTest(unittest.TestCase):
task='chat',
model='baichuan-inc/Baichuan-13B-Chat',
device_map='auto',
torch_dtype=torch.float16,
llm_first=True)
torch_dtype=torch.float16)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -243,8 +233,7 @@ class LLMPipelineTest(unittest.TestCase):
task='chat',
model='baichuan-inc/Baichuan2-7B-Base',
device_map='auto',
torch_dtype=torch.float16,
llm_first=True)
torch_dtype=torch.float16)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -254,8 +243,7 @@ class LLMPipelineTest(unittest.TestCase):
task='chat',
model='baichuan-inc/Baichuan2-7B-Chat',
device_map='auto',
torch_dtype=torch.float16,
llm_first=True)
torch_dtype=torch.float16)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -265,8 +253,7 @@ class LLMPipelineTest(unittest.TestCase):
task='chat',
model='baichuan-inc/Baichuan2-7B-Chat-4bits',
device_map='auto',
torch_dtype=torch.float16,
llm_first=True)
torch_dtype=torch.float16)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -276,8 +263,7 @@ class LLMPipelineTest(unittest.TestCase):
task='chat',
model='baichuan-inc/Baichuan2-13B-Chat-4bits',
device_map='auto',
torch_dtype=torch.float16,
llm_first=True)
torch_dtype=torch.float16)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -288,8 +274,7 @@ class LLMPipelineTest(unittest.TestCase):
model='AI-ModelScope/WizardLM-13B-V1.2',
device_map='auto',
torch_dtype=torch.float16,
format_messages='wizardlm',
llm_first=True)
format_messages='wizardlm')
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
@@ -300,8 +285,7 @@ class LLMPipelineTest(unittest.TestCase):
model='AI-ModelScope/WizardMath-7B-V1.0',
device_map='auto',
torch_dtype=torch.float16,
format_messages='wizardcode',
llm_first=True)
format_messages='wizardcode')
print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg))
@@ -312,8 +296,7 @@ class LLMPipelineTest(unittest.TestCase):
model='AI-ModelScope/WizardCoder-Python-13B-V1.0',
device_map='auto',
torch_dtype=torch.float16,
format_messages='wizardcode',
llm_first=True)
format_messages='wizardcode')
print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg))
@@ -329,20 +312,19 @@ class LLMPipelineTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_qwen(self):
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat')
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skip('Need optimum and auto-gptq')
def test_qwen_int4(self):
pipe = pipeline(
task='chat', model='qwen/Qwen-7B-Chat-Int4', llm_first=True)
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat-Int4')
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_qwen_vl(self):
pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat', llm_first=True)
pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat')
print('messages: ', pipe(self.messages_mm, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -352,21 +334,20 @@ class LLMPipelineTest(unittest.TestCase):
model_type = ModelTypeHelper.get(model_id)
assert not LLMAdapterRegistry.contains(model_type)
pipe = pipeline(task='chat', model=model_id, llm_first=True)
pipe = pipeline(task='chat', model=model_id)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_qwen_stream_gemerate(self):
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
pipe = pipeline(task='chat', model='Qwen/Qwen-7B-Chat')
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_qwen1_5_stream_gemerate(self):
pipe = pipeline(
task='chat', model='qwen/Qwen1.5-1.8B-Chat', llm_first=True)
def test_qwen1_5_stream_generate(self):
pipe = pipeline(task='chat', model='Qwen/Qwen1.5-1.8B-Chat')
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
@@ -376,8 +357,7 @@ class LLMPipelineTest(unittest.TestCase):
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-13B-Chat',
llm_framework='swift',
llm_first=True)
llm_framework='swift')
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -386,8 +366,7 @@ class LLMPipelineTest(unittest.TestCase):
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-13B-Chat',
llm_framework='swift',
llm_first=True)
llm_framework='swift')
for stream_output in pipe.stream_generate(self.messages_zh,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
@@ -395,20 +374,14 @@ class LLMPipelineTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_yi_with_swift(self):
pipe = pipeline(
task='chat',
model='01ai/Yi-1.5-6B-Chat',
llm_framework='swift',
llm_first=True)
task='chat', model='01ai/Yi-1.5-6B-Chat', llm_framework='swift')
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_yi_stream_gemerate(self):
pipe = pipeline(
task='chat',
model='01ai/Yi-1.5-6B-Chat',
llm_framework='swift',
llm_first=True)
task='chat', model='01ai/Yi-1.5-6B-Chat', llm_framework='swift')
for stream_output in pipe.stream_generate(self.messages_zh,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')
@@ -418,8 +391,7 @@ class LLMPipelineTest(unittest.TestCase):
pipe = pipeline(
task='chat',
model='Shanghai_AI_Laboratory/internlm2-1_8b',
llm_framework='swift',
llm_first=True)
llm_framework='swift')
print('messages: ', pipe(self.messages_zh_one_round, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@@ -428,8 +400,7 @@ class LLMPipelineTest(unittest.TestCase):
pipe = pipeline(
task='chat',
model='Shanghai_AI_Laboratory/internlm2-1_8b',
llm_framework='swift',
llm_first=True)
llm_framework='swift')
for stream_output in pipe.stream_generate(self.messages_zh_one_round,
**self.gen_cfg):
print('messages: ', stream_output, end='\r')

View File

@@ -17,9 +17,7 @@ class MplugOwlMultimodalDialogueTest(unittest.TestCase):
model = Model.from_pretrained(
'damo/multi-modal_mplug_owl_multimodal-dialogue_7b')
pipeline_multimodal_dialogue = pipeline(
task=Tasks.multimodal_dialogue,
model=model,
)
task=Tasks.multimodal_dialogue, model=model)
image = 'data/resource/portrait_input.png'
system_prompt_1 = 'The following is a conversation between a curious human and AI assistant.'
system_prompt_2 = "The assistant gives helpful, detailed, and polite answers to the user's questions."

View File

@@ -56,7 +56,10 @@ class TextGenerationTest(unittest.TestCase):
first_sequence='sentence',
second_sequence=None)
pipeline_ins = pipeline(
task=Tasks.text_generation, model=model, preprocessor=preprocessor)
task=Tasks.text_generation,
model=model,
preprocessor=preprocessor,
external_engine_for_llm=False)
print(pipeline_ins(input))
def run_pipeline_with_model_id(self,
@@ -64,6 +67,7 @@ class TextGenerationTest(unittest.TestCase):
input,
init_kwargs={},
run_kwargs={}):
init_kwargs['external_engine_for_llm'] = False
pipeline_ins = pipeline(
task=Tasks.text_generation, model=model_id, **init_kwargs)
print(pipeline_ins(input, **run_kwargs))
@@ -73,12 +77,14 @@ class TextGenerationTest(unittest.TestCase):
input,
init_kwargs={},
run_kwargs={}):
init_kwargs['external_engine_for_llm'] = False
pipeline_ins = pipeline(
task=Tasks.text_generation, model=model_id, **init_kwargs)
# set stream inputs
assert isinstance(pipeline_ins, StreamingOutputMixin)
for output in pipeline_ins.stream_generate(input, **run_kwargs):
for output in pipeline_ins.stream_generate(
input, **run_kwargs, external_engine_for_llm=False):
print(output, end='\r')
print()
@@ -256,7 +262,10 @@ class TextGenerationTest(unittest.TestCase):
cache_path, first_sequence='sentence', second_sequence=None)
pipeline1 = TextGenerationPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.text_generation, model=model, preprocessor=preprocessor)
Tasks.text_generation,
model=model,
preprocessor=preprocessor,
external_engine_for_llm=False)
print(
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
)
@@ -272,14 +281,18 @@ class TextGenerationTest(unittest.TestCase):
second_sequence=None)
pipeline1 = TextGenerationPipeline(model, preprocessor)
pipeline2 = pipeline(
Tasks.text_generation, model=model, preprocessor=preprocessor)
Tasks.text_generation,
model=model,
preprocessor=preprocessor,
external_engine_for_llm=False)
print(
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}'
)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.text_generation)
pipeline_ins = pipeline(
task=Tasks.text_generation, external_engine_for_llm=False)
print(
pipeline_ins(
[self.palm_input_zh, self.palm_input_zh, self.palm_input_zh],
@@ -288,13 +301,17 @@ class TextGenerationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_bloom(self):
pipe = pipeline(
task=Tasks.text_generation, model='langboat/bloom-1b4-zh')
task=Tasks.text_generation,
model='langboat/bloom-1b4-zh',
external_engine_for_llm=False)
print(pipe('中国的首都是'))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_gpt_neo(self):
pipe = pipeline(
task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base')
task=Tasks.text_generation,
model='langboat/mengzi-gpt-neo-base',
external_engine_for_llm=False)
print(
pipe(
'我是',
@@ -308,7 +325,8 @@ class TextGenerationTest(unittest.TestCase):
def test_gpt2(self):
pipe = pipeline(
task=Tasks.text_generation,
model='damo/nlp_gpt2_text-generation_english-base')
model='damo/nlp_gpt2_text-generation_english-base',
external_engine_for_llm=False)
print(pipe('My name is Teven and I am'))
@unittest.skip('oom error for 7b model')