mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 20:49:37 +01:00
[to #42322933]多模态预训练模型OFA增加支持6b模型的feature
多模态预训练模型OFA增加支持6b模型的feature
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10574571
This commit is contained in:
committed by
yingda.chen
parent
ce08cfbea8
commit
64868bf2ad
@@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig):
|
||||
entangle_position_embedding=False,
|
||||
interpolate_position=False,
|
||||
orig_patch_image_size=224,
|
||||
share_attn_bias=False,
|
||||
use_image_feature=True,
|
||||
disable_entangle=False,
|
||||
use_ofasys=False,
|
||||
vit_type='vit_base',
|
||||
vit_drop_path_rate=0.0,
|
||||
**kwargs):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
@@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig):
|
||||
self.interpolate_position = interpolate_position
|
||||
self.orig_patch_image_size = orig_patch_image_size
|
||||
|
||||
self.share_attn_bias = share_attn_bias
|
||||
self.use_image_feature = use_image_feature
|
||||
self.disable_entangle = disable_entangle
|
||||
self.use_ofasys = use_ofasys
|
||||
self.vit_type = vit_type
|
||||
self.vit_drop_path_rate = vit_drop_path_rate
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
|
||||
344
modelscope/models/multi_modal/ofa/modeling_ofa.py
Executable file → Normal file
344
modelscope/models/multi_modal/ofa/modeling_ofa.py
Executable file → Normal file
@@ -35,6 +35,8 @@ from transformers.utils import logging
|
||||
from .configuration_ofa import OFAConfig
|
||||
from .generate import utils
|
||||
from .resnet import ResNet
|
||||
from .utils.utils import DropPath
|
||||
from .vit import vit_base, vit_huge, vit_large, vit_large_336
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList):
|
||||
yield m
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
r"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
Args:
|
||||
x (`nn.Modules`): input nn layers.
|
||||
drop_prob (`float`): drop path ratio.
|
||||
training (`bool`): whether is training or inference.
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (1, x.shape[1], 1)
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
r"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
Args:
|
||||
drop_prob: drop path ratio.
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
|
||||
|
||||
class OFAAttention(nn.Module):
|
||||
r"""
|
||||
Multi-headed attention, with additional implementation for NormFormer.
|
||||
@@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel):
|
||||
self.padding_idx)
|
||||
|
||||
if config.add_type_embedding:
|
||||
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
|
||||
if config.use_image_feature:
|
||||
self.type_embedding = Embedding(2, embed_dim, padding_idx=None)
|
||||
else:
|
||||
self.type_embedding = Embedding(1, embed_dim, padding_idx=None)
|
||||
else:
|
||||
self.type_embedding = None
|
||||
|
||||
if config.resnet_type == 'resnet18':
|
||||
self.embed_images = ResNet(
|
||||
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet34':
|
||||
self.embed_images = ResNet(
|
||||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet50':
|
||||
self.embed_images = ResNet(
|
||||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet101':
|
||||
self.embed_images = ResNet(
|
||||
[3, 4, 23], drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet152':
|
||||
self.embed_images = ResNet(
|
||||
[3, 8, 36], drop_path_rate=config.resnet_drop_path_rate)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
if config.use_image_feature:
|
||||
if config.use_ofasys:
|
||||
vit_backbone = {
|
||||
'vit_base': vit_base,
|
||||
'vit_large': vit_large,
|
||||
'vit_large_336': vit_large_336,
|
||||
'vit_huge': vit_huge,
|
||||
}[config.vit_type]
|
||||
self.embed_images = vit_backbone(config.vit_drop_path_rate)
|
||||
|
||||
self.image_proj = Linear(1024, embed_dim)
|
||||
self.image_proj = Linear(self.embed_images.width, embed_dim)
|
||||
|
||||
if config.resnet_model_path:
|
||||
else:
|
||||
if config.resnet_type == 'resnet18':
|
||||
self.embed_images = ResNet(
|
||||
[2, 2, 2], drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet34':
|
||||
self.embed_images = ResNet(
|
||||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet50':
|
||||
self.embed_images = ResNet(
|
||||
[3, 4, 6], drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet101':
|
||||
self.embed_images = ResNet(
|
||||
[3, 4, 23],
|
||||
drop_path_rate=config.resnet_drop_path_rate)
|
||||
elif config.resnet_type == 'resnet152':
|
||||
self.embed_images = ResNet(
|
||||
[3, 8, 36],
|
||||
drop_path_rate=config.resnet_drop_path_rate)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.image_proj = Linear(1024, embed_dim)
|
||||
|
||||
if not config.use_ofasys and config.resnet_model_path:
|
||||
print('load resnet {}'.format(config.resnet_model_path))
|
||||
resnet_state_dict = torch.load(config.resnet_model_path)
|
||||
self.embed_images.load_state_dict(resnet_state_dict)
|
||||
@@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel):
|
||||
|
||||
self.embed_positions = Embedding(self.max_source_positions + 2,
|
||||
embed_dim)
|
||||
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
|
||||
embed_dim)
|
||||
self.pos_ln = LayerNorm(embed_dim)
|
||||
self.image_pos_ln = LayerNorm(embed_dim)
|
||||
|
||||
if config.use_image_feature:
|
||||
self.embed_image_positions = Embedding(
|
||||
config.image_bucket_size**2 + 1, embed_dim)
|
||||
if not config.use_ofasys:
|
||||
self.pos_ln = LayerNorm(embed_dim)
|
||||
|
||||
if config.use_image_feature:
|
||||
self.image_pos_ln = LayerNorm(embed_dim)
|
||||
self.pos_scaling = float(embed_dim / self.num_attention_heads
|
||||
* config.attn_scale_factor)**-0.5
|
||||
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
|
||||
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
if not (config.use_ofasys and config.entangle_position_embedding):
|
||||
self.pos_q_linear = nn.Linear(embed_dim, embed_dim)
|
||||
self.pos_k_linear = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
if self.encoder_layerdrop > 0.0:
|
||||
self.layers = LayerDropModuleList(p=self.encoder_layerdrop)
|
||||
@@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel):
|
||||
self.token_bucket_size = config.token_bucket_size
|
||||
token_num_rel_dis = 2 * config.token_bucket_size - 1
|
||||
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
|
||||
self.share_attn_bias = config.share_attn_bias
|
||||
num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers
|
||||
self.token_rel_pos_table_list = nn.ModuleList([
|
||||
Embedding(
|
||||
token_num_rel_dis, self.num_attention_heads, zero_init=True)
|
||||
for _ in range(config.encoder_layers)
|
||||
for _ in range(num_rel_pos_tables)
|
||||
])
|
||||
|
||||
self.image_bucket_size = config.image_bucket_size
|
||||
image_num_rel_dis = (2 * config.image_bucket_size
|
||||
- 1) * (2 * config.image_bucket_size - 1) + 3
|
||||
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
|
||||
image_num_rel_dis)
|
||||
self.image_rel_pos_table_list = nn.ModuleList([
|
||||
Embedding(
|
||||
image_num_rel_dis, self.num_attention_heads, zero_init=True)
|
||||
for _ in range(config.encoder_layers)
|
||||
])
|
||||
if config.use_image_feature:
|
||||
self.image_bucket_size = config.image_bucket_size
|
||||
image_num_rel_dis = (2 * config.image_bucket_size
|
||||
- 1) * (2 * config.image_bucket_size - 1) + 3
|
||||
image_rp_bucket = make_image_bucket_position(
|
||||
config.image_bucket_size, image_num_rel_dis)
|
||||
self.image_rel_pos_table_list = nn.ModuleList([
|
||||
Embedding(
|
||||
image_num_rel_dis,
|
||||
self.num_attention_heads,
|
||||
zero_init=True) for _ in range(num_rel_pos_tables)
|
||||
])
|
||||
|
||||
self.register_buffer('image_rp_bucket', image_rp_bucket)
|
||||
|
||||
if config.layernorm_embedding:
|
||||
self.layernorm_embedding = LayerNorm(embed_dim)
|
||||
@@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel):
|
||||
self.layernorm_embedding = None
|
||||
|
||||
self.register_buffer('token_rp_bucket', token_rp_bucket)
|
||||
self.register_buffer('image_rp_bucket', image_rp_bucket)
|
||||
self.entangle_position_embedding = config.entangle_position_embedding
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
self.use_ofasys = config.use_ofasys
|
||||
|
||||
def get_input_embeddings(self):
|
||||
r"""
|
||||
@@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel):
|
||||
if has_pads:
|
||||
x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x))
|
||||
|
||||
pos_embed = self.pos_ln(pos_embed)
|
||||
if patch_images is not None:
|
||||
image_pos_embed = self.image_pos_ln(image_pos_embed)
|
||||
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
|
||||
if patch_images_2 is not None:
|
||||
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
|
||||
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
|
||||
if self.use_ofasys:
|
||||
if patch_images is not None:
|
||||
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
|
||||
if patch_images_2 is not None:
|
||||
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
|
||||
else:
|
||||
pos_embed = self.pos_ln(pos_embed)
|
||||
if patch_images is not None:
|
||||
image_pos_embed = self.image_pos_ln(image_pos_embed)
|
||||
pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1)
|
||||
if patch_images_2 is not None:
|
||||
image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2)
|
||||
pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1)
|
||||
|
||||
pos_q = self.pos_q_linear(pos_embed).view(
|
||||
x.size(0), x.size(1), self.num_attention_heads, -1).transpose(
|
||||
1, 2) * self.pos_scaling
|
||||
pos_k = self.pos_k_linear(pos_embed).view(
|
||||
x.size(0), x.size(1), self.num_attention_heads,
|
||||
-1).transpose(1, 2)
|
||||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
|
||||
def build_abs_pos_bias(pos_embed):
|
||||
batch_size, seq_length = pos_embed.size(0), pos_embed.size(1)
|
||||
if not (self.use_ofasys and self.entangle_position_embedding):
|
||||
pos_q = self.pos_q_linear(pos_embed).view(
|
||||
batch_size, seq_length, self.num_attention_heads,
|
||||
-1).transpose(1, 2) * self.pos_scaling
|
||||
pos_k = self.pos_k_linear(pos_embed).view(
|
||||
batch_size, seq_length, self.num_attention_heads,
|
||||
-1).transpose(1, 2)
|
||||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
|
||||
else:
|
||||
abs_pos_bias = torch.zeros(
|
||||
batch_size,
|
||||
self.num_attention_heads,
|
||||
seq_length,
|
||||
seq_length,
|
||||
dtype=pos_embed.dtype,
|
||||
device=pos_embed.device)
|
||||
return abs_pos_bias
|
||||
|
||||
abs_pos_bias = build_abs_pos_bias(pos_embed)
|
||||
|
||||
# expand attention_mask
|
||||
if has_pads:
|
||||
@@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel):
|
||||
if output_hidden_states:
|
||||
encoder_states += (x, )
|
||||
self_attn_bias = abs_pos_bias.clone()
|
||||
|
||||
real_idx = 0 if self.share_attn_bias else idx
|
||||
|
||||
self_attn_bias[:, :, -input_ids.size(1):,
|
||||
-input_ids.size(1):] += self.get_rel_pos_bias(
|
||||
input_ids, idx)
|
||||
input_ids, real_idx)
|
||||
if patch_images_2 is not None:
|
||||
self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \
|
||||
self.get_image_rel_pos_bias(image_position_ids_2, idx)
|
||||
self.get_image_rel_pos_bias(image_position_ids_2, real_idx)
|
||||
self_attn_bias[:, :,
|
||||
image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa
|
||||
image_num_patches_2:image_num_patches_2 + image_num_patches] += \
|
||||
self.get_image_rel_pos_bias(image_position_ids, idx) # noqa
|
||||
self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa
|
||||
elif patch_images is not None:
|
||||
self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \
|
||||
self.get_image_rel_pos_bias(image_position_ids, idx)
|
||||
self.get_image_rel_pos_bias(image_position_ids, real_idx)
|
||||
self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1))
|
||||
|
||||
hidden_outputs = layer(
|
||||
@@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel):
|
||||
self._future_mask = torch.empty(0)
|
||||
self.share_input_output_embed = config.share_decoder_input_output_embed
|
||||
self.num_attention_heads = config.decoder_attention_heads
|
||||
self.use_ofasys = config.use_ofasys
|
||||
self.disable_entangle = config.disable_entangle
|
||||
|
||||
if embed_tokens is not None:
|
||||
self.embed_tokens = embed_tokens
|
||||
@@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel):
|
||||
else:
|
||||
self.layernorm_embedding = None
|
||||
|
||||
if config.use_ofasys:
|
||||
if config.add_type_embedding:
|
||||
self.type_embedding = Embedding(
|
||||
1, self.embed_dim, padding_idx=None)
|
||||
else:
|
||||
self.type_embedding = None
|
||||
|
||||
self.window_size = config.code_image_size // 8
|
||||
|
||||
self.embed_positions = Embedding(self.max_target_positions + 2,
|
||||
self.embed_dim)
|
||||
self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1,
|
||||
self.embed_dim)
|
||||
self.pos_ln = LayerNorm(self.embed_dim)
|
||||
self.image_pos_ln = LayerNorm(self.embed_dim)
|
||||
|
||||
if not config.use_ofasys:
|
||||
self.embed_image_positions = Embedding(
|
||||
config.image_bucket_size**2 + 1, self.embed_dim)
|
||||
if not config.use_ofasys:
|
||||
self.pos_ln = LayerNorm(self.embed_dim)
|
||||
self.image_pos_ln = LayerNorm(self.embed_dim)
|
||||
self.pos_scaling = float(self.embed_dim / self.num_attention_heads
|
||||
* config.attn_scale_factor)**-0.5
|
||||
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
if not (config.use_ofasys and config.entangle_position_embedding):
|
||||
self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
@@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel):
|
||||
self.token_bucket_size = config.token_bucket_size
|
||||
token_num_rel_dis = 2 * config.token_bucket_size - 1
|
||||
token_rp_bucket = make_token_bucket_position(config.token_bucket_size)
|
||||
|
||||
self.share_attn_bias = config.share_attn_bias
|
||||
num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers
|
||||
self.token_rel_pos_table_list = nn.ModuleList([
|
||||
Embedding(
|
||||
token_num_rel_dis, self.num_attention_heads, zero_init=True)
|
||||
for _ in range(config.decoder_layers)
|
||||
for _ in range(num_rel_pos_tables)
|
||||
])
|
||||
|
||||
self.image_bucket_size = config.image_bucket_size
|
||||
image_num_rel_dis = (2 * config.image_bucket_size
|
||||
- 1) * (2 * config.image_bucket_size - 1) + 3
|
||||
image_rp_bucket = make_image_bucket_position(config.image_bucket_size,
|
||||
image_num_rel_dis)
|
||||
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
|
||||
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
|
||||
image_position_idx = torch.cat(
|
||||
[torch.tensor([0]), image_position_idx.view(-1)])
|
||||
image_position_idx = torch.cat(
|
||||
[image_position_idx,
|
||||
torch.tensor([1024] * 768)])
|
||||
self.image_rel_pos_table_list = nn.ModuleList([
|
||||
Embedding(
|
||||
image_num_rel_dis, self.num_attention_heads, zero_init=True)
|
||||
for _ in range(config.decoder_layers)
|
||||
])
|
||||
if config.use_image_feature:
|
||||
if not config.use_ofasys:
|
||||
self.image_bucket_size = config.image_bucket_size
|
||||
image_num_rel_dis = (2 * config.image_bucket_size - 1) * (
|
||||
2 * config.image_bucket_size - 1) + 3
|
||||
image_rp_bucket = make_image_bucket_position(
|
||||
config.image_bucket_size, image_num_rel_dis)
|
||||
image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \
|
||||
torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa
|
||||
image_position_idx = torch.cat(
|
||||
[torch.tensor([0]),
|
||||
image_position_idx.view(-1)])
|
||||
image_position_idx = torch.cat(
|
||||
[image_position_idx,
|
||||
torch.tensor([1024] * 768)])
|
||||
self.register_buffer('image_position_idx', image_position_idx)
|
||||
|
||||
self.image_rel_pos_table_list = nn.ModuleList([
|
||||
Embedding(
|
||||
image_num_rel_dis,
|
||||
self.num_attention_heads,
|
||||
zero_init=True) for _ in range(num_rel_pos_tables)
|
||||
])
|
||||
self.register_buffer('image_rp_bucket', image_rp_bucket)
|
||||
|
||||
self.register_buffer('token_rp_bucket', token_rp_bucket)
|
||||
self.register_buffer('image_rp_bucket', image_rp_bucket)
|
||||
self.register_buffer('image_position_idx', image_position_idx)
|
||||
self.entangle_position_embedding = config.entangle_position_embedding
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
@@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel):
|
||||
|
||||
batch_size = tgt_pos_embed.size(0)
|
||||
tgt_len = tgt_pos_embed.size(1)
|
||||
tgt_pos_embed = self.image_pos_ln(
|
||||
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
|
||||
if not self.use_ofasys:
|
||||
tgt_pos_embed = self.image_pos_ln(
|
||||
tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed)
|
||||
|
||||
if src_pos_embed is not None:
|
||||
src_len = src_pos_embed.size(1)
|
||||
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
|
||||
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
|
||||
1, 2) * self.pos_scaling
|
||||
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
|
||||
batch_size, src_len, self.num_attention_heads,
|
||||
-1).transpose(1, 2)
|
||||
if not (self.entangle_position_embedding and self.use_ofasys):
|
||||
pos_q = self.cross_pos_q_linear(tgt_pos_embed).view(
|
||||
batch_size, tgt_len, self.num_attention_heads,
|
||||
-1).transpose(1, 2) * self.pos_scaling
|
||||
pos_k = self.cross_pos_k_linear(src_pos_embed).view(
|
||||
batch_size, src_len, self.num_attention_heads,
|
||||
-1).transpose(1, 2)
|
||||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
|
||||
else:
|
||||
abs_pos_bias = torch.zeros(
|
||||
batch_size,
|
||||
self.num_attention_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
dtype=tgt_pos_embed.dtype,
|
||||
device=tgt_pos_embed.device)
|
||||
else:
|
||||
src_len = tgt_pos_embed.size(1)
|
||||
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
|
||||
batch_size, tgt_len, self.num_attention_heads, -1).transpose(
|
||||
1, 2) * self.pos_scaling
|
||||
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
|
||||
batch_size, src_len, self.num_attention_heads,
|
||||
-1).transpose(1, 2)
|
||||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
|
||||
# batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1)
|
||||
if not (self.entangle_position_embedding and self.use_ofasys):
|
||||
pos_q = self.self_pos_q_linear(tgt_pos_embed).view(
|
||||
batch_size, tgt_len, self.num_attention_heads,
|
||||
-1).transpose(1, 2) * self.pos_scaling
|
||||
pos_k = self.self_pos_k_linear(tgt_pos_embed).view(
|
||||
batch_size, tgt_len, self.num_attention_heads,
|
||||
-1).transpose(1, 2)
|
||||
abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3))
|
||||
else:
|
||||
abs_pos_bias = torch.zeros(
|
||||
batch_size,
|
||||
self.num_attention_heads,
|
||||
tgt_len,
|
||||
tgt_len,
|
||||
dtype=tgt_pos_embed.dtype,
|
||||
device=tgt_pos_embed.device)
|
||||
|
||||
return abs_pos_bias
|
||||
|
||||
@@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel):
|
||||
past_key_values) > 0 else None
|
||||
|
||||
self_attn_bias = self_abs_pos_bias.clone()
|
||||
real_idx = 0 if self.share_attn_bias else idx
|
||||
if code_masks is None or not code_masks.any():
|
||||
self_attn_bias += self.get_rel_pos_bias(
|
||||
all_prev_output_tokens, idx).unsqueeze(0)
|
||||
all_prev_output_tokens, real_idx).unsqueeze(0)
|
||||
elif code_masks is not None and code_masks.all():
|
||||
self_attn_bias += self.get_image_rel_pos_bias(
|
||||
all_prev_output_tokens, idx).unsqueeze(0)
|
||||
all_prev_output_tokens, real_idx).unsqueeze(0)
|
||||
else:
|
||||
self_attn_bias[~code_masks] += self.get_rel_pos_bias(
|
||||
all_prev_output_tokens, idx).unsqueeze(0)
|
||||
all_prev_output_tokens, real_idx).unsqueeze(0)
|
||||
self_attn_bias[code_masks] += self.get_image_rel_pos_bias(
|
||||
all_prev_output_tokens, idx).unsqueeze(0)
|
||||
all_prev_output_tokens, real_idx).unsqueeze(0)
|
||||
self_attn_bias = self_attn_bias.reshape(
|
||||
-1,
|
||||
*self_attn_bias.size()[-2:])
|
||||
@@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel):
|
||||
|
||||
self.encoder = OFAEncoder(config, shared)
|
||||
self.decoder = OFADecoder(config, shared)
|
||||
self.use_ofasys = config.use_ofasys
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def expand_mask(mask: torch.Tensor,
|
||||
@@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor,
|
||||
src_len).to(dtype)
|
||||
return expanded_mask.masked_fill(expanded_mask.bool(),
|
||||
torch.finfo(dtype).min)
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
r"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
Args:
|
||||
x (`nn.Modules`): input nn layers.
|
||||
drop_prob (`float`): drop path ratio.
|
||||
training (`bool`): whether is training or inference.
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (1, x.shape[1], 1)
|
||||
random_tensor = keep_prob + torch.rand(
|
||||
shape, dtype=x.dtype, device=x.device)
|
||||
random_tensor.floor_() # binarize
|
||||
output = x.div(keep_prob) * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
r"""
|
||||
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
Args:
|
||||
drop_prob: drop path ratio.
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super().__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return 'p={}'.format(self.drop_prob)
|
||||
|
||||
155
modelscope/models/multi_modal/ofa/vit.py
Normal file
155
modelscope/models/multi_modal/ofa/vit.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairseq.modules import LayerNorm
|
||||
from torch import nn
|
||||
|
||||
from .utils.utils import DropPath
|
||||
|
||||
__all__ = [
|
||||
'vit_base',
|
||||
'vit_large',
|
||||
'vit_large_336',
|
||||
'vit_huge',
|
||||
]
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
attn_mask: torch.Tensor = None,
|
||||
drop_path_rate=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([
|
||||
('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||
('gelu', QuickGELU()),
|
||||
('c_proj', nn.Linear(d_model * 4, d_model)),
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
self.drop_path = DropPath(drop_path_rate)
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = (
|
||||
self.attn_mask.to(dtype=x.dtype, device=x.device)
|
||||
if self.attn_mask is not None else None)
|
||||
return self.attn(
|
||||
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.drop_path(self.attention(self.ln_1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.ln_2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask: torch.Tensor = None,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[
|
||||
ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_resolution: int,
|
||||
patch_size: int,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
drop_path_rate: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.patch_size = patch_size
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
scale = width**-0.5
|
||||
self.width = width
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
||||
(input_resolution // patch_size)**2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
self.transformer = Transformer(
|
||||
width, layers, heads, drop_path_rate=drop_path_rate)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
resolution = x.shape[-2]
|
||||
height, width = x.shape[-2] // self.patch_size, x.shape[
|
||||
-1] // self.patch_size
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
|
||||
if resolution != self.input_resolution:
|
||||
old_pe = self.positional_embedding[1:]
|
||||
patch_num = self.input_resolution // self.patch_size
|
||||
old_pe = old_pe.reshape(1, patch_num, patch_num,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
new_pe = F.interpolate(
|
||||
old_pe, size=(height, width), mode='bilinear')
|
||||
new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1)
|
||||
x = x + new_pe.to(x.dtype)
|
||||
else:
|
||||
x = x + self.positional_embedding[1:].to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
bz, seq, hidden = x.shape
|
||||
x = x.transpose(1, 2).reshape(bz, hidden, height, width)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def vit_base(drop_path_rate: float = 0.0):
|
||||
return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate)
|
||||
|
||||
|
||||
def vit_large(drop_path_rate: float = 0.0):
|
||||
return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate)
|
||||
|
||||
|
||||
def vit_large_336(drop_path_rate: float = 0.0):
|
||||
return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate)
|
||||
|
||||
|
||||
def vit_huge(drop_path_rate: float = 0.0):
|
||||
return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate)
|
||||
@@ -53,8 +53,11 @@ class OfaForAllTasks(TorchModel):
|
||||
raise NotImplementedError
|
||||
# there is some diff between here and our ofa code,
|
||||
# there will be no need to use param: use_bpe
|
||||
self.tokenizer.add_tokens(['<code_{}>'.format(i) for i in range(8192)])
|
||||
self.tokenizer.add_tokens(['<bin_{}>'.format(i) for i in range(1000)])
|
||||
if not model.use_ofasys:
|
||||
self.tokenizer.add_tokens(
|
||||
['<code_{}>'.format(i) for i in range(8192)])
|
||||
self.tokenizer.add_tokens(
|
||||
['<bin_{}>'.format(i) for i in range(1000)])
|
||||
self.cfg.update({'num_bins': 1000, 'num_codes': 8192})
|
||||
self.batch_size = self.cfg.model.get('batch_size', 1)
|
||||
self.patch_image_size = self.cfg.model.get('patch_image_size', 480)
|
||||
|
||||
Reference in New Issue
Block a user