mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-22 02:59:24 +01:00
add structured model probing pipeline for image classification
增加对structured model probing pipeline的支持 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11376544
This commit is contained in:
committed by
wenmeng.zwm
parent
0967ece5a0
commit
7a65cf64e9
BIN
data/test/images/image_structured_model_probing_test_image.jpg
Normal file
BIN
data/test/images/image_structured_model_probing_test_image.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
@@ -80,6 +80,7 @@ class Models(object):
|
||||
image_casmvs_depth_estimation = 'image-casmvs-depth-estimation'
|
||||
vop_retrieval_model = 'vop-retrieval-model'
|
||||
ddcolor = 'ddcolor'
|
||||
image_probing_model = 'image-probing-model'
|
||||
defrcn = 'defrcn'
|
||||
image_face_fusion = 'image-face-fusion'
|
||||
ddpm = 'ddpm'
|
||||
@@ -310,6 +311,7 @@ class Pipelines(object):
|
||||
video_panoptic_segmentation = 'video-panoptic-segmentation'
|
||||
vop_retrieval = 'vop-video-text-retrieval'
|
||||
ddcolor_image_colorization = 'ddcolor-image-colorization'
|
||||
image_structured_model_probing = 'image-structured-model-probing'
|
||||
image_fewshot_detection = 'image-fewshot-detection'
|
||||
image_face_fusion = 'image-face-fusion'
|
||||
ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation'
|
||||
|
||||
@@ -9,13 +9,14 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
image_denoise, image_inpainting, image_instance_segmentation,
|
||||
image_matching, image_mvs_depth_estimation,
|
||||
image_panoptic_segmentation, image_portrait_enhancement,
|
||||
image_quality_assessment_mos, image_reid_person,
|
||||
image_restoration, image_semantic_segmentation,
|
||||
image_to_image_generation, image_to_image_translation,
|
||||
language_guided_video_summarization, movie_scene_segmentation,
|
||||
object_detection, panorama_depth_estimation,
|
||||
pointcloud_sceneflow_estimation, product_retrieval_embedding,
|
||||
realtime_object_detection, referring_video_object_segmentation,
|
||||
image_probing_model, image_quality_assessment_mos,
|
||||
image_reid_person, image_restoration,
|
||||
image_semantic_segmentation, image_to_image_generation,
|
||||
image_to_image_translation, language_guided_video_summarization,
|
||||
movie_scene_segmentation, object_detection,
|
||||
panorama_depth_estimation, pointcloud_sceneflow_estimation,
|
||||
product_retrieval_embedding, realtime_object_detection,
|
||||
referring_video_object_segmentation,
|
||||
robust_image_classification, salient_detection,
|
||||
shop_segmentation, super_resolution, video_frame_interpolation,
|
||||
video_object_segmentation, video_panoptic_segmentation,
|
||||
|
||||
24
modelscope/models/cv/image_probing_model/__init__.py
Normal file
24
modelscope/models/cv/image_probing_model/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from .model import StructuredProbingModel
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'model': ['StructuredProbingModel'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
308
modelscope/models/cv/image_probing_model/backbone.py
Normal file
308
modelscope/models/cv/image_probing_model/backbone.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# The implementation is adopted from OpenAI-CLIP,
|
||||
# made pubicly available under the MIT License at https://github.com/openai/CLIP
|
||||
|
||||
import math
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from functools import reduce
|
||||
from operator import mul
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision import models
|
||||
|
||||
from .utils import convert_weights, load_pretrained
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed
|
||||
# after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool,
|
||||
# and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(
|
||||
OrderedDict([('-1', nn.AvgPool2d(stride)),
|
||||
('0',
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * self.expansion,
|
||||
1,
|
||||
stride=1,
|
||||
bias=False)),
|
||||
('1', nn.BatchNorm2d(planes * self.expansion))]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu(self.bn1(self.conv1(x)))
|
||||
out = self.relu(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
x.shape[2] * x.shape[3]).permute(2, 0, 1)
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0)
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype)
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x,
|
||||
key=x,
|
||||
value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat(
|
||||
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False)
|
||||
|
||||
return x[0]
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||
('gelu', QuickGELU()),
|
||||
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(
|
||||
dtype=x.dtype,
|
||||
device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(
|
||||
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor, idx):
|
||||
features = {}
|
||||
x_norm = self.ln_1(x)
|
||||
features['layer_{}_pre_attn'.format(idx)] = x_norm.permute(1, 0, 2)
|
||||
attn = self.attention(x_norm)
|
||||
features['layer_{}_attn'.format(idx)] = attn.permute(1, 0, 2)
|
||||
x = x + attn
|
||||
mlp = self.mlp(self.ln_2(x))
|
||||
features['layer_{}_mlp'.format(idx)] = mlp.permute(1, 0, 2)
|
||||
x = x + mlp
|
||||
return x, features
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(layers):
|
||||
block = ResidualAttentionBlock(width, heads, attn_mask)
|
||||
self.resblocks.append(block)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
features = {}
|
||||
for idx, block in enumerate(self.resblocks):
|
||||
x, block_feats = block(x, idx)
|
||||
features.update(block_feats)
|
||||
return x, features
|
||||
|
||||
|
||||
class VisualTransformer(nn.Module):
|
||||
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int,
|
||||
layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
print(input_resolution, patch_size, width, layers, heads, output_dim)
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False)
|
||||
|
||||
scale = width**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
||||
(input_resolution // patch_size)**2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor, return_all=True):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
zeros = torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
|
||||
# shape = [*, grid ** 2 + 1, width]
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + zeros, x], dim=1)
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x, features = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x[:, 0, :])
|
||||
|
||||
if return_all:
|
||||
features['pre_logits'] = x
|
||||
return features
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPNet(nn.Module):
|
||||
|
||||
def __init__(self, arch_name, pretrained, **kwargs):
|
||||
super(CLIPNet, self).__init__()
|
||||
|
||||
if arch_name == 'CLIP_ViTB32':
|
||||
self.clip = VisualTransformer(
|
||||
input_resolution=224,
|
||||
patch_size=32,
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
output_dim=512)
|
||||
|
||||
elif arch_name in ('CLIP_ViTB16', 'CLIP_ViTB16_FP16'):
|
||||
self.clip = VisualTransformer(
|
||||
input_resolution=224,
|
||||
patch_size=16,
|
||||
width=768,
|
||||
layers=12,
|
||||
heads=12,
|
||||
output_dim=512)
|
||||
|
||||
elif arch_name in ('CLIP_ViTL14', 'CLIP_ViTL14_FP16'):
|
||||
self.clip = VisualTransformer(
|
||||
input_resolution=224,
|
||||
patch_size=14,
|
||||
width=1024,
|
||||
layers=24,
|
||||
heads=16,
|
||||
output_dim=768)
|
||||
|
||||
else:
|
||||
raise KeyError(f'Unsupported arch_name for CLIP, {arch_name}')
|
||||
|
||||
def forward(self, input_data):
|
||||
output = self.clip(input_data)
|
||||
return output
|
||||
|
||||
|
||||
def CLIP(arch_name='CLIP_RN50',
|
||||
use_pretrain=False,
|
||||
load_from='',
|
||||
state_dict=None,
|
||||
**kwargs):
|
||||
model = CLIPNet(arch_name=arch_name, pretrained=None, **kwargs)
|
||||
if use_pretrain:
|
||||
if arch_name.endswith('FP16'):
|
||||
convert_weights(model.clip)
|
||||
load_pretrained(model.clip, state_dict, load_from)
|
||||
return model
|
||||
|
||||
|
||||
class ProbingModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, feat_size, num_classes):
|
||||
super(ProbingModel, self).__init__()
|
||||
self.linear = torch.nn.Linear(feat_size, num_classes)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
93
modelscope/models/cv/image_probing_model/model.py
Normal file
93
modelscope/models/cv/image_probing_model/model.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import json
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .backbone import CLIP, ProbingModel
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_classification, module_name=Models.image_probing_model)
|
||||
class StructuredProbingModel(TorchModel):
|
||||
"""
|
||||
The implementation of 'Structured Model Probing: Empowering
|
||||
Efficient Adaptation by Structured Regularization'.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
"""
|
||||
Initialize a probing model.
|
||||
Args:
|
||||
model_dir: model id or path
|
||||
"""
|
||||
super(StructuredProbingModel, self).__init__()
|
||||
model_dir = os.path.join(model_dir, 'food101-clip-vitl14-full.pt')
|
||||
model_file = torch.load(model_dir)
|
||||
self.feature_size = model_file['meta_info']['feature_size']
|
||||
self.num_classes = model_file['meta_info']['num_classes']
|
||||
self.backbone = CLIP(
|
||||
'CLIP_ViTL14_FP16',
|
||||
use_pretrain=True,
|
||||
state_dict=model_file['backbone_model_state_dict'])
|
||||
self.probing_model = ProbingModel(self.feature_size, self.num_classes)
|
||||
self.probing_model.load_state_dict(
|
||||
model_file['probing_model_state_dict'])
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward Function of SMP.
|
||||
Args:
|
||||
x: the input images (B, 3, H, W)
|
||||
"""
|
||||
|
||||
keys = []
|
||||
for idx in range(0, 24):
|
||||
keys.append('layer_{}_pre_attn'.format(idx))
|
||||
keys.append('layer_{}_attn'.format(idx))
|
||||
keys.append('layer_{}_mlp'.format(idx))
|
||||
keys.append('pre_logits')
|
||||
features = self.backbone(x.half())
|
||||
features_agg = []
|
||||
for i in keys:
|
||||
aggregated_feature = self.aggregate_token(features[i], 1024)
|
||||
features_agg.append(aggregated_feature)
|
||||
features_agg = torch.cat((features_agg), dim=1)
|
||||
outputs = self.probing_model(features_agg.float())
|
||||
return outputs
|
||||
|
||||
def aggregate_token(self, output, target_size):
|
||||
"""
|
||||
Aggregating features from tokens.
|
||||
Args:
|
||||
output: the output of intermidiant features
|
||||
from a ViT model
|
||||
target_size: target aggregated feature size
|
||||
"""
|
||||
if len(output.shape) == 3:
|
||||
_, n_token, channels = output.shape
|
||||
if channels >= target_size:
|
||||
pool_size = 0
|
||||
else:
|
||||
n_groups = target_size / channels
|
||||
pool_size = int(n_token / n_groups)
|
||||
|
||||
if pool_size > 0:
|
||||
output = torch.permute(output, (0, 2, 1))
|
||||
output = torch.nn.AvgPool1d(
|
||||
kernel_size=pool_size, stride=pool_size)(
|
||||
output)
|
||||
output = torch.flatten(output, start_dim=1)
|
||||
else:
|
||||
output = torch.mean(output, dim=1)
|
||||
output = torch.nn.functional.normalize(output, dim=1)
|
||||
return output
|
||||
148
modelscope/models/cv/image_probing_model/utils.py
Normal file
148
modelscope/models/cv/image_probing_model/utils.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import re
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def load_pretrained(model: torch.nn.Module,
|
||||
state_dict,
|
||||
local_path: str,
|
||||
map_location='cpu',
|
||||
logger=None,
|
||||
sub_level=None):
|
||||
return load_pretrained_dict(model, state_dict, logger, sub_level=sub_level)
|
||||
|
||||
|
||||
def load_pretrained_dict(model: torch.nn.Module,
|
||||
state_dict: dict,
|
||||
logger=None,
|
||||
sub_level=None):
|
||||
"""
|
||||
Load parameters to model with
|
||||
1. Sub name by revise_keys For DataParallelModel or DistributeParallelModel.
|
||||
2. Load 'state_dict' again if possible by key 'state_dict' or 'model_state'.
|
||||
3. Take sub level keys from source, e.g. load 'backbone' part from a classifier into a backbone model.
|
||||
4. Auto remove invalid parameters from source.
|
||||
5. Log or warning if unexpected key exists or key misses.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module):
|
||||
state_dict (dict): dict of parameters
|
||||
logger (logging.Logger, None):
|
||||
sub_level (str, optional): If not None, parameters with key startswith sub_level will remove the prefix
|
||||
to fit actual model keys. This action happens if user want to load sub module parameters
|
||||
into a sub module model.
|
||||
"""
|
||||
revise_keys = [(r'^module\.', '')]
|
||||
|
||||
if 'state_dict' in state_dict:
|
||||
state_dict = state_dict['state_dict']
|
||||
if 'model_state' in state_dict:
|
||||
state_dict = state_dict['model_state']
|
||||
|
||||
for p, r in revise_keys:
|
||||
state_dict = {re.sub(p, r, k): v for k, v in state_dict.items()}
|
||||
|
||||
if sub_level:
|
||||
sub_level = sub_level if sub_level.endswith('.') else (sub_level + '.')
|
||||
sub_level_len = len(sub_level)
|
||||
state_dict = {
|
||||
key[sub_level_len:]: value
|
||||
for key, value in state_dict.items() if key.startswith(sub_level)
|
||||
}
|
||||
|
||||
state_dict = _auto_drop_invalid(model, state_dict, logger=logger)
|
||||
|
||||
load_status = model.load_state_dict(state_dict, strict=False)
|
||||
unexpected_keys = load_status.unexpected_keys
|
||||
missing_keys = load_status.missing_keys
|
||||
err_msgs = []
|
||||
if unexpected_keys:
|
||||
err_msgs.append('unexpected key in source '
|
||||
f'state_dict: {", ".join(unexpected_keys)}\n')
|
||||
if missing_keys:
|
||||
err_msgs.append('missing key in source '
|
||||
f'state_dict: {", ".join(missing_keys)}\n')
|
||||
err_msgs = '\n'.join(err_msgs)
|
||||
|
||||
if len(err_msgs) > 0:
|
||||
if logger:
|
||||
logger.warning(err_msgs)
|
||||
else:
|
||||
import warnings
|
||||
warnings.warn(err_msgs)
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""
|
||||
Convert applicable model parameters to fp16.
|
||||
"""
|
||||
|
||||
def _convert_weights_to_fp16(layer):
|
||||
if isinstance(layer, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
layer.weight.data = layer.weight.data.half()
|
||||
if layer.bias is not None:
|
||||
layer.bias.data = layer.bias.data.half()
|
||||
|
||||
if isinstance(layer, nn.MultiheadAttention):
|
||||
for attr in [
|
||||
*[f'{s}_proj_weight' for s in ['in', 'q', 'k', 'v']],
|
||||
'in_proj_bias', 'bias_k', 'bias_v'
|
||||
]:
|
||||
tensor = getattr(layer, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ['text_projection', 'proj']:
|
||||
if hasattr(layer, name):
|
||||
attr = getattr(layer, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
for name in ['prompt_embeddings']:
|
||||
if hasattr(layer, name):
|
||||
attr = getattr(layer, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def _auto_drop_invalid(model: torch.nn.Module, state_dict: dict, logger=None):
|
||||
"""
|
||||
Strip unmatched parameters in state_dict, e.g. shape not matched, type not matched.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module):
|
||||
state_dict (dict):
|
||||
logger (logging.Logger, None):
|
||||
|
||||
Returns:
|
||||
A new state dict.
|
||||
"""
|
||||
ret_dict = state_dict.copy()
|
||||
invalid_msgs = []
|
||||
for key, value in model.state_dict().items():
|
||||
if key in state_dict:
|
||||
# Check shape
|
||||
new_value = state_dict[key]
|
||||
if value.shape != new_value.shape:
|
||||
invalid_msgs.append(
|
||||
f'{key}: invalid shape, dst {value.shape} vs. src {new_value.shape}'
|
||||
)
|
||||
ret_dict.pop(key)
|
||||
elif value.dtype != new_value.dtype:
|
||||
invalid_msgs.append(
|
||||
f'{key}: invalid dtype, dst {value.dtype} vs. src {new_value.dtype}'
|
||||
)
|
||||
ret_dict.pop(key)
|
||||
if len(invalid_msgs) > 0:
|
||||
warning_msg = 'ignore keys from source: \n' + '\n'.join(invalid_msgs)
|
||||
if logger:
|
||||
logger.warning(warning_msg)
|
||||
else:
|
||||
import warnings
|
||||
warnings.warn(warning_msg)
|
||||
return ret_dict
|
||||
@@ -86,6 +86,7 @@ if TYPE_CHECKING:
|
||||
from .image_mvs_depth_estimation_pipeline import ImageMultiViewDepthEstimationPipeline
|
||||
from .panorama_depth_estimation_pipeline import PanoramaDepthEstimationPipeline
|
||||
from .ddcolor_image_colorization_pipeline import DDColorImageColorizationPipeline
|
||||
from .image_structured_model_probing_pipeline import ImageStructuredModelProbingPipeline
|
||||
from .video_colorization_pipeline import VideoColorizationPipeline
|
||||
from .image_defrcn_fewshot_pipeline import ImageDefrcnDetectionPipeline
|
||||
from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline
|
||||
@@ -207,6 +208,9 @@ else:
|
||||
'ddcolor_image_colorization_pipeline': [
|
||||
'DDColorImageColorizationPipeline'
|
||||
],
|
||||
'image_structured_model_probing_pipeline': [
|
||||
'ImageSturcturedModelProbingPipeline'
|
||||
],
|
||||
'video_colorization_pipeline': ['VideoColorizationPipeline'],
|
||||
'image_defrcn_fewshot_pipeline': ['ImageDefrcnDetectionPipeline'],
|
||||
'image_quality_assessment_mos_pipeline': [
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from mmcv.parallel import collate, scatter
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_classification,
|
||||
module_name=Pipelines.image_structured_model_probing)
|
||||
class ImageStructuredModelProbingPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a vision middleware pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> recognition_pipeline = pipeline(self.task, self.model_id)
|
||||
>>> file_name = 'data/test/images/\
|
||||
image_structured_model_probing_test_image.jpg'
|
||||
>>> result = recognition_pipeline(file_name)
|
||||
>>> print(f'recognition output: {result}.')
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
self.model.eval()
|
||||
model_dir = os.path.join(model, 'food101-clip-vitl14-full.pt')
|
||||
model_file = torch.load(model_dir)
|
||||
self.label_map = model_file['meta_info']['label_map']
|
||||
logger.info('load model done')
|
||||
|
||||
self.transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.48145466, 0.4578275, 0.40821073],
|
||||
std=[0.26862954, 0.26130258, 0.27577711])
|
||||
])
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
|
||||
img = LoadImage.convert_to_img(input)
|
||||
|
||||
data = self.transform(img)
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
if next(self.model.parameters()).is_cuda:
|
||||
data = scatter(data, [next(self.model.parameters()).device])[0]
|
||||
|
||||
return data
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
results = self.model(input)
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
scores = torch.softmax(inputs, dim=1).cpu()
|
||||
labels = torch.argmax(scores, dim=1).cpu().tolist()
|
||||
label_names = [self.label_map[label] for label in labels]
|
||||
|
||||
return {OutputKeys.LABELS: label_names, OutputKeys.SCORES: scores}
|
||||
29
tests/pipelines/test_image_structured_model_probing.py
Normal file
29
tests/pipelines/test_image_structured_model_probing.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import unittest
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageStructuredModelProbingTest(unittest.TestCase,
|
||||
DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.image_classification
|
||||
self.model_id = 'damo/structured_model_probing'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
|
||||
recognition_pipeline = pipeline(self.task, self.model_id)
|
||||
file_name = 'data/test/images/image_structured_model_probing_test_image.jpg'
|
||||
result = recognition_pipeline(file_name)
|
||||
|
||||
print(f'recognition output: {result}.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user