From 3b21ff10ec824b7ad6d062ce35c8cb7e990deec5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BF=8E=E8=88=AA?= Date: Mon, 31 Oct 2022 16:57:49 +0800 Subject: [PATCH] fix ocr prepreocess --- modelscope/preprocessors/multi_modal.py | 1 - modelscope/preprocessors/ofa/ocr_recognition.py | 11 ++++++----- requirements/multi-modal.txt | 2 ++ tests/trainers/test_ofa_trainer.py | 5 ++--- 4 files changed, 10 insertions(+), 9 deletions(-) diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 256c5243..af241d83 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -93,7 +93,6 @@ class OfaPreprocessor(Preprocessor): data = input else: data = self._build_dict(input) - data = self._ofa_input_compatibility_conversion(data) sample = self.preprocess(data) str_data = dict() for k, v in data.items(): diff --git a/modelscope/preprocessors/ofa/ocr_recognition.py b/modelscope/preprocessors/ofa/ocr_recognition.py index 26fff9d2..a0342c14 100644 --- a/modelscope/preprocessors/ofa/ocr_recognition.py +++ b/modelscope/preprocessors/ofa/ocr_recognition.py @@ -2,12 +2,12 @@ from typing import Any, Dict import torch -from PIL import Image +import unicodedata2 from torchvision import transforms from torchvision.transforms import InterpolationMode from torchvision.transforms import functional as F +from zhconv import convert -from modelscope.preprocessors.image import load_image from modelscope.utils.constant import ModeKeys from .base import OfaBasePreprocessor @@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: sample = self._build_infer_sample(data) - target = data[self.column_map['text']] - target = target.translate(self.transtab).strip() + target = sample['label'] target_token_list = target.strip().split() target = ' '.join(target_token_list[:self.max_tgt_length]) sample['target'] = self.tokenize_text(target, add_bos=False) @@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor): 'patch_mask': torch.tensor([True]) } if 'text' in self.column_map and self.column_map['text'] in data: - sample['label'] = data[self.column_map['text']] + target = data[self.column_map['text']] + target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans')) + sample['label'] = target return sample diff --git a/requirements/multi-modal.txt b/requirements/multi-modal.txt index 255f6155..578f0b54 100644 --- a/requirements/multi-modal.txt +++ b/requirements/multi-modal.txt @@ -11,3 +11,5 @@ timm tokenizers torchvision transformers>=4.12.0 +unicodedata2 +zhconv diff --git a/tests/trainers/test_ofa_trainer.py b/tests/trainers/test_ofa_trainer.py index 3f68a9fb..6f96aea1 100644 --- a/tests/trainers/test_ofa_trainer.py +++ b/tests/trainers/test_ofa_trainer.py @@ -5,7 +5,7 @@ import unittest import json -from modelscope.metainfo import Trainers +from modelscope.metainfo import Metrics, Trainers from modelscope.msdatasets import MsDataset from modelscope.trainers import build_trainer from modelscope.utils.constant import DownloadMode, ModelFile @@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase): 'ocr_fudanvi_zh', subset_name='scene', namespace='modelscope', - split='train[:200]', + split='train[800:900]', download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS), eval_dataset=MsDataset.load( 'ocr_fudanvi_zh', @@ -96,7 +96,6 @@ class TestOfaTrainer(unittest.TestCase): cfg_file=config_file) trainer = build_trainer(name=Trainers.ofa, default_args=args) trainer.train() - self.assertIn( ModelFile.TORCH_MODEL_BIN_FILE, os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))