From 55d8d1b67cb309067eba90b9219be521641177d5 Mon Sep 17 00:00:00 2001 From: "zhicheng.sc" Date: Fri, 10 Feb 2023 02:16:21 +0000 Subject: [PATCH] sd-inpainting: pass prompt in kwargs & reduce GPU usage Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11616099 * pass prompt in kwargs --- .../cv/image_inpainting_sdv2_pipeline.py | 29 +++++++++++++++++-- requirements/cv.txt | 2 +- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/modelscope/pipelines/cv/image_inpainting_sdv2_pipeline.py b/modelscope/pipelines/cv/image_inpainting_sdv2_pipeline.py index 2d3e55d5..e5f2d052 100644 --- a/modelscope/pipelines/cv/image_inpainting_sdv2_pipeline.py +++ b/modelscope/pipelines/cv/image_inpainting_sdv2_pipeline.py @@ -56,13 +56,29 @@ class ImageInpaintingSDV2Pipeline(DiffusersPipeline): """ super().__init__(model, device, **kwargs) - torch_dtype = kwargs.get('torch_dtype', torch.float32) + torch_dtype = kwargs.get('torch_dtype', torch.float16) # build upon the diffuser stable diffusion pipeline self.pipeline = StableDiffusionInpaintPipeline.from_pretrained( model, torch_dtype=torch_dtype) self.pipeline.to(self.device) + enable_attention_slicing = kwargs.get('enable_attention_slicing', True) + if enable_attention_slicing: + self.pipeline.enable_attention_slicing() + + def _sanitize_parameters(self, **pipeline_parameters): + """ + this method should sanitize the keyword args to preprocessor params, + forward params and postprocess params on '__call__' or '_process_single' method + + Returns: + Dict[str, str]: preprocess_params = {} + Dict[str, str]: forward_params = pipeline_parameters + Dict[str, str]: postprocess_params = pipeline_parameters + """ + return {}, pipeline_parameters, pipeline_parameters + def forward(self, inputs: Dict[str, Any], **forward_params) -> Dict[str, Any]: if not isinstance(inputs, dict): @@ -76,13 +92,20 @@ class ImageInpaintingSDV2Pipeline(DiffusersPipeline): num_images_per_prompt = inputs.get('num_images_per_prompt', 1) eta = inputs.get('eta', 0.0) + if 'prompt' in inputs.keys(): + prompt = inputs['prompt'] + else: + # for demo_service + prompt = forward_params.get('prompt', 'background') + print(f'Test with prompt: {prompt}') + image = load_image(inputs['image']) mask = load_image(inputs['mask']) - prompt = inputs['prompt'] + w, h = image.size print(f'loaded input image of size ({w}, {h})') width, height = map(lambda x: x - x % 64, - (w, h)) # resize to integer multiple of 32 + (w, h)) # resize to integer multiple of 64 image = image.resize((width, height)) mask = mask.resize((width, height)) out_image = self.pipeline( diff --git a/requirements/cv.txt b/requirements/cv.txt index d40d7db5..23107e7c 100644 --- a/requirements/cv.txt +++ b/requirements/cv.txt @@ -54,7 +54,7 @@ timm>=0.4.9 torchmetrics>=0.6.2 torchsummary>=1.5.1 torchvision -transformers>=4.19.2 +transformers>=4.26.0 ujson utils videofeatures_clipit>=1.0