mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 16:38:01 +01:00
training script
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,6 @@
|
|||||||
samples/
|
samples/
|
||||||
|
wandb/
|
||||||
|
outputs/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
models/StableDiffusion/stable-diffusion-v1-5
|
models/StableDiffusion/stable-diffusion-v1-5
|
||||||
scripts/animate_inter.py
|
scripts/animate_inter.py
|
||||||
|
|||||||
31
README.md
31
README.md
@@ -63,7 +63,7 @@ Contributions are always welcome!! The <code>dev</code> branch is for community
|
|||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
## Setup for Inference
|
## Setups for Inference
|
||||||
|
|
||||||
### Prepare Environment
|
### Prepare Environment
|
||||||
|
|
||||||
@@ -139,6 +139,35 @@ Then run the following commands:
|
|||||||
python -m scripts.animate --config [path to the config file]
|
python -m scripts.animate --config [path to the config file]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Steps for Training
|
||||||
|
|
||||||
|
### Dataset
|
||||||
|
Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine.
|
||||||
|
Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`.
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder:
|
||||||
|
```
|
||||||
|
train_data:
|
||||||
|
csv_path: [Replace with .csv Annotation File Path]
|
||||||
|
video_folder: [Replace with Video Folder Path]
|
||||||
|
sample_size: 256
|
||||||
|
```
|
||||||
|
Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files.
|
||||||
|
|
||||||
|
### Training
|
||||||
|
To train motion modules
|
||||||
|
```
|
||||||
|
torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
To finetune the unet's image layers
|
||||||
|
```
|
||||||
|
torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## Gradio Demo
|
## Gradio Demo
|
||||||
We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands:
|
We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands:
|
||||||
```
|
```
|
||||||
|
|||||||
98
animatediff/data/dataset.py
Normal file
98
animatediff/data/dataset.py
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
import os, io, csv, math, random
|
||||||
|
import numpy as np
|
||||||
|
from einops import rearrange
|
||||||
|
from decord import VideoReader
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from torch.utils.data.dataset import Dataset
|
||||||
|
from animatediff.utils.util import zero_rank_print
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class WebVid10M(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
csv_path, video_folder,
|
||||||
|
sample_size=256, sample_stride=4, sample_n_frames=16,
|
||||||
|
is_image=False,
|
||||||
|
):
|
||||||
|
zero_rank_print(f"loading annotations from {csv_path} ...")
|
||||||
|
with open(csv_path, 'r') as csvfile:
|
||||||
|
self.dataset = list(csv.DictReader(csvfile))
|
||||||
|
self.length = len(self.dataset)
|
||||||
|
zero_rank_print(f"data scale: {self.length}")
|
||||||
|
|
||||||
|
self.video_folder = video_folder
|
||||||
|
self.sample_stride = sample_stride
|
||||||
|
self.sample_n_frames = sample_n_frames
|
||||||
|
self.is_image = is_image
|
||||||
|
|
||||||
|
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
|
||||||
|
self.pixel_transforms = transforms.Compose([
|
||||||
|
transforms.RandomHorizontalFlip(),
|
||||||
|
transforms.Resize(sample_size[0]),
|
||||||
|
transforms.CenterCrop(sample_size),
|
||||||
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
||||||
|
])
|
||||||
|
|
||||||
|
def get_batch(self, idx):
|
||||||
|
video_dict = self.dataset[idx]
|
||||||
|
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
|
||||||
|
|
||||||
|
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
|
||||||
|
video_reader = VideoReader(video_dir)
|
||||||
|
video_length = len(video_reader)
|
||||||
|
|
||||||
|
if not self.is_image:
|
||||||
|
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
|
||||||
|
start_idx = random.randint(0, video_length - clip_length)
|
||||||
|
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
|
||||||
|
else:
|
||||||
|
batch_index = [random.randint(0, video_length - 1)]
|
||||||
|
|
||||||
|
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
|
||||||
|
pixel_values = pixel_values / 255.
|
||||||
|
del video_reader
|
||||||
|
|
||||||
|
if self.is_image:
|
||||||
|
pixel_values = pixel_values[0]
|
||||||
|
|
||||||
|
return pixel_values, name
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
pixel_values, name = self.get_batch(idx)
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
idx = random.randint(0, self.length-1)
|
||||||
|
|
||||||
|
pixel_values = self.pixel_transforms(pixel_values)
|
||||||
|
sample = dict(pixel_values=pixel_values, text=name)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from animatediff.utils.util import save_videos_grid
|
||||||
|
|
||||||
|
dataset = WebVid10M(
|
||||||
|
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
|
||||||
|
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
|
||||||
|
sample_size=256,
|
||||||
|
sample_stride=4, sample_n_frames=16,
|
||||||
|
is_image=True,
|
||||||
|
)
|
||||||
|
import pdb
|
||||||
|
pdb.set_trace()
|
||||||
|
|
||||||
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
|
||||||
|
for idx, batch in enumerate(dataloader):
|
||||||
|
print(batch["pixel_values"].shape, len(batch["text"]))
|
||||||
|
# for i in range(batch["pixel_values"].shape[0]):
|
||||||
|
# save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
|
||||||
@@ -5,11 +5,16 @@ from typing import Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torchvision
|
import torchvision
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
|
||||||
|
|
||||||
|
def zero_rank_print(s):
|
||||||
|
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
||||||
|
|
||||||
|
|
||||||
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
||||||
videos = rearrange(videos, "b c t h w -> t b c h w")
|
videos = rearrange(videos, "b c t h w -> t b c h w")
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|||||||
48
configs/training/image_finetune.yaml
Normal file
48
configs/training/image_finetune.yaml
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
image_finetune: true
|
||||||
|
|
||||||
|
output_dir: "outputs"
|
||||||
|
pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
|
||||||
|
|
||||||
|
noise_scheduler_kwargs:
|
||||||
|
num_train_timesteps: 1000
|
||||||
|
beta_start: 0.00085
|
||||||
|
beta_end: 0.012
|
||||||
|
beta_schedule: "scaled_linear"
|
||||||
|
steps_offset: 1
|
||||||
|
clip_sample: false
|
||||||
|
|
||||||
|
train_data:
|
||||||
|
csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
|
||||||
|
video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
|
||||||
|
sample_size: 256
|
||||||
|
|
||||||
|
validation_data:
|
||||||
|
prompts:
|
||||||
|
- "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
|
||||||
|
- "A drone view of celebration with Christma tree and fireworks, starry sky - background."
|
||||||
|
- "Robot dancing in times square."
|
||||||
|
- "Pacific coast, carmel by the sea ocean and waves."
|
||||||
|
num_inference_steps: 25
|
||||||
|
guidance_scale: 8.
|
||||||
|
|
||||||
|
trainable_modules:
|
||||||
|
- "."
|
||||||
|
|
||||||
|
unet_checkpoint_path: ""
|
||||||
|
|
||||||
|
learning_rate: 1.e-5
|
||||||
|
train_batch_size: 50
|
||||||
|
|
||||||
|
max_train_epoch: -1
|
||||||
|
max_train_steps: 100
|
||||||
|
checkpointing_epochs: -1
|
||||||
|
checkpointing_steps: 60
|
||||||
|
|
||||||
|
validation_steps: 5000
|
||||||
|
validation_steps_tuple: [2, 50]
|
||||||
|
|
||||||
|
global_seed: 42
|
||||||
|
mixed_precision_training: true
|
||||||
|
enable_xformers_memory_efficient_attention: True
|
||||||
|
|
||||||
|
is_debug: False
|
||||||
66
configs/training/training.yaml
Normal file
66
configs/training/training.yaml
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
image_finetune: false
|
||||||
|
|
||||||
|
output_dir: "outputs"
|
||||||
|
pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
|
||||||
|
|
||||||
|
unet_additional_kwargs:
|
||||||
|
use_motion_module : true
|
||||||
|
motion_module_resolutions : [ 1,2,4,8 ]
|
||||||
|
unet_use_cross_frame_attention : false
|
||||||
|
unet_use_temporal_attention : false
|
||||||
|
|
||||||
|
motion_module_type: Vanilla
|
||||||
|
motion_module_kwargs:
|
||||||
|
num_attention_heads : 8
|
||||||
|
num_transformer_block : 1
|
||||||
|
attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
|
||||||
|
temporal_position_encoding : true
|
||||||
|
temporal_position_encoding_max_len : 24
|
||||||
|
temporal_attention_dim_div : 1
|
||||||
|
zero_initialize : true
|
||||||
|
|
||||||
|
noise_scheduler_kwargs:
|
||||||
|
num_train_timesteps: 1000
|
||||||
|
beta_start: 0.00085
|
||||||
|
beta_end: 0.012
|
||||||
|
beta_schedule: "linear"
|
||||||
|
steps_offset: 1
|
||||||
|
clip_sample: false
|
||||||
|
|
||||||
|
train_data:
|
||||||
|
csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
|
||||||
|
video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
|
||||||
|
sample_size: 256
|
||||||
|
sample_stride: 4
|
||||||
|
sample_n_frames: 16
|
||||||
|
|
||||||
|
validation_data:
|
||||||
|
prompts:
|
||||||
|
- "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
|
||||||
|
- "A drone view of celebration with Christma tree and fireworks, starry sky - background."
|
||||||
|
- "Robot dancing in times square."
|
||||||
|
- "Pacific coast, carmel by the sea ocean and waves."
|
||||||
|
num_inference_steps: 25
|
||||||
|
guidance_scale: 8.
|
||||||
|
|
||||||
|
trainable_modules:
|
||||||
|
- "motion_modules."
|
||||||
|
|
||||||
|
unet_checkpoint_path: ""
|
||||||
|
|
||||||
|
learning_rate: 1.e-4
|
||||||
|
train_batch_size: 4
|
||||||
|
|
||||||
|
max_train_epoch: -1
|
||||||
|
max_train_steps: 100
|
||||||
|
checkpointing_epochs: -1
|
||||||
|
checkpointing_steps: 60
|
||||||
|
|
||||||
|
validation_steps: 5000
|
||||||
|
validation_steps_tuple: [2, 50]
|
||||||
|
|
||||||
|
global_seed: 42
|
||||||
|
mixed_precision_training: true
|
||||||
|
enable_xformers_memory_efficient_attention: True
|
||||||
|
|
||||||
|
is_debug: False
|
||||||
@@ -14,8 +14,10 @@ dependencies:
|
|||||||
- transformers==4.25.1
|
- transformers==4.25.1
|
||||||
- xformers==0.0.16
|
- xformers==0.0.16
|
||||||
- imageio==2.27.0
|
- imageio==2.27.0
|
||||||
|
- decord==0.6.0
|
||||||
- gdown
|
- gdown
|
||||||
- einops
|
- einops
|
||||||
- omegaconf
|
- omegaconf
|
||||||
- safetensors
|
- safetensors
|
||||||
- gradio
|
- gradio
|
||||||
|
- wandb
|
||||||
|
|||||||
493
train.py
Normal file
493
train.py
Normal file
@@ -0,0 +1,493 @@
|
|||||||
|
import os
|
||||||
|
import math
|
||||||
|
import wandb
|
||||||
|
import random
|
||||||
|
import logging
|
||||||
|
import inspect
|
||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
import subprocess
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from einops import rearrange
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from safetensors import safe_open
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torchvision
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.distributed as dist
|
||||||
|
from torch.optim.swa_utils import AveragedModel
|
||||||
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
import diffusers
|
||||||
|
from diffusers import AutoencoderKL, DDIMScheduler
|
||||||
|
from diffusers.models import UNet2DConditionModel
|
||||||
|
from diffusers.pipelines import StableDiffusionPipeline
|
||||||
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.utils import check_min_version
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
|
import transformers
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
|
from animatediff.data.dataset import WebVid10M
|
||||||
|
from animatediff.models.unet import UNet3DConditionModel
|
||||||
|
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||||
|
from animatediff.utils.util import save_videos_grid, zero_rank_print
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
|
||||||
|
"""Initializes distributed environment."""
|
||||||
|
if launcher == 'pytorch':
|
||||||
|
rank = int(os.environ['RANK'])
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
local_rank = rank % num_gpus
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
dist.init_process_group(backend=backend, **kwargs)
|
||||||
|
|
||||||
|
elif launcher == 'slurm':
|
||||||
|
proc_id = int(os.environ['SLURM_PROCID'])
|
||||||
|
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||||
|
node_list = os.environ['SLURM_NODELIST']
|
||||||
|
num_gpus = torch.cuda.device_count()
|
||||||
|
local_rank = proc_id % num_gpus
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
addr = subprocess.getoutput(
|
||||||
|
f'scontrol show hostname {node_list} | head -n1')
|
||||||
|
os.environ['MASTER_ADDR'] = addr
|
||||||
|
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||||
|
os.environ['RANK'] = str(proc_id)
|
||||||
|
port = os.environ.get('PORT', port)
|
||||||
|
os.environ['MASTER_PORT'] = str(port)
|
||||||
|
dist.init_process_group(backend=backend)
|
||||||
|
zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
|
||||||
|
|
||||||
|
return local_rank
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
image_finetune: bool,
|
||||||
|
|
||||||
|
name: str,
|
||||||
|
use_wandb: bool,
|
||||||
|
launcher: str,
|
||||||
|
|
||||||
|
output_dir: str,
|
||||||
|
pretrained_model_path: str,
|
||||||
|
|
||||||
|
train_data: Dict,
|
||||||
|
validation_data: Dict,
|
||||||
|
cfg_random_null_text: bool = True,
|
||||||
|
cfg_random_null_text_ratio: float = 0.1,
|
||||||
|
|
||||||
|
unet_checkpoint_path: str = "",
|
||||||
|
unet_additional_kwargs: Dict = {},
|
||||||
|
ema_decay: float = 0.9999,
|
||||||
|
noise_scheduler_kwargs = None,
|
||||||
|
|
||||||
|
max_train_epoch: int = -1,
|
||||||
|
max_train_steps: int = 100,
|
||||||
|
validation_steps: int = 100,
|
||||||
|
validation_steps_tuple: Tuple = (-1,),
|
||||||
|
|
||||||
|
learning_rate: float = 3e-5,
|
||||||
|
scale_lr: bool = False,
|
||||||
|
lr_warmup_steps: int = 0,
|
||||||
|
lr_scheduler: str = "constant",
|
||||||
|
|
||||||
|
trainable_modules: Tuple[str] = (None, ),
|
||||||
|
num_workers: int = 32,
|
||||||
|
train_batch_size: int = 1,
|
||||||
|
adam_beta1: float = 0.9,
|
||||||
|
adam_beta2: float = 0.999,
|
||||||
|
adam_weight_decay: float = 1e-2,
|
||||||
|
adam_epsilon: float = 1e-08,
|
||||||
|
max_grad_norm: float = 1.0,
|
||||||
|
gradient_accumulation_steps: int = 1,
|
||||||
|
gradient_checkpointing: bool = False,
|
||||||
|
checkpointing_epochs: int = 5,
|
||||||
|
checkpointing_steps: int = -1,
|
||||||
|
|
||||||
|
mixed_precision_training: bool = True,
|
||||||
|
enable_xformers_memory_efficient_attention: bool = True,
|
||||||
|
|
||||||
|
global_seed: int = 42,
|
||||||
|
is_debug: bool = False,
|
||||||
|
):
|
||||||
|
check_min_version("0.10.0.dev0")
|
||||||
|
|
||||||
|
# Initialize distributed training
|
||||||
|
local_rank = init_dist(launcher=launcher)
|
||||||
|
global_rank = dist.get_rank()
|
||||||
|
num_processes = dist.get_world_size()
|
||||||
|
is_main_process = global_rank == 0
|
||||||
|
|
||||||
|
seed = global_seed + global_rank
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
# Logging folder
|
||||||
|
folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
|
||||||
|
output_dir = os.path.join(output_dir, folder_name)
|
||||||
|
if is_debug and os.path.exists(output_dir):
|
||||||
|
os.system(f"rm -rf {output_dir}")
|
||||||
|
|
||||||
|
*_, config = inspect.getargvalues(inspect.currentframe())
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_main_process and (not is_debug) and use_wandb:
|
||||||
|
run = wandb.init(project="animatediff", name=folder_name, config=config)
|
||||||
|
|
||||||
|
# Handle the output folder creation
|
||||||
|
if is_main_process:
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
os.makedirs(f"{output_dir}/samples", exist_ok=True)
|
||||||
|
os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
|
||||||
|
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
|
||||||
|
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
|
||||||
|
|
||||||
|
# Load scheduler, tokenizer and models.
|
||||||
|
noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
|
||||||
|
|
||||||
|
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
|
||||||
|
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
|
||||||
|
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
|
||||||
|
if not image_finetune:
|
||||||
|
unet = UNet3DConditionModel.from_pretrained_2d(
|
||||||
|
pretrained_model_path, subfolder="unet",
|
||||||
|
unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
|
||||||
|
|
||||||
|
# Load pretrained unet weights
|
||||||
|
if unet_checkpoint_path != "":
|
||||||
|
zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
|
||||||
|
unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
|
||||||
|
if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
|
||||||
|
state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
|
||||||
|
|
||||||
|
m, u = unet.load_state_dict(state_dict, strict=False)
|
||||||
|
zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
|
||||||
|
assert len(u) == 0
|
||||||
|
|
||||||
|
# Freeze vae and text_encoder
|
||||||
|
vae.requires_grad_(False)
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
|
# Set unet trainable parameters
|
||||||
|
unet.requires_grad_(False)
|
||||||
|
for name, param in unet.named_parameters():
|
||||||
|
for trainable_module_name in trainable_modules:
|
||||||
|
if trainable_module_name in name:
|
||||||
|
param.requires_grad = True
|
||||||
|
break
|
||||||
|
|
||||||
|
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
trainable_params,
|
||||||
|
lr=learning_rate,
|
||||||
|
betas=(adam_beta1, adam_beta2),
|
||||||
|
weight_decay=adam_weight_decay,
|
||||||
|
eps=adam_epsilon,
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
zero_rank_print(f"trainable params number: {len(trainable_params)}")
|
||||||
|
zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
|
||||||
|
|
||||||
|
# Enable xformers
|
||||||
|
if enable_xformers_memory_efficient_attention:
|
||||||
|
if is_xformers_available():
|
||||||
|
unet.enable_xformers_memory_efficient_attention()
|
||||||
|
else:
|
||||||
|
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||||
|
|
||||||
|
# Enable gradient checkpointing
|
||||||
|
if gradient_checkpointing:
|
||||||
|
unet.enable_gradient_checkpointing()
|
||||||
|
|
||||||
|
# Move models to GPU
|
||||||
|
vae.to(local_rank)
|
||||||
|
text_encoder.to(local_rank)
|
||||||
|
|
||||||
|
# Get the training dataset
|
||||||
|
train_dataset = WebVid10M(**train_data, is_image=image_finetune)
|
||||||
|
distributed_sampler = DistributedSampler(
|
||||||
|
train_dataset,
|
||||||
|
num_replicas=num_processes,
|
||||||
|
rank=global_rank,
|
||||||
|
shuffle=True,
|
||||||
|
seed=global_seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# DataLoaders creation:
|
||||||
|
train_dataloader = torch.utils.data.DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=train_batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
sampler=distributed_sampler,
|
||||||
|
num_workers=num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
drop_last=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the training iteration
|
||||||
|
if max_train_steps == -1:
|
||||||
|
assert max_train_epoch != -1
|
||||||
|
max_train_steps = max_train_epoch * len(train_dataloader)
|
||||||
|
|
||||||
|
if checkpointing_steps == -1:
|
||||||
|
assert checkpointing_epochs != -1
|
||||||
|
checkpointing_steps = checkpointing_epochs * len(train_dataloader)
|
||||||
|
|
||||||
|
if scale_lr:
|
||||||
|
learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||||
|
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validation pipeline
|
||||||
|
if not image_finetune:
|
||||||
|
validation_pipeline = AnimationPipeline(
|
||||||
|
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
|
||||||
|
).to("cuda")
|
||||||
|
else:
|
||||||
|
validation_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
|
pretrained_model_path,
|
||||||
|
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
|
||||||
|
)
|
||||||
|
validation_pipeline.enable_vae_slicing()
|
||||||
|
|
||||||
|
# DDP warpper
|
||||||
|
unet.to(local_rank)
|
||||||
|
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
|
||||||
|
|
||||||
|
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||||
|
# Afterwards we recalculate our number of training epochs
|
||||||
|
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||||
|
|
||||||
|
# Train!
|
||||||
|
total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
logging.info("***** Running training *****")
|
||||||
|
logging.info(f" Num examples = {len(train_dataset)}")
|
||||||
|
logging.info(f" Num Epochs = {num_train_epochs}")
|
||||||
|
logging.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||||
|
logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
||||||
|
logging.info(f" Total optimization steps = {max_train_steps}")
|
||||||
|
global_step = 0
|
||||||
|
first_epoch = 0
|
||||||
|
|
||||||
|
# Only show the progress bar once on each machine.
|
||||||
|
progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
|
||||||
|
progress_bar.set_description("Steps")
|
||||||
|
|
||||||
|
# Support mixed-precision training
|
||||||
|
scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
|
||||||
|
|
||||||
|
for epoch in range(first_epoch, num_train_epochs):
|
||||||
|
train_dataloader.sampler.set_epoch(epoch)
|
||||||
|
unet.train()
|
||||||
|
|
||||||
|
for step, batch in enumerate(train_dataloader):
|
||||||
|
if cfg_random_null_text:
|
||||||
|
batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
|
||||||
|
|
||||||
|
# Data batch sanity check
|
||||||
|
if epoch == first_epoch and step == 0:
|
||||||
|
pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
|
||||||
|
if not image_finetune:
|
||||||
|
pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
|
||||||
|
for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
|
||||||
|
pixel_value = pixel_value[None, ...]
|
||||||
|
save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
|
||||||
|
else:
|
||||||
|
for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
|
||||||
|
pixel_value = pixel_value / 2. + 0.5
|
||||||
|
torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
|
||||||
|
|
||||||
|
### >>>> Training >>>> ###
|
||||||
|
|
||||||
|
# Convert videos to latent space
|
||||||
|
pixel_values = batch["pixel_values"].to(local_rank)
|
||||||
|
video_length = pixel_values.shape[1]
|
||||||
|
with torch.no_grad():
|
||||||
|
if not image_finetune:
|
||||||
|
pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
|
||||||
|
latents = vae.encode(pixel_values).latent_dist
|
||||||
|
latents = latents.sample()
|
||||||
|
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
|
||||||
|
else:
|
||||||
|
latents = vae.encode(pixel_values).latent_dist
|
||||||
|
latents = latents.sample()
|
||||||
|
|
||||||
|
latents = latents * 0.18215
|
||||||
|
|
||||||
|
# Sample noise that we'll add to the latents
|
||||||
|
noise = torch.randn_like(latents)
|
||||||
|
bsz = latents.shape[0]
|
||||||
|
|
||||||
|
# Sample a random timestep for each video
|
||||||
|
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
|
# (this is the forward diffusion process)
|
||||||
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
|
# Get the text embedding for conditioning
|
||||||
|
with torch.no_grad():
|
||||||
|
prompt_ids = tokenizer(
|
||||||
|
batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||||
|
).input_ids.to(latents.device)
|
||||||
|
encoder_hidden_states = text_encoder(prompt_ids)[0]
|
||||||
|
|
||||||
|
# Get the target for loss depending on the prediction type
|
||||||
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
|
# Predict the noise residual and compute loss
|
||||||
|
# Mixed-precision training
|
||||||
|
with torch.cuda.amp.autocast(enabled=mixed_precision_training):
|
||||||
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
# Backpropagate
|
||||||
|
if mixed_precision_training:
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
""" >>> gradient clipping >>> """
|
||||||
|
scaler.unscale_(optimizer)
|
||||||
|
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
||||||
|
""" <<< gradient clipping <<< """
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
else:
|
||||||
|
loss.backward()
|
||||||
|
""" >>> gradient clipping >>> """
|
||||||
|
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
||||||
|
""" <<< gradient clipping <<< """
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
lr_scheduler.step()
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
### <<<< Training <<<< ###
|
||||||
|
|
||||||
|
# Wandb logging
|
||||||
|
if is_main_process and (not is_debug) and use_wandb:
|
||||||
|
wandb.log({"train_loss": loss.item()}, step=global_step)
|
||||||
|
|
||||||
|
# Save checkpoint
|
||||||
|
if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
|
||||||
|
save_path = os.path.join(output_dir, f"checkpoints")
|
||||||
|
state_dict = {
|
||||||
|
"epoch": epoch,
|
||||||
|
"global_step": global_step,
|
||||||
|
"state_dict": unet.state_dict(),
|
||||||
|
}
|
||||||
|
if step == len(train_dataloader) - 1:
|
||||||
|
torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
|
||||||
|
else:
|
||||||
|
torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt"))
|
||||||
|
logging.info(f"Saved state to {save_path} (global_step: {global_step})")
|
||||||
|
|
||||||
|
# Periodically validation
|
||||||
|
if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
|
||||||
|
samples = []
|
||||||
|
|
||||||
|
generator = torch.Generator(device=latents.device)
|
||||||
|
generator.manual_seed(global_seed)
|
||||||
|
|
||||||
|
height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
|
||||||
|
width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
|
||||||
|
|
||||||
|
prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
|
||||||
|
|
||||||
|
for idx, prompt in enumerate(prompts):
|
||||||
|
if not image_finetune:
|
||||||
|
sample = validation_pipeline(
|
||||||
|
prompt,
|
||||||
|
generator = generator,
|
||||||
|
video_length = train_data.sample_n_frames,
|
||||||
|
height = height,
|
||||||
|
width = width,
|
||||||
|
**validation_data,
|
||||||
|
).videos
|
||||||
|
save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
|
||||||
|
samples.append(sample)
|
||||||
|
|
||||||
|
else:
|
||||||
|
sample = validation_pipeline(
|
||||||
|
prompt,
|
||||||
|
generator = generator,
|
||||||
|
height = height,
|
||||||
|
width = width,
|
||||||
|
num_inference_steps = validation_data.get("num_inference_steps", 25),
|
||||||
|
guidance_scale = validation_data.get("guidance_scale", 8.),
|
||||||
|
).images[0]
|
||||||
|
sample = torchvision.transforms.functional.to_tensor(sample)
|
||||||
|
samples.append(sample)
|
||||||
|
|
||||||
|
if not image_finetune:
|
||||||
|
samples = torch.concat(samples)
|
||||||
|
save_path = f"{output_dir}/samples/sample-{global_step}.gif"
|
||||||
|
save_videos_grid(samples, save_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
samples = torch.stack(samples)
|
||||||
|
save_path = f"{output_dir}/samples/sample-{global_step}.png"
|
||||||
|
torchvision.utils.save_image(samples, save_path, nrow=4)
|
||||||
|
|
||||||
|
logging.info(f"Saved samples to {save_path}")
|
||||||
|
|
||||||
|
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||||
|
progress_bar.set_postfix(**logs)
|
||||||
|
|
||||||
|
if global_step >= max_train_steps:
|
||||||
|
break
|
||||||
|
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--config", type=str, required=True)
|
||||||
|
parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
|
||||||
|
parser.add_argument("--wandb", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
name = Path(args.config).stem
|
||||||
|
config = OmegaConf.load(args.config)
|
||||||
|
|
||||||
|
main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)
|
||||||
Reference in New Issue
Block a user