mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 01:36:20 +02:00
add code
This commit is contained in:
300
animatediff/models/attention.py
Normal file
300
animatediff/models/attention.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import pdb
|
||||
|
||||
@dataclass
|
||||
class Transformer3DModelOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
|
||||
class Transformer3DModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# Define input layers
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
if use_linear_projection:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
# Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
# Input
|
||||
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
||||
video_length = hidden_states.shape[2]
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
||||
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
|
||||
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
else:
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
video_length=video_length
|
||||
)
|
||||
|
||||
# Output
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
|
||||
output = hidden_states + residual
|
||||
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer3DModelOutput(sample=output)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
|
||||
unet_use_cross_frame_attention = None,
|
||||
unet_use_temporal_attention = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
||||
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
|
||||
self.unet_use_temporal_attention = unet_use_temporal_attention
|
||||
|
||||
# SC-Attn
|
||||
assert unet_use_cross_frame_attention is not None
|
||||
if unet_use_cross_frame_attention:
|
||||
self.attn1 = SparseCausalAttention2D(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
else:
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
||||
|
||||
# Cross-Attn
|
||||
if cross_attention_dim is not None:
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
||||
else:
|
||||
self.norm2 = None
|
||||
|
||||
# Feed-forward
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
|
||||
# Temp-Attn
|
||||
assert unet_use_temporal_attention is not None
|
||||
if unet_use_temporal_attention:
|
||||
self.attn_temp = CrossAttention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
|
||||
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
||||
" available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
if self.attn2 is not None:
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
|
||||
# SparseCausal-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
|
||||
# if self.only_cross_attention:
|
||||
# hidden_states = (
|
||||
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
# )
|
||||
# else:
|
||||
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
||||
|
||||
# pdb.set_trace()
|
||||
if self.unet_use_cross_frame_attention:
|
||||
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
|
||||
else:
|
||||
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
# Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = (
|
||||
self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
# Temporal-Attention
|
||||
if self.unet_use_temporal_attention:
|
||||
d = hidden_states.shape[1]
|
||||
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
||||
norm_hidden_states = (
|
||||
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
|
||||
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
||||
|
||||
return hidden_states
|
||||
330
animatediff/models/motion_module.py
Normal file
330
animatediff/models/motion_module.py
Normal file
@@ -0,0 +1,330 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
import torchvision
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.utils import BaseOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.models.attention import CrossAttention, FeedForward
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import math
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
# Zero out the parameters of a module and return it.
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemporalTransformer3DModelOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
else:
|
||||
xformers = None
|
||||
|
||||
|
||||
def get_motion_module(
|
||||
in_channels,
|
||||
motion_module_type: str,
|
||||
motion_module_kwargs: dict
|
||||
):
|
||||
if motion_module_type == "Vanilla":
|
||||
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
|
||||
class VanillaTemporalModule(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_attention_heads = 8,
|
||||
num_transformer_block = 2,
|
||||
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
|
||||
cross_frame_attention_mode = None,
|
||||
temporal_position_encoding = False,
|
||||
temporal_position_encoding_max_len = 24,
|
||||
temporal_attention_dim_div = 1,
|
||||
zero_initialize = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.temporal_transformer = TemporalTransformer3DModel(
|
||||
in_channels=in_channels,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
|
||||
num_layers=num_transformer_block,
|
||||
attention_block_types=attention_block_types,
|
||||
cross_frame_attention_mode=cross_frame_attention_mode,
|
||||
temporal_position_encoding=temporal_position_encoding,
|
||||
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
||||
)
|
||||
|
||||
if zero_initialize:
|
||||
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
|
||||
|
||||
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
|
||||
hidden_states = input_tensor
|
||||
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
|
||||
|
||||
output = hidden_states
|
||||
return output
|
||||
|
||||
|
||||
class TemporalTransformer3DModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
|
||||
num_layers,
|
||||
attention_block_types=(
|
||||
"Temporal_Self",
|
||||
"Temporal_Self",
|
||||
),
|
||||
dropout=0.0,
|
||||
norm_num_groups=32,
|
||||
cross_attention_dim=768,
|
||||
activation_fn="geglu",
|
||||
attention_bias=False,
|
||||
upcast_attention=False,
|
||||
|
||||
cross_frame_attention_mode=None,
|
||||
temporal_position_encoding=False,
|
||||
temporal_position_encoding_max_len=24,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
TemporalTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
attention_block_types=attention_block_types,
|
||||
dropout=dropout,
|
||||
norm_num_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
cross_frame_attention_mode=cross_frame_attention_mode,
|
||||
temporal_position_encoding=temporal_position_encoding,
|
||||
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, in_channels)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
|
||||
video_length = hidden_states.shape[2]
|
||||
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
|
||||
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
|
||||
# Transformer Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
|
||||
|
||||
# output
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
output = hidden_states + residual
|
||||
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TemporalTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
||||
dropout = 0.0,
|
||||
norm_num_groups = 32,
|
||||
cross_attention_dim = 768,
|
||||
activation_fn = "geglu",
|
||||
attention_bias = False,
|
||||
upcast_attention = False,
|
||||
cross_frame_attention_mode = None,
|
||||
temporal_position_encoding = False,
|
||||
temporal_position_encoding_max_len = 24,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
attention_blocks = []
|
||||
norms = []
|
||||
|
||||
for block_name in attention_block_types:
|
||||
attention_blocks.append(
|
||||
VersatileAttention(
|
||||
attention_mode=block_name.split("_")[0],
|
||||
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
|
||||
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
cross_frame_attention_mode=cross_frame_attention_mode,
|
||||
temporal_position_encoding=temporal_position_encoding,
|
||||
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
||||
)
|
||||
)
|
||||
norms.append(nn.LayerNorm(dim))
|
||||
|
||||
self.attention_blocks = nn.ModuleList(attention_blocks)
|
||||
self.norms = nn.ModuleList(norms)
|
||||
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.ff_norm = nn.LayerNorm(dim)
|
||||
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
||||
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
||||
norm_hidden_states = norm(hidden_states)
|
||||
hidden_states = attention_block(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
|
||||
video_length=video_length,
|
||||
) + hidden_states
|
||||
|
||||
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
||||
|
||||
output = hidden_states
|
||||
return output
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model: int, dropout: float = 0., max_len: int = 24):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
# print(f"d_model: {d_model}")
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
pe = torch.zeros(1, max_len, d_model)
|
||||
pe[0, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[0, :, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class VersatileAttention(CrossAttention):
|
||||
def __init__(
|
||||
self,
|
||||
attention_mode=None,
|
||||
cross_frame_attention_mode=None,
|
||||
temporal_position_encoding=False,
|
||||
temporal_position_encoding_max_len=24,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
assert attention_mode == "Temporal"
|
||||
|
||||
self.attention_mode = attention_mode
|
||||
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
|
||||
|
||||
self.pos_encoder = PositionalEncoding(
|
||||
kwargs["query_dim"],
|
||||
dropout=0.,
|
||||
max_len=temporal_position_encoding_max_len
|
||||
) if (temporal_position_encoding and attention_mode == "Temporal") else None
|
||||
|
||||
def extra_repr(self):
|
||||
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
if self.attention_mode == "Temporal":
|
||||
d = hidden_states.shape[1]
|
||||
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
|
||||
|
||||
if self.pos_encoder is not None:
|
||||
hidden_states = self.pos_encoder(hidden_states)
|
||||
|
||||
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states
|
||||
|
||||
if self.group_norm is not None:
|
||||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape[-1] != query.shape[1]:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value, attention_mask)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
|
||||
if self.attention_mode == "Temporal":
|
||||
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
||||
|
||||
return hidden_states
|
||||
197
animatediff/models/resnet.py
Normal file
197
animatediff/models/resnet.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class InflatedConv3d(nn.Conv2d):
|
||||
def forward(self, x):
|
||||
video_length = x.shape[2]
|
||||
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = super().forward(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_conv_transpose = use_conv_transpose
|
||||
self.name = name
|
||||
|
||||
conv = None
|
||||
if use_conv_transpose:
|
||||
raise NotImplementedError
|
||||
elif use_conv:
|
||||
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
|
||||
|
||||
def forward(self, hidden_states, output_size=None):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
|
||||
if self.use_conv_transpose:
|
||||
raise NotImplementedError
|
||||
|
||||
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
||||
dtype = hidden_states.dtype
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
|
||||
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
||||
if hidden_states.shape[0] >= 64:
|
||||
hidden_states = hidden_states.contiguous()
|
||||
|
||||
# if `output_size` is passed we force the interpolation output
|
||||
# size and do not make use of `scale_factor=2`
|
||||
if output_size is None:
|
||||
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
|
||||
else:
|
||||
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
||||
|
||||
# If the input is bfloat16, we cast back to bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
# if self.use_conv:
|
||||
# if self.name == "conv":
|
||||
# hidden_states = self.conv(hidden_states)
|
||||
# else:
|
||||
# hidden_states = self.Conv2d_0(hidden_states)
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Downsample3D(nn.Module):
|
||||
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
stride = 2
|
||||
self.name = name
|
||||
|
||||
if use_conv:
|
||||
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
if self.use_conv and self.padding == 0:
|
||||
raise NotImplementedError
|
||||
|
||||
assert hidden_states.shape[1] == self.channels
|
||||
hidden_states = self.conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResnetBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout=0.0,
|
||||
temb_channels=512,
|
||||
groups=32,
|
||||
groups_out=None,
|
||||
pre_norm=True,
|
||||
eps=1e-6,
|
||||
non_linearity="swish",
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
self.pre_norm = True
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
self.time_embedding_norm = time_embedding_norm
|
||||
self.output_scale_factor = output_scale_factor
|
||||
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
if self.time_embedding_norm == "default":
|
||||
time_emb_proj_out_channels = out_channels
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
time_emb_proj_out_channels = out_channels * 2
|
||||
else:
|
||||
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
||||
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if non_linearity == "swish":
|
||||
self.nonlinearity = lambda x: F.silu(x)
|
||||
elif non_linearity == "mish":
|
||||
self.nonlinearity = Mish()
|
||||
elif non_linearity == "silu":
|
||||
self.nonlinearity = nn.SiLU()
|
||||
|
||||
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, input_tensor, temb):
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
class Mish(torch.nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
|
||||
489
animatediff/models/unet.py
Normal file
489
animatediff/models/unet.py
Normal file
@@ -0,0 +1,489 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import os
|
||||
import json
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import (
|
||||
CrossAttnDownBlock3D,
|
||||
CrossAttnUpBlock3D,
|
||||
DownBlock3D,
|
||||
UNetMidBlock3DCrossAttn,
|
||||
UpBlock3D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
from .resnet import InflatedConv3d
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet3DConditionOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: Optional[int] = None,
|
||||
in_channels: int = 4,
|
||||
out_channels: int = 4,
|
||||
center_input_sample: bool = False,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D",
|
||||
),
|
||||
mid_block_type: str = "UNetMidBlock3DCrossAttn",
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D"
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
|
||||
# Additional
|
||||
use_motion_module = False,
|
||||
motion_module_resolutions = ( 1,2,4,8 ),
|
||||
motion_module_mid_block = False,
|
||||
motion_module_decoder_only = False,
|
||||
motion_module_type = None,
|
||||
motion_module_kwargs = {},
|
||||
unet_use_cross_frame_attention = None,
|
||||
unet_use_temporal_attention = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
# input
|
||||
self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
|
||||
|
||||
# time
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
res = 2 ** i
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[i],
|
||||
downsample_padding=downsample_padding,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
|
||||
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
if mid_block_type == "UNetMidBlock3DCrossAttn":
|
||||
self.mid_block = UNetMidBlock3DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
|
||||
use_motion_module=use_motion_module and motion_module_mid_block,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
# count how many layers upsample the videos
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_attention_head_dim = list(reversed(attention_head_dim))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
res = 2 ** (3 - i)
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=reversed_attention_head_dim[i],
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
|
||||
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_slicable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_slicable_dims(module)
|
||||
|
||||
num_slicable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_slicable_layers * [1]
|
||||
|
||||
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet3DConditionOutput, Tuple]:
|
||||
r"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
|
||||
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2**self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
# pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# mid
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
|
||||
# post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet3DConditionOutput(sample=sample)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
|
||||
if subfolder is not None:
|
||||
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
|
||||
print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
|
||||
|
||||
config_file = os.path.join(pretrained_model_path, 'config.json')
|
||||
if not os.path.isfile(config_file):
|
||||
raise RuntimeError(f"{config_file} does not exist")
|
||||
with open(config_file, "r") as f:
|
||||
config = json.load(f)
|
||||
config["_class_name"] = cls.__name__
|
||||
config["down_block_types"] = [
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"CrossAttnDownBlock3D",
|
||||
"DownBlock3D"
|
||||
]
|
||||
config["up_block_types"] = [
|
||||
"UpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D",
|
||||
"CrossAttnUpBlock3D"
|
||||
]
|
||||
|
||||
from diffusers.utils import WEIGHTS_NAME
|
||||
model = cls.from_config(config, **unet_additional_kwargs)
|
||||
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
||||
if not os.path.isfile(model_file):
|
||||
raise RuntimeError(f"{model_file} does not exist")
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
|
||||
m, u = model.load_state_dict(state_dict, strict=False)
|
||||
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
|
||||
# print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
|
||||
|
||||
params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
|
||||
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
|
||||
|
||||
return model
|
||||
733
animatediff/models/unet_blocks.py
Normal file
733
animatediff/models/unet_blocks.py
Normal file
@@ -0,0 +1,733 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import Transformer3DModel
|
||||
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
|
||||
from .motion_module import get_motion_module
|
||||
|
||||
import pdb
|
||||
|
||||
def get_down_block(
|
||||
down_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
temb_channels,
|
||||
add_downsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
downsample_padding=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock3D":
|
||||
return DownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
|
||||
return CrossAttnDownBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
up_block_type,
|
||||
num_layers,
|
||||
in_channels,
|
||||
out_channels,
|
||||
prev_output_channel,
|
||||
temb_channels,
|
||||
add_upsample,
|
||||
resnet_eps,
|
||||
resnet_act_fn,
|
||||
attn_num_head_channels,
|
||||
resnet_groups=None,
|
||||
cross_attention_dim=None,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock3D":
|
||||
return UpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock3D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
|
||||
return CrossAttnUpBlock3D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
upcast_attention=False,
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
motion_modules = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
if dual_cross_attention:
|
||||
raise NotImplementedError
|
||||
attentions.append(
|
||||
Transformer3DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=in_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
) if use_motion_module else None
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttnDownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
motion_modules = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
if dual_cross_attention:
|
||||
raise NotImplementedError
|
||||
attentions.append(
|
||||
Transformer3DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
) if use_motion_module else None
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample3D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
# add motion module
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class DownBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
motion_modules = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
) if use_motion_module else None
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
Downsample3D(
|
||||
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
# add motion module
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class CrossAttnUpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
motion_modules = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
if dual_cross_attention:
|
||||
raise NotImplementedError
|
||||
attentions.append(
|
||||
Transformer3DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
) if use_motion_module else None
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
||||
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
# add motion module
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UpBlock3D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
motion_modules = []
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock3D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
get_motion_module(
|
||||
in_channels=out_channels,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
) if use_motion_module else None
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
self.motion_modules = nn.ModuleList(motion_modules)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
|
||||
for resnet, motion_module in zip(self.resnets, self.motion_modules):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
if motion_module is not None:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size)
|
||||
|
||||
return hidden_states
|
||||
423
animatediff/pipelines/pipeline_animation.py
Normal file
423
animatediff/pipelines/pipeline_animation.py
Normal file
@@ -0,0 +1,423 @@
|
||||
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
|
||||
|
||||
import inspect
|
||||
from typing import Callable, List, Optional, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers.utils import is_accelerate_available
|
||||
from packaging import version
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.models import AutoencoderKL
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from diffusers.utils import deprecate, logging, BaseOutput
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from ..models.unet import UNet3DConditionModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnimationPipelineOutput(BaseOutput):
|
||||
videos: Union[torch.Tensor, np.ndarray]
|
||||
|
||||
|
||||
class AnimationPipeline(DiffusionPipeline):
|
||||
_optional_components = []
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet3DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
||||
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
||||
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
||||
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
||||
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
||||
" file"
|
||||
)
|
||||
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
||||
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
||||
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
||||
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
||||
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
||||
)
|
||||
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(scheduler.config)
|
||||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
|
||||
version.parse(unet.config._diffusers_version).base_version
|
||||
) < version.parse("0.9.0.dev0")
|
||||
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
|
||||
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
||||
deprecation_message = (
|
||||
"The configuration file of the unet has set the default `sample_size` to smaller than"
|
||||
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
|
||||
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
|
||||
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
|
||||
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
||||
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
|
||||
" in the config might lead to incorrect results in future versions. If you have downloaded this"
|
||||
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
|
||||
" the `unet/config.json` file"
|
||||
)
|
||||
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
|
||||
new_config = dict(unet.config)
|
||||
new_config["sample_size"] = 64
|
||||
unet._internal_dict = FrozenDict(new_config)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
self.vae.enable_slicing()
|
||||
|
||||
def disable_vae_slicing(self):
|
||||
self.vae.disable_slicing()
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
if is_accelerate_available():
|
||||
from accelerate import cpu_offload
|
||||
else:
|
||||
raise ImportError("Please install accelerate via `pip install accelerate`")
|
||||
|
||||
device = torch.device(f"cuda:{gpu_id}")
|
||||
|
||||
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
||||
if cpu_offloaded_model is not None:
|
||||
cpu_offload(cpu_offloaded_model, device)
|
||||
|
||||
|
||||
@property
|
||||
def _execution_device(self):
|
||||
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
|
||||
return self.device
|
||||
for module in self.unet.modules():
|
||||
if (
|
||||
hasattr(module, "_hf_hook")
|
||||
and hasattr(module._hf_hook, "execution_device")
|
||||
and module._hf_hook.execution_device is not None
|
||||
):
|
||||
return torch.device(module._hf_hook.execution_device)
|
||||
return self.device
|
||||
|
||||
def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
text_embeddings = self.text_encoder(
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
text_embeddings = text_embeddings[0]
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""] * batch_size
|
||||
elif type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
uncond_embeddings = self.text_encoder(
|
||||
uncond_input.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
uncond_embeddings = uncond_embeddings[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return text_embeddings
|
||||
|
||||
def decode_latents(self, latents):
|
||||
video_length = latents.shape[2]
|
||||
latents = 1 / 0.18215 * latents
|
||||
latents = rearrange(latents, "b c f h w -> (b f) c h w")
|
||||
video = self.vae.decode(latents).sample
|
||||
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
|
||||
video = (video / 2 + 0.5).clamp(0, 1)
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
video = video.cpu().float().numpy()
|
||||
return video
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def check_inputs(self, prompt, height, width, callback_steps):
|
||||
if not isinstance(prompt, str) and not isinstance(prompt, list):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if height % 8 != 0 or width % 8 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
rand_device = "cpu" if device.type == "mps" else device
|
||||
|
||||
if isinstance(generator, list):
|
||||
shape = shape
|
||||
# shape = (1,) + shape[1:]
|
||||
latents = [
|
||||
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
||||
for i in range(batch_size)
|
||||
]
|
||||
latents = torch.cat(latents, dim=0).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
video_length: Optional[int],
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "tensor",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
**kwargs,
|
||||
):
|
||||
# Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, callback_steps)
|
||||
|
||||
# Define call parameters
|
||||
# batch_size = 1 if isinstance(prompt, str) else len(prompt)
|
||||
batch_size = 1
|
||||
if latents is not None:
|
||||
batch_size = latents.shape[0]
|
||||
if isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# Encode input prompt
|
||||
prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
|
||||
if negative_prompt is not None:
|
||||
negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
|
||||
text_embeddings = self._encode_prompt(
|
||||
prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
|
||||
)
|
||||
|
||||
# Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare latent variables
|
||||
num_channels_latents = self.unet.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_videos_per_prompt,
|
||||
num_channels_latents,
|
||||
video_length,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
latents_dtype = latents.dtype
|
||||
|
||||
# Prepare extra step kwargs.
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
|
||||
# noise_pred = []
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
# for batch_idx in range(latent_model_input.shape[0]):
|
||||
# noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype)
|
||||
# noise_pred.append(noise_pred_single)
|
||||
# noise_pred = torch.cat(noise_pred)
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# Post-processing
|
||||
video = self.decode_latents(latents)
|
||||
|
||||
# Convert to tensor
|
||||
if output_type == "tensor":
|
||||
video = torch.from_numpy(video)
|
||||
|
||||
if not return_dict:
|
||||
return video
|
||||
|
||||
return AnimationPipelineOutput(videos=video)
|
||||
1382
animatediff/utils/convert_from_ckpt.py
Normal file
1382
animatediff/utils/convert_from_ckpt.py
Normal file
File diff suppressed because it is too large
Load Diff
132
animatediff/utils/convert_lora_safetensor_to_diffusers.py
Normal file
132
animatediff/utils/convert_lora_safetensor_to_diffusers.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
""" Conversion script for the LoRA's safetensors checkpoints. """
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import pdb
|
||||
|
||||
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
||||
# load base model
|
||||
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
||||
|
||||
# load LoRA weight from .safetensors
|
||||
# state_dict = load_file(checkpoint_path)
|
||||
|
||||
visited = []
|
||||
|
||||
# directly update weight in diffusers model
|
||||
for key in state_dict:
|
||||
# it is suggested to print out the key, it usually will be something like below
|
||||
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
||||
|
||||
# as we have set the alpha beforehand, so just skip
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
|
||||
if "text" in key:
|
||||
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
||||
curr_layer = pipeline.text_encoder
|
||||
else:
|
||||
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
||||
curr_layer = pipeline.unet
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
while len(layer_infos) > -1:
|
||||
try:
|
||||
curr_layer = curr_layer.__getattr__(temp_name)
|
||||
if len(layer_infos) > 0:
|
||||
temp_name = layer_infos.pop(0)
|
||||
elif len(layer_infos) == 0:
|
||||
break
|
||||
except Exception:
|
||||
if len(temp_name) > 0:
|
||||
temp_name += "_" + layer_infos.pop(0)
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
||||
# lora_dim = weight_up.shape[1]
|
||||
# curr_layer.weight.data += (1/lora_dim) * alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
||||
# lora_dim = weight_up.shape[1]
|
||||
# curr_layer.weight.data += (1/lora_dim) * alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
||||
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
parser.add_argument(
|
||||
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_prefix_text_encoder",
|
||||
default="lora_te",
|
||||
type=str,
|
||||
help="The prefix of text encoder weight in safetensors",
|
||||
)
|
||||
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
||||
parser.add_argument(
|
||||
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
||||
)
|
||||
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
base_model_path = args.base_model_path
|
||||
checkpoint_path = args.checkpoint_path
|
||||
dump_path = args.dump_path
|
||||
lora_prefix_unet = args.lora_prefix_unet
|
||||
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
||||
alpha = args.alpha
|
||||
|
||||
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
||||
|
||||
pipe = pipe.to(args.device)
|
||||
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
||||
84
animatediff/utils/util.py
Normal file
84
animatediff/utils/util.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import os
|
||||
import imageio
|
||||
import numpy as np
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
|
||||
videos = rearrange(videos, "b c t h w -> t b c h w")
|
||||
outputs = []
|
||||
for x in videos:
|
||||
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
||||
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
||||
if rescale:
|
||||
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
||||
x = (x * 255).numpy().astype(np.uint8)
|
||||
outputs.append(x)
|
||||
|
||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||
imageio.mimsave(path, outputs, fps=fps)
|
||||
|
||||
|
||||
# DDIM Inversion
|
||||
@torch.no_grad()
|
||||
def init_prompt(prompt, pipeline):
|
||||
uncond_input = pipeline.tokenizer(
|
||||
[""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
|
||||
return_tensors="pt"
|
||||
)
|
||||
uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
|
||||
text_input = pipeline.tokenizer(
|
||||
[prompt],
|
||||
padding="max_length",
|
||||
max_length=pipeline.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
|
||||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
|
||||
timestep, next_timestep = min(
|
||||
timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
|
||||
alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
|
||||
alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
|
||||
next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
|
||||
next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
|
||||
return next_sample
|
||||
|
||||
|
||||
def get_noise_pred_single(latents, t, context, unet):
|
||||
noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
|
||||
return noise_pred
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
||||
context = init_prompt(prompt, pipeline)
|
||||
uncond_embeddings, cond_embeddings = context.chunk(2)
|
||||
all_latent = [latent]
|
||||
latent = latent.clone().detach()
|
||||
for i in tqdm(range(num_inv_steps)):
|
||||
t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
|
||||
noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
|
||||
latent = next_step(noise_pred, t, latent, ddim_scheduler)
|
||||
all_latent.append(latent)
|
||||
return all_latent
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
||||
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
||||
return ddim_latents
|
||||
Reference in New Issue
Block a user