mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
support v2
This commit is contained in:
@@ -24,7 +24,7 @@ from .unet_blocks import (
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
from .resnet import InflatedConv3d
|
||||
from .resnet import InflatedConv3d, InflatedGroupNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -77,6 +77,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
|
||||
use_inflated_groupnorm=False,
|
||||
|
||||
# Additional
|
||||
use_motion_module = False,
|
||||
motion_module_resolutions = ( 1,2,4,8 ),
|
||||
@@ -88,7 +90,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
unet_use_temporal_attention = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.sample_size = sample_size
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
|
||||
@@ -150,6 +152,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
|
||||
use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only),
|
||||
motion_module_type=motion_module_type,
|
||||
@@ -175,6 +178,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
|
||||
use_motion_module=use_motion_module and motion_module_mid_block,
|
||||
motion_module_type=motion_module_type,
|
||||
@@ -227,6 +231,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
|
||||
unet_use_temporal_attention=unet_use_temporal_attention,
|
||||
use_inflated_groupnorm=use_inflated_groupnorm,
|
||||
|
||||
use_motion_module=use_motion_module and (res in motion_module_resolutions),
|
||||
motion_module_type=motion_module_type,
|
||||
@@ -236,7 +241,10 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
|
||||
prev_output_channel = output_channel
|
||||
|
||||
# out
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||
if use_inflated_groupnorm:
|
||||
self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||
else:
|
||||
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user