mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 01:36:20 +02:00
update
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
ToonYou:
|
ToonYou:
|
||||||
base: ""
|
base: ""
|
||||||
path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
|
path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
|
||||||
motion_module: "models/Motion_Module/mm_sd_v14.ckpt"
|
motion_module:
|
||||||
|
- "models/Motion_Module/mm_sd_v14.ckpt"
|
||||||
|
- "models/Motion_Module/mm_sd_v15.ckpt"
|
||||||
|
|
||||||
seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
|
seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
|
||||||
steps: 25
|
steps: 25
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ torch==1.12.1+cu113
|
|||||||
torchvision==0.13.1+cu113
|
torchvision==0.13.1+cu113
|
||||||
diffusers[torch]==0.11.1
|
diffusers[torch]==0.11.1
|
||||||
transformers==4.25.1
|
transformers==4.25.1
|
||||||
|
imageio==2.27.0
|
||||||
|
gdown
|
||||||
einops
|
einops
|
||||||
omegaconf
|
omegaconf
|
||||||
safetensors
|
safetensors
|
||||||
imageio==2.27.0
|
|
||||||
|
|||||||
@@ -37,94 +37,103 @@ def main(args):
|
|||||||
|
|
||||||
config = OmegaConf.load(args.config)
|
config = OmegaConf.load(args.config)
|
||||||
samples = []
|
samples = []
|
||||||
|
|
||||||
|
sample_idx = 0
|
||||||
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
|
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
|
||||||
### >>> create validation pipeline >>> ###
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
|
||||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
|
||||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
|
||||||
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
|
||||||
|
|
||||||
pipeline = AnimationPipeline(
|
|
||||||
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
|
||||||
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
|
||||||
).to("cuda")
|
|
||||||
|
|
||||||
# 1. unet ckpt
|
|
||||||
# 1.1 motion module
|
|
||||||
motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu")
|
|
||||||
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
|
||||||
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
|
||||||
assert len(unexpected) == 0
|
|
||||||
|
|
||||||
# 1.2 T2I
|
motion_modules = model_config.motion_module
|
||||||
if model_config.path != "":
|
motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
|
||||||
if model_config.path.endswith(".ckpt"):
|
for motion_module in motion_modules:
|
||||||
state_dict = torch.load(model_config.path)
|
|
||||||
pipeline.unet.load_state_dict(state_dict)
|
### >>> create validation pipeline >>> ###
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
||||||
elif model_config.path.endswith(".safetensors"):
|
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
||||||
state_dict = {}
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
||||||
with safe_open(model_config.path, framework="pt", device="cpu") as f:
|
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
||||||
for key in f.keys():
|
|
||||||
state_dict[key] = f.get_tensor(key)
|
pipeline = AnimationPipeline(
|
||||||
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
||||||
is_lora = all("lora" in k for k in state_dict.keys())
|
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
||||||
if not is_lora:
|
).to("cuda")
|
||||||
base_state_dict = state_dict
|
|
||||||
else:
|
# 1. unet ckpt
|
||||||
base_state_dict = {}
|
# 1.1 motion module
|
||||||
with safe_open(model_config.base, framework="pt", device="cpu") as f:
|
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
|
||||||
|
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
||||||
|
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
||||||
|
assert len(unexpected) == 0
|
||||||
|
|
||||||
|
# 1.2 T2I
|
||||||
|
if model_config.path != "":
|
||||||
|
if model_config.path.endswith(".ckpt"):
|
||||||
|
state_dict = torch.load(model_config.path)
|
||||||
|
pipeline.unet.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
elif model_config.path.endswith(".safetensors"):
|
||||||
|
state_dict = {}
|
||||||
|
with safe_open(model_config.path, framework="pt", device="cpu") as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
base_state_dict[key] = f.get_tensor(key)
|
state_dict[key] = f.get_tensor(key)
|
||||||
|
|
||||||
# vae
|
is_lora = all("lora" in k for k in state_dict.keys())
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
|
if not is_lora:
|
||||||
pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
base_state_dict = state_dict
|
||||||
# unet
|
else:
|
||||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
|
base_state_dict = {}
|
||||||
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
with safe_open(model_config.base, framework="pt", device="cpu") as f:
|
||||||
# text_model
|
for key in f.keys():
|
||||||
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
|
base_state_dict[key] = f.get_tensor(key)
|
||||||
|
|
||||||
# import pdb
|
# vae
|
||||||
# pdb.set_trace()
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
|
||||||
if is_lora:
|
pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
||||||
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
|
# unet
|
||||||
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
|
||||||
|
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||||
|
# text_model
|
||||||
|
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
|
||||||
|
|
||||||
|
# import pdb
|
||||||
|
# pdb.set_trace()
|
||||||
|
if is_lora:
|
||||||
|
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
|
||||||
|
|
||||||
pipeline.to("cuda")
|
pipeline.to("cuda")
|
||||||
### <<< create validation pipeline <<< ###
|
### <<< create validation pipeline <<< ###
|
||||||
|
|
||||||
prompts = model_config.prompt
|
prompts = model_config.prompt
|
||||||
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
||||||
|
|
||||||
random_seeds = model_config.pop("seed", [-1])
|
|
||||||
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
|
||||||
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
|
||||||
|
|
||||||
config[config_key].random_seed = []
|
|
||||||
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
|
||||||
|
|
||||||
# manually set random seed for reproduction
|
random_seeds = model_config.get("seed", [-1])
|
||||||
if random_seed != -1: torch.manual_seed(random_seed)
|
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
||||||
else: torch.seed()
|
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
||||||
config[config_key].random_seed.append(torch.initial_seed())
|
|
||||||
|
|
||||||
print(f"current seed: {torch.initial_seed()}")
|
config[config_key].random_seed = []
|
||||||
print(f"sampling {prompt} ...")
|
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
||||||
sample = pipeline(
|
|
||||||
prompt,
|
# manually set random seed for reproduction
|
||||||
negative_prompt = n_prompt,
|
if random_seed != -1: torch.manual_seed(random_seed)
|
||||||
num_inference_steps = model_config.steps,
|
else: torch.seed()
|
||||||
guidance_scale = model_config.guidance_scale,
|
config[config_key].random_seed.append(torch.initial_seed())
|
||||||
width = args.W,
|
|
||||||
height = args.H,
|
print(f"current seed: {torch.initial_seed()}")
|
||||||
video_length = args.L,
|
print(f"sampling {prompt} ...")
|
||||||
).videos
|
sample = pipeline(
|
||||||
samples.append(sample)
|
prompt,
|
||||||
|
negative_prompt = n_prompt,
|
||||||
|
num_inference_steps = model_config.steps,
|
||||||
|
guidance_scale = model_config.guidance_scale,
|
||||||
|
width = args.W,
|
||||||
|
height = args.H,
|
||||||
|
video_length = args.L,
|
||||||
|
).videos
|
||||||
|
samples.append(sample)
|
||||||
|
|
||||||
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
||||||
save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif")
|
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
|
||||||
print(f"save to {savedir}/sample/{prompt}.gif")
|
print(f"save to {savedir}/sample/{prompt}.gif")
|
||||||
|
|
||||||
|
sample_idx += 1
|
||||||
|
|
||||||
samples = torch.concat(samples)
|
samples = torch.concat(samples)
|
||||||
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
|
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
|
||||||
|
|||||||
Reference in New Issue
Block a user