training script

This commit is contained in:
Yuwei Guo
2023-08-20 17:02:57 +08:00
parent e559802fef
commit e816747d66
8 changed files with 744 additions and 1 deletions

View File

@@ -5,11 +5,16 @@ from typing import Union
import torch
import torchvision
import torch.distributed as dist
from tqdm import tqdm
from einops import rearrange
def zero_rank_print(s):
if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
videos = rearrange(videos, "b c t h w -> t b c h w")
outputs = []