Fix gpt3 finetune nan

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12981998
* fix gpt3 finetune nan
This commit is contained in:
hemu.zp
2023-06-19 11:25:27 +08:00
committed by wenmeng.zwm
parent cc3c384d5e
commit ad5d1aeb62

View File

@@ -1023,7 +1023,11 @@ class DistributedGPT3(TorchModel, StreamingOutputMixin):
losses = losses.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
mask_sum = loss_mask.sum()
if mask_sum == 0:
loss = torch.sum(losses.view(-1)).zero_()
else:
loss = torch.sum(losses.view(-1) * loss_mask) / mask_sum
return TextGenerationModelOutput(logits=logits, loss=loss)