add TorchModel for vlpt and dbnet for ocr detection

把vlpt和dbnet相关模型集成到model模块,通过configuration.json控制模型和后处理

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11712249
This commit is contained in:
xixing.tj
2023-02-21 22:40:00 +08:00
committed by wenmeng.zwm
parent 386be89f3d
commit a06297e8c4
8 changed files with 919 additions and 72 deletions

View File

@@ -99,6 +99,7 @@ class Models(object):
object_detection_3d = 'object_detection_3d'
ddpm = 'ddpm'
ocr_recognition = 'OCRRecognition'
ocr_detection = 'OCRDetection'
image_quality_assessment_mos = 'image-quality-assessment-mos'
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
m2fp = 'm2fp'
@@ -903,6 +904,7 @@ class Preprocessors(object):
image_sky_change_preprocessor = 'image-sky-change-preprocessor'
image_demoire_preprocessor = 'image-demoire-preprocessor'
ocr_recognition = 'ocr-recognition'
ocr_detection = 'ocr-detection'
bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor'
nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor'

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .model import OCRDetection
else:
_import_structure = {
'model': ['OCRDetection'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,77 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict
import numpy as np
import torch
import torch.nn.functional as F
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .modules.dbnet import DBModel, VLPTModel
from .utils import boxes_from_bitmap, polygons_from_bitmap
LOGGER = get_logger()
@MODELS.register_module(Tasks.ocr_detection, module_name=Models.ocr_detection)
class OCRDetection(TorchModel):
def __init__(self, model_dir: str, **kwargs):
"""initialize the ocr recognition model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, **kwargs)
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
cfgs = Config.from_file(
os.path.join(model_dir, ModelFile.CONFIGURATION))
self.thresh = cfgs.model.inference_kwargs.thresh
self.return_polygon = cfgs.model.inference_kwargs.return_polygon
self.backbone = cfgs.model.backbone
self.detector = None
if self.backbone == 'resnet50':
self.detector = VLPTModel()
elif self.backbone == 'resnet18':
self.detector = DBModel()
else:
raise TypeError(
f'detector backbone should be either resnet18, resnet50, but got {cfgs.model.backbone}'
)
if model_path != '':
self.detector.load_state_dict(
torch.load(model_path, map_location='cpu'))
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
img (`torch.Tensor`): image tensor,
shape of each tensor is [3, H, W].
Return:
results (`torch.Tensor`): bitmap tensor,
shape of each tensor is [1, H, W].
org_shape (`List`): image original shape,
value is [height, width].
"""
pred = self.detector(input['img'])
return {'results': pred, 'org_shape': input['org_shape']}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
pred = inputs['results'][0]
height, width = inputs['org_shape']
segmentation = pred > self.thresh
if self.return_polygon:
boxes, scores = polygons_from_bitmap(pred, segmentation, width,
height)
else:
boxes, scores = boxes_from_bitmap(pred, segmentation, width,
height)
result = {'det_polygons': np.array(boxes)}
return result

View File

@@ -0,0 +1,451 @@
# ------------------------------------------------------------------------------
# Part of implementation is adopted from ViLT,
# made publicly available under the Apache License 2.0 at https://github.com/dandelin/ViLT.
# ------------------------------------------------------------------------------
import math
import os
import sys
import torch
import torch.nn as nn
BatchNorm2d = nn.BatchNorm2d
def constant_init(module, constant, bias=0):
nn.init.constant_(module.weight, constant)
if hasattr(module, 'bias'):
nn.init.constant_(module.bias, bias)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
super(BasicBlock, self).__init__()
self.with_dcn = dcn is not None
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = dcn.get('fallback_on_stride', False)
self.with_modulated_dcn = dcn.get('modulated', False)
# self.conv2 = conv3x3(planes, planes)
if not self.with_dcn or fallback_on_stride:
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, bias=False)
else:
deformable_groups = dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
from assets.ops.dcn import DeformConv
conv_op = DeformConv
offset_channels = 18
else:
from assets.ops.dcn import ModulatedDeformConv
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
planes,
deformable_groups * offset_channels,
kernel_size=3,
padding=1)
self.conv2 = conv_op(
planes,
planes,
kernel_size=3,
padding=1,
deformable_groups=deformable_groups,
bias=False)
self.bn2 = BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# out = self.conv2(out)
if not self.with_dcn:
out = self.conv2(out)
elif self.with_modulated_dcn:
offset_mask = self.conv2_offset(out)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, -9:, :, :].sigmoid()
out = self.conv2(out, offset, mask)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
super(Bottleneck, self).__init__()
self.with_dcn = dcn is not None
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm2d(planes)
fallback_on_stride = False
self.with_modulated_dcn = False
if self.with_dcn:
fallback_on_stride = dcn.get('fallback_on_stride', False)
self.with_modulated_dcn = dcn.get('modulated', False)
if not self.with_dcn or fallback_on_stride:
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
else:
deformable_groups = dcn.get('deformable_groups', 1)
if not self.with_modulated_dcn:
from assets.ops.dcn import DeformConv
conv_op = DeformConv
offset_channels = 18
else:
from assets.ops.dcn import ModulatedDeformConv
conv_op = ModulatedDeformConv
offset_channels = 27
self.conv2_offset = nn.Conv2d(
planes,
deformable_groups * offset_channels,
kernel_size=3,
padding=1)
self.conv2 = conv_op(
planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
deformable_groups=deformable_groups,
bias=False)
self.bn2 = BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
self.bn3 = BatchNorm2d(planes * 4)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
self.dcn = dcn
self.with_dcn = dcn is not None
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
# out = self.conv2(out)
if not self.with_dcn:
out = self.conv2(out)
elif self.with_modulated_dcn:
offset_mask = self.conv2_offset(out)
offset = offset_mask[:, :18, :, :]
mask = offset_mask[:, -9:, :, :].sigmoid()
out = self.conv2(out, offset, mask)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block,
layers,
num_classes=1000,
dcn=None,
stage_with_dcn=(False, False, False, False)):
self.dcn = dcn
self.stage_with_dcn = stage_with_dcn
self.inplanes = 64
super(ResNet, self).__init__()
self.conv1 = nn.Conv2d(
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(
block, 128, layers[1], stride=2, dcn=dcn)
self.layer3 = self._make_layer(
block, 256, layers[2], stride=2, dcn=dcn)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=2, dcn=dcn)
# self.avgpool = nn.AvgPool2d(7, stride=1)
# self.fc = nn.Linear(512 * block.expansion, num_classes)
# self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
if self.dcn is not None:
for m in self.modules():
if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
if hasattr(m, 'conv2_offset'):
constant_init(m.conv2_offset, 0)
def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, dcn=dcn))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, dcn=dcn))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x2 = self.layer1(x)
x3 = self.layer2(x2)
x4 = self.layer3(x3)
x5 = self.layer4(x4)
return x2, x3, x4, x5
class SegDetector(nn.Module):
def __init__(self,
in_channels=[64, 128, 256, 512],
inner_channels=256,
k=10,
bias=False,
adaptive=False,
smooth=False,
serial=False,
*args,
**kwargs):
'''
bias: Whether conv layers have bias or not.
adaptive: Whether to use adaptive threshold training or not.
smooth: If true, use bilinear instead of deconv.
serial: If true, thresh prediction will combine segmentation result as input.
'''
super(SegDetector, self).__init__()
self.k = k
self.serial = serial
self.up5 = nn.Upsample(scale_factor=2, mode='nearest')
self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias)
self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias)
self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias)
self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias)
self.out5 = nn.Sequential(
nn.Conv2d(
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.Upsample(scale_factor=8, mode='nearest'))
self.out4 = nn.Sequential(
nn.Conv2d(
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.Upsample(scale_factor=4, mode='nearest'))
self.out3 = nn.Sequential(
nn.Conv2d(
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
nn.Upsample(scale_factor=2, mode='nearest'))
self.out2 = nn.Conv2d(
inner_channels, inner_channels // 4, 3, padding=1, bias=bias)
self.binarize = nn.Sequential(
nn.Conv2d(
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
self.binarize.apply(self.weights_init)
self.adaptive = adaptive
if adaptive:
self.thresh = self._init_thresh(
inner_channels, serial=serial, smooth=smooth, bias=bias)
self.thresh.apply(self.weights_init)
self.in5.apply(self.weights_init)
self.in4.apply(self.weights_init)
self.in3.apply(self.weights_init)
self.in2.apply(self.weights_init)
self.out5.apply(self.weights_init)
self.out4.apply(self.weights_init)
self.out3.apply(self.weights_init)
self.out2.apply(self.weights_init)
def weights_init(self, m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.kaiming_normal_(m.weight.data)
elif classname.find('BatchNorm') != -1:
m.weight.data.fill_(1.)
m.bias.data.fill_(1e-4)
def _init_thresh(self,
inner_channels,
serial=False,
smooth=False,
bias=False):
in_channels = inner_channels
if serial:
in_channels += 1
self.thresh = nn.Sequential(
nn.Conv2d(
in_channels, inner_channels // 4, 3, padding=1, bias=bias),
BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
self._init_upsample(
inner_channels // 4,
inner_channels // 4,
smooth=smooth,
bias=bias), BatchNorm2d(inner_channels // 4),
nn.ReLU(inplace=True),
self._init_upsample(
inner_channels // 4, 1, smooth=smooth, bias=bias),
nn.Sigmoid())
return self.thresh
def _init_upsample(self,
in_channels,
out_channels,
smooth=False,
bias=False):
if smooth:
inter_out_channels = out_channels
if out_channels == 1:
inter_out_channels = in_channels
module_list = [
nn.Upsample(scale_factor=2, mode='nearest'),
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)
]
if out_channels == 1:
module_list.append(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=1,
bias=True))
return nn.Sequential(module_list)
else:
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
def forward(self, features, gt=None, masks=None, training=False):
c2, c3, c4, c5 = features
in5 = self.in5(c5)
in4 = self.in4(c4)
in3 = self.in3(c3)
in2 = self.in2(c2)
out4 = self.up5(in5) + in4 # 1/16
out3 = self.up4(out4) + in3 # 1/8
out2 = self.up3(out3) + in2 # 1/4
p5 = self.out5(in5)
p4 = self.out4(out4)
p3 = self.out3(out3)
p2 = self.out2(out2)
fuse = torch.cat((p5, p4, p3, p2), 1)
# this is the pred module, not binarization module;
# We do not correct the name due to the trained model.
binary = self.binarize(fuse)
return binary
def step_function(self, x, y):
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
class VLPTModel(nn.Module):
def __init__(self, *args, **kwargs):
"""
VLPT-STD pretrained DBNet-resnet50 model,
paper reference: https://arxiv.org/pdf/2204.13867.pdf
"""
super(VLPTModel, self).__init__()
self.backbone = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
self.decoder = SegDetector(
in_channels=[256, 512, 1024, 2048], adaptive=True, k=50, **kwargs)
def forward(self, x):
return self.decoder(self.backbone(x))
class DBModel(nn.Module):
def __init__(self, *args, **kwargs):
"""
DBNet-resnet18 model without deformable conv,
paper reference: https://arxiv.org/pdf/1911.08947.pdf
"""
super(DBModel, self).__init__()
self.backbone = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
self.decoder = SegDetector(
in_channels=[64, 128, 256, 512], adaptive=True, k=50, **kwargs)
def forward(self, x):
return self.decoder(self.backbone(x))

View File

@@ -0,0 +1,69 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
from typing import Any, Dict
import cv2
import numpy as np
import PIL
import torch
from modelscope.metainfo import Preprocessors
from modelscope.preprocessors import Preprocessor, load_image
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.utils.config import Config
from modelscope.utils.constant import Fields, ModeKeys, ModelFile
@PREPROCESSORS.register_module(
Fields.cv, module_name=Preprocessors.ocr_detection)
class OCRDetectionPreprocessor(Preprocessor):
def __init__(self, model_dir: str, mode: str = ModeKeys.INFERENCE):
"""The base constructor for all ocr recognition preprocessors.
Args:
model_dir (str): model directory to initialize some resource
mode: The mode for the preprocessor.
"""
super().__init__(mode)
cfgs = Config.from_file(
os.path.join(model_dir, ModelFile.CONFIGURATION))
self.image_short_side = cfgs.model.inference_kwargs.image_short_side
def __call__(self, inputs):
"""process the raw input data
Args:
inputs:
- A string containing an HTTP link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL or opencv directly
Returns:
outputs: the preprocessed image
"""
if isinstance(inputs, str):
img = np.array(load_image(inputs))
elif isinstance(inputs, PIL.Image.Image):
img = np.array(inputs)
else:
raise TypeError(
f'inputs should be either str, PIL.Image, np.array, but got {type(inputs)}'
)
img = img[:, :, ::-1]
height, width, _ = img.shape
if height < width:
new_height = self.image_short_side
new_width = int(math.ceil(new_height / height * width / 32) * 32)
else:
new_width = self.image_short_side
new_height = int(math.ceil(new_width / width * height / 32) * 32)
resized_img = cv2.resize(img, (new_width, new_height))
resized_img = resized_img - np.array([123.68, 116.78, 103.94],
dtype=np.float32)
resized_img /= 255.
resized_img = torch.from_numpy(resized_img).permute(
2, 0, 1).float().unsqueeze(0)
result = {'img': resized_img, 'org_shape': [height, width]}
return result

View File

@@ -0,0 +1,256 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import cv2
import numpy as np
import pyclipper
from shapely.geometry import Polygon
def rboxes_to_polygons(rboxes):
"""
Convert rboxes to polygons
ARGS
`rboxes`: [n, 5]
RETURN
`polygons`: [n, 8]
"""
theta = rboxes[:, 4:5]
cxcy = rboxes[:, :2]
half_w = rboxes[:, 2:3] / 2.
half_h = rboxes[:, 3:4] / 2.
v1 = np.hstack([np.cos(theta) * half_w, np.sin(theta) * half_w])
v2 = np.hstack([-np.sin(theta) * half_h, np.cos(theta) * half_h])
p1 = cxcy - v1 - v2
p2 = cxcy + v1 - v2
p3 = cxcy + v1 + v2
p4 = cxcy - v1 + v2
polygons = np.hstack([p1, p2, p3, p4])
return polygons
def cal_width(box):
pd1 = point_dist(box[0], box[1], box[2], box[3])
pd2 = point_dist(box[4], box[5], box[6], box[7])
return (pd1 + pd2) / 2
def point_dist(x1, y1, x2, y2):
return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1))
def draw_polygons(img, polygons):
for p in polygons.tolist():
p = [int(o) for o in p]
cv2.line(img, (p[0], p[1]), (p[2], p[3]), (0, 255, 0), 1)
cv2.line(img, (p[2], p[3]), (p[4], p[5]), (0, 255, 0), 1)
cv2.line(img, (p[4], p[5]), (p[6], p[7]), (0, 255, 0), 1)
cv2.line(img, (p[6], p[7]), (p[0], p[1]), (0, 255, 0), 1)
return img
def nms_python(boxes):
boxes = sorted(boxes, key=lambda x: -x[8])
nms_flag = [True] * len(boxes)
for i, a in enumerate(boxes):
if not nms_flag[i]:
continue
else:
for j, b in enumerate(boxes):
if not j > i:
continue
if not nms_flag[j]:
continue
score_a = a[8]
score_b = b[8]
rbox_a = polygon2rbox(a[:8])
rbox_b = polygon2rbox(b[:8])
if point_in_rbox(rbox_a[:2], rbox_b) or point_in_rbox(
rbox_b[:2], rbox_a):
if score_a > score_b:
nms_flag[j] = False
boxes_nms = []
for i, box in enumerate(boxes):
if nms_flag[i]:
boxes_nms.append(box)
return boxes_nms
def point_in_rbox(c, rbox):
cx0, cy0 = c[0], c[1]
cx1, cy1 = rbox[0], rbox[1]
w, h = rbox[2], rbox[3]
theta = rbox[4]
dist_x = np.abs((cx1 - cx0) * np.cos(theta) + (cy1 - cy0) * np.sin(theta))
dist_y = np.abs(-(cx1 - cx0) * np.sin(theta) + (cy1 - cy0) * np.cos(theta))
return ((dist_x < w / 2.0) and (dist_y < h / 2.0))
def polygon2rbox(polygon):
x1, x2, x3, x4 = polygon[0], polygon[2], polygon[4], polygon[6]
y1, y2, y3, y4 = polygon[1], polygon[3], polygon[5], polygon[7]
c_x = (x1 + x2 + x3 + x4) / 4
c_y = (y1 + y2 + y3 + y4) / 4
w1 = point_dist(x1, y1, x2, y2)
w2 = point_dist(x3, y3, x4, y4)
h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2)
h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4)
h = h1 + h2
w = (w1 + w2) / 2
theta1 = np.arctan2(y2 - y1, x2 - x1)
theta2 = np.arctan2(y3 - y4, x3 - x4)
theta = (theta1 + theta2) / 2.0
return [c_x, c_y, w, h, theta]
def point_line_dist(px, py, x1, y1, x2, y2):
eps = 1e-6
dx = x2 - x1
dy = y2 - y1
div = np.sqrt(dx * dx + dy * dy) + eps
dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div
return dist
# Part of the implementation is borrowed and modified from DB,
# publicly available at https://github.com/MhLiao/DB.
def polygons_from_bitmap(pred, _bitmap, dest_width, dest_height):
"""
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
assert _bitmap.size(0) == 1
bitmap = _bitmap.cpu().numpy()[0]
pred = pred.cpu().detach().numpy()[0]
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:100]:
epsilon = 0.01 * cv2.arcLength(contour, True)
approx = cv2.approxPolyDP(contour, epsilon, True)
points = approx.reshape((-1, 2))
if points.shape[0] < 4:
continue
score = box_score_fast(pred, points.reshape(-1, 2))
if 0.7 > score:
continue
if points.shape[0] > 2:
box = unclip(points, unclip_ratio=2.0)
if len(box) > 1:
continue
else:
continue
box = box.reshape(-1, 2)
_, sside = get_mini_boxes(box.reshape((-1, 1, 2)))
if sside < 3 + 2:
continue
if not isinstance(dest_width, int):
dest_width = dest_width.item()
dest_height = dest_height.item()
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.tolist())
scores.append(score)
return boxes, scores
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height):
"""
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
assert _bitmap.size(0) == 1
bitmap = _bitmap.cpu().numpy()[0]
pred = pred.cpu().detach().numpy()[0]
height, width = bitmap.shape
boxes = []
scores = []
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
for contour in contours[:100]:
points, sside = get_mini_boxes(contour)
if sside < 3:
continue
points = np.array(points)
score = box_score_fast(pred, points.reshape(-1, 2))
if 0.3 > score:
continue
box = unclip(points, unclip_ratio=1.5).reshape(-1, 1, 2)
box, sside = get_mini_boxes(box)
if sside < 3 + 2:
continue
box = np.array(box).astype(np.int32)
if not isinstance(dest_width, int):
dest_width = dest_width.item()
dest_height = dest_height.item()
box[:, 0] = np.clip(
np.round(box[:, 0] / width * dest_width), 0, dest_width)
box[:, 1] = np.clip(
np.round(box[:, 1] / height * dest_height), 0, dest_height)
boxes.append(box.reshape(-1).tolist())
scores.append(score)
return boxes, scores
def box_score_fast(bitmap, _box):
h, w = bitmap.shape[:2]
box = _box.copy()
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
box[:, 0] = box[:, 0] - xmin
box[:, 1] = box[:, 1] - ymin
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
def unclip(box, unclip_ratio=1.5):
poly = Polygon(box)
distance = poly.area * unclip_ratio / poly.length
offset = pyclipper.PyclipperOffset()
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
expanded = np.array(offset.Execute(distance))
return expanded
def get_mini_boxes(contour):
bounding_box = cv2.minAreaRect(contour)
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
if points[1][1] > points[0][1]:
index_1 = 0
index_4 = 1
else:
index_1 = 1
index_4 = 0
if points[3][1] > points[2][1]:
index_2 = 2
index_3 = 3
else:
index_2 = 3
index_3 = 2
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
return box, min(bounding_box[1])

View File

@@ -1,5 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import os
import os.path as osp
from typing import Any, Dict
@@ -9,11 +10,12 @@ import tensorflow as tf
import torch
from modelscope.metainfo import Pipelines
from modelscope.models.cv.ocr_detection import OCRDetection
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.pipelines.cv.ocr_utils.model_vlpt import DBModel, VLPTModel
from modelscope.preprocessors import LoadImage
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.device import device_placement
from modelscope.utils.logger import get_logger
@@ -48,7 +50,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6,
@PIPELINES.register_module(
Tasks.ocr_detection, module_name=Pipelines.ocr_detection)
class OCRDetectionPipeline(Pipeline):
""" OCR Recognition Pipeline.
""" OCR Detection Pipeline.
Example:
@@ -82,39 +84,19 @@ class OCRDetectionPipeline(Pipeline):
Args:
model: model id on modelscope hub.
"""
assert isinstance(model, str), 'model must be a single str'
super().__init__(model=model, **kwargs)
if 'vlpt' in self.model:
# for model cv_resnet50_ocr-detection-vlpt
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
logger.info(f'loading model from dir {model}')
cfgs = Config.from_file(os.path.join(model, ModelFile.CONFIGURATION))
if hasattr(cfgs, 'model') and hasattr(cfgs.model, 'model_type'):
self.model_type = cfgs.model.model_type
else:
self.model_type = 'SegLink++'
self.thresh = 0.3
self.image_short_side = 736
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.infer_model = VLPTModel().to(self.device)
self.infer_model.eval()
checkpoint = torch.load(model_path, map_location=self.device)
if 'state_dict' in checkpoint:
self.infer_model.load_state_dict(checkpoint['state_dict'])
else:
self.infer_model.load_state_dict(checkpoint)
elif 'db' in self.model:
# for model cv_resnet18_ocr-detection-db-line-level_damo (original dbnet)
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
logger.info(f'loading model from {model_path}')
self.thresh = 0.2
self.image_short_side = 736
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu')
self.infer_model = DBModel().to(self.device)
self.infer_model.eval()
checkpoint = torch.load(model_path, map_location=self.device)
if 'state_dict' in checkpoint:
self.infer_model.load_state_dict(checkpoint['state_dict'])
else:
self.infer_model.load_state_dict(checkpoint)
if self.model_type == 'DBNet':
self.ocr_detector = self.model.to(self.device)
self.ocr_detector.eval()
logger.info('loading model done')
else:
# for model seglink++
tf.reset_default_graph()
@@ -191,26 +173,29 @@ class OCRDetectionPipeline(Pipeline):
variable_averages.variables_to_restore())
model_loader.restore(sess, model_path)
def preprocess(self, input: Input) -> Dict[str, Any]:
if 'vlpt' in self.model or 'db' in self.model:
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1]
height, width, _ = img.shape
if height < width:
new_height = self.image_short_side
new_width = int(
math.ceil(new_height / height * width / 32) * 32)
else:
new_width = self.image_short_side
new_height = int(
math.ceil(new_width / width * height / 32) * 32)
resized_img = cv2.resize(img, (new_width, new_height))
resized_img = resized_img - np.array([123.68, 116.78, 103.94],
dtype=np.float32)
resized_img /= 255.
resized_img = torch.from_numpy(resized_img).permute(
2, 0, 1).float().unsqueeze(0)
def __call__(self, input, **kwargs):
"""
Detect text instance in the text image.
result = {'img': resized_img, 'org_shape': [height, width]}
Args:
input (`Image`):
The pipeline handles three types of images:
- A string containing an HTTP link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL or opencv directly
The pipeline currently supports single image input.
Return:
An array of contour polygons of detected N text instances in image,
every row is [x1, y1, x2, y2, x3, y3, x4, y4, ...].
"""
return super().__call__(input, **kwargs)
def preprocess(self, input: Input) -> Dict[str, Any]:
if self.model_type == 'DBNet':
result = self.preprocessor(input)
return result
else:
img = LoadImage.convert_to_ndarray(input)
@@ -235,9 +220,9 @@ class OCRDetectionPipeline(Pipeline):
return result
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
if 'vlpt' in self.model or 'db' in self.model:
pred = self.infer_model(input['img'])
return {'results': pred, 'org_shape': input['org_shape']}
if self.model_type == 'DBNet':
outputs = self.ocr_detector(input)
return outputs
else:
with self._graph.as_default():
with self._session.as_default():
@@ -247,23 +232,8 @@ class OCRDetectionPipeline(Pipeline):
return sess_outputs
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if 'vlpt' in self.model:
pred = inputs['results'][0]
height, width = inputs['org_shape']
segmentation = pred > self.thresh
boxes, scores = polygons_from_bitmap(pred, segmentation, width,
height)
result = {OutputKeys.POLYGONS: np.array(boxes)}
return result
elif 'db' in self.model:
pred = inputs['results'][0]
height, width = inputs['org_shape']
segmentation = pred > self.thresh
boxes, scores = boxes_from_bitmap(pred, segmentation, width,
height)
result = {OutputKeys.POLYGONS: np.array(boxes)}
if self.model_type == 'DBNet':
result = {OutputKeys.POLYGONS: inputs['det_polygons']}
return result
else:
rboxes = inputs['combined_rboxes'][0]