mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2026-04-03 01:36:55 +02:00
avoid zero size batch caused by grad_accumulation
This commit is contained in:
@@ -462,6 +462,8 @@ class VoiceCraft(nn.Module):
|
||||
before padding.
|
||||
"""
|
||||
x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"]
|
||||
if len(x) == 0:
|
||||
return None
|
||||
x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x
|
||||
y = y[:, :, :y_lens.max()]
|
||||
assert x.ndim == 2, x.shape
|
||||
|
||||
Reference in New Issue
Block a user