mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
support v2
This commit is contained in:
@@ -30,7 +30,8 @@ def get_down_block(
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
|
||||
use_inflated_groupnorm=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
motion_module_type=None,
|
||||
@@ -50,6 +51,8 @@ def get_down_block(
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
@@ -77,6 +80,7 @@ def get_down_block(
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
@@ -106,6 +110,7 @@ def get_up_block(
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
use_inflated_groupnorm=None,
|
||||
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
@@ -125,6 +130,8 @@ def get_up_block(
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
motion_module_kwargs=motion_module_kwargs,
|
||||
@@ -152,6 +159,7 @@ def get_up_block(
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
|
||||
use_motion_module=use_motion_module,
|
||||
motion_module_type=motion_module_type,
|
||||
@@ -181,6 +189,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
use_inflated_groupnorm=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
@@ -206,6 +215,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
@@ -248,6 +259,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -290,6 +303,7 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
use_inflated_groupnorm=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
@@ -318,6 +332,8 @@ class CrossAttnDownBlock3D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
if dual_cross_attention:
|
||||
@@ -421,6 +437,8 @@ class DownBlock3D(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
downsample_padding=1,
|
||||
|
||||
use_inflated_groupnorm=None,
|
||||
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
@@ -444,6 +462,8 @@ class DownBlock3D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
@@ -526,6 +546,7 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
|
||||
unet_use_cross_frame_attention=None,
|
||||
unet_use_temporal_attention=None,
|
||||
use_inflated_groupnorm=None,
|
||||
|
||||
use_motion_module=None,
|
||||
|
||||
@@ -556,6 +577,8 @@ class CrossAttnUpBlock3D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
if dual_cross_attention:
|
||||
@@ -661,6 +684,8 @@ class UpBlock3D(nn.Module):
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
|
||||
use_inflated_groupnorm=None,
|
||||
|
||||
use_motion_module=None,
|
||||
motion_module_type=None,
|
||||
motion_module_kwargs=None,
|
||||
@@ -685,6 +710,8 @@ class UpBlock3D(nn.Module):
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
)
|
||||
)
|
||||
motion_modules.append(
|
||||
|
||||
Reference in New Issue
Block a user