DALL-E 2: 修复dev/dalle2_1分支问题,增加测试代码,本地测试通过

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10037492
This commit is contained in:
xuangen.hlh
2022-09-06 20:47:23 +08:00
committed by yingda.chen
parent cd8ac57fdd
commit f7f29ed1ff
13 changed files with 2638 additions and 3 deletions

View File

@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import numpy as np
import torch
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class MultiStageDiffusionTest(unittest.TestCase):
model_id = 'damo/cv_diffusion_text-to-image-synthesis'
test_text = {'text': 'Photograph of a baby chicken wearing sunglasses'}
@unittest.skip(
'skip test since the pretrained model is not publicly available')
def test_run_with_model_from_modelhub(self):
model = Model.from_pretrained(self.model_id)
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis, model=model)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))
@unittest.skip(
'skip test since the pretrained model is not publicly available')
def test_run_with_model_name(self):
pipe_line_text_to_image_synthesis = pipeline(
task=Tasks.text_to_image_synthesis, model=self.model_id)
img = pipe_line_text_to_image_synthesis(
self.test_text)[OutputKeys.OUTPUT_IMG]
print(np.sum(np.abs(img)))
if __name__ == '__main__':
unittest.main()