mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-25 04:30:48 +01:00
refactor
This commit is contained in:
@@ -108,7 +108,7 @@ def pipeline(task: str = None,
|
||||
"""
|
||||
if task is None and pipeline_name is None:
|
||||
raise ValueError('task or pipeline_name is required')
|
||||
prefer_llm_pipeline = kwargs.get('llm_first')
|
||||
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
|
||||
if task is not None and task.lower() in [
|
||||
Tasks.text_generation, Tasks.chat
|
||||
]:
|
||||
@@ -123,7 +123,8 @@ def pipeline(task: str = None,
|
||||
if third_party is not None:
|
||||
kwargs.pop(ThirdParty.KEY)
|
||||
if pipeline_name is None and prefer_llm_pipeline:
|
||||
pipeline_name = llm_first_checker(model, model_revision, kwargs)
|
||||
pipeline_name = external_engine_for_llm_checker(
|
||||
model, model_revision, kwargs)
|
||||
else:
|
||||
model = normalize_model_input(
|
||||
model,
|
||||
@@ -143,7 +144,7 @@ def pipeline(task: str = None,
|
||||
model[0], revision=model_revision)
|
||||
register_plugins_repo(cfg.safe_get('plugins'))
|
||||
register_modelhub_repo(model, cfg.get('allow_remote', False))
|
||||
pipeline_name = llm_first_checker(
|
||||
pipeline_name = external_engine_for_llm_checker(
|
||||
model, model_revision,
|
||||
kwargs) if prefer_llm_pipeline else None
|
||||
if pipeline_name is not None:
|
||||
@@ -218,9 +219,10 @@ def get_default_pipeline_info(task):
|
||||
return pipeline_name, default_model
|
||||
|
||||
|
||||
def llm_first_checker(model: Union[str, List[str], Model,
|
||||
List[Model]], revision: Optional[str],
|
||||
kwargs: Dict[str, Any]) -> Optional[str]:
|
||||
def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||
List[Model]],
|
||||
revision: Optional[str],
|
||||
kwargs: Dict[str, Any]) -> Optional[str]:
|
||||
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
|
||||
|
||||
if isinstance(model, list):
|
||||
@@ -239,5 +241,5 @@ def llm_first_checker(model: Union[str, List[str], Model,
|
||||
def clear_llm_info(kwargs: Dict):
|
||||
from modelscope.utils.model_type_helper import ModelTypeHelper
|
||||
|
||||
kwargs.pop('llm_first', None)
|
||||
kwargs.pop('external_engine_for_llm', None)
|
||||
ModelTypeHelper.clear_cache()
|
||||
|
||||
@@ -10,7 +10,7 @@ def add_server_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--port', type=int, default=8000, help='Server port')
|
||||
parser.add_argument('--debug', default='debug', help='Set debug level.')
|
||||
parser.add_argument(
|
||||
'--llm_first',
|
||||
'--external_engine_for_llm',
|
||||
type=bool,
|
||||
default=True,
|
||||
help='Use LLMPipeline first for llm models.')
|
||||
|
||||
@@ -14,9 +14,9 @@ logger = get_logger()
|
||||
|
||||
def _startup_model(app: FastAPI) -> None:
|
||||
logger.info('download model and create pipeline')
|
||||
app.state.pipeline = create_pipeline(app.state.args.model_id,
|
||||
app.state.args.revision,
|
||||
app.state.args.llm_first)
|
||||
app.state.pipeline = create_pipeline(
|
||||
app.state.args.model_id, app.state.args.revision,
|
||||
app.state.args.external_engine_for_llm)
|
||||
info = {}
|
||||
info['task_name'] = app.state.pipeline.group_key
|
||||
info['schema'] = get_task_schemas(app.state.pipeline.group_key)
|
||||
|
||||
@@ -43,7 +43,7 @@ class DeployChecker:
|
||||
task=task,
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
pipeline_info = get_pipeline_information_by_pipeline(ppl)
|
||||
|
||||
# call pipeline
|
||||
|
||||
@@ -67,7 +67,9 @@ Todo:
|
||||
"""
|
||||
|
||||
|
||||
def create_pipeline(model_id: str, revision: str, llm_first: bool = True):
|
||||
def create_pipeline(model_id: str,
|
||||
revision: str,
|
||||
external_engine_for_llm: bool = True):
|
||||
model_configuration_file = model_file_download(
|
||||
model_id=model_id,
|
||||
file_path=ModelFile.CONFIGURATION,
|
||||
@@ -77,7 +79,7 @@ def create_pipeline(model_id: str, revision: str, llm_first: bool = True):
|
||||
task=cfg.task,
|
||||
model=model_id,
|
||||
model_revision=revision,
|
||||
llm_first=llm_first)
|
||||
external_engine_for_llm=external_engine_for_llm)
|
||||
|
||||
|
||||
def get_class_user_attributes(cls):
|
||||
|
||||
@@ -39,7 +39,7 @@ class ModelJsonTest:
|
||||
task=task,
|
||||
model=model_id,
|
||||
model_revision=model_revision,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
pipeline_info = get_pipeline_information_by_pipeline(ppl)
|
||||
|
||||
# call pipeline
|
||||
|
||||
@@ -20,13 +20,17 @@ class TextGPT3GenerationTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_gpt3_1_3B(self):
|
||||
pipe = pipeline(
|
||||
Tasks.text_generation, model=self.model_id_1_3B, llm_first=False)
|
||||
Tasks.text_generation,
|
||||
model=self.model_id_1_3B,
|
||||
external_engine_for_llm=False)
|
||||
print(pipe(self.input))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_gpt3_1_3B_with_streaming(self):
|
||||
pipe = pipeline(
|
||||
Tasks.text_generation, model=self.model_id_1_3B, llm_first=False)
|
||||
Tasks.text_generation,
|
||||
model=self.model_id_1_3B,
|
||||
external_engine_for_llm=False)
|
||||
for output in pipe.stream_generate(self.input, max_length=64):
|
||||
print(output, end='\r')
|
||||
print()
|
||||
@@ -34,13 +38,17 @@ class TextGPT3GenerationTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_gpt3_2_7B(self):
|
||||
pipe = pipeline(
|
||||
Tasks.text_generation, model=self.model_id_2_7B, llm_first=False)
|
||||
Tasks.text_generation,
|
||||
model=self.model_id_2_7B,
|
||||
external_engine_for_llm=False)
|
||||
print(pipe(self.input))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_gpt3_1_3B_with_args(self):
|
||||
pipe = pipeline(
|
||||
Tasks.text_generation, model=self.model_id_1_3B, llm_first=False)
|
||||
Tasks.text_generation,
|
||||
model=self.model_id_1_3B,
|
||||
external_engine_for_llm=False)
|
||||
print(pipe(self.input, top_p=0.9, temperature=0.9, max_length=32))
|
||||
|
||||
@unittest.skip('distributed gpt3 13B, skipped')
|
||||
@@ -67,7 +75,9 @@ class TextGPT3GenerationTest(unittest.TestCase):
|
||||
|_ mp_rank_07_model_states.pt
|
||||
"""
|
||||
pipe = pipeline(
|
||||
Tasks.text_generation, model=self.model_dir_13B, llm_first=False)
|
||||
Tasks.text_generation,
|
||||
model=self.model_dir_13B,
|
||||
external_engine_for_llm=False)
|
||||
print(pipe(self.input))
|
||||
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ class Llama2TextGenerationPipelineTest(unittest.TestCase):
|
||||
input,
|
||||
init_kwargs={},
|
||||
run_kwargs={}):
|
||||
init_kwargs['external_engine_for_llm'] = False
|
||||
pipeline_ins = pipeline(task=Tasks.chat, model=model_id, **init_kwargs)
|
||||
pipeline_ins._model_prepare = True
|
||||
result = pipeline_ins(input, **run_kwargs)
|
||||
@@ -36,6 +37,7 @@ class Llama2TextGenerationPipelineTest(unittest.TestCase):
|
||||
self.llama2_model_id_7B_chat_ms,
|
||||
self.llama2_input_chat_ch,
|
||||
init_kwargs={
|
||||
'external_engine_for_llm': False,
|
||||
'device_map': 'auto',
|
||||
'torch_dtype': torch.float16,
|
||||
'model_revision': 'v1.0.5',
|
||||
|
||||
@@ -142,28 +142,36 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm2(self):
|
||||
pipe = pipeline(
|
||||
task='chat', model='ZhipuAI/chatglm2-6b', llm_first=True)
|
||||
task='chat',
|
||||
model='ZhipuAI/chatglm2-6b',
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm2int4(self):
|
||||
pipe = pipeline(
|
||||
task='chat', model='ZhipuAI/chatglm2-6b-int4', llm_first=True)
|
||||
task='chat',
|
||||
model='ZhipuAI/chatglm2-6b-int4',
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm232k(self):
|
||||
pipe = pipeline(
|
||||
task='chat', model='ZhipuAI/chatglm2-6b-32k', llm_first=True)
|
||||
task='chat',
|
||||
model='ZhipuAI/chatglm2-6b-32k',
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_chatglm3(self):
|
||||
pipe = pipeline(
|
||||
task='chat', model='ZhipuAI/chatglm3-6b', llm_first=True)
|
||||
task='chat',
|
||||
model='ZhipuAI/chatglm3-6b',
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -175,7 +183,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
torch_dtype=torch.float16,
|
||||
device_map='auto',
|
||||
ignore_file_pattern=[r'.+\.bin$'],
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||
|
||||
@@ -188,7 +196,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
torch_dtype=torch.float16,
|
||||
device_map='auto',
|
||||
ignore_file_pattern=[r'.+\.bin$'],
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||
|
||||
@@ -200,7 +208,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
torch_dtype=torch.float16,
|
||||
device_map='auto',
|
||||
ignore_file_pattern=[r'.+\.bin$'],
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_code, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_code, **self.gen_cfg))
|
||||
|
||||
@@ -211,7 +219,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model='baichuan-inc/baichuan-7B',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -222,7 +230,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model='baichuan-inc/Baichuan-13B-Base',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -233,7 +241,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model='baichuan-inc/Baichuan-13B-Chat',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -244,7 +252,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model='baichuan-inc/Baichuan2-7B-Base',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -255,7 +263,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model='baichuan-inc/Baichuan2-7B-Chat',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -266,7 +274,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model='baichuan-inc/Baichuan2-7B-Chat-4bits',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -277,7 +285,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model='baichuan-inc/Baichuan2-13B-Chat-4bits',
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -289,7 +297,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
format_messages='wizardlm',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||
|
||||
@@ -301,7 +309,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
format_messages='wizardcode',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg))
|
||||
|
||||
@@ -313,7 +321,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
device_map='auto',
|
||||
torch_dtype=torch.float16,
|
||||
format_messages='wizardcode',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg))
|
||||
|
||||
@@ -329,20 +337,28 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_qwen(self):
|
||||
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='qwen/Qwen-7B-Chat',
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skip('Need optimum and auto-gptq')
|
||||
def test_qwen_int4(self):
|
||||
pipe = pipeline(
|
||||
task='chat', model='qwen/Qwen-7B-Chat-Int4', llm_first=True)
|
||||
task='chat',
|
||||
model='qwen/Qwen-7B-Chat-Int4',
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_qwen_vl(self):
|
||||
pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat', llm_first=True)
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='qwen/Qwen-VL-Chat',
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_mm, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -352,13 +368,17 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
model_type = ModelTypeHelper.get(model_id)
|
||||
assert not LLMAdapterRegistry.contains(model_type)
|
||||
|
||||
pipe = pipeline(task='chat', model=model_id, llm_first=True)
|
||||
pipe = pipeline(
|
||||
task='chat', model=model_id, external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_qwen_stream_gemerate(self):
|
||||
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
|
||||
pipe = pipeline(
|
||||
task='chat',
|
||||
model='qwen/Qwen-7B-Chat',
|
||||
external_engine_for_llm=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
@@ -366,7 +386,9 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_qwen1_5_stream_gemerate(self):
|
||||
pipe = pipeline(
|
||||
task='chat', model='qwen/Qwen1.5-1.8B-Chat', llm_first=True)
|
||||
task='chat',
|
||||
model='qwen/Qwen1.5-1.8B-Chat',
|
||||
external_engine_for_llm=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
@@ -377,7 +399,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-13B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -387,7 +409,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
task='chat',
|
||||
model='baichuan-inc/Baichuan2-13B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
@@ -398,7 +420,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
task='chat',
|
||||
model='01ai/Yi-1.5-6B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -408,7 +430,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
task='chat',
|
||||
model='01ai/Yi-1.5-6B-Chat',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
@@ -419,7 +441,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
task='chat',
|
||||
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
print('messages: ', pipe(self.messages_zh_one_round, **self.gen_cfg))
|
||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||
|
||||
@@ -429,7 +451,7 @@ class LLMPipelineTest(unittest.TestCase):
|
||||
task='chat',
|
||||
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
||||
llm_framework='swift',
|
||||
llm_first=True)
|
||||
external_engine_for_llm=True)
|
||||
for stream_output in pipe.stream_generate(self.messages_zh_one_round,
|
||||
**self.gen_cfg):
|
||||
print('messages: ', stream_output, end='\r')
|
||||
|
||||
@@ -19,7 +19,7 @@ class MplugOwlMultimodalDialogueTest(unittest.TestCase):
|
||||
pipeline_multimodal_dialogue = pipeline(
|
||||
task=Tasks.multimodal_dialogue,
|
||||
model=model,
|
||||
)
|
||||
external_engine_for_llm=False)
|
||||
image = 'data/resource/portrait_input.png'
|
||||
system_prompt_1 = 'The following is a conversation between a curious human and AI assistant.'
|
||||
system_prompt_2 = "The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
@@ -41,14 +41,16 @@ class MplugOwlMultimodalDialogueTest(unittest.TestCase):
|
||||
},
|
||||
]
|
||||
}
|
||||
result = pipeline_multimodal_dialogue(messages, max_length=5)
|
||||
result = pipeline_multimodal_dialogue(
|
||||
messages, max_length=5, external_engine_for_llm=False)
|
||||
print(result[OutputKeys.TEXT])
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_multimodal_dialogue_with_name(self):
|
||||
pipeline_multimodal_dialogue = pipeline(
|
||||
Tasks.multimodal_dialogue,
|
||||
model='damo/multi-modal_mplug_owl_multimodal-dialogue_7b')
|
||||
model='damo/multi-modal_mplug_owl_multimodal-dialogue_7b',
|
||||
external_engine_for_llm=False)
|
||||
image = 'data/resource/portrait_input.png'
|
||||
system_prompt_1 = 'The following is a conversation between a curious human and AI assistant.'
|
||||
system_prompt_2 = "The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
@@ -77,7 +79,8 @@ class MplugOwlMultimodalDialogueTest(unittest.TestCase):
|
||||
def test_run_with_multimodal_dialogue_with_text(self):
|
||||
pipeline_multimodal_dialogue = pipeline(
|
||||
Tasks.multimodal_dialogue,
|
||||
model='damo/multi-modal_mplug_owl_multimodal-dialogue_7b')
|
||||
model='damo/multi-modal_mplug_owl_multimodal-dialogue_7b',
|
||||
external_engine_for_llm=False)
|
||||
system_prompt_1 = 'The following is a conversation between a curious human and AI assistant.'
|
||||
system_prompt_2 = "The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
messages = {
|
||||
|
||||
@@ -59,7 +59,7 @@ class TextGenerationTest(unittest.TestCase):
|
||||
task=Tasks.text_generation,
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
llm_first=False)
|
||||
external_engine_for_llm=False)
|
||||
print(pipeline_ins(input))
|
||||
|
||||
def run_pipeline_with_model_id(self,
|
||||
@@ -67,7 +67,7 @@ class TextGenerationTest(unittest.TestCase):
|
||||
input,
|
||||
init_kwargs={},
|
||||
run_kwargs={}):
|
||||
init_kwargs['llm_first'] = False
|
||||
init_kwargs['external_engine_for_llm'] = False
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_generation, model=model_id, **init_kwargs)
|
||||
print(pipeline_ins(input, **run_kwargs))
|
||||
@@ -77,14 +77,14 @@ class TextGenerationTest(unittest.TestCase):
|
||||
input,
|
||||
init_kwargs={},
|
||||
run_kwargs={}):
|
||||
init_kwargs['llm_first'] = False
|
||||
init_kwargs['external_engine_for_llm'] = False
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_generation, model=model_id, **init_kwargs)
|
||||
|
||||
# set stream inputs
|
||||
assert isinstance(pipeline_ins, StreamingOutputMixin)
|
||||
for output in pipeline_ins.stream_generate(
|
||||
input, **run_kwargs, llm_first=False):
|
||||
input, **run_kwargs, external_engine_for_llm=False):
|
||||
print(output, end='\r')
|
||||
print()
|
||||
|
||||
@@ -265,7 +265,7 @@ class TextGenerationTest(unittest.TestCase):
|
||||
Tasks.text_generation,
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
llm_first=False)
|
||||
external_engine_for_llm=False)
|
||||
print(
|
||||
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
|
||||
)
|
||||
@@ -284,14 +284,15 @@ class TextGenerationTest(unittest.TestCase):
|
||||
Tasks.text_generation,
|
||||
model=model,
|
||||
preprocessor=preprocessor,
|
||||
llm_first=False)
|
||||
external_engine_for_llm=False)
|
||||
print(
|
||||
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}'
|
||||
)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(task=Tasks.text_generation, llm_first=False)
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.text_generation, external_engine_for_llm=False)
|
||||
print(
|
||||
pipeline_ins(
|
||||
[self.palm_input_zh, self.palm_input_zh, self.palm_input_zh],
|
||||
@@ -302,7 +303,7 @@ class TextGenerationTest(unittest.TestCase):
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_generation,
|
||||
model='langboat/bloom-1b4-zh',
|
||||
llm_first=False)
|
||||
external_engine_for_llm=False)
|
||||
print(pipe('中国的首都是'))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@@ -310,7 +311,7 @@ class TextGenerationTest(unittest.TestCase):
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_generation,
|
||||
model='langboat/mengzi-gpt-neo-base',
|
||||
llm_first=False)
|
||||
external_engine_for_llm=False)
|
||||
print(
|
||||
pipe(
|
||||
'我是',
|
||||
@@ -325,7 +326,7 @@ class TextGenerationTest(unittest.TestCase):
|
||||
pipe = pipeline(
|
||||
task=Tasks.text_generation,
|
||||
model='damo/nlp_gpt2_text-generation_english-base',
|
||||
llm_first=False)
|
||||
external_engine_for_llm=False)
|
||||
print(pipe('My name is Teven and I am'))
|
||||
|
||||
@unittest.skip('oom error for 7b model')
|
||||
|
||||
Reference in New Issue
Block a user