Files
modelscope/tests/pipelines/test_fid_dialogue.py

72 lines
3.0 KiB
Python
Raw Permalink Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
2023-05-22 10:53:18 +08:00
class FidDialogueTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.fid_dialogue
# 240M
self.model_id_240m = 'damo/ChatPLUG-240M'
self.model_revision_240m = 'v1.0.0'
# 3.7B
self.model_id_3_7b = 'damo/ChatPLUG-3.7B'
self.model_revision_3_7b = 'v1.0.0'
# sample
know_list = [
'李白701年—762年字太白号青莲居士又号“谪仙人”。是唐代伟大的浪漫主义诗人被后人誉为“诗仙”。与杜甫并称为“李杜”为了与另两位诗人李商隐与杜牧即“小李杜”区别杜甫与',
'李白701年2月28日762字太白号青莲居士唐朝诗人有“诗仙”之称最伟大的浪漫主义诗人。汉族出生于西域碎叶城今吉尔吉斯斯坦托克马克5岁随父迁至剑南道之绵州巴西郡',
'李白701─762字太白号青莲居士祖籍陇西成纪今甘肃省天水县附近。先世于隋末流徙中亚。李白即生于中亚的碎叶城今吉尔吉斯斯坦境内。五岁时随其父迁居绵州彰明县今四川省江油'
]
self.input = {
'history': '你好[SEP]你好,我是娜娜,很高兴认识你![SEP]李白是谁',
'bot_profile': '我是娜娜;我是女生;我是单身',
'knowledge': '[SEP]'.join(know_list),
'user_profile': '你是小明'
}
preprocess_params = {'max_encoder_length': 300, 'context_turn': 3}
forward_params = {
'min_length': 10,
'max_length': 512,
'num_beams': 1,
'temperature': 0.8,
'do_sample': True,
'early_stopping': True,
'top_k': 50,
'top_p': 0.8,
'repetition_penalty': 1.2,
'length_penalty': 1.2,
'no_repeat_ngram_size': 6
}
self.kwargs = {
'preprocess_params': preprocess_params,
'forward_params': forward_params
}
2023-05-22 10:53:18 +08:00
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_240m_pipeline(self):
pipeline_ins = pipeline(
task=self.task,
model=self.model_id_240m,
model_revision=self.model_revision_240m)
result = pipeline_ins(self.input, **self.kwargs)
print(result)
2023-05-22 10:53:18 +08:00
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_3_7b_pipeline(self):
pipeline_ins = pipeline(
task=self.task,
model=self.model_id_3_7b,
model_revision=self.model_revision_3_7b)
result = pipeline_ins(self.input, **self.kwargs)
print(result)
if __name__ == '__main__':
unittest.main()