support sparsectrl

This commit is contained in:
Yuwei Guo
2023-12-15 20:55:51 +08:00
parent 6c8a01b148
commit 401bc45697
7 changed files with 697 additions and 45 deletions

View File

@@ -12,7 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Changes were made to this source code by Yuwei Guo.
""" Conversion script for the LoRA's safetensors checkpoints. """
import argparse
@@ -21,11 +22,9 @@ import torch
from safetensors.torch import load_file
from diffusers import StableDiffusionPipeline
import pdb
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
def load_diffusers_lora(pipeline, state_dict, alpha=1.0):
# directly update weight in diffusers model
for key in state_dict:
# only process lora down key
@@ -48,7 +47,6 @@ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
return pipeline
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
# load base model
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)

View File

@@ -11,7 +11,7 @@ from safetensors import safe_open
from tqdm import tqdm
from einops import rearrange
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
def zero_rank_print(s):
@@ -96,12 +96,15 @@ def load_weights(
# motion module
motion_module_path = "",
motion_module_lora_configs = [],
# domain adapter
adapter_lora_path = "",
adapter_lora_scale = 1.0,
# image layers
dreambooth_model_path = "",
lora_model_path = "",
lora_alpha = 0.8,
dreambooth_model_path = "",
lora_model_path = "",
lora_alpha = 0.8,
):
# 1.1 motion module
# motion module
unet_state_dict = {}
if motion_module_path != "":
print(f"load motion module from {motion_module_path}")
@@ -113,6 +116,7 @@ def load_weights(
assert len(unexpected) == 0
del unet_state_dict
# base model
if dreambooth_model_path != "":
print(f"load dreambooth model from {dreambooth_model_path}")
if dreambooth_model_path.endswith(".safetensors"):
@@ -133,6 +137,7 @@ def load_weights(
animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
del dreambooth_state_dict
# lora layers
if lora_model_path != "":
print(f"load lora model from {lora_model_path}")
assert lora_model_path.endswith(".safetensors")
@@ -144,14 +149,21 @@ def load_weights(
animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
del lora_state_dict
# domain adapter lora
if adapter_lora_path != "":
print(f"load domain lora from {adapter_lora_path}")
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
# motion module lora
for motion_module_lora_config in motion_module_lora_configs:
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
print(f"load motion LoRA from {path}")
motion_lora_state_dict = torch.load(path, map_location="cpu")
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
return animation_pipeline