mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-24 03:59:23 +01:00
fix a bug
This commit is contained in:
@@ -8,6 +8,7 @@ from torchvision.transforms import InterpolationMode
|
||||
from torchvision.transforms import functional as F
|
||||
|
||||
from modelscope.preprocessors.image import load_image
|
||||
from modelscope.utils.constant import ModeKeys
|
||||
from .base import OfaBasePreprocessor
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
@@ -57,14 +58,21 @@ def ocr_resize(img, patch_image_size, is_document=False):
|
||||
|
||||
class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
|
||||
|
||||
def __init__(self, cfg, model_dir):
|
||||
def __init__(self,
|
||||
cfg,
|
||||
model_dir,
|
||||
mode=ModeKeys.INFERENCE,
|
||||
*args,
|
||||
**kwargs):
|
||||
"""preprocess the data
|
||||
|
||||
Args:
|
||||
cfg(modelscope.utils.config.ConfigDict) : model config
|
||||
model_dir (str): model path
|
||||
model_dir (str): model path,
|
||||
mode: preprocessor mode (model mode)
|
||||
"""
|
||||
super(OfaOcrRecognitionPreprocessor, self).__init__(cfg, model_dir)
|
||||
super(OfaOcrRecognitionPreprocessor,
|
||||
self).__init__(cfg, model_dir, mode, *args, **kwargs)
|
||||
# Initialize transform
|
||||
if self.cfg.model.imagenet_default_mean_and_std:
|
||||
mean = IMAGENET_DEFAULT_MEAN
|
||||
@@ -87,7 +95,7 @@ class OfaOcrRecognitionPreprocessor(OfaBasePreprocessor):
|
||||
data['image'], Image.Image) else load_image(data['image'])
|
||||
patch_image = self.patch_resize_transform(image)
|
||||
prompt = self.cfg.model.get('prompt', '图片上的文字是什么?')
|
||||
inputs = self.get_inputs(prompt)
|
||||
inputs = self.tokenize_text(prompt)
|
||||
|
||||
sample = {
|
||||
'source': inputs,
|
||||
|
||||
@@ -36,10 +36,10 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
# 'launcher': 'pytorch',
|
||||
'max_epochs': 1,
|
||||
'use_fp16': True,
|
||||
'dataloader': {'batch_size_per_gpu': 4, 'workers_per_gpu': 0},
|
||||
'dataloader': {'batch_size_per_gpu': 1, 'workers_per_gpu': 0},
|
||||
'lr_scheduler': {'name': 'polynomial_decay',
|
||||
'warmup_proportion': 0.01,
|
||||
'lr_end': 1e-07},
|
||||
'lr_endo': 1e-07},
|
||||
'lr_scheduler_hook': {'type': 'LrSchedulerHook', 'by_epoch': False},
|
||||
'optimizer': {'type': 'AdamW', 'lr': 5e-05, 'weight_decay': 0.01},
|
||||
'optimizer_hook': {'type': 'TorchAMPOptimizerHook',
|
||||
@@ -86,11 +86,11 @@ class TestOfaTrainer(unittest.TestCase):
|
||||
train_dataset=MsDataset.load(
|
||||
'coco_2014_caption',
|
||||
namespace='modelscope',
|
||||
split='train[:100]'),
|
||||
split='train[:20]'),
|
||||
eval_dataset=MsDataset.load(
|
||||
'coco_2014_caption',
|
||||
namespace='modelscope',
|
||||
split='validation[:20]'),
|
||||
split='validation[:10]'),
|
||||
metrics=[Metrics.BLEU],
|
||||
cfg_file=config_file)
|
||||
trainer = build_trainer(name=Trainers.ofa_tasks, default_args=args)
|
||||
|
||||
Reference in New Issue
Block a user