sd-inpainting: pass prompt in kwargs & reduce GPU usage

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11616099

* pass prompt in kwargs
This commit is contained in:
zhicheng.sc
2023-02-10 02:16:21 +00:00
committed by wenmeng.zwm
parent 29e47e5030
commit 55d8d1b67c
2 changed files with 27 additions and 4 deletions

View File

@@ -56,13 +56,29 @@ class ImageInpaintingSDV2Pipeline(DiffusersPipeline):
"""
super().__init__(model, device, **kwargs)
torch_dtype = kwargs.get('torch_dtype', torch.float32)
torch_dtype = kwargs.get('torch_dtype', torch.float16)
# build upon the diffuser stable diffusion pipeline
self.pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model, torch_dtype=torch_dtype)
self.pipeline.to(self.device)
enable_attention_slicing = kwargs.get('enable_attention_slicing', True)
if enable_attention_slicing:
self.pipeline.enable_attention_slicing()
def _sanitize_parameters(self, **pipeline_parameters):
"""
this method should sanitize the keyword args to preprocessor params,
forward params and postprocess params on '__call__' or '_process_single' method
Returns:
Dict[str, str]: preprocess_params = {}
Dict[str, str]: forward_params = pipeline_parameters
Dict[str, str]: postprocess_params = pipeline_parameters
"""
return {}, pipeline_parameters, pipeline_parameters
def forward(self, inputs: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
if not isinstance(inputs, dict):
@@ -76,13 +92,20 @@ class ImageInpaintingSDV2Pipeline(DiffusersPipeline):
num_images_per_prompt = inputs.get('num_images_per_prompt', 1)
eta = inputs.get('eta', 0.0)
if 'prompt' in inputs.keys():
prompt = inputs['prompt']
else:
# for demo_service
prompt = forward_params.get('prompt', 'background')
print(f'Test with prompt: {prompt}')
image = load_image(inputs['image'])
mask = load_image(inputs['mask'])
prompt = inputs['prompt']
w, h = image.size
print(f'loaded input image of size ({w}, {h})')
width, height = map(lambda x: x - x % 64,
(w, h)) # resize to integer multiple of 32
(w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
mask = mask.resize((width, height))
out_image = self.pipeline(

View File

@@ -54,7 +54,7 @@ timm>=0.4.9
torchmetrics>=0.6.2
torchsummary>=1.5.1
torchvision
transformers>=4.19.2
transformers>=4.26.0
ujson
utils
videofeatures_clipit>=1.0