mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933] add content check model
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11473049
This commit is contained in:
3
data/test/images/content_check.jpg
Normal file
3
data/test/images/content_check.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7d486900ecca027d70453322d0f22de4b36f9534a324b8b1cda3ea86bb72bac6
|
||||
size 353096
|
||||
@@ -92,6 +92,7 @@ class Models(object):
|
||||
image_probing_model = 'image-probing-model'
|
||||
defrcn = 'defrcn'
|
||||
image_face_fusion = 'image-face-fusion'
|
||||
content_check = 'content-check'
|
||||
open_vocabulary_detection_vild = 'open-vocabulary-detection-vild'
|
||||
ecbsr = 'ecbsr'
|
||||
msrresnet_lite = 'msrresnet-lite'
|
||||
@@ -298,6 +299,7 @@ class Pipelines(object):
|
||||
face_recognition_onnx_fm = 'manual-face-recognition-frfm'
|
||||
arc_face_recognition = 'ir50-face-recognition-arcface'
|
||||
mask_face_recognition = 'resnet-face-recognition-facemask'
|
||||
content_check = 'resnet50-image-classification-cc'
|
||||
image_instance_segmentation = 'cascade-mask-rcnn-swin-image-instance-segmentation'
|
||||
maskdino_instance_segmentation = 'maskdino-swin-image-instance-segmentation'
|
||||
image2image_translation = 'image-to-image-translation'
|
||||
@@ -596,6 +598,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_hrnetw18_hand-pose-keypoints_coco-wholebody'),
|
||||
Tasks.card_detection: (Pipelines.card_detection,
|
||||
'damo/cv_resnet_carddetection_scrfd34gkps'),
|
||||
Tasks.content_check: (Pipelines.content_check,
|
||||
'damo/cv_resnet50_content-check_cc'),
|
||||
Tasks.face_detection:
|
||||
(Pipelines.mog_face_detection,
|
||||
'damo/cv_resnet101_face-detection_cvpr22papermogface'),
|
||||
|
||||
@@ -5,10 +5,12 @@ from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mmcls_model import ClassificationModel
|
||||
from .resnet50_cc import ContentCheckBackbone
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'mmcls_model': ['ClassificationModel'],
|
||||
'resnet50_cc': ['ContentCheckBackbone'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
50
modelscope/models/cv/image_classification/resnet50_cc.py
Normal file
50
modelscope/models/cv/image_classification/resnet50_cc.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from math import lgamma
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import models
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import MODELS
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.image_classification, Models.content_check)
|
||||
class ContentCheckBackbone(TorchModel):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ContentCheckBackbone, self).__init__()
|
||||
cc_model = models.resnet50()
|
||||
cc_model.fc = nn.Sequential(
|
||||
nn.Linear(2048, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(512, 10),
|
||||
)
|
||||
self.model = cc_model
|
||||
|
||||
def forward(self, img):
|
||||
x = self.model(img)
|
||||
return x
|
||||
|
||||
@classmethod
|
||||
def _instantiate(cls, **kwargs):
|
||||
model_file = kwargs.get('model_name', ModelFile.TORCH_MODEL_FILE)
|
||||
ckpt_path = os.path.join(kwargs['model_dir'], model_file)
|
||||
logger.info(f'loading model from {ckpt_path}')
|
||||
model_dir = kwargs.pop('model_dir')
|
||||
model = cls(**kwargs)
|
||||
ckpt_path = os.path.join(model_dir, model_file)
|
||||
load_dict = torch.load(ckpt_path, map_location='cpu')
|
||||
new_dict = {}
|
||||
for k, v in load_dict.items():
|
||||
new_dict['model.' + k] = v
|
||||
model.load_state_dict(new_dict)
|
||||
return model
|
||||
@@ -147,6 +147,12 @@ TASK_OUTPUTS = {
|
||||
Tasks.card_detection:
|
||||
[OutputKeys.SCORES, OutputKeys.BOXES, OutputKeys.KEYPOINTS],
|
||||
|
||||
# content check result for single sample
|
||||
# {
|
||||
# "scores": [0.9] # non sexy probability
|
||||
# }
|
||||
Tasks.content_check: [OutputKeys.SCORES],
|
||||
|
||||
# image driving perception result for single sample
|
||||
# {
|
||||
# "boxes": [
|
||||
|
||||
74
modelscope/pipelines/cv/content_check_pipeline.py
Normal file
74
modelscope/pipelines/cv/content_check_pipeline.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_classification, module_name=Pipelines.content_check)
|
||||
class ContentCheckPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a content check pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
Example:
|
||||
ContentCheckPipeline can judge whether the picture is pornographic
|
||||
|
||||
```python
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> cc_func = pipeline('image_classification', 'damo/cv_resnet50_image-classification_cc')
|
||||
>>> cc_func("https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/content_check.jpg")
|
||||
{'scores': [0.2789826989173889], 'labels': 'pornographic'}
|
||||
```
|
||||
"""
|
||||
|
||||
# content check model
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.test_transforms = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
logger.info('content check model loaded!')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_img(input)
|
||||
img = self.test_transforms(img).float()
|
||||
result = {}
|
||||
result['img'] = img
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
img = input['img'].unsqueeze(0)
|
||||
result = self.model(img)
|
||||
score = [1 - F.softmax(result[:, :5])[0][-1].tolist()]
|
||||
if score[0] < 0.5:
|
||||
label = 'pornographic'
|
||||
else:
|
||||
label = 'normal'
|
||||
return {OutputKeys.SCORES: score, OutputKeys.LABELS: label}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -131,8 +131,13 @@ class CVTasks(object):
|
||||
|
||||
# domain specific object detection
|
||||
domain_specific_object_detection = 'domain-specific-object-detection'
|
||||
# 3d reconstruction
|
||||
|
||||
# content check
|
||||
content_check = 'content-check'
|
||||
|
||||
# 3d face reconstruction
|
||||
face_reconstruction = 'face-reconstruction'
|
||||
|
||||
# image quality assessment mos
|
||||
image_quality_assessment_mos = 'image-quality-assessment-mos'
|
||||
# motion generation
|
||||
|
||||
29
tests/pipelines/test_content_check.py
Normal file
29
tests/pipelines/test_content_check.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ContentCheckTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_classification
|
||||
self.model_id = 'damo/cv_resnet50_image-classification_cc'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run(self):
|
||||
content_check_func = pipeline(self.task, model=self.model_id)
|
||||
result = content_check_func('data/test/images/content_check.jpg')
|
||||
print(result)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user