stable diffusion allow postprocess kwargs

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12418063
* stable diffusion allow postprocess kwargs
This commit is contained in:
zhangzhicheng.zzc
2023-04-21 18:34:29 +08:00
committed by hemu
parent 37b3d04824
commit fc9822ba85
3 changed files with 4 additions and 4 deletions

View File

@@ -39,10 +39,10 @@ class DiffusersPipeline(Pipeline):
self.models = [self.model]
self.has_multiple_models = len(self.models) > 1
def preprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def preprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return inputs
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
return inputs
def __call__(self, input: Union[Input, List[Input]], *args,

View File

@@ -75,7 +75,7 @@ class ChineseStableDiffusionPipeline(DiffusersPipeline):
callback=inputs.get('callback'),
callback_steps=inputs.get('callback_steps', 1))
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
images = []
for img in inputs.images:
if isinstance(img, Image.Image):

View File

@@ -65,7 +65,7 @@ class StableDiffusionWrapperPipeline(DiffusersPipeline):
callback=inputs.get('callback'),
callback_steps=inputs.get('callback_steps', 1))
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
images = []
for img in inputs.images:
if isinstance(img, Image.Image):