add RealEdit.txt; fixed masking warning to be compatible with torch 2.1.0

This commit is contained in:
jason-on-salt-a40
2024-04-05 12:52:28 -07:00
parent 2506954b64
commit e2e598d900
2 changed files with 315 additions and 0 deletions

View File

@@ -409,6 +409,10 @@ class VoiceCraft(nn.Module):
.expand(-1, self.args.nhead, -1, -1)
.reshape(bsz * self.args.nhead, 1, src_len)
)
# Check shapes and resize or broadcast as necessary
if xy_attn_mask.shape != _xy_padding_mask.shape:
# Assuming _xy_padding_mask has the correct shape and xy_attn_mask needs adjustment
xy_attn_mask = xy_attn_mask.expand_as(_xy_padding_mask) # Example approach
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
new_attn_mask = torch.zeros_like(xy_attn_mask)