mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
* fix type checking of inputs for np.array inputs type of np.ndarray is not checked correctly. * add docstr for /preprocessor.py add docstr about np.ndarray, shape and order
72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import math
|
|
import os
|
|
from typing import Any, Dict
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import PIL
|
|
import torch
|
|
|
|
from modelscope.metainfo import Preprocessors
|
|
from modelscope.preprocessors import Preprocessor, load_image
|
|
from modelscope.preprocessors.builder import PREPROCESSORS
|
|
from modelscope.utils.config import Config
|
|
from modelscope.utils.constant import Fields, ModeKeys, ModelFile
|
|
|
|
|
|
@PREPROCESSORS.register_module(
|
|
Fields.cv, module_name=Preprocessors.ocr_detection)
|
|
class OCRDetectionPreprocessor(Preprocessor):
|
|
|
|
def __init__(self, model_dir: str, mode: str = ModeKeys.INFERENCE):
|
|
"""The base constructor for all ocr recognition preprocessors.
|
|
|
|
Args:
|
|
model_dir (str): model directory to initialize some resource
|
|
mode: The mode for the preprocessor.
|
|
"""
|
|
super().__init__(mode)
|
|
cfgs = Config.from_file(
|
|
os.path.join(model_dir, ModelFile.CONFIGURATION))
|
|
self.image_short_side = cfgs.model.inference_kwargs.image_short_side
|
|
|
|
def __call__(self, inputs):
|
|
"""process the raw input data
|
|
Args:
|
|
inputs:
|
|
- A string containing an HTTP link pointing to an image
|
|
- A string containing a local path to an image
|
|
- An image loaded in PIL(PIL.Image.Image) or opencv(np.ndarray) directly, 3 channels RGB
|
|
Returns:
|
|
outputs: the preprocessed image
|
|
"""
|
|
if isinstance(inputs, str):
|
|
img = np.array(load_image(inputs))
|
|
elif isinstance(inputs, PIL.Image.Image):
|
|
img = np.array(inputs)
|
|
elif isinstance(inputs, np.ndarray):
|
|
img = inputs
|
|
else:
|
|
raise TypeError(
|
|
f'inputs should be either str, PIL.Image, np.array, but got {type(inputs)}'
|
|
)
|
|
|
|
img = img[:, :, ::-1]
|
|
height, width, _ = img.shape
|
|
if height < width:
|
|
new_height = self.image_short_side
|
|
new_width = int(math.ceil(new_height / height * width / 32) * 32)
|
|
else:
|
|
new_width = self.image_short_side
|
|
new_height = int(math.ceil(new_width / width * height / 32) * 32)
|
|
resized_img = cv2.resize(img, (new_width, new_height))
|
|
resized_img = resized_img - np.array([123.68, 116.78, 103.94],
|
|
dtype=np.float32)
|
|
resized_img /= 255.
|
|
resized_img = torch.from_numpy(resized_img).permute(
|
|
2, 0, 1).float().unsqueeze(0)
|
|
|
|
result = {'img': resized_img, 'org_shape': [height, width]}
|
|
return result
|