mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 16:38:01 +01:00
fix model
This commit is contained in:
@@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module):
|
|||||||
pe = torch.zeros(1, max_len, d_model)
|
pe = torch.zeros(1, max_len, d_model)
|
||||||
pe[0, :, 0::2] = torch.sin(position * div_term)
|
pe[0, :, 0::2] = torch.sin(position * div_term)
|
||||||
pe[0, :, 1::2] = torch.cos(position * div_term)
|
pe[0, :, 1::2] = torch.cos(position * div_term)
|
||||||
self.register_buffer('pe', pe)
|
self.register_buffer('pe', pe, persistent=False)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x + self.pe[:, :x.size(1)]
|
x = x + self.pe[:, :x.size(1)]
|
||||||
@@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention):
|
|||||||
attention_mode = None,
|
attention_mode = None,
|
||||||
cross_frame_attention_mode = None,
|
cross_frame_attention_mode = None,
|
||||||
temporal_position_encoding = False,
|
temporal_position_encoding = False,
|
||||||
temporal_position_encoding_max_len = 24,
|
temporal_position_encoding_max_len = 32,
|
||||||
*args, **kwargs
|
*args, **kwargs
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|||||||
@@ -475,16 +475,77 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|||||||
return UNet3DConditionOutput(sample=sample)
|
return UNet3DConditionOutput(sample=sample)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
|
||||||
if subfolder is not None:
|
from diffusers import __version__
|
||||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available
|
||||||
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...")
|
from diffusers.modeling_utils import load_state_dict
|
||||||
|
print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...")
|
||||||
|
|
||||||
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
|
||||||
|
user_agent = {
|
||||||
|
"diffusers": __version__,
|
||||||
|
"file_type": "model",
|
||||||
|
"framework": "pytorch",
|
||||||
|
}
|
||||||
|
|
||||||
|
model_file = None
|
||||||
|
if is_safetensors_available():
|
||||||
|
try:
|
||||||
|
model_file = cls._get_model_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
model_file = cls._get_model_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
weights_name=WEIGHTS_NAME,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
config, unused_kwargs = cls.load_config(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
device_map=device_map,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
|
||||||
if not os.path.isfile(config_file):
|
|
||||||
raise RuntimeError(f"{config_file} does not exist")
|
|
||||||
with open(config_file, "r") as f:
|
|
||||||
config = json.load(f)
|
|
||||||
config["_class_name"] = cls.__name__
|
config["_class_name"] = cls.__name__
|
||||||
config["down_block_types"] = [
|
config["down_block_types"] = [
|
||||||
"CrossAttnDownBlock3D",
|
"CrossAttnDownBlock3D",
|
||||||
@@ -499,12 +560,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|||||||
"CrossAttnUpBlock3D"
|
"CrossAttnUpBlock3D"
|
||||||
]
|
]
|
||||||
|
|
||||||
from diffusers.utils import WEIGHTS_NAME
|
model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs)
|
||||||
model = cls.from_config(config, **unet_additional_kwargs)
|
state_dict = load_state_dict(model_file)
|
||||||
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
|
||||||
if not os.path.isfile(model_file):
|
|
||||||
raise RuntimeError(f"{model_file} does not exist")
|
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
|
||||||
|
|
||||||
m, u = model.load_state_dict(state_dict, strict=False)
|
m, u = model.load_state_dict(state_dict, strict=False)
|
||||||
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@@ -14,6 +15,45 @@ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, con
|
|||||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
|
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
|
||||||
|
|
||||||
|
|
||||||
|
MOTION_MODULES = [
|
||||||
|
"mm_sd_v14.ckpt",
|
||||||
|
"mm_sd_v15.ckpt",
|
||||||
|
"mm_sd_v15_v2.ckpt",
|
||||||
|
"v3_sd15_mm.ckpt",
|
||||||
|
]
|
||||||
|
|
||||||
|
ADAPTERS = [
|
||||||
|
# "mm_sd_v14.ckpt",
|
||||||
|
# "mm_sd_v15.ckpt",
|
||||||
|
# "mm_sd_v15_v2.ckpt",
|
||||||
|
# "mm_sdxl_v10_beta.ckpt",
|
||||||
|
"v2_lora_PanLeft.ckpt",
|
||||||
|
"v2_lora_PanRight.ckpt",
|
||||||
|
"v2_lora_RollingAnticlockwise.ckpt",
|
||||||
|
"v2_lora_RollingClockwise.ckpt",
|
||||||
|
"v2_lora_TiltDown.ckpt",
|
||||||
|
"v2_lora_TiltUp.ckpt",
|
||||||
|
"v2_lora_ZoomIn.ckpt",
|
||||||
|
"v2_lora_ZoomOut.ckpt",
|
||||||
|
"v3_sd15_adapter.ckpt",
|
||||||
|
# "v3_sd15_mm.ckpt",
|
||||||
|
"v3_sd15_sparsectrl_rgb.ckpt",
|
||||||
|
"v3_sd15_sparsectrl_scribble.ckpt",
|
||||||
|
]
|
||||||
|
|
||||||
|
BACKUP_DREAMBOOTH_MODELS = [
|
||||||
|
"realisticVisionV60B1_v51VAE.safetensors",
|
||||||
|
"majicmixRealistic_v4.safetensors",
|
||||||
|
"leosamsFilmgirlUltra_velvia20Lora.safetensors",
|
||||||
|
"toonyou_beta3.safetensors",
|
||||||
|
"majicmixRealistic_v5Preview.safetensors",
|
||||||
|
"rcnzCartoon3d_v10.safetensors",
|
||||||
|
"lyriel_v16.safetensors",
|
||||||
|
"leosamsHelloworldXL_filmGrain20.safetensors",
|
||||||
|
"TUSUN.safetensors",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def zero_rank_print(s):
|
def zero_rank_print(s):
|
||||||
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
||||||
|
|
||||||
@@ -33,64 +73,21 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
|
|||||||
imageio.mimsave(path, outputs, fps=fps)
|
imageio.mimsave(path, outputs, fps=fps)
|
||||||
|
|
||||||
|
|
||||||
# DDIM Inversion
|
def auto_download(local_path, is_dreambooth_lora=False):
|
||||||
@torch.no_grad()
|
hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff"
|
||||||
def init_prompt(prompt, pipeline):
|
folder, filename = os.path.split(local_path)
|
||||||
uncond_input = pipeline.tokenizer(
|
|
||||||
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
|
||||||
return_tensors="pt"
|
|
||||||
)
|
|
||||||
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
|
||||||
text_input = pipeline.tokenizer(
|
|
||||||
[prompt],
|
|
||||||
padding="max_length",
|
|
||||||
max_length=pipeline.tokenizer.model_max_length,
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
|
||||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
|
||||||
|
|
||||||
return context
|
if not os.path.exists(local_path):
|
||||||
|
print(f"local file {local_path} does not exist. trying to download from {hf_repo}")
|
||||||
|
|
||||||
|
if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}"
|
||||||
|
else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}"
|
||||||
|
|
||||||
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
folder = "." if folder == "" else folder
|
||||||
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
os.makedirs(folder, exist_ok=True)
|
||||||
timestep, next_timestep = min(
|
snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename])
|
||||||
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
|
||||||
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
|
||||||
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
|
||||||
beta_prod_t = 1 - alpha_prod_t
|
|
||||||
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
|
||||||
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
|
||||||
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
|
||||||
return next_sample
|
|
||||||
|
|
||||||
|
|
||||||
def get_noise_pred_single(latents, t, context, unet):
|
|
||||||
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
|
||||||
return noise_pred
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
|
||||||
context = init_prompt(prompt, pipeline)
|
|
||||||
uncond_embeddings, cond_embeddings = context.chunk(2)
|
|
||||||
all_latent = [latent]
|
|
||||||
latent = latent.clone().detach()
|
|
||||||
for i in tqdm(range(num_inv_steps)):
|
|
||||||
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
|
||||||
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
|
||||||
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
|
||||||
all_latent.append(latent)
|
|
||||||
return all_latent
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
|
||||||
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
|
||||||
return ddim_latents
|
|
||||||
|
|
||||||
def load_weights(
|
def load_weights(
|
||||||
animation_pipeline,
|
animation_pipeline,
|
||||||
# motion module
|
# motion module
|
||||||
@@ -107,10 +104,16 @@ def load_weights(
|
|||||||
# motion module
|
# motion module
|
||||||
unet_state_dict = {}
|
unet_state_dict = {}
|
||||||
if motion_module_path != "":
|
if motion_module_path != "":
|
||||||
|
auto_download(motion_module_path, is_dreambooth_lora=False)
|
||||||
|
|
||||||
print(f"load motion module from {motion_module_path}")
|
print(f"load motion module from {motion_module_path}")
|
||||||
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
||||||
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
||||||
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
# filter parameters
|
||||||
|
for name, param in motion_module_state_dict.items():
|
||||||
|
if not "motion_modules." in name: continue
|
||||||
|
if "pos_encoder.pe" in name: continue
|
||||||
|
unet_state_dict.update({name: param})
|
||||||
unet_state_dict.pop("animatediff_config", "")
|
unet_state_dict.pop("animatediff_config", "")
|
||||||
|
|
||||||
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
||||||
@@ -119,6 +122,8 @@ def load_weights(
|
|||||||
|
|
||||||
# base model
|
# base model
|
||||||
if dreambooth_model_path != "":
|
if dreambooth_model_path != "":
|
||||||
|
auto_download(dreambooth_model_path, is_dreambooth_lora=True)
|
||||||
|
|
||||||
print(f"load dreambooth model from {dreambooth_model_path}")
|
print(f"load dreambooth model from {dreambooth_model_path}")
|
||||||
if dreambooth_model_path.endswith(".safetensors"):
|
if dreambooth_model_path.endswith(".safetensors"):
|
||||||
dreambooth_state_dict = {}
|
dreambooth_state_dict = {}
|
||||||
@@ -140,6 +145,8 @@ def load_weights(
|
|||||||
|
|
||||||
# lora layers
|
# lora layers
|
||||||
if lora_model_path != "":
|
if lora_model_path != "":
|
||||||
|
auto_download(lora_model_path, is_dreambooth_lora=True)
|
||||||
|
|
||||||
print(f"load lora model from {lora_model_path}")
|
print(f"load lora model from {lora_model_path}")
|
||||||
assert lora_model_path.endswith(".safetensors")
|
assert lora_model_path.endswith(".safetensors")
|
||||||
lora_state_dict = {}
|
lora_state_dict = {}
|
||||||
@@ -152,6 +159,8 @@ def load_weights(
|
|||||||
|
|
||||||
# domain adapter lora
|
# domain adapter lora
|
||||||
if adapter_lora_path != "":
|
if adapter_lora_path != "":
|
||||||
|
auto_download(adapter_lora_path, is_dreambooth_lora=False)
|
||||||
|
|
||||||
print(f"load domain lora from {adapter_lora_path}")
|
print(f"load domain lora from {adapter_lora_path}")
|
||||||
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
||||||
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
||||||
@@ -162,6 +171,9 @@ def load_weights(
|
|||||||
# motion module lora
|
# motion module lora
|
||||||
for motion_module_lora_config in motion_module_lora_configs:
|
for motion_module_lora_config in motion_module_lora_configs:
|
||||||
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
||||||
|
|
||||||
|
auto_download(path, is_dreambooth_lora=False)
|
||||||
|
|
||||||
print(f"load motion LoRA from {path}")
|
print(f"load motion LoRA from {path}")
|
||||||
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
||||||
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
||||||
|
|||||||
Reference in New Issue
Block a user