support v2

This commit is contained in:
Yuwei Guo
2023-09-10 21:26:51 +08:00
parent 1b50d640dc
commit 108921965d
3 changed files with 61 additions and 6 deletions

View File

@@ -30,7 +30,8 @@ def get_down_block(
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
use_motion_module=None,
motion_module_type=None,
@@ -50,6 +51,8 @@ def get_down_block(
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
@@ -77,6 +80,7 @@ def get_down_block(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
@@ -106,6 +110,7 @@ def get_up_block(
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
use_motion_module=None,
motion_module_type=None,
@@ -125,6 +130,8 @@ def get_up_block(
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
@@ -152,6 +159,7 @@ def get_up_block(
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
@@ -181,6 +189,7 @@ class UNetMidBlock3DCrossAttn(nn.Module):
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
use_motion_module=None,
@@ -206,6 +215,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
]
attentions = []
@@ -248,6 +259,8 @@ class UNetMidBlock3DCrossAttn(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
@@ -290,6 +303,7 @@ class CrossAttnDownBlock3D(nn.Module):
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
use_motion_module=None,
@@ -318,6 +332,8 @@ class CrossAttnDownBlock3D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
@@ -421,6 +437,8 @@ class DownBlock3D(nn.Module):
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
use_inflated_groupnorm=None,
use_motion_module=None,
motion_module_type=None,
@@ -444,6 +462,8 @@ class DownBlock3D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
@@ -526,6 +546,7 @@ class CrossAttnUpBlock3D(nn.Module):
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
use_motion_module=None,
@@ -556,6 +577,8 @@ class CrossAttnUpBlock3D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
@@ -661,6 +684,8 @@ class UpBlock3D(nn.Module):
output_scale_factor=1.0,
add_upsample=True,
use_inflated_groupnorm=None,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
@@ -685,6 +710,8 @@ class UpBlock3D(nn.Module):
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(