From ad5d1aeb624aad3701fe3ccb5c699e41b033f24f Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Mon, 19 Jun 2023 11:25:27 +0800 Subject: [PATCH] Fix gpt3 finetune nan Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/12981998 * fix gpt3 finetune nan --- modelscope/models/nlp/gpt3/distributed_gpt3.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py index 570487df..c7917b31 100644 --- a/modelscope/models/nlp/gpt3/distributed_gpt3.py +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -1023,7 +1023,11 @@ class DistributedGPT3(TorchModel, StreamingOutputMixin): losses = losses.float() loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + mask_sum = loss_mask.sum() + if mask_sum == 0: + loss = torch.sum(losses.view(-1)).zero_() + else: + loss = torch.sum(losses.view(-1) * loss_mask) / mask_sum return TextGenerationModelOutput(logits=logits, loss=loss)