caption finetune done, need add belu

This commit is contained in:
行嗔
2022-09-30 15:49:21 +08:00
parent a799dd237d
commit ac653594d8
6 changed files with 46 additions and 22 deletions

View File

@@ -126,7 +126,7 @@ class OfaForAllTasks(TorchModel):
return ret
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
if self.cfg.task == Tasks.image_captioning:
if not self.model.training and self.cfg.task == Tasks.image_captioning:
caption = input[OutputKeys.CAPTION]
result_l = list()
for cap in caption:

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
import string
from os import path as osp
import json
@@ -58,6 +59,9 @@ class OfaBasePreprocessor:
self.mean = [0.5, 0.5, 0.5]
self.std = [0.5, 0.5, 0.5]
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
self.transtab = str.maketrans(
{key: None
for key in string.punctuation})
self.constraint_trie = None
if self.cfg.model.get('answer2label', None):
ans2label_file = osp.join(model_dir, self.cfg.model.answer2label)

View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict, Union
import torch
@@ -43,6 +44,17 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data['text']
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target'][:-1]])
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
@@ -55,12 +67,3 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True])
}
return sample
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data['target']
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target)
return sample

View File

@@ -79,5 +79,6 @@ class TorchAMPOptimizerHook(OptimizerHook):
self.scaler.step(trainer.optimizer)
self.scaler.update(self._scale_update_param)
trainer.optimizer.zero_grad()
print('xcxcxcxcxc: optimizer step')
setattr(self._model, 'forward', self._ori_model_forward)

View File

@@ -1,6 +1,6 @@
import math
import os
from functools import partial
from typing import Dict, Optional
from datasets import load_dataset
from torch import distributed as dist
@@ -27,13 +27,7 @@ class OFATrainer(EpochBasedTrainer):
model_dir = model.model_dir
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
cfg = Config.from_file(cfg_file)
dataset = load_dataset(
cfg.dataset.script,
data_files=cfg.dataset.hf_dataset,
sep=cfg.dataset.sep,
)
dataset = MsDataset.from_hf_dataset(
dataset.rename_columns(cfg.dataset.column_map))
dataset = self._build_dataset_with_config(cfg)
preprocessor = {
ConfigKeys.train:
OfaPreprocessor(
@@ -42,9 +36,11 @@ class OFATrainer(EpochBasedTrainer):
OfaPreprocessor(
model_dir=model_dir, mode=ModeKeys.EVAL, no_collate=True),
}
epoch_steps = len(dataset['train']) // (
cfg.train.optimizer_hook.cumulative_iters
* cfg.train.dataloader.batch_size_per_gpu)
# use torchrun launch
world_size = int(os.environ.get('WORLD_SIZE', 1))
epoch_steps = math.ceil(
len(dataset['train']) / # noqa
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs
cfg.train.criterion.tokenizer = model.tokenizer
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion(
@@ -104,3 +100,24 @@ class OFATrainer(EpochBasedTrainer):
else:
self.log_buffer.update(train_outputs['log_vars'])
self.train_outputs = train_outputs
def _build_dataset_with_config(self, cfg):
if hasattr(cfg.dataset, 'hf_dataset'):
dataset = load_dataset(
cfg.dataset.script,
data_files=cfg.dataset.hf_dataset,
sep=cfg.dataset.sep,
)
dataset = MsDataset.from_hf_dataset(
dataset.rename_columns(cfg.dataset.column_map))
return dataset
elif hasattr(cfg.dataset, 'ms_dataset'):
dataset_d = dict()
for key in cfg.dataset.ms_dataset.keys():
dataset_d[key] = MsDataset.load(**cfg.dataset.ms_dataset[key])
dataset_d[key] = MsDataset.from_hf_dataset(
dataset_d[key]._hf_ds.rename_columns(
cfg.dataset.column_map))
return dataset_d
else:
raise NotImplementedError

View File

@@ -216,7 +216,6 @@ class EpochBasedTrainer(BaseTrainer):
self._max_epochs = self.cfg.train.max_epochs
else:
self._max_epochs = kwargs['max_epochs']
self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None)
self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None)
if self._train_iters_per_epoch is None and hasattr(