add vision_efficient_tuning models

1)新增任务:vision_efficient_tuning;
2)新增该任务下四个模型:
vision_efficient_tuning_adapter、
vision_efficient_tuning_prefix、
vision_efficient_tuning_prompt、
vision_efficient_tuning_lora

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11575894
This commit is contained in:
zeyinzi.jzyz
2023-02-09 07:59:33 +00:00
committed by wenmeng.zwm
parent b34e2cad86
commit 9faf588bc6
20 changed files with 1924 additions and 1 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8b28d9c33eff034a706534f195f4443f8c053a74d5553787a5cb9b20873c072f
size 1962

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bbd99f0253d6e0d10ec500cf781cc83b93809db58da54bd914b0b80b7fe8d8a4
size 2409

View File

@@ -87,6 +87,7 @@ class Models(object):
object_detection_3d = 'object_detection_3d'
ddpm = 'ddpm'
image_quality_assessment_mos = 'image-quality-assessment-mos'
vision_efficient_tuning = 'vision-efficient-tuning'
# EasyCV models
yolox = 'YOLOX'
@@ -322,8 +323,9 @@ class Pipelines(object):
video_colorization = 'video-colorization'
motion_generattion = 'mdm-motion-generation'
object_detection_3d_depe = 'object-detection-3d-depe'
image_quality_assessment_mos = 'image-quality-assessment-mos'
vision_efficient_tuning = 'vision-efficient-tuning'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'
translation_quality_estimation = 'translation-quality-estimation'
@@ -677,6 +679,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_quality_assessment_mos: (
Pipelines.image_quality_assessment_mos,
'damo/cv_resnet_image-quality-assessment-mos_youtubeUGC'),
Tasks.vision_efficient_tuning: (
Pipelines.vision_efficient_tuning,
'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'),
Tasks.object_detection_3d: (Pipelines.object_detection_3d_depe,
'damo/cv_object-detection-3d_depe'),
}

View File

