[to #42322933]多模态预训练模型OFA增加支持6b模型的feature

多模态预训练模型OFA增加支持6b模型的feature
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10574571
This commit is contained in:
xiaodongdeng.dxd
2022-10-31 20:42:56 +08:00
committed by yingda.chen
parent ce08cfbea8
commit 64868bf2ad
5 changed files with 416 additions and 143 deletions

View File

@@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig):
entangle_position_embedding=False,
interpolate_position=False,
orig_patch_image_size=224,
share_attn_bias=False,
use_image_feature=True,
disable_entangle=False,
use_ofasys=False,
vit_type='vit_base',
vit_drop_path_rate=0.0,
**kwargs):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
@@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig):
self.interpolate_position = interpolate_position
self.orig_patch_image_size = orig_patch_image_size
self.share_attn_bias = share_attn_bias
self.use_image_feature = use_image_feature
self.disable_entangle = disable_entangle
self.use_ofasys = use_ofasys
self.vit_type = vit_type
self.vit_drop_path_rate = vit_drop_path_rate
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,

344
modelscope/models/multi_modal/ofa/modeling_ofa.py Executable file → Normal file
View File

@@ -35,6 +35,8 @@ from transformers.utils import logging
from .configuration_ofa import OFAConfig
from .generate import utils
from .resnet import ResNet
from .utils.utils import DropPath
from .vit import vit_base, vit_huge, vit_large, vit_large_336
logger = logging.get_logger(__name__)
@@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList):
yield m
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args:
x (`nn.Modules`): input nn layers.
drop_prob (`float`): drop path ratio.
training (`bool`): whether is training or inference.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (1, x.shape[1], 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args:
drop_prob: drop path ratio.
"""
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)
class OFAAttention(nn.Module):
r"""
Multi-headed attention, with additional implementation for NormFormer.
@@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel):
self.padding_idx)
if config.add_type_embedding:
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
if config.use_image_feature:
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
else:
self.type_embedding = Embedding(1, embed_dim, padding_idx=None)
else:
self.type_embedding = None
if config.resnet_type == 'resnet18':
self.embed_images = ResNet(
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet34':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet50':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet101':
self.embed_images = ResNet(
[3, 4, 23], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet152':
self.embed_images = ResNet(
[3, 8, 36], drop_path_rate=config.resnet_drop_path_rate)
else:
raise NotImplementedError
if config.use_image_feature:
if config.use_ofasys:
vit_backbone = {
'vit_base': vit_base,
'vit_large': vit_large,
'vit_large_336': vit_large_336,
'vit_huge': vit_huge,
}[config.vit_type]
self.embed_images = vit_backbone(config.vit_drop_path_rate)
self.image_proj = Linear(1024, embed_dim)
self.image_proj = Linear(self.embed_images.width, embed_dim)
if config.resnet_model_path:
else:
if config.resnet_type == 'resnet18':
self.embed_images = ResNet(
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet34':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet50':
self.embed_images = ResNet(
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet101':
self.embed_images = ResNet(
[3, 4, 23],
drop_path_rate=config.resnet_drop_path_rate)
elif config.resnet_type == 'resnet152':
self.embed_images = ResNet(
[3, 8, 36],
drop_path_rate=config.resnet_drop_path_rate)
else:
raise NotImplementedError
self.image_proj = Linear(1024, embed_dim)
if not config.use_ofasys and config.resnet_model_path:
print('load resnet {}'.format(config.resnet_model_path))
resnet_state_dict = torch.load(config.resnet_model_path)
self.embed_images.load_state_dict(resnet_state_dict)
@@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel):
self.embed_positions = Embedding(self.max_source_positions + 2,
embed_dim)
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
embed_dim)
self.pos_ln = LayerNorm(embed_dim)
self.image_pos_ln = LayerNorm(embed_dim)
if config.use_image_feature:
self.embed_image_positions = Embedding(
config.image_bucket_size**2 + 1, embed_dim)
if not config.use_ofasys:
self.pos_ln = LayerNorm(embed_dim)
if config.use_image_feature:
self.image_pos_ln = LayerNorm(embed_dim)
self.pos_scaling = float(embed_dim / self.num_attention_heads
* config.attn_scale_factor)**-0.5
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
if not (config.use_ofasys and config.entangle_position_embedding):
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
if self.encoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
@@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel):
self.token_bucket_size = config.token_bucket_size
token_num_rel_dis = 2 * config.token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
self.share_attn_bias = config.share_attn_bias
num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers
self.token_rel_pos_table_list = nn.ModuleList([
Embedding(
token_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.encoder_layers)
for _ in range(num_rel_pos_tables)
])
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.encoder_layers)
])
if config.use_image_feature:
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(
config.image_bucket_size, image_num_rel_dis)
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis,
self.num_attention_heads,
zero_init=True) for _ in range(num_rel_pos_tables)
])
self.register_buffer('image_rp_bucket', image_rp_bucket)
if config.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
@@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel):
self.layernorm_embedding = None
self.register_buffer('token_rp_bucket', token_rp_bucket)
self.register_buffer('image_rp_bucket', image_rp_bucket)
self.entangle_position_embedding = config.entangle_position_embedding
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
self.use_ofasys = config.use_ofasys
def get_input_embeddings(self):
r"""
@@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel):
if has_pads:
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
if self.use_ofasys:
if patch_images is not None:
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
else:
pos_embed = self.pos_ln(pos_embed)
if patch_images is not None:
image_pos_embed = self.image_pos_ln(image_pos_embed)
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
if patch_images_2 is not None:
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
pos_q = self.pos_q_linear(pos_embed).view(
x.size(0), x.size(1), self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
x.size(0), x.size(1), self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
def build_abs_pos_bias(pos_embed):
batch_size, seq_length = pos_embed.size(0), pos_embed.size(1)
if not (self.use_ofasys and self.entangle_position_embedding):
pos_q = self.pos_q_linear(pos_embed).view(
batch_size, seq_length, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.pos_k_linear(pos_embed).view(
batch_size, seq_length, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
seq_length,
seq_length,
dtype=pos_embed.dtype,
device=pos_embed.device)
return abs_pos_bias
abs_pos_bias = build_abs_pos_bias(pos_embed)
# expand attention_mask
if has_pads:
@@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel):
if output_hidden_states:
encoder_states += (x, )
self_attn_bias = abs_pos_bias.clone()
real_idx = 0 if self.share_attn_bias else idx
self_attn_bias[:, :, -input_ids.size(1):,
-input_ids.size(1):] += self.get_rel_pos_bias(
input_ids, idx)
input_ids, real_idx)
if patch_images_2 is not None:
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
self.get_image_rel_pos_bias(image_position_ids_2, idx)
self.get_image_rel_pos_bias(image_position_ids_2, real_idx)
self_attn_bias[:, :,
image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa
image_num_patches_2:image_num_patches_2 + image_num_patches] += \
self.get_image_rel_pos_bias(image_position_ids, idx) # noqa
self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa
elif patch_images is not None:
self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \
self.get_image_rel_pos_bias(image_position_ids, idx)
self.get_image_rel_pos_bias(image_position_ids, real_idx)
self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1))
hidden_outputs = layer(
@@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel):
self._future_mask = torch.empty(0)
self.share_input_output_embed = config.share_decoder_input_output_embed
self.num_attention_heads = config.decoder_attention_heads
self.use_ofasys = config.use_ofasys
self.disable_entangle = config.disable_entangle
if embed_tokens is not None:
self.embed_tokens = embed_tokens
@@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel):
else:
self.layernorm_embedding = None
if config.use_ofasys:
if config.add_type_embedding:
self.type_embedding = Embedding(
1, self.embed_dim, padding_idx=None)
else:
self.type_embedding = None
self.window_size = config.code_image_size // 8
self.embed_positions = Embedding(self.max_target_positions + 2,
self.embed_dim)
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
self.embed_dim)
self.pos_ln = LayerNorm(self.embed_dim)
self.image_pos_ln = LayerNorm(self.embed_dim)
if not config.use_ofasys:
self.embed_image_positions = Embedding(
config.image_bucket_size**2 + 1, self.embed_dim)
if not config.use_ofasys:
self.pos_ln = LayerNorm(self.embed_dim)
self.image_pos_ln = LayerNorm(self.embed_dim)
self.pos_scaling = float(self.embed_dim / self.num_attention_heads
* config.attn_scale_factor)**-0.5
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
if not (config.use_ofasys and config.entangle_position_embedding):
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
@@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel):
self.token_bucket_size = config.token_bucket_size
token_num_rel_dis = 2 * config.token_bucket_size - 1
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
self.share_attn_bias = config.share_attn_bias
num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers
self.token_rel_pos_table_list = nn.ModuleList([
Embedding(
token_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.decoder_layers)
for _ in range(num_rel_pos_tables)
])
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size
- 1) * (2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
image_position_idx = torch.cat(
[torch.tensor([0]), image_position_idx.view(-1)])
image_position_idx = torch.cat(
[image_position_idx,
torch.tensor([1024] * 768)])
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis, self.num_attention_heads, zero_init=True)
for _ in range(config.decoder_layers)
])
if config.use_image_feature:
if not config.use_ofasys:
self.image_bucket_size = config.image_bucket_size
image_num_rel_dis = (2 * config.image_bucket_size - 1) * (
2 * config.image_bucket_size - 1) + 3
image_rp_bucket = make_image_bucket_position(
config.image_bucket_size, image_num_rel_dis)
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
image_position_idx = torch.cat(
[torch.tensor([0]),
image_position_idx.view(-1)])
image_position_idx = torch.cat(
[image_position_idx,
torch.tensor([1024] * 768)])
self.register_buffer('image_position_idx', image_position_idx)
self.image_rel_pos_table_list = nn.ModuleList([
Embedding(
image_num_rel_dis,
self.num_attention_heads,
zero_init=True) for _ in range(num_rel_pos_tables)
])
self.register_buffer('image_rp_bucket', image_rp_bucket)
self.register_buffer('token_rp_bucket', token_rp_bucket)
self.register_buffer('image_rp_bucket', image_rp_bucket)
self.register_buffer('image_position_idx', image_position_idx)
self.entangle_position_embedding = config.entangle_position_embedding
self.gradient_checkpointing = False
@@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel):
batch_size = tgt_pos_embed.size(0)
tgt_len = tgt_pos_embed.size(1)
tgt_pos_embed = self.image_pos_ln(
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
if not self.use_ofasys:
tgt_pos_embed = self.image_pos_ln(
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
if src_pos_embed is not None:
src_len = src_pos_embed.size(1)
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
if not (self.entangle_position_embedding and self.use_ofasys):
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
tgt_len,
src_len,
dtype=tgt_pos_embed.dtype,
device=tgt_pos_embed.device)
else:
src_len = tgt_pos_embed.size(1)
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, src_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
# batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1)
if not (self.entangle_position_embedding and self.use_ofasys):
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2) * self.pos_scaling
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
batch_size, tgt_len, self.num_attention_heads,
-1).transpose(1, 2)
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
else:
abs_pos_bias = torch.zeros(
batch_size,
self.num_attention_heads,
tgt_len,
tgt_len,
dtype=tgt_pos_embed.dtype,
device=tgt_pos_embed.device)
return abs_pos_bias
@@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel):
past_key_values) > 0 else None
self_attn_bias = self_abs_pos_bias.clone()
real_idx = 0 if self.share_attn_bias else idx
if code_masks is None or not code_masks.any():
self_attn_bias += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
elif code_masks is not None and code_masks.all():
self_attn_bias += self.get_image_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
else:
self_attn_bias[~code_masks] += self.get_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(
all_prev_output_tokens, idx).unsqueeze(0)
all_prev_output_tokens, real_idx).unsqueeze(0)
self_attn_bias = self_attn_bias.reshape(
-1,
*self_attn_bias.size()[-2:])
@@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel):
self.encoder = OFAEncoder(config, shared)
self.decoder = OFADecoder(config, shared)
self.use_ofasys = config.use_ofasys
# Initialize weights and apply final processing
self.post_init()

View File

@@ -2,6 +2,7 @@
from typing import Optional
import torch
import torch.nn as nn
def expand_mask(mask: torch.Tensor,
@@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor,
src_len).to(dtype)
return expanded_mask.masked_fill(expanded_mask.bool(),
torch.finfo(dtype).min)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args:
x (`nn.Modules`): input nn layers.
drop_prob (`float`): drop path ratio.
training (`bool`): whether is training or inference.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (1, x.shape[1], 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
r"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Args:
drop_prob: drop path ratio.
"""
def __init__(self, drop_prob=None):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return 'p={}'.format(self.drop_prob)

