diff --git a/.gitattributes b/.gitattributes index b2724f28..0d4c368e 100644 --- a/.gitattributes +++ b/.gitattributes @@ -2,3 +2,4 @@ *.jpg filter=lfs diff=lfs merge=lfs -text *.mp4 filter=lfs diff=lfs merge=lfs -text *.wav filter=lfs diff=lfs merge=lfs -text +*.JPEG filter=lfs diff=lfs merge=lfs -text diff --git a/data/test/images/bird.JPEG b/data/test/images/bird.JPEG new file mode 100755 index 00000000..897eb3c8 --- /dev/null +++ b/data/test/images/bird.JPEG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:19fb781a44aec9349a8b73850e53b7eb9b0623d54ebd0cd8577c13bf463b5004 +size 74237 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 1e67c885..c6858794 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -10,6 +10,7 @@ class Models(object): Model name should only contain model info but not task info. """ # vision models + classification_model = 'ClassificationModel' nafnet = 'nafnet' csrnet = 'csrnet' cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin' @@ -66,6 +67,8 @@ class Pipelines(object): action_recognition = 'TAdaConv_action-recognition' animal_recognation = 'resnet101-animal_recog' cmdssl_video_embedding = 'cmdssl-r2p1d_video_embedding' + general_image_classification = 'vit-base_image-classification_ImageNet-labels' + daily_image_classification = 'vit-base_image-classification_Dailylife-labels' image_color_enhance = 'csrnet-image-color-enhance' virtual_tryon = 'virtual_tryon' image_colorization = 'unet-image-colorization' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index db09dc77..88177746 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from . import (action_recognition, animal_recognition, cartoon, - cmdssl_video_embedding, face_generation, image_color_enhance, - image_colorization, image_denoise, image_instance_segmentation, - super_resolution, virual_tryon) + cmdssl_video_embedding, face_generation, image_classification, + image_color_enhance, image_colorization, image_denoise, + image_instance_segmentation, super_resolution, virual_tryon) diff --git a/modelscope/models/cv/image_classification/__init__.py b/modelscope/models/cv/image_classification/__init__.py new file mode 100644 index 00000000..7afe44bb --- /dev/null +++ b/modelscope/models/cv/image_classification/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .mmcls_model import ClassificationModel + +else: + _import_structure = { + 'mmcls_model': ['ClassificationModel'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_classification/mmcls_model.py b/modelscope/models/cv/image_classification/mmcls_model.py new file mode 100644 index 00000000..371c9d41 --- /dev/null +++ b/modelscope/models/cv/image_classification/mmcls_model.py @@ -0,0 +1,46 @@ +import os + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import Tasks + + +@MODELS.register_module( + Tasks.image_classification_imagenet, + module_name=Models.classification_model) +@MODELS.register_module( + Tasks.image_classification_dailylife, + module_name=Models.classification_model) +class ClassificationModel(TorchModel): + + def __init__(self, model_dir: str): + import mmcv + from mmcls.models import build_classifier + + super().__init__(model_dir) + + config = os.path.join(model_dir, 'config.py') + + cfg = mmcv.Config.fromfile(config) + cfg.model.pretrained = None + self.cls_model = build_classifier(cfg.model) + + self.cfg = cfg + self.ms_model_dir = model_dir + + self.load_pretrained_checkpoint() + + def forward(self, Inputs): + + return self.cls_model(**Inputs) + + def load_pretrained_checkpoint(self): + import mmcv + checkpoint_path = os.path.join(self.ms_model_dir, 'checkpoints.pth') + if os.path.exists(checkpoint_path): + checkpoint = mmcv.runner.load_checkpoint( + self.cls_model, checkpoint_path, map_location='cpu') + if 'CLASSES' in checkpoint.get('meta', {}): + self.cls_model.CLASSES = checkpoint['meta']['CLASSES'] + self.CLASSES = self.cls_model.CLASSES diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 557af27d..eb5f0e6d 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -94,6 +94,12 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_gan_face-image-generation'), Tasks.image_super_resolution: (Pipelines.image_super_resolution, 'damo/cv_rrdb_image-super-resolution'), + Tasks.image_classification_imagenet: + (Pipelines.general_image_classification, + 'damo/cv_vit-base_image-classification_ImageNet-labels'), + Tasks.image_classification_dailylife: + (Pipelines.daily_image_classification, + 'damo/cv_vit-base_image-classification_Dailylife-labels'), } diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 4019522e..5d5f93c1 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from .action_recognition_pipeline import ActionRecognitionPipeline from .animal_recog_pipeline import AnimalRecogPipeline from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline + from .image_classification_pipeline import GeneralImageClassificationPipeline from .face_image_generation_pipeline import FaceImageGenerationPipeline from .image_cartoon_pipeline import ImageCartoonPipeline from .image_denoise_pipeline import ImageDenoisePipeline @@ -23,6 +24,8 @@ else: 'action_recognition_pipeline': ['ActionRecognitionPipeline'], 'animal_recog_pipeline': ['AnimalRecogPipeline'], 'cmdssl_video_embedding_pipleline': ['CMDSSLVideoEmbeddingPipeline'], + 'image_classification_pipeline': + ['GeneralImageClassificationPipeline'], 'image_color_enhance_pipeline': ['ImageColorEnhancePipeline'], 'virtual_tryon_pipeline': ['VirtualTryonPipeline'], 'image_colorization_pipeline': ['ImageColorizationPipeline'], diff --git a/modelscope/pipelines/cv/image_classification_pipeline.py b/modelscope/pipelines/cv/image_classification_pipeline.py new file mode 100644 index 00000000..169187fe --- /dev/null +++ b/modelscope/pipelines/cv/image_classification_pipeline.py @@ -0,0 +1,87 @@ +from typing import Any, Dict + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input +from modelscope.preprocessors import load_image +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger +from ..base import Pipeline +from ..builder import PIPELINES + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification_imagenet, + module_name=Pipelines.general_image_classification) +@PIPELINES.register_module( + Tasks.image_classification_dailylife, + module_name=Pipelines.daily_image_classification) +class GeneralImageClassificationPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` and `preprocessor` to create a kws pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + logger.info('load model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + from mmcls.datasets.pipelines import Compose + from mmcv.parallel import collate, scatter + if isinstance(input, str): + img = np.array(load_image(input)) + elif isinstance(input, PIL.Image.Image): + img = np.array(input.convert('RGB')) + elif isinstance(input, np.ndarray): + if len(input.shape) == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + img = input[:, :, ::-1] # in rgb order + else: + raise TypeError(f'input should be either str, PIL.Image,' + f' np.array, but got {type(input)}') + + mmcls_cfg = self.model.cfg + # build the data pipeline + if mmcls_cfg.data.test.pipeline[0]['type'] == 'LoadImageFromFile': + mmcls_cfg.data.test.pipeline.pop(0) + data = dict(img=img) + test_pipeline = Compose(mmcls_cfg.data.test.pipeline) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + if next(self.model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + with torch.no_grad(): + input['return_loss'] = False + scores = self.model(input) + + return {'scores': scores} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + + scores = inputs['scores'] + pred_score = np.max(scores, axis=1)[0] + pred_label = np.argmax(scores, axis=1)[0] + result = {'pred_label': pred_label, 'pred_score': float(pred_score)} + result['pred_class'] = self.model.CLASSES[result['pred_label']] + + outputs = { + OutputKeys.SCORES: [result['pred_score']], + OutputKeys.LABELS: [result['pred_class']] + } + return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index b08977f9..fafb762f 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -34,6 +34,8 @@ class CVTasks(object): face_image_generation = 'face-image-generation' image_super_resolution = 'image-super-resolution' style_transfer = 'style-transfer' + image_classification_imagenet = 'image-classification-imagenet' + image_classification_dailylife = 'image-classification-dailylife' class NLPTasks(object): diff --git a/tests/pipelines/test_general_image_classification.py b/tests/pipelines/test_general_image_classification.py new file mode 100644 index 00000000..cf4ac3c0 --- /dev/null +++ b/tests/pipelines/test_general_image_classification.py @@ -0,0 +1,42 @@ +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + + +class GeneralImageClassificationTest(unittest.TestCase): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_ImageNet(self): + general_image_classification = pipeline( + Tasks.image_classification_imagenet, + model='damo/cv_vit-base_image-classification_ImageNet-labels') + result = general_image_classification('data/test/images/bird.JPEG') + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_Dailylife(self): + general_image_classification = pipeline( + Tasks.image_classification_dailylife, + model='damo/cv_vit-base_image-classification_Dailylife-labels') + result = general_image_classification('data/test/images/bird.JPEG') + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_ImageNet_default_task(self): + general_image_classification = pipeline( + Tasks.image_classification_imagenet) + result = general_image_classification('data/test/images/bird.JPEG') + print(result) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_Dailylife_default_task(self): + general_image_classification = pipeline( + Tasks.image_classification_dailylife) + result = general_image_classification('data/test/images/bird.JPEG') + print(result) + + +if __name__ == '__main__': + unittest.main()