fix conflict

This commit is contained in:
翎航
2022-10-26 22:41:13 +08:00
6 changed files with 164 additions and 49 deletions

View File

@@ -27,15 +27,21 @@ class AccuracyMetric(Metric):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
if isinstance(ground_truths, list):
self.preds.extend(eval_results)
self.labels.extend(ground_truths)
elif isinstance(ground_truths, np.ndarray):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise 'only support list or np.ndarray'
for truth in ground_truths:
self.labels.append(truth)
for result in eval_results:
if isinstance(truth, str):
self.preds.append(result.strip().replace(' ', ''))
else:
self.preds.append(result)
def evaluate(self):
assert len(self.preds) == len(self.labels)

View File

@@ -0,0 +1,87 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict
import numpy as np
from modelscope.metainfo import Metrics
from modelscope.outputs import OutputKeys
from modelscope.utils.registry import default_group
from .base import Metric
from .builder import METRICS, MetricKeys
@METRICS.register_module(group_key=default_group, module_name=Metrics.NED)
class NedMetric(Metric):
"""The ned metric computation class for classification classes.
This metric class calculates the levenshtein distance between sentences for the whole input batches.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.preds = []
self.labels = []
def add(self, outputs: Dict, inputs: Dict):
label_name = OutputKeys.LABEL if OutputKeys.LABEL in inputs else OutputKeys.LABELS
ground_truths = inputs[label_name]
eval_results = outputs[label_name]
for key in [
OutputKeys.CAPTION, OutputKeys.TEXT, OutputKeys.BOXES,
OutputKeys.LABELS, OutputKeys.SCORES
]:
if key in outputs and outputs[key] is not None:
eval_results = outputs[key]
break
assert type(ground_truths) == type(eval_results)
if isinstance(ground_truths, list):
self.preds.extend(eval_results)
self.labels.extend(ground_truths)
elif isinstance(ground_truths, np.ndarray):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise Exception('only support list or np.ndarray')
def evaluate(self):
assert len(self.preds) == len(self.labels)
return {
MetricKeys.NED: (np.asarray([
1.0 - NedMetric._distance(pred, ref)
for pred, ref in zip(self.preds, self.labels)
])).mean().item()
}
@staticmethod
def _distance(pred, ref):
if pred is None or ref is None:
raise TypeError('Argument (pred or ref) is NoneType.')
if pred == ref:
return 0.0
if len(pred) == 0:
return len(ref)
if len(ref) == 0:
return len(pred)
m_len = max(len(pred), len(ref))
if m_len == 0:
return 0.0
def levenshtein(s0, s1):
v0 = [0] * (len(s1) + 1)
v1 = [0] * (len(s1) + 1)
for i in range(len(v0)):
v0[i] = i
for i in range(len(s0)):
v1[0] = i + 1
for j in range(len(s1)):
cost = 1
if s0[i] == s1[j]:
cost = 0
v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)
v0, v1 = v1, v0
return v0[len(s1)]
return levenshtein(pred, ref) / m_len

View File

@@ -91,8 +91,24 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
])
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = data['image'] if isinstance(
data['image'], Image.Image) else load_image(data['image'])
if self.mode == ModeKeys.TRAIN:
return self._build_train_sample(data)
else:
return self._build_infer_sample(data)
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
sample['prev_output_tokens'] = torch.cat(
[self.bos_item, sample['target'][:-1]])
return sample
def _build_infer_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
image = self.get_img_pil(data[self.column_map['image']])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?')
inputs = self.tokenize_text(prompt)
@@ -102,4 +118,6 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_image': patch_image,
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
return sample

View File

@@ -129,9 +129,7 @@ class OFATrainer(EpochBasedTrainer):
def train_step(self, model, inputs):
model.train()
model_outputs = model.forward(inputs)
loss, sample_size, logging_output = self.criterion(
model_outputs, inputs)
loss, sample_size, logging_output = self.criterion(model, inputs)
train_outputs = {'loss': loss}
# add model output info to log
if 'log_vars' not in train_outputs:

View File

@@ -123,7 +123,7 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
self.padding_idx = args.tokenizer.pad_token_id
self.args = args
def forward(self, output, sample, update_num=0, reduce=True):
def forward(self, model, sample, update_num=0, reduce=True):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
@@ -131,11 +131,16 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
if 'labels' in sample:
del sample['labels']
if 'samples' in sample:
del sample['samples']
if self.use_rdrop:
construct_rdrop_sample(sample)
output = model.model(**sample['net_input'])
loss, nll_loss, ntokens = self.compute_loss(
output, sample, update_num, reduce=reduce)
output.logits, sample, update_num, reduce=reduce)
sample_size = (
sample['target'].size(0) if self.sentence_avg else ntokens)
logging_output = {
@@ -147,19 +152,18 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
}
return loss, sample_size, logging_output
def get_lprobs_and_target(self, net_output, sample):
def get_lprobs_and_target(self, logits, sample):
conf = sample['conf'][:, None, None] if 'conf' in sample and sample[
'conf'] is not None else 1
constraint_masks = None
if 'constraint_masks' in sample and sample[
'constraint_masks'] is not None:
constraint_masks = sample['constraint_masks']
net_output[0].masked_fill_(~constraint_masks, -math.inf)
logits.masked_fill_(~constraint_masks, -math.inf)
if self.constraint_start is not None and self.constraint_end is not None:
net_output[0][:, :, 4:self.constraint_start] = -math.inf
net_output[0][:, :, self.constraint_end:] = -math.inf
lprobs = F.log_softmax(
net_output[0], dim=-1, dtype=torch.float32) * conf
logits[:, :, 4:self.constraint_start] = -math.inf
logits[:, :, self.constraint_end:] = -math.inf
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) * conf
target = sample['target']
if self.ignore_prefix_size > 0:
lprobs = lprobs[:, self.ignore_prefix_size:, :].contiguous()
@@ -180,9 +184,9 @@ class AdjustLabelSmoothedCrossEntropyCriterion(_Loss):
return lprobs.view(-1,
lprobs.size(-1)), target.view(-1), constraint_masks
def compute_loss(self, net_output, sample, update_num, reduce=True):
def compute_loss(self, logits, sample, update_num, reduce=True):
lprobs, target, constraint_masks = self.get_lprobs_and_target(
net_output, sample)
logits, sample)
if constraint_masks is not None:
constraint_masks = constraint_masks[target != self.padding_idx]
lprobs = lprobs[target != self.padding_idx]

View File

@@ -5,10 +5,10 @@ import unittest
import json
from modelscope.metainfo import Metrics, Trainers
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import ModelFile
from modelscope.utils.constant import DownloadMode, ModelFile
from modelscope.utils.test_utils import test_level
@@ -17,26 +17,27 @@ class TestOfaTrainer(unittest.TestCase):
def setUp(self) -> None:
self.finetune_cfg = \
{'framework': 'pytorch',
'task': 'image-captioning',
'task': 'ocr-recognition',
'model': {'type': 'ofa',
'beam_search': {'beam_size': 5,
'max_len_b': 16,
'max_len_b': 64,
'min_len': 1,
'no_repeat_ngram_size': 0},
'seed': 7,
'max_src_length': 256,
'language': 'en',
'max_src_length': 128,
'language': 'zh',
'gen_type': 'generation',
'patch_image_size': 480,
'is_document': False,
'max_image_size': 480,
'imagenet_default_mean_and_std': False},
'pipeline': {'type': 'image-captioning'},
'dataset': {'column_map': {'text': 'caption'}},
'train': {'work_dir': 'work/ckpts/caption',
'pipeline': {'type': 'ofa-ocr-recognition'},
'dataset': {'column_map': {'text': 'label'}},
'train': {'work_dir': 'work/ckpts/recognition',
# 'launcher': 'pytorch',
'max_epochs': 1,
'use_fp16': True,
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0},
'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'lr_scheduler': {'name': 'polynomial_decay',
'warmup_proportion': 0.01,
'lr_end': 1e-07},
@@ -57,47 +58,48 @@ class TestOfaTrainer(unittest.TestCase):
'report_accuracy': False,
'sample_patch_num': 196,
'sentence_avg': False,
'use_rdrop': False},
'use_rdrop': True},
'hooks': [{'type': 'BestCkptSaverHook',
'metric_key': 'bleu-4',
'metric_key': 'accuracy',
'interval': 100},
{'type': 'TextLoggerHook', 'interval': 1},
{'type': 'IterTimerHook'},
{'type': 'EvaluationHook', 'by_epoch': True, 'interval': 1}]},
'evaluation': {'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'metrics': [{'type': 'bleu',
'eval_tokenized_bleu': False,
'ref_name': 'labels',
'hyp_name': 'caption'}]},
'metrics': [{'type': 'accuracy'}]},
'preprocessor': []}
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer_std(self):
WORKSPACE = './workspace/ckpts/caption'
WORKSPACE = './workspace/ckpts/recognition'
os.makedirs(WORKSPACE, exist_ok=True)
config_file = os.path.join(WORKSPACE, ModelFile.CONFIGURATION)
with open(config_file, 'w') as writer:
json.dump(self.finetune_cfg, writer)
pretrained_model = 'damo/ofa_image-caption_coco_distilled_en'
pretrained_model = 'damo/ofa_ocr-recognition_scene_base_zh'
args = dict(
model=pretrained_model,
work_dir=WORKSPACE,
train_dataset=MsDataset.load(
'coco_2014_caption',
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='train[:20]'),
split='train[:200]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load(
'coco_2014_caption',
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='validation[:10]'),
metrics=[Metrics.BLEU],
split='test[:20]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()
self.assertIn(ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, 'output')))
self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))
shutil.rmtree(WORKSPACE)