This commit is contained in:
Yuwei Guo
2023-07-11 18:12:56 +08:00
parent 81f2422dc3
commit 41a698ae8e

View File

@@ -51,14 +51,14 @@ class VanillaTemporalModule(nn.Module):
def __init__( def __init__(
self, self,
in_channels, in_channels,
num_attention_heads = 8, num_attention_heads = 8,
num_transformer_block = 2, num_transformer_block = 2,
attention_block_types =( "Temporal_Self", "Temporal_Self" ), attention_block_types =( "Temporal_Self", "Temporal_Self" ),
cross_frame_attention_mode = None, cross_frame_attention_mode = None,
temporal_position_encoding = False, temporal_position_encoding = False,
temporal_position_encoding_max_len = 24, temporal_position_encoding_max_len = 24,
temporal_attention_dim_div = 1, temporal_attention_dim_div = 1,
zero_initialize = True, zero_initialize = True,
): ):
super().__init__() super().__init__()
@@ -92,20 +92,17 @@ class TemporalTransformer3DModel(nn.Module):
attention_head_dim, attention_head_dim,
num_layers, num_layers,
attention_block_types=( attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
"Temporal_Self", dropout = 0.0,
"Temporal_Self", norm_num_groups = 32,
), cross_attention_dim = 768,
dropout=0.0, activation_fn = "geglu",
norm_num_groups=32, attention_bias = False,
cross_attention_dim=768, upcast_attention = False,
activation_fn="geglu",
attention_bias=False, cross_frame_attention_mode = None,
upcast_attention=False, temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
): ):
super().__init__() super().__init__()
@@ -228,10 +225,14 @@ class TemporalTransformerBlock(nn.Module):
class PositionalEncoding(nn.Module): class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0., max_len: int = 24): def __init__(
self,
d_model,
dropout = 0.,
max_len = 24
):
super().__init__() super().__init__()
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# print(f"d_model: {d_model}")
position = torch.arange(max_len).unsqueeze(1) position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model) pe = torch.zeros(1, max_len, d_model)
@@ -247,10 +248,10 @@ class PositionalEncoding(nn.Module):
class VersatileAttention(CrossAttention): class VersatileAttention(CrossAttention):
def __init__( def __init__(
self, self,
attention_mode=None, attention_mode = None,
cross_frame_attention_mode=None, cross_frame_attention_mode = None,
temporal_position_encoding=False, temporal_position_encoding = False,
temporal_position_encoding_max_len=24, temporal_position_encoding_max_len = 24,
*args, **kwargs *args, **kwargs
): ):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)