Files
modelscope/tests/pipelines/test_llama2_text_generation_pipeline.py

56 lines
1.9 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class Llama2TextGenerationPipelineTest(unittest.TestCase):
def setUp(self) -> None:
self.llama2_model_id_7B_chat_ms = 'modelscope/Llama-2-7b-chat-ms'
self.llama2_input_chat_ch = 'What are the company there?'
self.history_demo = [(
'Where is the capital of Zhejiang?',
'Thank you for asking! The capital of Zhejiang Province is Hangzhou.'
)]
def run_pipeline_with_model_id(self,
model_id,
input,
init_kwargs={},
run_kwargs={}):
init_kwargs['external_engine_for_llm'] = False
pipeline_ins = pipeline(task=Tasks.chat, model=model_id, **init_kwargs)
pipeline_ins._model_prepare = True
result = pipeline_ins(input, **run_kwargs)
print(result['response'])
# 7B_ms_chat
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_llama2_7B_chat_ms_with_model_name_with_chat_ch_with_args(self):
self.run_pipeline_with_model_id(
self.llama2_model_id_7B_chat_ms,
self.llama2_input_chat_ch,
init_kwargs={
'external_engine_for_llm': False,
'device_map': 'auto',
'torch_dtype': torch.float16,
'model_revision': 'v1.0.5',
'ignore_file_pattern': [r'.+\.bin$']
},
run_kwargs={
'max_length': 512,
'do_sample': True,
'top_p': 0.9,
'history': self.history_demo
})
if __name__ == '__main__':
unittest.main()