mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
Merge remote-tracking branch 'origin/ofa/finetune_loss' into ofa/finetune
# Conflicts: # tests/trainers/test_ofa_trainer.py
This commit is contained in:
@@ -129,10 +129,9 @@ class OFATrainer(EpochBasedTrainer):
|
||||
|
||||
def train_step(self, model, inputs):
|
||||
model.train()
|
||||
model_outputs = model.forward(inputs)
|
||||
loss, sample_size, logging_output = self.criterion(
|
||||
model_outputs, inputs)
|
||||
train_outputs = {'loss': loss}
|
||||
# model_outputs = model.forward(inputs)
|
||||
loss, sample_size, logging_output = self.criterion(model, inputs)
|
||||
train_outputs = {'loss': loss / 100}
|
||||
# add model output info to log
|
||||
if 'log_vars' not in train_outputs:
|
||||
default_keys_pattern = ['loss']
|
||||
|
||||
@@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
self.padding_idx = args.tokenizer.pad_token_id
|
||||
self.args = args
|
||||
|
||||
def forward(self, output, sample, update_num=0, reduce=True):
|
||||
def forward(self, model, sample, update_num=0, reduce=True):
|
||||
"""Compute the loss for the given sample.
|
||||
|
||||
Returns a tuple with three elements:
|
||||
@@ -131,15 +131,20 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
2) the sample size, which is used as the denominator for the gradient
|
||||
3) logging outputs to display while training
|
||||
"""
|
||||
if 'labels' in sample:
|
||||
del sample['labels']
|
||||
if 'samples' in sample:
|
||||
del sample['samples']
|
||||
|
||||
if self.use_rdrop:
|
||||
construct_rdrop_sample(sample)
|
||||
|
||||
output = model.model(**sample['net_input'])
|
||||
loss, nll_loss, ntokens = self.compute_loss(
|
||||
output, sample, update_num, reduce=reduce)
|
||||
output.logits, sample, update_num, reduce=reduce)
|
||||
sample_size = (
|
||||
sample['target'].size(0) if self.sentence_avg else ntokens)
|
||||
logging_output = {
|
||||
'loss': loss.data,
|
||||
'loss': loss.data / 100,
|
||||
'nll_loss': nll_loss.data,
|
||||
'ntokens': sample['ntokens'],
|
||||
'nsentences': sample['nsentences'],
|
||||
@@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
}
|
||||
return loss, sample_size, logging_output
|
||||
|
||||
def get_lprobs_and_target(self, net_output, sample):
|
||||
def get_lprobs_and_target(self, logits, sample):
|
||||
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[
|
||||
'conf'] is not None else 1
|
||||
constraint_masks = None
|
||||
if 'constraint_masks' in sample and sample[
|
||||
'constraint_masks'] is not None:
|
||||
constraint_masks = sample['constraint_masks']
|
||||
net_output[0].masked_fill_(~constraint_masks, -math.inf)
|
||||
logits.masked_fill_(~constraint_masks, -math.inf)
|
||||
if self.constraint_start is not None and self.constraint_end is not None:
|
||||
net_output[0][:, :, 4:self.constraint_start] = -math.inf
|
||||
net_output[0][:, :, self.constraint_end:] = -math.inf
|
||||
lprobs = F.log_softmax(
|
||||
net_output[0], dim=-1, dtype=torch.float32) * conf
|
||||
logits[:, :, 4:self.constraint_start] = -math.inf
|
||||
logits[:, :, self.constraint_end:] = -math.inf
|
||||
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf
|
||||
target = sample['target']
|
||||
if self.ignore_prefix_size > 0:
|
||||
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
|
||||
@@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
|
||||
return lprobs.view(-1,
|
||||
lprobs.size(-1)), target.view(-1), constraint_masks
|
||||
|
||||
def compute_loss(self, net_output, sample, update_num, reduce=True):
|
||||
def compute_loss(self, logits, sample, update_num, reduce=True):
|
||||
lprobs, target, constraint_masks = self.get_lprobs_and_target(
|
||||
net_output, sample)
|
||||
logits, sample)
|
||||
if constraint_masks is not None:
|
||||
constraint_masks = constraint_masks[target != self.padding_idx]
|
||||
lprobs = lprobs[target != self.padding_idx]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import unittest
|
||||
|
||||
@@ -57,7 +59,7 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
'report_accuracy': False,
|
||||
'sample_patch_num': 196,
|
||||
'sentence_avg': False,
|
||||
'use_rdrop': False},
|
||||
'use_rdrop': True},
|
||||
'hooks': [{'type': 'BestCkptSaverHook',
|
||||
'metric_key': 'bleu-4',
|
||||
'interval': 100},
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
{"framework": "pytorch", "task": "image-captioning", "model": {"type": "ofa", "beam_search": {"beam_size": 5, "max_len_b": 16, "min_len": 1, "no_repeat_ngram_size": 0}, "seed": 7, "max_src_length": 256, "language": "en", "gen_type": "generation", "patch_image_size": 480, "max_image_size": 480, "imagenet_default_mean_and_std": false}, "pipeline": {"type": "image-captioning"}, "dataset": {"column_map": {"text": "caption"}}, "train": {"work_dir": "work/ckpts/caption", "max_epochs": 1, "use_fp16": true, "dataloader": {"batch_size_per_gpu": 4, "workers_per_gpu": 0}, "lr_scheduler": {"name": "polynomial_decay", "warmup_proportion": 0.01, "lr_end": 1e-07}, "lr_scheduler_hook": {"type": "LrSchedulerHook", "by_epoch": false}, "optimizer": {"type": "AdamW", "lr": 5e-05, "weight_decay": 0.01}, "optimizer_hook": {"type": "TorchAMPOptimizerHook", "cumulative_iters": 1, "grad_clip": {"max_norm": 1.0, "norm_type": 2}, "loss_keys": "loss"}, "criterion": {"name": "AdjustLabelSmoothedCrossEntropyCriterion", "constraint_range": null, "drop_worst_after": 0, "drop_worst_ratio": 0.0, "ignore_eos": false, "ignore_prefix_size": 0, "label_smoothing": 0.0, "reg_alpha": 1.0, "report_accuracy": false, "sample_patch_num": 196, "sentence_avg": false, "use_rdrop": true}, "hooks": [{"type": "BestCkptSaverHook", "metric_key": "bleu-4", "interval": 100}, {"type": "TextLoggerHook", "interval": 1}, {"type": "IterTimerHook"}, {"type": "EvaluationHook", "by_epoch": true, "interval": 1}]}, "evaluation": {"dataloader": {"batch_size_per_gpu": 4, "workers_per_gpu": 0}, "metrics": [{"type": "bleu", "eval_tokenized_bleu": false, "ref_name": "labels", "hyp_name": "caption"}]}, "preprocessor": []}
|
||||
Reference in New Issue
Block a user