From e816747d661bffc4c709f16975a4dc4a897ba4bd Mon Sep 17 00:00:00 2001 From: Yuwei Guo Date: Sun, 20 Aug 2023 17:02:57 +0800 Subject: [PATCH] training script --- .gitignore | 2 + README.md | 31 +- animatediff/data/dataset.py | 98 ++++++ animatediff/utils/util.py | 5 + configs/training/image_finetune.yaml | 48 +++ configs/training/training.yaml | 66 ++++ environment.yaml | 2 + train.py | 493 +++++++++++++++++++++++++++ 8 files changed, 744 insertions(+), 1 deletion(-) create mode 100644 animatediff/data/dataset.py create mode 100644 configs/training/image_finetune.yaml create mode 100644 configs/training/training.yaml create mode 100644 train.py diff --git a/.gitignore b/.gitignore index 4a1bef6..f95c0f1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ samples/ +wandb/ +outputs/ __pycache__/ models/StableDiffusion/stable-diffusion-v1-5 scripts/animate_inter.py diff --git a/README.md b/README.md index 4ca23f1..37bd5c4 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ Contributions are always welcome!! The dev branch is for community -## Setup for Inference +## Setups for Inference ### Prepare Environment @@ -139,6 +139,35 @@ Then run the following commands: python -m scripts.animate --config [path to the config file] ``` + +## Steps for Training + +### Dataset +Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine. +Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`. + +### Configuration +After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder: +``` +train_data: + csv_path: [Replace with .csv Annotation File Path] + video_folder: [Replace with Video Folder Path] + sample_size: 256 +``` +Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files. + +### Training +To train motion modules +``` +torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml +``` + +To finetune the unet's image layers +``` +torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml +``` + + ## Gradio Demo We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands: ``` diff --git a/animatediff/data/dataset.py b/animatediff/data/dataset.py new file mode 100644 index 0000000..3f6ec10 --- /dev/null +++ b/animatediff/data/dataset.py @@ -0,0 +1,98 @@ +import os, io, csv, math, random +import numpy as np +from einops import rearrange +from decord import VideoReader + +import torch +import torchvision.transforms as transforms +from torch.utils.data.dataset import Dataset +from animatediff.utils.util import zero_rank_print + + + +class WebVid10M(Dataset): + def __init__( + self, + csv_path, video_folder, + sample_size=256, sample_stride=4, sample_n_frames=16, + is_image=False, + ): + zero_rank_print(f"loading annotations from {csv_path} ...") + with open(csv_path, 'r') as csvfile: + self.dataset = list(csv.DictReader(csvfile)) + self.length = len(self.dataset) + zero_rank_print(f"data scale: {self.length}") + + self.video_folder = video_folder + self.sample_stride = sample_stride + self.sample_n_frames = sample_n_frames + self.is_image = is_image + + sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) + self.pixel_transforms = transforms.Compose([ + transforms.RandomHorizontalFlip(), + transforms.Resize(sample_size[0]), + transforms.CenterCrop(sample_size), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), + ]) + + def get_batch(self, idx): + video_dict = self.dataset[idx] + videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] + + video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") + video_reader = VideoReader(video_dir) + video_length = len(video_reader) + + if not self.is_image: + clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) + start_idx = random.randint(0, video_length - clip_length) + batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) + else: + batch_index = [random.randint(0, video_length - 1)] + + pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() + pixel_values = pixel_values / 255. + del video_reader + + if self.is_image: + pixel_values = pixel_values[0] + + return pixel_values, name + + def __len__(self): + return self.length + + def __getitem__(self, idx): + while True: + try: + pixel_values, name = self.get_batch(idx) + break + + except Exception as e: + idx = random.randint(0, self.length-1) + + pixel_values = self.pixel_transforms(pixel_values) + sample = dict(pixel_values=pixel_values, text=name) + return sample + + + +if __name__ == "__main__": + from animatediff.utils.util import save_videos_grid + + dataset = WebVid10M( + csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv", + video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", + sample_size=256, + sample_stride=4, sample_n_frames=16, + is_image=True, + ) + import pdb + pdb.set_trace() + + dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) + for idx, batch in enumerate(dataloader): + print(batch["pixel_values"].shape, len(batch["text"])) + # for i in range(batch["pixel_values"].shape[0]): + # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py index 83f3161..ee2dd2b 100644 --- a/animatediff/utils/util.py +++ b/animatediff/utils/util.py @@ -5,11 +5,16 @@ from typing import Union import torch import torchvision +import torch.distributed as dist from tqdm import tqdm from einops import rearrange +def zero_rank_print(s): + if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) + + def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): videos = rearrange(videos, "b c t h w -> t b c h w") outputs = [] diff --git a/configs/training/image_finetune.yaml b/configs/training/image_finetune.yaml new file mode 100644 index 0000000..ea05fd1 --- /dev/null +++ b/configs/training/image_finetune.yaml @@ -0,0 +1,48 @@ +image_finetune: true + +output_dir: "outputs" +pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "scaled_linear" + steps_offset: 1 + clip_sample: false + +train_data: + csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv" + video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val" + sample_size: 256 + +validation_data: + prompts: + - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." + - "A drone view of celebration with Christma tree and fireworks, starry sky - background." + - "Robot dancing in times square." + - "Pacific coast, carmel by the sea ocean and waves." + num_inference_steps: 25 + guidance_scale: 8. + +trainable_modules: + - "." + +unet_checkpoint_path: "" + +learning_rate: 1.e-5 +train_batch_size: 50 + +max_train_epoch: -1 +max_train_steps: 100 +checkpointing_epochs: -1 +checkpointing_steps: 60 + +validation_steps: 5000 +validation_steps_tuple: [2, 50] + +global_seed: 42 +mixed_precision_training: true +enable_xformers_memory_efficient_attention: True + +is_debug: False diff --git a/configs/training/training.yaml b/configs/training/training.yaml new file mode 100644 index 0000000..626f05c --- /dev/null +++ b/configs/training/training.yaml @@ -0,0 +1,66 @@ +image_finetune: false + +output_dir: "outputs" +pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" + +unet_additional_kwargs: + use_motion_module : true + motion_module_resolutions : [ 1,2,4,8 ] + unet_use_cross_frame_attention : false + unet_use_temporal_attention : false + + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads : 8 + num_transformer_block : 1 + attention_block_types : [ "Temporal_Self", "Temporal_Self" ] + temporal_position_encoding : true + temporal_position_encoding_max_len : 24 + temporal_attention_dim_div : 1 + zero_initialize : true + +noise_scheduler_kwargs: + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + beta_schedule: "linear" + steps_offset: 1 + clip_sample: false + +train_data: + csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv" + video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val" + sample_size: 256 + sample_stride: 4 + sample_n_frames: 16 + +validation_data: + prompts: + - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." + - "A drone view of celebration with Christma tree and fireworks, starry sky - background." + - "Robot dancing in times square." + - "Pacific coast, carmel by the sea ocean and waves." + num_inference_steps: 25 + guidance_scale: 8. + +trainable_modules: + - "motion_modules." + +unet_checkpoint_path: "" + +learning_rate: 1.e-4 +train_batch_size: 4 + +max_train_epoch: -1 +max_train_steps: 100 +checkpointing_epochs: -1 +checkpointing_steps: 60 + +validation_steps: 5000 +validation_steps_tuple: [2, 50] + +global_seed: 42 +mixed_precision_training: true +enable_xformers_memory_efficient_attention: True + +is_debug: False diff --git a/environment.yaml b/environment.yaml index 8ab5514..64d18c7 100644 --- a/environment.yaml +++ b/environment.yaml @@ -14,8 +14,10 @@ dependencies: - transformers==4.25.1 - xformers==0.0.16 - imageio==2.27.0 + - decord==0.6.0 - gdown - einops - omegaconf - safetensors - gradio + - wandb diff --git a/train.py b/train.py new file mode 100644 index 0000000..094e419 --- /dev/null +++ b/train.py @@ -0,0 +1,493 @@ +import os +import math +import wandb +import random +import logging +import inspect +import argparse +import datetime +import subprocess + +from pathlib import Path +from tqdm.auto import tqdm +from einops import rearrange +from omegaconf import OmegaConf +from safetensors import safe_open +from typing import Dict, Optional, Tuple + +import torch +import torchvision +import torch.nn.functional as F +import torch.distributed as dist +from torch.optim.swa_utils import AveragedModel +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +import diffusers +from diffusers import AutoencoderKL, DDIMScheduler +from diffusers.models import UNet2DConditionModel +from diffusers.pipelines import StableDiffusionPipeline +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available + +import transformers +from transformers import CLIPTextModel, CLIPTokenizer + +from animatediff.data.dataset import WebVid10M +from animatediff.models.unet import UNet3DConditionModel +from animatediff.pipelines.pipeline_animation import AnimationPipeline +from animatediff.utils.util import save_videos_grid, zero_rank_print + + + +def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs): + """Initializes distributed environment.""" + if launcher == 'pytorch': + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + local_rank = rank % num_gpus + torch.cuda.set_device(local_rank) + dist.init_process_group(backend=backend, **kwargs) + + elif launcher == 'slurm': + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + local_rank = proc_id % num_gpus + torch.cuda.set_device(local_rank) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['RANK'] = str(proc_id) + port = os.environ.get('PORT', port) + os.environ['MASTER_PORT'] = str(port) + dist.init_process_group(backend=backend) + zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}") + + else: + raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!') + + return local_rank + + + +def main( + image_finetune: bool, + + name: str, + use_wandb: bool, + launcher: str, + + output_dir: str, + pretrained_model_path: str, + + train_data: Dict, + validation_data: Dict, + cfg_random_null_text: bool = True, + cfg_random_null_text_ratio: float = 0.1, + + unet_checkpoint_path: str = "", + unet_additional_kwargs: Dict = {}, + ema_decay: float = 0.9999, + noise_scheduler_kwargs = None, + + max_train_epoch: int = -1, + max_train_steps: int = 100, + validation_steps: int = 100, + validation_steps_tuple: Tuple = (-1,), + + learning_rate: float = 3e-5, + scale_lr: bool = False, + lr_warmup_steps: int = 0, + lr_scheduler: str = "constant", + + trainable_modules: Tuple[str] = (None, ), + num_workers: int = 32, + train_batch_size: int = 1, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_weight_decay: float = 1e-2, + adam_epsilon: float = 1e-08, + max_grad_norm: float = 1.0, + gradient_accumulation_steps: int = 1, + gradient_checkpointing: bool = False, + checkpointing_epochs: int = 5, + checkpointing_steps: int = -1, + + mixed_precision_training: bool = True, + enable_xformers_memory_efficient_attention: bool = True, + + global_seed: int = 42, + is_debug: bool = False, +): + check_min_version("0.10.0.dev0") + + # Initialize distributed training + local_rank = init_dist(launcher=launcher) + global_rank = dist.get_rank() + num_processes = dist.get_world_size() + is_main_process = global_rank == 0 + + seed = global_seed + global_rank + torch.manual_seed(seed) + + # Logging folder + folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") + output_dir = os.path.join(output_dir, folder_name) + if is_debug and os.path.exists(output_dir): + os.system(f"rm -rf {output_dir}") + + *_, config = inspect.getargvalues(inspect.currentframe()) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + + if is_main_process and (not is_debug) and use_wandb: + run = wandb.init(project="animatediff", name=folder_name, config=config) + + # Handle the output folder creation + if is_main_process: + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/samples", exist_ok=True) + os.makedirs(f"{output_dir}/sanity_check", exist_ok=True) + os.makedirs(f"{output_dir}/checkpoints", exist_ok=True) + OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) + + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + if not image_finetune: + unet = UNet3DConditionModel.from_pretrained_2d( + pretrained_model_path, subfolder="unet", + unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) + ) + else: + unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet") + + # Load pretrained unet weights + if unet_checkpoint_path != "": + zero_rank_print(f"from checkpoint: {unet_checkpoint_path}") + unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu") + if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}") + state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path + + m, u = unet.load_state_dict(state_dict, strict=False) + zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") + assert len(u) == 0 + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + # Set unet trainable parameters + unet.requires_grad_(False) + for name, param in unet.named_parameters(): + for trainable_module_name in trainable_modules: + if trainable_module_name in name: + param.requires_grad = True + break + + trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters())) + optimizer = torch.optim.AdamW( + trainable_params, + lr=learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, + ) + + if is_main_process: + zero_rank_print(f"trainable params number: {len(trainable_params)}") + zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M") + + # Enable xformers + if enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable gradient checkpointing + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Move models to GPU + vae.to(local_rank) + text_encoder.to(local_rank) + + # Get the training dataset + train_dataset = WebVid10M(**train_data, is_image=image_finetune) + distributed_sampler = DistributedSampler( + train_dataset, + num_replicas=num_processes, + rank=global_rank, + shuffle=True, + seed=global_seed, + ) + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=train_batch_size, + shuffle=False, + sampler=distributed_sampler, + num_workers=num_workers, + pin_memory=True, + drop_last=True, + ) + + # Get the training iteration + if max_train_steps == -1: + assert max_train_epoch != -1 + max_train_steps = max_train_epoch * len(train_dataloader) + + if checkpointing_steps == -1: + assert checkpointing_epochs != -1 + checkpointing_steps = checkpointing_epochs * len(train_dataloader) + + if scale_lr: + learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes) + + # Scheduler + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + ) + + # Validation pipeline + if not image_finetune: + validation_pipeline = AnimationPipeline( + unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, + ).to("cuda") + else: + validation_pipeline = StableDiffusionPipeline.from_pretrained( + pretrained_model_path, + unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None, + ) + validation_pipeline.enable_vae_slicing() + + # DDP warpper + unet.to(local_rank) + unet = DDP(unet, device_ids=[local_rank], output_device=local_rank) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + + # Train! + total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps + + if is_main_process: + logging.info("***** Running training *****") + logging.info(f" Num examples = {len(train_dataset)}") + logging.info(f" Num Epochs = {num_train_epochs}") + logging.info(f" Instantaneous batch size per device = {train_batch_size}") + logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logging.info(f" Total optimization steps = {max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process) + progress_bar.set_description("Steps") + + # Support mixed-precision training + scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None + + for epoch in range(first_epoch, num_train_epochs): + train_dataloader.sampler.set_epoch(epoch) + unet.train() + + for step, batch in enumerate(train_dataloader): + if cfg_random_null_text: + batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']] + + # Data batch sanity check + if epoch == first_epoch and step == 0: + pixel_values, texts = batch['pixel_values'].cpu(), batch['text'] + if not image_finetune: + pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w") + for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): + pixel_value = pixel_value[None, ...] + save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True) + else: + for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)): + pixel_value = pixel_value / 2. + 0.5 + torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png") + + ### >>>> Training >>>> ### + + # Convert videos to latent space + pixel_values = batch["pixel_values"].to(local_rank) + video_length = pixel_values.shape[1] + with torch.no_grad(): + if not image_finetune: + pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") + latents = vae.encode(pixel_values).latent_dist + latents = latents.sample() + latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) + else: + latents = vae.encode(pixel_values).latent_dist + latents = latents.sample() + + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + with torch.no_grad(): + prompt_ids = tokenizer( + batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ).input_ids.to(latents.device) + encoder_hidden_states = text_encoder(prompt_ids)[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + raise NotImplementedError + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + # Predict the noise residual and compute loss + # Mixed-precision training + with torch.cuda.amp.autocast(enabled=mixed_precision_training): + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + optimizer.zero_grad() + + # Backpropagate + if mixed_precision_training: + scaler.scale(loss).backward() + """ >>> gradient clipping >>> """ + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) + """ <<< gradient clipping <<< """ + scaler.step(optimizer) + scaler.update() + else: + loss.backward() + """ >>> gradient clipping >>> """ + torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm) + """ <<< gradient clipping <<< """ + optimizer.step() + + lr_scheduler.step() + progress_bar.update(1) + global_step += 1 + + ### <<<< Training <<<< ### + + # Wandb logging + if is_main_process and (not is_debug) and use_wandb: + wandb.log({"train_loss": loss.item()}, step=global_step) + + # Save checkpoint + if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1): + save_path = os.path.join(output_dir, f"checkpoints") + state_dict = { + "epoch": epoch, + "global_step": global_step, + "state_dict": unet.state_dict(), + } + if step == len(train_dataloader) - 1: + torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt")) + else: + torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt")) + logging.info(f"Saved state to {save_path} (global_step: {global_step})") + + # Periodically validation + if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple): + samples = [] + + generator = torch.Generator(device=latents.device) + generator.manual_seed(global_seed) + + height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size + width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size + + prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts + + for idx, prompt in enumerate(prompts): + if not image_finetune: + sample = validation_pipeline( + prompt, + generator = generator, + video_length = train_data.sample_n_frames, + height = height, + width = width, + **validation_data, + ).videos + save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") + samples.append(sample) + + else: + sample = validation_pipeline( + prompt, + generator = generator, + height = height, + width = width, + num_inference_steps = validation_data.get("num_inference_steps", 25), + guidance_scale = validation_data.get("guidance_scale", 8.), + ).images[0] + sample = torchvision.transforms.functional.to_tensor(sample) + samples.append(sample) + + if not image_finetune: + samples = torch.concat(samples) + save_path = f"{output_dir}/samples/sample-{global_step}.gif" + save_videos_grid(samples, save_path) + + else: + samples = torch.stack(samples) + save_path = f"{output_dir}/samples/sample-{global_step}.png" + torchvision.utils.save_image(samples, save_path, nrow=4) + + logging.info(f"Saved samples to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= max_train_steps: + break + + dist.destroy_process_group() + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch") + parser.add_argument("--wandb", action="store_true") + args = parser.parse_args() + + name = Path(args.config).stem + config = OmegaConf.load(args.config) + + main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)