From 6b09cb3d7a34493b4d28690eeb51c2f87f93d82d Mon Sep 17 00:00:00 2001 From: "rujiao.lrj" Date: Tue, 19 Sep 2023 19:20:19 +0800 Subject: [PATCH] add model for card correction Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14049168 --- modelscope/__init__.py | 48 ++-- modelscope/exporters/__init__.py | 12 +- modelscope/exporters/cv/__init__.py | 3 +- modelscope/exporters/nlp/__init__.py | 3 +- .../nlp/csanmt_for_translation_exporter.py | 3 +- modelscope/metainfo.py | 3 + modelscope/metrics/__init__.py | 37 ++-- modelscope/outputs/outputs.py | 7 +- modelscope/pipeline_inputs.py | 2 + modelscope/pipelines/cv/__init__.py | 8 +- .../cv/card_detection_correction_pipeline.py | 208 ++++++++++++++++++ .../cv/ocr_utils/model_resnet18_half.py | 14 +- modelscope/utils/constant.py | 1 + .../test_card_detection_correction.py | 39 ++++ 14 files changed, 333 insertions(+), 55 deletions(-) create mode 100644 modelscope/pipelines/cv/card_detection_correction_pipeline.py create mode 100644 tests/pipelines/test_card_detection_correction.py diff --git a/modelscope/__init__.py b/modelscope/__init__.py index ac362be1..5a2f470e 100644 --- a/modelscope/__init__.py +++ b/modelscope/__init__.py @@ -4,36 +4,38 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: - from .version import __release_datetime__, __version__ - from .trainers import EpochBasedTrainer, TrainingArgs, build_dataset_from_file - from .trainers import Hook, Priority - from .exporters import Exporter - from .exporters import TfModelExporter - from .exporters import TorchModelExporter + from .exporters import Exporter, TfModelExporter, TorchModelExporter from .hub.api import HubApi - from .hub.snapshot_download import snapshot_download + from .hub.check_model import check_local_model_is_latest, check_model_is_id from .hub.push_to_hub import push_to_hub, push_to_hub_async - from .hub.check_model import check_model_is_id, check_local_model_is_latest - from .metrics import AudioNoiseMetric, Metric, task_default_metrics, ImageColorEnhanceMetric, ImageDenoiseMetric, \ - ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric, SequenceClassificationMetric, \ - TextGenerationMetric, TokenClassificationMetric, VideoSummarizationMetric, MovieSceneSegmentationMetric, \ - AccuracyMetric, BleuMetric, ImageInpaintingMetric, ReferringVideoObjectSegmentationMetric, \ - VideoFrameInterpolationMetric, VideoStabilizationMetric, VideoSuperResolutionMetric, PplMetric, \ - ImageQualityAssessmentDegradationMetric, ImageQualityAssessmentMosMetric, TextRankingMetric, \ - LossMetric, ImageColorizationMetric, OCRRecognitionMetric + from .hub.snapshot_download import snapshot_download + from .metrics import ( + AccuracyMetric, AudioNoiseMetric, BleuMetric, ImageColorEnhanceMetric, + ImageColorizationMetric, ImageDenoiseMetric, ImageInpaintingMetric, + ImageInstanceSegmentationCOCOMetric, ImagePortraitEnhancementMetric, + ImageQualityAssessmentDegradationMetric, + ImageQualityAssessmentMosMetric, LossMetric, Metric, + MovieSceneSegmentationMetric, OCRRecognitionMetric, PplMetric, + ReferringVideoObjectSegmentationMetric, SequenceClassificationMetric, + TextGenerationMetric, TextRankingMetric, TokenClassificationMetric, + VideoFrameInterpolationMetric, VideoStabilizationMetric, + VideoSummarizationMetric, VideoSuperResolutionMetric, + task_default_metrics) from .models import Model, TorchModel - from .preprocessors import Preprocessor + from .msdatasets import MsDataset from .pipelines import Pipeline, pipeline - from .utils.hub import read_config, create_model_if_not_exist - from .utils.logger import get_logger + from .preprocessors import Preprocessor + from .trainers import (EpochBasedTrainer, Hook, Priority, TrainingArgs, + build_dataset_from_file) from .utils.constant import Tasks - from .utils.hf_util import AutoConfig, GenerationConfig - from .utils.hf_util import (AutoModel, AutoModelForCausalLM, + from .utils.hf_util import (AutoConfig, AutoModel, AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, - AutoModelForTokenClassification) - from .utils.hf_util import AutoTokenizer - from .msdatasets import MsDataset + AutoModelForTokenClassification, AutoTokenizer, + GenerationConfig) + from .utils.hub import create_model_if_not_exist, read_config + from .utils.logger import get_logger + from .version import __release_datetime__, __version__ else: _import_structure = { diff --git a/modelscope/exporters/__init__.py b/modelscope/exporters/__init__.py index e5a10a0d..7fc094ac 100644 --- a/modelscope/exporters/__init__.py +++ b/modelscope/exporters/__init__.py @@ -7,13 +7,13 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .base import Exporter from .builder import build_exporter - from .cv import CartoonTranslationExporter - from .nlp import CsanmtForTranslationExporter - from .tf_model_exporter import TfModelExporter - from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter - from .torch_model_exporter import TorchModelExporter - from .cv import FaceDetectionSCRFDExporter + from .cv import CartoonTranslationExporter, FaceDetectionSCRFDExporter from .multi_modal import StableDiffuisonExporter + from .nlp import (CsanmtForTranslationExporter, + SbertForSequenceClassificationExporter, + SbertForZeroShotClassificationExporter) + from .tf_model_exporter import TfModelExporter + from .torch_model_exporter import TorchModelExporter else: _import_structure = { 'base': ['Exporter'], diff --git a/modelscope/exporters/cv/__init__.py b/modelscope/exporters/cv/__init__.py index 67a406db..a8c65f2f 100644 --- a/modelscope/exporters/cv/__init__.py +++ b/modelscope/exporters/cv/__init__.py @@ -6,8 +6,9 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .cartoon_translation_exporter import CartoonTranslationExporter - from .object_detection_damoyolo_exporter import ObjectDetectionDamoyoloExporter from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter + from .object_detection_damoyolo_exporter import \ + ObjectDetectionDamoyoloExporter else: _import_structure = { 'cartoon_translation_exporter': ['CartoonTranslationExporter'], diff --git a/modelscope/exporters/nlp/__init__.py b/modelscope/exporters/nlp/__init__.py index 26df5775..4f9c1bc5 100644 --- a/modelscope/exporters/nlp/__init__.py +++ b/modelscope/exporters/nlp/__init__.py @@ -6,7 +6,8 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .csanmt_for_translation_exporter import CsanmtForTranslationExporter - from .model_for_token_classification_exporter import ModelForSequenceClassificationExporter + from .model_for_token_classification_exporter import \ + ModelForSequenceClassificationExporter from .sbert_for_sequence_classification_exporter import \ SbertForSequenceClassificationExporter from .sbert_for_zero_shot_classification_exporter import \ diff --git a/modelscope/exporters/nlp/csanmt_for_translation_exporter.py b/modelscope/exporters/nlp/csanmt_for_translation_exporter.py index 65b55b43..c7a584d8 100644 --- a/modelscope/exporters/nlp/csanmt_for_translation_exporter.py +++ b/modelscope/exporters/nlp/csanmt_for_translation_exporter.py @@ -28,7 +28,8 @@ class CsanmtForTranslationExporter(TfModelExporter): tf.disable_eager_execution() super().__init__(model) - from modelscope.pipelines.nlp.translation_pipeline import TranslationPipeline + from modelscope.pipelines.nlp.translation_pipeline import \ + TranslationPipeline self.pipeline = TranslationPipeline(self.model) def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]: diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 23ffdab1..750b7aa0 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -293,6 +293,7 @@ class Pipelines(object): table_recognition = 'dla34-table-recognition' lineless_table_recognition = 'lore-lineless-table-recognition' license_plate_detection = 'resnet18-license-plate-detection' + card_detection_correction = 'resnet18-card-detection-correction' action_recognition = 'TAdaConv_action-recognition' animal_recognition = 'resnet101-animal-recognition' general_recognition = 'resnet101-general-recognition' @@ -677,6 +678,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.license_plate_detection: (Pipelines.license_plate_detection, 'damo/cv_resnet18_license-plate-detection_damo'), + Tasks.card_detection_correction: (Pipelines.card_detection_correction, + 'damo/cv_resnet18_card_correction'), Tasks.fill_mask: (Pipelines.fill_mask, 'damo/nlp_veco_fill-mask-large'), Tasks.feature_extraction: (Pipelines.feature_extraction, 'damo/pert_feature-extraction_base-test'), diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index 6f5dfbde..75ccfcf9 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -4,34 +4,39 @@ from typing import TYPE_CHECKING from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: + from .accuracy_metric import AccuracyMetric from .audio_noise_metric import AudioNoiseMetric from .base import Metric + from .bleu_metric import BleuMetric from .builder import METRICS, build_metric, task_default_metrics from .image_color_enhance_metric import ImageColorEnhanceMetric + from .image_colorization_metric import ImageColorizationMetric from .image_denoise_metric import ImageDenoiseMetric + from .image_inpainting_metric import ImageInpaintingMetric from .image_instance_segmentation_metric import \ ImageInstanceSegmentationCOCOMetric - from .image_portrait_enhancement_metric import ImagePortraitEnhancementMetric + from .image_portrait_enhancement_metric import \ + ImagePortraitEnhancementMetric + from .image_quality_assessment_degradation_metric import \ + ImageQualityAssessmentDegradationMetric + from .image_quality_assessment_mos_metric import \ + ImageQualityAssessmentMosMetric + from .loss_metric import LossMetric + from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric + from .ocr_recognition_metric import OCRRecognitionMetric + from .ppl_metric import PplMetric + from .referring_video_object_segmentation_metric import \ + ReferringVideoObjectSegmentationMetric from .sequence_classification_metric import SequenceClassificationMetric from .text_generation_metric import TextGenerationMetric + from .text_ranking_metric import TextRankingMetric from .token_classification_metric import TokenClassificationMetric - from .video_summarization_metric import VideoSummarizationMetric - from .movie_scene_segmentation_metric import MovieSceneSegmentationMetric - from .accuracy_metric import AccuracyMetric - from .bleu_metric import BleuMetric - from .image_inpainting_metric import ImageInpaintingMetric - from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric + from .translation_evaluation_metric import TranslationEvaluationMetric from .video_frame_interpolation_metric import VideoFrameInterpolationMetric from .video_stabilization_metric import VideoStabilizationMetric - from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric - from .ppl_metric import PplMetric - from .image_quality_assessment_degradation_metric import ImageQualityAssessmentDegradationMetric - from .image_quality_assessment_mos_metric import ImageQualityAssessmentMosMetric - from .text_ranking_metric import TextRankingMetric - from .loss_metric import LossMetric - from .image_colorization_metric import ImageColorizationMetric - from .ocr_recognition_metric import OCRRecognitionMetric - from .translation_evaluation_metric import TranslationEvaluationMetric + from .video_summarization_metric import VideoSummarizationMetric + from .video_super_resolution_metric.video_super_resolution_metric import \ + VideoSuperResolutionMetric else: _import_structure = { 'audio_noise_metric': ['AudioNoiseMetric'], diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index d9c6147e..f1ae964c 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -442,6 +442,8 @@ TASK_OUTPUTS = { Tasks.table_recognition: [OutputKeys.POLYGONS], Tasks.lineless_table_recognition: [OutputKeys.POLYGONS, OutputKeys.BOXES], Tasks.license_plate_detection: [OutputKeys.POLYGONS, OutputKeys.TEXT], + Tasks.card_detection_correction: + [OutputKeys.POLYGONS, OutputKeys.OUTPUT_IMGS], # ocr recognition result for single sample # { @@ -669,8 +671,9 @@ TASK_OUTPUTS = { # np.array # 2D array containing only 0, 1 # ] # } - Tasks.image_segmentation: - [OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS], + Tasks.image_segmentation: [ + OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS + ], # video panoptic segmentation result for single sample # "scores": [[0.8, 0.25, 0.05, 0.05], [0.9, 0.1, 0.05, 0.05]] diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index df0d4794..3be03682 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -110,6 +110,8 @@ TASK_INPUTS = { InputType.IMAGE, Tasks.license_plate_detection: InputType.IMAGE, + Tasks.card_detection_correction: + InputType.IMAGE, Tasks.lineless_table_recognition: InputType.IMAGE, Tasks.table_recognition: diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 15ebb80e..00fc21d8 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -48,6 +48,7 @@ if TYPE_CHECKING: from .ocr_detection_pipeline import OCRDetectionPipeline from .ocr_recognition_pipeline import OCRRecognitionPipeline from .license_plate_detection_pipeline import LicensePlateDetectionPipeline + from .card_detection_correction_pipeline import CardDetectionCorrectionPipeline from .table_recognition_pipeline import TableRecognitionPipeline from .lineless_table_recognition_pipeline import LinelessTableRecognitionPipeline from .skin_retouching_pipeline import SkinRetouchingPipeline @@ -165,6 +166,8 @@ else: 'ocr_detection_pipeline': ['OCRDetectionPipeline'], 'ocr_recognition_pipeline': ['OCRRecognitionPipeline'], 'license_plate_detection_pipeline': ['LicensePlateDetectionPipeline'], + 'card_detection_correction_pipeline': + ['CardDetectionCorrectionPipeline'], 'table_recognition_pipeline': ['TableRecognitionPipeline'], 'skin_retouching_pipeline': ['SkinRetouchingPipeline'], 'face_reconstruction_pipeline': ['FaceReconstructionPipeline'], @@ -184,8 +187,9 @@ else: 'facial_landmark_confidence_pipeline': ['FacialLandmarkConfidencePipeline'], 'face_processing_base_pipeline': ['FaceProcessingBasePipeline'], - 'face_attribute_recognition_pipeline': - ['FaceAttributeRecognitionPipeline'], + 'face_attribute_recognition_pipeline': [ + 'FaceAttributeRecognitionPipeline' + ], 'mtcnn_face_detection_pipeline': ['MtcnnFaceDetectionPipeline'], 'hand_static_pipeline': ['HandStaticPipeline'], 'referring_video_object_segmentation_pipeline': [ diff --git a/modelscope/pipelines/cv/card_detection_correction_pipeline.py b/modelscope/pipelines/cv/card_detection_correction_pipeline.py new file mode 100644 index 00000000..dac174de --- /dev/null +++ b/modelscope/pipelines/cv/card_detection_correction_pipeline.py @@ -0,0 +1,208 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import math +import os.path as osp +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.pipelines.cv.ocr_utils.model_resnet18_half import \ + CardDetectionCorrectionModel +from modelscope.pipelines.cv.ocr_utils.table_process import ( + bbox_decode, bbox_post_process, decode_by_ind, get_affine_transform, nms) +from modelscope.preprocessors import load_image +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.device import (create_device, device_placement, + verify_device) +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.card_detection_correction, + module_name=Pipelines.card_detection_correction) +class CardDetectionCorrection(Pipeline): + r""" Card Detection Pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + + >>> detector = pipeline(Tasks.card_detection_correction, model='damo/cv_resnet18_card_correction') + >>> detector("https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/card_detection_correction.jpg") + >>> { + >>> "polygons": array([[ 60.562023, 110.682144, 688.57715, 77.34028, 720.2409, + >>> 480.33508, 70.20054, 504.9171 ]], dtype=float32), + >>> "output_imgs": [array([ + >>> [[[168, 176, 192], + >>> [165, 173, 188], + >>> [163, 172, 187], + >>> ..., + >>> [153, 153, 165], + >>> [153, 153, 165], + >>> [153, 153, 165]], + >>> [[187, 194, 210], + >>> [184, 192, 203], + >>> [183, 191, 199], + >>> ..., + >>> [168, 166, 186], + >>> [169, 166, 185], + >>> [169, 165, 184]], + >>> [[186, 193, 211], + >>> [183, 191, 205], + >>> [183, 192, 203], + >>> ..., + >>> [170, 167, 187], + >>> [171, 165, 186], + >>> [170, 164, 184]]]], dtype=uint8)} + """ + + def __init__(self, + model: str, + device: str = 'gpu', + device_map=None, + **kwargs): + """ + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + model_path = osp.join(self.model, ModelFile.TORCH_MODEL_FILE) + config_path = osp.join(self.model, ModelFile.CONFIGURATION) + logger.info(f'loading model from {model_path}') + + self.cfg = Config.from_file(config_path) + self.K = self.cfg.K + + if device_map is not None: + assert device == 'gpu', '`device` and `device_map` cannot be input at the same time!' + self.device_map = device_map + verify_device(device) + self.device_name = device + self.device = create_device(self.device_name) + + self.infer_model = CardDetectionCorrectionModel() + checkpoint = torch.load(model_path, map_location=self.device) + if 'state_dict' in checkpoint: + self.infer_model.load_state_dict(checkpoint['state_dict']) + else: + self.infer_model.load_state_dict(checkpoint) + self.infer_model = self.infer_model.to(self.device) + self.infer_model.to(self.device).eval() + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_ndarray(input)[:, :, ::-1] + self.image = np.array(img) + + mean = np.array([0.408, 0.447, 0.470], + dtype=np.float32).reshape(1, 1, 3) + std = np.array([0.289, 0.274, 0.278], + dtype=np.float32).reshape(1, 1, 3) + height, width = img.shape[0:2] + inp_height, inp_width = self.cfg.input_h, self.cfg.input_w + c = np.array([width / 2., height / 2.], dtype=np.float32) + s = max(height, width) * 1.0 + + trans_input = get_affine_transform(c, s, 0, [inp_width, inp_height]) + resized_image = cv2.resize(img, (width, height)) + inp_image = cv2.warpAffine( + resized_image, + trans_input, (inp_width, inp_height), + flags=cv2.INTER_LINEAR) + inp_image = ((inp_image / 255. - mean) / std).astype(np.float32) + + images = inp_image.transpose(2, 0, 1).reshape(1, 3, inp_height, + inp_width) + images = torch.from_numpy(images).to(self.device) + meta = { + 'c': c, + 's': s, + 'input_height': inp_height, + 'input_width': inp_width, + 'out_height': inp_height // 4, + 'out_width': inp_width // 4 + } + + result = {'img': images, 'meta': meta} + + return result + + def distance(self, x1, y1, x2, y2): + return math.sqrt(pow(x1 - x2, 2) + pow(y1 - y2, 2)) + + def crop_image(self, img, position): + x0, y0 = position[0][0], position[0][1] + x1, y1 = position[1][0], position[1][1] + x2, y2 = position[2][0], position[2][1] + x3, y3 = position[3][0], position[3][1] + + img_width = self.distance((x0 + x3) / 2, (y0 + y3) / 2, (x1 + x2) / 2, + (y1 + y2) / 2) + img_height = self.distance((x0 + x1) / 2, (y0 + y1) / 2, (x2 + x3) / 2, + (y2 + y3) / 2) + + corners_trans = np.zeros((4, 2), np.float32) + corners_trans[0] = [0, 0] + corners_trans[1] = [img_width, 0] + corners_trans[2] = [img_width, img_height] + corners_trans[3] = [0, img_height] + + transform = cv2.getPerspectiveTransform(position, corners_trans) + dst = cv2.warpPerspective(img, transform, + (int(img_width), int(img_height))) + return dst + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + pred = self.infer_model(input['img']) + return {'results': pred, 'meta': input['meta']} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + output = inputs['results'][0] + meta = inputs['meta'] + hm = output['hm'].sigmoid_() + wh = output['wh'] + reg = output['reg'] + angle_cls = output['cls'].sigmoid_() + + bbox, inds = bbox_decode(hm, wh, reg=reg, K=self.K) + angle_cls = decode_by_ind( + angle_cls, inds, K=self.K).detach().cpu().numpy() + bbox = bbox.detach().cpu().numpy() + for i in range(bbox.shape[1]): + bbox[0][i][9] = angle_cls[0][i] + bbox = nms(bbox, 0.3) + bbox = bbox_post_process(bbox.copy(), [meta['c'].cpu().numpy()], + [meta['s']], meta['out_height'], + meta['out_width']) + + res = [] + angle = [] + sub_imgs = [] + for idx, box in enumerate(bbox[0]): + if box[8] > 0.3: + angle.append(int(box[9])) + res.append(box[0:8]) + sub_img = self.crop_image(self.image, + res[-1].copy().reshape(4, 2)) + if angle[-1] == 1: + sub_img = cv2.rotate(sub_img, 2) + if angle[-1] == 2: + sub_img = cv2.rotate(sub_img, 1) + if angle[-1] == 3: + sub_img = cv2.rotate(sub_img, 0) + sub_imgs.append(sub_img) + + result = { + OutputKeys.POLYGONS: np.array(res), + OutputKeys.OUTPUT_IMGS: np.array(sub_imgs) + } + return result diff --git a/modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py b/modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py index 2d771eb4..7a732674 100644 --- a/modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py +++ b/modelscope/pipelines/cv/ocr_utils/model_resnet18_half.py @@ -91,10 +91,10 @@ class Bottleneck(nn.Module): class PoseResNet(nn.Module): - def __init__(self, block, layers, head_conv=64, **kwargs): + def __init__(self, block, layers, heads, head_conv=64, **kwargs): self.inplanes = 64 self.deconv_with_bias = False - self.heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2} + self.heads = heads super(PoseResNet, self).__init__() self.conv1 = nn.Conv2d( @@ -270,6 +270,14 @@ resnet_spec = { def LicensePlateDet(num_layers=18): + heads = {'hm': 1, 'cls': 4, 'ftype': 11, 'wh': 8, 'reg': 2} block_class, layers = resnet_spec[num_layers] - model = PoseResNet(block_class, layers) + model = PoseResNet(block_class, layers, heads) + return model + + +def CardDetectionCorrectionModel(num_layers=18): + heads = {'hm': 1, 'cls': 4, 'ftype': 2, 'wh': 8, 'reg': 2} + block_class, layers = resnet_spec[num_layers] + model = PoseResNet(block_class, layers, heads) return model diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 68801e79..fb315f4b 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -19,6 +19,7 @@ class CVTasks(object): table_recognition = 'table-recognition' lineless_table_recognition = 'lineless-table-recognition' license_plate_detection = 'license-plate-detection' + card_detection_correction = 'card-detection-correction' # human face body related animal_recognition = 'animal-recognition' diff --git a/tests/pipelines/test_card_detection_correction.py b/tests/pipelines/test_card_detection_correction.py new file mode 100644 index 00000000..cd979d4f --- /dev/null +++ b/tests/pipelines/test_card_detection_correction.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os.path as osp +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class CardDetectionCorrectionTest(unittest.TestCase): + + def setUp(self) -> None: + self.model_id = 'damo/cv_resnet18_card_correction' + cache_path = snapshot_download(self.model_id) + self.test_image = osp.join(cache_path, 'data/demo.jpg') + self.task = Tasks.card_detection_correction + + def pipeline_inference(self, pipe: Pipeline, input_location: str): + result = pipe(input_location) + print('card detection results: ') + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + card_detection_correction = pipeline( + Tasks.card_detection_correction, model=self.model_id) + self.pipeline_inference(card_detection_correction, self.test_image) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + card_detection_correction = pipeline(Tasks.card_detection_correction) + self.pipeline_inference(card_detection_correction, self.test_image) + + +if __name__ == '__main__': + unittest.main()