diff --git a/data/test/images/image_structured_model_probing_test_image.jpg b/data/test/images/image_structured_model_probing_test_image.jpg new file mode 100644 index 00000000..54f79fea Binary files /dev/null and b/data/test/images/image_structured_model_probing_test_image.jpg differ diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index f8a788f3..fa475ebb 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -80,6 +80,7 @@ class Models(object): image_casmvs_depth_estimation = 'image-casmvs-depth-estimation' vop_retrieval_model = 'vop-retrieval-model' ddcolor = 'ddcolor' + image_probing_model = 'image-probing-model' defrcn = 'defrcn' image_face_fusion = 'image-face-fusion' ddpm = 'ddpm' @@ -310,6 +311,7 @@ class Pipelines(object): video_panoptic_segmentation = 'video-panoptic-segmentation' vop_retrieval = 'vop-video-text-retrieval' ddcolor_image_colorization = 'ddcolor-image-colorization' + image_structured_model_probing = 'image-structured-model-probing' image_fewshot_detection = 'image-fewshot-detection' image_face_fusion = 'image-face-fusion' ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation' diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index 44a6896f..0f4f33c2 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -9,13 +9,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, image_denoise, image_inpainting, image_instance_segmentation, image_matching, image_mvs_depth_estimation, image_panoptic_segmentation, image_portrait_enhancement, - image_quality_assessment_mos, image_reid_person, - image_restoration, image_semantic_segmentation, - image_to_image_generation, image_to_image_translation, - language_guided_video_summarization, movie_scene_segmentation, - object_detection, panorama_depth_estimation, - pointcloud_sceneflow_estimation, product_retrieval_embedding, - realtime_object_detection, referring_video_object_segmentation, + image_probing_model, image_quality_assessment_mos, + image_reid_person, image_restoration, + image_semantic_segmentation, image_to_image_generation, + image_to_image_translation, language_guided_video_summarization, + movie_scene_segmentation, object_detection, + panorama_depth_estimation, pointcloud_sceneflow_estimation, + product_retrieval_embedding, realtime_object_detection, + referring_video_object_segmentation, robust_image_classification, salient_detection, shop_segmentation, super_resolution, video_frame_interpolation, video_object_segmentation, video_panoptic_segmentation, diff --git a/modelscope/models/cv/image_probing_model/__init__.py b/modelscope/models/cv/image_probing_model/__init__.py new file mode 100644 index 00000000..e97a1b77 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + + from .model import StructuredProbingModel + +else: + _import_structure = { + 'model': ['StructuredProbingModel'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_probing_model/backbone.py b/modelscope/models/cv/image_probing_model/backbone.py new file mode 100644 index 00000000..8f3ed5b6 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/backbone.py @@ -0,0 +1,308 @@ +# The implementation is adopted from OpenAI-CLIP, +# made pubicly available under the MIT License at https://github.com/openai/CLIP + +import math +import sys +from collections import OrderedDict +from functools import reduce +from operator import mul + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from PIL import Image +from torchvision import models + +from .utils import convert_weights, load_pretrained + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed + # after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, + # and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict([('-1', nn.AvgPool2d(stride)), + ('0', + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False)), + ('1', nn.BatchNorm2d(planes * self.expansion))])) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + + def __init__(self, + spacial_dim: int, + embed_dim: int, + num_heads: int, + output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], + x.shape[2] * x.shape[3]).permute(2, 0, 1) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) + x = x + self.positional_embedding[:, None, :].to(x.dtype) + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False) + + return x[0] + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + ret = super().forward(x.type(torch.float32)) + return ret.type(orig_type) + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model))])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: torch.Tensor): + self.attn_mask = self.attn_mask.to( + dtype=x.dtype, + device=x.device) if self.attn_mask is not None else None + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor, idx): + features = {} + x_norm = self.ln_1(x) + features['layer_{}_pre_attn'.format(idx)] = x_norm.permute(1, 0, 2) + attn = self.attention(x_norm) + features['layer_{}_attn'.format(idx)] = attn.permute(1, 0, 2) + x = x + attn + mlp = self.mlp(self.ln_2(x)) + features['layer_{}_mlp'.format(idx)] = mlp.permute(1, 0, 2) + x = x + mlp + return x, features + + +class Transformer(nn.Module): + + def __init__(self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList() + for i in range(layers): + block = ResidualAttentionBlock(width, heads, attn_mask) + self.resblocks.append(block) + + def forward(self, x: torch.Tensor): + features = {} + for idx, block in enumerate(self.resblocks): + x, block_feats = block(x, idx) + features.update(block_feats) + return x, features + + +class VisualTransformer(nn.Module): + + def __init__(self, input_resolution: int, patch_size: int, width: int, + layers: int, heads: int, output_dim: int): + super().__init__() + print(input_resolution, patch_size, width, layers, heads, output_dim) + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def forward(self, x: torch.Tensor, return_all=True): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + zeros = torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device) + # shape = [*, grid ** 2 + 1, width] + x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1) + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x, features = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if return_all: + features['pre_logits'] = x + return features + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CLIPNet(nn.Module): + + def __init__(self, arch_name, pretrained, **kwargs): + super(CLIPNet, self).__init__() + + if arch_name == 'CLIP_ViTB32': + self.clip = VisualTransformer( + input_resolution=224, + patch_size=32, + width=768, + layers=12, + heads=12, + output_dim=512) + + elif arch_name in ('CLIP_ViTB16', 'CLIP_ViTB16_FP16'): + self.clip = VisualTransformer( + input_resolution=224, + patch_size=16, + width=768, + layers=12, + heads=12, + output_dim=512) + + elif arch_name in ('CLIP_ViTL14', 'CLIP_ViTL14_FP16'): + self.clip = VisualTransformer( + input_resolution=224, + patch_size=14, + width=1024, + layers=24, + heads=16, + output_dim=768) + + else: + raise KeyError(f'Unsupported arch_name for CLIP, {arch_name}') + + def forward(self, input_data): + output = self.clip(input_data) + return output + + +def CLIP(arch_name='CLIP_RN50', + use_pretrain=False, + load_from='', + state_dict=None, + **kwargs): + model = CLIPNet(arch_name=arch_name, pretrained=None, **kwargs) + if use_pretrain: + if arch_name.endswith('FP16'): + convert_weights(model.clip) + load_pretrained(model.clip, state_dict, load_from) + return model + + +class ProbingModel(torch.nn.Module): + + def __init__(self, feat_size, num_classes): + super(ProbingModel, self).__init__() + self.linear = torch.nn.Linear(feat_size, num_classes) + + def forward(self, x): + return self.linear(x) diff --git a/modelscope/models/cv/image_probing_model/model.py b/modelscope/models/cv/image_probing_model/model.py new file mode 100644 index 00000000..e7636f40 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/model.py @@ -0,0 +1,93 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import os +from typing import Any, Dict + +import json +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +from .backbone import CLIP, ProbingModel + + +@MODELS.register_module( + Tasks.image_classification, module_name=Models.image_probing_model) +class StructuredProbingModel(TorchModel): + """ + The implementation of 'Structured Model Probing: Empowering + Efficient Adaptation by Structured Regularization'. + """ + + def __init__(self, model_dir, *args, **kwargs): + """ + Initialize a probing model. + Args: + model_dir: model id or path + """ + super(StructuredProbingModel, self).__init__() + model_dir = os.path.join(model_dir, 'food101-clip-vitl14-full.pt') + model_file = torch.load(model_dir) + self.feature_size = model_file['meta_info']['feature_size'] + self.num_classes = model_file['meta_info']['num_classes'] + self.backbone = CLIP( + 'CLIP_ViTL14_FP16', + use_pretrain=True, + state_dict=model_file['backbone_model_state_dict']) + self.probing_model = ProbingModel(self.feature_size, self.num_classes) + self.probing_model.load_state_dict( + model_file['probing_model_state_dict']) + + def forward(self, x): + """ + Forward Function of SMP. + Args: + x: the input images (B, 3, H, W) + """ + + keys = [] + for idx in range(0, 24): + keys.append('layer_{}_pre_attn'.format(idx)) + keys.append('layer_{}_attn'.format(idx)) + keys.append('layer_{}_mlp'.format(idx)) + keys.append('pre_logits') + features = self.backbone(x.half()) + features_agg = [] + for i in keys: + aggregated_feature = self.aggregate_token(features[i], 1024) + features_agg.append(aggregated_feature) + features_agg = torch.cat((features_agg), dim=1) + outputs = self.probing_model(features_agg.float()) + return outputs + + def aggregate_token(self, output, target_size): + """ + Aggregating features from tokens. + Args: + output: the output of intermidiant features + from a ViT model + target_size: target aggregated feature size + """ + if len(output.shape) == 3: + _, n_token, channels = output.shape + if channels >= target_size: + pool_size = 0 + else: + n_groups = target_size / channels + pool_size = int(n_token / n_groups) + + if pool_size > 0: + output = torch.permute(output, (0, 2, 1)) + output = torch.nn.AvgPool1d( + kernel_size=pool_size, stride=pool_size)( + output) + output = torch.flatten(output, start_dim=1) + else: + output = torch.mean(output, dim=1) + output = torch.nn.functional.normalize(output, dim=1) + return output diff --git a/modelscope/models/cv/image_probing_model/utils.py b/modelscope/models/cv/image_probing_model/utils.py new file mode 100644 index 00000000..c2b13ae5 --- /dev/null +++ b/modelscope/models/cv/image_probing_model/utils.py @@ -0,0 +1,148 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import re + +import torch +import torch.nn as nn + + +def load_pretrained(model: torch.nn.Module, + state_dict, + local_path: str, + map_location='cpu', + logger=None, + sub_level=None): + return load_pretrained_dict(model, state_dict, logger, sub_level=sub_level) + + +def load_pretrained_dict(model: torch.nn.Module, + state_dict: dict, + logger=None, + sub_level=None): + """ + Load parameters to model with + 1. Sub name by revise_keys For DataParallelModel or DistributeParallelModel. + 2. Load 'state_dict' again if possible by key 'state_dict' or 'model_state'. + 3. Take sub level keys from source, e.g. load 'backbone' part from a classifier into a backbone model. + 4. Auto remove invalid parameters from source. + 5. Log or warning if unexpected key exists or key misses. + + Args: + model (torch.nn.Module): + state_dict (dict): dict of parameters + logger (logging.Logger, None): + sub_level (str, optional): If not None, parameters with key startswith sub_level will remove the prefix + to fit actual model keys. This action happens if user want to load sub module parameters + into a sub module model. + """ + revise_keys = [(r'^module\.', '')] + + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + if 'model_state' in state_dict: + state_dict = state_dict['model_state'] + + for p, r in revise_keys: + state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()} + + if sub_level: + sub_level = sub_level if sub_level.endswith('.') else (sub_level + '.') + sub_level_len = len(sub_level) + state_dict = { + key[sub_level_len:]: value + for key, value in state_dict.items() if key.startswith(sub_level) + } + + state_dict = _auto_drop_invalid(model, state_dict, logger=logger) + + load_status = model.load_state_dict(state_dict, strict=False) + unexpected_keys = load_status.unexpected_keys + missing_keys = load_status.missing_keys + err_msgs = [] + if unexpected_keys: + err_msgs.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msgs.append('missing key in source ' + f'state_dict: {", ".join(missing_keys)}\n') + err_msgs = '\n'.join(err_msgs) + + if len(err_msgs) > 0: + if logger: + logger.warning(err_msgs) + else: + import warnings + warnings.warn(err_msgs) + + +def convert_weights(model: nn.Module): + """ + Convert applicable model parameters to fp16. + """ + + def _convert_weights_to_fp16(layer): + if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)): + layer.weight.data = layer.weight.data.half() + if layer.bias is not None: + layer.bias.data = layer.bias.data.half() + + if isinstance(layer, nn.MultiheadAttention): + for attr in [ + *[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']], + 'in_proj_bias', 'bias_k', 'bias_v' + ]: + tensor = getattr(layer, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ['text_projection', 'proj']: + if hasattr(layer, name): + attr = getattr(layer, name) + if attr is not None: + attr.data = attr.data.half() + + for name in ['prompt_embeddings']: + if hasattr(layer, name): + attr = getattr(layer, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +def _auto_drop_invalid(model: torch.nn.Module, state_dict: dict, logger=None): + """ + Strip unmatched parameters in state_dict, e.g. shape not matched, type not matched. + + Args: + model (torch.nn.Module): + state_dict (dict): + logger (logging.Logger, None): + + Returns: + A new state dict. + """ + ret_dict = state_dict.copy() + invalid_msgs = [] + for key, value in model.state_dict().items(): + if key in state_dict: + # Check shape + new_value = state_dict[key] + if value.shape != new_value.shape: + invalid_msgs.append( + f'{key}: invalid shape, dst {value.shape} vs. src {new_value.shape}' + ) + ret_dict.pop(key) + elif value.dtype != new_value.dtype: + invalid_msgs.append( + f'{key}: invalid dtype, dst {value.dtype} vs. src {new_value.dtype}' + ) + ret_dict.pop(key) + if len(invalid_msgs) > 0: + warning_msg = 'ignore keys from source: \n' + '\n'.join(invalid_msgs) + if logger: + logger.warning(warning_msg) + else: + import warnings + warnings.warn(warning_msg) + return ret_dict diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index d5839eab..c37a5630 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -86,6 +86,7 @@ if TYPE_CHECKING: from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline from .panorama_depth_estimation_pipeline import PanoramaDepthEstimationPipeline from .ddcolor_image_colorization_pipeline import DDColorImageColorizationPipeline + from .image_structured_model_probing_pipeline import ImageStructuredModelProbingPipeline from .video_colorization_pipeline import VideoColorizationPipeline from .image_defrcn_fewshot_pipeline import ImageDefrcnDetectionPipeline from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline @@ -207,6 +208,9 @@ else: 'ddcolor_image_colorization_pipeline': [ 'DDColorImageColorizationPipeline' ], + 'image_structured_model_probing_pipeline': [ + 'ImageSturcturedModelProbingPipeline' + ], 'video_colorization_pipeline': ['VideoColorizationPipeline'], 'image_defrcn_fewshot_pipeline': ['ImageDefrcnDetectionPipeline'], 'image_quality_assessment_mos_pipeline': [ diff --git a/modelscope/pipelines/cv/image_structured_model_probing_pipeline.py b/modelscope/pipelines/cv/image_structured_model_probing_pipeline.py new file mode 100644 index 00000000..bc2561e2 --- /dev/null +++ b/modelscope/pipelines/cv/image_structured_model_probing_pipeline.py @@ -0,0 +1,79 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math +import os +import os.path as osp +from typing import Any, Dict + +import numpy as np +import torch +import torchvision.transforms as transforms +from mmcv.parallel import collate, scatter + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_classification, + module_name=Pipelines.image_structured_model_probing) +class ImageStructuredModelProbingPipeline(Pipeline): + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a vision middleware pipeline for prediction + Args: + model: model id on modelscope hub. + Example: + >>> from modelscope.pipelines import pipeline + >>> recognition_pipeline = pipeline(self.task, self.model_id) + >>> file_name = 'data/test/images/\ + image_structured_model_probing_test_image.jpg' + >>> result = recognition_pipeline(file_name) + >>> print(f'recognition output: {result}.') + """ + super().__init__(model=model, **kwargs) + self.model.eval() + model_dir = os.path.join(model, 'food101-clip-vitl14-full.pt') + model_file = torch.load(model_dir) + self.label_map = model_file['meta_info']['label_map'] + logger.info('load model done') + + self.transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.48145466, 0.4578275, 0.40821073], + std=[0.26862954, 0.26130258, 0.27577711]) + ]) + + def preprocess(self, input: Input) -> Dict[str, Any]: + + img = LoadImage.convert_to_img(input) + + data = self.transform(img) + data = collate([data], samples_per_gpu=1) + if next(self.model.parameters()).is_cuda: + data = scatter(data, [next(self.model.parameters()).device])[0] + + return data + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + with torch.no_grad(): + results = self.model(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + scores = torch.softmax(inputs, dim=1).cpu() + labels = torch.argmax(scores, dim=1).cpu().tolist() + label_names = [self.label_map[label] for label in labels] + + return {OutputKeys.LABELS: label_names, OutputKeys.SCORES: scores} diff --git a/tests/pipelines/test_image_structured_model_probing.py b/tests/pipelines/test_image_structured_model_probing.py new file mode 100644 index 00000000..563e131c --- /dev/null +++ b/tests/pipelines/test_image_structured_model_probing.py @@ -0,0 +1,29 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import unittest + +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class ImageStructuredModelProbingTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.image_classification + self.model_id = 'damo/structured_model_probing' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + + recognition_pipeline = pipeline(self.task, self.model_id) + file_name = 'data/test/images/image_structured_model_probing_test_image.jpg' + result = recognition_pipeline(file_name) + + print(f'recognition output: {result}.') + + +if __name__ == '__main__': + unittest.main()