avoid zero size batch caused by grad_accumulation

This commit is contained in:
jason-on-salt-a40
2024-04-09 11:33:58 -07:00
parent 778db3443d
commit 17061636f2
3 changed files with 35 additions and 13 deletions

View File

@@ -462,6 +462,8 @@ class VoiceCraft(nn.Module):
before padding.
"""
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
if len(x) == 0:
return None
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
y = y[:, :, :y_lens.max()]
assert x.ndim == 2, x.shape