2023-11-10 11:57:39 +08:00
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2023-07-09 21:32:22 +08:00
from dataclasses import dataclass
2023-11-10 11:57:39 +08:00
from typing import Any , Dict , List , Optional , Tuple , Union
2023-07-09 21:32:22 +08:00
import os
import json
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
import torch
import torch . nn as nn
import torch . utils . checkpoint
2023-11-10 11:57:39 +08:00
from einops import rearrange , repeat
2023-07-09 21:32:22 +08:00
from diffusers . configuration_utils import ConfigMixin , register_to_config
2023-11-10 11:57:39 +08:00
from diffusers . loaders import UNet2DConditionLoadersMixin , AttnProcsLayers
2023-07-09 21:32:22 +08:00
from diffusers . utils import BaseOutput , logging
2023-11-10 11:57:39 +08:00
from diffusers . models . activations import get_activation
from diffusers . models . attention_processor import AttentionProcessor , AttnProcessor , LoRAAttnProcessor
from diffusers . models . embeddings import (
GaussianFourierProjection ,
ImageHintTimeEmbedding ,
ImageProjection ,
ImageTimeEmbedding ,
PositionNet ,
TextImageProjection ,
TextImageTimeEmbedding ,
TextTimeEmbedding ,
TimestepEmbedding ,
Timesteps ,
)
from diffusers . models . modeling_utils import ModelMixin
2023-07-09 21:32:22 +08:00
from . unet_blocks import (
UNetMidBlock3DCrossAttn ,
get_down_block ,
get_up_block ,
)
2023-11-10 11:57:39 +08:00
from animatediff . utils . util import zero_rank_print
2023-07-09 21:32:22 +08:00
logger = logging . get_logger ( __name__ ) # pylint: disable=invalid-name
@dataclass
class UNet3DConditionOutput ( BaseOutput ) :
2023-11-10 11:57:39 +08:00
"""
The output of [ ` UNet3DConditionModel ` ] .
Args :
sample ( ` torch . FloatTensor ` of shape ` ( batch_size , num_channels , height , width ) ` ) :
The hidden states output conditioned on ` encoder_hidden_states ` input . Output of last layer of model .
"""
sample : torch . FloatTensor = None
class UNet3DConditionModel ( ModelMixin , ConfigMixin , UNet2DConditionLoadersMixin ) :
r """
A conditional 3 D UNet model that takes a noisy sample , conditional state , and a timestep and returns a sample
shaped output .
This model inherits from [ ` ModelMixin ` ] . Check the superclass documentation for it ' s generic methods implemented
for all models ( such as downloading or saving ) .
Parameters :
sample_size ( ` int ` or ` Tuple [ int , int ] ` , * optional * , defaults to ` None ` ) :
Height and width of input / output sample .
in_channels ( ` int ` , * optional * , defaults to 4 ) : Number of channels in the input sample .
out_channels ( ` int ` , * optional * , defaults to 4 ) : Number of channels in the output .
center_input_sample ( ` bool ` , * optional * , defaults to ` False ` ) : Whether to center the input sample .
flip_sin_to_cos ( ` bool ` , * optional * , defaults to ` False ` ) :
Whether to flip the sin to cos in the time embedding .
freq_shift ( ` int ` , * optional * , defaults to 0 ) : The frequency shift to apply to the time embedding .
down_block_types ( ` Tuple [ str ] ` , * optional * , defaults to ` ( " CrossAttnDownBlock2D " , " CrossAttnDownBlock2D " , " CrossAttnDownBlock2D " , " DownBlock2D " ) ` ) :
The tuple of downsample blocks to use .
mid_block_type ( ` str ` , * optional * , defaults to ` " UNetMidBlock2DCrossAttn " ` ) :
Block type for middle of UNet , it can be either ` UNetMidBlock2DCrossAttn ` or
` UNetMidBlock2DSimpleCrossAttn ` . If ` None ` , the mid block layer is skipped .
up_block_types ( ` Tuple [ str ] ` , * optional * , defaults to ` ( " UpBlock2D " , " CrossAttnUpBlock2D " , " CrossAttnUpBlock2D " , " CrossAttnUpBlock2D " ) ` ) :
The tuple of upsample blocks to use .
only_cross_attention ( ` bool ` or ` Tuple [ bool ] ` , * optional * , default to ` False ` ) :
Whether to include self - attention in the basic transformer blocks , see
[ ` ~ models . attention . BasicTransformerBlock ` ] .
block_out_channels ( ` Tuple [ int ] ` , * optional * , defaults to ` ( 320 , 640 , 1280 , 1280 ) ` ) :
The tuple of output channels for each block .
layers_per_block ( ` int ` , * optional * , defaults to 2 ) : The number of layers per block .
downsample_padding ( ` int ` , * optional * , defaults to 1 ) : The padding to use for the downsampling convolution .
mid_block_scale_factor ( ` float ` , * optional * , defaults to 1.0 ) : The scale factor to use for the mid block .
act_fn ( ` str ` , * optional * , defaults to ` " silu " ` ) : The activation function to use .
norm_num_groups ( ` int ` , * optional * , defaults to 32 ) : The number of groups to use for the normalization .
If ` None ` , normalization and activation layers is skipped in post - processing .
norm_eps ( ` float ` , * optional * , defaults to 1e-5 ) : The epsilon to use for the normalization .
cross_attention_dim ( ` int ` or ` Tuple [ int ] ` , * optional * , defaults to 1280 ) :
The dimension of the cross attention features .
transformer_layers_per_block ( ` int ` or ` Tuple [ int ] ` , * optional * , defaults to 1 ) :
The number of transformer blocks of type [ ` ~ models . attention . BasicTransformerBlock ` ] . Only relevant for
[ ` ~ models . unet_2d_blocks . CrossAttnDownBlock2D ` ] , [ ` ~ models . unet_2d_blocks . CrossAttnUpBlock2D ` ] ,
[ ` ~ models . unet_2d_blocks . UNetMidBlock2DCrossAttn ` ] .
encoder_hid_dim ( ` int ` , * optional * , defaults to None ) :
If ` encoder_hid_dim_type ` is defined , ` encoder_hidden_states ` will be projected from ` encoder_hid_dim `
dimension to ` cross_attention_dim ` .
encoder_hid_dim_type ( ` str ` , * optional * , defaults to ` None ` ) :
If given , the ` encoder_hidden_states ` and potentially other embeddings are down - projected to text
embeddings of dimension ` cross_attention ` according to ` encoder_hid_dim_type ` .
attention_head_dim ( ` int ` , * optional * , defaults to 8 ) : The dimension of the attention heads .
num_attention_heads ( ` int ` , * optional * ) :
The number of attention heads . If not defined , defaults to ` attention_head_dim `
resnet_time_scale_shift ( ` str ` , * optional * , defaults to ` " default " ` ) : Time scale shift config
for ResNet blocks ( see [ ` ~ models . resnet . ResnetBlock2D ` ] ) . Choose from ` default ` or ` scale_shift ` .
class_embed_type ( ` str ` , * optional * , defaults to ` None ` ) :
The type of class embedding to use which is ultimately summed with the time embeddings . Choose from ` None ` ,
` " timestep " ` , ` " identity " ` , ` " projection " ` , or ` " simple_projection " ` .
addition_embed_type ( ` str ` , * optional * , defaults to ` None ` ) :
Configures an optional embedding which will be summed with the time embeddings . Choose from ` None ` or
" text " . " text " will use the ` TextTimeEmbedding ` layer .
addition_time_embed_dim : ( ` int ` , * optional * , defaults to ` None ` ) :
Dimension for the timestep embeddings .
num_class_embeds ( ` int ` , * optional * , defaults to ` None ` ) :
Input dimension of the learnable embedding matrix to be projected to ` time_embed_dim ` , when performing
class conditioning with ` class_embed_type ` equal to ` None ` .
time_embedding_type ( ` str ` , * optional * , defaults to ` positional ` ) :
The type of position embedding to use for timesteps . Choose from ` positional ` or ` fourier ` .
time_embedding_dim ( ` int ` , * optional * , defaults to ` None ` ) :
An optional override for the dimension of the projected time embedding .
time_embedding_act_fn ( ` str ` , * optional * , defaults to ` None ` ) :
Optional activation function to use only once on the time embeddings before they are passed to the rest of
the UNet . Choose from ` silu ` , ` mish ` , ` gelu ` , and ` swish ` .
timestep_post_act ( ` str ` , * optional * , defaults to ` None ` ) :
The second activation function to use in timestep embedding . Choose from ` silu ` , ` mish ` and ` gelu ` .
time_cond_proj_dim ( ` int ` , * optional * , defaults to ` None ` ) :
The dimension of ` cond_proj ` layer in the timestep embedding .
conv_in_kernel ( ` int ` , * optional * , default to ` 3 ` ) : The kernel size of ` conv_in ` layer .
conv_out_kernel ( ` int ` , * optional * , default to ` 3 ` ) : The kernel size of ` conv_out ` layer .
projection_class_embeddings_input_dim ( ` int ` , * optional * ) : The dimension of the ` class_labels ` input when
` class_embed_type = " projection " ` . Required when ` class_embed_type = " projection " ` .
class_embeddings_concat ( ` bool ` , * optional * , defaults to ` False ` ) : Whether to concatenate the time
embeddings with the class embeddings .
mid_block_only_cross_attention ( ` bool ` , * optional * , defaults to ` None ` ) :
Whether to use cross attention with the mid block when using the ` UNetMidBlock2DSimpleCrossAttn ` . If
` only_cross_attention ` is given as a single boolean and ` mid_block_only_cross_attention ` is ` None ` , the
` only_cross_attention ` value is used as the value for ` mid_block_only_cross_attention ` . Default to ` False `
otherwise .
"""
2023-07-09 21:32:22 +08:00
_supports_gradient_checkpointing = True
@register_to_config
def __init__ (
self ,
sample_size : Optional [ int ] = None ,
in_channels : int = 4 ,
out_channels : int = 4 ,
center_input_sample : bool = False ,
flip_sin_to_cos : bool = True ,
2023-11-10 11:57:39 +08:00
freq_shift : int = 0 ,
2023-07-09 21:32:22 +08:00
down_block_types : Tuple [ str ] = (
" CrossAttnDownBlock3D " ,
" CrossAttnDownBlock3D " ,
" CrossAttnDownBlock3D " ,
" DownBlock3D " ,
) ,
2023-11-10 11:57:39 +08:00
mid_block_type : Optional [ str ] = " UNetMidBlock3DCrossAttn " ,
up_block_types : Tuple [ str ] = ( " UpBlock3D " , " CrossAttnUpBlock3D " , " CrossAttnUpBlock3D " , " CrossAttnUpBlock3D " ) ,
2023-07-09 21:32:22 +08:00
only_cross_attention : Union [ bool , Tuple [ bool ] ] = False ,
block_out_channels : Tuple [ int ] = ( 320 , 640 , 1280 , 1280 ) ,
2023-11-10 11:57:39 +08:00
layers_per_block : Union [ int , Tuple [ int ] ] = 2 ,
2023-07-09 21:32:22 +08:00
downsample_padding : int = 1 ,
mid_block_scale_factor : float = 1 ,
act_fn : str = " silu " ,
2023-11-10 11:57:39 +08:00
norm_num_groups : Optional [ int ] = 32 ,
2023-07-09 21:32:22 +08:00
norm_eps : float = 1e-5 ,
2023-11-10 11:57:39 +08:00
cross_attention_dim : Union [ int , Tuple [ int ] ] = 1280 ,
transformer_layers_per_block : Union [ int , Tuple [ int ] ] = 1 ,
encoder_hid_dim : Optional [ int ] = None ,
encoder_hid_dim_type : Optional [ str ] = None ,
2023-07-09 21:32:22 +08:00
attention_head_dim : Union [ int , Tuple [ int ] ] = 8 ,
2023-11-10 11:57:39 +08:00
num_attention_heads : Optional [ Union [ int , Tuple [ int ] ] ] = None ,
2023-07-09 21:32:22 +08:00
dual_cross_attention : bool = False ,
use_linear_projection : bool = False ,
class_embed_type : Optional [ str ] = None ,
2023-11-10 11:57:39 +08:00
addition_embed_type : Optional [ str ] = None ,
addition_time_embed_dim : Optional [ int ] = None ,
2023-07-09 21:32:22 +08:00
num_class_embeds : Optional [ int ] = None ,
upcast_attention : bool = False ,
resnet_time_scale_shift : str = " default " ,
2023-11-10 11:57:39 +08:00
resnet_skip_time_act : bool = False ,
resnet_out_scale_factor : int = 1.0 ,
time_embedding_type : str = " positional " ,
time_embedding_dim : Optional [ int ] = None ,
time_embedding_act_fn : Optional [ str ] = None ,
timestep_post_act : Optional [ str ] = None ,
time_cond_proj_dim : Optional [ int ] = None ,
conv_in_kernel : int = 3 ,
conv_out_kernel : int = 3 ,
projection_class_embeddings_input_dim : Optional [ int ] = None ,
attention_type : str = " default " ,
class_embeddings_concat : bool = False ,
mid_block_only_cross_attention : Optional [ bool ] = None ,
cross_attention_norm : Optional [ str ] = None ,
addition_embed_type_num_heads = 64 ,
# motion module
use_motion_module = False ,
motion_module_resolutions = ( 1 , 2 , 4 , 8 ) ,
motion_module_mid_block = False ,
motion_module_decoder_only = False ,
motion_module_type = None ,
motion_module_kwargs = None ,
2023-07-09 21:32:22 +08:00
) :
super ( ) . __init__ ( )
self . sample_size = sample_size
2023-11-10 11:57:39 +08:00
if num_attention_heads is not None :
raise ValueError (
" At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19. "
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len ( down_block_types ) != len ( up_block_types ) :
raise ValueError (
f " Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: { down_block_types } . `up_block_types`: { up_block_types } . "
)
if len ( block_out_channels ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: { block_out_channels } . `down_block_types`: { down_block_types } . "
)
if not isinstance ( only_cross_attention , bool ) and len ( only_cross_attention ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: { only_cross_attention } . `down_block_types`: { down_block_types } . "
)
if not isinstance ( num_attention_heads , int ) and len ( num_attention_heads ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: { num_attention_heads } . `down_block_types`: { down_block_types } . "
)
if not isinstance ( attention_head_dim , int ) and len ( attention_head_dim ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: { attention_head_dim } . `down_block_types`: { down_block_types } . "
)
if isinstance ( cross_attention_dim , list ) and len ( cross_attention_dim ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: { cross_attention_dim } . `down_block_types`: { down_block_types } . "
)
if not isinstance ( layers_per_block , int ) and len ( layers_per_block ) != len ( down_block_types ) :
raise ValueError (
f " Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: { layers_per_block } . `down_block_types`: { down_block_types } . "
)
2023-07-09 21:32:22 +08:00
# input
2023-11-10 11:57:39 +08:00
conv_in_padding = ( conv_in_kernel - 1 ) / / 2
self . conv_in = nn . Conv2d (
in_channels , block_out_channels [ 0 ] , kernel_size = conv_in_kernel , padding = conv_in_padding
)
2023-07-09 21:32:22 +08:00
# time
2023-11-10 11:57:39 +08:00
if time_embedding_type == " fourier " :
time_embed_dim = time_embedding_dim or block_out_channels [ 0 ] * 2
if time_embed_dim % 2 != 0 :
raise ValueError ( f " `time_embed_dim` should be divisible by 2, but is { time_embed_dim } . " )
self . time_proj = GaussianFourierProjection (
time_embed_dim / / 2 , set_W_to_weight = False , log = False , flip_sin_to_cos = flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == " positional " :
time_embed_dim = time_embedding_dim or block_out_channels [ 0 ] * 4
self . time_proj = Timesteps ( block_out_channels [ 0 ] , flip_sin_to_cos , freq_shift )
timestep_input_dim = block_out_channels [ 0 ]
else :
raise ValueError (
f " { time_embedding_type } does not exist. Please make sure to use one of `fourier` or `positional`. "
)
self . time_embedding = TimestepEmbedding (
timestep_input_dim ,
time_embed_dim ,
act_fn = act_fn ,
post_act_fn = timestep_post_act ,
cond_proj_dim = time_cond_proj_dim ,
)
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
if encoder_hid_dim_type is None and encoder_hid_dim is not None :
encoder_hid_dim_type = " text_proj "
self . register_to_config ( encoder_hid_dim_type = encoder_hid_dim_type )
logger . info ( " encoder_hid_dim_type defaults to ' text_proj ' as `encoder_hid_dim` is defined. " )
if encoder_hid_dim is None and encoder_hid_dim_type is not None :
raise ValueError (
f " `encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to { encoder_hid_dim_type } . "
)
if encoder_hid_dim_type == " text_proj " :
self . encoder_hid_proj = nn . Linear ( encoder_hid_dim , cross_attention_dim )
elif encoder_hid_dim_type == " text_image_proj " :
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self . encoder_hid_proj = TextImageProjection (
text_embed_dim = encoder_hid_dim ,
image_embed_dim = cross_attention_dim ,
cross_attention_dim = cross_attention_dim ,
)
elif encoder_hid_dim_type == " image_proj " :
# Kandinsky 2.2
self . encoder_hid_proj = ImageProjection (
image_embed_dim = encoder_hid_dim ,
cross_attention_dim = cross_attention_dim ,
)
elif encoder_hid_dim_type is not None :
raise ValueError (
f " encoder_hid_dim_type: { encoder_hid_dim_type } must be None, ' text_proj ' or ' text_image_proj ' . "
)
else :
self . encoder_hid_proj = None
2023-07-09 21:32:22 +08:00
# class embedding
if class_embed_type is None and num_class_embeds is not None :
self . class_embedding = nn . Embedding ( num_class_embeds , time_embed_dim )
elif class_embed_type == " timestep " :
2023-11-10 11:57:39 +08:00
self . class_embedding = TimestepEmbedding ( timestep_input_dim , time_embed_dim , act_fn = act_fn )
2023-07-09 21:32:22 +08:00
elif class_embed_type == " identity " :
self . class_embedding = nn . Identity ( time_embed_dim , time_embed_dim )
2023-11-10 11:57:39 +08:00
elif class_embed_type == " projection " :
if projection_class_embeddings_input_dim is None :
raise ValueError (
" `class_embed_type`: ' projection ' requires `projection_class_embeddings_input_dim` be set "
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self . class_embedding = TimestepEmbedding ( projection_class_embeddings_input_dim , time_embed_dim )
elif class_embed_type == " simple_projection " :
if projection_class_embeddings_input_dim is None :
raise ValueError (
" `class_embed_type`: ' simple_projection ' requires `projection_class_embeddings_input_dim` be set "
)
self . class_embedding = nn . Linear ( projection_class_embeddings_input_dim , time_embed_dim )
2023-07-09 21:32:22 +08:00
else :
self . class_embedding = None
2023-11-10 11:57:39 +08:00
if addition_embed_type == " text " :
if encoder_hid_dim is not None :
text_time_embedding_from_dim = encoder_hid_dim
else :
text_time_embedding_from_dim = cross_attention_dim
self . add_embedding = TextTimeEmbedding (
text_time_embedding_from_dim , time_embed_dim , num_heads = addition_embed_type_num_heads
)
elif addition_embed_type == " text_image " :
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self . add_embedding = TextImageTimeEmbedding (
text_embed_dim = cross_attention_dim , image_embed_dim = cross_attention_dim , time_embed_dim = time_embed_dim
)
elif addition_embed_type == " text_time " :
self . add_time_proj = Timesteps ( addition_time_embed_dim , flip_sin_to_cos , freq_shift )
self . add_embedding = TimestepEmbedding ( projection_class_embeddings_input_dim , time_embed_dim )
elif addition_embed_type == " image " :
# Kandinsky 2.2
self . add_embedding = ImageTimeEmbedding ( image_embed_dim = encoder_hid_dim , time_embed_dim = time_embed_dim )
elif addition_embed_type == " image_hint " :
# Kandinsky 2.2 ControlNet
self . add_embedding = ImageHintTimeEmbedding ( image_embed_dim = encoder_hid_dim , time_embed_dim = time_embed_dim )
elif addition_embed_type is not None :
raise ValueError ( f " addition_embed_type: { addition_embed_type } must be None, ' text ' or ' text_image ' . " )
if time_embedding_act_fn is None :
self . time_embed_act = None
else :
self . time_embed_act = get_activation ( time_embedding_act_fn )
2023-07-09 21:32:22 +08:00
self . down_blocks = nn . ModuleList ( [ ] )
self . up_blocks = nn . ModuleList ( [ ] )
if isinstance ( only_cross_attention , bool ) :
2023-11-10 11:57:39 +08:00
if mid_block_only_cross_attention is None :
mid_block_only_cross_attention = only_cross_attention
2023-07-09 21:32:22 +08:00
only_cross_attention = [ only_cross_attention ] * len ( down_block_types )
2023-11-10 11:57:39 +08:00
if mid_block_only_cross_attention is None :
mid_block_only_cross_attention = False
if isinstance ( num_attention_heads , int ) :
num_attention_heads = ( num_attention_heads , ) * len ( down_block_types )
2023-07-09 21:32:22 +08:00
if isinstance ( attention_head_dim , int ) :
attention_head_dim = ( attention_head_dim , ) * len ( down_block_types )
2023-11-10 11:57:39 +08:00
if isinstance ( cross_attention_dim , int ) :
cross_attention_dim = ( cross_attention_dim , ) * len ( down_block_types )
if isinstance ( layers_per_block , int ) :
layers_per_block = [ layers_per_block ] * len ( down_block_types )
if isinstance ( transformer_layers_per_block , int ) :
transformer_layers_per_block = [ transformer_layers_per_block ] * len ( down_block_types )
if class_embeddings_concat :
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else :
blocks_time_embed_dim = time_embed_dim
2023-07-09 21:32:22 +08:00
# down
output_channel = block_out_channels [ 0 ]
for i , down_block_type in enumerate ( down_block_types ) :
input_channel = output_channel
output_channel = block_out_channels [ i ]
is_final_block = i == len ( block_out_channels ) - 1
2023-11-10 11:57:39 +08:00
res = 2 * * i
2023-07-09 21:32:22 +08:00
down_block = get_down_block (
down_block_type ,
2023-11-10 11:57:39 +08:00
num_layers = layers_per_block [ i ] ,
transformer_layers_per_block = transformer_layers_per_block [ i ] ,
2023-07-09 21:32:22 +08:00
in_channels = input_channel ,
out_channels = output_channel ,
2023-11-10 11:57:39 +08:00
temb_channels = blocks_time_embed_dim ,
2023-07-09 21:32:22 +08:00
add_downsample = not is_final_block ,
resnet_eps = norm_eps ,
resnet_act_fn = act_fn ,
resnet_groups = norm_num_groups ,
2023-11-10 11:57:39 +08:00
cross_attention_dim = cross_attention_dim [ i ] ,
num_attention_heads = num_attention_heads [ i ] ,
2023-07-09 21:32:22 +08:00
downsample_padding = downsample_padding ,
dual_cross_attention = dual_cross_attention ,
use_linear_projection = use_linear_projection ,
only_cross_attention = only_cross_attention [ i ] ,
upcast_attention = upcast_attention ,
resnet_time_scale_shift = resnet_time_scale_shift ,
2023-11-10 11:57:39 +08:00
attention_type = attention_type ,
resnet_skip_time_act = resnet_skip_time_act ,
resnet_out_scale_factor = resnet_out_scale_factor ,
cross_attention_norm = cross_attention_norm ,
attention_head_dim = attention_head_dim [ i ] if attention_head_dim [ i ] is not None else output_channel ,
2023-07-09 21:32:22 +08:00
use_motion_module = use_motion_module and ( res in motion_module_resolutions ) and ( not motion_module_decoder_only ) ,
motion_module_type = motion_module_type ,
motion_module_kwargs = motion_module_kwargs ,
)
self . down_blocks . append ( down_block )
# mid
if mid_block_type == " UNetMidBlock3DCrossAttn " :
self . mid_block = UNetMidBlock3DCrossAttn (
2023-11-10 11:57:39 +08:00
transformer_layers_per_block = transformer_layers_per_block [ - 1 ] ,
2023-07-09 21:32:22 +08:00
in_channels = block_out_channels [ - 1 ] ,
2023-11-10 11:57:39 +08:00
temb_channels = blocks_time_embed_dim ,
2023-07-09 21:32:22 +08:00
resnet_eps = norm_eps ,
resnet_act_fn = act_fn ,
output_scale_factor = mid_block_scale_factor ,
resnet_time_scale_shift = resnet_time_scale_shift ,
2023-11-10 11:57:39 +08:00
cross_attention_dim = cross_attention_dim [ - 1 ] ,
num_attention_heads = num_attention_heads [ - 1 ] ,
2023-07-09 21:32:22 +08:00
resnet_groups = norm_num_groups ,
dual_cross_attention = dual_cross_attention ,
use_linear_projection = use_linear_projection ,
upcast_attention = upcast_attention ,
2023-11-10 11:57:39 +08:00
attention_type = attention_type ,
2023-07-09 21:32:22 +08:00
use_motion_module = use_motion_module and motion_module_mid_block ,
motion_module_type = motion_module_type ,
motion_module_kwargs = motion_module_kwargs ,
)
2023-11-10 11:57:39 +08:00
elif mid_block_type is None :
self . mid_block = None
2023-07-09 21:32:22 +08:00
else :
raise ValueError ( f " unknown mid_block_type : { mid_block_type } " )
2023-11-10 11:57:39 +08:00
# count how many layers upsample the images
2023-07-09 21:32:22 +08:00
self . num_upsamplers = 0
# up
reversed_block_out_channels = list ( reversed ( block_out_channels ) )
2023-11-10 11:57:39 +08:00
reversed_num_attention_heads = list ( reversed ( num_attention_heads ) )
reversed_layers_per_block = list ( reversed ( layers_per_block ) )
reversed_cross_attention_dim = list ( reversed ( cross_attention_dim ) )
reversed_transformer_layers_per_block = list ( reversed ( transformer_layers_per_block ) )
2023-07-09 21:32:22 +08:00
only_cross_attention = list ( reversed ( only_cross_attention ) )
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
output_channel = reversed_block_out_channels [ 0 ]
for i , up_block_type in enumerate ( up_block_types ) :
is_final_block = i == len ( block_out_channels ) - 1
2023-11-10 11:57:39 +08:00
res = 2 * * ( 2 - i )
2023-07-09 21:32:22 +08:00
prev_output_channel = output_channel
output_channel = reversed_block_out_channels [ i ]
input_channel = reversed_block_out_channels [ min ( i + 1 , len ( block_out_channels ) - 1 ) ]
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
# add upsample block for all BUT final layer
if not is_final_block :
add_upsample = True
self . num_upsamplers + = 1
else :
add_upsample = False
up_block = get_up_block (
up_block_type ,
2023-11-10 11:57:39 +08:00
num_layers = reversed_layers_per_block [ i ] + 1 ,
transformer_layers_per_block = reversed_transformer_layers_per_block [ i ] ,
2023-07-09 21:32:22 +08:00
in_channels = input_channel ,
out_channels = output_channel ,
prev_output_channel = prev_output_channel ,
2023-11-10 11:57:39 +08:00
temb_channels = blocks_time_embed_dim ,
2023-07-09 21:32:22 +08:00
add_upsample = add_upsample ,
resnet_eps = norm_eps ,
resnet_act_fn = act_fn ,
resnet_groups = norm_num_groups ,
2023-11-10 11:57:39 +08:00
cross_attention_dim = reversed_cross_attention_dim [ i ] ,
num_attention_heads = reversed_num_attention_heads [ i ] ,
2023-07-09 21:32:22 +08:00
dual_cross_attention = dual_cross_attention ,
use_linear_projection = use_linear_projection ,
only_cross_attention = only_cross_attention [ i ] ,
upcast_attention = upcast_attention ,
resnet_time_scale_shift = resnet_time_scale_shift ,
2023-11-10 11:57:39 +08:00
attention_type = attention_type ,
resnet_skip_time_act = resnet_skip_time_act ,
resnet_out_scale_factor = resnet_out_scale_factor ,
cross_attention_norm = cross_attention_norm ,
attention_head_dim = attention_head_dim [ i ] if attention_head_dim [ i ] is not None else output_channel ,
2023-07-09 21:32:22 +08:00
use_motion_module = use_motion_module and ( res in motion_module_resolutions ) ,
motion_module_type = motion_module_type ,
motion_module_kwargs = motion_module_kwargs ,
)
self . up_blocks . append ( up_block )
prev_output_channel = output_channel
# out
2023-11-10 11:57:39 +08:00
if norm_num_groups is not None :
self . conv_norm_out = nn . GroupNorm (
num_channels = block_out_channels [ 0 ] , num_groups = norm_num_groups , eps = norm_eps
)
self . conv_act = get_activation ( act_fn )
2023-09-10 21:26:51 +08:00
else :
2023-11-10 11:57:39 +08:00
self . conv_norm_out = None
self . conv_act = None
conv_out_padding = ( conv_out_kernel - 1 ) / / 2
self . conv_out = nn . Conv2d (
block_out_channels [ 0 ] , out_channels , kernel_size = conv_out_kernel , padding = conv_out_padding
)
if attention_type == " gated " :
positive_len = 768
if isinstance ( cross_attention_dim , int ) :
positive_len = cross_attention_dim
elif isinstance ( cross_attention_dim , tuple ) or isinstance ( cross_attention_dim , list ) :
positive_len = cross_attention_dim [ 0 ]
self . position_net = PositionNet ( positive_len = positive_len , out_dim = cross_attention_dim )
def set_image_layer_lora ( self , image_layer_lora_rank : int = 128 ) :
lora_attn_procs = { }
for name in self . attn_processors . keys ( ) :
zero_rank_print ( f " (add lora) { name } " )
cross_attention_dim = None if name . endswith ( " attn1.processor " ) else self . config . cross_attention_dim
if name . startswith ( " mid_block " ) :
hidden_size = self . config . block_out_channels [ - 1 ]
elif name . startswith ( " up_blocks " ) :
block_id = int ( name [ len ( " up_blocks. " ) ] )
hidden_size = list ( reversed ( self . config . block_out_channels ) ) [ block_id ]
elif name . startswith ( " down_blocks " ) :
block_id = int ( name [ len ( " down_blocks. " ) ] )
hidden_size = self . config . block_out_channels [ block_id ]
lora_attn_procs [ name ] = LoRAAttnProcessor (
hidden_size = hidden_size ,
cross_attention_dim = cross_attention_dim ,
rank = image_layer_lora_rank if image_layer_lora_rank > 16 else hidden_size / / image_layer_lora_rank ,
)
self . set_attn_processor ( lora_attn_procs )
lora_layers = AttnProcsLayers ( self . attn_processors )
zero_rank_print ( f " (lora parameters): { sum ( p . numel ( ) for p in lora_layers . parameters ( ) ) / 1e6 : .3f } M " )
del lora_layers
def set_image_layer_lora_scale ( self , lora_scale : float = 1.0 ) :
for block in self . down_blocks : setattr ( block , " lora_scale " , lora_scale )
for block in self . up_blocks : setattr ( block , " lora_scale " , lora_scale )
setattr ( self . mid_block , " lora_scale " , lora_scale )
@property
def attn_processors ( self ) - > Dict [ str , AttentionProcessor ] :
r """
Returns :
` dict ` of attention processors : A dictionary containing all attention processors used in the model with
indexed by its weight name .
"""
# set recursively
processors = { }
def fn_recursive_add_processors ( name : str , module : torch . nn . Module , processors : Dict [ str , AttentionProcessor ] ) :
if hasattr ( module , " set_processor " ) :
if not " motion_modules. " in name :
processors [ f " { name } .processor " ] = module . processor
for sub_name , child in module . named_children ( ) :
fn_recursive_add_processors ( f " { name } . { sub_name } " , child , processors )
return processors
for name , module in self . named_children ( ) :
fn_recursive_add_processors ( name , module , processors )
return processors
def set_attn_processor ( self , processor : Union [ AttentionProcessor , Dict [ str , AttentionProcessor ] ] , is_motion_module = False ) :
r """
Sets the attention processor to use to compute attention .
Parameters :
processor ( ` dict ` of ` AttentionProcessor ` or only ` AttentionProcessor ` ) :
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for * * all * * ` Attention ` layers .
If ` processor ` is a dict , the key needs to define the path to the corresponding cross attention
processor . This is strongly recommended when setting trainable attention processors .
"""
count = len ( self . attn_processors . keys ( ) ) if not is_motion_module else len ( self . motion_module_attn_processors . keys ( ) )
if isinstance ( processor , dict ) and len ( processor ) != count :
raise ValueError (
f " A dict of processors was passed, but the number of processors { len ( processor ) } does not match the "
f " number of attention layers: { count } . Please make sure to pass { count } processor classes. "
)
def fn_recursive_attn_processor ( name : str , module : torch . nn . Module , processor ) :
if hasattr ( module , " set_processor " ) :
if ( ( not is_motion_module ) and ( not " motion_modules. " in name ) ) or ( is_motion_module and ( " motion_modules. " in name ) ) :
if not isinstance ( processor , dict ) :
module . set_processor ( processor )
else :
module . set_processor ( processor . pop ( f " { name } .processor " ) )
for sub_name , child in module . named_children ( ) :
fn_recursive_attn_processor ( f " { name } . { sub_name } " , child , processor )
for name , module in self . named_children ( ) :
fn_recursive_attn_processor ( name , module , processor )
def set_default_attn_processor ( self ) :
"""
Disables custom attention processors and sets the default attention implementation .
"""
self . set_attn_processor ( AttnProcessor ( ) )
@property
def motion_module_attn_processors ( self ) :
# set recursively
processors = { }
def fn_recursive_add_processors ( name : str , module : torch . nn . Module , processors : Dict [ str , AttentionProcessor ] ) :
# filter out processors in motion module
if hasattr ( module , " set_processor " ) :
if " motion_modules. " in name :
processors [ f " { name } .processor " ] = module . processor
for sub_name , child in module . named_children ( ) :
fn_recursive_add_processors ( f " { name } . { sub_name } " , child , processors )
return processors
for name , module in self . named_children ( ) :
fn_recursive_add_processors ( name , module , processors )
return processors
def set_motion_module_lora ( self , motion_module_lora_rank : int = 256 , motion_lora_resolution = [ 32 , 64 , 128 ] ) :
lora_attn_procs = { }
#motion_name = []
#if 32 in motion_lora_resolution:
# motion_name.append('up_blocks.0')
# motion_name.append('down_blocks.2')
# if 64 in motion_lora_resolution:
# motion_name.append('up_blocks.1')
# motion_name.append('down_blocks.1')
# if 128 in motion_lora_resolution:
# motion_name.append('up_blocks.2')
# motion_name.append('down_blocks.0')
for name in self . motion_module_attn_processors . keys ( ) :
#prefix = '.'.join(name.split('.')[:2])
#if prefix not in motion_name:
# continue
print ( f " (add motion lora) { name } " )
if name . startswith ( " mid_block " ) :
hidden_size = self . config . block_out_channels [ - 1 ]
elif name . startswith ( " up_blocks " ) :
block_id = int ( name [ len ( " up_blocks. " ) ] )
hidden_size = list ( reversed ( self . config . block_out_channels ) ) [ block_id ]
elif name . startswith ( " down_blocks " ) :
block_id = int ( name [ len ( " down_blocks. " ) ] )
hidden_size = self . config . block_out_channels [ block_id ]
lora_attn_procs [ name ] = LoRAAttnProcessor (
hidden_size = hidden_size ,
cross_attention_dim = None ,
rank = motion_module_lora_rank ,
)
self . set_attn_processor ( lora_attn_procs , is_motion_module = True )
lora_layers = AttnProcsLayers ( self . motion_module_attn_processors )
print ( f " (motion lora parameters): { sum ( p . numel ( ) for p in lora_layers . parameters ( ) ) / 1e6 : .3f } M " )
del lora_layers
2023-07-09 21:32:22 +08:00
def set_attention_slice ( self , slice_size ) :
r """
Enable sliced attention computation .
2023-11-10 11:57:39 +08:00
When this option is enabled , the attention module splits the input tensor in slices to compute attention in
several steps . This is useful for saving some memory in exchange for a small decrease in speed .
2023-07-09 21:32:22 +08:00
Args :
slice_size ( ` str ` or ` int ` or ` list ( int ) ` , * optional * , defaults to ` " auto " ` ) :
2023-11-10 11:57:39 +08:00
When ` " auto " ` , input to the attention heads is halved , so attention is computed in two steps . If
` " max " ` , maximum amount of memory is saved by running only one slice at a time . If a number is
2023-07-09 21:32:22 +08:00
provided , uses as many slices as ` attention_head_dim / / slice_size ` . In this case , ` attention_head_dim `
must be a multiple of ` slice_size ` .
"""
sliceable_head_dims = [ ]
2023-11-10 11:57:39 +08:00
def fn_recursive_retrieve_sliceable_dims ( module : torch . nn . Module ) :
2023-07-09 21:32:22 +08:00
if hasattr ( module , " set_attention_slice " ) :
sliceable_head_dims . append ( module . sliceable_head_dim )
for child in module . children ( ) :
2023-11-10 11:57:39 +08:00
fn_recursive_retrieve_sliceable_dims ( child )
2023-07-09 21:32:22 +08:00
# retrieve number of attention layers
for module in self . children ( ) :
2023-11-10 11:57:39 +08:00
fn_recursive_retrieve_sliceable_dims ( module )
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
num_sliceable_layers = len ( sliceable_head_dims )
2023-07-09 21:32:22 +08:00
if slice_size == " auto " :
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [ dim / / 2 for dim in sliceable_head_dims ]
elif slice_size == " max " :
# make smallest slice possible
2023-11-10 11:57:39 +08:00
slice_size = num_sliceable_layers * [ 1 ]
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
slice_size = num_sliceable_layers * [ slice_size ] if not isinstance ( slice_size , list ) else slice_size
2023-07-09 21:32:22 +08:00
if len ( slice_size ) != len ( sliceable_head_dims ) :
raise ValueError (
f " You have provided { len ( slice_size ) } , but { self . config } has { len ( sliceable_head_dims ) } different "
f " attention layers. Make sure to match `len(slice_size)` to be { len ( sliceable_head_dims ) } . "
)
for i in range ( len ( slice_size ) ) :
size = slice_size [ i ]
dim = sliceable_head_dims [ i ]
if size is not None and size > dim :
raise ValueError ( f " size { size } has to be smaller or equal to { dim } . " )
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice ( module : torch . nn . Module , slice_size : List [ int ] ) :
if hasattr ( module , " set_attention_slice " ) :
module . set_attention_slice ( slice_size . pop ( ) )
for child in module . children ( ) :
fn_recursive_set_attention_slice ( child , slice_size )
reversed_slice_size = list ( reversed ( slice_size ) )
for module in self . children ( ) :
fn_recursive_set_attention_slice ( module , reversed_slice_size )
def _set_gradient_checkpointing ( self , module , value = False ) :
2023-11-10 11:57:39 +08:00
if hasattr ( module , " gradient_checkpointing " ) :
2023-07-09 21:32:22 +08:00
module . gradient_checkpointing = value
def forward (
self ,
sample : torch . FloatTensor ,
timestep : Union [ torch . Tensor , float , int ] ,
encoder_hidden_states : torch . Tensor ,
class_labels : Optional [ torch . Tensor ] = None ,
2023-11-10 11:57:39 +08:00
timestep_cond : Optional [ torch . Tensor ] = None ,
2023-07-09 21:32:22 +08:00
attention_mask : Optional [ torch . Tensor ] = None ,
2023-11-10 11:57:39 +08:00
cross_attention_kwargs : Optional [ Dict [ str , Any ] ] = None ,
added_cond_kwargs : Optional [ Dict [ str , torch . Tensor ] ] = None ,
down_block_additional_residuals : Optional [ Tuple [ torch . Tensor ] ] = None ,
mid_block_additional_residual : Optional [ torch . Tensor ] = None ,
encoder_attention_mask : Optional [ torch . Tensor ] = None ,
2023-07-09 21:32:22 +08:00
return_dict : bool = True ,
) - > Union [ UNet3DConditionOutput , Tuple ] :
r """
2023-11-10 11:57:39 +08:00
The [ ` UNet2DConditionModel ` ] forward method .
2023-07-09 21:32:22 +08:00
Args :
2023-11-10 11:57:39 +08:00
sample ( ` torch . FloatTensor ` ) :
The noisy input tensor with the following shape ` ( batch , channel , height , width ) ` .
timestep ( ` torch . FloatTensor ` or ` float ` or ` int ` ) : The number of timesteps to denoise an input .
encoder_hidden_states ( ` torch . FloatTensor ` ) :
The encoder hidden states with shape ` ( batch , sequence_length , feature_dim ) ` .
encoder_attention_mask ( ` torch . Tensor ` ) :
A cross - attention mask of shape ` ( batch , sequence_length ) ` is applied to ` encoder_hidden_states ` . If
` True ` the mask is kept , otherwise if ` False ` it is discarded . Mask will be converted into a bias ,
which adds large negative values to the attention scores corresponding to " discard " tokens .
2023-07-09 21:32:22 +08:00
return_dict ( ` bool ` , * optional * , defaults to ` True ` ) :
2023-11-10 11:57:39 +08:00
Whether or not to return a [ ` ~ models . unet_2d_condition . UNet2DConditionOutput ` ] instead of a plain
tuple .
cross_attention_kwargs ( ` dict ` , * optional * ) :
A kwargs dictionary that if specified is passed along to the [ ` AttnProcessor ` ] .
added_cond_kwargs : ( ` dict ` , * optional * ) :
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks .
2023-07-09 21:32:22 +08:00
Returns :
[ ` ~ models . unet_2d_condition . UNet2DConditionOutput ` ] or ` tuple ` :
2023-11-10 11:57:39 +08:00
If ` return_dict ` is True , an [ ` ~ models . unet_2d_condition . UNet2DConditionOutput ` ] is returned , otherwise
a ` tuple ` is returned where the first element is the sample tensor .
2023-07-09 21:32:22 +08:00
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
2023-11-10 11:57:39 +08:00
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
2023-07-09 21:32:22 +08:00
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2 * * self . num_upsamplers
2023-11-10 11:57:39 +08:00
# convert the time, size, and text embedding into (b f) c h w
video_length = sample . shape [ 2 ]
timestep = repeat ( timestep , " b-> (b f) " , f = video_length )
encoder_hidden_states = repeat ( encoder_hidden_states , " b c d-> (b f) c d " , f = video_length )
added_cond_kwargs [ ' time_ids ' ] = repeat ( added_cond_kwargs [ ' time_ids ' ] , " b c -> (b f) c " , f = video_length )
added_cond_kwargs [ ' text_embeds ' ] = repeat ( added_cond_kwargs [ ' text_embeds ' ] , " b c -> (b f) c " , f = video_length )
#sample = rearrange(sample, "b c f h w -> (b f) c h w")
2023-07-09 21:32:22 +08:00
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any ( s % default_overall_up_factor != 0 for s in sample . shape [ - 2 : ] ) :
logger . info ( " Forward upsample size to force interpolation output size. " )
forward_upsample_size = True
2023-11-10 11:57:39 +08:00
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
2023-07-09 21:32:22 +08:00
if attention_mask is not None :
2023-11-10 11:57:39 +08:00
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
2023-07-09 21:32:22 +08:00
attention_mask = ( 1 - attention_mask . to ( sample . dtype ) ) * - 10000.0
attention_mask = attention_mask . unsqueeze ( 1 )
2023-11-10 11:57:39 +08:00
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None :
encoder_attention_mask = ( 1 - encoder_attention_mask . to ( sample . dtype ) ) * - 10000.0
encoder_attention_mask = encoder_attention_mask . unsqueeze ( 1 )
# 0. center input if necessary
2023-07-09 21:32:22 +08:00
if self . config . center_input_sample :
sample = 2 * sample - 1.0
2023-11-10 11:57:39 +08:00
# 1. time
2023-07-09 21:32:22 +08:00
timesteps = timestep
if not torch . is_tensor ( timesteps ) :
2023-11-10 11:57:39 +08:00
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
2023-07-09 21:32:22 +08:00
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample . device . type == " mps "
if isinstance ( timestep , float ) :
dtype = torch . float32 if is_mps else torch . float64
else :
dtype = torch . int32 if is_mps else torch . int64
timesteps = torch . tensor ( [ timesteps ] , dtype = dtype , device = sample . device )
elif len ( timesteps . shape ) == 0 :
timesteps = timesteps [ None ] . to ( sample . device )
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
2023-11-10 11:57:39 +08:00
# timesteps = timesteps
2023-07-09 21:32:22 +08:00
t_emb = self . time_proj ( timesteps )
2023-11-10 11:57:39 +08:00
# `Timesteps` does not contain any weights and will always return f32 tensors
2023-07-09 21:32:22 +08:00
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
2023-11-10 11:57:39 +08:00
t_emb = t_emb . to ( dtype = sample . dtype )
emb = self . time_embedding ( t_emb , timestep_cond )
aug_emb = None
2023-07-09 21:32:22 +08:00
if self . class_embedding is not None :
if class_labels is None :
raise ValueError ( " class_labels should be provided when num_class_embeds > 0 " )
if self . config . class_embed_type == " timestep " :
class_labels = self . time_proj ( class_labels )
2023-11-10 11:57:39 +08:00
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels . to ( dtype = sample . dtype )
class_emb = self . class_embedding ( class_labels ) . to ( dtype = sample . dtype )
if self . config . class_embeddings_concat :
emb = torch . cat ( [ emb , class_emb ] , dim = - 1 )
else :
emb = emb + class_emb
if self . config . addition_embed_type == " text " :
aug_emb = self . add_embedding ( encoder_hidden_states )
elif self . config . addition_embed_type == " text_image " :
# Kandinsky 2.1 - style
if " image_embeds " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `addition_embed_type` set to ' text_image ' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs` "
)
image_embs = added_cond_kwargs . get ( " image_embeds " )
text_embs = added_cond_kwargs . get ( " text_embeds " , encoder_hidden_states )
aug_emb = self . add_embedding ( text_embs , image_embs )
elif self . config . addition_embed_type == " text_time " :
# SDXL - style
if " text_embeds " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `addition_embed_type` set to ' text_time ' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs` "
)
text_embeds = added_cond_kwargs . get ( " text_embeds " )
if " time_ids " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `addition_embed_type` set to ' text_time ' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs` "
)
time_ids = added_cond_kwargs . get ( " time_ids " )
time_embeds = self . add_time_proj ( time_ids . flatten ( ) )
time_embeds = time_embeds . reshape ( ( text_embeds . shape [ 0 ] , - 1 ) )
add_embeds = torch . concat ( [ text_embeds , time_embeds ] , dim = - 1 )
add_embeds = add_embeds . to ( emb . dtype )
aug_emb = self . add_embedding ( add_embeds )
elif self . config . addition_embed_type == " image " :
# Kandinsky 2.2 - style
if " image_embeds " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `addition_embed_type` set to ' image ' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs` "
)
image_embs = added_cond_kwargs . get ( " image_embeds " )
aug_emb = self . add_embedding ( image_embs )
elif self . config . addition_embed_type == " image_hint " :
# Kandinsky 2.2 - style
if " image_embeds " not in added_cond_kwargs or " hint " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `addition_embed_type` set to ' image_hint ' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs` "
)
image_embs = added_cond_kwargs . get ( " image_embeds " )
hint = added_cond_kwargs . get ( " hint " )
aug_emb , hint = self . add_embedding ( image_embs , hint )
sample = torch . cat ( [ sample , hint ] , dim = 1 )
emb = emb + aug_emb if aug_emb is not None else emb
if self . time_embed_act is not None :
emb = self . time_embed_act ( emb )
if self . encoder_hid_proj is not None and self . config . encoder_hid_dim_type == " text_proj " :
encoder_hidden_states = self . encoder_hid_proj ( encoder_hidden_states )
elif self . encoder_hid_proj is not None and self . config . encoder_hid_dim_type == " text_image_proj " :
# Kadinsky 2.1 - style
if " image_embeds " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `encoder_hid_dim_type` set to ' text_image_proj ' which requires the keyword argument `image_embeds` to be passed in `added_conditions` "
)
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
image_embeds = added_cond_kwargs . get ( " image_embeds " )
encoder_hidden_states = self . encoder_hid_proj ( encoder_hidden_states , image_embeds )
elif self . encoder_hid_proj is not None and self . config . encoder_hid_dim_type == " image_proj " :
# Kandinsky 2.2 - style
if " image_embeds " not in added_cond_kwargs :
raise ValueError (
f " { self . __class__ } has the config param `encoder_hid_dim_type` set to ' image_proj ' which requires the keyword argument `image_embeds` to be passed in `added_conditions` "
)
image_embeds = added_cond_kwargs . get ( " image_embeds " )
encoder_hidden_states = self . encoder_hid_proj ( image_embeds )
# 2. pre-process
video_length = sample . shape [ 2 ]
sample = rearrange ( sample , " b c f h w -> (b f) c h w " )
2023-07-09 21:32:22 +08:00
sample = self . conv_in ( sample )
2023-11-10 11:57:39 +08:00
sample = rearrange ( sample , " (b f) c h w -> b c f h w " , f = video_length )
# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs . get ( " gligen " , None ) is not None :
cross_attention_kwargs = cross_attention_kwargs . copy ( )
gligen_args = cross_attention_kwargs . pop ( " gligen " )
cross_attention_kwargs [ " gligen " ] = { " objs " : self . position_net ( * * gligen_args ) }
# 3. down
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
2023-07-09 21:32:22 +08:00
down_block_res_samples = ( sample , )
for downsample_block in self . down_blocks :
if hasattr ( downsample_block , " has_cross_attention " ) and downsample_block . has_cross_attention :
2023-11-10 11:57:39 +08:00
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = { }
if is_adapter and len ( down_block_additional_residuals ) > 0 :
additional_residuals [ " additional_residuals " ] = down_block_additional_residuals . pop ( 0 )
2023-07-09 21:32:22 +08:00
sample , res_samples = downsample_block (
hidden_states = sample ,
temb = emb ,
encoder_hidden_states = encoder_hidden_states ,
attention_mask = attention_mask ,
2023-11-10 11:57:39 +08:00
cross_attention_kwargs = cross_attention_kwargs ,
encoder_attention_mask = encoder_attention_mask ,
* * additional_residuals ,
2023-07-09 21:32:22 +08:00
)
else :
2023-11-10 11:57:39 +08:00
sample , res_samples = downsample_block ( hidden_states = sample , temb = emb )
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
if is_adapter and len ( down_block_additional_residuals ) > 0 :
sample + = down_block_additional_residuals . pop ( 0 )
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
down_block_res_samples + = res_samples
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
if is_controlnet :
new_down_block_res_samples = ( )
for down_block_res_sample , down_block_additional_residual in zip (
down_block_res_samples , down_block_additional_residuals
) :
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + ( down_block_res_sample , )
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self . mid_block is not None :
sample = self . mid_block (
sample ,
emb ,
encoder_hidden_states = encoder_hidden_states ,
attention_mask = attention_mask ,
cross_attention_kwargs = cross_attention_kwargs ,
encoder_attention_mask = encoder_attention_mask ,
)
# To support T2I-Adapter-XL
if (
is_adapter
and len ( down_block_additional_residuals ) > 0
and sample . shape == down_block_additional_residuals [ 0 ] . shape
) :
sample + = down_block_additional_residuals . pop ( 0 )
if is_controlnet :
sample = sample + mid_block_additional_residual
# 5. up
2023-07-09 21:32:22 +08:00
for i , upsample_block in enumerate ( self . up_blocks ) :
is_final_block = i == len ( self . up_blocks ) - 1
res_samples = down_block_res_samples [ - len ( upsample_block . resnets ) : ]
down_block_res_samples = down_block_res_samples [ : - len ( upsample_block . resnets ) ]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size :
upsample_size = down_block_res_samples [ - 1 ] . shape [ 2 : ]
if hasattr ( upsample_block , " has_cross_attention " ) and upsample_block . has_cross_attention :
sample = upsample_block (
hidden_states = sample ,
temb = emb ,
res_hidden_states_tuple = res_samples ,
encoder_hidden_states = encoder_hidden_states ,
2023-11-10 11:57:39 +08:00
cross_attention_kwargs = cross_attention_kwargs ,
2023-07-09 21:32:22 +08:00
upsample_size = upsample_size ,
attention_mask = attention_mask ,
2023-11-10 11:57:39 +08:00
encoder_attention_mask = encoder_attention_mask ,
2023-07-09 21:32:22 +08:00
)
else :
sample = upsample_block (
2023-11-10 11:57:39 +08:00
hidden_states = sample , temb = emb , res_hidden_states_tuple = res_samples , upsample_size = upsample_size
2023-07-09 21:32:22 +08:00
)
2023-11-10 11:57:39 +08:00
video_length = sample . shape [ 2 ]
sample = rearrange ( sample , " b c f h w -> (b f) c h w " )
# 6. post-process
if self . conv_norm_out :
sample = self . conv_norm_out ( sample )
sample = self . conv_act ( sample )
2023-07-09 21:32:22 +08:00
sample = self . conv_out ( sample )
2023-11-10 11:57:39 +08:00
sample = rearrange ( sample , " (b f) c h w -> b c f h w " , f = video_length )
2023-07-09 21:32:22 +08:00
if not return_dict :
return ( sample , )
return UNet3DConditionOutput ( sample = sample )
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
@classmethod
def from_pretrained_2d ( cls , pretrained_model_path , subfolder = None , unet_additional_kwargs = None ) :
if subfolder is not None :
pretrained_model_path = os . path . join ( pretrained_model_path , subfolder )
print ( f " loaded temporal unet ' s pretrained weights from { pretrained_model_path } ... " )
config_file = os . path . join ( pretrained_model_path , ' config.json ' )
if not os . path . isfile ( config_file ) :
raise RuntimeError ( f " { config_file } does not exist " )
with open ( config_file , " r " ) as f :
config = json . load ( f )
config [ " _class_name " ] = cls . __name__
config [ " down_block_types " ] = [
2023-11-10 11:57:39 +08:00
" DownBlock3D " ,
2023-07-09 21:32:22 +08:00
" CrossAttnDownBlock3D " ,
" CrossAttnDownBlock3D " ,
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
]
config [ " up_block_types " ] = [
" CrossAttnUpBlock3D " ,
" CrossAttnUpBlock3D " ,
2023-11-10 11:57:39 +08:00
" UpBlock3D " ,
2023-07-09 21:32:22 +08:00
]
2023-11-10 11:57:39 +08:00
config [ " mid_block_type " ] = " UNetMidBlock3DCrossAttn "
from diffusers . utils import SAFETENSORS_WEIGHTS_NAME
2023-07-09 21:32:22 +08:00
model = cls . from_config ( config , * * unet_additional_kwargs )
2023-11-10 11:57:39 +08:00
model_file = os . path . join ( pretrained_model_path , SAFETENSORS_WEIGHTS_NAME )
2023-07-09 21:32:22 +08:00
if not os . path . isfile ( model_file ) :
raise RuntimeError ( f " { model_file } does not exist " )
2023-11-10 11:57:39 +08:00
state_dict = { }
from safetensors import safe_open
with safe_open ( model_file , framework = ' pt ' ) as f :
for k in f . keys ( ) :
state_dict [ k ] = f . get_tensor ( k )
2023-07-09 21:32:22 +08:00
m , u = model . load_state_dict ( state_dict , strict = False )
print ( f " ### missing keys: { len ( m ) } ; \n ### unexpected keys: { len ( u ) } ; " )
# print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
params = [ p . numel ( ) if " temporal " in n else 0 for n , p in model . named_parameters ( ) ]
print ( f " ### Temporal Module Parameters: { sum ( params ) / 1e6 } M " )
2023-11-10 11:57:39 +08:00
return model