mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 01:36:20 +02:00
support v2
This commit is contained in:
@@ -18,6 +18,17 @@ class InflatedConv3d(nn.Conv2d):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class InflatedGroupNorm(nn.GroupNorm):
|
||||||
|
def forward(self, x):
|
||||||
|
video_length = x.shape[2]
|
||||||
|
|
||||||
|
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||||
|
x = super().forward(x)
|
||||||
|
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Upsample3D(nn.Module):
|
class Upsample3D(nn.Module):
|
||||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -112,6 +123,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
time_embedding_norm="default",
|
time_embedding_norm="default",
|
||||||
output_scale_factor=1.0,
|
output_scale_factor=1.0,
|
||||||
use_in_shortcut=None,
|
use_in_shortcut=None,
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.pre_norm = pre_norm
|
self.pre_norm = pre_norm
|
||||||
@@ -126,7 +138,11 @@ class ResnetBlock3D(nn.Module):
|
|||||||
if groups_out is None:
|
if groups_out is None:
|
||||||
groups_out = groups
|
groups_out = groups
|
||||||
|
|
||||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
assert use_inflated_groupnorm != None
|
||||||
|
if use_inflated_groupnorm:
|
||||||
|
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
else:
|
||||||
|
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||||
|
|
||||||
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
@@ -142,7 +158,11 @@ class ResnetBlock3D(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.time_emb_proj = None
|
self.time_emb_proj = None
|
||||||
|
|
||||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
if use_inflated_groupnorm:
|
||||||
|
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||||
|
else:
|
||||||
|
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||||
|
|
||||||
self.dropout = torch.nn.Dropout(dropout)
|
self.dropout = torch.nn.Dropout(dropout)
|
||||||
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from .unet_blocks import (
|
|||||||
get_down_block,
|
get_down_block,
|
||||||
get_up_block,
|
get_up_block,
|
||||||
)
|
)
|
||||||
from .resnet import InflatedConv3d
|
from .resnet import InflatedConv3d, InflatedGroupNorm
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
@@ -77,6 +77,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|||||||
upcast_attention: bool = False,
|
upcast_attention: bool = False,
|
||||||
resnet_time_scale_shift: str = "default",
|
resnet_time_scale_shift: str = "default",
|
||||||
|
|
||||||
|
use_inflated_groupnorm=False,
|
||||||
|
|
||||||
# Additional
|
# Additional
|
||||||
use_motion_module = False,
|
use_motion_module = False,
|
||||||
motion_module_resolutions = ( 1,2,4,8 ),
|
motion_module_resolutions = ( 1,2,4,8 ),
|
||||||
@@ -150,6 +152,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
|
|
||||||
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
||||||
motion_module_type=motion_module_type,
|
motion_module_type=motion_module_type,
|
||||||
@@ -175,6 +178,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
|
|
||||||
use_motion_module=use_motion_module and motion_module_mid_block,
|
use_motion_module=use_motion_module and motion_module_mid_block,
|
||||||
motion_module_type=motion_module_type,
|
motion_module_type=motion_module_type,
|
||||||
@@ -227,6 +231,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
|
|
||||||
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
||||||
motion_module_type=motion_module_type,
|
motion_module_type=motion_module_type,
|
||||||
@@ -236,7 +241,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
|||||||
prev_output_channel = output_channel
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
# out
|
# out
|
||||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
if use_inflated_groupnorm:
|
||||||
|
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||||
|
else:
|
||||||
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||||
self.conv_act = nn.SiLU()
|
self.conv_act = nn.SiLU()
|
||||||
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ def get_down_block(
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=None,
|
unet_use_cross_frame_attention=None,
|
||||||
unet_use_temporal_attention=None,
|
unet_use_temporal_attention=None,
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
|
|
||||||
use_motion_module=None,
|
use_motion_module=None,
|
||||||
|
|
||||||
@@ -50,6 +51,8 @@ def get_down_block(
|
|||||||
downsample_padding=downsample_padding,
|
downsample_padding=downsample_padding,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
|
|
||||||
use_motion_module=use_motion_module,
|
use_motion_module=use_motion_module,
|
||||||
motion_module_type=motion_module_type,
|
motion_module_type=motion_module_type,
|
||||||
motion_module_kwargs=motion_module_kwargs,
|
motion_module_kwargs=motion_module_kwargs,
|
||||||
@@ -77,6 +80,7 @@ def get_down_block(
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
|
|
||||||
use_motion_module=use_motion_module,
|
use_motion_module=use_motion_module,
|
||||||
motion_module_type=motion_module_type,
|
motion_module_type=motion_module_type,
|
||||||
@@ -106,6 +110,7 @@ def get_up_block(
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=None,
|
unet_use_cross_frame_attention=None,
|
||||||
unet_use_temporal_attention=None,
|
unet_use_temporal_attention=None,
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
|
|
||||||
use_motion_module=None,
|
use_motion_module=None,
|
||||||
motion_module_type=None,
|
motion_module_type=None,
|
||||||
@@ -125,6 +130,8 @@ def get_up_block(
|
|||||||
resnet_groups=resnet_groups,
|
resnet_groups=resnet_groups,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
|
|
||||||
use_motion_module=use_motion_module,
|
use_motion_module=use_motion_module,
|
||||||
motion_module_type=motion_module_type,
|
motion_module_type=motion_module_type,
|
||||||
motion_module_kwargs=motion_module_kwargs,
|
motion_module_kwargs=motion_module_kwargs,
|
||||||
@@ -152,6 +159,7 @@ def get_up_block(
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
|
|
||||||
use_motion_module=use_motion_module,
|
use_motion_module=use_motion_module,
|
||||||
motion_module_type=motion_module_type,
|
motion_module_type=motion_module_type,
|
||||||
@@ -181,6 +189,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=None,
|
unet_use_cross_frame_attention=None,
|
||||||
unet_use_temporal_attention=None,
|
unet_use_temporal_attention=None,
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
|
|
||||||
use_motion_module=None,
|
use_motion_module=None,
|
||||||
|
|
||||||
@@ -206,6 +215,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
pre_norm=resnet_pre_norm,
|
pre_norm=resnet_pre_norm,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
attentions = []
|
attentions = []
|
||||||
@@ -248,6 +259,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
|||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
pre_norm=resnet_pre_norm,
|
pre_norm=resnet_pre_norm,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -290,6 +303,7 @@ class CrossAttnDownBlock3D(nn.Module):
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=None,
|
unet_use_cross_frame_attention=None,
|
||||||
unet_use_temporal_attention=None,
|
unet_use_temporal_attention=None,
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
|
|
||||||
use_motion_module=None,
|
use_motion_module=None,
|
||||||
|
|
||||||
@@ -318,6 +332,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
|||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
pre_norm=resnet_pre_norm,
|
pre_norm=resnet_pre_norm,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if dual_cross_attention:
|
if dual_cross_attention:
|
||||||
@@ -422,6 +438,8 @@ class DownBlock3D(nn.Module):
|
|||||||
add_downsample=True,
|
add_downsample=True,
|
||||||
downsample_padding=1,
|
downsample_padding=1,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
|
|
||||||
use_motion_module=None,
|
use_motion_module=None,
|
||||||
motion_module_type=None,
|
motion_module_type=None,
|
||||||
motion_module_kwargs=None,
|
motion_module_kwargs=None,
|
||||||
@@ -444,6 +462,8 @@ class DownBlock3D(nn.Module):
|
|||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
pre_norm=resnet_pre_norm,
|
pre_norm=resnet_pre_norm,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
motion_modules.append(
|
motion_modules.append(
|
||||||
@@ -526,6 +546,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
|||||||
|
|
||||||
unet_use_cross_frame_attention=None,
|
unet_use_cross_frame_attention=None,
|
||||||
unet_use_temporal_attention=None,
|
unet_use_temporal_attention=None,
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
|
|
||||||
use_motion_module=None,
|
use_motion_module=None,
|
||||||
|
|
||||||
@@ -556,6 +577,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
|||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
pre_norm=resnet_pre_norm,
|
pre_norm=resnet_pre_norm,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if dual_cross_attention:
|
if dual_cross_attention:
|
||||||
@@ -661,6 +684,8 @@ class UpBlock3D(nn.Module):
|
|||||||
output_scale_factor=1.0,
|
output_scale_factor=1.0,
|
||||||
add_upsample=True,
|
add_upsample=True,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=None,
|
||||||
|
|
||||||
use_motion_module=None,
|
use_motion_module=None,
|
||||||
motion_module_type=None,
|
motion_module_type=None,
|
||||||
motion_module_kwargs=None,
|
motion_module_kwargs=None,
|
||||||
@@ -685,6 +710,8 @@ class UpBlock3D(nn.Module):
|
|||||||
non_linearity=resnet_act_fn,
|
non_linearity=resnet_act_fn,
|
||||||
output_scale_factor=output_scale_factor,
|
output_scale_factor=output_scale_factor,
|
||||||
pre_norm=resnet_pre_norm,
|
pre_norm=resnet_pre_norm,
|
||||||
|
|
||||||
|
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
motion_modules.append(
|
motion_modules.append(
|
||||||
|
|||||||
Reference in New Issue
Block a user