mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Fix gpt3 finetune nan
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12981998 * fix gpt3 finetune nan
This commit is contained in:
@@ -1023,7 +1023,11 @@ class DistributedGPT3(TorchModel, StreamingOutputMixin):
|
||||
|
||||
losses = losses.float()
|
||||
loss_mask = loss_mask.view(-1).float()
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
|
||||
mask_sum = loss_mask.sum()
|
||||
if mask_sum == 0:
|
||||
loss = torch.sum(losses.view(-1)).zero_()
|
||||
else:
|
||||
loss = torch.sum(losses.view(-1) * loss_mask) / mask_sum
|
||||
|
||||
return TextGenerationModelOutput(logits=logits, loss=loss)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user