diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index b18e0979..8c7d3780 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -204,7 +204,14 @@ TASK_INPUTS = { InputType.IMAGE, Tasks.video_embedding: InputType.VIDEO, - Tasks.virtual_try_on: (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE), + Tasks.virtual_try_on: [ + (InputType.IMAGE, InputType.IMAGE, InputType.IMAGE), + { + 'masked_model': InputType.IMAGE, + 'pose': InputType.IMAGE, + 'cloth': InputType.IMAGE, + } + ], Tasks.text_driven_segmentation: { InputKeys.IMAGE: InputType.IMAGE, InputKeys.TEXT: InputType.TEXT diff --git a/tests/pipelines/test_virtual_try_on.py b/tests/pipelines/test_virtual_try_on.py index c8a55f79..2e399d8f 100644 --- a/tests/pipelines/test_virtual_try_on.py +++ b/tests/pipelines/test_virtual_try_on.py @@ -20,7 +20,11 @@ class VirtualTryonTest(unittest.TestCase): masked_model = Image.open('data/test/images/virtual_tryon_model.jpg') pose = Image.open('data/test/images/virtual_tryon_pose.jpg') cloth = Image.open('data/test/images/virtual_tryon_cloth.jpg') - input_imgs = (masked_model, pose, cloth) + input_imgs = { + 'masked_model': masked_model, + 'pose': pose, + 'cloth': cloth, + } @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name(self):