mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Format llm pipeline (#1094)
* format llm pipeline Co-authored-by: Yingda Chen <yingda.chen@alibaba-inc.com>
This commit is contained in:
@@ -205,6 +205,13 @@ class Pipeline(ABC):
|
|||||||
kwargs['preprocess_params'] = preprocess_params
|
kwargs['preprocess_params'] = preprocess_params
|
||||||
kwargs['forward_params'] = forward_params
|
kwargs['forward_params'] = forward_params
|
||||||
kwargs['postprocess_params'] = postprocess_params
|
kwargs['postprocess_params'] = postprocess_params
|
||||||
|
|
||||||
|
# for LLMPipeline, we shall support treating list of roles as a
|
||||||
|
# one single 'messages' input
|
||||||
|
if 'LLMPipeline' in type(self).__name__ and isinstance(input, list):
|
||||||
|
input = {'messages': input}
|
||||||
|
kwargs['is_message'] = True
|
||||||
|
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
output = []
|
output = []
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from modelscope.hub.snapshot_download import snapshot_download
|
|||||||
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
|
from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
|
||||||
from modelscope.models.base import Model
|
from modelscope.models.base import Model
|
||||||
from modelscope.utils.config import ConfigDict, check_config
|
from modelscope.utils.config import ConfigDict, check_config
|
||||||
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke,
|
from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
|
||||||
ThirdParty)
|
ThirdParty)
|
||||||
from modelscope.utils.hub import read_config
|
from modelscope.utils.hub import read_config
|
||||||
from modelscope.utils.plugins import (register_modelhub_repo,
|
from modelscope.utils.plugins import (register_modelhub_repo,
|
||||||
@@ -108,12 +108,23 @@ def pipeline(task: str = None,
|
|||||||
"""
|
"""
|
||||||
if task is None and pipeline_name is None:
|
if task is None and pipeline_name is None:
|
||||||
raise ValueError('task or pipeline_name is required')
|
raise ValueError('task or pipeline_name is required')
|
||||||
|
prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
|
||||||
|
if task is not None and task.lower() in [
|
||||||
|
Tasks.text_generation, Tasks.chat
|
||||||
|
]:
|
||||||
|
# if not specified, prefer llm pipeline for aforementioned tasks
|
||||||
|
if prefer_llm_pipeline is None:
|
||||||
|
prefer_llm_pipeline = True
|
||||||
|
# for llm pipeline, if llm_framework is not specified, default to swift instead
|
||||||
|
# TODO: port the swift infer based on transformer into ModelScope
|
||||||
|
if prefer_llm_pipeline and kwargs.get('llm_framework') is None:
|
||||||
|
kwargs['llm_framework'] = 'swift'
|
||||||
third_party = kwargs.get(ThirdParty.KEY)
|
third_party = kwargs.get(ThirdParty.KEY)
|
||||||
if third_party is not None:
|
if third_party is not None:
|
||||||
kwargs.pop(ThirdParty.KEY)
|
kwargs.pop(ThirdParty.KEY)
|
||||||
if pipeline_name is None and kwargs.get('llm_first'):
|
if pipeline_name is None and prefer_llm_pipeline:
|
||||||
pipeline_name = llm_first_checker(model, model_revision, kwargs)
|
pipeline_name = external_engine_for_llm_checker(
|
||||||
|
model, model_revision, kwargs)
|
||||||
else:
|
else:
|
||||||
model = normalize_model_input(
|
model = normalize_model_input(
|
||||||
model,
|
model,
|
||||||
@@ -133,9 +144,9 @@ def pipeline(task: str = None,
|
|||||||
model[0], revision=model_revision)
|
model[0], revision=model_revision)
|
||||||
register_plugins_repo(cfg.safe_get('plugins'))
|
register_plugins_repo(cfg.safe_get('plugins'))
|
||||||
register_modelhub_repo(model, cfg.get('allow_remote', False))
|
register_modelhub_repo(model, cfg.get('allow_remote', False))
|
||||||
pipeline_name = llm_first_checker(
|
pipeline_name = external_engine_for_llm_checker(
|
||||||
model, model_revision,
|
model, model_revision,
|
||||||
kwargs) if kwargs.get('llm_first') else None
|
kwargs) if prefer_llm_pipeline else None
|
||||||
if pipeline_name is not None:
|
if pipeline_name is not None:
|
||||||
pipeline_props = {'type': pipeline_name}
|
pipeline_props = {'type': pipeline_name}
|
||||||
else:
|
else:
|
||||||
@@ -208,8 +219,9 @@ def get_default_pipeline_info(task):
|
|||||||
return pipeline_name, default_model
|
return pipeline_name, default_model
|
||||||
|
|
||||||
|
|
||||||
def llm_first_checker(model: Union[str, List[str], Model,
|
def external_engine_for_llm_checker(model: Union[str, List[str], Model,
|
||||||
List[Model]], revision: Optional[str],
|
List[Model]],
|
||||||
|
revision: Optional[str],
|
||||||
kwargs: Dict[str, Any]) -> Optional[str]:
|
kwargs: Dict[str, Any]) -> Optional[str]:
|
||||||
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
|
from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
|
||||||
|
|
||||||
@@ -229,5 +241,5 @@ def llm_first_checker(model: Union[str, List[str], Model,
|
|||||||
def clear_llm_info(kwargs: Dict):
|
def clear_llm_info(kwargs: Dict):
|
||||||
from modelscope.utils.model_type_helper import ModelTypeHelper
|
from modelscope.utils.model_type_helper import ModelTypeHelper
|
||||||
|
|
||||||
kwargs.pop('llm_first', None)
|
kwargs.pop('external_engine_for_llm', None)
|
||||||
ModelTypeHelper.clear_cache()
|
ModelTypeHelper.clear_cache()
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
|
|||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
SWIFT_MODEL_ID_MAPPING = {}
|
SWIFT_MODEL_ID_MAPPING = {}
|
||||||
|
SWIFT_FRAMEWORK = 'swift'
|
||||||
|
|
||||||
|
|
||||||
class LLMAdapterRegistry:
|
class LLMAdapterRegistry:
|
||||||
@@ -90,7 +91,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
|||||||
if self._is_swift_model(model):
|
if self._is_swift_model(model):
|
||||||
if self.llm_framework is not None:
|
if self.llm_framework is not None:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f'Cannot using swift with llm_framework, ignoring {self.llm_framework}.'
|
f'Cannot use swift with llm_framework, ignoring {self.llm_framework}.'
|
||||||
)
|
)
|
||||||
|
|
||||||
base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
|
base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
|
||||||
@@ -155,7 +156,8 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
self.cfg = Config.from_file(cfg_file)
|
self.cfg = Config.from_file(cfg_file)
|
||||||
return self.cfg.safe_get('adapter_cfg.tuner_backend') == 'swift'
|
return self.cfg.safe_get(
|
||||||
|
'adapter_cfg.tuner_backend') == SWIFT_FRAMEWORK
|
||||||
|
|
||||||
def _wrap_infer_framework(self, model_dir, framework='vllm'):
|
def _wrap_infer_framework(self, model_dir, framework='vllm'):
|
||||||
from modelscope.pipelines.accelerate.base import InferFramework
|
from modelscope.pipelines.accelerate.base import InferFramework
|
||||||
@@ -184,7 +186,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
|||||||
self.torch_dtype = kwargs.pop('torch_dtype', None)
|
self.torch_dtype = kwargs.pop('torch_dtype', None)
|
||||||
self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
|
||||||
|
|
||||||
if llm_framework == 'swift':
|
if llm_framework == SWIFT_FRAMEWORK:
|
||||||
self._init_swift(kwargs['model'], kwargs.get('device', 'gpu'))
|
self._init_swift(kwargs['model'], kwargs.get('device', 'gpu'))
|
||||||
return
|
return
|
||||||
with self._temp_configuration_file(kwargs):
|
with self._temp_configuration_file(kwargs):
|
||||||
@@ -254,9 +256,13 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
|||||||
contents = [message['content'] for message in messages]
|
contents = [message['content'] for message in messages]
|
||||||
prompt = contents[-1]
|
prompt = contents[-1]
|
||||||
history = list(zip(contents[::2], contents[1::2]))
|
history = list(zip(contents[::2], contents[1::2]))
|
||||||
|
if self.llm_framework == SWIFT_FRAMEWORK:
|
||||||
|
return dict(system=system, query=prompt, history=history)
|
||||||
|
else:
|
||||||
return dict(system=system, prompt=prompt, history=history)
|
return dict(system=system, prompt=prompt, history=history)
|
||||||
|
|
||||||
assert model_id in SWIFT_MODEL_ID_MAPPING, 'Swift framework does not support current model!'
|
assert model_id in SWIFT_MODEL_ID_MAPPING,\
|
||||||
|
f'Invalid model id {model_id} or Swift framework does not support this model.'
|
||||||
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
|
args = InferArguments(model_type=SWIFT_MODEL_ID_MAPPING[model_id])
|
||||||
model, template = prepare_model_template(
|
model, template = prepare_model_template(
|
||||||
args, device_map=self.device_map)
|
args, device_map=self.device_map)
|
||||||
@@ -297,7 +303,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
|||||||
= isinstance(inputs, dict) and 'messages' in inputs
|
= isinstance(inputs, dict) and 'messages' in inputs
|
||||||
tokens = self.preprocess(inputs, **preprocess_params)
|
tokens = self.preprocess(inputs, **preprocess_params)
|
||||||
|
|
||||||
if self.llm_framework in (None, 'swift'):
|
if self.llm_framework in (None, SWIFT_FRAMEWORK):
|
||||||
# pytorch model
|
# pytorch model
|
||||||
if hasattr(self.model, 'generate'):
|
if hasattr(self.model, 'generate'):
|
||||||
outputs = self.model.generate(**tokens, **forward_params)
|
outputs = self.model.generate(**tokens, **forward_params)
|
||||||
@@ -310,7 +316,7 @@ class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
|
|||||||
tokens = [list(tokens['inputs'].flatten().numpy())]
|
tokens = [list(tokens['inputs'].flatten().numpy())]
|
||||||
outputs = self.model(tokens, **forward_params)[0]
|
outputs = self.model(tokens, **forward_params)[0]
|
||||||
|
|
||||||
if self.llm_framework is None:
|
if self.llm_framework in (None, SWIFT_FRAMEWORK):
|
||||||
# pytorch model
|
# pytorch model
|
||||||
outputs = outputs.tolist()[0][len(tokens['inputs'][0]):]
|
outputs = outputs.tolist()[0][len(tokens['inputs'][0]):]
|
||||||
response = self.postprocess(outputs, **postprocess_params)
|
response = self.postprocess(outputs, **postprocess_params)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ def add_server_args(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument('--port', type=int, default=8000, help='Server port')
|
parser.add_argument('--port', type=int, default=8000, help='Server port')
|
||||||
parser.add_argument('--debug', default='debug', help='Set debug level.')
|
parser.add_argument('--debug', default='debug', help='Set debug level.')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--llm_first',
|
'--external_engine_for_llm',
|
||||||
type=bool,
|
type=bool,
|
||||||
default=True,
|
default=True,
|
||||||
help='Use LLMPipeline first for llm models.')
|
help='Use LLMPipeline first for llm models.')
|
||||||
|
|||||||
@@ -14,9 +14,9 @@ logger = get_logger()
|
|||||||
|
|
||||||
def _startup_model(app: FastAPI) -> None:
|
def _startup_model(app: FastAPI) -> None:
|
||||||
logger.info('download model and create pipeline')
|
logger.info('download model and create pipeline')
|
||||||
app.state.pipeline = create_pipeline(app.state.args.model_id,
|
app.state.pipeline = create_pipeline(
|
||||||
app.state.args.revision,
|
app.state.args.model_id, app.state.args.revision,
|
||||||
app.state.args.llm_first)
|
app.state.args.external_engine_for_llm)
|
||||||
info = {}
|
info = {}
|
||||||
info['task_name'] = app.state.pipeline.group_key
|
info['task_name'] = app.state.pipeline.group_key
|
||||||
info['schema'] = get_task_schemas(app.state.pipeline.group_key)
|
info['schema'] = get_task_schemas(app.state.pipeline.group_key)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class DeployChecker:
|
|||||||
task=task,
|
task=task,
|
||||||
model=model_id,
|
model=model_id,
|
||||||
model_revision=model_revision,
|
model_revision=model_revision,
|
||||||
llm_first=True)
|
external_engine_for_llm=True)
|
||||||
pipeline_info = get_pipeline_information_by_pipeline(ppl)
|
pipeline_info = get_pipeline_information_by_pipeline(ppl)
|
||||||
|
|
||||||
# call pipeline
|
# call pipeline
|
||||||
|
|||||||
@@ -67,7 +67,9 @@ Todo:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_pipeline(model_id: str, revision: str, llm_first: bool = True):
|
def create_pipeline(model_id: str,
|
||||||
|
revision: str,
|
||||||
|
external_engine_for_llm: bool = True):
|
||||||
model_configuration_file = model_file_download(
|
model_configuration_file = model_file_download(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
file_path=ModelFile.CONFIGURATION,
|
file_path=ModelFile.CONFIGURATION,
|
||||||
@@ -77,7 +79,7 @@ def create_pipeline(model_id: str, revision: str, llm_first: bool = True):
|
|||||||
task=cfg.task,
|
task=cfg.task,
|
||||||
model=model_id,
|
model=model_id,
|
||||||
model_revision=revision,
|
model_revision=revision,
|
||||||
llm_first=llm_first)
|
external_engine_for_llm=external_engine_for_llm)
|
||||||
|
|
||||||
|
|
||||||
def get_class_user_attributes(cls):
|
def get_class_user_attributes(cls):
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ class ModelJsonTest:
|
|||||||
task=task,
|
task=task,
|
||||||
model=model_id,
|
model=model_id,
|
||||||
model_revision=model_revision,
|
model_revision=model_revision,
|
||||||
llm_first=True)
|
external_engine_for_llm=True)
|
||||||
pipeline_info = get_pipeline_information_by_pipeline(ppl)
|
pipeline_info = get_pipeline_information_by_pipeline(ppl)
|
||||||
|
|
||||||
# call pipeline
|
# call pipeline
|
||||||
|
|||||||
@@ -17,26 +17,38 @@ class TextGPT3GenerationTest(unittest.TestCase):
|
|||||||
self.model_dir_13B = snapshot_download(self.model_id_13B)
|
self.model_dir_13B = snapshot_download(self.model_id_13B)
|
||||||
self.input = '好的'
|
self.input = '好的'
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skip('deprecated, skipped')
|
||||||
def test_gpt3_1_3B(self):
|
def test_gpt3_1_3B(self):
|
||||||
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
|
pipe = pipeline(
|
||||||
|
Tasks.text_generation,
|
||||||
|
model=self.model_id_1_3B,
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(pipe(self.input))
|
print(pipe(self.input))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skip('deprecated, skipped')
|
||||||
def test_gpt3_1_3B_with_streaming(self):
|
def test_gpt3_1_3B_with_streaming(self):
|
||||||
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
|
pipe = pipeline(
|
||||||
|
Tasks.text_generation,
|
||||||
|
model=self.model_id_1_3B,
|
||||||
|
external_engine_for_llm=False)
|
||||||
for output in pipe.stream_generate(self.input, max_length=64):
|
for output in pipe.stream_generate(self.input, max_length=64):
|
||||||
print(output, end='\r')
|
print(output, end='\r')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skip('deprecated, skipped')
|
||||||
def test_gpt3_2_7B(self):
|
def test_gpt3_2_7B(self):
|
||||||
pipe = pipeline(Tasks.text_generation, model=self.model_id_2_7B)
|
pipe = pipeline(
|
||||||
|
Tasks.text_generation,
|
||||||
|
model=self.model_id_2_7B,
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(pipe(self.input))
|
print(pipe(self.input))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skip('deprecated, skipped')
|
||||||
def test_gpt3_1_3B_with_args(self):
|
def test_gpt3_1_3B_with_args(self):
|
||||||
pipe = pipeline(Tasks.text_generation, model=self.model_id_1_3B)
|
pipe = pipeline(
|
||||||
|
Tasks.text_generation,
|
||||||
|
model=self.model_id_1_3B,
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(pipe(self.input, top_p=0.9, temperature=0.9, max_length=32))
|
print(pipe(self.input, top_p=0.9, temperature=0.9, max_length=32))
|
||||||
|
|
||||||
@unittest.skip('distributed gpt3 13B, skipped')
|
@unittest.skip('distributed gpt3 13B, skipped')
|
||||||
@@ -62,7 +74,10 @@ class TextGPT3GenerationTest(unittest.TestCase):
|
|||||||
|_ mp_rank_06_model_states.pt
|
|_ mp_rank_06_model_states.pt
|
||||||
|_ mp_rank_07_model_states.pt
|
|_ mp_rank_07_model_states.pt
|
||||||
"""
|
"""
|
||||||
pipe = pipeline(Tasks.text_generation, model=self.model_dir_13B)
|
pipe = pipeline(
|
||||||
|
Tasks.text_generation,
|
||||||
|
model=self.model_dir_13B,
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(pipe(self.input))
|
print(pipe(self.input))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ class Llama2TextGenerationPipelineTest(unittest.TestCase):
|
|||||||
input,
|
input,
|
||||||
init_kwargs={},
|
init_kwargs={},
|
||||||
run_kwargs={}):
|
run_kwargs={}):
|
||||||
|
init_kwargs['external_engine_for_llm'] = False
|
||||||
pipeline_ins = pipeline(task=Tasks.chat, model=model_id, **init_kwargs)
|
pipeline_ins = pipeline(task=Tasks.chat, model=model_id, **init_kwargs)
|
||||||
pipeline_ins._model_prepare = True
|
pipeline_ins._model_prepare = True
|
||||||
result = pipeline_ins(input, **run_kwargs)
|
result = pipeline_ins(input, **run_kwargs)
|
||||||
@@ -36,6 +37,7 @@ class Llama2TextGenerationPipelineTest(unittest.TestCase):
|
|||||||
self.llama2_model_id_7B_chat_ms,
|
self.llama2_model_id_7B_chat_ms,
|
||||||
self.llama2_input_chat_ch,
|
self.llama2_input_chat_ch,
|
||||||
init_kwargs={
|
init_kwargs={
|
||||||
|
'external_engine_for_llm': False,
|
||||||
'device_map': 'auto',
|
'device_map': 'auto',
|
||||||
'torch_dtype': torch.float16,
|
'torch_dtype': torch.float16,
|
||||||
'model_revision': 'v1.0.5',
|
'model_revision': 'v1.0.5',
|
||||||
|
|||||||
@@ -141,29 +141,25 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_chatglm2(self):
|
def test_chatglm2(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(task='chat', model='ZhipuAI/chatglm2-6b')
|
||||||
task='chat', model='ZhipuAI/chatglm2-6b', llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_chatglm2int4(self):
|
def test_chatglm2int4(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(task='chat', model='ZhipuAI/chatglm2-6b-int4')
|
||||||
task='chat', model='ZhipuAI/chatglm2-6b-int4', llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_chatglm232k(self):
|
def test_chatglm232k(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(task='chat', model='ZhipuAI/chatglm2-6b-32k')
|
||||||
task='chat', model='ZhipuAI/chatglm2-6b-32k', llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_chatglm3(self):
|
def test_chatglm3(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(task='chat', model='ZhipuAI/chatglm3-6b')
|
||||||
task='chat', model='ZhipuAI/chatglm3-6b', llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -174,8 +170,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
model='modelscope/Llama-2-7b-ms',
|
model='modelscope/Llama-2-7b-ms',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
ignore_file_pattern=[r'.+\.bin$'],
|
ignore_file_pattern=[r'.+\.bin$'])
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -187,8 +182,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
revision='v1.0.2',
|
revision='v1.0.2',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
ignore_file_pattern=[r'.+\.bin$'],
|
ignore_file_pattern=[r'.+\.bin$'])
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -199,8 +193,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
model='AI-ModelScope/CodeLlama-7b-Instruct-hf',
|
model='AI-ModelScope/CodeLlama-7b-Instruct-hf',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
ignore_file_pattern=[r'.+\.bin$'],
|
ignore_file_pattern=[r'.+\.bin$'])
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_code, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_code, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_code, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_code, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -210,8 +203,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/baichuan-7B',
|
model='baichuan-inc/baichuan-7B',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16)
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -221,8 +213,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan-13B-Base',
|
model='baichuan-inc/Baichuan-13B-Base',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16)
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -232,8 +223,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan-13B-Chat',
|
model='baichuan-inc/Baichuan-13B-Chat',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16)
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -243,8 +233,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan2-7B-Base',
|
model='baichuan-inc/Baichuan2-7B-Base',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16)
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -254,8 +243,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan2-7B-Chat',
|
model='baichuan-inc/Baichuan2-7B-Chat',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16)
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -265,8 +253,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan2-7B-Chat-4bits',
|
model='baichuan-inc/Baichuan2-7B-Chat-4bits',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16)
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -276,8 +263,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan2-13B-Chat-4bits',
|
model='baichuan-inc/Baichuan2-13B-Chat-4bits',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16)
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -288,8 +274,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
model='AI-ModelScope/WizardLM-13B-V1.2',
|
model='AI-ModelScope/WizardLM-13B-V1.2',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
format_messages='wizardlm',
|
format_messages='wizardlm')
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_en, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_en, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -300,8 +285,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
model='AI-ModelScope/WizardMath-7B-V1.0',
|
model='AI-ModelScope/WizardMath-7B-V1.0',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
format_messages='wizardcode',
|
format_messages='wizardcode')
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg))
|
print('messages: ', pipe(self.message_wizard_math, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_wizard_math, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -312,8 +296,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
model='AI-ModelScope/WizardCoder-Python-13B-V1.0',
|
model='AI-ModelScope/WizardCoder-Python-13B-V1.0',
|
||||||
device_map='auto',
|
device_map='auto',
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
format_messages='wizardcode',
|
format_messages='wizardcode')
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg))
|
print('messages: ', pipe(self.message_wizard_code, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_wizard_code, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -329,20 +312,19 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_qwen(self):
|
def test_qwen(self):
|
||||||
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
|
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat')
|
||||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@unittest.skip('Need optimum and auto-gptq')
|
@unittest.skip('Need optimum and auto-gptq')
|
||||||
def test_qwen_int4(self):
|
def test_qwen_int4(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat-Int4')
|
||||||
task='chat', model='qwen/Qwen-7B-Chat-Int4', llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_qwen_vl(self):
|
def test_qwen_vl(self):
|
||||||
pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat', llm_first=True)
|
pipe = pipeline(task='chat', model='qwen/Qwen-VL-Chat')
|
||||||
print('messages: ', pipe(self.messages_mm, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_mm, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -352,21 +334,20 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
model_type = ModelTypeHelper.get(model_id)
|
model_type = ModelTypeHelper.get(model_id)
|
||||||
assert not LLMAdapterRegistry.contains(model_type)
|
assert not LLMAdapterRegistry.contains(model_type)
|
||||||
|
|
||||||
pipe = pipeline(task='chat', model=model_id, llm_first=True)
|
pipe = pipeline(task='chat', model=model_id)
|
||||||
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_qwen_stream_gemerate(self):
|
def test_qwen_stream_gemerate(self):
|
||||||
pipe = pipeline(task='chat', model='qwen/Qwen-7B-Chat', llm_first=True)
|
pipe = pipeline(task='chat', model='Qwen/Qwen-7B-Chat')
|
||||||
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
|
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
|
||||||
**self.gen_cfg):
|
**self.gen_cfg):
|
||||||
print('messages: ', stream_output, end='\r')
|
print('messages: ', stream_output, end='\r')
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
def test_qwen1_5_stream_gemerate(self):
|
def test_qwen1_5_stream_generate(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(task='chat', model='Qwen/Qwen1.5-1.8B-Chat')
|
||||||
task='chat', model='qwen/Qwen1.5-1.8B-Chat', llm_first=True)
|
|
||||||
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
|
for stream_output in pipe.stream_generate(self.messages_zh_with_system,
|
||||||
**self.gen_cfg):
|
**self.gen_cfg):
|
||||||
print('messages: ', stream_output, end='\r')
|
print('messages: ', stream_output, end='\r')
|
||||||
@@ -376,8 +357,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan2-13B-Chat',
|
model='baichuan-inc/Baichuan2-13B-Chat',
|
||||||
llm_framework='swift',
|
llm_framework='swift')
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -386,8 +366,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task='chat',
|
task='chat',
|
||||||
model='baichuan-inc/Baichuan2-13B-Chat',
|
model='baichuan-inc/Baichuan2-13B-Chat',
|
||||||
llm_framework='swift',
|
llm_framework='swift')
|
||||||
llm_first=True)
|
|
||||||
for stream_output in pipe.stream_generate(self.messages_zh,
|
for stream_output in pipe.stream_generate(self.messages_zh,
|
||||||
**self.gen_cfg):
|
**self.gen_cfg):
|
||||||
print('messages: ', stream_output, end='\r')
|
print('messages: ', stream_output, end='\r')
|
||||||
@@ -395,20 +374,14 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_yi_with_swift(self):
|
def test_yi_with_swift(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task='chat',
|
task='chat', model='01ai/Yi-1.5-6B-Chat', llm_framework='swift')
|
||||||
model='01ai/Yi-1.5-6B-Chat',
|
|
||||||
llm_framework='swift',
|
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh_with_system, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||||
def test_yi_stream_gemerate(self):
|
def test_yi_stream_gemerate(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task='chat',
|
task='chat', model='01ai/Yi-1.5-6B-Chat', llm_framework='swift')
|
||||||
model='01ai/Yi-1.5-6B-Chat',
|
|
||||||
llm_framework='swift',
|
|
||||||
llm_first=True)
|
|
||||||
for stream_output in pipe.stream_generate(self.messages_zh,
|
for stream_output in pipe.stream_generate(self.messages_zh,
|
||||||
**self.gen_cfg):
|
**self.gen_cfg):
|
||||||
print('messages: ', stream_output, end='\r')
|
print('messages: ', stream_output, end='\r')
|
||||||
@@ -418,8 +391,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task='chat',
|
task='chat',
|
||||||
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
||||||
llm_framework='swift',
|
llm_framework='swift')
|
||||||
llm_first=True)
|
|
||||||
print('messages: ', pipe(self.messages_zh_one_round, **self.gen_cfg))
|
print('messages: ', pipe(self.messages_zh_one_round, **self.gen_cfg))
|
||||||
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
print('prompt: ', pipe(self.prompt_zh, **self.gen_cfg))
|
||||||
|
|
||||||
@@ -428,8 +400,7 @@ class LLMPipelineTest(unittest.TestCase):
|
|||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task='chat',
|
task='chat',
|
||||||
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
model='Shanghai_AI_Laboratory/internlm2-1_8b',
|
||||||
llm_framework='swift',
|
llm_framework='swift')
|
||||||
llm_first=True)
|
|
||||||
for stream_output in pipe.stream_generate(self.messages_zh_one_round,
|
for stream_output in pipe.stream_generate(self.messages_zh_one_round,
|
||||||
**self.gen_cfg):
|
**self.gen_cfg):
|
||||||
print('messages: ', stream_output, end='\r')
|
print('messages: ', stream_output, end='\r')
|
||||||
|
|||||||
@@ -17,9 +17,7 @@ class MplugOwlMultimodalDialogueTest(unittest.TestCase):
|
|||||||
model = Model.from_pretrained(
|
model = Model.from_pretrained(
|
||||||
'damo/multi-modal_mplug_owl_multimodal-dialogue_7b')
|
'damo/multi-modal_mplug_owl_multimodal-dialogue_7b')
|
||||||
pipeline_multimodal_dialogue = pipeline(
|
pipeline_multimodal_dialogue = pipeline(
|
||||||
task=Tasks.multimodal_dialogue,
|
task=Tasks.multimodal_dialogue, model=model)
|
||||||
model=model,
|
|
||||||
)
|
|
||||||
image = 'data/resource/portrait_input.png'
|
image = 'data/resource/portrait_input.png'
|
||||||
system_prompt_1 = 'The following is a conversation between a curious human and AI assistant.'
|
system_prompt_1 = 'The following is a conversation between a curious human and AI assistant.'
|
||||||
system_prompt_2 = "The assistant gives helpful, detailed, and polite answers to the user's questions."
|
system_prompt_2 = "The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||||
|
|||||||
@@ -56,7 +56,10 @@ class TextGenerationTest(unittest.TestCase):
|
|||||||
first_sequence='sentence',
|
first_sequence='sentence',
|
||||||
second_sequence=None)
|
second_sequence=None)
|
||||||
pipeline_ins = pipeline(
|
pipeline_ins = pipeline(
|
||||||
task=Tasks.text_generation, model=model, preprocessor=preprocessor)
|
task=Tasks.text_generation,
|
||||||
|
model=model,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(pipeline_ins(input))
|
print(pipeline_ins(input))
|
||||||
|
|
||||||
def run_pipeline_with_model_id(self,
|
def run_pipeline_with_model_id(self,
|
||||||
@@ -64,6 +67,7 @@ class TextGenerationTest(unittest.TestCase):
|
|||||||
input,
|
input,
|
||||||
init_kwargs={},
|
init_kwargs={},
|
||||||
run_kwargs={}):
|
run_kwargs={}):
|
||||||
|
init_kwargs['external_engine_for_llm'] = False
|
||||||
pipeline_ins = pipeline(
|
pipeline_ins = pipeline(
|
||||||
task=Tasks.text_generation, model=model_id, **init_kwargs)
|
task=Tasks.text_generation, model=model_id, **init_kwargs)
|
||||||
print(pipeline_ins(input, **run_kwargs))
|
print(pipeline_ins(input, **run_kwargs))
|
||||||
@@ -73,12 +77,14 @@ class TextGenerationTest(unittest.TestCase):
|
|||||||
input,
|
input,
|
||||||
init_kwargs={},
|
init_kwargs={},
|
||||||
run_kwargs={}):
|
run_kwargs={}):
|
||||||
|
init_kwargs['external_engine_for_llm'] = False
|
||||||
pipeline_ins = pipeline(
|
pipeline_ins = pipeline(
|
||||||
task=Tasks.text_generation, model=model_id, **init_kwargs)
|
task=Tasks.text_generation, model=model_id, **init_kwargs)
|
||||||
|
|
||||||
# set stream inputs
|
# set stream inputs
|
||||||
assert isinstance(pipeline_ins, StreamingOutputMixin)
|
assert isinstance(pipeline_ins, StreamingOutputMixin)
|
||||||
for output in pipeline_ins.stream_generate(input, **run_kwargs):
|
for output in pipeline_ins.stream_generate(
|
||||||
|
input, **run_kwargs, external_engine_for_llm=False):
|
||||||
print(output, end='\r')
|
print(output, end='\r')
|
||||||
print()
|
print()
|
||||||
|
|
||||||
@@ -256,7 +262,10 @@ class TextGenerationTest(unittest.TestCase):
|
|||||||
cache_path, first_sequence='sentence', second_sequence=None)
|
cache_path, first_sequence='sentence', second_sequence=None)
|
||||||
pipeline1 = TextGenerationPipeline(model, preprocessor)
|
pipeline1 = TextGenerationPipeline(model, preprocessor)
|
||||||
pipeline2 = pipeline(
|
pipeline2 = pipeline(
|
||||||
Tasks.text_generation, model=model, preprocessor=preprocessor)
|
Tasks.text_generation,
|
||||||
|
model=model,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(
|
print(
|
||||||
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
|
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
|
||||||
)
|
)
|
||||||
@@ -272,14 +281,18 @@ class TextGenerationTest(unittest.TestCase):
|
|||||||
second_sequence=None)
|
second_sequence=None)
|
||||||
pipeline1 = TextGenerationPipeline(model, preprocessor)
|
pipeline1 = TextGenerationPipeline(model, preprocessor)
|
||||||
pipeline2 = pipeline(
|
pipeline2 = pipeline(
|
||||||
Tasks.text_generation, model=model, preprocessor=preprocessor)
|
Tasks.text_generation,
|
||||||
|
model=model,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(
|
print(
|
||||||
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}'
|
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}'
|
||||||
)
|
)
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
def test_run_with_default_model(self):
|
def test_run_with_default_model(self):
|
||||||
pipeline_ins = pipeline(task=Tasks.text_generation)
|
pipeline_ins = pipeline(
|
||||||
|
task=Tasks.text_generation, external_engine_for_llm=False)
|
||||||
print(
|
print(
|
||||||
pipeline_ins(
|
pipeline_ins(
|
||||||
[self.palm_input_zh, self.palm_input_zh, self.palm_input_zh],
|
[self.palm_input_zh, self.palm_input_zh, self.palm_input_zh],
|
||||||
@@ -288,13 +301,17 @@ class TextGenerationTest(unittest.TestCase):
|
|||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
def test_bloom(self):
|
def test_bloom(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task=Tasks.text_generation, model='langboat/bloom-1b4-zh')
|
task=Tasks.text_generation,
|
||||||
|
model='langboat/bloom-1b4-zh',
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(pipe('中国的首都是'))
|
print(pipe('中国的首都是'))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
def test_gpt_neo(self):
|
def test_gpt_neo(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task=Tasks.text_generation, model='langboat/mengzi-gpt-neo-base')
|
task=Tasks.text_generation,
|
||||||
|
model='langboat/mengzi-gpt-neo-base',
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(
|
print(
|
||||||
pipe(
|
pipe(
|
||||||
'我是',
|
'我是',
|
||||||
@@ -308,7 +325,8 @@ class TextGenerationTest(unittest.TestCase):
|
|||||||
def test_gpt2(self):
|
def test_gpt2(self):
|
||||||
pipe = pipeline(
|
pipe = pipeline(
|
||||||
task=Tasks.text_generation,
|
task=Tasks.text_generation,
|
||||||
model='damo/nlp_gpt2_text-generation_english-base')
|
model='damo/nlp_gpt2_text-generation_english-base',
|
||||||
|
external_engine_for_llm=False)
|
||||||
print(pipe('My name is Teven and I am'))
|
print(pipe('My name is Teven and I am'))
|
||||||
|
|
||||||
@unittest.skip('oom error for 7b model')
|
@unittest.skip('oom error for 7b model')
|
||||||
|
|||||||
Reference in New Issue
Block a user