mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
update
This commit is contained in:
@@ -1,7 +1,9 @@
|
||||
ToonYou:
|
||||
base: ""
|
||||
path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
|
||||
motion_module: "models/Motion_Module/mm_sd_v14.ckpt"
|
||||
motion_module:
|
||||
- "models/Motion_Module/mm_sd_v14.ckpt"
|
||||
- "models/Motion_Module/mm_sd_v15.ckpt"
|
||||
|
||||
seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
|
||||
steps: 25
|
||||
|
||||
@@ -3,7 +3,8 @@ torch==1.12.1+cu113
|
||||
torchvision==0.13.1+cu113
|
||||
diffusers[torch]==0.11.1
|
||||
transformers==4.25.1
|
||||
imageio==2.27.0
|
||||
gdown
|
||||
einops
|
||||
omegaconf
|
||||
safetensors
|
||||
imageio==2.27.0
|
||||
|
||||
@@ -37,7 +37,14 @@ def main(args):
|
||||
|
||||
config = OmegaConf.load(args.config)
|
||||
samples = []
|
||||
|
||||
sample_idx = 0
|
||||
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
|
||||
|
||||
motion_modules = model_config.motion_module
|
||||
motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
|
||||
for motion_module in motion_modules:
|
||||
|
||||
### >>> create validation pipeline >>> ###
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
||||
@@ -51,7 +58,7 @@ def main(args):
|
||||
|
||||
# 1. unet ckpt
|
||||
# 1.1 motion module
|
||||
motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu")
|
||||
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
|
||||
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
||||
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
||||
assert len(unexpected) == 0
|
||||
@@ -97,7 +104,7 @@ def main(args):
|
||||
prompts = model_config.prompt
|
||||
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
||||
|
||||
random_seeds = model_config.pop("seed", [-1])
|
||||
random_seeds = model_config.get("seed", [-1])
|
||||
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
||||
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
||||
|
||||
@@ -123,9 +130,11 @@ def main(args):
|
||||
samples.append(sample)
|
||||
|
||||
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
||||
save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif")
|
||||
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
|
||||
print(f"save to {savedir}/sample/{prompt}.gif")
|
||||
|
||||
sample_idx += 1
|
||||
|
||||
samples = torch.concat(samples)
|
||||
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user