mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
add model for card correction
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14049168
This commit is contained in:
@@ -4,36 +4,38 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .version import __release_datetime__, __version__
|
||||
from .trainers import EpochBasedTrainer, TrainingArgs, build_dataset_from_file
|
||||
from .trainers import Hook, Priority
|
||||
from .exporters import Exporter
|
||||
from .exporters import TfModelExporter
|
||||
from .exporters import TorchModelExporter
|
||||
from .exporters import Exporter, TfModelExporter, TorchModelExporter
|
||||
from .hub.api import HubApi
|
||||
from .hub.snapshot_download import snapshot_download
|
||||
from .hub.check_model import check_local_model_is_latest, check_model_is_id
|
||||
from .hub.push_to_hub import push_to_hub, push_to_hub_async
|
||||
from .hub.check_model import check_model_is_id, check_local_model_is_latest
|
||||
from .metrics import AudioNoiseMetric, Metric, task_default_metrics, ImageColorEnhanceMetric, ImageDenoiseMetric, \
|
||||
ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric, SequenceClassificationMetric, \
|
||||
TextGenerationMetric, TokenClassificationMetric, VideoSummarizationMetric, MovieSceneSegmentationMetric, \
|
||||
AccuracyMetric, BleuMetric, ImageInpaintingMetric, ReferringVideoObjectSegmentationMetric, \
|
||||
VideoFrameInterpolationMetric, VideoStabilizationMetric, VideoSuperResolutionMetric, PplMetric, \
|
||||
ImageQualityAssessmentDegradationMetric, ImageQualityAssessmentMosMetric, TextRankingMetric, \
|
||||
LossMetric, ImageColorizationMetric, OCRRecognitionMetric
|
||||
from .hub.snapshot_download import snapshot_download
|
||||
from .metrics import (
|
||||
AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric,
|
||||
ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric,
|
||||
ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric,
|
||||
ImageQualityAssessmentDegradationMetric,
|
||||
ImageQualityAssessmentMosMetric, LossMetric, Metric,
|
||||
MovieSceneSegmentationMetric, OCRRecognitionMetric, PplMetric,
|
||||
ReferringVideoObjectSegmentationMetric, SequenceClassificationMetric,
|
||||
TextGenerationMetric, TextRankingMetric, TokenClassificationMetric,
|
||||
VideoFrameInterpolationMetric, VideoStabilizationMetric,
|
||||
VideoSummarizationMetric, VideoSuperResolutionMetric,
|
||||
task_default_metrics)
|
||||
from .models import Model, TorchModel
|
||||
from .preprocessors import Preprocessor
|
||||
from .msdatasets import MsDataset
|
||||
from .pipelines import Pipeline, pipeline
|
||||
from .utils.hub import read_config, create_model_if_not_exist
|
||||
from .utils.logger import get_logger
|
||||
from .preprocessors import Preprocessor
|
||||
from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs,
|
||||
build_dataset_from_file)
|
||||
from .utils.constant import Tasks
|
||||
from .utils.hf_util import AutoConfig, GenerationConfig
|
||||
from .utils.hf_util import (AutoModel, AutoModelForCausalLM,
|
||||
from .utils.hf_util import (AutoConfig, AutoModel, AutoModelForCausalLM,
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification)
|
||||
from .utils.hf_util import AutoTokenizer
|
||||
from .msdatasets import MsDataset
|
||||
AutoModelForTokenClassification, AutoTokenizer,
|
||||
GenerationConfig)
|
||||
from .utils.hub import create_model_if_not_exist, read_config
|
||||
from .utils.logger import get_logger
|
||||
from .version import __release_datetime__, __version__
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
|
||||
@@ -7,13 +7,13 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
if TYPE_CHECKING:
|
||||
from .base import Exporter
|
||||
from .builder import build_exporter
|
||||
from .cv import CartoonTranslationExporter
|
||||
from .nlp import CsanmtForTranslationExporter
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
from .cv import FaceDetectionSCRFDExporter
|
||||
from .cv import CartoonTranslationExporter, FaceDetectionSCRFDExporter
|
||||
from .multi_modal import StableDiffuisonExporter
|
||||
from .nlp import (CsanmtForTranslationExporter,
|
||||
SbertForSequenceClassificationExporter,
|
||||
SbertForZeroShotClassificationExporter)
|
||||
from .tf_model_exporter import TfModelExporter
|
||||
from .torch_model_exporter import TorchModelExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'base': ['Exporter'],
|
||||
|
||||
@@ -6,8 +6,9 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cartoon_translation_exporter import CartoonTranslationExporter
|
||||
from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter
|
||||
from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter
|
||||
from .object_detection_damoyolo_exporter import \
|
||||
ObjectDetectionDamoyoloExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'cartoon_translation_exporter': ['CartoonTranslationExporter'],
|
||||
|
||||
@@ -6,7 +6,8 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .csanmt_for_translation_exporter import CsanmtForTranslationExporter
|
||||
from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter
|
||||
from .model_for_token_classification_exporter import \
|
||||
ModelForSequenceClassificationExporter
|
||||
from .sbert_for_sequence_classification_exporter import \
|
||||
SbertForSequenceClassificationExporter
|
||||
from .sbert_for_zero_shot_classification_exporter import \
|
||||
|
||||
@@ -28,7 +28,8 @@ class CsanmtForTranslationExporter(TfModelExporter):
|
||||
tf.disable_eager_execution()
|
||||
super().__init__(model)
|
||||
|
||||
from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline
|
||||
from modelscope.pipelines.nlp.translation_pipeline import \
|
||||
TranslationPipeline
|
||||
self.pipeline = TranslationPipeline(self.model)
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
|
||||
@@ -293,6 +293,7 @@ class Pipelines(object):
|
||||
table_recognition = 'dla34-table-recognition'
|
||||
lineless_table_recognition = 'lore-lineless-table-recognition'
|
||||
license_plate_detection = 'resnet18-license-plate-detection'
|
||||
card_detection_correction = 'resnet18-card-detection-correction'
|
||||
action_recognition = 'TAdaConv_action-recognition'
|
||||
animal_recognition = 'resnet101-animal-recognition'
|
||||
general_recognition = 'resnet101-general-recognition'
|
||||
@@ -677,6 +678,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.license_plate_detection:
|
||||
(Pipelines.license_plate_detection,
|
||||
'damo/cv_resnet18_license-plate-detection_damo'),
|
||||
Tasks.card_detection_correction: (Pipelines.card_detection_correction,
|
||||
'damo/cv_resnet18_card_correction'),
|
||||
Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'),
|
||||
Tasks.feature_extraction: (Pipelines.feature_extraction,
|
||||
'damo/pert_feature-extraction_base-test'),
|
||||
|
||||
@@ -4,34 +4,39 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .accuracy_metric import AccuracyMetric
|
||||
from .audio_noise_metric import AudioNoiseMetric
|
||||
from .base import Metric
|
||||
from .bleu_metric import BleuMetric
|
||||
from .builder import METRICS, build_metric, task_default_metrics
|
||||
from .image_color_enhance_metric import ImageColorEnhanceMetric
|
||||
from .image_colorization_metric import ImageColorizationMetric
|
||||
from .image_denoise_metric import ImageDenoiseMetric
|
||||
from .image_inpainting_metric import ImageInpaintingMetric
|
||||
from .image_instance_segmentation_metric import \
|
||||
ImageInstanceSegmentationCOCOMetric
|
||||
from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric
|
||||
from .image_portrait_enhancement_metric import \
|
||||
ImagePortraitEnhancementMetric
|
||||
from .image_quality_assessment_degradation_metric import \
|
||||
ImageQualityAssessmentDegradationMetric
|
||||
from .image_quality_assessment_mos_metric import \
|
||||
ImageQualityAssessmentMosMetric
|
||||
from .loss_metric import LossMetric
|
||||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
|
||||
from .ocr_recognition_metric import OCRRecognitionMetric
|
||||
from .ppl_metric import PplMetric
|
||||
from .referring_video_object_segmentation_metric import \
|
||||
ReferringVideoObjectSegmentationMetric
|
||||
from .sequence_classification_metric import SequenceClassificationMetric
|
||||
from .text_generation_metric import TextGenerationMetric
|
||||
from .text_ranking_metric import TextRankingMetric
|
||||
from .token_classification_metric import TokenClassificationMetric
|
||||
from .video_summarization_metric import VideoSummarizationMetric
|
||||
from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric
|
||||
from .accuracy_metric import AccuracyMetric
|
||||
from .bleu_metric import BleuMetric
|
||||
from .image_inpainting_metric import ImageInpaintingMetric
|
||||
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric
|
||||
from .translation_evaluation_metric import TranslationEvaluationMetric
|
||||
from .video_frame_interpolation_metric import VideoFrameInterpolationMetric
|
||||
from .video_stabilization_metric import VideoStabilizationMetric
|
||||
from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric
|
||||
from .ppl_metric import PplMetric
|
||||
from .image_quality_assessment_degradation_metric import ImageQualityAssessmentDegradationMetric
|
||||
from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric
|
||||
from .text_ranking_metric import TextRankingMetric
|
||||
from .loss_metric import LossMetric
|
||||
from .image_colorization_metric import ImageColorizationMetric
|
||||
from .ocr_recognition_metric import OCRRecognitionMetric
|
||||
from .translation_evaluation_metric import TranslationEvaluationMetric
|
||||
from .video_summarization_metric import VideoSummarizationMetric
|
||||
from .video_super_resolution_metric.video_super_resolution_metric import \
|
||||
VideoSuperResolutionMetric
|
||||
else:
|
||||
_import_structure = {
|
||||
'audio_noise_metric': ['AudioNoiseMetric'],
|
||||
|
||||
@@ -442,6 +442,8 @@ TASK_OUTPUTS = {
|
||||
Tasks.table_recognition: [OutputKeys.POLYGONS],
|
||||
Tasks.lineless_table_recognition: [OutputKeys.POLYGONS, OutputKeys.BOXES],
|
||||
Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT],
|
||||
Tasks.card_detection_correction:
|
||||
[OutputKeys.POLYGONS, OutputKeys.OUTPUT_IMGS],
|
||||
|
||||
# ocr recognition result for single sample
|
||||
# {
|
||||
@@ -669,8 +671,9 @@ TASK_OUTPUTS = {
|
||||
# np.array # 2D array containing only 0, 1
|
||||
# ]
|
||||
# }
|
||||
Tasks.image_segmentation:
|
||||
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS],
|
||||
Tasks.image_segmentation: [
|
||||
OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS
|
||||
],
|
||||
|
||||
# video panoptic segmentation result for single sample
|
||||
# "scores": [[0.8, 0.25, 0.05, 0.05], [0.9, 0.1, 0.05, 0.05]]
|
||||
|
||||
@@ -110,6 +110,8 @@ TASK_INPUTS = {
|
||||
InputType.IMAGE,
|
||||
Tasks.license_plate_detection:
|
||||
InputType.IMAGE,
|
||||
Tasks.card_detection_correction:
|
||||
InputType.IMAGE,
|
||||
Tasks.lineless_table_recognition:
|
||||
InputType.IMAGE,
|
||||
Tasks.table_recognition:
|
||||
|
||||
@@ -48,6 +48,7 @@ if TYPE_CHECKING:
|
||||
from .ocr_detection_pipeline import OCRDetectionPipeline
|
||||
from .ocr_recognition_pipeline import OCRRecognitionPipeline
|
||||
from .license_plate_detection_pipeline import LicensePlateDetectionPipeline
|
||||
from .card_detection_correction_pipeline import CardDetectionCorrectionPipeline
|
||||
from .table_recognition_pipeline import TableRecognitionPipeline
|
||||
from .lineless_table_recognition_pipeline import LinelessTableRecognitionPipeline
|
||||
from .skin_retouching_pipeline import SkinRetouchingPipeline
|
||||
@@ -165,6 +166,8 @@ else:
|
||||
'ocr_detection_pipeline': ['OCRDetectionPipeline'],
|
||||
'ocr_recognition_pipeline': ['OCRRecognitionPipeline'],
|
||||
'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'],
|
||||
'card_detection_correction_pipeline':
|
||||
['CardDetectionCorrectionPipeline'],
|
||||
'table_recognition_pipeline': ['TableRecognitionPipeline'],
|
||||
'skin_retouching_pipeline': ['SkinRetouchingPipeline'],
|
||||
'face_reconstruction_pipeline': ['FaceReconstructionPipeline'],
|
||||
@@ -184,8 +187,9 @@ else:
|
||||
'facial_landmark_confidence_pipeline':
|
||||
['FacialLandmarkConfidencePipeline'],
|
||||
'face_processing_base_pipeline': ['FaceProcessingBasePipeline'],
|
||||
'face_attribute_recognition_pipeline':
|
||||
['FaceAttributeRecognitionPipeline'],
|
||||
'face_attribute_recognition_pipeline': [
|
||||
'FaceAttributeRecognitionPipeline'
|
||||
],
|
||||
'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'],
|
||||
'hand_static_pipeline': ['HandStaticPipeline'],
|
||||
'referring_video_object_segmentation_pipeline': [
|
||||
|
||||
208
modelscope/pipelines/cv/card_detection_correction_pipeline.py
Normal file
208
modelscope/pipelines/cv/card_detection_correction_pipeline.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import math
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \
|
||||
CardDetectionCorrectionModel
|
||||
from modelscope.pipelines.cv.ocr_utils.table_process import (
|
||||
bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms)
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.device import (create_device, device_placement,
|
||||
verify_device)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.card_detection_correction,
|
||||
module_name=Pipelines.card_detection_correction)
|
||||
class CardDetectionCorrection(Pipeline):
|
||||
r""" Card Detection Pipeline.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
|
||||
>>> detector = pipeline(Tasks.card_detection_correction, model='damo/cv_resnet18_card_correction')
|
||||
>>> detector("https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/card_detection_correction.jpg")
|
||||
>>> {
|
||||
>>> "polygons": array([[ 60.562023, 110.682144, 688.57715, 77.34028, 720.2409,
|
||||
>>> 480.33508, 70.20054, 504.9171 ]], dtype=float32),
|
||||
>>> "output_imgs": [array([
|
||||
>>> [[[168, 176, 192],
|
||||
>>> [165, 173, 188],
|
||||
>>> [163, 172, 187],
|
||||
>>> ...,
|
||||
>>> [153, 153, 165],
|
||||
>>> [153, 153, 165],
|
||||
>>> [153, 153, 165]],
|
||||
>>> [[187, 194, 210],
|
||||
>>> [184, 192, 203],
|
||||
>>> [183, 191, 199],
|
||||
>>> ...,
|
||||
>>> [168, 166, 186],
|
||||
>>> [169, 166, 185],
|
||||
>>> [169, 165, 184]],
|
||||
>>> [[186, 193, 211],
|
||||
>>> [183, 191, 205],
|
||||
>>> [183, 192, 203],
|
||||
>>> ...,
|
||||
>>> [170, 167, 187],
|
||||
>>> [171, 165, 186],
|
||||
>>> [170, 164, 184]]]], dtype=uint8)}
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
device: str = 'gpu',
|
||||
device_map=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE)
|
||||
config_path = osp.join(self.model, ModelFile.CONFIGURATION)
|
||||
logger.info(f'loading model from {model_path}')
|
||||
|
||||
self.cfg = Config.from_file(config_path)
|
||||
self.K = self.cfg.K
|
||||
|
||||
if device_map is not None:
|
||||
assert device == 'gpu', '`device` and `device_map` cannot be input at the same time!'
|
||||
self.device_map = device_map
|
||||
verify_device(device)
|
||||
self.device_name = device
|
||||
self.device = create_device(self.device_name)
|
||||
|
||||
self.infer_model = CardDetectionCorrectionModel()
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
if 'state_dict' in checkpoint:
|
||||
self.infer_model.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
self.infer_model.load_state_dict(checkpoint)
|
||||
self.infer_model = self.infer_model.to(self.device)
|
||||
self.infer_model.to(self.device).eval()
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_ndarray(input)[:, :, ::-1]
|
||||
self.image = np.array(img)
|
||||
|
||||
mean = np.array([0.408, 0.447, 0.470],
|
||||
dtype=np.float32).reshape(1, 1, 3)
|
||||
std = np.array([0.289, 0.274, 0.278],
|
||||
dtype=np.float32).reshape(1, 1, 3)
|
||||
height, width = img.shape[0:2]
|
||||
inp_height, inp_width = self.cfg.input_h, self.cfg.input_w
|
||||
c = np.array([width / 2., height / 2.], dtype=np.float32)
|
||||
s = max(height, width) * 1.0
|
||||
|
||||
trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height])
|
||||
resized_image = cv2.resize(img, (width, height))
|
||||
inp_image = cv2.warpAffine(
|
||||
resized_image,
|
||||
trans_input, (inp_width, inp_height),
|
||||
flags=cv2.INTER_LINEAR)
|
||||
inp_image = ((inp_image / 255. - mean) / std).astype(np.float32)
|
||||
|
||||
images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height,
|
||||
inp_width)
|
||||
images = torch.from_numpy(images).to(self.device)
|
||||
meta = {
|
||||
'c': c,
|
||||
's': s,
|
||||
'input_height': inp_height,
|
||||
'input_width': inp_width,
|
||||
'out_height': inp_height // 4,
|
||||
'out_width': inp_width // 4
|
||||
}
|
||||
|
||||
result = {'img': images, 'meta': meta}
|
||||
|
||||
return result
|
||||
|
||||
def distance(self, x1, y1, x2, y2):
|
||||
return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2))
|
||||
|
||||
def crop_image(self, img, position):
|
||||
x0, y0 = position[0][0], position[0][1]
|
||||
x1, y1 = position[1][0], position[1][1]
|
||||
x2, y2 = position[2][0], position[2][1]
|
||||
x3, y3 = position[3][0], position[3][1]
|
||||
|
||||
img_width = self.distance((x0 + x3) / 2, (y0 + y3) / 2, (x1 + x2) / 2,
|
||||
(y1 + y2) / 2)
|
||||
img_height = self.distance((x0 + x1) / 2, (y0 + y1) / 2, (x2 + x3) / 2,
|
||||
(y2 + y3) / 2)
|
||||
|
||||
corners_trans = np.zeros((4, 2), np.float32)
|
||||
corners_trans[0] = [0, 0]
|
||||
corners_trans[1] = [img_width, 0]
|
||||
corners_trans[2] = [img_width, img_height]
|
||||
corners_trans[3] = [0, img_height]
|
||||
|
||||
transform = cv2.getPerspectiveTransform(position, corners_trans)
|
||||
dst = cv2.warpPerspective(img, transform,
|
||||
(int(img_width), int(img_height)))
|
||||
return dst
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pred = self.infer_model(input['img'])
|
||||
return {'results': pred, 'meta': input['meta']}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
output = inputs['results'][0]
|
||||
meta = inputs['meta']
|
||||
hm = output['hm'].sigmoid_()
|
||||
wh = output['wh']
|
||||
reg = output['reg']
|
||||
angle_cls = output['cls'].sigmoid_()
|
||||
|
||||
bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K)
|
||||
angle_cls = decode_by_ind(
|
||||
angle_cls, inds, K=self.K).detach().cpu().numpy()
|
||||
bbox = bbox.detach().cpu().numpy()
|
||||
for i in range(bbox.shape[1]):
|
||||
bbox[0][i][9] = angle_cls[0][i]
|
||||
bbox = nms(bbox, 0.3)
|
||||
bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()],
|
||||
[meta['s']], meta['out_height'],
|
||||
meta['out_width'])
|
||||
|
||||
res = []
|
||||
angle = []
|
||||
sub_imgs = []
|
||||
for idx, box in enumerate(bbox[0]):
|
||||
if box[8] > 0.3:
|
||||
angle.append(int(box[9]))
|
||||
res.append(box[0:8])
|
||||
sub_img = self.crop_image(self.image,
|
||||
res[-1].copy().reshape(4, 2))
|
||||
if angle[-1] == 1:
|
||||
sub_img = cv2.rotate(sub_img, 2)
|
||||
if angle[-1] == 2:
|
||||
sub_img = cv2.rotate(sub_img, 1)
|
||||
if angle[-1] == 3:
|
||||
sub_img = cv2.rotate(sub_img, 0)
|
||||
sub_imgs.append(sub_img)
|
||||
|
||||
result = {
|
||||
OutputKeys.POLYGONS: np.array(res),
|
||||
OutputKeys.OUTPUT_IMGS: np.array(sub_imgs)
|
||||
}
|
||||
return result
|
||||
@@ -91,10 +91,10 @@ class Bottleneck(nn.Module):
|
||||
|
||||
class PoseResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, head_conv=64, **kwargs):
|
||||
def __init__(self, block, layers, heads, head_conv=64, **kwargs):
|
||||
self.inplanes = 64
|
||||
self.deconv_with_bias = False
|
||||
self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2}
|
||||
self.heads = heads
|
||||
|
||||
super(PoseResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
@@ -270,6 +270,14 @@ resnet_spec = {
|
||||
|
||||
|
||||
def LicensePlateDet(num_layers=18):
|
||||
heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2}
|
||||
block_class, layers = resnet_spec[num_layers]
|
||||
model = PoseResNet(block_class, layers)
|
||||
model = PoseResNet(block_class, layers, heads)
|
||||
return model
|
||||
|
||||
|
||||
def CardDetectionCorrectionModel(num_layers=18):
|
||||
heads = {'hm': 1, 'cls': 4, 'ftype': 2, 'wh': 8, 'reg': 2}
|
||||
block_class, layers = resnet_spec[num_layers]
|
||||
model = PoseResNet(block_class, layers, heads)
|
||||
return model
|
||||
|
||||
@@ -19,6 +19,7 @@ class CVTasks(object):
|
||||
table_recognition = 'table-recognition'
|
||||
lineless_table_recognition = 'lineless-table-recognition'
|
||||
license_plate_detection = 'license-plate-detection'
|
||||
card_detection_correction = 'card-detection-correction'
|
||||
|
||||
# human face body related
|
||||
animal_recognition = 'animal-recognition'
|
||||
|
||||
39
tests/pipelines/test_card_detection_correction.py
Normal file
39
tests/pipelines/test_card_detection_correction.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class CardDetectionCorrectionTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model_id = 'damo/cv_resnet18_card_correction'
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
self.test_image = osp.join(cache_path, 'data/demo.jpg')
|
||||
self.task = Tasks.card_detection_correction
|
||||
|
||||
def pipeline_inference(self, pipe: Pipeline, input_location: str):
|
||||
result = pipe(input_location)
|
||||
print('card detection results: ')
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
card_detection_correction = pipeline(
|
||||
Tasks.card_detection_correction, model=self.model_id)
|
||||
self.pipeline_inference(card_detection_correction, self.test_image)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
card_detection_correction = pipeline(Tasks.card_detection_correction)
|
||||
self.pipeline_inference(card_detection_correction, self.test_image)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user