optimize memory cost

This commit is contained in:
Yuwei Guo
2023-07-12 16:41:08 +08:00
parent 41a698ae8e
commit 05fdf470ad
2 changed files with 10 additions and 1 deletions

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass
import numpy as np
import torch
from tqdm import tqdm
from diffusers.utils import is_accelerate_available
from packaging import version
@@ -239,7 +240,11 @@ class AnimationPipeline(DiffusionPipeline):
video_length = latents.shape[2]
latents = 1 / 0.18215 * latents
latents = rearrange(latents, "b c f h w -> (b f) c h w")
video = self.vae.decode(latents).sample
# video = self.vae.decode(latents).sample
video = []
for frame_idx in tqdm(range(latents.shape[0])):
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16