2023-07-22 21:53:04 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from modelscope.pipelines import pipeline
|
|
|
|
|
from modelscope.utils.constant import Tasks
|
|
|
|
|
from modelscope.utils.test_utils import test_level
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Llama2TextGenerationPipelineTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def setUp(self) -> None:
|
|
|
|
|
self.llama2_model_id_7B_chat_ms = 'modelscope/Llama-2-7b-chat-ms'
|
2023-09-12 17:57:00 +08:00
|
|
|
self.llama2_input_chat_ch = 'What are the company there?'
|
|
|
|
|
self.history_demo = [(
|
|
|
|
|
'Where is the capital of Zhejiang?',
|
|
|
|
|
'Thank you for asking! The capital of Zhejiang Province is Hangzhou.'
|
|
|
|
|
)]
|
2023-07-22 21:53:04 +08:00
|
|
|
|
|
|
|
|
def run_pipeline_with_model_id(self,
|
|
|
|
|
model_id,
|
|
|
|
|
input,
|
|
|
|
|
init_kwargs={},
|
|
|
|
|
run_kwargs={}):
|
2024-11-22 20:04:59 +08:00
|
|
|
init_kwargs['external_engine_for_llm'] = False
|
2023-09-12 17:57:00 +08:00
|
|
|
pipeline_ins = pipeline(task=Tasks.chat, model=model_id, **init_kwargs)
|
2023-07-22 21:53:04 +08:00
|
|
|
pipeline_ins._model_prepare = True
|
|
|
|
|
result = pipeline_ins(input, **run_kwargs)
|
2023-09-12 17:57:00 +08:00
|
|
|
print(result['response'])
|
2023-07-22 21:53:04 +08:00
|
|
|
|
|
|
|
|
# 7B_ms_chat
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
|
|
|
def test_llama2_7B_chat_ms_with_model_name_with_chat_ch_with_args(self):
|
|
|
|
|
self.run_pipeline_with_model_id(
|
|
|
|
|
self.llama2_model_id_7B_chat_ms,
|
|
|
|
|
self.llama2_input_chat_ch,
|
|
|
|
|
init_kwargs={
|
2024-11-22 20:04:59 +08:00
|
|
|
'external_engine_for_llm': False,
|
2023-07-22 21:53:04 +08:00
|
|
|
'device_map': 'auto',
|
2023-09-12 17:57:00 +08:00
|
|
|
'torch_dtype': torch.float16,
|
|
|
|
|
'model_revision': 'v1.0.5',
|
|
|
|
|
'ignore_file_pattern': [r'.+\.bin$']
|
2023-07-22 21:53:04 +08:00
|
|
|
},
|
|
|
|
|
run_kwargs={
|
2023-09-12 17:57:00 +08:00
|
|
|
'max_length': 512,
|
2023-07-22 21:53:04 +08:00
|
|
|
'do_sample': True,
|
2023-09-12 17:57:00 +08:00
|
|
|
'top_p': 0.9,
|
|
|
|
|
'history': self.history_demo
|
2023-07-22 21:53:04 +08:00
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|