add ocr-finetune

This commit is contained in:
翎航
2022-10-26 10:52:10 +08:00
parent cc8b78eac8
commit c077dea072
6 changed files with 152 additions and 18 deletions

View File

@@ -27,6 +27,13 @@ class AccuracyMetric(Metric):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
if isinstance(ground_truths, list):
self.preds.extend(eval_results)

View File

@@ -0,0 +1,56 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
import numpy as np
from similarity.normalized_levenshtein import NormalizedLevenshtein
from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(group_key=default_group, module_name=Metrics.NED)
class NedMetric(Metric):
"""The metric computation class for classification classes.
This metric class calculates accuracy for the whole input batches.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ned = NormalizedLevenshtein()
self.preds = []
self.labels = []
def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
if isinstance(ground_truths, list):
self.preds.extend(eval_results)
self.labels.extend(ground_truths)
elif isinstance(ground_truths, np.ndarray):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise 'only support list or np.ndarray'
def evaluate(self):
assert len(self.preds) == len(self.labels)
return {
MetricKeys.NED: (np.asarray([
self.ned.distance(pred, ref)
for pred, ref in zip(self.preds, self.labels)
])).mean().item()
}

View File

@@ -91,8 +91,24 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target'][:-1]])
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?')
inputs = self.tokenize_text(prompt)
@@ -102,4 +118,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
return sample

View File

@@ -6,6 +6,7 @@ pycocotools>=2.0.4
# which introduced compatability issues that are being investigated
rouge_score<=0.0.4
sacrebleu
strsim
taming-transformers-rom1504
timm
tokenizers

View File

@@ -15,9 +15,64 @@ from modelscope.utils.test_utils import test_level
class TestOfaTrainer(unittest.TestCase):
def setUp(self) -> None:
# self.finetune_cfg = \
# {'framework': 'pytorch',
# 'task': 'image-captioning',
# 'model': {'type': 'ofa',
# 'beam_search': {'beam_size': 5,
# 'max_len_b': 16,
# 'min_len': 1,
# 'no_repeat_ngram_size': 0},
# 'seed': 7,
# 'max_src_length': 256,
# 'language': 'en',
# 'gen_type': 'generation',
# 'patch_image_size': 480,
# 'max_image_size': 480,
# 'imagenet_default_mean_and_std': False},
# 'pipeline': {'type': 'image-captioning'},
# 'dataset': {'column_map': {'text': 'caption'}},
# 'train': {'work_dir': 'work/ckpts/caption',
# # 'launcher': 'pytorch',
# 'max_epochs': 1,
# 'use_fp16': True,
# 'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
# 'lr_scheduler': {'name': 'polynomial_decay',
# 'warmup_proportion': 0.01,
# 'lr_end': 1e-07},
# 'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
# 'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01},
# 'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
# 'cumulative_iters': 1,
# 'grad_clip': {'max_norm': 1.0, 'norm_type': 2},
# 'loss_keys': 'loss'},
# 'criterion': {'name': 'AdjustLabelSmoothedCrossEntropyCriterion',
# 'constraint_range': None,
# 'drop_worst_after': 0,
# 'drop_worst_ratio': 0.0,
# 'ignore_eos': False,
# 'ignore_prefix_size': 0,
# 'label_smoothing': 0.1,
# 'reg_alpha': 1.0,
# 'report_accuracy': False,
# 'sample_patch_num': 196,
# 'sentence_avg': False,
# 'use_rdrop': True},
# 'hooks': [{'type': 'BestCkptSaverHook',
# 'metric_key': 'bleu-4',
# 'interval': 100},
# {'type': 'TextLoggerHook', 'interval': 1},
# {'type': 'IterTimerHook'},
# {'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]},
# 'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
# 'metrics': [{'type': 'bleu',
# 'eval_tokenized_bleu': False,
# 'ref_name': 'labels',
# 'hyp_name': 'caption'}]},
# 'preprocessor': []}
self.finetune_cfg = \
{'framework': 'pytorch',
'task': 'image-captioning',
'task': 'ocr-recognition',
'model': {'type': 'ofa',
'beam_search': {'beam_size': 5,
'max_len_b': 16,
@@ -25,18 +80,19 @@ class TestOfaTrainer(unittest.TestCase):
'no_repeat_ngram_size': 0},
'seed': 7,
'max_src_length': 256,
'language': 'en',
'language': 'zh',
'gen_type': 'generation',
'patch_image_size': 480,
'is_document': False,
'max_image_size': 480,
'imagenet_default_mean_and_std': False},
'pipeline': {'type': 'image-captioning'},
'pipeline': {'type': 'ofa-ocr-recognition'},
'dataset': {'column_map': {'text': 'caption'}},
'train': {'work_dir': 'work/ckpts/caption',
'train': {'work_dir': 'work/ckpts/recognition',
# 'launcher': 'pytorch',
'max_epochs': 1,
'use_fp16': True,
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0},
'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'lr_scheduler': {'name': 'polynomial_decay',
'warmup_proportion': 0.01,
'lr_end': 1e-07},
@@ -59,39 +115,36 @@ class TestOfaTrainer(unittest.TestCase):
'sentence_avg': False,
'use_rdrop': True},
'hooks': [{'type': 'BestCkptSaverHook',
'metric_key': 'bleu-4',
'metric_key': 'ned',
'rule': 'min',
'interval': 100},
{'type': 'TextLoggerHook', 'interval': 1},
{'type': 'IterTimerHook'},
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]},
'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'metrics': [{'type': 'bleu',
'eval_tokenized_bleu': False,
'ref_name': 'labels',
'hyp_name': 'caption'}]},
'metrics': [{'type': 'ned'}]},
'preprocessor': []}
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_std(self):
WORKSPACE = './workspace/ckpts/caption'
WORKSPACE = './workspace/ckpts/recognition'
os.makedirs(WORKSPACE, exist_ok=True)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
with open(config_file, 'w') as writer:
json.dump(self.finetune_cfg, writer)
pretrained_model = 'damo/ofa_image-caption_coco_distilled_en'
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
args = dict(
model=pretrained_model,
work_dir=WORKSPACE,
train_dataset=MsDataset.load(
'coco_2014_caption',
namespace='modelscope',
split='train[:20]'),
split='train[:12]'),
eval_dataset=MsDataset.load(
'coco_2014_caption',
namespace='modelscope',
split='validation[:10]'),
metrics=[Metrics.BLEU],
split='validation[:4]'),
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()

View File

@@ -1 +0,0 @@
{"framework": "pytorch", "task": "image-captioning", "model": {"type": "ofa", "beam_search": {"beam_size": 5, "max_len_b": 16, "min_len": 1, "no_repeat_ngram_size": 0}, "seed": 7, "max_src_length": 256, "language": "en", "gen_type": "generation", "patch_image_size": 480, "max_image_size": 480, "imagenet_default_mean_and_std": false}, "pipeline": {"type": "image-captioning"}, "dataset": {"column_map": {"text": "caption"}}, "train": {"work_dir": "work/ckpts/caption", "max_epochs": 1, "use_fp16": true, "dataloader": {"batch_size_per_gpu": 4, "workers_per_gpu": 0}, "lr_scheduler": {"name": "polynomial_decay", "warmup_proportion": 0.01, "lr_end": 1e-07}, "lr_scheduler_hook": {"type": "LrSchedulerHook", "by_epoch": false}, "optimizer": {"type": "AdamW", "lr": 5e-05, "weight_decay": 0.01}, "optimizer_hook": {"type": "TorchAMPOptimizerHook", "cumulative_iters": 1, "grad_clip": {"max_norm": 1.0, "norm_type": 2}, "loss_keys": "loss"}, "criterion": {"name": "AdjustLabelSmoothedCrossEntropyCriterion", "constraint_range": null, "drop_worst_after": 0, "drop_worst_ratio": 0.0, "ignore_eos": false, "ignore_prefix_size": 0, "label_smoothing": 0.0, "reg_alpha": 1.0, "report_accuracy": false, "sample_patch_num": 196, "sentence_avg": false, "use_rdrop": true}, "hooks": [{"type": "BestCkptSaverHook", "metric_key": "bleu-4", "interval": 100}, {"type": "TextLoggerHook", "interval": 1}, {"type": "IterTimerHook"}, {"type": "EvaluationHook", "by_epoch": true, "interval": 1}]}, "evaluation": {"dataloader": {"batch_size_per_gpu": 4, "workers_per_gpu": 0}, "metrics": [{"type": "bleu", "eval_tokenized_bleu": false, "ref_name": "labels", "hyp_name": "caption"}]}, "preprocessor": []}