add structured model probing pipeline for image classification

增加对structured model probing pipeline的支持

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11376544
This commit is contained in:
wuzhifan.wzf
2023-02-08 08:29:56 +00:00
committed by wenmeng.zwm
parent 0967ece5a0
commit 7a65cf64e9
10 changed files with 695 additions and 7 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 28 KiB

View File

@@ -80,6 +80,7 @@ class Models(object):
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation' image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
vop_retrieval_model = 'vop-retrieval-model' vop_retrieval_model = 'vop-retrieval-model'
ddcolor = 'ddcolor' ddcolor = 'ddcolor'
image_probing_model = 'image-probing-model'
defrcn = 'defrcn' defrcn = 'defrcn'
image_face_fusion = 'image-face-fusion' image_face_fusion = 'image-face-fusion'
ddpm = 'ddpm' ddpm = 'ddpm'
@@ -310,6 +311,7 @@ class Pipelines(object):
video_panoptic_segmentation = 'video-panoptic-segmentation' video_panoptic_segmentation = 'video-panoptic-segmentation'
vop_retrieval = 'vop-video-text-retrieval' vop_retrieval = 'vop-video-text-retrieval'
ddcolor_image_colorization = 'ddcolor-image-colorization' ddcolor_image_colorization = 'ddcolor-image-colorization'
image_structured_model_probing = 'image-structured-model-probing'
image_fewshot_detection = 'image-fewshot-detection' image_fewshot_detection = 'image-fewshot-detection'
image_face_fusion = 'image-face-fusion' image_face_fusion = 'image-face-fusion'
ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation' ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation'

View File

@@ -9,13 +9,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_denoise, image_inpainting, image_instance_segmentation, image_denoise, image_inpainting, image_instance_segmentation,
image_matching, image_mvs_depth_estimation, image_matching, image_mvs_depth_estimation,
image_panoptic_segmentation, image_portrait_enhancement, image_panoptic_segmentation, image_portrait_enhancement,
image_quality_assessment_mos, image_reid_person, image_probing_model, image_quality_assessment_mos,
image_restoration, image_semantic_segmentation, image_reid_person, image_restoration,
image_to_image_generation, image_to_image_translation, image_semantic_segmentation, image_to_image_generation,
language_guided_video_summarization, movie_scene_segmentation, image_to_image_translation, language_guided_video_summarization,
object_detection, panorama_depth_estimation, movie_scene_segmentation, object_detection,
pointcloud_sceneflow_estimation, product_retrieval_embedding, panorama_depth_estimation, pointcloud_sceneflow_estimation,
realtime_object_detection, referring_video_object_segmentation, product_retrieval_embedding, realtime_object_detection,
referring_video_object_segmentation,
robust_image_classification, salient_detection, robust_image_classification, salient_detection,
shop_segmentation, super_resolution, video_frame_interpolation, shop_segmentation, super_resolution, video_frame_interpolation,
video_object_segmentation, video_panoptic_segmentation, video_object_segmentation, video_panoptic_segmentation,

View File

