mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 02:59:24 +01:00
add structured model probing pipeline for image classification
增加对structured model probing pipeline的支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11376544
This commit is contained in:
committed by
wenmeng.zwm
parent
0967ece5a0
commit
7a65cf64e9
BIN
data/test/images/image_structured_model_probing_test_image.jpg
Normal file
BIN
data/test/images/image_structured_model_probing_test_image.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
@@ -80,6 +80,7 @@ class Models(object):
|
|||||||
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
|
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
|
||||||
vop_retrieval_model = 'vop-retrieval-model'
|
vop_retrieval_model = 'vop-retrieval-model'
|
||||||
ddcolor = 'ddcolor'
|
ddcolor = 'ddcolor'
|
||||||
|
image_probing_model = 'image-probing-model'
|
||||||
defrcn = 'defrcn'
|
defrcn = 'defrcn'
|
||||||
image_face_fusion = 'image-face-fusion'
|
image_face_fusion = 'image-face-fusion'
|
||||||
ddpm = 'ddpm'
|
ddpm = 'ddpm'
|
||||||
@@ -310,6 +311,7 @@ class Pipelines(object):
|
|||||||
video_panoptic_segmentation = 'video-panoptic-segmentation'
|
video_panoptic_segmentation = 'video-panoptic-segmentation'
|
||||||
vop_retrieval = 'vop-video-text-retrieval'
|
vop_retrieval = 'vop-video-text-retrieval'
|
||||||
ddcolor_image_colorization = 'ddcolor-image-colorization'
|
ddcolor_image_colorization = 'ddcolor-image-colorization'
|
||||||
|
image_structured_model_probing = 'image-structured-model-probing'
|
||||||
image_fewshot_detection = 'image-fewshot-detection'
|
image_fewshot_detection = 'image-fewshot-detection'
|
||||||
image_face_fusion = 'image-face-fusion'
|
image_face_fusion = 'image-face-fusion'
|
||||||
ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation'
|
ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation'
|
||||||
|
|||||||
@@ -9,13 +9,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
|||||||
image_denoise, image_inpainting, image_instance_segmentation,
|
image_denoise, image_inpainting, image_instance_segmentation,
|
||||||
image_matching, image_mvs_depth_estimation,
|
image_matching, image_mvs_depth_estimation,
|
||||||
image_panoptic_segmentation, image_portrait_enhancement,
|
image_panoptic_segmentation, image_portrait_enhancement,
|
||||||
image_quality_assessment_mos, image_reid_person,
|
image_probing_model, image_quality_assessment_mos,
|
||||||
image_restoration, image_semantic_segmentation,
|
image_reid_person, image_restoration,
|
||||||
image_to_image_generation, image_to_image_translation,
|
image_semantic_segmentation, image_to_image_generation,
|
||||||
language_guided_video_summarization, movie_scene_segmentation,
|
image_to_image_translation, language_guided_video_summarization,
|
||||||
object_detection, panorama_depth_estimation,
|
movie_scene_segmentation, object_detection,
|
||||||
pointcloud_sceneflow_estimation, product_retrieval_embedding,
|
panorama_depth_estimation, pointcloud_sceneflow_estimation,
|
||||||
realtime_object_detection, referring_video_object_segmentation,
|
product_retrieval_embedding, realtime_object_detection,
|
||||||
|
referring_video_object_segmentation,
|
||||||
robust_image_classification, salient_detection,
|
robust_image_classification, salient_detection,
|
||||||
shop_segmentation, super_resolution, video_frame_interpolation,
|
shop_segmentation, super_resolution, video_frame_interpolation,
|
||||||
video_object_segmentation, video_panoptic_segmentation,
|
video_object_segmentation, video_panoptic_segmentation,
|
||||||
|
|||||||
24
modelscope/models/cv/image_probing_model/__init__.py
Normal file
24
modelscope/models/cv/image_probing_model/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from modelscope.utils.import_utils import LazyImportModule
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
|
from .model import StructuredProbingModel
|
||||||
|
|
||||||
|
else:
|
||||||
|
_import_structure = {
|
||||||
|
'model': ['StructuredProbingModel'],
|
||||||
|
}
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.modules[__name__] = LazyImportModule(
|
||||||
|
__name__,
|
||||||
|
globals()['__file__'],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
extra_objects={},
|
||||||
|
)
|
||||||
308
modelscope/models/cv/image_probing_model/backbone.py
Normal file
308
modelscope/models/cv/image_probing_model/backbone.py
Normal file
@@ -0,0 +1,308 @@
|
|||||||
|
# The implementation is adopted from OpenAI-CLIP,
|
||||||
|
# made pubicly available under the MIT License at https://github.com/openai/CLIP
|
||||||
|
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
from collections import OrderedDict
|
||||||
|
from functools import reduce
|
||||||
|
from operator import mul
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from PIL import Image
|
||||||
|
from torchvision import models
|
||||||
|
|
||||||
|
from .utils import convert_weights, load_pretrained
|
||||||
|
|
||||||
|
|
||||||
|
class Bottleneck(nn.Module):
|
||||||
|
expansion = 4
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# all conv layers have stride 1. an avgpool is performed
|
||||||
|
# after the second convolution when stride > 1
|
||||||
|
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||||
|
self.bn1 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||||
|
self.bn2 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
|
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||||
|
|
||||||
|
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||||
|
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||||
|
|
||||||
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
self.downsample = None
|
||||||
|
self.stride = stride
|
||||||
|
|
||||||
|
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||||
|
# downsampling layer is prepended with an avgpool,
|
||||||
|
# and the subsequent convolution has stride 1
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
OrderedDict([('-1', nn.AvgPool2d(stride)),
|
||||||
|
('0',
|
||||||
|
nn.Conv2d(
|
||||||
|
inplanes,
|
||||||
|
planes * self.expansion,
|
||||||
|
1,
|
||||||
|
stride=1,
|
||||||
|
bias=False)),
|
||||||
|
('1', nn.BatchNorm2d(planes * self.expansion))]))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
identity = x
|
||||||
|
|
||||||
|
out = self.relu(self.bn1(self.conv1(x)))
|
||||||
|
out = self.relu(self.bn2(self.conv2(out)))
|
||||||
|
out = self.avgpool(out)
|
||||||
|
out = self.bn3(self.conv3(out))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
identity = self.downsample(x)
|
||||||
|
|
||||||
|
out += identity
|
||||||
|
out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionPool2d(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
spacial_dim: int,
|
||||||
|
embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
output_dim: int = None):
|
||||||
|
super().__init__()
|
||||||
|
self.positional_embedding = nn.Parameter(
|
||||||
|
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
||||||
|
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||||
|
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||||
|
self.num_heads = num_heads
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1],
|
||||||
|
x.shape[2] * x.shape[3]).permute(2, 0, 1)
|
||||||
|
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)
|
||||||
|
x = x + self.positional_embedding[:, None, :].to(x.dtype)
|
||||||
|
x, _ = F.multi_head_attention_forward(
|
||||||
|
query=x,
|
||||||
|
key=x,
|
||||||
|
value=x,
|
||||||
|
embed_dim_to_check=x.shape[-1],
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
q_proj_weight=self.q_proj.weight,
|
||||||
|
k_proj_weight=self.k_proj.weight,
|
||||||
|
v_proj_weight=self.v_proj.weight,
|
||||||
|
in_proj_weight=None,
|
||||||
|
in_proj_bias=torch.cat(
|
||||||
|
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||||
|
bias_k=None,
|
||||||
|
bias_v=None,
|
||||||
|
add_zero_attn=False,
|
||||||
|
dropout_p=0,
|
||||||
|
out_proj_weight=self.c_proj.weight,
|
||||||
|
out_proj_bias=self.c_proj.bias,
|
||||||
|
use_separate_proj_weight=True,
|
||||||
|
training=self.training,
|
||||||
|
need_weights=False)
|
||||||
|
|
||||||
|
return x[0]
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm(nn.LayerNorm):
|
||||||
|
"""Subclass torch's LayerNorm to handle fp16."""
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
orig_type = x.dtype
|
||||||
|
ret = super().forward(x.type(torch.float32))
|
||||||
|
return ret.type(orig_type)
|
||||||
|
|
||||||
|
|
||||||
|
class QuickGELU(nn.Module):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
return x * torch.sigmoid(1.702 * x)
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualAttentionBlock(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
d_model: int,
|
||||||
|
n_head: int,
|
||||||
|
attn_mask: torch.Tensor = None):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||||
|
self.ln_1 = LayerNorm(d_model)
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||||
|
('gelu', QuickGELU()),
|
||||||
|
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
||||||
|
self.ln_2 = LayerNorm(d_model)
|
||||||
|
self.attn_mask = attn_mask
|
||||||
|
|
||||||
|
def attention(self, x: torch.Tensor):
|
||||||
|
self.attn_mask = self.attn_mask.to(
|
||||||
|
dtype=x.dtype,
|
||||||
|
device=x.device) if self.attn_mask is not None else None
|
||||||
|
return self.attn(
|
||||||
|
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, idx):
|
||||||
|
features = {}
|
||||||
|
x_norm = self.ln_1(x)
|
||||||
|
features['layer_{}_pre_attn'.format(idx)] = x_norm.permute(1, 0, 2)
|
||||||
|
attn = self.attention(x_norm)
|
||||||
|
features['layer_{}_attn'.format(idx)] = attn.permute(1, 0, 2)
|
||||||
|
x = x + attn
|
||||||
|
mlp = self.mlp(self.ln_2(x))
|
||||||
|
features['layer_{}_mlp'.format(idx)] = mlp.permute(1, 0, 2)
|
||||||
|
x = x + mlp
|
||||||
|
return x, features
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
width: int,
|
||||||
|
layers: int,
|
||||||
|
heads: int,
|
||||||
|
attn_mask: torch.Tensor = None):
|
||||||
|
super().__init__()
|
||||||
|
self.width = width
|
||||||
|
self.layers = layers
|
||||||
|
self.resblocks = nn.ModuleList()
|
||||||
|
for i in range(layers):
|
||||||
|
block = ResidualAttentionBlock(width, heads, attn_mask)
|
||||||
|
self.resblocks.append(block)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
features = {}
|
||||||
|
for idx, block in enumerate(self.resblocks):
|
||||||
|
x, block_feats = block(x, idx)
|
||||||
|
features.update(block_feats)
|
||||||
|
return x, features
|
||||||
|
|
||||||
|
|
||||||
|
class VisualTransformer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, input_resolution: int, patch_size: int, width: int,
|
||||||
|
layers: int, heads: int, output_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
print(input_resolution, patch_size, width, layers, heads, output_dim)
|
||||||
|
self.input_resolution = input_resolution
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.conv1 = nn.Conv2d(
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=width,
|
||||||
|
kernel_size=patch_size,
|
||||||
|
stride=patch_size,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
scale = width**-0.5
|
||||||
|
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||||
|
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
||||||
|
(input_resolution // patch_size)**2 + 1, width))
|
||||||
|
self.ln_pre = LayerNorm(width)
|
||||||
|
|
||||||
|
self.transformer = Transformer(width, layers, heads)
|
||||||
|
|
||||||
|
self.ln_post = LayerNorm(width)
|
||||||
|
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, return_all=True):
|
||||||
|
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||||
|
x = x.reshape(x.shape[0], x.shape[1],
|
||||||
|
-1) # shape = [*, width, grid ** 2]
|
||||||
|
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||||
|
zeros = torch.zeros(
|
||||||
|
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
|
||||||
|
# shape = [*, grid ** 2 + 1, width]
|
||||||
|
x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1)
|
||||||
|
x = x + self.positional_embedding.to(x.dtype)
|
||||||
|
x = self.ln_pre(x)
|
||||||
|
|
||||||
|
x = x.permute(1, 0, 2) # NLD -> LND
|
||||||
|
x, features = self.transformer(x)
|
||||||
|
x = x.permute(1, 0, 2) # LND -> NLD
|
||||||
|
|
||||||
|
x = self.ln_post(x[:, 0, :])
|
||||||
|
|
||||||
|
if return_all:
|
||||||
|
features['pre_logits'] = x
|
||||||
|
return features
|
||||||
|
|
||||||
|
if self.proj is not None:
|
||||||
|
x = x @ self.proj
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CLIPNet(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, arch_name, pretrained, **kwargs):
|
||||||
|
super(CLIPNet, self).__init__()
|
||||||
|
|
||||||
|
if arch_name == 'CLIP_ViTB32':
|
||||||
|
self.clip = VisualTransformer(
|
||||||
|
input_resolution=224,
|
||||||
|
patch_size=32,
|
||||||
|
width=768,
|
||||||
|
layers=12,
|
||||||
|
heads=12,
|
||||||
|
output_dim=512)
|
||||||
|
|
||||||
|
elif arch_name in ('CLIP_ViTB16', 'CLIP_ViTB16_FP16'):
|
||||||
|
self.clip = VisualTransformer(
|
||||||
|
input_resolution=224,
|
||||||
|
patch_size=16,
|
||||||
|
width=768,
|
||||||
|
layers=12,
|
||||||
|
heads=12,
|
||||||
|
output_dim=512)
|
||||||
|
|
||||||
|
elif arch_name in ('CLIP_ViTL14', 'CLIP_ViTL14_FP16'):
|
||||||
|
self.clip = VisualTransformer(
|
||||||
|
input_resolution=224,
|
||||||
|
patch_size=14,
|
||||||
|
width=1024,
|
||||||
|
layers=24,
|
||||||
|
heads=16,
|
||||||
|
output_dim=768)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise KeyError(f'Unsupported arch_name for CLIP, {arch_name}')
|
||||||
|
|
||||||
|
def forward(self, input_data):
|
||||||
|
output = self.clip(input_data)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def CLIP(arch_name='CLIP_RN50',
|
||||||
|
use_pretrain=False,
|
||||||
|
load_from='',
|
||||||
|
state_dict=None,
|
||||||
|
**kwargs):
|
||||||
|
model = CLIPNet(arch_name=arch_name, pretrained=None, **kwargs)
|
||||||
|
if use_pretrain:
|
||||||
|
if arch_name.endswith('FP16'):
|
||||||
|
convert_weights(model.clip)
|
||||||
|
load_pretrained(model.clip, state_dict, load_from)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
class ProbingModel(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, feat_size, num_classes):
|
||||||
|
super(ProbingModel, self).__init__()
|
||||||
|
self.linear = torch.nn.Linear(feat_size, num_classes)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear(x)
|
||||||
93
modelscope/models/cv/image_probing_model/model.py
Normal file
93
modelscope/models/cv/image_probing_model/model.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from modelscope.metainfo import Models
|
||||||
|
from modelscope.models.base.base_torch_model import TorchModel
|
||||||
|
from modelscope.models.builder import MODELS
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from .backbone import CLIP, ProbingModel
|
||||||
|
|
||||||
|
|
||||||
|
@MODELS.register_module(
|
||||||
|
Tasks.image_classification, module_name=Models.image_probing_model)
|
||||||
|
class StructuredProbingModel(TorchModel):
|
||||||
|
"""
|
||||||
|
The implementation of 'Structured Model Probing: Empowering
|
||||||
|
Efficient Adaptation by Structured Regularization'.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_dir, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize a probing model.
|
||||||
|
Args:
|
||||||
|
model_dir: model id or path
|
||||||
|
"""
|
||||||
|
super(StructuredProbingModel, self).__init__()
|
||||||
|
model_dir = os.path.join(model_dir, 'food101-clip-vitl14-full.pt')
|
||||||
|
model_file = torch.load(model_dir)
|
||||||
|
self.feature_size = model_file['meta_info']['feature_size']
|
||||||
|
self.num_classes = model_file['meta_info']['num_classes']
|
||||||
|
self.backbone = CLIP(
|
||||||
|
'CLIP_ViTL14_FP16',
|
||||||
|
use_pretrain=True,
|
||||||
|
state_dict=model_file['backbone_model_state_dict'])
|
||||||
|
self.probing_model = ProbingModel(self.feature_size, self.num_classes)
|
||||||
|
self.probing_model.load_state_dict(
|
||||||
|
model_file['probing_model_state_dict'])
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward Function of SMP.
|
||||||
|
Args:
|
||||||
|
x: the input images (B, 3, H, W)
|
||||||
|
"""
|
||||||
|
|
||||||
|
keys = []
|
||||||
|
for idx in range(0, 24):
|
||||||
|
keys.append('layer_{}_pre_attn'.format(idx))
|
||||||
|
keys.append('layer_{}_attn'.format(idx))
|
||||||
|
keys.append('layer_{}_mlp'.format(idx))
|
||||||
|
keys.append('pre_logits')
|
||||||
|
features = self.backbone(x.half())
|
||||||
|
features_agg = []
|
||||||
|
for i in keys:
|
||||||
|
aggregated_feature = self.aggregate_token(features[i], 1024)
|
||||||
|
features_agg.append(aggregated_feature)
|
||||||
|
features_agg = torch.cat((features_agg), dim=1)
|
||||||
|
outputs = self.probing_model(features_agg.float())
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def aggregate_token(self, output, target_size):
|
||||||
|
"""
|
||||||
|
Aggregating features from tokens.
|
||||||
|
Args:
|
||||||
|
output: the output of intermidiant features
|
||||||
|
from a ViT model
|
||||||
|
target_size: target aggregated feature size
|
||||||
|
"""
|
||||||
|
if len(output.shape) == 3:
|
||||||
|
_, n_token, channels = output.shape
|
||||||
|
if channels >= target_size:
|
||||||
|
pool_size = 0
|
||||||
|
else:
|
||||||
|
n_groups = target_size / channels
|
||||||
|
pool_size = int(n_token / n_groups)
|
||||||
|
|
||||||
|
if pool_size > 0:
|
||||||
|
output = torch.permute(output, (0, 2, 1))
|
||||||
|
output = torch.nn.AvgPool1d(
|
||||||
|
kernel_size=pool_size, stride=pool_size)(
|
||||||
|
output)
|
||||||
|
output = torch.flatten(output, start_dim=1)
|
||||||
|
else:
|
||||||
|
output = torch.mean(output, dim=1)
|
||||||
|
output = torch.nn.functional.normalize(output, dim=1)
|
||||||
|
return output
|
||||||
148
modelscope/models/cv/image_probing_model/utils.py
Normal file
148
modelscope/models/cv/image_probing_model/utils.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretrained(model: torch.nn.Module,
|
||||||
|
state_dict,
|
||||||
|
local_path: str,
|
||||||
|
map_location='cpu',
|
||||||
|
logger=None,
|
||||||
|
sub_level=None):
|
||||||
|
return load_pretrained_dict(model, state_dict, logger, sub_level=sub_level)
|
||||||
|
|
||||||
|
|
||||||
|
def load_pretrained_dict(model: torch.nn.Module,
|
||||||
|
state_dict: dict,
|
||||||
|
logger=None,
|
||||||
|
sub_level=None):
|
||||||
|
"""
|
||||||
|
Load parameters to model with
|
||||||
|
1. Sub name by revise_keys For DataParallelModel or DistributeParallelModel.
|
||||||
|
2. Load 'state_dict' again if possible by key 'state_dict' or 'model_state'.
|
||||||
|
3. Take sub level keys from source, e.g. load 'backbone' part from a classifier into a backbone model.
|
||||||
|
4. Auto remove invalid parameters from source.
|
||||||
|
5. Log or warning if unexpected key exists or key misses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module):
|
||||||
|
state_dict (dict): dict of parameters
|
||||||
|
logger (logging.Logger, None):
|
||||||
|
sub_level (str, optional): If not None, parameters with key startswith sub_level will remove the prefix
|
||||||
|
to fit actual model keys. This action happens if user want to load sub module parameters
|
||||||
|
into a sub module model.
|
||||||
|
"""
|
||||||
|
revise_keys = [(r'^module\.', '')]
|
||||||
|
|
||||||
|
if 'state_dict' in state_dict:
|
||||||
|
state_dict = state_dict['state_dict']
|
||||||
|
if 'model_state' in state_dict:
|
||||||
|
state_dict = state_dict['model_state']
|
||||||
|
|
||||||
|
for p, r in revise_keys:
|
||||||
|
state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
|
||||||
|
|
||||||
|
if sub_level:
|
||||||
|
sub_level = sub_level if sub_level.endswith('.') else (sub_level + '.')
|
||||||
|
sub_level_len = len(sub_level)
|
||||||
|
state_dict = {
|
||||||
|
key[sub_level_len:]: value
|
||||||
|
for key, value in state_dict.items() if key.startswith(sub_level)
|
||||||
|
}
|
||||||
|
|
||||||
|
state_dict = _auto_drop_invalid(model, state_dict, logger=logger)
|
||||||
|
|
||||||
|
load_status = model.load_state_dict(state_dict, strict=False)
|
||||||
|
unexpected_keys = load_status.unexpected_keys
|
||||||
|
missing_keys = load_status.missing_keys
|
||||||
|
err_msgs = []
|
||||||
|
if unexpected_keys:
|
||||||
|
err_msgs.append('unexpected key in source '
|
||||||
|
f'state_dict: {", ".join(unexpected_keys)}\n')
|
||||||
|
if missing_keys:
|
||||||
|
err_msgs.append('missing key in source '
|
||||||
|
f'state_dict: {", ".join(missing_keys)}\n')
|
||||||
|
err_msgs = '\n'.join(err_msgs)
|
||||||
|
|
||||||
|
if len(err_msgs) > 0:
|
||||||
|
if logger:
|
||||||
|
logger.warning(err_msgs)
|
||||||
|
else:
|
||||||
|
import warnings
|
||||||
|
warnings.warn(err_msgs)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_weights(model: nn.Module):
|
||||||
|
"""
|
||||||
|
Convert applicable model parameters to fp16.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _convert_weights_to_fp16(layer):
|
||||||
|
if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||||
|
layer.weight.data = layer.weight.data.half()
|
||||||
|
if layer.bias is not None:
|
||||||
|
layer.bias.data = layer.bias.data.half()
|
||||||
|
|
||||||
|
if isinstance(layer, nn.MultiheadAttention):
|
||||||
|
for attr in [
|
||||||
|
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
|
||||||
|
'in_proj_bias', 'bias_k', 'bias_v'
|
||||||
|
]:
|
||||||
|
tensor = getattr(layer, attr)
|
||||||
|
if tensor is not None:
|
||||||
|
tensor.data = tensor.data.half()
|
||||||
|
|
||||||
|
for name in ['text_projection', 'proj']:
|
||||||
|
if hasattr(layer, name):
|
||||||
|
attr = getattr(layer, name)
|
||||||
|
if attr is not None:
|
||||||
|
attr.data = attr.data.half()
|
||||||
|
|
||||||
|
for name in ['prompt_embeddings']:
|
||||||
|
if hasattr(layer, name):
|
||||||
|
attr = getattr(layer, name)
|
||||||
|
if attr is not None:
|
||||||
|
attr.data = attr.data.half()
|
||||||
|
|
||||||
|
model.apply(_convert_weights_to_fp16)
|
||||||
|
|
||||||
|
|
||||||
|
def _auto_drop_invalid(model: torch.nn.Module, state_dict: dict, logger=None):
|
||||||
|
"""
|
||||||
|
Strip unmatched parameters in state_dict, e.g. shape not matched, type not matched.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (torch.nn.Module):
|
||||||
|
state_dict (dict):
|
||||||
|
logger (logging.Logger, None):
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new state dict.
|
||||||
|
"""
|
||||||
|
ret_dict = state_dict.copy()
|
||||||
|
invalid_msgs = []
|
||||||
|
for key, value in model.state_dict().items():
|
||||||
|
if key in state_dict:
|
||||||
|
# Check shape
|
||||||
|
new_value = state_dict[key]
|
||||||
|
if value.shape != new_value.shape:
|
||||||
|
invalid_msgs.append(
|
||||||
|
f'{key}: invalid shape, dst {value.shape} vs. src {new_value.shape}'
|
||||||
|
)
|
||||||
|
ret_dict.pop(key)
|
||||||
|
elif value.dtype != new_value.dtype:
|
||||||
|
invalid_msgs.append(
|
||||||
|
f'{key}: invalid dtype, dst {value.dtype} vs. src {new_value.dtype}'
|
||||||
|
)
|
||||||
|
ret_dict.pop(key)
|
||||||
|
if len(invalid_msgs) > 0:
|
||||||
|
warning_msg = 'ignore keys from source: \n' + '\n'.join(invalid_msgs)
|
||||||
|
if logger:
|
||||||
|
logger.warning(warning_msg)
|
||||||
|
else:
|
||||||
|
import warnings
|
||||||
|
warnings.warn(warning_msg)
|
||||||
|
return ret_dict
|
||||||
@@ -86,6 +86,7 @@ if TYPE_CHECKING:
|
|||||||
from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline
|
from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline
|
||||||
from .panorama_depth_estimation_pipeline import PanoramaDepthEstimationPipeline
|
from .panorama_depth_estimation_pipeline import PanoramaDepthEstimationPipeline
|
||||||
from .ddcolor_image_colorization_pipeline import DDColorImageColorizationPipeline
|
from .ddcolor_image_colorization_pipeline import DDColorImageColorizationPipeline
|
||||||
|
from .image_structured_model_probing_pipeline import ImageStructuredModelProbingPipeline
|
||||||
from .video_colorization_pipeline import VideoColorizationPipeline
|
from .video_colorization_pipeline import VideoColorizationPipeline
|
||||||
from .image_defrcn_fewshot_pipeline import ImageDefrcnDetectionPipeline
|
from .image_defrcn_fewshot_pipeline import ImageDefrcnDetectionPipeline
|
||||||
from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline
|
from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline
|
||||||
@@ -207,6 +208,9 @@ else:
|
|||||||
'ddcolor_image_colorization_pipeline': [
|
'ddcolor_image_colorization_pipeline': [
|
||||||
'DDColorImageColorizationPipeline'
|
'DDColorImageColorizationPipeline'
|
||||||
],
|
],
|
||||||
|
'image_structured_model_probing_pipeline': [
|
||||||
|
'ImageSturcturedModelProbingPipeline'
|
||||||
|
],
|
||||||
'video_colorization_pipeline': ['VideoColorizationPipeline'],
|
'video_colorization_pipeline': ['VideoColorizationPipeline'],
|
||||||
'image_defrcn_fewshot_pipeline': ['ImageDefrcnDetectionPipeline'],
|
'image_defrcn_fewshot_pipeline': ['ImageDefrcnDetectionPipeline'],
|
||||||
'image_quality_assessment_mos_pipeline': [
|
'image_quality_assessment_mos_pipeline': [
|
||||||
|
|||||||
@@ -0,0 +1,79 @@
|
|||||||
|
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torchvision.transforms as transforms
|
||||||
|
from mmcv.parallel import collate, scatter
|
||||||
|
|
||||||
|
from modelscope.metainfo import Pipelines
|
||||||
|
from modelscope.outputs import OutputKeys
|
||||||
|
from modelscope.pipelines.base import Input, Pipeline
|
||||||
|
from modelscope.pipelines.builder import PIPELINES
|
||||||
|
from modelscope.preprocessors import LoadImage
|
||||||
|
from modelscope.utils.config import Config
|
||||||
|
from modelscope.utils.constant import ModelFile, Tasks
|
||||||
|
from modelscope.utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@PIPELINES.register_module(
|
||||||
|
Tasks.image_classification,
|
||||||
|
module_name=Pipelines.image_structured_model_probing)
|
||||||
|
class ImageStructuredModelProbingPipeline(Pipeline):
|
||||||
|
|
||||||
|
def __init__(self, model: str, **kwargs):
|
||||||
|
"""
|
||||||
|
use `model` to create a vision middleware pipeline for prediction
|
||||||
|
Args:
|
||||||
|
model: model id on modelscope hub.
|
||||||
|
Example:
|
||||||
|
>>> from modelscope.pipelines import pipeline
|
||||||
|
>>> recognition_pipeline = pipeline(self.task, self.model_id)
|
||||||
|
>>> file_name = 'data/test/images/\
|
||||||
|
image_structured_model_probing_test_image.jpg'
|
||||||
|
>>> result = recognition_pipeline(file_name)
|
||||||
|
>>> print(f'recognition output: {result}.')
|
||||||
|
"""
|
||||||
|
super().__init__(model=model, **kwargs)
|
||||||
|
self.model.eval()
|
||||||
|
model_dir = os.path.join(model, 'food101-clip-vitl14-full.pt')
|
||||||
|
model_file = torch.load(model_dir)
|
||||||
|
self.label_map = model_file['meta_info']['label_map']
|
||||||
|
logger.info('load model done')
|
||||||
|
|
||||||
|
self.transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize(
|
||||||
|
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||||
|
std=[0.26862954, 0.26130258, 0.27577711])
|
||||||
|
])
|
||||||
|
|
||||||
|
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||||
|
|
||||||
|
img = LoadImage.convert_to_img(input)
|
||||||
|
|
||||||
|
data = self.transform(img)
|
||||||
|
data = collate([data], samples_per_gpu=1)
|
||||||
|
if next(self.model.parameters()).is_cuda:
|
||||||
|
data = scatter(data, [next(self.model.parameters()).device])[0]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
with torch.no_grad():
|
||||||
|
results = self.model(input)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
scores = torch.softmax(inputs, dim=1).cpu()
|
||||||
|
labels = torch.argmax(scores, dim=1).cpu().tolist()
|
||||||
|
label_names = [self.label_map[label] for label in labels]
|
||||||
|
|
||||||
|
return {OutputKeys.LABELS: label_names, OutputKeys.SCORES: scores}
|
||||||
29
tests/pipelines/test_image_structured_model_probing.py
Normal file
29
tests/pipelines/test_image_structured_model_probing.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from modelscope.pipelines import pipeline
|
||||||
|
from modelscope.utils.constant import Tasks
|
||||||
|
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||||
|
from modelscope.utils.test_utils import test_level
|
||||||
|
|
||||||
|
|
||||||
|
class ImageStructuredModelProbingTest(unittest.TestCase,
|
||||||
|
DemoCompatibilityCheck):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.task = Tasks.image_classification
|
||||||
|
self.model_id = 'damo/structured_model_probing'
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||||
|
def test_run_modelhub(self):
|
||||||
|
|
||||||
|
recognition_pipeline = pipeline(self.task, self.model_id)
|
||||||
|
file_name = 'data/test/images/image_structured_model_probing_test_image.jpg'
|
||||||
|
result = recognition_pipeline(file_name)
|
||||||
|
|
||||||
|
print(f'recognition output: {result}.')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user