diff --git a/modelscope/metrics/accuracy_metric.py b/modelscope/metrics/accuracy_metric.py index 1761786e..953ece4c 100644 --- a/modelscope/metrics/accuracy_metric.py +++ b/modelscope/metrics/accuracy_metric.py @@ -27,15 +27,21 @@ class AccuracyMetric(Metric): label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS ground_truths = inputs[label_name] eval_results = outputs[label_name] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in outputs and outputs[key] is not None: + eval_results = outputs[key] + break assert type(ground_truths) == type(eval_results) - if isinstance(ground_truths, list): - self.preds.extend(eval_results) - self.labels.extend(ground_truths) - elif isinstance(ground_truths, np.ndarray): - self.preds.extend(eval_results.tolist()) - self.labels.extend(ground_truths.tolist()) - else: - raise 'only support list or np.ndarray' + for truth in ground_truths: + self.labels.append(truth) + for result in eval_results: + if isinstance(truth, str): + self.preds.append(result.strip().replace(' ', '')) + else: + self.preds.append(result) def evaluate(self): assert len(self.preds) == len(self.labels) diff --git a/modelscope/metrics/ned_metric.py b/modelscope/metrics/ned_metric.py new file mode 100644 index 00000000..e87bb2c4 --- /dev/null +++ b/modelscope/metrics/ned_metric.py @@ -0,0 +1,87 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import Dict + +import numpy as np + +from modelscope.metainfo import Metrics +from modelscope.outputs import OutputKeys +from modelscope.utils.registry import default_group +from .base import Metric +from .builder import METRICS, MetricKeys + + +@METRICS.register_module(group_key=default_group, module_name=Metrics.NED) +class NedMetric(Metric): + """The ned metric computation class for classification classes. + + This metric class calculates the levenshtein distance between sentences for the whole input batches. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.preds = [] + self.labels = [] + + def add(self, outputs: Dict, inputs: Dict): + label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS + ground_truths = inputs[label_name] + eval_results = outputs[label_name] + for key in [ + OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES, + OutputKeys.LABELS, OutputKeys.SCORES + ]: + if key in outputs and outputs[key] is not None: + eval_results = outputs[key] + break + assert type(ground_truths) == type(eval_results) + if isinstance(ground_truths, list): + self.preds.extend(eval_results) + self.labels.extend(ground_truths) + elif isinstance(ground_truths, np.ndarray): + self.preds.extend(eval_results.tolist()) + self.labels.extend(ground_truths.tolist()) + else: + raise Exception('only support list or np.ndarray') + + def evaluate(self): + assert len(self.preds) == len(self.labels) + return { + MetricKeys.NED: (np.asarray([ + 1.0 - NedMetric._distance(pred, ref) + for pred, ref in zip(self.preds, self.labels) + ])).mean().item() + } + + @staticmethod + def _distance(pred, ref): + if pred is None or ref is None: + raise TypeError('Argument (pred or ref) is NoneType.') + if pred == ref: + return 0.0 + if len(pred) == 0: + return len(ref) + if len(ref) == 0: + return len(pred) + m_len = max(len(pred), len(ref)) + if m_len == 0: + return 0.0 + + def levenshtein(s0, s1): + v0 = [0] * (len(s1) + 1) + v1 = [0] * (len(s1) + 1) + + for i in range(len(v0)): + v0[i] = i + + for i in range(len(s0)): + v1[0] = i + 1 + for j in range(len(s1)): + cost = 1 + if s0[i] == s1[j]: + cost = 0 + v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost) + v0, v1 = v1, v0 + return v0[len(s1)] + + return levenshtein(pred, ref) / m_len diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py index 1761dbd4..26fff9d2 100644 --- a/modelscope/preprocessors/ofa/ocr_recognition.py +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -91,8 +91,24 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): ]) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: - image = data['image'] if isinstance( - data['image'], Image.Image) else load_image(data['image']) + if self.mode == ModeKeys.TRAIN: + return self._build_train_sample(data) + else: + return self._build_infer_sample(data) + + def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + sample = self._build_infer_sample(data) + target = data[self.column_map['text']] + target = target.translate(self.transtab).strip() + target_token_list = target.strip().split() + target = ' '.join(target_token_list[:self.max_tgt_length]) + sample['target'] = self.tokenize_text(target, add_bos=False) + sample['prev_output_tokens'] = torch.cat( + [self.bos_item, sample['target'][:-1]]) + return sample + + def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: + image = self.get_img_pil(data[self.column_map['image']]) patch_image = self.patch_resize_transform(image) prompt = self.cfg.model.get('prompt', '图片上的文字是什么?') inputs = self.tokenize_text(prompt) @@ -102,4 +118,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): 'patch_image': patch_image, 'patch_mask': torch.tensor([True]) } + if 'text' in self.column_map and self.column_map['text'] in data: + sample['label'] = data[self.column_map['text']] return sample diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py index 02853925..f8028c6c 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer.py @@ -129,9 +129,7 @@ class OFATrainer(EpochBasedTrainer): def train_step(self, model, inputs): model.train() - model_outputs = model.forward(inputs) - loss, sample_size, logging_output = self.criterion( - model_outputs, inputs) + loss, sample_size, logging_output = self.criterion(model, inputs) train_outputs = {'loss': loss} # add model output info to log if 'log_vars' not in train_outputs: diff --git a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py index 2189a5db..3c38884c 100644 --- a/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py +++ b/modelscope/trainers/multi_modal/ofa/ofa_trainer_utils.py @@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): self.padding_idx = args.tokenizer.pad_token_id self.args = args - def forward(self, output, sample, update_num=0, reduce=True): + def forward(self, model, sample, update_num=0, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: @@ -131,11 +131,16 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ + if 'labels' in sample: + del sample['labels'] + if 'samples' in sample: + del sample['samples'] + if self.use_rdrop: construct_rdrop_sample(sample) - + output = model.model(**sample['net_input']) loss, nll_loss, ntokens = self.compute_loss( - output, sample, update_num, reduce=reduce) + output.logits, sample, update_num, reduce=reduce) sample_size = ( sample['target'].size(0) if self.sentence_avg else ntokens) logging_output = { @@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): } return loss, sample_size, logging_output - def get_lprobs_and_target(self, net_output, sample): + def get_lprobs_and_target(self, logits, sample): conf = sample['conf'][:, None, None] if 'conf' in sample and sample[ 'conf'] is not None else 1 constraint_masks = None if 'constraint_masks' in sample and sample[ 'constraint_masks'] is not None: constraint_masks = sample['constraint_masks'] - net_output[0].masked_fill_(~constraint_masks, -math.inf) + logits.masked_fill_(~constraint_masks, -math.inf) if self.constraint_start is not None and self.constraint_end is not None: - net_output[0][:, :, 4:self.constraint_start] = -math.inf - net_output[0][:, :, self.constraint_end:] = -math.inf - lprobs = F.log_softmax( - net_output[0], dim=-1, dtype=torch.float32) * conf + logits[:, :, 4:self.constraint_start] = -math.inf + logits[:, :, self.constraint_end:] = -math.inf + lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf target = sample['target'] if self.ignore_prefix_size > 0: lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous() @@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss): return lprobs.view(-1, lprobs.size(-1)), target.view(-1), constraint_masks - def compute_loss(self, net_output, sample, update_num, reduce=True): + def compute_loss(self, logits, sample, update_num, reduce=True): lprobs, target, constraint_masks = self.get_lprobs_and_target( - net_output, sample) + logits, sample) if constraint_masks is not None: constraint_masks = constraint_masks[target != self.padding_idx] lprobs = lprobs[target != self.padding_idx] diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 06003625..3f68a9fb 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -5,10 +5,10 @@ import unittest import json -from modelscope.metainfo import Metrics, Trainers +from modelscope.metainfo import Trainers from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer -from modelscope.utils.constant import ModelFile +from modelscope.utils.constant import DownloadMode, ModelFile from modelscope.utils.test_utils import test_level @@ -17,26 +17,27 @@ class TestOfaTrainer(unittest.TestCase): def setUp(self) -> None: self.finetune_cfg = \ {'framework': 'pytorch', - 'task': 'image-captioning', + 'task': 'ocr-recognition', 'model': {'type': 'ofa', 'beam_search': {'beam_size': 5, - 'max_len_b': 16, + 'max_len_b': 64, 'min_len': 1, 'no_repeat_ngram_size': 0}, 'seed': 7, - 'max_src_length': 256, - 'language': 'en', + 'max_src_length': 128, + 'language': 'zh', 'gen_type': 'generation', 'patch_image_size': 480, + 'is_document': False, 'max_image_size': 480, 'imagenet_default_mean_and_std': False}, - 'pipeline': {'type': 'image-captioning'}, - 'dataset': {'column_map': {'text': 'caption'}}, - 'train': {'work_dir': 'work/ckpts/caption', + 'pipeline': {'type': 'ofa-ocr-recognition'}, + 'dataset': {'column_map': {'text': 'label'}}, + 'train': {'work_dir': 'work/ckpts/recognition', # 'launcher': 'pytorch', 'max_epochs': 1, 'use_fp16': True, - 'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0}, + 'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, 'lr_scheduler': {'name': 'polynomial_decay', 'warmup_proportion': 0.01, 'lr_end': 1e-07}, @@ -57,47 +58,48 @@ class TestOfaTrainer(unittest.TestCase): 'report_accuracy': False, 'sample_patch_num': 196, 'sentence_avg': False, - 'use_rdrop': False}, + 'use_rdrop': True}, 'hooks': [{'type': 'BestCkptSaverHook', - 'metric_key': 'bleu-4', + 'metric_key': 'accuracy', 'interval': 100}, {'type': 'TextLoggerHook', 'interval': 1}, {'type': 'IterTimerHook'}, {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]}, 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0}, - 'metrics': [{'type': 'bleu', - 'eval_tokenized_bleu': False, - 'ref_name': 'labels', - 'hyp_name': 'caption'}]}, + 'metrics': [{'type': 'accuracy'}]}, 'preprocessor': []} @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_trainer_std(self): - WORKSPACE = './workspace/ckpts/caption' + WORKSPACE = './workspace/ckpts/recognition' os.makedirs(WORKSPACE, exist_ok=True) config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION) with open(config_file, 'w') as writer: json.dump(self.finetune_cfg, writer) - pretrained_model = 'damo/ofa_image-caption_coco_distilled_en' + pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh' args = dict( model=pretrained_model, work_dir=WORKSPACE, train_dataset=MsDataset.load( - 'coco_2014_caption', + 'ocr_fudanvi_zh', + subset_name='scene', namespace='modelscope', - split='train[:20]'), + split='train[:200]', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), eval_dataset=MsDataset.load( - 'coco_2014_caption', + 'ocr_fudanvi_zh', + subset_name='scene', namespace='modelscope', - split='validation[:10]'), - metrics=[Metrics.BLEU], + split='test[:20]', + download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), cfg_file=config_file) trainer = build_trainer(name=Trainers.ofa, default_args=args) trainer.train() - self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE, - os.listdir(os.path.join(WORKSPACE, 'output'))) + self.assertIn( + ModelFile.TORCH_MODEL_BIN_FILE, + os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR))) shutil.rmtree(WORKSPACE)