mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
add ocr-finetune ned
This commit is contained in:
@@ -42,7 +42,7 @@ class AccuracyMetric(Metric):
|
||||
self.preds.extend(eval_results.tolist())
|
||||
self.labels.extend(ground_truths.tolist())
|
||||
else:
|
||||
raise 'only support list or np.ndarray'
|
||||
raise Exception('only support list or np.ndarray')
|
||||
|
||||
def evaluate(self):
|
||||
assert len(self.preds) == len(self.labels)
|
||||
|
||||
@@ -14,9 +14,9 @@ from .builder import METRICS, MetricKeys
|
||||
|
||||
@METRICS.register_module(group_key=default_group, module_name=Metrics.NED)
|
||||
class NedMetric(Metric):
|
||||
"""The metric computation class for classification classes.
|
||||
"""The ned metric computation class for classification classes.
|
||||
|
||||
This metric class calculates accuracy for the whole input batches.
|
||||
This metric class calculates the levenshtein distance between sentences for the whole input batches.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -44,13 +44,46 @@ class NedMetric(Metric):
|
||||
self.preds.extend(eval_results.tolist())
|
||||
self.labels.extend(ground_truths.tolist())
|
||||
else:
|
||||
raise 'only support list or np.ndarray'
|
||||
raise Exception('only support list or np.ndarray')
|
||||
|
||||
def evaluate(self):
|
||||
assert len(self.preds) == len(self.labels)
|
||||
return {
|
||||
MetricKeys.NED: (np.asarray([
|
||||
self.ned.distance(pred, ref)
|
||||
1.0 - NedMetric._distance(pred, ref)
|
||||
for pred, ref in zip(self.preds, self.labels)
|
||||
])).mean().item()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _distance(pred, ref):
|
||||
if pred is None or ref is None:
|
||||
raise TypeError('Argument s0 is NoneType.')
|
||||
if pred == ref:
|
||||
return 0.0
|
||||
if len(pred) == 0:
|
||||
return len(ref)
|
||||
if len(ref) == 0:
|
||||
return len(pred)
|
||||
m_len = max(len(pred), len(ref))
|
||||
if m_len == 0:
|
||||
return 0.0
|
||||
|
||||
def levenshtein(s0, s1):
|
||||
v0 = [0] * (len(s1) + 1)
|
||||
v1 = [0] * (len(s1) + 1)
|
||||
|
||||
for i in range(len(v0)):
|
||||
v0[i] = i
|
||||
|
||||
for i in range(len(s0)):
|
||||
v1[0] = i + 1
|
||||
for j in range(len(s1)):
|
||||
cost = 1
|
||||
if s0[i] == s1[j]:
|
||||
cost = 0
|
||||
v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)
|
||||
v0, v1 = v1, v0
|
||||
return v0[len(s1)]
|
||||
|
||||
return levenshtein(pred, ref) / m_len
|
||||
|
||||
@@ -87,7 +87,7 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
'max_image_size': 480,
|
||||
'imagenet_default_mean_and_std': False},
|
||||
'pipeline': {'type': 'ofa-ocr-recognition'},
|
||||
'dataset': {'column_map': {'text': 'caption'}},
|
||||
'dataset': {'column_map': {'text': 'label'}},
|
||||
'train': {'work_dir': 'work/ckpts/recognition',
|
||||
# 'launcher': 'pytorch',
|
||||
'max_epochs': 1,
|
||||
@@ -116,7 +116,6 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
'use_rdrop': True},
|
||||
'hooks': [{'type': 'BestCkptSaverHook',
|
||||
'metric_key': 'ned',
|
||||
'rule': 'min',
|
||||
'interval': 100},
|
||||
{'type': 'TextLoggerHook', 'interval': 1},
|
||||
{'type': 'IterTimerHook'},
|
||||
@@ -138,11 +137,13 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
model=pretrained_model,
|
||||
work_dir=WORKSPACE,
|
||||
train_dataset=MsDataset.load(
|
||||
'coco_2014_caption',
|
||||
'ocr_fudanvi_zh',
|
||||
subset_name='scene',
|
||||
namespace='modelscope',
|
||||
split='train[:12]'),
|
||||
eval_dataset=MsDataset.load(
|
||||
'coco_2014_caption',
|
||||
'ocr_fudanvi_zh',
|
||||
subset_name='scene',
|
||||
namespace='modelscope',
|
||||
split='validation[:4]'),
|
||||
cfg_file=config_file)
|
||||
|
||||
Reference in New Issue
Block a user