mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 08:27:51 +01:00
fix model
This commit is contained in:
@@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module):
|
||||
pe = torch.zeros(1, max_len, d_model)
|
||||
pe[0, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[0, :, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
self.register_buffer('pe', pe, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :x.size(1)]
|
||||
@@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention):
|
||||
attention_mode = None,
|
||||
cross_frame_attention_mode = None,
|
||||
temporal_position_encoding = False,
|
||||
temporal_position_encoding_max_len = 24,
|
||||
temporal_position_encoding_max_len = 32,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -475,16 +475,77 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
return UNet3DConditionOutput(sample=sample)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
||||
if subfolder is not None:
|
||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
||||
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
|
||||
def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
|
||||
from diffusers import __version__
|
||||
from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available
|
||||
from diffusers.modeling_utils import load_state_dict
|
||||
print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...")
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
model_file = None
|
||||
if is_safetensors_available():
|
||||
try:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
model_file = cls._get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
config, unused_kwargs = cls.load_config(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
||||
if not os.path.isfile(config_file):
|
||||
raise RuntimeError(f"{config_file} does not exist")
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
config["_class_name"] = cls.__name__
|
||||
config["down_block_types"] = [
|
||||
"CrossAttnDownBlock3D",
|
||||
@@ -499,12 +560,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
"CrossAttnUpBlock3D"
|
||||
]
|
||||
|
||||
from diffusers.utils import WEIGHTS_NAME
|
||||
model = cls.from_config(config, **unet_additional_kwargs)
|
||||
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
||||
if not os.path.isfile(model_file):
|
||||
raise RuntimeError(f"{model_file} does not exist")
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs)
|
||||
state_dict = load_state_dict(model_file)
|
||||
|
||||
m, u = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
import torchvision
|
||||
import torch.distributed as dist
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
@@ -14,6 +15,45 @@ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, con
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
|
||||
|
||||
|
||||
MOTION_MODULES = [
|
||||
"mm_sd_v14.ckpt",
|
||||
"mm_sd_v15.ckpt",
|
||||
"mm_sd_v15_v2.ckpt",
|
||||
"v3_sd15_mm.ckpt",
|
||||
]
|
||||
|
||||
ADAPTERS = [
|
||||
# "mm_sd_v14.ckpt",
|
||||
# "mm_sd_v15.ckpt",
|
||||
# "mm_sd_v15_v2.ckpt",
|
||||
# "mm_sdxl_v10_beta.ckpt",
|
||||
"v2_lora_PanLeft.ckpt",
|
||||
"v2_lora_PanRight.ckpt",
|
||||
"v2_lora_RollingAnticlockwise.ckpt",
|
||||
"v2_lora_RollingClockwise.ckpt",
|
||||
"v2_lora_TiltDown.ckpt",
|
||||
"v2_lora_TiltUp.ckpt",
|
||||
"v2_lora_ZoomIn.ckpt",
|
||||
"v2_lora_ZoomOut.ckpt",
|
||||
"v3_sd15_adapter.ckpt",
|
||||
# "v3_sd15_mm.ckpt",
|
||||
"v3_sd15_sparsectrl_rgb.ckpt",
|
||||
"v3_sd15_sparsectrl_scribble.ckpt",
|
||||
]
|
||||
|
||||
BACKUP_DREAMBOOTH_MODELS = [
|
||||
"realisticVisionV60B1_v51VAE.safetensors",
|
||||
"majicmixRealistic_v4.safetensors",
|
||||
"leosamsFilmgirlUltra_velvia20Lora.safetensors",
|
||||
"toonyou_beta3.safetensors",
|
||||
"majicmixRealistic_v5Preview.safetensors",
|
||||
"rcnzCartoon3d_v10.safetensors",
|
||||
"lyriel_v16.safetensors",
|
||||
"leosamsHelloworldXL_filmGrain20.safetensors",
|
||||
"TUSUN.safetensors",
|
||||
]
|
||||
|
||||
|
||||
def zero_rank_print(s):
|
||||
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
||||
|
||||
@@ -33,64 +73,21 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
|
||||
imageio.mimsave(path, outputs, fps=fps)
|
||||
|
||||
|
||||
# DDIM Inversion
|
||||
@torch.no_grad()
|
||||
def init_prompt(prompt, pipeline):
|
||||
uncond_input = pipeline.tokenizer(
|
||||
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
||||
return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
||||
text_input = pipeline.tokenizer(
|
||||
[prompt],
|
||||
padding="max_length",
|
||||
max_length=pipeline.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
def auto_download(local_path, is_dreambooth_lora=False):
|
||||
hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff"
|
||||
folder, filename = os.path.split(local_path)
|
||||
|
||||
return context
|
||||
if not os.path.exists(local_path):
|
||||
print(f"local file {local_path} does not exist. trying to download from {hf_repo}")
|
||||
|
||||
if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}"
|
||||
else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}"
|
||||
|
||||
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
||||
timestep, next_timestep = min(
|
||||
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
||||
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
||||
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
||||
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
||||
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
||||
return next_sample
|
||||
folder = "." if folder == "" else folder
|
||||
os.makedirs(folder, exist_ok=True)
|
||||
snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename])
|
||||
|
||||
|
||||
def get_noise_pred_single(latents, t, context, unet):
|
||||
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
||||
return noise_pred
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
||||
context = init_prompt(prompt, pipeline)
|
||||
uncond_embeddings, cond_embeddings = context.chunk(2)
|
||||
all_latent = [latent]
|
||||
latent = latent.clone().detach()
|
||||
for i in tqdm(range(num_inv_steps)):
|
||||
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
||||
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
||||
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
||||
all_latent.append(latent)
|
||||
return all_latent
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
||||
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
||||
return ddim_latents
|
||||
|
||||
def load_weights(
|
||||
animation_pipeline,
|
||||
# motion module
|
||||
@@ -107,10 +104,16 @@ def load_weights(
|
||||
# motion module
|
||||
unet_state_dict = {}
|
||||
if motion_module_path != "":
|
||||
auto_download(motion_module_path, is_dreambooth_lora=False)
|
||||
|
||||
print(f"load motion module from {motion_module_path}")
|
||||
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
||||
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
||||
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
||||
# filter parameters
|
||||
for name, param in motion_module_state_dict.items():
|
||||
if not "motion_modules." in name: continue
|
||||
if "pos_encoder.pe" in name: continue
|
||||
unet_state_dict.update({name: param})
|
||||
unet_state_dict.pop("animatediff_config", "")
|
||||
|
||||
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
||||
@@ -119,6 +122,8 @@ def load_weights(
|
||||
|
||||
# base model
|
||||
if dreambooth_model_path != "":
|
||||
auto_download(dreambooth_model_path, is_dreambooth_lora=True)
|
||||
|
||||
print(f"load dreambooth model from {dreambooth_model_path}")
|
||||
if dreambooth_model_path.endswith(".safetensors"):
|
||||
dreambooth_state_dict = {}
|
||||
@@ -140,6 +145,8 @@ def load_weights(
|
||||
|
||||
# lora layers
|
||||
if lora_model_path != "":
|
||||
auto_download(lora_model_path, is_dreambooth_lora=True)
|
||||
|
||||
print(f"load lora model from {lora_model_path}")
|
||||
assert lora_model_path.endswith(".safetensors")
|
||||
lora_state_dict = {}
|
||||
@@ -152,6 +159,8 @@ def load_weights(
|
||||
|
||||
# domain adapter lora
|
||||
if adapter_lora_path != "":
|
||||
auto_download(adapter_lora_path, is_dreambooth_lora=False)
|
||||
|
||||
print(f"load domain lora from {adapter_lora_path}")
|
||||
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
||||
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
||||
@@ -162,6 +171,9 @@ def load_weights(
|
||||
# motion module lora
|
||||
for motion_module_lora_config in motion_module_lora_configs:
|
||||
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
||||
|
||||
auto_download(path, is_dreambooth_lora=False)
|
||||
|
||||
print(f"load motion LoRA from {path}")
|
||||
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
||||
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
||||
|
||||
Reference in New Issue
Block a user