diff --git a/modelscope/models/nlp/gpt3/distributed_gpt3.py b/modelscope/models/nlp/gpt3/distributed_gpt3.py index 570487df..c7917b31 100644 --- a/modelscope/models/nlp/gpt3/distributed_gpt3.py +++ b/modelscope/models/nlp/gpt3/distributed_gpt3.py @@ -1023,7 +1023,11 @@ class DistributedGPT3(TorchModel, StreamingOutputMixin): losses = losses.float() loss_mask = loss_mask.view(-1).float() - loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + mask_sum = loss_mask.sum() + if mask_sum == 0: + loss = torch.sum(losses.view(-1)).zero_() + else: + loss = torch.sum(losses.view(-1) * loss_mask) / mask_sum return TextGenerationModelOutput(logits=logits, loss=loss)