@@ -0,0 +1,30 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .vision_efficient_tuning_adapter import VisionEfficientTuningAdapterModel
from .vision_efficient_tuning_prompt import VisionEfficientTuningPromptModel
from .vision_efficient_tuning_prefix import VisionEfficientTuningPrefixModel
from .vision_efficient_tuning_lora import VisionEfficientTuningLoRAModel
else:
_import_structure = {
'vision_efficient_tuning_adapter':
['VisionEfficientTuningAdapterModel'],
'vision_efficient_tuning_prompt': ['VisionEfficientTuningPromptModel'],
'vision_efficient_tuning_prefix': ['VisionEfficientTuningPrefixModel'],
'vision_efficient_tuning_lora': ['VisionEfficientTuningLoRAModel'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,351 @@
# The implementation here is modified based on timm,
# originally Apache 2.0 License and publicly available at
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from .petl import Adapter, LoRA, Prefix, Prompt
from .timm_vision_transformer import (Attention, Block, DropPath, LayerScale,
Mlp, PatchEmbed, VisionTransformer)
class AttentionPETL(nn.Module):
"""Extend the parameter-efficient transfer learning (PETL) method to the original Attention.
Prefix tuning optimizes the task-specific vector in the multi-head attention layer.
'Prefix-tuning: Optimizing continuous prompts for generation' by Li & Liang(2021)
See https://arxiv.org/abs/2101.00190
LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
See https://arxiv.org/abs/2106.09685
Attributes:
prefix_length: An integer indicating the length of prefix tuning.
prefix_type: A string indicating the type of prefix tuning.
lora_length: An integer indicating the length of LoRA tuning.
lora_type: A string indicating the type of LoRA tuning.
"""
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.,
prefix_length=None,
prefix_type=None,
lora_length=None,
lora_type=None,
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
if lora_length and lora_length > 0:
self.lora = LoRA(
dim=dim,
num_heads=num_heads,
lora_length=lora_length,
lora_type=lora_type)
else:
self.lora = None
if prefix_length and prefix_length > 0:
self.prefix = Prefix(
dim=dim,
num_heads=num_heads,
prefix_length=prefix_length,
prefix_type=prefix_type)
else:
self.prefix = None
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
if self.lora is not None:
q, k, v = self.lora(x, q, k, v)
if self.prefix is not None:
q, k, v = self.prefix(x, q, k, v)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class BlockPETL(nn.Module):
"""Extend the parameter-efficient transfer learning (PETL) method to the original Block.
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
and prepend to the original tokens in the first layer or multiple layers.
'Visual Prompt Tuning' by Jia et al.(2022)
See https://arxiv.org/abs/2203.12119
Adapters project input tokens by an MLP layer.
'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
See http://arxiv.org/abs/1902.00751
Attributes:
adapter_length: An integer indicating the length of adapter tuning.
adapter_type: A string indicating the type of adapter tuning.
prompt_length: An integer indicating the length of prompt tuning.
prompt_type: A string indicating the type of prompt tuning.
"""
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_layer=Attention,
layer_num=-1,
prompt_length=None,
prompt_type=None,
prefix_length=None,
prefix_type=None,
adapter_length=None,
adapter_type=None,
lora_length=None,
lora_type=None,
):
super().__init__()
self.layer_num = layer_num
self.norm1 = norm_layer(dim)
self.attn = attn_layer(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
prefix_length=prefix_length,
prefix_type=prefix_type,
lora_length=lora_length,
lora_type=lora_type,
)
self.ls1 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.adapter_length = adapter_length
self.adapter_type = adapter_type
if adapter_length and adapter_length > 0:
self.adapter = Adapter(
dim=dim,
adapter_length=adapter_length,
adapter_type=adapter_type,
act_layer=act_layer)
else:
self.adapter = None
self.prompt_length = prompt_length
self.prompt_type = prompt_type
if prompt_length and prompt_length > 0:
self.prompt = Prompt(
dim=dim,
layer_num=layer_num,
prompt_length=prompt_length,
prompt_type=prompt_type)
else:
self.prompt = None
def forward(self, x):
if self.prompt is not None:
x = self.prompt(x)
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
if self.adapter is not None:
x = x + self.adapter(
self.drop_path2(self.ls2(self.mlp(self.norm2(x)))))
else:
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class VisionTransformerPETL(VisionTransformer):
""" Extend the parameter-efficient transfer learning (PETL) method to the original Vision Transformer.
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
The implementation of several tuning methods (prompt, prefix, adapter, and LoRA) based on ViT.
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
prompt_length=None,
prompt_type=None,
prefix_length=None,
prefix_type=None,
adapter_length=None,
adapter_type=None,
lora_length=None,
lora_type=None,
):
super().__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False
self.depth = depth
self.img_size = img_size
self.class_token = class_token
self.prompt_length = prompt_length
self.prompt_type = prompt_type
self.prefix_length = prefix_length
self.prefix_type = prefix_type
self.adapter_length = adapter_length
self.adapter_type = adapter_type
self.lora_length = lora_length
self.lora_type = lora_type
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm,
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(
1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(
torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
if prompt_length is not None or prefix_length is not None \
or adapter_length is not None or lora_length is not None:
attn_layer = AttentionPETL
block_fn = BlockPETL
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
attn_layer=attn_layer,
layer_num=i,
prompt_length=prompt_length[i] if isinstance(
prompt_length, list) else prompt_length,
prompt_type=prompt_type,
prefix_length=prefix_length[i] if isinstance(
prefix_length, list) else prefix_length,
prefix_type=prefix_type,
adapter_length=adapter_length[i] if isinstance(
adapter_length, list) else adapter_length,
adapter_type=adapter_type,
lora_length=lora_length[i] if isinstance(
lora_length, list) else lora_length,
lora_type=lora_type) for i in range(depth)
])
else:
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)

View File

@@ -0,0 +1,25 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import torch.nn as nn
class ClassifierHead(nn.Module):
"""The implementation of classification head.
Attributes:
dim: An integer indicating the hidden dimension.
num_classes: A string indicating the number of class.
dropout_rate: A float indicating the dropout rate.
"""
def __init__(self, dim, num_classes, dropout_rate=0):
super().__init__()
self.dim = dim
self.num_classes = num_classes
if dropout_rate > 0.0:
self.dropout = nn.Dropout(dropout_rate)
self.fc = nn.Linear(dim, num_classes)
def forward(self, x):
if hasattr(self, 'dropout'):
x = self.dropout(x)
return self.fc(x)

View File

@@ -0,0 +1,174 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
class Prompt(nn.Module):
"""The implementation of vision prompt tuning method.
Visual prompt tuning (VPT) is proposed to initialize tunable prompt tokens
and prepend to the original tokens in the first layer or multiple layers.
'Visual Prompt Tuning' by Jia et al.(2022)
See https://arxiv.org/abs/2203.12119
Attributes:
dim: An integer indicating the embedding dimension.
layer_num: An integer indicating number of layers.
prompt_length: An integer indicating the length of vision prompt tuning.
prompt_type: A string indicating the type of vision prompt tuning.
"""
def __init__(self, dim, layer_num, prompt_length=None, prompt_type=None):
super(Prompt, self).__init__()
self.dim = dim
self.layer_num = layer_num
self.prompt_length = prompt_length
self.prompt_type = prompt_type
self.prompt_token = nn.Parameter(torch.zeros(1, prompt_length, dim))
nn.init.xavier_uniform_(self.prompt_token)
def forward(self, x):
B, N, C = x.shape
prompt_token = self.prompt_token.expand(B, -1, -1)
if self.layer_num == 0:
x = torch.cat((x, prompt_token), dim=1)
else:
x = torch.cat((x[:, :-self.prompt_length, :], prompt_token), dim=1)
return x
class Adapter(nn.Module):
"""The implementation of adapter tuning method.
Adapters project input tokens by an MLP layer.
'Parameter-Efficient Transfer Learning for NLP' by Houlsby et al.(2019)
See http://arxiv.org/abs/1902.00751
Attributes:
dim: An integer indicating the embedding dimension.
adapter_length: An integer indicating the length of adapter tuning.
adapter_type: A string indicating the type of adapter tuning.
"""
def __init__(
self,
dim,
adapter_length=None,
adapter_type=None,
act_layer=nn.GELU,
):
super(Adapter, self).__init__()
self.dim = dim
self.adapter_length = adapter_length
self.adapter_type = adapter_type
self.ln1 = nn.Linear(dim, adapter_length)
self.activate = act_layer()
self.ln2 = nn.Linear(adapter_length, dim)
self.init_weights()
def init_weights(self):
def _init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.normal_(m.bias, std=1e-6)
self.apply(_init_weights)
def forward(self, x, identity=None):
out = self.ln2(self.activate(self.ln1(x)))
if identity is None:
identity = x
out = identity + out
return out
class LoRA(nn.Module):
"""The implementation of LoRA tuning method.
LoRA constructs an additional layer with low-rank decomposition matrices of the weights in the network.
'LoRA: Low-Rank Adaptation of Large Language Models' by Hu et al.(2021)
See https://arxiv.org/abs/2106.09685
Attributes:
dim: An integer indicating the embedding dimension.
num_heads: An integer indicating number of attention heads.
lora_length: An integer indicating the length of LoRA tuning.
lora_type: A string indicating the type of LoRA tuning.
"""
def __init__(
self,
dim,
num_heads,
lora_length=None,
lora_type=None,
):
super(LoRA, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.lora_a = nn.Linear(dim, lora_length, bias=False)
nn.init.kaiming_uniform_(self.lora_a.weight, a=math.sqrt(5))
self.lora_b = nn.Linear(lora_length, dim * 3, bias=False)
nn.init.zeros_(self.lora_b.weight)
self.lora_length = lora_length
self.lora_type = lora_type
def forward(self, x, q, k, v):
B, N, C = x.shape
qkv_delta = self.lora_b(self.lora_a(x))
qkv_delta = qkv_delta.reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(
2, 0, 3, 1, 4)
q_delta, k_delta, v_delta = qkv_delta.unbind(0)
q, k, v = q + q_delta, k + k_delta, v + v_delta
return q, k, v
class Prefix(nn.Module):
"""The implementation of prefix tuning method.
Prefix tuning optimizes the task-specific vector in the multi-head attention layer.
'Prefix-tuning: Optimizing continuous prompts for generation' by Li & Liang(2021)
See https://arxiv.org/abs/2101.00190
Attributes:
dim: An integer indicating the embedding dimension.
num_heads: An integer indicating number of attention heads.
prefix_length: An integer indicating the length of prefix tuning.
prefix_type: A string indicating the type of prefix tuning.
"""
def __init__(
self,
dim,
num_heads,
prefix_length=None,
prefix_type=None,
):
super(Prefix, self).__init__()
self.dim = dim
self.num_heads = num_heads
self.prefix_length = prefix_length
self.prefix_type = prefix_type
self.prefix_key = nn.Parameter(torch.zeros(1, prefix_length, dim))
self.prefix_value = nn.Parameter(torch.zeros(1, prefix_length, dim))
nn.init.xavier_uniform_(self.prefix_key)
nn.init.xavier_uniform_(self.prefix_value)
def forward(self, x, q, k, v):
B, N, C = x.shape
prefix_key = self.prefix_key.expand(B, -1, -1).reshape(
B, self.prefix_length, self.num_heads,
self.dim // self.num_heads).permute(0, 2, 1, 3)
prefix_value = self.prefix_value.expand(B, -1, -1).reshape(
B, self.prefix_length, self.num_heads,
self.dim // self.num_heads).permute(0, 2, 1, 3)
k, v = torch.cat((k, prefix_key), dim=2), torch.cat((v, prefix_value),
dim=2)
return q, k, v

View File

@@ -0,0 +1,132 @@
# The implementation is adopted from timm (version: 0.6.11),
# made publicly available under the Apache 2.0 License at
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/helpers.py
import math
from itertools import chain
from typing import Callable
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
def named_apply(fn: Callable,
module: nn.Module,
name='',
depth_first=True,
include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
def adapt_input_conv(in_chans, conv_weight):
conv_type = conv_weight.dtype
conv_weight = conv_weight.float(
) # Some weights are in torch.half, ensure it's float for sum on CPU
O, I, J, K = conv_weight.shape
if in_chans == 1:
if I > 3:
assert conv_weight.shape[1] % 3 == 0
# For models with space2depth stems
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
conv_weight = conv_weight.sum(dim=2, keepdim=False)
else:
conv_weight = conv_weight.sum(dim=1, keepdim=True)
elif in_chans != 3:
if I != 3: # noqa
raise NotImplementedError(
'Weight format not supported by conversion.')
else:
# NOTE this strategy should be better than random init, but there could be other combinations of
# the original RGB input layer weights that'd work better for specific cases.
repeat = int(math.ceil(in_chans / 3))
conv_weight = conv_weight.repeat(1, repeat, 1,
1)[:, :in_chans, :, :]
conv_weight *= (3 / float(in_chans))
conv_weight = conv_weight.to(conv_type)
return conv_weight
def checkpoint_seq(functions,
x,
every=1,
flatten=False,
skip_last=False,
preserve_rng_state=True):
r"""A helper function for checkpointing sequential models.
Sequential models execute a list of modules/functions in order
(sequentially). Therefore, we can divide such a sequence into segments
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
manner, i.e., not storing the intermediate activations. The inputs of each
checkpointed segment will be saved for re-running the segment in the backward pass.
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
.. warning::
Checkpointing currently only supports :func:`torch.autograd.backward`
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
is not supported.
.. warning:
At least one of the inputs needs to have :code:`requires_grad=True` if
grads are needed for model inputs, otherwise the checkpointed part of the
model won't have gradients.
Args:
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
x: A Tensor that is input to :attr:`functions`
every: checkpoint every-n functions (default: 1)
flatten (bool): flatten nn.Sequential of nn.Sequentials
skip_last (bool): skip checkpointing the last function in the sequence if True
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
the RNG state during each checkpoint.
Returns:
Output of running :attr:`functions` sequentially on :attr:`*inputs`
Example:
>>> model = nn.Sequential(...)
>>> input_var = checkpoint_seq(model, input_var, every=2)
"""
def run_function(start, end, functions):
def forward(_x):
for j in range(start, end + 1):
_x = functions[j](_x)
return _x
return forward
if isinstance(functions, torch.nn.Sequential):
functions = functions.children()
if flatten:
functions = chain.from_iterable(functions)
if not isinstance(functions, (tuple, list)):
functions = tuple(functions)
num_checkpointed = len(functions)
if skip_last:
num_checkpointed -= 1
end = -1
for start in range(0, num_checkpointed, every):
end = min(start + every - 1, num_checkpointed - 1)
x = checkpoint(
run_function(start, end, functions),
x,
preserve_rng_state=preserve_rng_state)
if skip_last:
return run_function(end + 1, len(functions) - 1, functions)(x)
return x

View File

@@ -0,0 +1,755 @@
# The implementation is adopted from timm (version: 0.6.11),
# made publicly available under the Apache 2.0 License at
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/vision_transformer.py,
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/mlp.py,
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/mlp.py,
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/patch_embed.py,
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/drop.py,
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/helpers.py
import collections.abc
import logging
import math
from collections import OrderedDict
from functools import partial
from itertools import repeat
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import _assert
from .timm_helpers import adapt_input_conv, checkpoint_seq, named_apply
from .timm_weight_init import lecun_normal_, trunc_normal_
_logger = logging.getLogger(__name__)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(repeat(x, n))
return parse
to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, C, H, W = x.shape
_assert(
H == self.img_size[0],
f"Input image height ({H}) doesn't match model ({self.img_size[0]})."
)
_assert(
W == self.img_size[1],
f"Input image width ({W}) doesn't match model ({self.img_size[1]})."
)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
def drop_path(x,
drop_prob: float = 0.,
training: bool = False,
scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'
class Attention(nn.Module):
def __init__(self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(
0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Block(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
self.ls1 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class ResPostBlock(nn.Module):
def __init__(self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.init_values = init_values
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.init_weights()
def init_weights(self):
# NOTE this init overrides that base model init with specific changes for the block type
if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def forward(self, x):
x = x + self.drop_path1(self.norm1(self.attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x
class ParallelBlock(nn.Module):
def __init__(self,
dim,
num_heads,
num_parallel=2,
mlp_ratio=4.,
qkv_bias=False,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm):
super().__init__()
self.num_parallel = num_parallel
self.attns = nn.ModuleList()
self.ffns = nn.ModuleList()
for _ in range(num_parallel):
self.attns.append(
nn.Sequential(
OrderedDict([('norm', norm_layer(dim)),
('attn',
Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)),
('ls',
LayerScale(dim, init_values=init_values)
if init_values else nn.Identity()),
('drop_path', DropPath(drop_path)
if drop_path > 0. else nn.Identity())])))
self.ffns.append(
nn.Sequential(
OrderedDict([('norm', norm_layer(dim)),
('mlp',
Mlp(dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop)),
('ls',
LayerScale(dim, init_values=init_values)
if init_values else nn.Identity()),
('drop_path', DropPath(drop_path)
if drop_path > 0. else nn.Identity())])))
def _forward_jit(self, x):
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
return x
@torch.jit.ignore
def _forward(self, x):
x = x + sum(attn(x) for attn in self.attns)
x = x + sum(ffn(x) for ffn in self.ffns)
return x
def forward(self, x):
if torch.jit.is_scripting() or torch.jit.is_tracing():
return self._forward_jit(x)
else:
return self._forward(x)
class VisionTransformer(nn.Module):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
- https://arxiv.org/abs/2010.11929
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
global_pool='token',
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=True,
init_values=None,
class_token=True,
no_embed_class=False,
pre_norm=False,
fc_norm=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
weight_init='',
embed_layer=PatchEmbed,
norm_layer=None,
act_layer=None,
block_fn=Block,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'token')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
bias=not pre_norm, # disable bias if pre-norm is used (e.g. CLIP)
)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(
1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(
torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
self.norm_pre = norm_layer(embed_dim) if pre_norm else nn.Identity()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer) for i in range(depth)
])
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02)
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self)
def _init_weights(self, m):
# this fn left here for compat with downstream users
init_weights_vit_timm(m)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999, ))])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x),
dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x),
dim=1)
x = x + self.pos_embed
return self.pos_drop(x)
def forward_features(self, x):
x = self.patch_embed(x)
x = self._pos_embed(x)
x = self.norm_pre(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_prefix_tokens:].mean(
dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def init_weights_vit_timm(module: nn.Module, name: str = ''):
""" ViT weight initialization, original timm impl (for reproducibility) """
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def init_weights_vit_jax(module: nn.Module,
name: str = '',
head_bias: float = 0.):
""" ViT weight initialization, matching JAX (Flax) impl """
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.normal_(
module.bias,
std=1e-6) if 'mlp' in name else nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def init_weights_vit_moco(module: nn.Module, name: str = ''):
""" ViT weight initialization, matching moco-v3 impl minus fixed PatchEmbed """
if isinstance(module, nn.Linear):
if 'qkv' in name:
# treat the weights of Q, K, V separately
val = math.sqrt(
6.
/ float(module.weight.shape[0] // 3 + module.weight.shape[1]))
nn.init.uniform_(module.weight, -val, val)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def get_init_weights_vit(mode='jax', head_bias: float = 0.):
if 'jax' in mode:
return partial(init_weights_vit_jax, head_bias=head_bias)
elif 'moco' in mode:
return init_weights_vit_moco
else:
return init_weights_vit_timm
@torch.no_grad()
def _load_weights(model: VisionTransformer,
checkpoint_path: str,
prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p(w, t=True):
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.flatten()
if t:
if w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/'
if hasattr(model.patch_embed, 'backbone'):
# hybrid
backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(
adapt_input_conv(stem.conv.weight.shape[1],
_n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(
_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(
_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(
_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(
_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(
_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(
_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(model.patch_embed.proj.weight.shape[1],
_n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
pos_embed_w = _n2p(
w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_prefix_tokens',
1),
model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if isinstance(
model.head, nn.Linear
) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
# NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
# if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
# model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
block.attn.qkv.weight.copy_(
torch.cat([
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T
for n in ('query', 'key', 'value')
]))
block.attn.qkv.bias.copy_(
torch.cat([
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1)
for n in ('query', 'key', 'value')
]))
block.attn.proj.weight.copy_(
_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
for r in range(2):
getattr(block.mlp, f'fc{r + 1}').weight.copy_(
_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
getattr(block.mlp, f'fc{r + 1}').bias.copy_(
_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape,
posemb_new.shape)
ntok_new = posemb_new.shape[1]
if num_prefix_tokens:
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[
0, num_prefix_tokens:]
ntok_new -= num_prefix_tokens
else:
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info('Position embedding grid-size from %s to %s',
[gs_old, gs_old], gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old,
-1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(
posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3,
1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb

View File

@@ -0,0 +1,131 @@
# The implementation is adopted from timm (version: 0.6.11),
# made publicly available under the Apache 2.0 License at
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/layers/weight_init.py
import math
import warnings
import torch
from torch.nn.init import _calculate_fan_in_and_fan_out
def _trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1. + math.erf(x / math.sqrt(2.))) / 2.
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
stacklevel=2)
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std) # noqa
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
applied while sampling the normal with mean/std applied, therefore a, b args
should be adjusted to match the range of mean, std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
with torch.no_grad():
return _trunc_normal_(tensor, mean, std, a, b)
def trunc_normal_tf_(tensor, mean=0., std=1., a=-2., b=2.):
# type: (Tensor, float, float, float, float) -> Tensor
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
and the result is subsquently scaled and shifted by the mean and std args.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
return tensor
def variance_scaling_(tensor, scale=1.0, mode='fan_in', distribution='normal'):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == 'fan_in':
denom = fan_in
elif mode == 'fan_out':
denom = fan_out
elif mode == 'fan_avg':
denom = (fan_in + fan_out) / 2
else:
raise ValueError(f'invalid mode {mode}')
variance = scale / denom
if distribution == 'truncated_normal':
# constant is stddev of standard normal truncated to (-2, 2)
trunc_normal_tf_(tensor, std=math.sqrt(variance) / .87962566103423978)
elif distribution == 'normal':
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == 'uniform':
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f'invalid distribution {distribution}')
def lecun_normal_(tensor):
variance_scaling_(tensor, mode='fan_in', distribution='truncated_normal')

View File

@@ -0,0 +1,65 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
import torch
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
@MODELS.register_module(
Tasks.vision_efficient_tuning, module_name=Models.vision_efficient_tuning)
class VisionEfficientTuningModel(TorchModel):
""" The implementation of vision efficient tuning.
This model is constructed with the following parts:
- 'backbone': pre-trained backbone model with parameters.
- 'head': classification head with fine-tuning.
"""
def __init__(self, model_dir: str, **kwargs):
""" Initialize a vision efficient tuning model.
Args:
model_dir: model id or path, where model_dir/pytorch_model.pt contains:
- 'backbone_cfg': config of backbone.
- 'backbone_weight': parameters of backbone.
- 'head_cfg': config of head.
- 'head_weight': parameters of head.
- 'CLASSES': list of label name.
"""
from .backbone import VisionTransformerPETL
from .head import ClassifierHead
super().__init__(model_dir)
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
model_dict = torch.load(model_path)
backbone_cfg = model_dict['backbone_cfg']
if 'type' in backbone_cfg:
backbone_cfg.pop('type')
self.backbone_model = VisionTransformerPETL(**backbone_cfg)
self.backbone_model.load_state_dict(
model_dict['backbone_weight'], strict=True)
head_cfg = model_dict['head_cfg']
if 'type' in head_cfg:
head_cfg.pop('type')
self.head_model = ClassifierHead(**head_cfg)
self.head_model.load_state_dict(model_dict['head_weight'], strict=True)
self.CLASSES = model_dict['CLASSES']
def forward(self, inputs):
""" Dynamic forward function of vision efficient tuning.
Args:
inputs: the input images (B, 3, H, W).
"""
backbone_output = self.backbone_model(inputs)
head_output = self.head_model(backbone_output)
return head_output

View File

@@ -993,6 +993,13 @@ TASK_OUTPUTS = {
# "output_video": "path_to_rendered_video"
# }
Tasks.motion_generation: [OutputKeys.KEYPOINTS, OutputKeys.OUTPUT_VIDEO],
# vision efficient tuning result for single sample
# {
# "scores": [0.9, 0.1, 0.05, 0.05]
# "labels": ["dog", "horse", "cow", "cat"],
# }
Tasks.vision_efficient_tuning: [OutputKeys.SCORES, OutputKeys.LABELS],
}

View File

@@ -86,6 +86,8 @@ TASK_INPUTS = {
InputType.IMAGE,
Tasks.image_fewshot_detection:
InputType.IMAGE,
Tasks.vision_efficient_tuning:
InputType.IMAGE,
# image editing task result for a single image
Tasks.skin_retouching:

View File

@@ -73,6 +73,10 @@ if TYPE_CHECKING:
from .hand_static_pipeline import HandStaticPipeline
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline
from .vision_efficient_tuning_adapter_pipeline import VisionEfficientTuningAdapterPipeline
from .vision_efficient_tuning_prompt_pipeline import VisionEfficientTuningPromptPipeline
from .vision_efficient_tuning_prefix_pipeline import VisionEfficientTuningPrefixPipeline
from .vision_efficient_tuning_lora_pipeline import VisionEfficientTuningLoRAPipeline
from .vision_middleware_pipeline import VisionMiddlewarePipeline
from .video_frame_interpolation_pipeline import VideoFrameInterpolationPipeline
from .image_skychange_pipeline import ImageSkychangePipeline
@@ -187,6 +191,18 @@ else:
'language_guided_video_summarization_pipeline': [
'LanguageGuidedVideoSummarizationPipeline'
],
'vision_efficient_tuning_adapter_pipeline': [
'VisionEfficientTuningAdapterPipeline'
],
'vision_efficient_tuning_prompt_pipeline': [
'VisionEfficientTuningPromptPipeline'
],
'vision_efficient_tuning_prefix_pipeline': [
'VisionEfficientTuningPrefixPipeline'
],
'vision_efficient_tuning_lora_pipeline': [
'VisionEfficientTuningLoRAPipeline'
],
'vision_middleware_pipeline': ['VisionMiddlewarePipeline'],
'video_frame_interpolation_pipeline': [
'VideoFrameInterpolationPipeline'

View File

@@ -0,0 +1,74 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import Any, Dict
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as transforms
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.vision_efficient_tuning,
module_name=Pipelines.vision_efficient_tuning)
class VisionEfficientTuningPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
use `model` to create a vision efficient tuning pipeline for prediction
Args:
model: model id on modelscope hub.
Example:
>>> from modelscope.pipelines import pipeline
>>> petl_pipeline = pipeline('vision-efficient-tuning',
'damo/cv_vitb16_classification_vision-efficient-tuning-adapter')
>>> result = petl_pipeline(
'data/test/images/vision_efficient_tuning_test_1.png')
>>> print(f'Output: {result}.')
"""
super().__init__(model=model, **kwargs)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.model = self.model.to(self.device)
self.model.eval()
self.transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_img(input)
data = self.transform(img).unsqueeze(0).to(self.device)
return data
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
with torch.no_grad():
results = self.model(input)
return results
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
scores = F.softmax(inputs, dim=1).cpu().numpy()
pred_scores = np.sort(scores, axis=1)[0][::-1][:5]
pred_labels = np.argsort(scores, axis=1)[0][::-1][:5]
result = {
'pred_score': [score for score in pred_scores],
'pred_class': [self.model.CLASSES[label] for label in pred_labels]
}
outputs = {
OutputKeys.SCORES: result['pred_score'],
OutputKeys.LABELS: result['pred_class']
}
return outputs

View File

@@ -133,6 +133,9 @@ class CVTasks(object):
# motion generation
motion_generation = 'motion-generation'
# vision efficient tuning
vision_efficient_tuning = 'vision-efficient-tuning'
class NLPTasks(object):
# nlp tasks

View File

@@ -0,0 +1,37 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest
from modelscope.models import Model
from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
VisionEfficientTuningModel
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class VisionEfficientTuningAdapterTest(unittest.TestCase,
DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.vision_efficient_tuning
self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-adapter'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_pipeline(self):
petl_pipeline = pipeline(self.task, self.model_id)
result = petl_pipeline(
'data/test/images/vision_efficient_tuning_test_1.png')
print(f'Vision-efficient-tuning-adapter output: {result}.')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_load_model_from_pretrained(self):
model = Model.from_pretrained(
'damo/cv_vitb16_classification_vision-efficient-tuning-adapter')
self.assertTrue(model.__class__ == VisionEfficientTuningModel)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,36 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest
from modelscope.models import Model
from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
VisionEfficientTuningModel
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class VisionEfficientTuningLoRATest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.vision_efficient_tuning
self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-lora'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_pipeline(self):
petl_pipeline = pipeline(self.task, self.model_id)
result = petl_pipeline(
'data/test/images/vision_efficient_tuning_test_1.png')
print(f'Vision-efficient-tuning-lora output: {result}.')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_load_model_from_pretrained(self):
model = Model.from_pretrained(
'damo/cv_vitb16_classification_vision-efficient-tuning-lora')
self.assertTrue(model.__class__ == VisionEfficientTuningModel)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,37 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest
from modelscope.models import Model
from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
VisionEfficientTuningModel
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class VisionEfficientTuningPrefixTest(unittest.TestCase,
DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.vision_efficient_tuning
self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prefix'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_pipeline(self):
petl_pipeline = pipeline(self.task, self.model_id)
result = petl_pipeline(
'data/test/images/vision_efficient_tuning_test_1.png')
print(f'Vision-efficient-tuning-prefix output: {result}.')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_load_model_from_pretrained(self):
model = Model.from_pretrained(
'damo/cv_vitb16_classification_vision-efficient-tuning-prefix')
self.assertTrue(model.__class__ == VisionEfficientTuningModel)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,37 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest
from modelscope.models import Model
from modelscope.models.cv.vision_efficient_tuning.vision_efficient_tuning import \
VisionEfficientTuningModel
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class VisionEfficientTuningPromptTest(unittest.TestCase,
DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.vision_efficient_tuning
self.model_id = 'damo/cv_vitb16_classification_vision-efficient-tuning-prompt'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_pipeline(self):
petl_pipeline = pipeline(self.task, self.model_id)
result = petl_pipeline(
'data/test/images/vision_efficient_tuning_test_1.png')
print(f'Vision-efficient-tuning-prompt output: {result}.')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_load_model_from_pretrained(self):
model = Model.from_pretrained(
'damo/cv_vitb16_classification_vision-efficient-tuning-prompt')
self.assertTrue(model.__class__ == VisionEfficientTuningModel)
if __name__ == '__main__':
unittest.main()