fix model

This commit is contained in:
Yuwei
2024-07-17 08:03:42 +00:00
parent cf80ddeb47
commit 786a99cc7f
3 changed files with 140 additions and 71 deletions

View File

@@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module):
pe = torch.zeros(1, max_len, d_model) pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe) self.register_buffer('pe', pe, persistent=False)
def forward(self, x): def forward(self, x):
x = x + self.pe[:, :x.size(1)] x = x + self.pe[:, :x.size(1)]
@@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention):
attention_mode = None, attention_mode = None,
cross_frame_attention_mode = None, cross_frame_attention_mode = None,
temporal_position_encoding = False, temporal_position_encoding = False,
temporal_position_encoding_max_len = 24, temporal_position_encoding_max_len = 32,
*args, **kwargs *args, **kwargs
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)

View File

@@ -475,16 +475,77 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
return UNet3DConditionOutput(sample=sample) return UNet3DConditionOutput(sample=sample)
@classmethod @classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs):
if subfolder is not None: from diffusers import __version__
pretrained_model_path = os.path.join(pretrained_model_path, subfolder) from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available
print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") from diffusers.modeling_utils import load_state_dict
print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...")
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
user_agent = {
"diffusers": __version__,
"file_type": "model",
"framework": "pytorch",
}
model_file = None
if is_safetensors_available():
try:
model_file = cls._get_model_file(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
except:
pass
if model_file is None:
model_file = cls._get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
config, unused_kwargs = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
device_map=device_map,
**kwargs,
)
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
config["_class_name"] = cls.__name__ config["_class_name"] = cls.__name__
config["down_block_types"] = [ config["down_block_types"] = [
"CrossAttnDownBlock3D", "CrossAttnDownBlock3D",
@@ -499,12 +560,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
"CrossAttnUpBlock3D" "CrossAttnUpBlock3D"
] ]
from diffusers.utils import WEIGHTS_NAME model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs)
model = cls.from_config(config, **unet_additional_kwargs) state_dict = load_state_dict(model_file)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")
m, u = model.load_state_dict(state_dict, strict=False) m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")

View File

@@ -7,6 +7,7 @@ import torch
import torchvision import torchvision
import torch.distributed as dist import torch.distributed as dist
from huggingface_hub import snapshot_download
from safetensors import safe_open from safetensors import safe_open
from tqdm import tqdm from tqdm import tqdm
from einops import rearrange from einops import rearrange
@@ -14,6 +15,45 @@ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, con
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
MOTION_MODULES = [
"mm_sd_v14.ckpt",
"mm_sd_v15.ckpt",
"mm_sd_v15_v2.ckpt",
"v3_sd15_mm.ckpt",
]
ADAPTERS = [
# "mm_sd_v14.ckpt",
# "mm_sd_v15.ckpt",
# "mm_sd_v15_v2.ckpt",
# "mm_sdxl_v10_beta.ckpt",
"v2_lora_PanLeft.ckpt",
"v2_lora_PanRight.ckpt",
"v2_lora_RollingAnticlockwise.ckpt",
"v2_lora_RollingClockwise.ckpt",
"v2_lora_TiltDown.ckpt",
"v2_lora_TiltUp.ckpt",
"v2_lora_ZoomIn.ckpt",
"v2_lora_ZoomOut.ckpt",
"v3_sd15_adapter.ckpt",
# "v3_sd15_mm.ckpt",
"v3_sd15_sparsectrl_rgb.ckpt",
"v3_sd15_sparsectrl_scribble.ckpt",
]
BACKUP_DREAMBOOTH_MODELS = [
"realisticVisionV60B1_v51VAE.safetensors",
"majicmixRealistic_v4.safetensors",
"leosamsFilmgirlUltra_velvia20Lora.safetensors",
"toonyou_beta3.safetensors",
"majicmixRealistic_v5Preview.safetensors",
"rcnzCartoon3d_v10.safetensors",
"lyriel_v16.safetensors",
"leosamsHelloworldXL_filmGrain20.safetensors",
"TUSUN.safetensors",
]
def zero_rank_print(s): def zero_rank_print(s):
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
@@ -33,64 +73,21 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f
imageio.mimsave(path, outputs, fps=fps) imageio.mimsave(path, outputs, fps=fps)
# DDIM Inversion def auto_download(local_path, is_dreambooth_lora=False):
@torch.no_grad() hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff"
def init_prompt(prompt, pipeline): folder, filename = os.path.split(local_path)
uncond_input = pipeline.tokenizer(
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
return_tensors="pt"
)
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
text_input = pipeline.tokenizer(
[prompt],
padding="max_length",
max_length=pipeline.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
context = torch.cat([uncond_embeddings, text_embeddings])
return context if not os.path.exists(local_path):
print(f"local file {local_path} does not exist. trying to download from {hf_repo}")
if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}"
else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}"
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, folder = "." if folder == "" else folder
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): os.makedirs(folder, exist_ok=True)
timestep, next_timestep = min( snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename])
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
beta_prod_t = 1 - alpha_prod_t
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
return next_sample
def get_noise_pred_single(latents, t, context, unet):
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
return noise_pred
@torch.no_grad()
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
context = init_prompt(prompt, pipeline)
uncond_embeddings, cond_embeddings = context.chunk(2)
all_latent = [latent]
latent = latent.clone().detach()
for i in tqdm(range(num_inv_steps)):
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
latent = next_step(noise_pred, t, latent, ddim_scheduler)
all_latent.append(latent)
return all_latent
@torch.no_grad()
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
return ddim_latents
def load_weights( def load_weights(
animation_pipeline, animation_pipeline,
# motion module # motion module
@@ -107,10 +104,16 @@ def load_weights(
# motion module # motion module
unet_state_dict = {} unet_state_dict = {}
if motion_module_path != "": if motion_module_path != "":
auto_download(motion_module_path, is_dreambooth_lora=False)
print(f"load motion module from {motion_module_path}") print(f"load motion module from {motion_module_path}")
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) # filter parameters
for name, param in motion_module_state_dict.items():
if not "motion_modules." in name: continue
if "pos_encoder.pe" in name: continue
unet_state_dict.update({name: param})
unet_state_dict.pop("animatediff_config", "") unet_state_dict.pop("animatediff_config", "")
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
@@ -119,6 +122,8 @@ def load_weights(
# base model # base model
if dreambooth_model_path != "": if dreambooth_model_path != "":
auto_download(dreambooth_model_path, is_dreambooth_lora=True)
print(f"load dreambooth model from {dreambooth_model_path}") print(f"load dreambooth model from {dreambooth_model_path}")
if dreambooth_model_path.endswith(".safetensors"): if dreambooth_model_path.endswith(".safetensors"):
dreambooth_state_dict = {} dreambooth_state_dict = {}
@@ -140,6 +145,8 @@ def load_weights(
# lora layers # lora layers
if lora_model_path != "": if lora_model_path != "":
auto_download(lora_model_path, is_dreambooth_lora=True)
print(f"load lora model from {lora_model_path}") print(f"load lora model from {lora_model_path}")
assert lora_model_path.endswith(".safetensors") assert lora_model_path.endswith(".safetensors")
lora_state_dict = {} lora_state_dict = {}
@@ -152,6 +159,8 @@ def load_weights(
# domain adapter lora # domain adapter lora
if adapter_lora_path != "": if adapter_lora_path != "":
auto_download(adapter_lora_path, is_dreambooth_lora=False)
print(f"load domain lora from {adapter_lora_path}") print(f"load domain lora from {adapter_lora_path}")
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
@@ -162,6 +171,9 @@ def load_weights(
# motion module lora # motion module lora
for motion_module_lora_config in motion_module_lora_configs: for motion_module_lora_config in motion_module_lora_configs:
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
auto_download(path, is_dreambooth_lora=False)
print(f"load motion LoRA from {path}") print(f"load motion LoRA from {path}")
motion_lora_state_dict = torch.load(path, map_location="cpu") motion_lora_state_dict = torch.load(path, map_location="cpu")
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict