This commit is contained in:
Yuwei Guo
2023-07-09 23:25:46 +08:00
parent 5b702ae4e9
commit ebfd7b74f7
3 changed files with 96 additions and 84 deletions

View File

@@ -1,7 +1,9 @@
ToonYou: ToonYou:
base: "" base: ""
path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
motion_module: "models/Motion_Module/mm_sd_v14.ckpt" motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751] seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
steps: 25 steps: 25

View File

@@ -3,7 +3,8 @@ torch==1.12.1+cu113
torchvision==0.13.1+cu113 torchvision==0.13.1+cu113
diffusers[torch]==0.11.1 diffusers[torch]==0.11.1
transformers==4.25.1 transformers==4.25.1
imageio==2.27.0
gdown
einops einops
omegaconf omegaconf
safetensors safetensors
imageio==2.27.0

View File

@@ -37,94 +37,103 @@ def main(args):
config = OmegaConf.load(args.config) config = OmegaConf.load(args.config)
samples = [] samples = []
sample_idx = 0
for model_idx, (config_key, model_config) in enumerate(list(config.items())): for model_idx, (config_key, model_config) in enumerate(list(config.items())):
### >>> create validation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
).to("cuda")
# 1. unet ckpt
# 1.1 motion module
motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu")
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
assert len(unexpected) == 0
# 1.2 T2I motion_modules = model_config.motion_module
if model_config.path != "": motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
if model_config.path.endswith(".ckpt"): for motion_module in motion_modules:
state_dict = torch.load(model_config.path)
pipeline.unet.load_state_dict(state_dict) ### >>> create validation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
elif model_config.path.endswith(".safetensors"): text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
state_dict = {} vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
with safe_open(model_config.path, framework="pt", device="cpu") as f: unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
for key in f.keys():
state_dict[key] = f.get_tensor(key) pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
is_lora = all("lora" in k for k in state_dict.keys()) scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
if not is_lora: ).to("cuda")
base_state_dict = state_dict
else: # 1. unet ckpt
base_state_dict = {} # 1.1 motion module
with safe_open(model_config.base, framework="pt", device="cpu") as f: motion_module_state_dict = torch.load(motion_module, map_location="cpu")
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
assert len(unexpected) == 0
# 1.2 T2I
if model_config.path != "":
if model_config.path.endswith(".ckpt"):
state_dict = torch.load(model_config.path)
pipeline.unet.load_state_dict(state_dict)
elif model_config.path.endswith(".safetensors"):
state_dict = {}
with safe_open(model_config.path, framework="pt", device="cpu") as f:
for key in f.keys(): for key in f.keys():
base_state_dict[key] = f.get_tensor(key) state_dict[key] = f.get_tensor(key)
# vae is_lora = all("lora" in k for k in state_dict.keys())
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config) if not is_lora:
pipeline.vae.load_state_dict(converted_vae_checkpoint) base_state_dict = state_dict
# unet else:
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config) base_state_dict = {}
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) with safe_open(model_config.base, framework="pt", device="cpu") as f:
# text_model for key in f.keys():
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict) base_state_dict[key] = f.get_tensor(key)
# import pdb # vae
# pdb.set_trace() converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
if is_lora: pipeline.vae.load_state_dict(converted_vae_checkpoint)
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha) # unet
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
# text_model
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
# import pdb
# pdb.set_trace()
if is_lora:
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
pipeline.to("cuda") pipeline.to("cuda")
### <<< create validation pipeline <<< ### ### <<< create validation pipeline <<< ###
prompts = model_config.prompt prompts = model_config.prompt
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
random_seeds = model_config.pop("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
config[config_key].random_seed = []
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
# manually set random seed for reproduction random_seeds = model_config.get("seed", [-1])
if random_seed != -1: torch.manual_seed(random_seed) random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
else: torch.seed() random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
config[config_key].random_seed.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}") config[config_key].random_seed = []
print(f"sampling {prompt} ...") for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
sample = pipeline(
prompt, # manually set random seed for reproduction
negative_prompt = n_prompt, if random_seed != -1: torch.manual_seed(random_seed)
num_inference_steps = model_config.steps, else: torch.seed()
guidance_scale = model_config.guidance_scale, config[config_key].random_seed.append(torch.initial_seed())
width = args.W,
height = args.H, print(f"current seed: {torch.initial_seed()}")
video_length = args.L, print(f"sampling {prompt} ...")
).videos sample = pipeline(
samples.append(sample) prompt,
negative_prompt = n_prompt,
num_inference_steps = model_config.steps,
guidance_scale = model_config.guidance_scale,
width = args.W,
height = args.H,
video_length = args.L,
).videos
samples.append(sample)
prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif") save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
print(f"save to {savedir}/sample/{prompt}.gif") print(f"save to {savedir}/sample/{prompt}.gif")
sample_idx += 1
samples = torch.concat(samples) samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)