@@ -0,0 +1,24 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .model import StructuredProbingModel
else:
_import_structure = {
'model': ['StructuredProbingModel'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,308 @@
# The implementation is adopted from OpenAI-CLIP,
# made pubicly available under the MIT License at https://github.com/openai/CLIP
import math
import sys
from collections import OrderedDict
from functools import reduce
from operator import mul
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import models
from .utils import convert_weights, load_pretrained
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed
# after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool,
# and the subsequent convolution has stride 1
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1],
x.shape[2] * x.shape[3]).permute(2, 0, 1)
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)
x = x + self.positional_embedding[:, None, :].to(x.dtype)
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x[0]
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor, idx):
features = {}
x_norm = self.ln_1(x)
features['layer_{}_pre_attn'.format(idx)] = x_norm.permute(1, 0, 2)
attn = self.attention(x_norm)
features['layer_{}_attn'.format(idx)] = attn.permute(1, 0, 2)
x = x + attn
mlp = self.mlp(self.ln_2(x))
features['layer_{}_mlp'.format(idx)] = mlp.permute(1, 0, 2)
x = x + mlp
return x, features
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.ModuleList()
for i in range(layers):
block = ResidualAttentionBlock(width, heads, attn_mask)
self.resblocks.append(block)
def forward(self, x: torch.Tensor):
features = {}
for idx, block in enumerate(self.resblocks):
x, block_feats = block(x, idx)
features.update(block_feats)
return x, features
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__()
print(input_resolution, patch_size, width, layers, heads, output_dim)
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor, return_all=True):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
zeros = torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
# shape = [*, grid ** 2 + 1, width]
x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1)
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x, features = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if return_all:
features['pre_logits'] = x
return features
if self.proj is not None:
x = x @ self.proj
return x
class CLIPNet(nn.Module):
def __init__(self, arch_name, pretrained, **kwargs):
super(CLIPNet, self).__init__()
if arch_name == 'CLIP_ViTB32':
self.clip = VisualTransformer(
input_resolution=224,
patch_size=32,
width=768,
layers=12,
heads=12,
output_dim=512)
elif arch_name in ('CLIP_ViTB16', 'CLIP_ViTB16_FP16'):
self.clip = VisualTransformer(
input_resolution=224,
patch_size=16,
width=768,
layers=12,
heads=12,
output_dim=512)
elif arch_name in ('CLIP_ViTL14', 'CLIP_ViTL14_FP16'):
self.clip = VisualTransformer(
input_resolution=224,
patch_size=14,
width=1024,
layers=24,
heads=16,
output_dim=768)
else:
raise KeyError(f'Unsupported arch_name for CLIP, {arch_name}')
def forward(self, input_data):
output = self.clip(input_data)
return output
def CLIP(arch_name='CLIP_RN50',
use_pretrain=False,
load_from='',
state_dict=None,
**kwargs):
model = CLIPNet(arch_name=arch_name, pretrained=None, **kwargs)
if use_pretrain:
if arch_name.endswith('FP16'):
convert_weights(model.clip)
load_pretrained(model.clip, state_dict, load_from)
return model
class ProbingModel(torch.nn.Module):
def __init__(self, feat_size, num_classes):
super(ProbingModel, self).__init__()
self.linear = torch.nn.Linear(feat_size, num_classes)
def forward(self, x):
return self.linear(x)

View File

@@ -0,0 +1,93 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
from typing import Any, Dict
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from .backbone import CLIP, ProbingModel
@MODELS.register_module(
Tasks.image_classification, module_name=Models.image_probing_model)
class StructuredProbingModel(TorchModel):
"""
The implementation of 'Structured Model Probing: Empowering
Efficient Adaptation by Structured Regularization'.
"""
def __init__(self, model_dir, *args, **kwargs):
"""
Initialize a probing model.
Args:
model_dir: model id or path
"""
super(StructuredProbingModel, self).__init__()
model_dir = os.path.join(model_dir, 'food101-clip-vitl14-full.pt')
model_file = torch.load(model_dir)
self.feature_size = model_file['meta_info']['feature_size']
self.num_classes = model_file['meta_info']['num_classes']
self.backbone = CLIP(
'CLIP_ViTL14_FP16',
use_pretrain=True,
state_dict=model_file['backbone_model_state_dict'])
self.probing_model = ProbingModel(self.feature_size, self.num_classes)
self.probing_model.load_state_dict(
model_file['probing_model_state_dict'])
def forward(self, x):
"""
Forward Function of SMP.
Args:
x: the input images (B, 3, H, W)
"""
keys = []
for idx in range(0, 24):
keys.append('layer_{}_pre_attn'.format(idx))
keys.append('layer_{}_attn'.format(idx))
keys.append('layer_{}_mlp'.format(idx))
keys.append('pre_logits')
features = self.backbone(x.half())
features_agg = []
for i in keys:
aggregated_feature = self.aggregate_token(features[i], 1024)
features_agg.append(aggregated_feature)
features_agg = torch.cat((features_agg), dim=1)
outputs = self.probing_model(features_agg.float())
return outputs
def aggregate_token(self, output, target_size):
"""
Aggregating features from tokens.
Args:
output: the output of intermidiant features
from a ViT model
target_size: target aggregated feature size
"""
if len(output.shape) == 3:
_, n_token, channels = output.shape
if channels >= target_size:
pool_size = 0
else:
n_groups = target_size / channels
pool_size = int(n_token / n_groups)
if pool_size > 0:
output = torch.permute(output, (0, 2, 1))
output = torch.nn.AvgPool1d(
kernel_size=pool_size, stride=pool_size)(
output)
output = torch.flatten(output, start_dim=1)
else:
output = torch.mean(output, dim=1)
output = torch.nn.functional.normalize(output, dim=1)
return output

