mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
caption finetune done, need add belu
This commit is contained in:
@@ -126,7 +126,7 @@ class OfaForAllTasks(TorchModel):
|
||||
return ret
|
||||
|
||||
def postprocess(self, input: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
if self.cfg.task == Tasks.image_captioning:
|
||||
if not self.model.training and self.cfg.task == Tasks.image_captioning:
|
||||
caption = input[OutputKeys.CAPTION]
|
||||
result_l = list()
|
||||
for cap in caption:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import re
|
||||
import string
|
||||
from os import path as osp
|
||||
|
||||
import json
|
||||
@@ -58,6 +59,9 @@ class OfaBasePreprocessor:
|
||||
self.mean = [0.5, 0.5, 0.5]
|
||||
self.std = [0.5, 0.5, 0.5]
|
||||
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
|
||||
self.transtab = str.maketrans(
|
||||
{key: None
|
||||
for key in string.punctuation})
|
||||
self.constraint_trie = None
|
||||
if self.cfg.model.get('answer2label', None):
|
||||
ans2label_file = osp.join(model_dir, self.cfg.model.answer2label)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch
|
||||
@@ -43,6 +44,17 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
|
||||
else:
|
||||
return self._build_infer_sample(data)
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
target = data['text']
|
||||
target = target.translate(self.transtab).strip()
|
||||
target_token_list = target.strip().split()
|
||||
target = ' '.join(target_token_list[:self.max_tgt_length])
|
||||
sample['target'] = self.tokenize_text(target, add_bos=False)
|
||||
sample['prev_output_tokens'] = torch.cat(
|
||||
[self.bos_item, sample['target'][:-1]])
|
||||
return sample
|
||||
|
||||
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
image = data['image'] if isinstance(
|
||||
data['image'], Image.Image) else load_image(data['image'])
|
||||
@@ -55,12 +67,3 @@ class OfaImageCaptioningPreprocessor(OfaBasePreprocessor):
|
||||
'patch_mask': torch.tensor([True])
|
||||
}
|
||||
return sample
|
||||
|
||||
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sample = self._build_infer_sample(data)
|
||||
target = data['target']
|
||||
target = target.translate(self.transtab).strip()
|
||||
target_token_list = target.strip().split()
|
||||
target = ' '.join(target_token_list[:self.max_tgt_length])
|
||||
sample['target'] = self.tokenize_text(target)
|
||||
return sample
|
||||
|
||||
@@ -79,5 +79,6 @@ class TorchAMPOptimizerHook(OptimizerHook):
|
||||
self.scaler.step(trainer.optimizer)
|
||||
self.scaler.update(self._scale_update_param)
|
||||
trainer.optimizer.zero_grad()
|
||||
print('xcxcxcxcxc: optimizer step')
|
||||
|
||||
setattr(self._model, 'forward', self._ori_model_forward)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Dict, Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from torch import distributed as dist
|
||||
@@ -27,13 +27,7 @@ class OFATrainer(EpochBasedTrainer):
|
||||
model_dir = model.model_dir
|
||||
cfg_file = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
cfg = Config.from_file(cfg_file)
|
||||
dataset = load_dataset(
|
||||
cfg.dataset.script,
|
||||
data_files=cfg.dataset.hf_dataset,
|
||||
sep=cfg.dataset.sep,
|
||||
)
|
||||
dataset = MsDataset.from_hf_dataset(
|
||||
dataset.rename_columns(cfg.dataset.column_map))
|
||||
dataset = self._build_dataset_with_config(cfg)
|
||||
preprocessor = {
|
||||
ConfigKeys.train:
|
||||
OfaPreprocessor(
|
||||
@@ -42,9 +36,11 @@ class OFATrainer(EpochBasedTrainer):
|
||||
OfaPreprocessor(
|
||||
model_dir=model_dir, mode=ModeKeys.EVAL, no_collate=True),
|
||||
}
|
||||
epoch_steps = len(dataset['train']) // (
|
||||
cfg.train.optimizer_hook.cumulative_iters
|
||||
* cfg.train.dataloader.batch_size_per_gpu)
|
||||
# use torchrun launch
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
epoch_steps = math.ceil(
|
||||
len(dataset['train']) / # noqa
|
||||
(cfg.train.dataloader.batch_size_per_gpu * world_size)) # noqa
|
||||
cfg.train.lr_scheduler.num_train_steps = epoch_steps * cfg.train.max_epochs
|
||||
cfg.train.criterion.tokenizer = model.tokenizer
|
||||
self.criterion = AdjustLabelSmoothedCrossEntropyCriterion(
|
||||
@@ -104,3 +100,24 @@ class OFATrainer(EpochBasedTrainer):
|
||||
else:
|
||||
self.log_buffer.update(train_outputs['log_vars'])
|
||||
self.train_outputs = train_outputs
|
||||
|
||||
def _build_dataset_with_config(self, cfg):
|
||||
if hasattr(cfg.dataset, 'hf_dataset'):
|
||||
dataset = load_dataset(
|
||||
cfg.dataset.script,
|
||||
data_files=cfg.dataset.hf_dataset,
|
||||
sep=cfg.dataset.sep,
|
||||
)
|
||||
dataset = MsDataset.from_hf_dataset(
|
||||
dataset.rename_columns(cfg.dataset.column_map))
|
||||
return dataset
|
||||
elif hasattr(cfg.dataset, 'ms_dataset'):
|
||||
dataset_d = dict()
|
||||
for key in cfg.dataset.ms_dataset.keys():
|
||||
dataset_d[key] = MsDataset.load(**cfg.dataset.ms_dataset[key])
|
||||
dataset_d[key] = MsDataset.from_hf_dataset(
|
||||
dataset_d[key]._hf_ds.rename_columns(
|
||||
cfg.dataset.column_map))
|
||||
return dataset_d
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -216,7 +216,6 @@ class EpochBasedTrainer(BaseTrainer):
|
||||
self._max_epochs = self.cfg.train.max_epochs
|
||||
else:
|
||||
self._max_epochs = kwargs['max_epochs']
|
||||
|
||||
self._train_iters_per_epoch = kwargs.get('train_iters_per_epoch', None)
|
||||
self._eval_iters_per_epoch = kwargs.get('val_iters_per_epoch', None)
|
||||
if self._train_iters_per_epoch is None and hasattr(
|
||||
|
||||
Reference in New Issue
Block a user