support safetensors weight pipeline (#421)

This commit is contained in:
Wang Qiang
2023-07-27 16:49:01 +08:00
committed by GitHub
parent 9802dfe93b
commit dca6143b8b

View File

@@ -37,7 +37,9 @@ class StableDiffusionPipeline(DiffusersPipeline):
lora_dir: lora weight dir for unet.
custom_dir: custom diffusion weight dir for unet.
modifier_token: token to use as a modifier for the concept of custom diffusion.
use_safetensors: load safetensors weights.
"""
use_safetensors = kwargs.pop('use_safetensors', False)
# check custom diffusion input value
if custom_dir is None and modifier_token is not None:
raise ValueError(
@@ -50,7 +52,7 @@ class StableDiffusionPipeline(DiffusersPipeline):
# load pipeline
torch_type = torch.float16 if self.device == 'cuda' else torch.float32
self.pipeline = DiffusionPipeline.from_pretrained(
model, torch_dtype=torch_type)
model, use_safetensors=use_safetensors, torch_dtype=torch_type)
self.pipeline = self.pipeline.to(self.device)
# load lora moudle to unet