View File

@@ -0,0 +1,148 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import re
import torch
import torch.nn as nn
def load_pretrained(model: torch.nn.Module,
state_dict,
local_path: str,
map_location='cpu',
logger=None,
sub_level=None):
return load_pretrained_dict(model, state_dict, logger, sub_level=sub_level)
def load_pretrained_dict(model: torch.nn.Module,
state_dict: dict,
logger=None,
sub_level=None):
"""
Load parameters to model with
1. Sub name by revise_keys For DataParallelModel or DistributeParallelModel.
2. Load 'state_dict' again if possible by key 'state_dict' or 'model_state'.
3. Take sub level keys from source, e.g. load 'backbone' part from a classifier into a backbone model.
4. Auto remove invalid parameters from source.
5. Log or warning if unexpected key exists or key misses.
Args:
model (torch.nn.Module):
state_dict (dict): dict of parameters
logger (logging.Logger, None):
sub_level (str, optional): If not None, parameters with key startswith sub_level will remove the prefix
to fit actual model keys. This action happens if user want to load sub module parameters
into a sub module model.
"""
revise_keys = [(r'^module\.', '')]
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
if 'model_state' in state_dict:
state_dict = state_dict['model_state']
for p, r in revise_keys:
state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
if sub_level:
sub_level = sub_level if sub_level.endswith('.') else (sub_level + '.')
sub_level_len = len(sub_level)
state_dict = {
key[sub_level_len:]: value
for key, value in state_dict.items() if key.startswith(sub_level)
}
state_dict = _auto_drop_invalid(model, state_dict, logger=logger)
load_status = model.load_state_dict(state_dict, strict=False)
unexpected_keys = load_status.unexpected_keys
missing_keys = load_status.missing_keys
err_msgs = []
if unexpected_keys:
err_msgs.append('unexpected key in source '
f'state_dict: {", ".join(unexpected_keys)}\n')
if missing_keys:
err_msgs.append('missing key in source '
f'state_dict: {", ".join(missing_keys)}\n')
err_msgs = '\n'.join(err_msgs)
if len(err_msgs) > 0:
if logger:
logger.warning(err_msgs)
else:
import warnings
warnings.warn(err_msgs)
def convert_weights(model: nn.Module):
"""
Convert applicable model parameters to fp16.
"""
def _convert_weights_to_fp16(layer):
if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)):
layer.weight.data = layer.weight.data.half()
if layer.bias is not None:
layer.bias.data = layer.bias.data.half()
if isinstance(layer, nn.MultiheadAttention):
for attr in [
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
'in_proj_bias', 'bias_k', 'bias_v'
]:
tensor = getattr(layer, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ['text_projection', 'proj']:
if hasattr(layer, name):
attr = getattr(layer, name)
if attr is not None:
attr.data = attr.data.half()
for name in ['prompt_embeddings']:
if hasattr(layer, name):
attr = getattr(layer, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def _auto_drop_invalid(model: torch.nn.Module, state_dict: dict, logger=None):
"""
Strip unmatched parameters in state_dict, e.g. shape not matched, type not matched.
Args:
model (torch.nn.Module):
state_dict (dict):
logger (logging.Logger, None):
Returns:
A new state dict.
"""
ret_dict = state_dict.copy()
invalid_msgs = []
for key, value in model.state_dict().items():
if key in state_dict:
# Check shape
new_value = state_dict[key]
if value.shape != new_value.shape:
invalid_msgs.append(
f'{key}: invalid shape, dst {value.shape} vs. src {new_value.shape}'
)
ret_dict.pop(key)
elif value.dtype != new_value.dtype:
invalid_msgs.append(
f'{key}: invalid dtype, dst {value.dtype} vs. src {new_value.dtype}'
)
ret_dict.pop(key)
if len(invalid_msgs) > 0:
warning_msg = 'ignore keys from source: \n' + '\n'.join(invalid_msgs)
if logger:
logger.warning(warning_msg)
else:
import warnings
warnings.warn(warning_msg)
return ret_dict

View File

@@ -86,6 +86,7 @@ if TYPE_CHECKING:
from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline
from .panorama_depth_estimation_pipeline import PanoramaDepthEstimationPipeline from .panorama_depth_estimation_pipeline import PanoramaDepthEstimationPipeline
from .ddcolor_image_colorization_pipeline import DDColorImageColorizationPipeline from .ddcolor_image_colorization_pipeline import DDColorImageColorizationPipeline
from .image_structured_model_probing_pipeline import ImageStructuredModelProbingPipeline
from .video_colorization_pipeline import VideoColorizationPipeline from .video_colorization_pipeline import VideoColorizationPipeline
from .image_defrcn_fewshot_pipeline import ImageDefrcnDetectionPipeline from .image_defrcn_fewshot_pipeline import ImageDefrcnDetectionPipeline
from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline
@@ -207,6 +208,9 @@ else:
'ddcolor_image_colorization_pipeline': [ 'ddcolor_image_colorization_pipeline': [
'DDColorImageColorizationPipeline' 'DDColorImageColorizationPipeline'
], ],
'image_structured_model_probing_pipeline': [
'ImageSturcturedModelProbingPipeline'
],
'video_colorization_pipeline': ['VideoColorizationPipeline'], 'video_colorization_pipeline': ['VideoColorizationPipeline'],
'image_defrcn_fewshot_pipeline': ['ImageDefrcnDetectionPipeline'], 'image_defrcn_fewshot_pipeline': ['ImageDefrcnDetectionPipeline'],
'image_quality_assessment_mos_pipeline': [ 'image_quality_assessment_mos_pipeline': [

View File

@@ -0,0 +1,79 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import math
import os
import os.path as osp
from typing import Any, Dict
import numpy as np
import torch
import torchvision.transforms as transforms
from mmcv.parallel import collate, scatter
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.image_classification,
module_name=Pipelines.image_structured_model_probing)
class ImageStructuredModelProbingPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
use `model` to create a vision middleware pipeline for prediction
Args:
model: model id on modelscope hub.
Example:
>>> from modelscope.pipelines import pipeline
>>> recognition_pipeline = pipeline(self.task, self.model_id)
>>> file_name = 'data/test/images/\
image_structured_model_probing_test_image.jpg'
>>> result = recognition_pipeline(file_name)
>>> print(f'recognition output: {result}.')
"""
super().__init__(model=model, **kwargs)
self.model.eval()
model_dir = os.path.join(model, 'food101-clip-vitl14-full.pt')
model_file = torch.load(model_dir)
self.label_map = model_file['meta_info']['label_map']
logger.info('load model done')
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.48145466, 0.4578275, 0.40821073],
std=[0.26862954, 0.26130258, 0.27577711])
])
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_img(input)
data = self.transform(img)
data = collate([data], samples_per_gpu=1)
if next(self.model.parameters()).is_cuda:
data = scatter(data, [next(self.model.parameters()).device])[0]
return data
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
results = self.model(input)
return results
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
scores = torch.softmax(inputs, dim=1).cpu()
labels = torch.argmax(scores, dim=1).cpu().tolist()
label_names = [self.label_map[label] for label in labels]
return {OutputKeys.LABELS: label_names, OutputKeys.SCORES: scores}

View File

@@ -0,0 +1,29 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class ImageStructuredModelProbingTest(unittest.TestCase,
DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.image_classification
self.model_id = 'damo/structured_model_probing'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
recognition_pipeline = pipeline(self.task, self.model_id)
file_name = 'data/test/images/image_structured_model_probing_test_image.jpg'
result = recognition_pipeline(file_name)
print(f'recognition output: {result}.')
if __name__ == '__main__':
unittest.main()