Add llm_first parameter for pipeline

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14264249

* support llm_first parameter

* register_module(Tasks.text_generation)

* fix bug

* update format & fix out_base64 for int4

* pre-commit
This commit is contained in:
hemu.zp
2023-10-13 14:04:04 +08:00
committed by wenmeng.zwm
parent 582c3a5415
commit 4cad376298
10 changed files with 150 additions and 69 deletions

View File

@@ -3,6 +3,7 @@ import unittest
import torch
from modelscope import pipeline
from modelscope.pipelines.nlp.llm_pipeline import LLMPipeline
from modelscope.utils.test_utils import test_level
@@ -132,143 +133,172 @@ class LLMPipelineTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm2(self):
pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b', device_map='auto')
pipe = pipeline(
task='chat', model='ZhipuAI/chatglm2-6b', llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm2int4(self):
pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-int4')
pipe = pipeline(
task='chat', model='ZhipuAI/chatglm2-6b-int4', llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_chatglm232k(self):
pipe = LLMPipeline(model='ZhipuAI/chatglm2-6b-32k', device_map='auto')
pipe = pipeline(
task='chat', model='ZhipuAI/chatglm2-6b-32k', llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_llama2(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='modelscope/Llama-2-7b-ms',
torch_dtype=torch.float16,
device_map='auto',
ignore_file_pattern=[r'.+\.bin$'])
ignore_file_pattern=[r'.+\.bin$'],
llm_first=True)
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_llama2chat(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='modelscope/Llama-2-7b-chat-ms',
revision='v1.0.2',
torch_dtype=torch.float16,
device_map='auto',
ignore_file_pattern=[r'.+\.bin$'])
ignore_file_pattern=[r'.+\.bin$'],
llm_first=True)
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_codellama(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='AI-ModelScope/CodeLlama-7b-Instruct-hf',
torch_dtype=torch.float16,
device_map='auto',
ignore_file_pattern=[r'.+\.bin$'])
ignore_file_pattern=[r'.+\.bin$'],
llm_first=True)
print('messages: ', pipe(self.messages_code, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_code, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_baichuan_7b(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='baichuan-inc/baichuan-7B',
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_baichuan_13b(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan-13B-Base',
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_baichuan_13bchat(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan-13B-Chat',
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_baichuan2_7b(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-7B-Base',
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_baichuan2_7bchat(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-7B-Chat',
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skip('Need bitsandbytes')
def test_baichuan2_7bchat_int4(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-7B-Chat-4bits',
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skip('Need bitsandbytes')
def test_baichuan2_13bchat_int4(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='baichuan-inc/Baichuan2-13B-Chat-4bits',
device_map='auto',
torch_dtype=torch.float16)
torch_dtype=torch.float16,
llm_first=True)
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_wizardlm_13b(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='AI-ModelScope/WizardLM-13B-V1.2',
device_map='auto',
torch_dtype=torch.float16,
format_messages='wizardlm')
format_messages='wizardlm',
llm_first=True)
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_wizardmath(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='AI-ModelScope/WizardMath-7B-V1.0',
device_map='auto',
torch_dtype=torch.float16,
format_messages='wizardcode')
format_messages='wizardcode',
llm_first=True)
print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_wizardcode_13b(self):
pipe = LLMPipeline(
pipe = pipeline(
task='chat',
model='AI-ModelScope/WizardCoder-Python-13B-V1.0',
device_map='auto',
torch_dtype=torch.float16,
format_messages='wizardcode')
format_messages='wizardcode',
llm_first=True)
print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg))
@@ -284,19 +314,20 @@ class LLMPipelineTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_qwen(self):
pipe = LLMPipeline(model='qwen/Qwen-7B-Chat', device_map='auto')
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skip('Need optimum and auto-gptq')
def test_qwen_int4(self):
pipe = LLMPipeline(model='qwen/Qwen-7B-Chat-Int4', device_map='auto')
pipe = pipeline(
task='chat', model='qwen/Qwen-7B-Chat-Int4', llm_first=True)
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_qwen_vl(self):
pipe = LLMPipeline(model='qwen/Qwen-VL-Chat', device_map='auto')
pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat', llm_first=True)
print('messages: ', pipe(self.messages_mm, **self.gen_cfg))
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))