mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 01:36:20 +02:00
optimize memory cost
This commit is contained in:
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
@@ -239,7 +240,11 @@ class AnimationPipeline(DiffusionPipeline):
|
||||
video_length = latents.shape[2]
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
||||
video = self.vae.decode(latents).sample
|
||||
# video = self.vae.decode(latents).sample
|
||||
video = []
|
||||
for frame_idx in tqdm(range(latents.shape[0])):
|
||||
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
||||
video = torch.cat(video)
|
||||
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
||||
video = (video / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
|
||||
@@ -17,6 +17,7 @@ from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||
from animatediff.utils.util import save_videos_grid
|
||||
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
@@ -51,6 +52,9 @@ def main(args):
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
||||
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
||||
|
||||
if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
|
||||
else: assert False
|
||||
|
||||
pipeline = AnimationPipeline(
|
||||
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
||||
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
||||
|
||||
Reference in New Issue
Block a user