Files
modelscope/tests/pipelines/test_fid_dialogue.py
xingjun.wang 48c0d2a9af add 1.6
2023-05-22 10:53:18 +08:00

72 lines
3.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class FidDialogueTest(unittest.TestCase):
def setUp(self) -> None:
self.task = Tasks.fid_dialogue
# 240M
self.model_id_240m = 'damo/ChatPLUG-240M'
self.model_revision_240m = 'v1.0.0'
# 3.7B
self.model_id_3_7b = 'damo/ChatPLUG-3.7B'
self.model_revision_3_7b = 'v1.0.0'
# sample
know_list = [
'李白701年—762年字太白号青莲居士又号“谪仙人”。是唐代伟大的浪漫主义诗人被后人誉为“诗仙”。与杜甫并称为“李杜”为了与另两位诗人李商隐与杜牧即“小李杜”区别杜甫与',
'李白701年2月28日762字太白号青莲居士唐朝诗人有“诗仙”之称最伟大的浪漫主义诗人。汉族出生于西域碎叶城今吉尔吉斯斯坦托克马克5岁随父迁至剑南道之绵州巴西郡',
'李白701─762字太白号青莲居士祖籍陇西成纪今甘肃省天水县附近。先世于隋末流徙中亚。李白即生于中亚的碎叶城今吉尔吉斯斯坦境内。五岁时随其父迁居绵州彰明县今四川省江油'
]
self.input = {
'history': '你好[SEP]你好,我是娜娜,很高兴认识你![SEP]李白是谁',
'bot_profile': '我是娜娜;我是女生;我是单身',
'knowledge': '[SEP]'.join(know_list),
'user_profile': '你是小明'
}
preprocess_params = {'max_encoder_length': 300, 'context_turn': 3}
forward_params = {
'min_length': 10,
'max_length': 512,
'num_beams': 1,
'temperature': 0.8,
'do_sample': True,
'early_stopping': True,
'top_k': 50,
'top_p': 0.8,
'repetition_penalty': 1.2,
'length_penalty': 1.2,
'no_repeat_ngram_size': 6
}
self.kwargs = {
'preprocess_params': preprocess_params,
'forward_params': forward_params
}
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_240m_pipeline(self):
pipeline_ins = pipeline(
task=self.task,
model=self.model_id_240m,
model_revision=self.model_revision_240m)
result = pipeline_ins(self.input, **self.kwargs)
print(result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_3_7b_pipeline(self):
pipeline_ins = pipeline(
task=self.task,
model=self.model_id_3_7b,
model_revision=self.model_revision_3_7b)
result = pipeline_ins(self.input, **self.kwargs)
print(result)
if __name__ == '__main__':
unittest.main()