mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 17:56:15 +02:00
fix model
This commit is contained in:
@@ -7,6 +7,7 @@ import torch
|
||||
import torchvision
|
||||
import torch.distributed as dist
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
@@ -14,6 +15,45 @@ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, con
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
|
||||
|
||||
|
||||
MOTION_MODULES = [
|
||||
"mm_sd_v14.ckpt",
|
||||
"mm_sd_v15.ckpt",
|
||||
"mm_sd_v15_v2.ckpt",
|
||||
"v3_sd15_mm.ckpt",
|
||||
]
|
||||
|
||||
ADAPTERS = [
|
||||
# "mm_sd_v14.ckpt",
|
||||
# "mm_sd_v15.ckpt",
|
||||
# "mm_sd_v15_v2.ckpt",
|
||||
# "mm_sdxl_v10_beta.ckpt",
|
||||
"v2_lora_PanLeft.ckpt",
|
||||
"v2_lora_PanRight.ckpt",
|
||||
"v2_lora_RollingAnticlockwise.ckpt",
|
||||
"v2_lora_RollingClockwise.ckpt",
|
||||
"v2_lora_TiltDown.ckpt",
|
||||
"v2_lora_TiltUp.ckpt",
|
||||
"v2_lora_ZoomIn.ckpt",
|
||||
"v2_lora_ZoomOut.ckpt",
|
||||
"v3_sd15_adapter.ckpt",
|
||||
# "v3_sd15_mm.ckpt",
|
||||
"v3_sd15_sparsectrl_rgb.ckpt",
|
||||
"v3_sd15_sparsectrl_scribble.ckpt",
|
||||
]
|
||||
|
||||
BACKUP_DREAMBOOTH_MODELS = [
|
||||
"realisticVisionV60B1_v51VAE.safetensors",
|
||||
"majicmixRealistic_v4.safetensors",
|
||||
"leosamsFilmgirlUltra_velvia20Lora.safetensors",
|
||||
"toonyou_beta3.safetensors",
|
||||
"majicmixRealistic_v5Preview.safetensors",
|
||||
"rcnzCartoon3d_v10.safetensors",
|
||||
"lyriel_v16.safetensors",
|
||||
"leosamsHelloworldXL_filmGrain20.safetensors",
|
||||
"TUSUN.safetensors",
|
||||
]
|
||||
|
||||
|
||||
def zero_rank_print(s):
|
||||
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
||||
|
||||
@@ -33,64 +73,21 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
|
||||
imageio.mimsave(path, outputs, fps=fps)
|
||||
|
||||
|
||||
# DDIM Inversion
|
||||
@torch.no_grad()
|
||||
def init_prompt(prompt, pipeline):
|
||||
uncond_input = pipeline.tokenizer(
|
||||
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
||||
return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
||||
text_input = pipeline.tokenizer(
|
||||
[prompt],
|
||||
padding="max_length",
|
||||
max_length=pipeline.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
def auto_download(local_path, is_dreambooth_lora=False):
|
||||
hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff"
|
||||
folder, filename = os.path.split(local_path)
|
||||
|
||||
return context
|
||||
if not os.path.exists(local_path):
|
||||
print(f"local file {local_path} does not exist. trying to download from {hf_repo}")
|
||||
|
||||
if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}"
|
||||
else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}"
|
||||
|
||||
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
||||
timestep, next_timestep = min(
|
||||
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
||||
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
||||
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
||||
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
||||
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
||||
return next_sample
|
||||
folder = "." if folder == "" else folder
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename])
|
||||
|
||||
|
||||
def get_noise_pred_single(latents, t, context, unet):
|
||||
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
||||
return noise_pred
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
||||
context = init_prompt(prompt, pipeline)
|
||||
uncond_embeddings, cond_embeddings = context.chunk(2)
|
||||
all_latent = [latent]
|
||||
latent = latent.clone().detach()
|
||||
for i in tqdm(range(num_inv_steps)):
|
||||
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
||||
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
||||
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
||||
all_latent.append(latent)
|
||||
return all_latent
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
||||
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
||||
return ddim_latents
|
||||
|
||||
def load_weights(
|
||||
animation_pipeline,
|
||||
# motion module
|
||||
@@ -107,10 +104,16 @@ def load_weights(
|
||||
# motion module
|
||||
unet_state_dict = {}
|
||||
if motion_module_path != "":
|
||||
auto_download(motion_module_path, is_dreambooth_lora=False)
|
||||
|
||||
print(f"load motion module from {motion_module_path}")
|
||||
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
||||
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
||||
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
||||
# filter parameters
|
||||
for name, param in motion_module_state_dict.items():
|
||||
if not "motion_modules." in name: continue
|
||||
if "pos_encoder.pe" in name: continue
|
||||
unet_state_dict.update({name: param})
|
||||
unet_state_dict.pop("animatediff_config", "")
|
||||
|
||||
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
||||
@@ -119,6 +122,8 @@ def load_weights(
|
||||
|
||||
# base model
|
||||
if dreambooth_model_path != "":
|
||||
auto_download(dreambooth_model_path, is_dreambooth_lora=True)
|
||||
|
||||
print(f"load dreambooth model from {dreambooth_model_path}")
|
||||
if dreambooth_model_path.endswith(".safetensors"):
|
||||
dreambooth_state_dict = {}
|
||||
@@ -140,6 +145,8 @@ def load_weights(
|
||||
|
||||
# lora layers
|
||||
if lora_model_path != "":
|
||||
auto_download(lora_model_path, is_dreambooth_lora=True)
|
||||
|
||||
print(f"load lora model from {lora_model_path}")
|
||||
assert lora_model_path.endswith(".safetensors")
|
||||
lora_state_dict = {}
|
||||
@@ -152,6 +159,8 @@ def load_weights(
|
||||
|
||||
# domain adapter lora
|
||||
if adapter_lora_path != "":
|
||||
auto_download(adapter_lora_path, is_dreambooth_lora=False)
|
||||
|
||||
print(f"load domain lora from {adapter_lora_path}")
|
||||
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
||||
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
||||
@@ -162,6 +171,9 @@ def load_weights(
|
||||
# motion module lora
|
||||
for motion_module_lora_config in motion_module_lora_configs:
|
||||
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
||||
|
||||
auto_download(path, is_dreambooth_lora=False)
|
||||
|
||||
print(f"load motion LoRA from {path}")
|
||||
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
||||
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
||||
|
||||
Reference in New Issue
Block a user