support sparsectrl

This commit is contained in:
Yuwei Guo
2023-12-15 20:55:51 +08:00
parent 6c8a01b148
commit 401bc45697
7 changed files with 697 additions and 45 deletions

View File

@@ -28,9 +28,9 @@ def get_down_block(
upcast_attention=False,
resnet_time_scale_shift="default",
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
@@ -108,9 +108,9 @@ def get_up_block(
upcast_attention=False,
resnet_time_scale_shift="default",
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
@@ -187,9 +187,9 @@ class UNetMidBlock3DCrossAttn(nn.Module):
use_linear_projection=False,
upcast_attention=False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
@@ -301,9 +301,9 @@ class CrossAttnDownBlock3D(nn.Module):
only_cross_attention=False,
upcast_attention=False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
@@ -438,7 +438,7 @@ class DownBlock3D(nn.Module):
add_downsample=True,
downsample_padding=1,
use_inflated_groupnorm=None,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
@@ -544,9 +544,9 @@ class CrossAttnUpBlock3D(nn.Module):
only_cross_attention=False,
upcast_attention=False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
use_inflated_groupnorm=None,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
@@ -684,7 +684,7 @@ class UpBlock3D(nn.Module):
output_scale_factor=1.0,
add_upsample=True,
use_inflated_groupnorm=None,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,