Files

1179 lines
59 KiB
Python
Raw Permalink Normal View History

2023-11-10 11:57:39 +08:00
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2023-07-09 21:32:22 +08:00
from dataclasses import dataclass
2023-11-10 11:57:39 +08:00
from typing import Any, Dict, List, Optional, Tuple, Union
2023-07-09 21:32:22 +08:00
import os
import json
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
import torch
import torch.nn as nn
import torch.utils.checkpoint
2023-11-10 11:57:39 +08:00
from einops import rearrange, repeat
2023-07-09 21:32:22 +08:00
from diffusers.configuration_utils import ConfigMixin, register_to_config
2023-11-10 11:57:39 +08:00
from diffusers.loaders import UNet2DConditionLoadersMixin, AttnProcsLayers
2023-07-09 21:32:22 +08:00
from diffusers.utils import BaseOutput, logging
2023-11-10 11:57:39 +08:00
from diffusers.models.activations import get_activation
from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor, LoRAAttnProcessor
from diffusers.models.embeddings import (
GaussianFourierProjection,
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
2023-07-09 21:32:22 +08:00
from .unet_blocks import (
UNetMidBlock3DCrossAttn,
get_down_block,
get_up_block,
)
2023-11-10 11:57:39 +08:00
from animatediff.utils.util import zero_rank_print
2023-07-09 21:32:22 +08:00
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNet3DConditionOutput(BaseOutput):
2023-11-10 11:57:39 +08:00
"""
The output of [`UNet3DConditionModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor = None
class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
shaped output.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
Height and width of input/output sample.
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
Whether to include self-attention in the basic transformer blocks, see
[`~models.attention.BasicTransformerBlock`].
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
If `None`, normalization and activation layers is skipped in post-processing.
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
num_attention_heads (`int`, *optional*):
The number of attention heads. If not defined, defaults to `attention_head_dim`
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
Dimension for the timestep embeddings.
num_class_embeds (`int`, *optional*, defaults to `None`):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
time_embedding_type (`str`, *optional*, defaults to `positional`):
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
time_embedding_dim (`int`, *optional*, defaults to `None`):
An optional override for the dimension of the projected time embedding.
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
Optional activation function to use only once on the time embeddings before they are passed to the rest of
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
timestep_post_act (`str`, *optional*, defaults to `None`):
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
The dimension of `cond_proj` layer in the timestep embedding.
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
embeddings with the class embeddings.
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
otherwise.
"""
2023-07-09 21:32:22 +08:00
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
2023-11-10 11:57:39 +08:00
freq_shift: int = 0,
2023-07-09 21:32:22 +08:00
down_block_types: Tuple[str] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
2023-11-10 11:57:39 +08:00
mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
2023-07-09 21:32:22 +08:00
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
2023-11-10 11:57:39 +08:00
layers_per_block: Union[int, Tuple[int]] = 2,
2023-07-09 21:32:22 +08:00
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
2023-11-10 11:57:39 +08:00
norm_num_groups: Optional[int] = 32,
2023-07-09 21:32:22 +08:00
norm_eps: float = 1e-5,
2023-11-10 11:57:39 +08:00
cross_attention_dim: Union[int, Tuple[int]] = 1280,
transformer_layers_per_block: Union[int, Tuple[int]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
2023-07-09 21:32:22 +08:00
attention_head_dim: Union[int, Tuple[int]] = 8,
2023-11-10 11:57:39 +08:00
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
2023-07-09 21:32:22 +08:00
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
2023-11-10 11:57:39 +08:00
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
2023-07-09 21:32:22 +08:00
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
2023-11-10 11:57:39 +08:00
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: int = 1.0,
time_embedding_type: str = "positional",
time_embedding_dim: Optional[int] = None,
time_embedding_act_fn: Optional[str] = None,
timestep_post_act: Optional[str] = None,
time_cond_proj_dim: Optional[int] = None,
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
addition_embed_type_num_heads=64,
# motion module
use_motion_module=False,
motion_module_resolutions = (1,2,4,8),
motion_module_mid_block = False,
motion_module_decoder_only = False,
motion_module_type=None,
motion_module_kwargs=None,
2023-07-09 21:32:22 +08:00
):
super().__init__()
self.sample_size = sample_size
2023-11-10 11:57:39 +08:00
if num_attention_heads is not None:
raise ValueError(
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
)
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
)
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
)
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
)
2023-07-09 21:32:22 +08:00
# input
2023-11-10 11:57:39 +08:00
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
)
2023-07-09 21:32:22 +08:00
# time
2023-11-10 11:57:39 +08:00
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
if time_embed_dim % 2 != 0:
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
self.time_proj = GaussianFourierProjection(
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
post_act_fn=timestep_post_act,
cond_proj_dim=time_cond_proj_dim,
)
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2
self.encoder_hid_proj = ImageProjection(
image_embed_dim=encoder_hid_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
2023-07-09 21:32:22 +08:00
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
2023-11-10 11:57:39 +08:00
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
2023-07-09 21:32:22 +08:00
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
2023-11-10 11:57:39 +08:00
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif class_embed_type == "simple_projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
)
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
2023-07-09 21:32:22 +08:00
else:
self.class_embedding = None
2023-11-10 11:57:39 +08:00
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type == "image":
# Kandinsky 2.2
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type == "image_hint":
# Kandinsky 2.2 ControlNet
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
if time_embedding_act_fn is None:
self.time_embed_act = None
else:
self.time_embed_act = get_activation(time_embedding_act_fn)
2023-07-09 21:32:22 +08:00
self.down_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
2023-11-10 11:57:39 +08:00
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = only_cross_attention
2023-07-09 21:32:22 +08:00
only_cross_attention = [only_cross_attention] * len(down_block_types)
2023-11-10 11:57:39 +08:00
if mid_block_only_cross_attention is None:
mid_block_only_cross_attention = False
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
2023-07-09 21:32:22 +08:00
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
2023-11-10 11:57:39 +08:00
if isinstance(cross_attention_dim, int):
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
if isinstance(layers_per_block, int):
layers_per_block = [layers_per_block] * len(down_block_types)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
if class_embeddings_concat:
# The time embeddings are concatenated with the class embeddings. The dimension of the
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
# regular time embeddings
blocks_time_embed_dim = time_embed_dim * 2
else:
blocks_time_embed_dim = time_embed_dim
2023-07-09 21:32:22 +08:00
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
2023-11-10 11:57:39 +08:00
res = 2 ** i
2023-07-09 21:32:22 +08:00
down_block = get_down_block(
down_block_type,
2023-11-10 11:57:39 +08:00
num_layers=layers_per_block[i],
transformer_layers_per_block=transformer_layers_per_block[i],
2023-07-09 21:32:22 +08:00
in_channels=input_channel,
out_channels=output_channel,
2023-11-10 11:57:39 +08:00
temb_channels=blocks_time_embed_dim,
2023-07-09 21:32:22 +08:00
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
2023-11-10 11:57:39 +08:00
cross_attention_dim=cross_attention_dim[i],
num_attention_heads=num_attention_heads[i],
2023-07-09 21:32:22 +08:00
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
2023-11-10 11:57:39 +08:00
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
2023-07-09 21:32:22 +08:00
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == "UNetMidBlock3DCrossAttn":
self.mid_block = UNetMidBlock3DCrossAttn(
2023-11-10 11:57:39 +08:00
transformer_layers_per_block=transformer_layers_per_block[-1],
2023-07-09 21:32:22 +08:00
in_channels=block_out_channels[-1],
2023-11-10 11:57:39 +08:00
temb_channels=blocks_time_embed_dim,
2023-07-09 21:32:22 +08:00
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
2023-11-10 11:57:39 +08:00
cross_attention_dim=cross_attention_dim[-1],
num_attention_heads=num_attention_heads[-1],
2023-07-09 21:32:22 +08:00
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
2023-11-10 11:57:39 +08:00
attention_type=attention_type,
2023-07-09 21:32:22 +08:00
use_motion_module=use_motion_module and motion_module_mid_block,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
2023-11-10 11:57:39 +08:00
elif mid_block_type is None:
self.mid_block = None
2023-07-09 21:32:22 +08:00
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
2023-11-10 11:57:39 +08:00
# count how many layers upsample the images
2023-07-09 21:32:22 +08:00
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
2023-11-10 11:57:39 +08:00
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_layers_per_block = list(reversed(layers_per_block))
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
2023-07-09 21:32:22 +08:00
only_cross_attention = list(reversed(only_cross_attention))
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
2023-11-10 11:57:39 +08:00
res = 2 ** (2 - i)
2023-07-09 21:32:22 +08:00
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
2023-11-10 11:57:39 +08:00
num_layers=reversed_layers_per_block[i] + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
2023-07-09 21:32:22 +08:00
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
2023-11-10 11:57:39 +08:00
temb_channels=blocks_time_embed_dim,
2023-07-09 21:32:22 +08:00
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
2023-11-10 11:57:39 +08:00
cross_attention_dim=reversed_cross_attention_dim[i],
num_attention_heads=reversed_num_attention_heads[i],
2023-07-09 21:32:22 +08:00
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
2023-11-10 11:57:39 +08:00
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
2023-07-09 21:32:22 +08:00
use_motion_module=use_motion_module and (res in motion_module_resolutions),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
2023-11-10 11:57:39 +08:00
if norm_num_groups is not None:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = get_activation(act_fn)
2023-09-10 21:26:51 +08:00
else:
2023-11-10 11:57:39 +08:00
self.conv_norm_out = None
self.conv_act = None
conv_out_padding = (conv_out_kernel - 1) // 2
self.conv_out = nn.Conv2d(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)
if attention_type == "gated":
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
positive_len = cross_attention_dim[0]
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
def set_image_layer_lora(self, image_layer_lora_rank: int = 128):
lora_attn_procs = {}
for name in self.attn_processors.keys():
zero_rank_print(f"(add lora) {name}")
cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = self.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=image_layer_lora_rank if image_layer_lora_rank > 16 else hidden_size // image_layer_lora_rank,
)
self.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(self.attn_processors)
zero_rank_print(f"(lora parameters): {sum(p.numel() for p in lora_layers.parameters()) / 1e6:.3f} M")
del lora_layers
def set_image_layer_lora_scale(self, lora_scale: float = 1.0):
for block in self.down_blocks: setattr(block, "lora_scale", lora_scale)
for block in self.up_blocks: setattr(block, "lora_scale", lora_scale)
setattr(self.mid_block, "lora_scale", lora_scale)
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "set_processor"):
if not "motion_modules." in name:
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], is_motion_module=False):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys()) if not is_motion_module else len(self.motion_module_attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if ((not is_motion_module) and (not "motion_modules." in name)) or (is_motion_module and ("motion_modules." in name)):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(AttnProcessor())
@property
def motion_module_attn_processors(self):
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
# filter out processors in motion module
if hasattr(module, "set_processor"):
if "motion_modules." in name:
processors[f"{name}.processor"] = module.processor
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_motion_module_lora(self, motion_module_lora_rank: int = 256, motion_lora_resolution=[32, 64, 128]):
lora_attn_procs = {}
#motion_name = []
#if 32 in motion_lora_resolution:
# motion_name.append('up_blocks.0')
# motion_name.append('down_blocks.2')
# if 64 in motion_lora_resolution:
# motion_name.append('up_blocks.1')
# motion_name.append('down_blocks.1')
# if 128 in motion_lora_resolution:
# motion_name.append('up_blocks.2')
# motion_name.append('down_blocks.0')
for name in self.motion_module_attn_processors.keys():
#prefix = '.'.join(name.split('.')[:2])
#if prefix not in motion_name:
# continue
print(f"(add motion lora) {name}")
if name.startswith("mid_block"):
hidden_size = self.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size,
cross_attention_dim=None,
rank=motion_module_lora_rank,
)
self.set_attn_processor(lora_attn_procs, is_motion_module=True)
lora_layers = AttnProcsLayers(self.motion_module_attn_processors)
print(f"(motion lora parameters): {sum(p.numel() for p in lora_layers.parameters()) / 1e6:.3f} M")
del lora_layers
2023-07-09 21:32:22 +08:00
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
2023-11-10 11:57:39 +08:00
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
2023-07-09 21:32:22 +08:00
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
2023-11-10 11:57:39 +08:00
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
2023-07-09 21:32:22 +08:00
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
2023-11-10 11:57:39 +08:00
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
2023-07-09 21:32:22 +08:00
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
2023-11-10 11:57:39 +08:00
fn_recursive_retrieve_sliceable_dims(child)
2023-07-09 21:32:22 +08:00
# retrieve number of attention layers
for module in self.children():
2023-11-10 11:57:39 +08:00
fn_recursive_retrieve_sliceable_dims(module)
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
num_sliceable_layers = len(sliceable_head_dims)
2023-07-09 21:32:22 +08:00
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
2023-11-10 11:57:39 +08:00
slice_size = num_sliceable_layers * [1]
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
2023-07-09 21:32:22 +08:00
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
2023-11-10 11:57:39 +08:00
if hasattr(module, "gradient_checkpointing"):
2023-07-09 21:32:22 +08:00
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
2023-11-10 11:57:39 +08:00
timestep_cond: Optional[torch.Tensor] = None,
2023-07-09 21:32:22 +08:00
attention_mask: Optional[torch.Tensor] = None,
2023-11-10 11:57:39 +08:00
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
2023-07-09 21:32:22 +08:00
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
2023-11-10 11:57:39 +08:00
The [`UNet2DConditionModel`] forward method.
2023-07-09 21:32:22 +08:00
Args:
2023-11-10 11:57:39 +08:00
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, channel, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
encoder_attention_mask (`torch.Tensor`):
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
which adds large negative values to the attention scores corresponding to "discard" tokens.
2023-07-09 21:32:22 +08:00
return_dict (`bool`, *optional*, defaults to `True`):
2023-11-10 11:57:39 +08:00
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
added_cond_kwargs: (`dict`, *optional*):
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
are passed along to the UNet blocks.
2023-07-09 21:32:22 +08:00
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
2023-11-10 11:57:39 +08:00
If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
a `tuple` is returned where the first element is the sample tensor.
2023-07-09 21:32:22 +08:00
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
2023-11-10 11:57:39 +08:00
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
2023-07-09 21:32:22 +08:00
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
2023-11-10 11:57:39 +08:00
# convert the time, size, and text embedding into (b f) c h w
video_length = sample.shape[2]
timestep = repeat(timestep, "b-> (b f)", f=video_length)
encoder_hidden_states = repeat(encoder_hidden_states, "b c d-> (b f) c d", f=video_length)
added_cond_kwargs['time_ids'] = repeat(added_cond_kwargs['time_ids'], "b c -> (b f) c", f=video_length)
added_cond_kwargs['text_embeds'] = repeat(added_cond_kwargs['text_embeds'], "b c -> (b f) c", f=video_length)
#sample = rearrange(sample, "b c f h w -> (b f) c h w")
2023-07-09 21:32:22 +08:00
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
2023-11-10 11:57:39 +08:00
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
2023-07-09 21:32:22 +08:00
if attention_mask is not None:
2023-11-10 11:57:39 +08:00
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
2023-07-09 21:32:22 +08:00
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
2023-11-10 11:57:39 +08:00
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None:
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 0. center input if necessary
2023-07-09 21:32:22 +08:00
if self.config.center_input_sample:
sample = 2 * sample - 1.0
2023-11-10 11:57:39 +08:00
# 1. time
2023-07-09 21:32:22 +08:00
timesteps = timestep
if not torch.is_tensor(timesteps):
2023-11-10 11:57:39 +08:00
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
2023-07-09 21:32:22 +08:00
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
2023-11-10 11:57:39 +08:00
# timesteps = timesteps
2023-07-09 21:32:22 +08:00
t_emb = self.time_proj(timesteps)
2023-11-10 11:57:39 +08:00
# `Timesteps` does not contain any weights and will always return f32 tensors
2023-07-09 21:32:22 +08:00
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
2023-11-10 11:57:39 +08:00
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
2023-07-09 21:32:22 +08:00
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
2023-11-10 11:57:39 +08:00
# `Timesteps` does not contain any weights and will always return f32 tensors
# there might be better ways to encapsulate this.
class_labels = class_labels.to(dtype=sample.dtype)
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
if self.config.class_embeddings_concat:
emb = torch.cat([emb, class_emb], dim=-1)
else:
emb = emb + class_emb
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_image":
# Kandinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
aug_emb = self.add_embedding(text_embs, image_embs)
elif self.config.addition_embed_type == "text_time":
# SDXL - style
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
elif self.config.addition_embed_type == "image":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
aug_emb = self.add_embedding(image_embs)
elif self.config.addition_embed_type == "image_hint":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
)
image_embs = added_cond_kwargs.get("image_embeds")
hint = added_cond_kwargs.get("hint")
aug_emb, hint = self.add_embedding(image_embs, hint)
sample = torch.cat([sample, hint], dim=1)
emb = emb + aug_emb if aug_emb is not None else emb
if self.time_embed_act is not None:
emb = self.time_embed_act(emb)
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
# Kadinsky 2.1 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
# Kandinsky 2.2 - style
if "image_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
image_embeds = added_cond_kwargs.get("image_embeds")
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
# 2. pre-process
video_length = sample.shape[2]
sample = rearrange(sample, "b c f h w -> (b f) c h w")
2023-07-09 21:32:22 +08:00
sample = self.conv_in(sample)
2023-11-10 11:57:39 +08:00
sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
2023-07-09 21:32:22 +08:00
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
2023-11-10 11:57:39 +08:00
# For t2i-adapter CrossAttnDownBlock2D
additional_residuals = {}
if is_adapter and len(down_block_additional_residuals) > 0:
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
2023-07-09 21:32:22 +08:00
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
2023-11-10 11:57:39 +08:00
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
**additional_residuals,
2023-07-09 21:32:22 +08:00
)
else:
2023-11-10 11:57:39 +08:00
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
down_block_res_samples += res_samples
2023-07-09 21:32:22 +08:00
2023-11-10 11:57:39 +08:00
if is_controlnet:
new_down_block_res_samples = ()
for down_block_res_sample, down_block_additional_residual in zip(
down_block_res_samples, down_block_additional_residuals
):
down_block_res_sample = down_block_res_sample + down_block_additional_residual
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
down_block_res_samples = new_down_block_res_samples
# 4. mid
if self.mid_block is not None:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
)
# To support T2I-Adapter-XL
if (
is_adapter
and len(down_block_additional_residuals) > 0
and sample.shape == down_block_additional_residuals[0].shape
):
sample += down_block_additional_residuals.pop(0)
if is_controlnet:
sample = sample + mid_block_additional_residual
# 5. up
2023-07-09 21:32:22 +08:00
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
2023-11-10 11:57:39 +08:00
cross_attention_kwargs=cross_attention_kwargs,
2023-07-09 21:32:22 +08:00
upsample_size=upsample_size,
attention_mask=attention_mask,
2023-11-10 11:57:39 +08:00
encoder_attention_mask=encoder_attention_mask,
2023-07-09 21:32:22 +08:00
)
else:
sample = upsample_block(
2023-11-10 11:57:39 +08:00
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
2023-07-09 21:32:22 +08:00
)
2023-11-10 11:57:39 +08:00
video_length = sample.shape[2]
sample = rearrange(sample, "b c f h w -> (b f) c h w")
# 6. post-process
if self.conv_norm_out:
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
2023-07-09 21:32:22 +08:00
sample = self.conv_out(sample)
2023-11-10 11:57:39 +08:00
sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
2023-07-09 21:32:22 +08:00
if not return_dict:
return (sample,)
return UNet3DConditionOutput(sample=sample)
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
config["_class_name"] = cls.__name__
config["down_block_types"] = [
2023-11-10 11:57:39 +08:00
"DownBlock3D",
2023-07-09 21:32:22 +08:00
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
2023-11-10 11:57:39 +08:00
2023-07-09 21:32:22 +08:00
]
config["up_block_types"] = [
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
2023-11-10 11:57:39 +08:00
"UpBlock3D",
2023-07-09 21:32:22 +08:00
]
2023-11-10 11:57:39 +08:00
config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
2023-07-09 21:32:22 +08:00
model = cls.from_config(config, **unet_additional_kwargs)
2023-11-10 11:57:39 +08:00
model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
2023-07-09 21:32:22 +08:00
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
2023-11-10 11:57:39 +08:00
state_dict = {}
from safetensors import safe_open
with safe_open(model_file, framework='pt') as f:
for k in f.keys():
state_dict[k] = f.get_tensor(k)
2023-07-09 21:32:22 +08:00
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
# print(f"### missing keys:\n{m}\n### unexpected keys:\n{u}\n")
params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()]
print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
2023-11-10 11:57:39 +08:00
return model