fix model

This commit is contained in:
Yuwei
2024-07-17 08:03:42 +00:00
parent cf80ddeb47
commit 786a99cc7f
3 changed files with 140 additions and 71 deletions

View File

@@ -238,7 +238,7 @@ class PositionalEncoding(nn.Module):
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
self.register_buffer('pe', pe, persistent=False)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
@@ -251,7 +251,7 @@ class VersatileAttention(CrossAttention):
attention_mode = None,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
temporal_position_encoding_max_len = 32,
*args, **kwargs
):
super().__init__(*args, **kwargs)