From 41a698ae8e544b8f3733f3dee66da4ced1c974e2 Mon Sep 17 00:00:00 2001 From: Yuwei Guo Date: Tue, 11 Jul 2023 18:12:56 +0800 Subject: [PATCH] update --- animatediff/models/motion_module.py | 57 +++++++++++++++-------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/animatediff/models/motion_module.py b/animatediff/models/motion_module.py index 94be95b..2359e71 100644 --- a/animatediff/models/motion_module.py +++ b/animatediff/models/motion_module.py @@ -51,14 +51,14 @@ class VanillaTemporalModule(nn.Module): def __init__( self, in_channels, - num_attention_heads = 8, - num_transformer_block = 2, - attention_block_types =( "Temporal_Self", "Temporal_Self" ), - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - temporal_attention_dim_div = 1, - zero_initialize = True, + num_attention_heads = 8, + num_transformer_block = 2, + attention_block_types =( "Temporal_Self", "Temporal_Self" ), + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + temporal_attention_dim_div = 1, + zero_initialize = True, ): super().__init__() @@ -92,20 +92,17 @@ class TemporalTransformer3DModel(nn.Module): attention_head_dim, num_layers, - attention_block_types=( - "Temporal_Self", - "Temporal_Self", - ), - dropout=0.0, - norm_num_groups=32, - cross_attention_dim=768, - activation_fn="geglu", - attention_bias=False, - upcast_attention=False, - - cross_frame_attention_mode=None, - temporal_position_encoding=False, - temporal_position_encoding_max_len=24, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, ): super().__init__() @@ -228,10 +225,14 @@ class TemporalTransformerBlock(nn.Module): class PositionalEncoding(nn.Module): - def __init__(self, d_model: int, dropout: float = 0., max_len: int = 24): + def __init__( + self, + d_model, + dropout = 0., + max_len = 24 + ): super().__init__() self.dropout = nn.Dropout(p=dropout) - # print(f"d_model: {d_model}") position = torch.arange(max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(1, max_len, d_model) @@ -247,10 +248,10 @@ class PositionalEncoding(nn.Module): class VersatileAttention(CrossAttention): def __init__( self, - attention_mode=None, - cross_frame_attention_mode=None, - temporal_position_encoding=False, - temporal_position_encoding_max_len=24, + attention_mode = None, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, *args, **kwargs ): super().__init__(*args, **kwargs)