diff --git a/configs/prompts/1-ToonYou.yaml b/configs/prompts/1-ToonYou.yaml index b162276..2e938c2 100644 --- a/configs/prompts/1-ToonYou.yaml +++ b/configs/prompts/1-ToonYou.yaml @@ -1,7 +1,9 @@ ToonYou: - base: "" - path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" - motion_module: "models/Motion_Module/mm_sd_v14.ckpt" + base: "" + path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" + motion_module: + - "models/Motion_Module/mm_sd_v14.ckpt" + - "models/Motion_Module/mm_sd_v15.ckpt" seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751] steps: 25 diff --git a/requirements.txt b/requirements.txt index 14a47b2..d642cfc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,8 @@ torch==1.12.1+cu113 torchvision==0.13.1+cu113 diffusers[torch]==0.11.1 transformers==4.25.1 +imageio==2.27.0 +gdown einops omegaconf safetensors -imageio==2.27.0 diff --git a/scripts/animate.py b/scripts/animate.py index 9fe75c1..f22cc4a 100644 --- a/scripts/animate.py +++ b/scripts/animate.py @@ -37,94 +37,103 @@ def main(args): config = OmegaConf.load(args.config) samples = [] + + sample_idx = 0 for model_idx, (config_key, model_config) in enumerate(list(config.items())): - ### >>> create validation pipeline >>> ### - tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") - vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") - unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) - - pipeline = AnimationPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, - scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), - ).to("cuda") - - # 1. unet ckpt - # 1.1 motion module - motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu") - if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) - missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) - assert len(unexpected) == 0 - # 1.2 T2I - if model_config.path != "": - if model_config.path.endswith(".ckpt"): - state_dict = torch.load(model_config.path) - pipeline.unet.load_state_dict(state_dict) - - elif model_config.path.endswith(".safetensors"): - state_dict = {} - with safe_open(model_config.path, framework="pt", device="cpu") as f: - for key in f.keys(): - state_dict[key] = f.get_tensor(key) - - is_lora = all("lora" in k for k in state_dict.keys()) - if not is_lora: - base_state_dict = state_dict - else: - base_state_dict = {} - with safe_open(model_config.base, framework="pt", device="cpu") as f: + motion_modules = model_config.motion_module + motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules) + for motion_module in motion_modules: + + ### >>> create validation pipeline >>> ### + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") + unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + + pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + ).to("cuda") + + # 1. unet ckpt + # 1.1 motion module + motion_module_state_dict = torch.load(motion_module, map_location="cpu") + if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) + missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) + assert len(unexpected) == 0 + + # 1.2 T2I + if model_config.path != "": + if model_config.path.endswith(".ckpt"): + state_dict = torch.load(model_config.path) + pipeline.unet.load_state_dict(state_dict) + + elif model_config.path.endswith(".safetensors"): + state_dict = {} + with safe_open(model_config.path, framework="pt", device="cpu") as f: for key in f.keys(): - base_state_dict[key] = f.get_tensor(key) - - # vae - converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) - pipeline.vae.load_state_dict(converted_vae_checkpoint) - # unet - converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) - pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) - # text_model - pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict) - - # import pdb - # pdb.set_trace() - if is_lora: - pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) + state_dict[key] = f.get_tensor(key) + + is_lora = all("lora" in k for k in state_dict.keys()) + if not is_lora: + base_state_dict = state_dict + else: + base_state_dict = {} + with safe_open(model_config.base, framework="pt", device="cpu") as f: + for key in f.keys(): + base_state_dict[key] = f.get_tensor(key) + + # vae + converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) + pipeline.vae.load_state_dict(converted_vae_checkpoint) + # unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) + pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) + # text_model + pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict) + + # import pdb + # pdb.set_trace() + if is_lora: + pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) - pipeline.to("cuda") - ### <<< create validation pipeline <<< ### + pipeline.to("cuda") + ### <<< create validation pipeline <<< ### - prompts = model_config.prompt - n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt - - random_seeds = model_config.pop("seed", [-1]) - random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) - random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds - - config[config_key].random_seed = [] - for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): + prompts = model_config.prompt + n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt - # manually set random seed for reproduction - if random_seed != -1: torch.manual_seed(random_seed) - else: torch.seed() - config[config_key].random_seed.append(torch.initial_seed()) + random_seeds = model_config.get("seed", [-1]) + random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds) + random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds - print(f"current seed: {torch.initial_seed()}") - print(f"sampling {prompt} ...") - sample = pipeline( - prompt, - negative_prompt = n_prompt, - num_inference_steps = model_config.steps, - guidance_scale = model_config.guidance_scale, - width = args.W, - height = args.H, - video_length = args.L, - ).videos - samples.append(sample) + config[config_key].random_seed = [] + for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)): + + # manually set random seed for reproduction + if random_seed != -1: torch.manual_seed(random_seed) + else: torch.seed() + config[config_key].random_seed.append(torch.initial_seed()) + + print(f"current seed: {torch.initial_seed()}") + print(f"sampling {prompt} ...") + sample = pipeline( + prompt, + negative_prompt = n_prompt, + num_inference_steps = model_config.steps, + guidance_scale = model_config.guidance_scale, + width = args.W, + height = args.H, + video_length = args.L, + ).videos + samples.append(sample) - prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) - save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif") - print(f"save to {savedir}/sample/{prompt}.gif") + prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) + save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") + print(f"save to {savedir}/sample/{prompt}.gif") + + sample_idx += 1 samples = torch.concat(samples) save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)