diff --git a/animatediff/models/motion_module.py b/animatediff/models/motion_module.py index 2359e71..37d8f2f 100644 --- a/animatediff/models/motion_module.py +++ b/animatediff/models/motion_module.py @@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module): pe = torch.zeros(1, max_len, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + self.register_buffer('pe', pe, persistent=False) def forward(self, x): x = x + self.pe[:, :x.size(1)] @@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention): attention_mode = None, cross_frame_attention_mode = None, temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, + temporal_position_encoding_max_len = 32, *args, **kwargs ): super().__init__(*args, **kwargs) diff --git a/animatediff/models/unet.py b/animatediff/models/unet.py index 1d77e78..8e5bb90 100644 --- a/animatediff/models/unet.py +++ b/animatediff/models/unet.py @@ -475,16 +475,77 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): return UNet3DConditionOutput(sample=sample) @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): - if subfolder is not None: - pretrained_model_path = os.path.join(pretrained_model_path, subfolder) - print(f"loaded 3D unet's pretrained weights from {pretrained_model_path} ...") + def from_pretrained_2d(cls, pretrained_model_name_or_path, unet_additional_kwargs={}, **kwargs): + from diffusers import __version__ + from diffusers.utils import DIFFUSERS_CACHE, SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, is_safetensors_available + from diffusers.modeling_utils import load_state_dict + print(f"loaded 3D unet's pretrained weights from {pretrained_model_name_or_path} ...") + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + device_map = kwargs.pop("device_map", None) + + user_agent = { + "diffusers": __version__, + "file_type": "model", + "framework": "pytorch", + } + + model_file = None + if is_safetensors_available(): + try: + model_file = cls._get_model_file( + pretrained_model_name_or_path, + weights_name=SAFETENSORS_WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + except: + pass + + if model_file is None: + model_file = cls._get_model_file( + pretrained_model_name_or_path, + weights_name=WEIGHTS_NAME, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + + config, unused_kwargs = cls.load_config( + pretrained_model_name_or_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + subfolder=subfolder, + device_map=device_map, + **kwargs, + ) - config_file = os.path.join(pretrained_model_path, 'config.json') - if not os.path.isfile(config_file): - raise RuntimeError(f"{config_file} does not exist") - with open(config_file, "r") as f: - config = json.load(f) config["_class_name"] = cls.__name__ config["down_block_types"] = [ "CrossAttnDownBlock3D", @@ -499,12 +560,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): "CrossAttnUpBlock3D" ] - from diffusers.utils import WEIGHTS_NAME - model = cls.from_config(config, **unet_additional_kwargs) - model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) - if not os.path.isfile(model_file): - raise RuntimeError(f"{model_file} does not exist") - state_dict = torch.load(model_file, map_location="cpu") + model = cls.from_config(config, **unused_kwargs, **unet_additional_kwargs) + state_dict = load_state_dict(model_file) m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py index c094483..e01ba58 100644 --- a/animatediff/utils/util.py +++ b/animatediff/utils/util.py @@ -7,6 +7,7 @@ import torch import torchvision import torch.distributed as dist +from huggingface_hub import snapshot_download from safetensors import safe_open from tqdm import tqdm from einops import rearrange @@ -14,6 +15,45 @@ from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, con from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora +MOTION_MODULES = [ + "mm_sd_v14.ckpt", + "mm_sd_v15.ckpt", + "mm_sd_v15_v2.ckpt", + "v3_sd15_mm.ckpt", +] + +ADAPTERS = [ + # "mm_sd_v14.ckpt", + # "mm_sd_v15.ckpt", + # "mm_sd_v15_v2.ckpt", + # "mm_sdxl_v10_beta.ckpt", + "v2_lora_PanLeft.ckpt", + "v2_lora_PanRight.ckpt", + "v2_lora_RollingAnticlockwise.ckpt", + "v2_lora_RollingClockwise.ckpt", + "v2_lora_TiltDown.ckpt", + "v2_lora_TiltUp.ckpt", + "v2_lora_ZoomIn.ckpt", + "v2_lora_ZoomOut.ckpt", + "v3_sd15_adapter.ckpt", + # "v3_sd15_mm.ckpt", + "v3_sd15_sparsectrl_rgb.ckpt", + "v3_sd15_sparsectrl_scribble.ckpt", +] + +BACKUP_DREAMBOOTH_MODELS = [ + "realisticVisionV60B1_v51VAE.safetensors", + "majicmixRealistic_v4.safetensors", + "leosamsFilmgirlUltra_velvia20Lora.safetensors", + "toonyou_beta3.safetensors", + "majicmixRealistic_v5Preview.safetensors", + "rcnzCartoon3d_v10.safetensors", + "lyriel_v16.safetensors", + "leosamsHelloworldXL_filmGrain20.safetensors", + "TUSUN.safetensors", +] + + def zero_rank_print(s): if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) @@ -33,64 +73,21 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f imageio.mimsave(path, outputs, fps=fps) -# DDIM Inversion -@torch.no_grad() -def init_prompt(prompt, pipeline): - uncond_input = pipeline.tokenizer( - [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, - return_tensors="pt" - ) - uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] - text_input = pipeline.tokenizer( - [prompt], - padding="max_length", - max_length=pipeline.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] - context = torch.cat([uncond_embeddings, text_embeddings]) +def auto_download(local_path, is_dreambooth_lora=False): + hf_repo = "guoyww/animatediff_t2i_backups" if is_dreambooth_lora else "guoyww/animatediff" + folder, filename = os.path.split(local_path) - return context + if not os.path.exists(local_path): + print(f"local file {local_path} does not exist. trying to download from {hf_repo}") + if is_dreambooth_lora: assert filename in BACKUP_DREAMBOOTH_MODELS, f"{filename} dose not exist in {hf_repo}" + else: assert filename in MOTION_MODULES + ADAPTERS, f"{filename} dose not exist in {hf_repo}" -def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, - sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): - timestep, next_timestep = min( - timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep - alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod - alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] - beta_prod_t = 1 - alpha_prod_t - next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 - next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output - next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction - return next_sample + folder = "." if folder == "" else folder + os.makedirs(folder, exist_ok=True) + snapshot_download(repo_id=hf_repo, local_dir=folder, allow_patterns=[filename]) -def get_noise_pred_single(latents, t, context, unet): - noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] - return noise_pred - - -@torch.no_grad() -def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): - context = init_prompt(prompt, pipeline) - uncond_embeddings, cond_embeddings = context.chunk(2) - all_latent = [latent] - latent = latent.clone().detach() - for i in tqdm(range(num_inv_steps)): - t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] - noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) - latent = next_step(noise_pred, t, latent, ddim_scheduler) - all_latent.append(latent) - return all_latent - - -@torch.no_grad() -def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): - ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) - return ddim_latents - def load_weights( animation_pipeline, # motion module @@ -107,10 +104,16 @@ def load_weights( # motion module unet_state_dict = {} if motion_module_path != "": + auto_download(motion_module_path, is_dreambooth_lora=False) + print(f"load motion module from {motion_module_path}") motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict - unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) + # filter parameters + for name, param in motion_module_state_dict.items(): + if not "motion_modules." in name: continue + if "pos_encoder.pe" in name: continue + unet_state_dict.update({name: param}) unet_state_dict.pop("animatediff_config", "") missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) @@ -119,6 +122,8 @@ def load_weights( # base model if dreambooth_model_path != "": + auto_download(dreambooth_model_path, is_dreambooth_lora=True) + print(f"load dreambooth model from {dreambooth_model_path}") if dreambooth_model_path.endswith(".safetensors"): dreambooth_state_dict = {} @@ -140,6 +145,8 @@ def load_weights( # lora layers if lora_model_path != "": + auto_download(lora_model_path, is_dreambooth_lora=True) + print(f"load lora model from {lora_model_path}") assert lora_model_path.endswith(".safetensors") lora_state_dict = {} @@ -152,6 +159,8 @@ def load_weights( # domain adapter lora if adapter_lora_path != "": + auto_download(adapter_lora_path, is_dreambooth_lora=False) + print(f"load domain lora from {adapter_lora_path}") domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict @@ -162,6 +171,9 @@ def load_weights( # motion module lora for motion_module_lora_config in motion_module_lora_configs: path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] + + auto_download(path, is_dreambooth_lora=False) + print(f"load motion LoRA from {path}") motion_lora_state_dict = torch.load(path, map_location="cpu") motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict