mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] Add image-classification-imagenet and image-classification-dailylife pipelines
*Add image-classification-imagenet and image-classification-dailylife pipelines *Add models.cv.mmcls_model.ClassificaitonModel as a wrapper class for mmcls
This commit is contained in:
1
.gitattributes
vendored
1
.gitattributes
vendored
@@ -2,3 +2,4 @@
|
||||
*.jpg filter=lfs diff=lfs merge=lfs -text
|
||||
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
||||
*.wav filter=lfs diff=lfs merge=lfs -text
|
||||
*.JPEG filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
3
data/test/images/bird.JPEG
Executable file
3
data/test/images/bird.JPEG
Executable file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:19fb781a44aec9349a8b73850e53b7eb9b0623d54ebd0cd8577c13bf463b5004
|
||||
size 74237
|
||||
@@ -10,6 +10,7 @@ class Models(object):
|
||||
Model name should only contain model info but not task info.
|
||||
"""
|
||||
# vision models
|
||||
classification_model = 'ClassificationModel'
|
||||
nafnet = 'nafnet'
|
||||
csrnet = 'csrnet'
|
||||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
|
||||
@@ -66,6 +67,8 @@ class Pipelines(object):
|
||||
action_recognition = 'TAdaConv_action-recognition'
|
||||
animal_recognation = 'resnet101-animal_recog'
|
||||
cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding'
|
||||
general_image_classification = 'vit-base_image-classification_ImageNet-labels'
|
||||
daily_image_classification = 'vit-base_image-classification_Dailylife-labels'
|
||||
image_color_enhance = 'csrnet-image-color-enhance'
|
||||
virtual_tryon = 'virtual_tryon'
|
||||
image_colorization = 'unet-image-colorization'
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from . import (action_recognition, animal_recognition, cartoon,
|
||||
cmdssl_video_embedding, face_generation, image_color_enhance,
|
||||
image_colorization, image_denoise, image_instance_segmentation,
|
||||
super_resolution, virual_tryon)
|
||||
cmdssl_video_embedding, face_generation, image_classification,
|
||||
image_color_enhance, image_colorization, image_denoise,
|
||||
image_instance_segmentation, super_resolution, virual_tryon)
|
||||
|
||||
22
modelscope/models/cv/image_classification/__init__.py
Normal file
22
modelscope/models/cv/image_classification/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .mmcls_model import ClassificationModel
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'mmcls_model': ['ClassificationModel'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
46
modelscope/models/cv/image_classification/mmcls_model.py
Normal file
46
modelscope/models/cv/image_classification/mmcls_model.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import os
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_classification_imagenet,
|
||||
module_name=Models.classification_model)
|
||||
@MODELS.register_module(
|
||||
Tasks.image_classification_dailylife,
|
||||
module_name=Models.classification_model)
|
||||
class ClassificationModel(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str):
|
||||
import mmcv
|
||||
from mmcls.models import build_classifier
|
||||
|
||||
super().__init__(model_dir)
|
||||
|
||||
config = os.path.join(model_dir, 'config.py')
|
||||
|
||||
cfg = mmcv.Config.fromfile(config)
|
||||
cfg.model.pretrained = None
|
||||
self.cls_model = build_classifier(cfg.model)
|
||||
|
||||
self.cfg = cfg
|
||||
self.ms_model_dir = model_dir
|
||||
|
||||
self.load_pretrained_checkpoint()
|
||||
|
||||
def forward(self, Inputs):
|
||||
|
||||
return self.cls_model(**Inputs)
|
||||
|
||||
def load_pretrained_checkpoint(self):
|
||||
import mmcv
|
||||
checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth')
|
||||
if os.path.exists(checkpoint_path):
|
||||
checkpoint = mmcv.runner.load_checkpoint(
|
||||
self.cls_model, checkpoint_path, map_location='cpu')
|
||||
if 'CLASSES' in checkpoint.get('meta', {}):
|
||||
self.cls_model.CLASSES = checkpoint['meta']['CLASSES']
|
||||
self.CLASSES = self.cls_model.CLASSES
|
||||
@@ -94,6 +94,12 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_gan_face-image-generation'),
|
||||
Tasks.image_super_resolution: (Pipelines.image_super_resolution,
|
||||
'damo/cv_rrdb_image-super-resolution'),
|
||||
Tasks.image_classification_imagenet:
|
||||
(Pipelines.general_image_classification,
|
||||
'damo/cv_vit-base_image-classification_ImageNet-labels'),
|
||||
Tasks.image_classification_dailylife:
|
||||
(Pipelines.daily_image_classification,
|
||||
'damo/cv_vit-base_image-classification_Dailylife-labels'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ if TYPE_CHECKING:
|
||||
from .action_recognition_pipeline import ActionRecognitionPipeline
|
||||
from .animal_recog_pipeline import AnimalRecogPipeline
|
||||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
|
||||
from .image_classification_pipeline import GeneralImageClassificationPipeline
|
||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline
|
||||
from .image_cartoon_pipeline import ImageCartoonPipeline
|
||||
from .image_denoise_pipeline import ImageDenoisePipeline
|
||||
@@ -23,6 +24,8 @@ else:
|
||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||
'animal_recog_pipeline': ['AnimalRecogPipeline'],
|
||||
'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'],
|
||||
'image_classification_pipeline':
|
||||
['GeneralImageClassificationPipeline'],
|
||||
'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'],
|
||||
'virtual_tryon_pipeline': ['VirtualTryonPipeline'],
|
||||
'image_colorization_pipeline': ['ImageColorizationPipeline'],
|
||||
|
||||
87
modelscope/pipelines/cv/image_classification_pipeline.py
Normal file
87
modelscope/pipelines/cv/image_classification_pipeline.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.preprocessors import load_image
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..base import Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_classification_imagenet,
|
||||
module_name=Pipelines.general_image_classification)
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_classification_dailylife,
|
||||
module_name=Pipelines.daily_image_classification)
|
||||
class GeneralImageClassificationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` and `preprocessor` to create a kws pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
logger.info('load model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
from mmcls.datasets.pipelines import Compose
|
||||
from mmcv.parallel import collate, scatter
|
||||
if isinstance(input, str):
|
||||
img = np.array(load_image(input))
|
||||
elif isinstance(input, PIL.Image.Image):
|
||||
img = np.array(input.convert('RGB'))
|
||||
elif isinstance(input, np.ndarray):
|
||||
if len(input.shape) == 2:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
|
||||
img = input[:, :, ::-1] # in rgb order
|
||||
else:
|
||||
raise TypeError(f'input should be either str, PIL.Image,'
|
||||
f' np.array, but got {type(input)}')
|
||||
|
||||
mmcls_cfg = self.model.cfg
|
||||
# build the data pipeline
|
||||
if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile':
|
||||
mmcls_cfg.data.test.pipeline.pop(0)
|
||||
data = dict(img=img)
|
||||
test_pipeline = Compose(mmcls_cfg.data.test.pipeline)
|
||||
data = test_pipeline(data)
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
if next(self.model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
data = scatter(data, [next(self.model.parameters()).device])[0]
|
||||
|
||||
return data
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
with torch.no_grad():
|
||||
input['return_loss'] = False
|
||||
scores = self.model(input)
|
||||
|
||||
return {'scores': scores}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
scores = inputs['scores']
|
||||
pred_score = np.max(scores, axis=1)[0]
|
||||
pred_label = np.argmax(scores, axis=1)[0]
|
||||
result = {'pred_label': pred_label, 'pred_score': float(pred_score)}
|
||||
result['pred_class'] = self.model.CLASSES[result['pred_label']]
|
||||
|
||||
outputs = {
|
||||
OutputKeys.SCORES: [result['pred_score']],
|
||||
OutputKeys.LABELS: [result['pred_class']]
|
||||
}
|
||||
return outputs
|
||||
@@ -34,6 +34,8 @@ class CVTasks(object):
|
||||
face_image_generation = 'face-image-generation'
|
||||
image_super_resolution = 'image-super-resolution'
|
||||
style_transfer = 'style-transfer'
|
||||
image_classification_imagenet = 'image-classification-imagenet'
|
||||
image_classification_dailylife = 'image-classification-dailylife'
|
||||
|
||||
|
||||
class NLPTasks(object):
|
||||
|
||||
42
tests/pipelines/test_general_image_classification.py
Normal file
42
tests/pipelines/test_general_image_classification.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import unittest
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class GeneralImageClassificationTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_ImageNet(self):
|
||||
general_image_classification = pipeline(
|
||||
Tasks.image_classification_imagenet,
|
||||
model='damo/cv_vit-base_image-classification_ImageNet-labels')
|
||||
result = general_image_classification('data/test/images/bird.JPEG')
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_Dailylife(self):
|
||||
general_image_classification = pipeline(
|
||||
Tasks.image_classification_dailylife,
|
||||
model='damo/cv_vit-base_image-classification_Dailylife-labels')
|
||||
result = general_image_classification('data/test/images/bird.JPEG')
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_ImageNet_default_task(self):
|
||||
general_image_classification = pipeline(
|
||||
Tasks.image_classification_imagenet)
|
||||
result = general_image_classification('data/test/images/bird.JPEG')
|
||||
print(result)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_Dailylife_default_task(self):
|
||||
general_image_classification = pipeline(
|
||||
Tasks.image_classification_dailylife)
|
||||
result = general_image_classification('data/test/images/bird.JPEG')
|
||||
print(result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user