Files
modelscope/modelscope/models/cv/ocr_detection/preprocessor.py
xfanplus cb35b9c04e fix type checking of inputs for np.array (#271)
* fix type checking of inputs for np.array

inputs type of np.ndarray is not checked correctly.

* add docstr for /preprocessor.py

add docstr about np.ndarray, shape and order
2023-05-10 10:34:42 +08:00

72 lines
2.6 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from typing import Any, Dict
import cv2
import numpy as np
import PIL
import torch
from modelscope.metainfo import Preprocessors
from modelscope.preprocessors import Preprocessor, load_image
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModeKeys, ModelFile
@PREPROCESSORS.register_module(
Fields.cv, module_name=Preprocessors.ocr_detection)
class OCRDetectionPreprocessor(Preprocessor):
def __init__(self, model_dir: str, mode: str = ModeKeys.INFERENCE):
"""The base constructor for all ocr recognition preprocessors.
Args:
model_dir (str): model directory to initialize some resource
mode: The mode for the preprocessor.
"""
super().__init__(mode)
cfgs = Config.from_file(
os.path.join(model_dir, ModelFile.CONFIGURATION))
self.image_short_side = cfgs.model.inference_kwargs.image_short_side
def __call__(self, inputs):
"""process the raw input data
Args:
inputs:
- A string containing an HTTP link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL(PIL.Image.Image) or opencv(np.ndarray) directly, 3 channels RGB
Returns:
outputs: the preprocessed image
"""
if isinstance(inputs, str):
img = np.array(load_image(inputs))
elif isinstance(inputs, PIL.Image.Image):
img = np.array(inputs)
elif isinstance(inputs, np.ndarray):
img = inputs
else:
raise TypeError(
f'inputs should be either str, PIL.Image, np.array, but got {type(inputs)}'
)
img = img[:, :, ::-1]
height, width, _ = img.shape
if height < width:
new_height = self.image_short_side
new_width = int(math.ceil(new_height / height * width / 32) * 32)
else:
new_width = self.image_short_side
new_height = int(math.ceil(new_width / width * height / 32) * 32)
resized_img = cv2.resize(img, (new_width, new_height))
resized_img = resized_img - np.array([123.68, 116.78, 103.94],
dtype=np.float32)
resized_img /= 255.
resized_img = torch.from_numpy(resized_img).permute(
2, 0, 1).float().unsqueeze(0)
result = {'img': resized_img, 'org_shape': [height, width]}
return result