mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 17:56:15 +02:00
fix model
This commit is contained in:
@@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module):
|
||||
pe = torch.zeros(1, max_len, d_model)
|
||||
pe[0, :, 0::2] = torch.sin(position * div_term)
|
||||
pe[0, :, 1::2] = torch.cos(position * div_term)
|
||||
self.register_buffer('pe', pe)
|
||||
self.register_buffer('pe', pe, persistent=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.pe[:, :x.size(1)]
|
||||
@@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention):
|
||||
attention_mode = None,
|
||||
cross_frame_attention_mode = None,
|
||||
temporal_position_encoding = False,
|
||||
temporal_position_encoding_max_len = 24,
|
||||
temporal_position_encoding_max_len = 32,
|
||||
*args, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user