From a06297e8c4f3b3c19f672b1ec9c7211765c9aa03 Mon Sep 17 00:00:00 2001 From: "xixing.tj" Date: Tue, 21 Feb 2023 22:40:00 +0800 Subject: [PATCH] add TorchModel for vlpt and dbnet for ocr detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 把vlpt和dbnet相关模型集成到model模块,通过configuration.json控制模型和后处理 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11712249 --- modelscope/metainfo.py | 2 + .../models/cv/ocr_detection/__init__.py | 22 + modelscope/models/cv/ocr_detection/model.py | 77 +++ .../cv/ocr_detection/modules/__init__.py | 0 .../models/cv/ocr_detection/modules/dbnet.py | 451 ++++++++++++++++++ .../models/cv/ocr_detection/preprocessor.py | 69 +++ modelscope/models/cv/ocr_detection/utils.py | 256 ++++++++++ .../pipelines/cv/ocr_detection_pipeline.py | 114 ++--- 8 files changed, 919 insertions(+), 72 deletions(-) create mode 100644 modelscope/models/cv/ocr_detection/__init__.py create mode 100644 modelscope/models/cv/ocr_detection/model.py create mode 100644 modelscope/models/cv/ocr_detection/modules/__init__.py create mode 100644 modelscope/models/cv/ocr_detection/modules/dbnet.py create mode 100644 modelscope/models/cv/ocr_detection/preprocessor.py create mode 100644 modelscope/models/cv/ocr_detection/utils.py diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 4f54f47c..97be486d 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -99,6 +99,7 @@ class Models(object): object_detection_3d = 'object_detection_3d' ddpm = 'ddpm' ocr_recognition = 'OCRRecognition' + ocr_detection = 'OCRDetection' image_quality_assessment_mos = 'image-quality-assessment-mos' image_quality_assessment_degradation = 'image-quality-assessment-degradation' m2fp = 'm2fp' @@ -903,6 +904,7 @@ class Preprocessors(object): image_sky_change_preprocessor = 'image-sky-change-preprocessor' image_demoire_preprocessor = 'image-demoire-preprocessor' ocr_recognition = 'ocr-recognition' + ocr_detection = 'ocr-detection' bad_image_detecting_preprocessor = 'bad-image-detecting-preprocessor' nerf_recon_acc_preprocessor = 'nerf-recon-acc-preprocessor' diff --git a/modelscope/models/cv/ocr_detection/__init__.py b/modelscope/models/cv/ocr_detection/__init__.py new file mode 100644 index 00000000..9e9671fc --- /dev/null +++ b/modelscope/models/cv/ocr_detection/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .model import OCRDetection + +else: + _import_structure = { + 'model': ['OCRDetection'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/ocr_detection/model.py b/modelscope/models/cv/ocr_detection/model.py new file mode 100644 index 00000000..fdb4f8a1 --- /dev/null +++ b/modelscope/models/cv/ocr_detection/model.py @@ -0,0 +1,77 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict + +import numpy as np +import torch +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .modules.dbnet import DBModel, VLPTModel +from .utils import boxes_from_bitmap, polygons_from_bitmap + +LOGGER = get_logger() + + +@MODELS.register_module(Tasks.ocr_detection, module_name=Models.ocr_detection) +class OCRDetection(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + """initialize the ocr recognition model from the `model_dir` path. + + Args: + model_dir (str): the model path. + """ + super().__init__(model_dir, **kwargs) + + model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE) + cfgs = Config.from_file( + os.path.join(model_dir, ModelFile.CONFIGURATION)) + self.thresh = cfgs.model.inference_kwargs.thresh + self.return_polygon = cfgs.model.inference_kwargs.return_polygon + self.backbone = cfgs.model.backbone + self.detector = None + if self.backbone == 'resnet50': + self.detector = VLPTModel() + elif self.backbone == 'resnet18': + self.detector = DBModel() + else: + raise TypeError( + f'detector backbone should be either resnet18, resnet50, but got {cfgs.model.backbone}' + ) + if model_path != '': + self.detector.load_state_dict( + torch.load(model_path, map_location='cpu')) + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + """ + Args: + img (`torch.Tensor`): image tensor, + shape of each tensor is [3, H, W]. + + Return: + results (`torch.Tensor`): bitmap tensor, + shape of each tensor is [1, H, W]. + org_shape (`List`): image original shape, + value is [height, width]. + """ + pred = self.detector(input['img']) + return {'results': pred, 'org_shape': input['org_shape']} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + pred = inputs['results'][0] + height, width = inputs['org_shape'] + segmentation = pred > self.thresh + if self.return_polygon: + boxes, scores = polygons_from_bitmap(pred, segmentation, width, + height) + else: + boxes, scores = boxes_from_bitmap(pred, segmentation, width, + height) + result = {'det_polygons': np.array(boxes)} + return result diff --git a/modelscope/models/cv/ocr_detection/modules/__init__.py b/modelscope/models/cv/ocr_detection/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/ocr_detection/modules/dbnet.py b/modelscope/models/cv/ocr_detection/modules/dbnet.py new file mode 100644 index 00000000..33888324 --- /dev/null +++ b/modelscope/models/cv/ocr_detection/modules/dbnet.py @@ -0,0 +1,451 @@ +# ------------------------------------------------------------------------------ +# Part of implementation is adopted from ViLT, +# made publicly available under the Apache License 2.0 at https://github.com/dandelin/ViLT. +# ------------------------------------------------------------------------------ + +import math +import os +import sys + +import torch +import torch.nn as nn + +BatchNorm2d = nn.BatchNorm2d + + +def constant_init(module, constant, bias=0): + nn.init.constant_(module.weight, constant) + if hasattr(module, 'bias'): + nn.init.constant_(module.bias, bias) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): + super(BasicBlock, self).__init__() + self.with_dcn = dcn is not None + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = dcn.get('fallback_on_stride', False) + self.with_modulated_dcn = dcn.get('modulated', False) + # self.conv2 = conv3x3(planes, planes) + if not self.with_dcn or fallback_on_stride: + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, padding=1, bias=False) + else: + deformable_groups = dcn.get('deformable_groups', 1) + if not self.with_modulated_dcn: + from assets.ops.dcn import DeformConv + conv_op = DeformConv + offset_channels = 18 + else: + from assets.ops.dcn import ModulatedDeformConv + conv_op = ModulatedDeformConv + offset_channels = 27 + self.conv2_offset = nn.Conv2d( + planes, + deformable_groups * offset_channels, + kernel_size=3, + padding=1) + self.conv2 = conv_op( + planes, + planes, + kernel_size=3, + padding=1, + deformable_groups=deformable_groups, + bias=False) + self.bn2 = BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # out = self.conv2(out) + if not self.with_dcn: + out = self.conv2(out) + elif self.with_modulated_dcn: + offset_mask = self.conv2_offset(out) + offset = offset_mask[:, :18, :, :] + mask = offset_mask[:, -9:, :, :].sigmoid() + out = self.conv2(out, offset, mask) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None): + super(Bottleneck, self).__init__() + self.with_dcn = dcn is not None + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes) + fallback_on_stride = False + self.with_modulated_dcn = False + if self.with_dcn: + fallback_on_stride = dcn.get('fallback_on_stride', False) + self.with_modulated_dcn = dcn.get('modulated', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + else: + deformable_groups = dcn.get('deformable_groups', 1) + if not self.with_modulated_dcn: + from assets.ops.dcn import DeformConv + conv_op = DeformConv + offset_channels = 18 + else: + from assets.ops.dcn import ModulatedDeformConv + conv_op = ModulatedDeformConv + offset_channels = 27 + self.conv2_offset = nn.Conv2d( + planes, + deformable_groups * offset_channels, + kernel_size=3, + padding=1) + self.conv2 = conv_op( + planes, + planes, + kernel_size=3, + padding=1, + stride=stride, + deformable_groups=deformable_groups, + bias=False) + self.bn2 = BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dcn = dcn + self.with_dcn = dcn is not None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + # out = self.conv2(out) + if not self.with_dcn: + out = self.conv2(out) + elif self.with_modulated_dcn: + offset_mask = self.conv2_offset(out) + offset = offset_mask[:, :18, :, :] + mask = offset_mask[:, -9:, :, :].sigmoid() + out = self.conv2(out, offset, mask) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, + block, + layers, + num_classes=1000, + dcn=None, + stage_with_dcn=(False, False, False, False)): + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer( + block, 128, layers[1], stride=2, dcn=dcn) + self.layer3 = self._make_layer( + block, 256, layers[2], stride=2, dcn=dcn) + self.layer4 = self._make_layer( + block, 512, layers[3], stride=2, dcn=dcn) + # self.avgpool = nn.AvgPool2d(7, stride=1) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + + # self.smooth = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + if self.dcn is not None: + for m in self.modules(): + if isinstance(m, Bottleneck) or isinstance(m, BasicBlock): + if hasattr(m, 'conv2_offset'): + constant_init(m.conv2_offset, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dcn=None): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, dcn=dcn)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dcn=dcn)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x2 = self.layer1(x) + x3 = self.layer2(x2) + x4 = self.layer3(x3) + x5 = self.layer4(x4) + + return x2, x3, x4, x5 + + +class SegDetector(nn.Module): + + def __init__(self, + in_channels=[64, 128, 256, 512], + inner_channels=256, + k=10, + bias=False, + adaptive=False, + smooth=False, + serial=False, + *args, + **kwargs): + ''' + bias: Whether conv layers have bias or not. + adaptive: Whether to use adaptive threshold training or not. + smooth: If true, use bilinear instead of deconv. + serial: If true, thresh prediction will combine segmentation result as input. + ''' + super(SegDetector, self).__init__() + self.k = k + self.serial = serial + self.up5 = nn.Upsample(scale_factor=2, mode='nearest') + self.up4 = nn.Upsample(scale_factor=2, mode='nearest') + self.up3 = nn.Upsample(scale_factor=2, mode='nearest') + + self.in5 = nn.Conv2d(in_channels[-1], inner_channels, 1, bias=bias) + self.in4 = nn.Conv2d(in_channels[-2], inner_channels, 1, bias=bias) + self.in3 = nn.Conv2d(in_channels[-3], inner_channels, 1, bias=bias) + self.in2 = nn.Conv2d(in_channels[-4], inner_channels, 1, bias=bias) + + self.out5 = nn.Sequential( + nn.Conv2d( + inner_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.Upsample(scale_factor=8, mode='nearest')) + self.out4 = nn.Sequential( + nn.Conv2d( + inner_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.Upsample(scale_factor=4, mode='nearest')) + self.out3 = nn.Sequential( + nn.Conv2d( + inner_channels, inner_channels // 4, 3, padding=1, bias=bias), + nn.Upsample(scale_factor=2, mode='nearest')) + self.out2 = nn.Conv2d( + inner_channels, inner_channels // 4, 3, padding=1, bias=bias) + + self.binarize = nn.Sequential( + nn.Conv2d( + inner_channels, inner_channels // 4, 3, padding=1, bias=bias), + BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, inner_channels // 4, 2, 2), + BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), + nn.ConvTranspose2d(inner_channels // 4, 1, 2, 2), nn.Sigmoid()) + self.binarize.apply(self.weights_init) + + self.adaptive = adaptive + if adaptive: + self.thresh = self._init_thresh( + inner_channels, serial=serial, smooth=smooth, bias=bias) + self.thresh.apply(self.weights_init) + + self.in5.apply(self.weights_init) + self.in4.apply(self.weights_init) + self.in3.apply(self.weights_init) + self.in2.apply(self.weights_init) + self.out5.apply(self.weights_init) + self.out4.apply(self.weights_init) + self.out3.apply(self.weights_init) + self.out2.apply(self.weights_init) + + def weights_init(self, m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.kaiming_normal_(m.weight.data) + elif classname.find('BatchNorm') != -1: + m.weight.data.fill_(1.) + m.bias.data.fill_(1e-4) + + def _init_thresh(self, + inner_channels, + serial=False, + smooth=False, + bias=False): + in_channels = inner_channels + if serial: + in_channels += 1 + self.thresh = nn.Sequential( + nn.Conv2d( + in_channels, inner_channels // 4, 3, padding=1, bias=bias), + BatchNorm2d(inner_channels // 4), nn.ReLU(inplace=True), + self._init_upsample( + inner_channels // 4, + inner_channels // 4, + smooth=smooth, + bias=bias), BatchNorm2d(inner_channels // 4), + nn.ReLU(inplace=True), + self._init_upsample( + inner_channels // 4, 1, smooth=smooth, bias=bias), + nn.Sigmoid()) + return self.thresh + + def _init_upsample(self, + in_channels, + out_channels, + smooth=False, + bias=False): + if smooth: + inter_out_channels = out_channels + if out_channels == 1: + inter_out_channels = in_channels + module_list = [ + nn.Upsample(scale_factor=2, mode='nearest'), + nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias) + ] + if out_channels == 1: + module_list.append( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=1, + bias=True)) + + return nn.Sequential(module_list) + else: + return nn.ConvTranspose2d(in_channels, out_channels, 2, 2) + + def forward(self, features, gt=None, masks=None, training=False): + c2, c3, c4, c5 = features + in5 = self.in5(c5) + in4 = self.in4(c4) + in3 = self.in3(c3) + in2 = self.in2(c2) + + out4 = self.up5(in5) + in4 # 1/16 + out3 = self.up4(out4) + in3 # 1/8 + out2 = self.up3(out3) + in2 # 1/4 + + p5 = self.out5(in5) + p4 = self.out4(out4) + p3 = self.out3(out3) + p2 = self.out2(out2) + + fuse = torch.cat((p5, p4, p3, p2), 1) + # this is the pred module, not binarization module; + # We do not correct the name due to the trained model. + binary = self.binarize(fuse) + return binary + + def step_function(self, x, y): + return torch.reciprocal(1 + torch.exp(-self.k * (x - y))) + + +class VLPTModel(nn.Module): + + def __init__(self, *args, **kwargs): + """ + VLPT-STD pretrained DBNet-resnet50 model, + paper reference: https://arxiv.org/pdf/2204.13867.pdf + """ + super(VLPTModel, self).__init__() + self.backbone = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + self.decoder = SegDetector( + in_channels=[256, 512, 1024, 2048], adaptive=True, k=50, **kwargs) + + def forward(self, x): + return self.decoder(self.backbone(x)) + + +class DBModel(nn.Module): + + def __init__(self, *args, **kwargs): + """ + DBNet-resnet18 model without deformable conv, + paper reference: https://arxiv.org/pdf/1911.08947.pdf + """ + super(DBModel, self).__init__() + self.backbone = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + self.decoder = SegDetector( + in_channels=[64, 128, 256, 512], adaptive=True, k=50, **kwargs) + + def forward(self, x): + return self.decoder(self.backbone(x)) diff --git a/modelscope/models/cv/ocr_detection/preprocessor.py b/modelscope/models/cv/ocr_detection/preprocessor.py new file mode 100644 index 00000000..4b95b3c9 --- /dev/null +++ b/modelscope/models/cv/ocr_detection/preprocessor.py @@ -0,0 +1,69 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors import Preprocessor, load_image +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.utils.config import Config +from modelscope.utils.constant import Fields, ModeKeys, ModelFile + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.ocr_detection) +class OCRDetectionPreprocessor(Preprocessor): + + def __init__(self, model_dir: str, mode: str = ModeKeys.INFERENCE): + """The base constructor for all ocr recognition preprocessors. + + Args: + model_dir (str): model directory to initialize some resource + mode: The mode for the preprocessor. + """ + super().__init__(mode) + cfgs = Config.from_file( + os.path.join(model_dir, ModelFile.CONFIGURATION)) + self.image_short_side = cfgs.model.inference_kwargs.image_short_side + + def __call__(self, inputs): + """process the raw input data + Args: + inputs: + - A string containing an HTTP link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL or opencv directly + Returns: + outputs: the preprocessed image + """ + if isinstance(inputs, str): + img = np.array(load_image(inputs)) + elif isinstance(inputs, PIL.Image.Image): + img = np.array(inputs) + else: + raise TypeError( + f'inputs should be either str, PIL.Image, np.array, but got {type(inputs)}' + ) + + img = img[:, :, ::-1] + height, width, _ = img.shape + if height < width: + new_height = self.image_short_side + new_width = int(math.ceil(new_height / height * width / 32) * 32) + else: + new_width = self.image_short_side + new_height = int(math.ceil(new_width / width * height / 32) * 32) + resized_img = cv2.resize(img, (new_width, new_height)) + resized_img = resized_img - np.array([123.68, 116.78, 103.94], + dtype=np.float32) + resized_img /= 255. + resized_img = torch.from_numpy(resized_img).permute( + 2, 0, 1).float().unsqueeze(0) + + result = {'img': resized_img, 'org_shape': [height, width]} + return result diff --git a/modelscope/models/cv/ocr_detection/utils.py b/modelscope/models/cv/ocr_detection/utils.py new file mode 100644 index 00000000..6de22b3f --- /dev/null +++ b/modelscope/models/cv/ocr_detection/utils.py @@ -0,0 +1,256 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + + +def rboxes_to_polygons(rboxes): + """ + Convert rboxes to polygons + ARGS + `rboxes`: [n, 5] + RETURN + `polygons`: [n, 8] + """ + + theta = rboxes[:, 4:5] + cxcy = rboxes[:, :2] + half_w = rboxes[:, 2:3] / 2. + half_h = rboxes[:, 3:4] / 2. + v1 = np.hstack([np.cos(theta) * half_w, np.sin(theta) * half_w]) + v2 = np.hstack([-np.sin(theta) * half_h, np.cos(theta) * half_h]) + p1 = cxcy - v1 - v2 + p2 = cxcy + v1 - v2 + p3 = cxcy + v1 + v2 + p4 = cxcy - v1 + v2 + polygons = np.hstack([p1, p2, p3, p4]) + return polygons + + +def cal_width(box): + pd1 = point_dist(box[0], box[1], box[2], box[3]) + pd2 = point_dist(box[4], box[5], box[6], box[7]) + return (pd1 + pd2) / 2 + + +def point_dist(x1, y1, x2, y2): + return np.sqrt((x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)) + + +def draw_polygons(img, polygons): + for p in polygons.tolist(): + p = [int(o) for o in p] + cv2.line(img, (p[0], p[1]), (p[2], p[3]), (0, 255, 0), 1) + cv2.line(img, (p[2], p[3]), (p[4], p[5]), (0, 255, 0), 1) + cv2.line(img, (p[4], p[5]), (p[6], p[7]), (0, 255, 0), 1) + cv2.line(img, (p[6], p[7]), (p[0], p[1]), (0, 255, 0), 1) + return img + + +def nms_python(boxes): + boxes = sorted(boxes, key=lambda x: -x[8]) + nms_flag = [True] * len(boxes) + for i, a in enumerate(boxes): + if not nms_flag[i]: + continue + else: + for j, b in enumerate(boxes): + if not j > i: + continue + if not nms_flag[j]: + continue + score_a = a[8] + score_b = b[8] + rbox_a = polygon2rbox(a[:8]) + rbox_b = polygon2rbox(b[:8]) + if point_in_rbox(rbox_a[:2], rbox_b) or point_in_rbox( + rbox_b[:2], rbox_a): + if score_a > score_b: + nms_flag[j] = False + boxes_nms = [] + for i, box in enumerate(boxes): + if nms_flag[i]: + boxes_nms.append(box) + return boxes_nms + + +def point_in_rbox(c, rbox): + cx0, cy0 = c[0], c[1] + cx1, cy1 = rbox[0], rbox[1] + w, h = rbox[2], rbox[3] + theta = rbox[4] + dist_x = np.abs((cx1 - cx0) * np.cos(theta) + (cy1 - cy0) * np.sin(theta)) + dist_y = np.abs(-(cx1 - cx0) * np.sin(theta) + (cy1 - cy0) * np.cos(theta)) + return ((dist_x < w / 2.0) and (dist_y < h / 2.0)) + + +def polygon2rbox(polygon): + x1, x2, x3, x4 = polygon[0], polygon[2], polygon[4], polygon[6] + y1, y2, y3, y4 = polygon[1], polygon[3], polygon[5], polygon[7] + c_x = (x1 + x2 + x3 + x4) / 4 + c_y = (y1 + y2 + y3 + y4) / 4 + w1 = point_dist(x1, y1, x2, y2) + w2 = point_dist(x3, y3, x4, y4) + h1 = point_line_dist(c_x, c_y, x1, y1, x2, y2) + h2 = point_line_dist(c_x, c_y, x3, y3, x4, y4) + h = h1 + h2 + w = (w1 + w2) / 2 + theta1 = np.arctan2(y2 - y1, x2 - x1) + theta2 = np.arctan2(y3 - y4, x3 - x4) + theta = (theta1 + theta2) / 2.0 + return [c_x, c_y, w, h, theta] + + +def point_line_dist(px, py, x1, y1, x2, y2): + eps = 1e-6 + dx = x2 - x1 + dy = y2 - y1 + div = np.sqrt(dx * dx + dy * dy) + eps + dist = np.abs(px * dy - py * dx + x2 * y1 - y2 * x1) / div + return dist + + +# Part of the implementation is borrowed and modified from DB, +# publicly available at https://github.com/MhLiao/DB. +def polygons_from_bitmap(pred, _bitmap, dest_width, dest_height): + """ + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + + assert _bitmap.size(0) == 1 + bitmap = _bitmap.cpu().numpy()[0] + pred = pred.cpu().detach().numpy()[0] + height, width = bitmap.shape + boxes = [] + scores = [] + + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours[:100]: + epsilon = 0.01 * cv2.arcLength(contour, True) + approx = cv2.approxPolyDP(contour, epsilon, True) + points = approx.reshape((-1, 2)) + if points.shape[0] < 4: + continue + + score = box_score_fast(pred, points.reshape(-1, 2)) + if 0.7 > score: + continue + + if points.shape[0] > 2: + box = unclip(points, unclip_ratio=2.0) + if len(box) > 1: + continue + else: + continue + box = box.reshape(-1, 2) + _, sside = get_mini_boxes(box.reshape((-1, 1, 2))) + if sside < 3 + 2: + continue + + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.tolist()) + scores.append(score) + return boxes, scores + + +def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height): + """ + _bitmap: single map with shape (1, H, W), + whose values are binarized as {0, 1} + """ + + assert _bitmap.size(0) == 1 + bitmap = _bitmap.cpu().numpy()[0] + pred = pred.cpu().detach().numpy()[0] + height, width = bitmap.shape + boxes = [] + scores = [] + + contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8), + cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours[:100]: + points, sside = get_mini_boxes(contour) + if sside < 3: + continue + points = np.array(points) + + score = box_score_fast(pred, points.reshape(-1, 2)) + if 0.3 > score: + continue + + box = unclip(points, unclip_ratio=1.5).reshape(-1, 1, 2) + box, sside = get_mini_boxes(box) + + if sside < 3 + 2: + continue + + box = np.array(box).astype(np.int32) + if not isinstance(dest_width, int): + dest_width = dest_width.item() + dest_height = dest_height.item() + + box[:, 0] = np.clip( + np.round(box[:, 0] / width * dest_width), 0, dest_width) + box[:, 1] = np.clip( + np.round(box[:, 1] / height * dest_height), 0, dest_height) + boxes.append(box.reshape(-1).tolist()) + scores.append(score) + return boxes, scores + + +def box_score_fast(bitmap, _box): + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int32), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + +def unclip(box, unclip_ratio=1.5): + poly = Polygon(box) + distance = poly.area * unclip_ratio / poly.length + offset = pyclipper.PyclipperOffset() + offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + expanded = np.array(offset.Execute(distance)) + return expanded + + +def get_mini_boxes(contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [points[index_1], points[index_2], points[index_3], points[index_4]] + return box, min(bounding_box[1]) diff --git a/modelscope/pipelines/cv/ocr_detection_pipeline.py b/modelscope/pipelines/cv/ocr_detection_pipeline.py index ed198b66..cb7522c0 100644 --- a/modelscope/pipelines/cv/ocr_detection_pipeline.py +++ b/modelscope/pipelines/cv/ocr_detection_pipeline.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import math +import os import os.path as osp from typing import Any, Dict @@ -9,11 +10,12 @@ import tensorflow as tf import torch from modelscope.metainfo import Pipelines +from modelscope.models.cv.ocr_detection import OCRDetection from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Input, Pipeline from modelscope.pipelines.builder import PIPELINES -from modelscope.pipelines.cv.ocr_utils.model_vlpt import DBModel, VLPTModel from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks from modelscope.utils.device import device_placement from modelscope.utils.logger import get_logger @@ -48,7 +50,7 @@ tf.app.flags.DEFINE_float('link_threshold', 0.6, @PIPELINES.register_module( Tasks.ocr_detection, module_name=Pipelines.ocr_detection) class OCRDetectionPipeline(Pipeline): - """ OCR Recognition Pipeline. + """ OCR Detection Pipeline. Example: @@ -82,39 +84,19 @@ class OCRDetectionPipeline(Pipeline): Args: model: model id on modelscope hub. """ + assert isinstance(model, str), 'model must be a single str' super().__init__(model=model, **kwargs) - if 'vlpt' in self.model: - # for model cv_resnet50_ocr-detection-vlpt - model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) - logger.info(f'loading model from {model_path}') + logger.info(f'loading model from dir {model}') + cfgs = Config.from_file(os.path.join(model, ModelFile.CONFIGURATION)) + if hasattr(cfgs, 'model') and hasattr(cfgs.model, 'model_type'): + self.model_type = cfgs.model.model_type + else: + self.model_type = 'SegLink++' - self.thresh = 0.3 - self.image_short_side = 736 - self.device = torch.device( - 'cuda' if torch.cuda.is_available() else 'cpu') - self.infer_model = VLPTModel().to(self.device) - self.infer_model.eval() - checkpoint = torch.load(model_path, map_location=self.device) - if 'state_dict' in checkpoint: - self.infer_model.load_state_dict(checkpoint['state_dict']) - else: - self.infer_model.load_state_dict(checkpoint) - elif 'db' in self.model: - # for model cv_resnet18_ocr-detection-db-line-level_damo (original dbnet) - model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) - logger.info(f'loading model from {model_path}') - - self.thresh = 0.2 - self.image_short_side = 736 - self.device = torch.device( - 'cuda' if torch.cuda.is_available() else 'cpu') - self.infer_model = DBModel().to(self.device) - self.infer_model.eval() - checkpoint = torch.load(model_path, map_location=self.device) - if 'state_dict' in checkpoint: - self.infer_model.load_state_dict(checkpoint['state_dict']) - else: - self.infer_model.load_state_dict(checkpoint) + if self.model_type == 'DBNet': + self.ocr_detector = self.model.to(self.device) + self.ocr_detector.eval() + logger.info('loading model done') else: # for model seglink++ tf.reset_default_graph() @@ -191,26 +173,29 @@ class OCRDetectionPipeline(Pipeline): variable_averages.variables_to_restore()) model_loader.restore(sess, model_path) - def preprocess(self, input: Input) -> Dict[str, Any]: - if 'vlpt' in self.model or 'db' in self.model: - img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] - height, width, _ = img.shape - if height < width: - new_height = self.image_short_side - new_width = int( - math.ceil(new_height / height * width / 32) * 32) - else: - new_width = self.image_short_side - new_height = int( - math.ceil(new_width / width * height / 32) * 32) - resized_img = cv2.resize(img, (new_width, new_height)) - resized_img = resized_img - np.array([123.68, 116.78, 103.94], - dtype=np.float32) - resized_img /= 255. - resized_img = torch.from_numpy(resized_img).permute( - 2, 0, 1).float().unsqueeze(0) + def __call__(self, input, **kwargs): + """ + Detect text instance in the text image. - result = {'img': resized_img, 'org_shape': [height, width]} + Args: + input (`Image`): + The pipeline handles three types of images: + + - A string containing an HTTP link pointing to an image + - A string containing a local path to an image + - An image loaded in PIL or opencv directly + + The pipeline currently supports single image input. + + Return: + An array of contour polygons of detected N text instances in image, + every row is [x1, y1, x2, y2, x3, y3, x4, y4, ...]. + """ + return super().__call__(input, **kwargs) + + def preprocess(self, input: Input) -> Dict[str, Any]: + if self.model_type == 'DBNet': + result = self.preprocessor(input) return result else: img = LoadImage.convert_to_ndarray(input) @@ -235,9 +220,9 @@ class OCRDetectionPipeline(Pipeline): return result def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: - if 'vlpt' in self.model or 'db' in self.model: - pred = self.infer_model(input['img']) - return {'results': pred, 'org_shape': input['org_shape']} + if self.model_type == 'DBNet': + outputs = self.ocr_detector(input) + return outputs else: with self._graph.as_default(): with self._session.as_default(): @@ -247,23 +232,8 @@ class OCRDetectionPipeline(Pipeline): return sess_outputs def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: - if 'vlpt' in self.model: - pred = inputs['results'][0] - height, width = inputs['org_shape'] - segmentation = pred > self.thresh - - boxes, scores = polygons_from_bitmap(pred, segmentation, width, - height) - result = {OutputKeys.POLYGONS: np.array(boxes)} - return result - elif 'db' in self.model: - pred = inputs['results'][0] - height, width = inputs['org_shape'] - segmentation = pred > self.thresh - - boxes, scores = boxes_from_bitmap(pred, segmentation, width, - height) - result = {OutputKeys.POLYGONS: np.array(boxes)} + if self.model_type == 'DBNet': + result = {OutputKeys.POLYGONS: inputs['det_polygons']} return result else: rboxes = inputs['combined_rboxes'][0]