fix a bug

This commit is contained in:
行嗔
2022-10-24 23:19:23 +08:00
parent 428599f3e5
commit 46c3bdcfe8
2 changed files with 16 additions and 8 deletions

View File

@@ -8,6 +8,7 @@ from torchvision.transforms import InterpolationMode
from torchvision.transforms import functional as F
from modelscope.preprocessors.image import load_image
from modelscope.utils.constant import ModeKeys
from .base import OfaBasePreprocessor
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
@@ -57,14 +58,21 @@ def ocr_resize(img, patch_image_size, is_document=False):
class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
def __init__(self, cfg, model_dir):
def __init__(self,
cfg,
model_dir,
mode=ModeKeys.INFERENCE,
*args,
**kwargs):
"""preprocess the data
Args:
cfg(modelscope.utils.config.ConfigDict) : model config
model_dir (str): model path
model_dir (str): model path,
mode: preprocessor mode (model mode)
"""
super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir)
super(OfaOcrRecognitionPreprocessor,
self).__init__(cfg, model_dir, mode, *args, **kwargs)
# Initialize transform
if self.cfg.model.imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
@@ -87,7 +95,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
data['image'], Image.Image) else load_image(data['image'])
patch_image = self.patch_resize_transform(image)
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?')
inputs = self.get_inputs(prompt)
inputs = self.tokenize_text(prompt)
sample = {
'source': inputs,

View File

@@ -36,10 +36,10 @@ class TestOfaTrainer(unittest.TestCase):
# 'launcher': 'pytorch',
'max_epochs': 1,
'use_fp16': True,
'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0},
'lr_scheduler': {'name': 'polynomial_decay',
'warmup_proportion': 0.01,
'lr_end': 1e-07},
'lr_endo': 1e-07},
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01},
'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
@@ -86,11 +86,11 @@ class TestOfaTrainer(unittest.TestCase):
train_dataset=MsDataset.load(
'coco_2014_caption',
namespace='modelscope',
split='train[:100]'),
split='train[:20]'),
eval_dataset=MsDataset.load(
'coco_2014_caption',
namespace='modelscope',
split='validation[:20]'),
split='validation[:10]'),
metrics=[Metrics.BLEU],
cfg_file=config_file)
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args)