fix ocr prepreocess

This commit is contained in:
翎航
2022-10-31 16:57:49 +08:00
parent 2299f8fa65
commit 3b21ff10ec
4 changed files with 10 additions and 9 deletions

View File

@@ -93,7 +93,6 @@ class OfaPreprocessor(Preprocessor):
data = input
else:
data = self._build_dict(input)
data = self._ofa_input_compatibility_conversion(data)
sample = self.preprocess(data)
str_data = dict()
for k, v in data.items():

View File

@@ -2,12 +2,12 @@
from typing import Any, Dict
import torch
from PIL import Image
import unicodedata2
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
from zhconv import convert
from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
@@ -98,8 +98,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
def _build_train_sample(self, data: Dict[str, Any]) -> Dict[str, Any]:
sample = self._build_infer_sample(data)
target = data[self.column_map['text']]
target = target.translate(self.transtab).strip()
target = sample['label']
target_token_list = target.strip().split()
target = ' '.join(target_token_list[:self.max_tgt_length])
sample['target'] = self.tokenize_text(target, add_bos=False)
@@ -119,5 +118,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
'patch_mask': torch.tensor([True])
}
if 'text' in self.column_map and self.column_map['text'] in data:
sample['label'] = data[self.column_map['text']]
target = data[self.column_map['text']]
target = unicodedata2.normalize('NFKC', convert(target, 'zh-hans'))
sample['label'] = target
return sample

View File

@@ -11,3 +11,5 @@ timm
tokenizers
torchvision
transformers>=4.12.0
unicodedata2
zhconv

View File

@@ -5,7 +5,7 @@ import unittest
import json
from modelscope.metainfo import Trainers
from modelscope.metainfo import Metrics, Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.constant import DownloadMode, ModelFile
@@ -85,7 +85,7 @@ class TestOfaTrainer(unittest.TestCase):
'ocr_fudanvi_zh',
subset_name='scene',
namespace='modelscope',
split='train[:200]',
split='train[800:900]',
download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS),
eval_dataset=MsDataset.load(
'ocr_fudanvi_zh',
@@ -96,7 +96,6 @@ class TestOfaTrainer(unittest.TestCase):
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa, default_args=args)
trainer.train()
self.assertIn(
ModelFile.TORCH_MODEL_BIN_FILE,
os.listdir(os.path.join(WORKSPACE, ModelFile.TRAIN_OUTPUT_DIR)))