mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
support v2
This commit is contained in:
@@ -18,6 +18,17 @@ class InflatedConv3d(nn.Conv2d):
|
||||
return x
|
||||
|
||||
|
||||
class InflatedGroupNorm(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
video_length = x.shape[2]
|
||||
|
||||
x = rearrange(x, "b c f h w -> (b f) c h w")
|
||||
x = super().forward(x)
|
||||
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Upsample3D(nn.Module):
|
||||
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
|
||||
super().__init__()
|
||||
@@ -112,6 +123,7 @@ class ResnetBlock3D(nn.Module):
|
||||
time_embedding_norm="default",
|
||||
output_scale_factor=1.0,
|
||||
use_in_shortcut=None,
|
||||
use_inflated_groupnorm=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.pre_norm = pre_norm
|
||||
@@ -126,7 +138,11 @@ class ResnetBlock3D(nn.Module):
|
||||
if groups_out is None:
|
||||
groups_out = groups
|
||||
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
assert use_inflated_groupnorm != None
|
||||
if use_inflated_groupnorm:
|
||||
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
else:
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
|
||||
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
@@ -142,7 +158,11 @@ class ResnetBlock3D(nn.Module):
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
if use_inflated_groupnorm:
|
||||
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
else:
|
||||
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user