mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
support safetensors weight pipeline (#421)
This commit is contained in:
@@ -37,7 +37,9 @@ class StableDiffusionPipeline(DiffusersPipeline):
|
||||
lora_dir: lora weight dir for unet.
|
||||
custom_dir: custom diffusion weight dir for unet.
|
||||
modifier_token: token to use as a modifier for the concept of custom diffusion.
|
||||
use_safetensors: load safetensors weights.
|
||||
"""
|
||||
use_safetensors = kwargs.pop('use_safetensors', False)
|
||||
# check custom diffusion input value
|
||||
if custom_dir is None and modifier_token is not None:
|
||||
raise ValueError(
|
||||
@@ -50,7 +52,7 @@ class StableDiffusionPipeline(DiffusersPipeline):
|
||||
# load pipeline
|
||||
torch_type = torch.float16 if self.device == 'cuda' else torch.float32
|
||||
self.pipeline = DiffusionPipeline.from_pretrained(
|
||||
model, torch_dtype=torch_type)
|
||||
model, use_safetensors=use_safetensors, torch_dtype=torch_type)
|
||||
self.pipeline = self.pipeline.to(self.device)
|
||||
|
||||
# load lora moudle to unet
|
||||
|
||||
Reference in New Issue
Block a user