avoid zero size batch caused by grad_accumulation

This commit is contained in:
jason-on-salt-a40
2024-04-09 11:33:58 -07:00
parent 778db3443d
commit 17061636f2
3 changed files with 35 additions and 13 deletions

View File

@@ -90,6 +90,8 @@ class Trainer:
cur_batch = {key: batch[key][cur_ind] for key in batch}
with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32):
out = self.model(cur_batch)
if out == None:
continue
record_loss = out['loss'].detach().to(self.rank)
top10acc = out['top10acc'].to(self.rank)