mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 01:36:20 +02:00
optimize memory cost
This commit is contained in:
@@ -6,6 +6,7 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from diffusers.utils import is_accelerate_available
|
from diffusers.utils import is_accelerate_available
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@@ -239,7 +240,11 @@ class AnimationPipeline(DiffusionPipeline):
|
|||||||
video_length = latents.shape[2]
|
video_length = latents.shape[2]
|
||||||
latents = 1 / 0.18215 * latents
|
latents = 1 / 0.18215 * latents
|
||||||
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
||||||
video = self.vae.decode(latents).sample
|
# video = self.vae.decode(latents).sample
|
||||||
|
video = []
|
||||||
|
for frame_idx in tqdm(range(latents.shape[0])):
|
||||||
|
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
|
||||||
|
video = torch.cat(video)
|
||||||
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
||||||
video = (video / 2 + 0.5).clamp(0, 1)
|
video = (video / 2 + 0.5).clamp(0, 1)
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
|||||||
from animatediff.utils.util import save_videos_grid
|
from animatediff.utils.util import save_videos_grid
|
||||||
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
||||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
||||||
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
|
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
@@ -51,6 +52,9 @@ def main(args):
|
|||||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
||||||
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
||||||
|
|
||||||
|
if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
|
||||||
|
else: assert False
|
||||||
|
|
||||||
pipeline = AnimationPipeline(
|
pipeline = AnimationPipeline(
|
||||||
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
||||||
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
||||||
|
|||||||
Reference in New Issue
Block a user