add ocr-finetune ned

This commit is contained in:
翎航
2022-10-26 11:45:50 +08:00
parent c077dea072
commit 90d47832c0
3 changed files with 43 additions and 9 deletions

View File

@@ -42,7 +42,7 @@ class AccuracyMetric(Metric):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise 'only support list or np.ndarray'
raise Exception('only support list or np.ndarray')
def evaluate(self):
assert len(self.preds) == len(self.labels)

View File

@@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys
@METRICS.register_module(group_key=default_group, module_name=Metrics.NED)
class NedMetric(Metric):
"""The metric computation class for classification classes.
"""The ned metric computation class for classification classes.
This metric class calculates accuracy for the whole input batches.
This metric class calculates the levenshtein distance between sentences for the whole input batches.
"""
def __init__(self, *args, **kwargs):
@@ -44,13 +44,46 @@ class NedMetric(Metric):
self.preds.extend(eval_results.tolist())
self.labels.extend(ground_truths.tolist())
else:
raise 'only support list or np.ndarray'
raise Exception('only support list or np.ndarray')
def evaluate(self):
assert len(self.preds) == len(self.labels)
return {
MetricKeys.NED: (np.asarray([
self.ned.distance(pred, ref)
1.0 - NedMetric._distance(pred, ref)
for pred, ref in zip(self.preds, self.labels)
])).mean().item()
}
@staticmethod
def _distance(pred, ref):
if pred is None or ref is None:
raise TypeError('Argument s0 is NoneType.')
if pred == ref:
return 0.0
if len(pred) == 0:
return len(ref)
if len(ref) == 0:
return len(pred)
m_len = max(len(pred), len(ref))
if m_len == 0:
return 0.0
def levenshtein(s0, s1):
v0 = [0] * (len(s1) + 1)
v1 = [0] * (len(s1) + 1)
for i in range(len(v0)):
v0[i] = i
for i in range(len(s0)):
v1[0] = i + 1
for j in range(len(s1)):
cost = 1
if s0[i] == s1[j]:
cost = 0
v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)
v0, v1 = v1, v0
return v0[len(s1)]
return levenshtein(pred, ref) / m_len

View File

@@ -87,7 +87,7 @@ class TestOfaTrainer(unittest.TestCase):
'max_image_size': 480,
'imagenet_default_mean_and_std': False},
'pipeline': {'type': 'ofa-ocr-recognition'},
'dataset': {'column_map': {'text': 'caption'}},
'dataset': {'column_map': {'text': 'label'}},
'train': {'work_dir': 'work/ckpts/recognition',
# 'launcher': 'pytorch',
'max_epochs': 1,
@@ -116,7 +116,6 @@ class TestOfaTrainer(unittest.TestCase):
'use_rdrop': True},
'hooks': [{'type': 'BestCkptSaverHook',
'metric_key': 'ned',
'rule': 'min',
'interval': 100},
{'type': 'TextLoggerHook', 'interval': 1},
{'type': 'IterTimerHook'},
@@ -138,11 +137,13 @@ class TestOfaTrainer(unittest.TestCase):
model=pretrained_model,
work_dir=WORKSPACE,
train_dataset=MsDataset.load(
'coco_2014_caption',
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='train[:12]'),
eval_dataset=MsDataset.load(
'coco_2014_caption',
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='validation[:4]'),
cfg_file=config_file)