mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 04:29:22 +01:00
add TorchModel for vlpt and dbnet for ocr detection
把vlpt和dbnet相关模型集成到model模块,通过configuration.json控制模型和后处理 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11712249
This commit is contained in:
@@ -99,6 +99,7 @@ class Models(object):
|
||||
object_detection_3d = 'object_detection_3d'
|
||||
ddpm = 'ddpm'
|
||||
ocr_recognition = 'OCRRecognition'
|
||||
ocr_detection = 'OCRDetection'
|
||||
image_quality_assessment_mos = 'image-quality-assessment-mos'
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
m2fp = 'm2fp'
|
||||
@@ -903,6 +904,7 @@ class Preprocessors(object):
|
||||
image_sky_change_preprocessor = 'image-sky-change-preprocessor'
|
||||
image_demoire_preprocessor = 'image-demoire-preprocessor'
|
||||
ocr_recognition = 'ocr-recognition'
|
||||
ocr_detection = 'ocr-detection'
|
||||
bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor'
|
||||
nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor'
|
||||
|
||||
|
||||
22
modelscope/models/cv/ocr_detection/__init__.py
Normal file
22
modelscope/models/cv/ocr_detection/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .model import OCRDetection
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'model': ['OCRDetection'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
77
modelscope/models/cv/ocr_detection/model.py
Normal file
77
modelscope/models/cv/ocr_detection/model.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .modules.dbnet import DBModel, VLPTModel
|
||||
from .utils import boxes_from_bitmap, polygons_from_bitmap
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.ocr_detection, module_name=Models.ocr_detection)
|
||||
class OCRDetection(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""initialize the ocr recognition model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
"""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
cfgs = Config.from_file(
|
||||
os.path.join(model_dir, ModelFile.CONFIGURATION))
|
||||
self.thresh = cfgs.model.inference_kwargs.thresh
|
||||
self.return_polygon = cfgs.model.inference_kwargs.return_polygon
|
||||
self.backbone = cfgs.model.backbone
|
||||
self.detector = None
|
||||
if self.backbone == 'resnet50':
|
||||
self.detector = VLPTModel()
|
||||
elif self.backbone == 'resnet18':
|
||||
self.detector = DBModel()
|
||||
else:
|
||||
raise TypeError(
|
||||
f'detector backbone should be either resnet18, resnet50, but got {cfgs.model.backbone}'
|
||||
)
|
||||
if model_path != '':
|
||||
self.detector.load_state_dict(
|
||||
torch.load(model_path, map_location='cpu'))
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Args:
|
||||
img (`torch.Tensor`): image tensor,
|
||||
shape of each tensor is [3, H, W].
|
||||
|
||||
Return:
|
||||
results (`torch.Tensor`): bitmap tensor,
|
||||
shape of each tensor is [1, H, W].
|
||||
org_shape (`List`): image original shape,
|
||||
value is [height, width].
|
||||
"""
|
||||
pred = self.detector(input['img'])
|
||||
return {'results': pred, 'org_shape': input['org_shape']}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pred = inputs['results'][0]
|
||||
height, width = inputs['org_shape']
|
||||
segmentation = pred > self.thresh
|
||||
if self.return_polygon:
|
||||
boxes, scores = polygons_from_bitmap(pred, segmentation, width,
|
||||
height)
|
||||
else:
|
||||
boxes, scores = boxes_from_bitmap(pred, segmentation, width,
|
||||
height)
|
||||
result = {'det_polygons': np.array(boxes)}
|
||||
return result
|
||||
451
modelscope/models/cv/ocr_detection/modules/dbnet.py
Normal file
451
modelscope/models/cv/ocr_detection/modules/dbnet.py
Normal file
@@ -0,0 +1,451 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Part of implementation is adopted from ViLT,
|
||||
# made publicly available under the Apache License 2.0 at https://github.com/dandelin/ViLT.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
BatchNorm2d = nn.BatchNorm2d
|
||||
|
||||
|
||||
def constant_init(module, constant, bias=0):
|
||||
nn.init.constant_(module.weight, constant)
|
||||
if hasattr(module, 'bias'):
|
||||
nn.init.constant_(module.bias, bias)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.with_dcn = dcn is not None
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.with_modulated_dcn = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = dcn.get('fallback_on_stride', False)
|
||||
self.with_modulated_dcn = dcn.get('modulated', False)
|
||||
# self.conv2 = conv3x3(planes, planes)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, padding=1, bias=False)
|
||||
else:
|
||||
deformable_groups = dcn.get('deformable_groups', 1)
|
||||
if not self.with_modulated_dcn:
|
||||
from assets.ops.dcn import DeformConv
|
||||
conv_op = DeformConv
|
||||
offset_channels = 18
|
||||
else:
|
||||
from assets.ops.dcn import ModulatedDeformConv
|
||||
conv_op = ModulatedDeformConv
|
||||
offset_channels = 27
|
||||
self.conv2_offset = nn.Conv2d(
|
||||
planes,
|
||||
deformable_groups * offset_channels,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv2 = conv_op(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
deformable_groups=deformable_groups,
|
||||
bias=False)
|
||||
self.bn2 = BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# out = self.conv2(out)
|
||||
if not self.with_dcn:
|
||||
out = self.conv2(out)
|
||||
elif self.with_modulated_dcn:
|
||||
offset_mask = self.conv2_offset(out)
|
||||
offset = offset_mask[:, :18, :, :]
|
||||
mask = offset_mask[:, -9:, :, :].sigmoid()
|
||||
out = self.conv2(out, offset, mask)
|
||||
else:
|
||||
offset = self.conv2_offset(out)
|
||||
out = self.conv2(out, offset)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.with_dcn = dcn is not None
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = BatchNorm2d(planes)
|
||||
fallback_on_stride = False
|
||||
self.with_modulated_dcn = False
|
||||
if self.with_dcn:
|
||||
fallback_on_stride = dcn.get('fallback_on_stride', False)
|
||||
self.with_modulated_dcn = dcn.get('modulated', False)
|
||||
if not self.with_dcn or fallback_on_stride:
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
else:
|
||||
deformable_groups = dcn.get('deformable_groups', 1)
|
||||
if not self.with_modulated_dcn:
|
||||
from assets.ops.dcn import DeformConv
|
||||
conv_op = DeformConv
|
||||
offset_channels = 18
|
||||
else:
|
||||
from assets.ops.dcn import ModulatedDeformConv
|
||||
conv_op = ModulatedDeformConv
|
||||
offset_channels = 27
|
||||
self.conv2_offset = nn.Conv2d(
|
||||
planes,
|
||||
deformable_groups * offset_channels,
|
||||
kernel_size=3,
|
||||
padding=1)
|
||||
self.conv2 = conv_op(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
deformable_groups=deformable_groups,
|
||||
bias=False)
|
||||
self.bn2 = BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
self.dcn = dcn
|
||||
self.with_dcn = dcn is not None
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
# out = self.conv2(out)
|
||||
if not self.with_dcn:
|
||||
out = self.conv2(out)
|
||||
elif self.with_modulated_dcn:
|
||||
offset_mask = self.conv2_offset(out)
|
||||
offset = offset_mask[:, :18, :, :]
|
||||
mask = offset_mask[:, -9:, :, :].sigmoid()
|
||||
out = self.conv2(out, offset, mask)
|
||||
else:
|
||||
offset = self.conv2_offset(out)
|
||||
out = self.conv2(out, offset)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
block,
|
||||
layers,
|
||||
num_classes=1000,
|
||||
dcn=None,
|
||||
stage_with_dcn=(False, False, False, False)):
|
||||
self.dcn = dcn
|
||||
self.stage_with_dcn = stage_with_dcn
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(
|
||||
block, 128, layers[1], stride=2, dcn=dcn)
|
||||
self.layer3 = self._make_layer(
|
||||
block, 256, layers[2], stride=2, dcn=dcn)
|
||||
self.layer4 = self._make_layer(
|
||||
block, 512, layers[3], stride=2, dcn=dcn)
|
||||
# self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
# self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
# self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
if self.dcn is not None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):
|
||||
if hasattr(m, 'conv2_offset'):
|
||||
constant_init(m.conv2_offset, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dcn=None):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, dcn=dcn))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, dcn=dcn))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x2 = self.layer1(x)
|
||||
x3 = self.layer2(x2)
|
||||
x4 = self.layer3(x3)
|
||||
x5 = self.layer4(x4)
|
||||
|
||||
return x2, x3, x4, x5
|
||||
|
||||
|
||||
class SegDetector(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels=[64, 128, 256, 512],
|
||||
inner_channels=256,
|
||||
k=10,
|
||||
bias=False,
|
||||
adaptive=False,
|
||||
smooth=False,
|
||||
serial=False,
|
||||
*args,
|
||||
**kwargs):
|
||||
'''
|
||||
bias: Whether conv layers have bias or not.
|
||||
adaptive: Whether to use adaptive threshold training or not.
|
||||
smooth: If true, use bilinear instead of deconv.
|
||||
serial: If true, thresh prediction will combine segmentation result as input.
|
||||
'''
|
||||
super(SegDetector, self).__init__()
|
||||
self.k = k
|
||||
self.serial = serial
|
||||
self.up5 = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.up4 = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
self.up3 = nn.Upsample(scale_factor=2, mode='nearest')
|
||||
|
||||
self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias)
|
||||
self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias)
|
||||
self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias)
|
||||
self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias)
|
||||
|
||||
self.out5 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
||||
nn.Upsample(scale_factor=8, mode='nearest'))
|
||||
self.out4 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
||||
nn.Upsample(scale_factor=4, mode='nearest'))
|
||||
self.out3 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
||||
nn.Upsample(scale_factor=2, mode='nearest'))
|
||||
self.out2 = nn.Conv2d(
|
||||
inner_channels, inner_channels // 4, 3, padding=1, bias=bias)
|
||||
|
||||
self.binarize = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inner_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
||||
BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2),
|
||||
BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
|
||||
nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid())
|
||||
self.binarize.apply(self.weights_init)
|
||||
|
||||
self.adaptive = adaptive
|
||||
if adaptive:
|
||||
self.thresh = self._init_thresh(
|
||||
inner_channels, serial=serial, smooth=smooth, bias=bias)
|
||||
self.thresh.apply(self.weights_init)
|
||||
|
||||
self.in5.apply(self.weights_init)
|
||||
self.in4.apply(self.weights_init)
|
||||
self.in3.apply(self.weights_init)
|
||||
self.in2.apply(self.weights_init)
|
||||
self.out5.apply(self.weights_init)
|
||||
self.out4.apply(self.weights_init)
|
||||
self.out3.apply(self.weights_init)
|
||||
self.out2.apply(self.weights_init)
|
||||
|
||||
def weights_init(self, m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
nn.init.kaiming_normal_(m.weight.data)
|
||||
elif classname.find('BatchNorm') != -1:
|
||||
m.weight.data.fill_(1.)
|
||||
m.bias.data.fill_(1e-4)
|
||||
|
||||
def _init_thresh(self,
|
||||
inner_channels,
|
||||
serial=False,
|
||||
smooth=False,
|
||||
bias=False):
|
||||
in_channels = inner_channels
|
||||
if serial:
|
||||
in_channels += 1
|
||||
self.thresh = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels, inner_channels // 4, 3, padding=1, bias=bias),
|
||||
BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True),
|
||||
self._init_upsample(
|
||||
inner_channels // 4,
|
||||
inner_channels // 4,
|
||||
smooth=smooth,
|
||||
bias=bias), BatchNorm2d(inner_channels // 4),
|
||||
nn.ReLU(inplace=True),
|
||||
self._init_upsample(
|
||||
inner_channels // 4, 1, smooth=smooth, bias=bias),
|
||||
nn.Sigmoid())
|
||||
return self.thresh
|
||||
|
||||
def _init_upsample(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
smooth=False,
|
||||
bias=False):
|
||||
if smooth:
|
||||
inter_out_channels = out_channels
|
||||
if out_channels == 1:
|
||||
inter_out_channels = in_channels
|
||||
module_list = [
|
||||
nn.Upsample(scale_factor=2, mode='nearest'),
|
||||
nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)
|
||||
]
|
||||
if out_channels == 1:
|
||||
module_list.append(
|
||||
nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=True))
|
||||
|
||||
return nn.Sequential(module_list)
|
||||
else:
|
||||
return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)
|
||||
|
||||
def forward(self, features, gt=None, masks=None, training=False):
|
||||
c2, c3, c4, c5 = features
|
||||
in5 = self.in5(c5)
|
||||
in4 = self.in4(c4)
|
||||
in3 = self.in3(c3)
|
||||
in2 = self.in2(c2)
|
||||
|
||||
out4 = self.up5(in5) + in4 # 1/16
|
||||
out3 = self.up4(out4) + in3 # 1/8
|
||||
out2 = self.up3(out3) + in2 # 1/4
|
||||
|
||||
p5 = self.out5(in5)
|
||||
p4 = self.out4(out4)
|
||||
p3 = self.out3(out3)
|
||||
p2 = self.out2(out2)
|
||||
|
||||
fuse = torch.cat((p5, p4, p3, p2), 1)
|
||||
# this is the pred module, not binarization module;
|
||||
# We do not correct the name due to the trained model.
|
||||
binary = self.binarize(fuse)
|
||||
return binary
|
||||
|
||||
def step_function(self, x, y):
|
||||
return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))
|
||||
|
||||
|
||||
class VLPTModel(nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
VLPT-STD pretrained DBNet-resnet50 model,
|
||||
paper reference: https://arxiv.org/pdf/2204.13867.pdf
|
||||
"""
|
||||
super(VLPTModel, self).__init__()
|
||||
self.backbone = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
self.decoder = SegDetector(
|
||||
in_channels=[256, 512, 1024, 2048], adaptive=True, k=50, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.backbone(x))
|
||||
|
||||
|
||||
class DBModel(nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
DBNet-resnet18 model without deformable conv,
|
||||
paper reference: https://arxiv.org/pdf/1911.08947.pdf
|
||||
"""
|
||||
super(DBModel, self).__init__()
|
||||
self.backbone = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
self.decoder = SegDetector(
|
||||
in_channels=[64, 128, 256, 512], adaptive=True, k=50, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(self.backbone(x))
|
||||
69
modelscope/models/cv/ocr_detection/preprocessor.py
Normal file
69
modelscope/models/cv/ocr_detection/preprocessor.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors import Preprocessor, load_image
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import Fields, ModeKeys, ModelFile
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.cv, module_name=Preprocessors.ocr_detection)
|
||||
class OCRDetectionPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, mode: str = ModeKeys.INFERENCE):
|
||||
"""The base constructor for all ocr recognition preprocessors.
|
||||
|
||||
Args:
|
||||
model_dir (str): model directory to initialize some resource
|
||||
mode: The mode for the preprocessor.
|
||||
"""
|
||||
super().__init__(mode)
|
||||
cfgs = Config.from_file(
|
||||
os.path.join(model_dir, ModelFile.CONFIGURATION))
|
||||
self.image_short_side = cfgs.model.inference_kwargs.image_short_side
|
||||
|
||||
def __call__(self, inputs):
|
||||
"""process the raw input data
|
||||
Args:
|
||||
inputs:
|
||||
- A string containing an HTTP link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL or opencv directly
|
||||
Returns:
|
||||
outputs: the preprocessed image
|
||||
"""
|
||||
if isinstance(inputs, str):
|
||||
img = np.array(load_image(inputs))
|
||||
elif isinstance(inputs, PIL.Image.Image):
|
||||
img = np.array(inputs)
|
||||
else:
|
||||
raise TypeError(
|
||||
f'inputs should be either str, PIL.Image, np.array, but got {type(inputs)}'
|
||||
)
|
||||
|
||||
img = img[:, :, ::-1]
|
||||
height, width, _ = img.shape
|
||||
if height < width:
|
||||
new_height = self.image_short_side
|
||||
new_width = int(math.ceil(new_height / height * width / 32) * 32)
|
||||
else:
|
||||
new_width = self.image_short_side
|
||||
new_height = int(math.ceil(new_width / width * height / 32) * 32)
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
resized_img = resized_img - np.array([123.68, 116.78, 103.94],
|
||||
dtype=np.float32)
|
||||
resized_img /= 255.
|
||||
resized_img = torch.from_numpy(resized_img).permute(
|
||||
2, 0, 1).float().unsqueeze(0)
|
||||
|
||||
result = {'img': resized_img, 'org_shape': [height, width]}
|
||||
return result
|
||||
256
modelscope/models/cv/ocr_detection/utils.py
Normal file
256
modelscope/models/cv/ocr_detection/utils.py
Normal file
@@ -0,0 +1,256 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pyclipper
|
||||
from shapely.geometry import Polygon
|
||||
|
||||
|
||||
def rboxes_to_polygons(rboxes):
|
||||
"""
|
||||
Convert rboxes to polygons
|
||||
ARGS
|
||||
`rboxes`: [n, 5]
|
||||
RETURN
|
||||
`polygons`: [n, 8]
|
||||
"""
|
||||
|
||||
theta = rboxes[:, 4:5]
|
||||
cxcy = rboxes[:, :2]
|
||||
half_w = rboxes[:, 2:3] / 2.
|
||||
half_h = rboxes[:, 3:4] / 2.
|
||||
v1 = np.hstack([np.cos(theta) * half_w, np.sin(theta) * half_w])
|
||||
v2 = np.hstack([-np.sin(theta) * half_h, np.cos(theta) * half_h])
|
||||
p1 = cxcy - v1 - v2
|
||||
p2 = cxcy + v1 - v2
|
||||
p3 = cxcy + v1 + v2
|
||||
p4 = cxcy - v1 + v2
|
||||
polygons = np.hstack([p1, p2, p3, p4])
|
||||
return polygons
|
||||
|
||||
|
||||
def cal_width(box):
|
||||
pd1 = point_dist(box[0], box[1], box[2], box[3])
|
||||
pd2 = point_dist(box[4], box[5], box[6], box[7])
|
||||
return (pd1 + pd2) / 2
|
||||
|
||||
|
||||
def point_dist(x1, y1, x2, y2):
|
||||
return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1))
|
||||
|
||||
|
||||
def draw_polygons(img, polygons):
|
||||
for p in polygons.tolist():
|
||||
p = [int(o) for o in p]
|
||||
cv2.line(img, (p[0], p[1]), (p[2], p[3]), (0, 255, 0), 1)
|
||||
cv2.line(img, (p[2], p[3]), (p[4], p[5]), (0, 255, 0), 1)
|
||||
cv2.line(img, (p[4], p[5]), (p[6], p[7]), (0, 255, 0), 1)
|
||||
cv2.line(img, (p[6], p[7]), (p[0], p[1]), (0, 255, 0), 1)
|
||||
return img
|
||||
|
||||
|
||||
def nms_python(boxes):
|
||||
boxes = sorted(boxes, key=lambda x: -x[8])
|
||||
nms_flag = [True] * len(boxes)
|
||||
for i, a in enumerate(boxes):
|
||||
if not nms_flag[i]:
|
||||
continue
|
||||
else:
|
||||
for j, b in enumerate(boxes):
|
||||
if not j > i:
|
||||
continue
|
||||
if not nms_flag[j]:
|
||||
continue
|
||||
score_a = a[8]
|
||||
score_b = b[8]
|
||||
rbox_a = polygon2rbox(a[:8])
|
||||
rbox_b = polygon2rbox(b[:8])
|
||||
if point_in_rbox(rbox_a[:2], rbox_b) or point_in_rbox(
|
||||
rbox_b[:2], rbox_a):
|
||||
if score_a > score_b:
|
||||
nms_flag[j] = False
|
||||
boxes_nms = []
|
||||
for i, box in enumerate(boxes):
|
||||
if nms_flag[i]:
|
||||
boxes_nms.append(box)
|
||||
return boxes_nms
|
||||
|
||||
|
||||
def point_in_rbox(c, rbox):
|
||||
cx0, cy0 = c[0], c[1]
|
||||
cx1, cy1 = rbox[0], rbox[1]
|
||||
w, h = rbox[2], rbox[3]
|
||||
theta = rbox[4]
|
||||
dist_x = np.abs((cx1 - cx0) * np.cos(theta) + (cy1 - cy0) * np.sin(theta))
|
||||
dist_y = np.abs(-(cx1 - cx0) * np.sin(theta) + (cy1 - cy0) * np.cos(theta))
|
||||
return ((dist_x < w / 2.0) and (dist_y < h / 2.0))
|
||||
|
||||
|
||||
def polygon2rbox(polygon):
|
||||
x1, x2, x3, x4 = polygon[0], polygon[2], polygon[4], polygon[6]
|
||||
y1, y2, y3, y4 = polygon[1], polygon[3], polygon[5], polygon[7]
|
||||
c_x = (x1 + x2 + x3 + x4) / 4
|
||||
c_y = (y1 + y2 + y3 + y4) / 4
|
||||
w1 = point_dist(x1, y1, x2, y2)
|
||||
w2 = point_dist(x3, y3, x4, y4)
|
||||
h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2)
|
||||
h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4)
|
||||
h = h1 + h2
|
||||
w = (w1 + w2) / 2
|
||||
theta1 = np.arctan2(y2 - y1, x2 - x1)
|
||||
theta2 = np.arctan2(y3 - y4, x3 - x4)
|
||||
theta = (theta1 + theta2) / 2.0
|
||||
return [c_x, c_y, w, h, theta]
|
||||
|
||||
|
||||
def point_line_dist(px, py, x1, y1, x2, y2):
|
||||
eps = 1e-6
|
||||
dx = x2 - x1
|
||||
dy = y2 - y1
|
||||
div = np.sqrt(dx * dx + dy * dy) + eps
|
||||
dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div
|
||||
return dist
|
||||
|
||||
|
||||
# Part of the implementation is borrowed and modified from DB,
|
||||
# publicly available at https://github.com/MhLiao/DB.
|
||||
def polygons_from_bitmap(pred, _bitmap, dest_width, dest_height):
|
||||
"""
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
|
||||
assert _bitmap.size(0) == 1
|
||||
bitmap = _bitmap.cpu().numpy()[0]
|
||||
pred = pred.cpu().detach().numpy()[0]
|
||||
height, width = bitmap.shape
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours[:100]:
|
||||
epsilon = 0.01 * cv2.arcLength(contour, True)
|
||||
approx = cv2.approxPolyDP(contour, epsilon, True)
|
||||
points = approx.reshape((-1, 2))
|
||||
if points.shape[0] < 4:
|
||||
continue
|
||||
|
||||
score = box_score_fast(pred, points.reshape(-1, 2))
|
||||
if 0.7 > score:
|
||||
continue
|
||||
|
||||
if points.shape[0] > 2:
|
||||
box = unclip(points, unclip_ratio=2.0)
|
||||
if len(box) > 1:
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
box = box.reshape(-1, 2)
|
||||
_, sside = get_mini_boxes(box.reshape((-1, 1, 2)))
|
||||
if sside < 3 + 2:
|
||||
continue
|
||||
|
||||
if not isinstance(dest_width, int):
|
||||
dest_width = dest_width.item()
|
||||
dest_height = dest_height.item()
|
||||
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
|
||||
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height):
|
||||
"""
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
|
||||
assert _bitmap.size(0) == 1
|
||||
bitmap = _bitmap.cpu().numpy()[0]
|
||||
pred = pred.cpu().detach().numpy()[0]
|
||||
height, width = bitmap.shape
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
|
||||
cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
for contour in contours[:100]:
|
||||
points, sside = get_mini_boxes(contour)
|
||||
if sside < 3:
|
||||
continue
|
||||
points = np.array(points)
|
||||
|
||||
score = box_score_fast(pred, points.reshape(-1, 2))
|
||||
if 0.3 > score:
|
||||
continue
|
||||
|
||||
box = unclip(points, unclip_ratio=1.5).reshape(-1, 1, 2)
|
||||
box, sside = get_mini_boxes(box)
|
||||
|
||||
if sside < 3 + 2:
|
||||
continue
|
||||
|
||||
box = np.array(box).astype(np.int32)
|
||||
if not isinstance(dest_width, int):
|
||||
dest_width = dest_width.item()
|
||||
dest_height = dest_height.item()
|
||||
|
||||
box[:, 0] = np.clip(
|
||||
np.round(box[:, 0] / width * dest_width), 0, dest_width)
|
||||
box[:, 1] = np.clip(
|
||||
np.round(box[:, 1] / height * dest_height), 0, dest_height)
|
||||
boxes.append(box.reshape(-1).tolist())
|
||||
scores.append(score)
|
||||
return boxes, scores
|
||||
|
||||
|
||||
def box_score_fast(bitmap, _box):
|
||||
h, w = bitmap.shape[:2]
|
||||
box = _box.copy()
|
||||
xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1)
|
||||
xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1)
|
||||
ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1)
|
||||
ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1)
|
||||
|
||||
mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
|
||||
box[:, 0] = box[:, 0] - xmin
|
||||
box[:, 1] = box[:, 1] - ymin
|
||||
cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1)
|
||||
return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
|
||||
|
||||
|
||||
def unclip(box, unclip_ratio=1.5):
|
||||
poly = Polygon(box)
|
||||
distance = poly.area * unclip_ratio / poly.length
|
||||
offset = pyclipper.PyclipperOffset()
|
||||
offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
||||
expanded = np.array(offset.Execute(distance))
|
||||
return expanded
|
||||
|
||||
|
||||
def get_mini_boxes(contour):
|
||||
bounding_box = cv2.minAreaRect(contour)
|
||||
points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
|
||||
|
||||
index_1, index_2, index_3, index_4 = 0, 1, 2, 3
|
||||
if points[1][1] > points[0][1]:
|
||||
index_1 = 0
|
||||
index_4 = 1
|
||||
else:
|
||||
index_1 = 1
|
||||
index_4 = 0
|
||||
if points[3][1] > points[2][1]:
|
||||
index_2 = 2
|
||||
index_3 = 3
|
||||
else:
|
||||
index_2 = 3
|
||||
index_3 = 2
|
||||
|
||||
box = [points[index_1], points[index_2], points[index_3], points[index_4]]
|
||||
return box, min(bounding_box[1])
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
@@ -9,11 +10,12 @@ import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.ocr_detection import OCRDetection
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.cv.ocr_utils.model_vlpt import DBModel, VLPTModel
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import device_placement
|
||||
from modelscope.utils.logger import get_logger
|
||||
@@ -48,7 +50,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6,
|
||||
@PIPELINES.register_module(
|
||||
Tasks.ocr_detection, module_name=Pipelines.ocr_detection)
|
||||
class OCRDetectionPipeline(Pipeline):
|
||||
""" OCR Recognition Pipeline.
|
||||
""" OCR Detection Pipeline.
|
||||
|
||||
Example:
|
||||
|
||||
@@ -82,39 +84,19 @@ class OCRDetectionPipeline(Pipeline):
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
assert isinstance(model, str), 'model must be a single str'
|
||||
super().__init__(model=model, **kwargs)
|
||||
if 'vlpt' in self.model:
|
||||
# for model cv_resnet50_ocr-detection-vlpt
|
||||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
logger.info(f'loading model from {model_path}')
|
||||
logger.info(f'loading model from dir {model}')
|
||||
cfgs = Config.from_file(os.path.join(model, ModelFile.CONFIGURATION))
|
||||
if hasattr(cfgs, 'model') and hasattr(cfgs.model, 'model_type'):
|
||||
self.model_type = cfgs.model.model_type
|
||||
else:
|
||||
self.model_type = 'SegLink++'
|
||||
|
||||
self.thresh = 0.3
|
||||
self.image_short_side = 736
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.infer_model = VLPTModel().to(self.device)
|
||||
self.infer_model.eval()
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
if 'state_dict' in checkpoint:
|
||||
self.infer_model.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
self.infer_model.load_state_dict(checkpoint)
|
||||
elif 'db' in self.model:
|
||||
# for model cv_resnet18_ocr-detection-db-line-level_damo (original dbnet)
|
||||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
logger.info(f'loading model from {model_path}')
|
||||
|
||||
self.thresh = 0.2
|
||||
self.image_short_side = 736
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.infer_model = DBModel().to(self.device)
|
||||
self.infer_model.eval()
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
if 'state_dict' in checkpoint:
|
||||
self.infer_model.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
self.infer_model.load_state_dict(checkpoint)
|
||||
if self.model_type == 'DBNet':
|
||||
self.ocr_detector = self.model.to(self.device)
|
||||
self.ocr_detector.eval()
|
||||
logger.info('loading model done')
|
||||
else:
|
||||
# for model seglink++
|
||||
tf.reset_default_graph()
|
||||
@@ -191,26 +173,29 @@ class OCRDetectionPipeline(Pipeline):
|
||||
variable_averages.variables_to_restore())
|
||||
model_loader.restore(sess, model_path)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
if 'vlpt' in self.model or 'db' in self.model:
|
||||
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1]
|
||||
height, width, _ = img.shape
|
||||
if height < width:
|
||||
new_height = self.image_short_side
|
||||
new_width = int(
|
||||
math.ceil(new_height / height * width / 32) * 32)
|
||||
else:
|
||||
new_width = self.image_short_side
|
||||
new_height = int(
|
||||
math.ceil(new_width / width * height / 32) * 32)
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
resized_img = resized_img - np.array([123.68, 116.78, 103.94],
|
||||
dtype=np.float32)
|
||||
resized_img /= 255.
|
||||
resized_img = torch.from_numpy(resized_img).permute(
|
||||
2, 0, 1).float().unsqueeze(0)
|
||||
def __call__(self, input, **kwargs):
|
||||
"""
|
||||
Detect text instance in the text image.
|
||||
|
||||
result = {'img': resized_img, 'org_shape': [height, width]}
|
||||
Args:
|
||||
input (`Image`):
|
||||
The pipeline handles three types of images:
|
||||
|
||||
- A string containing an HTTP link pointing to an image
|
||||
- A string containing a local path to an image
|
||||
- An image loaded in PIL or opencv directly
|
||||
|
||||
The pipeline currently supports single image input.
|
||||
|
||||
Return:
|
||||
An array of contour polygons of detected N text instances in image,
|
||||
every row is [x1, y1, x2, y2, x3, y3, x4, y4, ...].
|
||||
"""
|
||||
return super().__call__(input, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
if self.model_type == 'DBNet':
|
||||
result = self.preprocessor(input)
|
||||
return result
|
||||
else:
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
@@ -235,9 +220,9 @@ class OCRDetectionPipeline(Pipeline):
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if 'vlpt' in self.model or 'db' in self.model:
|
||||
pred = self.infer_model(input['img'])
|
||||
return {'results': pred, 'org_shape': input['org_shape']}
|
||||
if self.model_type == 'DBNet':
|
||||
outputs = self.ocr_detector(input)
|
||||
return outputs
|
||||
else:
|
||||
with self._graph.as_default():
|
||||
with self._session.as_default():
|
||||
@@ -247,23 +232,8 @@ class OCRDetectionPipeline(Pipeline):
|
||||
return sess_outputs
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if 'vlpt' in self.model:
|
||||
pred = inputs['results'][0]
|
||||
height, width = inputs['org_shape']
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes, scores = polygons_from_bitmap(pred, segmentation, width,
|
||||
height)
|
||||
result = {OutputKeys.POLYGONS: np.array(boxes)}
|
||||
return result
|
||||
elif 'db' in self.model:
|
||||
pred = inputs['results'][0]
|
||||
height, width = inputs['org_shape']
|
||||
segmentation = pred > self.thresh
|
||||
|
||||
boxes, scores = boxes_from_bitmap(pred, segmentation, width,
|
||||
height)
|
||||
result = {OutputKeys.POLYGONS: np.array(boxes)}
|
||||
if self.model_type == 'DBNet':
|
||||
result = {OutputKeys.POLYGONS: inputs['det_polygons']}
|
||||
return result
|
||||
else:
|
||||
rboxes = inputs['combined_rboxes'][0]
|
||||
|
||||
Reference in New Issue
Block a user