View File

@@ -0,0 +1,155 @@
from collections import OrderedDict
import torch
import torch.nn.functional as F
from fairseq.modules import LayerNorm
from torch import nn
from .utils.utils import DropPath
__all__ = [
'vit_base',
'vit_large',
'vit_large_336',
'vit_huge',
]
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None,
drop_path_rate=0.0):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([
('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model)),
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
self.drop_path = DropPath(drop_path_rate)
def attention(self, x: torch.Tensor):
self.attn_mask = (
self.attn_mask.to(dtype=x.dtype, device=x.device)
if self.attn_mask is not None else None)
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.drop_path(self.attention(self.ln_1(x)))
x = x + self.drop_path(self.mlp(self.ln_2(x)))
return x
class Transformer(nn.Module):
def __init__(
self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None,
drop_path_rate: float = 0.0,
):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate)
for _ in range(layers)
])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(
self,
input_resolution: int,
patch_size: int,
width: int,
layers: int,
heads: int,
drop_path_rate: float = 0.0,
):
super().__init__()
self.input_resolution = input_resolution
self.patch_size = patch_size
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False,
)
scale = width**-0.5
self.width = width
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(
width, layers, heads, drop_path_rate=drop_path_rate)
def forward(self, x: torch.Tensor):
resolution = x.shape[-2]
height, width = x.shape[-2] // self.patch_size, x.shape[
-1] // self.patch_size
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
if resolution != self.input_resolution:
old_pe = self.positional_embedding[1:]
patch_num = self.input_resolution // self.patch_size
old_pe = old_pe.reshape(1, patch_num, patch_num,
-1).permute(0, 3, 1, 2)
new_pe = F.interpolate(
old_pe, size=(height, width), mode='bilinear')
new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1)
x = x + new_pe.to(x.dtype)
else:
x = x + self.positional_embedding[1:].to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
bz, seq, hidden = x.shape
x = x.transpose(1, 2).reshape(bz, hidden, height, width)
return x
def vit_base(drop_path_rate: float = 0.0):
return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate)
def vit_large(drop_path_rate: float = 0.0):
return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate)
def vit_large_336(drop_path_rate: float = 0.0):
return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate)
def vit_huge(drop_path_rate: float = 0.0):
return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate)

View File

@@ -53,8 +53,11 @@ class OfaForAllTasks(TorchModel):
raise NotImplementedError
# there is some diff between here and our ofa code,
# there will be no need to use param: use_bpe
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
if not model.use_ofasys:
self.tokenizer.add_tokens(
['<code_{}>'.format(i) for i in range(8192)])
self.tokenizer.add_tokens(
['<bin_{}>'.format(i) for i in range(1000)])
self.cfg.update({'num_bins': 1000, 'num_codes': 8192})
self.batch_size = self.cfg.model.get('batch_size', 1)
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)