[to #42322933] Add image-classification-imagenet and image-classification-dailylife pipelines

*Add image-classification-imagenet and image-classification-dailylife pipelines
*Add models.cv.mmcls_model.ClassificaitonModel as a wrapper class for mmcls
This commit is contained in:
zhanning.gzn
2022-07-27 22:29:13 +08:00
parent 087e684da5
commit 81718bd643
11 changed files with 218 additions and 3 deletions

1
.gitattributes vendored
View File

@@ -2,3 +2,4 @@
*.jpg filter=lfs diff=lfs merge=lfs -text
*.mp4 filter=lfs diff=lfs merge=lfs -text
*.wav filter=lfs diff=lfs merge=lfs -text
*.JPEG filter=lfs diff=lfs merge=lfs -text

3
data/test/images/bird.JPEG Executable file
View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:19fb781a44aec9349a8b73850e53b7eb9b0623d54ebd0cd8577c13bf463b5004
size 74237

View File

@@ -10,6 +10,7 @@ class Models(object):
Model name should only contain model info but not task info.
"""
# vision models
classification_model = 'ClassificationModel'
nafnet = 'nafnet'
csrnet = 'csrnet'
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
@@ -66,6 +67,8 @@ class Pipelines(object):
action_recognition = 'TAdaConv_action-recognition'
animal_recognation = 'resnet101-animal_recog'
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
image_color_enhance = 'csrnet-image-color-enhance'
virtual_tryon = 'virtual_tryon'
image_colorization = 'unet-image-colorization'

View File

@@ -1,5 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import (action_recognition, animal_recognition, cartoon,
cmdssl_video_embedding, face_generation, image_color_enhance,
image_colorization, image_denoise, image_instance_segmentation,
super_resolution, virual_tryon)
cmdssl_video_embedding, face_generation, image_classification,
image_color_enhance, image_colorization, image_denoise,
image_instance_segmentation, super_resolution, virual_tryon)

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .mmcls_model import ClassificationModel
else:
_import_structure = {
'mmcls_model': ['ClassificationModel'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,46 @@
import os
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
@MODELS.register_module(
Tasks.image_classification_imagenet,
module_name=Models.classification_model)
@MODELS.register_module(
Tasks.image_classification_dailylife,
module_name=Models.classification_model)
class ClassificationModel(TorchModel):
def __init__(self, model_dir: str):
import mmcv
from mmcls.models import build_classifier
super().__init__(model_dir)
config = os.path.join(model_dir, 'config.py')
cfg = mmcv.Config.fromfile(config)
cfg.model.pretrained = None
self.cls_model = build_classifier(cfg.model)
self.cfg = cfg
self.ms_model_dir = model_dir
self.load_pretrained_checkpoint()
def forward(self, Inputs):
return self.cls_model(**Inputs)
def load_pretrained_checkpoint(self):
import mmcv
checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth')
if os.path.exists(checkpoint_path):
checkpoint = mmcv.runner.load_checkpoint(
self.cls_model, checkpoint_path, map_location='cpu')
if 'CLASSES' in checkpoint.get('meta', {}):
self.cls_model.CLASSES = checkpoint['meta']['CLASSES']
self.CLASSES = self.cls_model.CLASSES

View File

@@ -94,6 +94,12 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_gan_face-image-generation'),
Tasks.image_super_resolution: (Pipelines.image_super_resolution,
'damo/cv_rrdb_image-super-resolution'),
Tasks.image_classification_imagenet:
(Pipelines.general_image_classification,
'damo/cv_vit-base_image-classification_ImageNet-labels'),
Tasks.image_classification_dailylife:
(Pipelines.daily_image_classification,
'damo/cv_vit-base_image-classification_Dailylife-labels'),
}

View File

@@ -7,6 +7,7 @@ if TYPE_CHECKING:
from .action_recognition_pipeline import ActionRecognitionPipeline
from .animal_recog_pipeline import AnimalRecogPipeline
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
from .image_classification_pipeline import GeneralImageClassificationPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
from .image_cartoon_pipeline import ImageCartoonPipeline
from .image_denoise_pipeline import ImageDenoisePipeline
@@ -23,6 +24,8 @@ else:
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
'animal_recog_pipeline': ['AnimalRecogPipeline'],
'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'],
'image_classification_pipeline':
['GeneralImageClassificationPipeline'],
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
'virtual_tryon_pipeline': ['VirtualTryonPipeline'],
'image_colorization_pipeline': ['ImageColorizationPipeline'],

View File

@@ -0,0 +1,87 @@
from typing import Any, Dict
import cv2
import numpy as np
import PIL
import torch
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input
from modelscope.preprocessors import load_image
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from ..base import Pipeline
from ..builder import PIPELINES
logger = get_logger()
@PIPELINES.register_module(
Tasks.image_classification_imagenet,
module_name=Pipelines.general_image_classification)
@PIPELINES.register_module(
Tasks.image_classification_dailylife,
module_name=Pipelines.daily_image_classification)
class GeneralImageClassificationPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
use `model` and `preprocessor` to create a kws pipeline for prediction
Args:
model: model id on modelscope hub.
"""
super().__init__(model=model, **kwargs)
logger.info('load model done')
def preprocess(self, input: Input) -> Dict[str, Any]:
from mmcls.datasets.pipelines import Compose
from mmcv.parallel import collate, scatter
if isinstance(input, str):
img = np.array(load_image(input))
elif isinstance(input, PIL.Image.Image):
img = np.array(input.convert('RGB'))
elif isinstance(input, np.ndarray):
if len(input.shape) == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
img = input[:, :, ::-1] # in rgb order
else:
raise TypeError(f'input should be either str, PIL.Image,'
f' np.array, but got {type(input)}')
mmcls_cfg = self.model.cfg
# build the data pipeline
if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
mmcls_cfg.data.test.pipeline.pop(0)
data = dict(img=img)
test_pipeline = Compose(mmcls_cfg.data.test.pipeline)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
if next(self.model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [next(self.model.parameters()).device])[0]
return data
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
input['return_loss'] = False
scores = self.model(input)
return {'scores': scores}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
scores = inputs['scores']
pred_score = np.max(scores, axis=1)[0]
pred_label = np.argmax(scores, axis=1)[0]
result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
result['pred_class'] = self.model.CLASSES[result['pred_label']]
outputs = {
OutputKeys.SCORES: [result['pred_score']],
OutputKeys.LABELS: [result['pred_class']]
}
return outputs

View File

@@ -34,6 +34,8 @@ class CVTasks(object):
face_image_generation = 'face-image-generation'
image_super_resolution = 'image-super-resolution'
style_transfer = 'style-transfer'
image_classification_imagenet = 'image-classification-imagenet'
image_classification_dailylife = 'image-classification-dailylife'
class NLPTasks(object):

View File

@@ -0,0 +1,42 @@
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class GeneralImageClassificationTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_ImageNet(self):
general_image_classification = pipeline(
Tasks.image_classification_imagenet,
model='damo/cv_vit-base_image-classification_ImageNet-labels')
result = general_image_classification('data/test/images/bird.JPEG')
print(result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_Dailylife(self):
general_image_classification = pipeline(
Tasks.image_classification_dailylife,
model='damo/cv_vit-base_image-classification_Dailylife-labels')
result = general_image_classification('data/test/images/bird.JPEG')
print(result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_ImageNet_default_task(self):
general_image_classification = pipeline(
Tasks.image_classification_imagenet)
result = general_image_classification('data/test/images/bird.JPEG')
print(result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_Dailylife_default_task(self):
general_image_classification = pipeline(
Tasks.image_classification_dailylife)
result = general_image_classification('data/test/images/bird.JPEG')
print(result)
if __name__ == '__main__':
unittest.main()