mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
sd-inpainting: pass prompt in kwargs & reduce GPU usage
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11616099 * pass prompt in kwargs
This commit is contained in:
@@ -56,13 +56,29 @@ class ImageInpaintingSDV2Pipeline(DiffusersPipeline):
|
||||
"""
|
||||
super().__init__(model, device, **kwargs)
|
||||
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.float32)
|
||||
torch_dtype = kwargs.get('torch_dtype', torch.float16)
|
||||
|
||||
# build upon the diffuser stable diffusion pipeline
|
||||
self.pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model, torch_dtype=torch_dtype)
|
||||
self.pipeline.to(self.device)
|
||||
|
||||
enable_attention_slicing = kwargs.get('enable_attention_slicing', True)
|
||||
if enable_attention_slicing:
|
||||
self.pipeline.enable_attention_slicing()
|
||||
|
||||
def _sanitize_parameters(self, **pipeline_parameters):
|
||||
"""
|
||||
this method should sanitize the keyword args to preprocessor params,
|
||||
forward params and postprocess params on '__call__' or '_process_single' method
|
||||
|
||||
Returns:
|
||||
Dict[str, str]: preprocess_params = {}
|
||||
Dict[str, str]: forward_params = pipeline_parameters
|
||||
Dict[str, str]: postprocess_params = pipeline_parameters
|
||||
"""
|
||||
return {}, pipeline_parameters, pipeline_parameters
|
||||
|
||||
def forward(self, inputs: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
if not isinstance(inputs, dict):
|
||||
@@ -76,13 +92,20 @@ class ImageInpaintingSDV2Pipeline(DiffusersPipeline):
|
||||
num_images_per_prompt = inputs.get('num_images_per_prompt', 1)
|
||||
eta = inputs.get('eta', 0.0)
|
||||
|
||||
if 'prompt' in inputs.keys():
|
||||
prompt = inputs['prompt']
|
||||
else:
|
||||
# for demo_service
|
||||
prompt = forward_params.get('prompt', 'background')
|
||||
print(f'Test with prompt: {prompt}')
|
||||
|
||||
image = load_image(inputs['image'])
|
||||
mask = load_image(inputs['mask'])
|
||||
prompt = inputs['prompt']
|
||||
|
||||
w, h = image.size
|
||||
print(f'loaded input image of size ({w}, {h})')
|
||||
width, height = map(lambda x: x - x % 64,
|
||||
(w, h)) # resize to integer multiple of 32
|
||||
(w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
mask = mask.resize((width, height))
|
||||
out_image = self.pipeline(
|
||||
|
||||
@@ -54,7 +54,7 @@ timm>=0.4.9
|
||||
torchmetrics>=0.6.2
|
||||
torchsummary>=1.5.1
|
||||
torchvision
|
||||
transformers>=4.19.2
|
||||
transformers>=4.26.0
|
||||
ujson
|
||||
utils
|
||||
videofeatures_clipit>=1.0
|
||||
|
||||
Reference in New Issue
Block a user