mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
training script
This commit is contained in:
@@ -5,11 +5,16 @@ from typing import Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.distributed as dist
|
||||
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def zero_rank_print(s):
|
||||
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
|
||||
|
||||
|
||||
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
||||
videos = rearrange(videos, "b c t h w -> t b c h w")
|
||||
outputs = []
|
||||
|
||||
Reference in New Issue
Block a user