support v2

This commit is contained in:
Yuwei Guo
2023-09-10 21:26:51 +08:00
parent 1b50d640dc
commit 108921965d
3 changed files with 61 additions and 6 deletions

View File

@@ -18,6 +18,17 @@ class InflatedConv3d(nn.Conv2d):
return x
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
@@ -112,6 +123,7 @@ class ResnetBlock3D(nn.Module):
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=None,
):
super().__init__()
self.pre_norm = pre_norm
@@ -126,7 +138,11 @@ class ResnetBlock3D(nn.Module):
if groups_out is None:
groups_out = groups
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
@@ -142,7 +158,11 @@ class ResnetBlock3D(nn.Module):
else:
self.time_emb_proj = None
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)