mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
update
This commit is contained in:
@@ -92,10 +92,7 @@ class TemporalTransformer3DModel(nn.Module):
|
||||
attention_head_dim,
|
||||
|
||||
num_layers,
|
||||
attention_block_types=(
|
||||
"Temporal_Self",
|
||||
"Temporal_Self",
|
||||
),
|
||||
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
|
||||
dropout = 0.0,
|
||||
norm_num_groups = 32,
|
||||
cross_attention_dim = 768,
|
||||
@@ -228,10 +225,14 @@ class TemporalTransformerBlock(nn.Module):
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model: int, dropout: float = 0., max_len: int = 24):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
dropout = 0.,
|
||||
max_len = 24
|
||||
):
|
||||
super().__init__()
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
# print(f"d_model: {d_model}")
|
||||
position = torch.arange(max_len).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
||||
pe = torch.zeros(1, max_len, d_model)
|
||||
|
||||
Reference in New Issue
Block a user