mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
generate 时去除 tokenizer 输出结尾的 sep,修复 gpt3 模型目前续写内容与上文无关的 bug
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9696570
134 lines
6.5 KiB
Python
134 lines
6.5 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||
import unittest
|
||
|
||
from modelscope.hub.snapshot_download import snapshot_download
|
||
from modelscope.models import Model
|
||
from modelscope.models.nlp import GPT3ForTextGeneration, PalmForTextGeneration
|
||
from modelscope.pipelines import pipeline
|
||
from modelscope.pipelines.nlp import TextGenerationPipeline
|
||
from modelscope.preprocessors import TextGenerationPreprocessor
|
||
from modelscope.utils.constant import Tasks
|
||
from modelscope.utils.test_utils import test_level
|
||
|
||
|
||
class TextGenerationTest(unittest.TestCase):
|
||
|
||
def setUp(self) -> None:
|
||
self.palm_model_id_zh = 'damo/nlp_palm2.0_text-generation_chinese-base'
|
||
self.palm_model_id_en = 'damo/nlp_palm2.0_text-generation_english-base'
|
||
self.palm_input_zh = """
|
||
本文总结了十个可穿戴产品的设计原则,而这些原则,同样也是笔者认为是这个行业最吸引人的地方:
|
||
1.为人们解决重复性问题;2.从人开始,而不是从机器开始;3.要引起注意,但不要刻意;4.提升用户能力,而不是取代
|
||
"""
|
||
self.palm_input_en = """
|
||
The Director of Public Prosecutions who let off Lord Janner over alleged child sex abuse started
|
||
her career at a legal chambers when the disgraced Labour peer was a top QC there . Alison Saunders ,
|
||
54 , sparked outrage last week when she decided the 86-year-old should not face astring of charges
|
||
of paedophilia against nine children because he has dementia . Today , newly-released documents
|
||
revealed damning evidence that abuse was covered up by police andsocial workers for more than 20 years .
|
||
And now it has emerged Mrs Saunders ' law career got off to a flying start when she secured her
|
||
pupillage -- a barrister 's training contract at 1 Garden Court Chambers in London in 1983 .
|
||
"""
|
||
|
||
self.gpt3_base_model_id = 'damo/nlp_gpt3_text-generation_chinese-base'
|
||
self.gpt3_large_model_id = 'damo/nlp_gpt3_text-generation_chinese-large'
|
||
self.gpt3_input = '《故乡》。深蓝的天空中挂着一轮金黄的圆月,下面是海边的沙地,'
|
||
|
||
def run_pipeline_with_model_instance(self, model_id, input):
|
||
model = Model.from_pretrained(model_id)
|
||
preprocessor = TextGenerationPreprocessor(
|
||
model.model_dir,
|
||
model.tokenizer,
|
||
first_sequence='sentence',
|
||
second_sequence=None)
|
||
pipeline_ins = pipeline(
|
||
task=Tasks.text_generation, model=model, preprocessor=preprocessor)
|
||
print(pipeline_ins(input))
|
||
|
||
def run_pipeline_with_model_id(self, model_id, input):
|
||
pipeline_ins = pipeline(task=Tasks.text_generation, model=model_id)
|
||
print(pipeline_ins(input))
|
||
|
||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
def test_palm_zh_with_model_name(self):
|
||
self.run_pipeline_with_model_id(self.palm_model_id_zh,
|
||
self.palm_input_zh)
|
||
|
||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
def test_palm_en_with_model_name(self):
|
||
self.run_pipeline_with_model_id(self.palm_model_id_en,
|
||
self.palm_input_en)
|
||
|
||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
def test_gpt_base_with_model_name(self):
|
||
self.run_pipeline_with_model_id(self.gpt3_base_model_id,
|
||
self.gpt3_input)
|
||
|
||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||
def test_gpt_large_with_model_name(self):
|
||
self.run_pipeline_with_model_id(self.gpt3_large_model_id,
|
||
self.gpt3_input)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_palm_zh_with_model_instance(self):
|
||
self.run_pipeline_with_model_instance(self.palm_model_id_zh,
|
||
self.palm_input_zh)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_palm_en_with_model_instance(self):
|
||
self.run_pipeline_with_model_instance(self.palm_model_id_en,
|
||
self.palm_input_en)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_gpt_base_with_model_instance(self):
|
||
self.run_pipeline_with_model_instance(self.gpt3_base_model_id,
|
||
self.gpt3_input)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_gpt_large_with_model_instance(self):
|
||
self.run_pipeline_with_model_instance(self.gpt3_large_model_id,
|
||
self.gpt3_input)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_run_palm(self):
|
||
for model_id, input in ((self.palm_model_id_zh, self.palm_input_zh),
|
||
(self.palm_model_id_en, self.palm_input_en)):
|
||
cache_path = snapshot_download(model_id)
|
||
model = PalmForTextGeneration.from_pretrained(cache_path)
|
||
preprocessor = TextGenerationPreprocessor(
|
||
cache_path,
|
||
model.tokenizer,
|
||
first_sequence='sentence',
|
||
second_sequence=None)
|
||
pipeline1 = TextGenerationPipeline(model, preprocessor)
|
||
pipeline2 = pipeline(
|
||
Tasks.text_generation, model=model, preprocessor=preprocessor)
|
||
print(
|
||
f'pipeline1: {pipeline1(input)}\npipeline2: {pipeline2(input)}'
|
||
)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_run_gpt3(self):
|
||
cache_path = snapshot_download(self.gpt3_base_model_id)
|
||
model = GPT3ForTextGeneration(cache_path)
|
||
preprocessor = TextGenerationPreprocessor(
|
||
cache_path,
|
||
model.tokenizer,
|
||
first_sequence='sentence',
|
||
second_sequence=None)
|
||
pipeline1 = TextGenerationPipeline(model, preprocessor)
|
||
pipeline2 = pipeline(
|
||
Tasks.text_generation, model=model, preprocessor=preprocessor)
|
||
print(
|
||
f'pipeline1: {pipeline1(self.gpt3_input)}\npipeline2: {pipeline2(self.gpt3_input)}'
|
||
)
|
||
|
||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||
def test_run_with_default_model(self):
|
||
pipeline_ins = pipeline(task=Tasks.text_generation)
|
||
print(pipeline_ins(self.palm_input_zh))
|
||
|
||
|
||
if __name__ == '__main__':
|
||
unittest.main()
|