mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2026-04-03 01:36:55 +02:00
avoid zero size batch caused by grad_accumulation
This commit is contained in:
@@ -90,6 +90,8 @@ class Trainer:
|
||||
cur_batch = {key: batch[key][cur_ind] for key in batch}
|
||||
with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32):
|
||||
out = self.model(cur_batch)
|
||||
if out == None:
|
||||
continue
|
||||
|
||||
record_loss = out['loss'].detach().to(self.rank)
|
||||
top10acc = out['top10acc'].to(self.rank)
|
||||
|
||||
Reference in New Issue
Block a user