mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2026-04-03 01:36:55 +02:00
add RealEdit.txt; fixed masking warning to be compatible with torch 2.1.0
This commit is contained in:
@@ -409,6 +409,10 @@ class VoiceCraft(nn.Module):
|
||||
.expand(-1, self.args.nhead, -1, -1)
|
||||
.reshape(bsz * self.args.nhead, 1, src_len)
|
||||
)
|
||||
# Check shapes and resize or broadcast as necessary
|
||||
if xy_attn_mask.shape != _xy_padding_mask.shape:
|
||||
# Assuming _xy_padding_mask has the correct shape and xy_attn_mask needs adjustment
|
||||
xy_attn_mask = xy_attn_mask.expand_as(_xy_padding_mask) # Example approach
|
||||
xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)
|
||||
|
||||
new_attn_mask = torch.zeros_like(xy_attn_mask)
|
||||
|
||||
Reference in New Issue
Block a user