diff --git a/modelscope/models/multi_modal/ofa/configuration_ofa.py b/modelscope/models/multi_modal/ofa/configuration_ofa.py index 4899f416..2edc651e 100644 --- a/modelscope/models/multi_modal/ofa/configuration_ofa.py +++ b/modelscope/models/multi_modal/ofa/configuration_ofa.py @@ -136,6 +136,12 @@ class OFAConfig(PretrainedConfig): entangle_position_embedding=False, interpolate_position=False, orig_patch_image_size=224, + share_attn_bias=False, + use_image_feature=True, + disable_entangle=False, + use_ofasys=False, + vit_type='vit_base', + vit_drop_path_rate=0.0, **kwargs): self.vocab_size = vocab_size self.max_position_embeddings = max_position_embeddings @@ -178,6 +184,13 @@ class OFAConfig(PretrainedConfig): self.interpolate_position = interpolate_position self.orig_patch_image_size = orig_patch_image_size + self.share_attn_bias = share_attn_bias + self.use_image_feature = use_image_feature + self.disable_entangle = disable_entangle + self.use_ofasys = use_ofasys + self.vit_type = vit_type + self.vit_drop_path_rate = vit_drop_path_rate + super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, diff --git a/modelscope/models/multi_modal/ofa/modeling_ofa.py b/modelscope/models/multi_modal/ofa/modeling_ofa.py old mode 100755 new mode 100644 index 0a7a2ce6..69005ef0 --- a/modelscope/models/multi_modal/ofa/modeling_ofa.py +++ b/modelscope/models/multi_modal/ofa/modeling_ofa.py @@ -35,6 +35,8 @@ from transformers.utils import logging from .configuration_ofa import OFAConfig from .generate import utils from .resnet import ResNet +from .utils.utils import DropPath +from .vit import vit_base, vit_huge, vit_large, vit_large_336 logger = logging.get_logger(__name__) @@ -249,45 +251,6 @@ class LayerDropModuleList(nn.ModuleList): yield m -def drop_path(x, drop_prob: float = 0.0, training: bool = False): - r""" - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - Args: - x (`nn.Modules`): input nn layers. - drop_prob (`float`): drop path ratio. - training (`bool`): whether is training or inference. - """ - if drop_prob == 0.0 or not training: - return x - keep_prob = 1 - drop_prob - shape = (1, x.shape[1], 1) - random_tensor = keep_prob + torch.rand( - shape, dtype=x.dtype, device=x.device) - random_tensor.floor_() # binarize - output = x.div(keep_prob) * random_tensor - return output - - -class DropPath(nn.Module): - r""" - Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). - - Args: - drop_prob: drop path ratio. - """ - - def __init__(self, drop_prob=None): - super().__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return 'p={}'.format(self.drop_prob) - - class OFAAttention(nn.Module): r""" Multi-headed attention, with additional implementation for NormFormer. @@ -898,31 +861,49 @@ class OFAEncoder(OFAPreTrainedModel): self.padding_idx) if config.add_type_embedding: - self.type_embedding = Embedding(2, embed_dim, padding_idx=None) + if config.use_image_feature: + self.type_embedding = Embedding(2, embed_dim, padding_idx=None) + else: + self.type_embedding = Embedding(1, embed_dim, padding_idx=None) else: self.type_embedding = None - if config.resnet_type == 'resnet18': - self.embed_images = ResNet( - [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet34': - self.embed_images = ResNet( - [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet50': - self.embed_images = ResNet( - [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet101': - self.embed_images = ResNet( - [3, 4, 23], drop_path_rate=config.resnet_drop_path_rate) - elif config.resnet_type == 'resnet152': - self.embed_images = ResNet( - [3, 8, 36], drop_path_rate=config.resnet_drop_path_rate) - else: - raise NotImplementedError + if config.use_image_feature: + if config.use_ofasys: + vit_backbone = { + 'vit_base': vit_base, + 'vit_large': vit_large, + 'vit_large_336': vit_large_336, + 'vit_huge': vit_huge, + }[config.vit_type] + self.embed_images = vit_backbone(config.vit_drop_path_rate) - self.image_proj = Linear(1024, embed_dim) + self.image_proj = Linear(self.embed_images.width, embed_dim) - if config.resnet_model_path: + else: + if config.resnet_type == 'resnet18': + self.embed_images = ResNet( + [2, 2, 2], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet34': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet50': + self.embed_images = ResNet( + [3, 4, 6], drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet101': + self.embed_images = ResNet( + [3, 4, 23], + drop_path_rate=config.resnet_drop_path_rate) + elif config.resnet_type == 'resnet152': + self.embed_images = ResNet( + [3, 8, 36], + drop_path_rate=config.resnet_drop_path_rate) + else: + raise NotImplementedError + + self.image_proj = Linear(1024, embed_dim) + + if not config.use_ofasys and config.resnet_model_path: print('load resnet {}'.format(config.resnet_model_path)) resnet_state_dict = torch.load(config.resnet_model_path) self.embed_images.load_state_dict(resnet_state_dict) @@ -933,14 +914,21 @@ class OFAEncoder(OFAPreTrainedModel): self.embed_positions = Embedding(self.max_source_positions + 2, embed_dim) - self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, - embed_dim) - self.pos_ln = LayerNorm(embed_dim) - self.image_pos_ln = LayerNorm(embed_dim) + + if config.use_image_feature: + self.embed_image_positions = Embedding( + config.image_bucket_size**2 + 1, embed_dim) + if not config.use_ofasys: + self.pos_ln = LayerNorm(embed_dim) + + if config.use_image_feature: + self.image_pos_ln = LayerNorm(embed_dim) self.pos_scaling = float(embed_dim / self.num_attention_heads * config.attn_scale_factor)**-0.5 - self.pos_q_linear = nn.Linear(embed_dim, embed_dim) - self.pos_k_linear = nn.Linear(embed_dim, embed_dim) + + if not (config.use_ofasys and config.entangle_position_embedding): + self.pos_q_linear = nn.Linear(embed_dim, embed_dim) + self.pos_k_linear = nn.Linear(embed_dim, embed_dim) if self.encoder_layerdrop > 0.0: self.layers = LayerDropModuleList(p=self.encoder_layerdrop) @@ -965,22 +953,28 @@ class OFAEncoder(OFAPreTrainedModel): self.token_bucket_size = config.token_bucket_size token_num_rel_dis = 2 * config.token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + self.share_attn_bias = config.share_attn_bias + num_rel_pos_tables = 1 if config.share_attn_bias else config.encoder_layers self.token_rel_pos_table_list = nn.ModuleList([ Embedding( token_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.encoder_layers) + for _ in range(num_rel_pos_tables) ]) - self.image_bucket_size = config.image_bucket_size - image_num_rel_dis = (2 * config.image_bucket_size - - 1) * (2 * config.image_bucket_size - 1) + 3 - image_rp_bucket = make_image_bucket_position(config.image_bucket_size, - image_num_rel_dis) - self.image_rel_pos_table_list = nn.ModuleList([ - Embedding( - image_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.encoder_layers) - ]) + if config.use_image_feature: + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size + - 1) * (2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + config.image_bucket_size, image_num_rel_dis) + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range(num_rel_pos_tables) + ]) + + self.register_buffer('image_rp_bucket', image_rp_bucket) if config.layernorm_embedding: self.layernorm_embedding = LayerNorm(embed_dim) @@ -988,12 +982,12 @@ class OFAEncoder(OFAPreTrainedModel): self.layernorm_embedding = None self.register_buffer('token_rp_bucket', token_rp_bucket) - self.register_buffer('image_rp_bucket', image_rp_bucket) self.entangle_position_embedding = config.entangle_position_embedding self.gradient_checkpointing = False # Initialize weights and apply final processing self.post_init() + self.use_ofasys = config.use_ofasys def get_input_embeddings(self): r""" @@ -1305,21 +1299,41 @@ class OFAEncoder(OFAPreTrainedModel): if has_pads: x = x * (1 - encoder_padding_mask.unsqueeze(-1).type_as(x)) - pos_embed = self.pos_ln(pos_embed) - if patch_images is not None: - image_pos_embed = self.image_pos_ln(image_pos_embed) - pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) - if patch_images_2 is not None: - image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) - pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + if self.use_ofasys: + if patch_images is not None: + pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) + else: + pos_embed = self.pos_ln(pos_embed) + if patch_images is not None: + image_pos_embed = self.image_pos_ln(image_pos_embed) + pos_embed = torch.cat([image_pos_embed, pos_embed], dim=1) + if patch_images_2 is not None: + image_pos_embed_2 = self.image_pos_ln(image_pos_embed_2) + pos_embed = torch.cat([image_pos_embed_2, pos_embed], dim=1) - pos_q = self.pos_q_linear(pos_embed).view( - x.size(0), x.size(1), self.num_attention_heads, -1).transpose( - 1, 2) * self.pos_scaling - pos_k = self.pos_k_linear(pos_embed).view( - x.size(0), x.size(1), self.num_attention_heads, - -1).transpose(1, 2) - abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + def build_abs_pos_bias(pos_embed): + batch_size, seq_length = pos_embed.size(0), pos_embed.size(1) + if not (self.use_ofasys and self.entangle_position_embedding): + pos_q = self.pos_q_linear(pos_embed).view( + batch_size, seq_length, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.pos_k_linear(pos_embed).view( + batch_size, seq_length, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + seq_length, + seq_length, + dtype=pos_embed.dtype, + device=pos_embed.device) + return abs_pos_bias + + abs_pos_bias = build_abs_pos_bias(pos_embed) # expand attention_mask if has_pads: @@ -1334,19 +1348,22 @@ class OFAEncoder(OFAPreTrainedModel): if output_hidden_states: encoder_states += (x, ) self_attn_bias = abs_pos_bias.clone() + + real_idx = 0 if self.share_attn_bias else idx + self_attn_bias[:, :, -input_ids.size(1):, -input_ids.size(1):] += self.get_rel_pos_bias( - input_ids, idx) + input_ids, real_idx) if patch_images_2 is not None: self_attn_bias[:, :, :image_num_patches_2, :image_num_patches_2] += \ - self.get_image_rel_pos_bias(image_position_ids_2, idx) + self.get_image_rel_pos_bias(image_position_ids_2, real_idx) self_attn_bias[:, :, image_num_patches_2:image_num_patches_2 + image_num_patches, # noqa image_num_patches_2:image_num_patches_2 + image_num_patches] += \ - self.get_image_rel_pos_bias(image_position_ids, idx) # noqa + self.get_image_rel_pos_bias(image_position_ids, real_idx) # noqa elif patch_images is not None: self_attn_bias[:, :, :x.size(1) - input_ids.size(1), :x.size(1) - input_ids.size(1)] += \ - self.get_image_rel_pos_bias(image_position_ids, idx) + self.get_image_rel_pos_bias(image_position_ids, real_idx) self_attn_bias = self_attn_bias.reshape(-1, x.size(1), x.size(1)) hidden_outputs = layer( @@ -1398,6 +1415,8 @@ class OFADecoder(OFAPreTrainedModel): self._future_mask = torch.empty(0) self.share_input_output_embed = config.share_decoder_input_output_embed self.num_attention_heads = config.decoder_attention_heads + self.use_ofasys = config.use_ofasys + self.disable_entangle = config.disable_entangle if embed_tokens is not None: self.embed_tokens = embed_tokens @@ -1415,18 +1434,31 @@ class OFADecoder(OFAPreTrainedModel): else: self.layernorm_embedding = None + if config.use_ofasys: + if config.add_type_embedding: + self.type_embedding = Embedding( + 1, self.embed_dim, padding_idx=None) + else: + self.type_embedding = None + self.window_size = config.code_image_size // 8 self.embed_positions = Embedding(self.max_target_positions + 2, self.embed_dim) - self.embed_image_positions = Embedding(config.image_bucket_size**2 + 1, - self.embed_dim) - self.pos_ln = LayerNorm(self.embed_dim) - self.image_pos_ln = LayerNorm(self.embed_dim) + + if not config.use_ofasys: + self.embed_image_positions = Embedding( + config.image_bucket_size**2 + 1, self.embed_dim) + if not config.use_ofasys: + self.pos_ln = LayerNorm(self.embed_dim) + self.image_pos_ln = LayerNorm(self.embed_dim) self.pos_scaling = float(self.embed_dim / self.num_attention_heads * config.attn_scale_factor)**-0.5 - self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) - self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + + if not (config.use_ofasys and config.entangle_position_embedding): + self.self_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.self_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) + self.cross_pos_q_linear = nn.Linear(self.embed_dim, self.embed_dim) self.cross_pos_k_linear = nn.Linear(self.embed_dim, self.embed_dim) @@ -1463,33 +1495,41 @@ class OFADecoder(OFAPreTrainedModel): self.token_bucket_size = config.token_bucket_size token_num_rel_dis = 2 * config.token_bucket_size - 1 token_rp_bucket = make_token_bucket_position(config.token_bucket_size) + + self.share_attn_bias = config.share_attn_bias + num_rel_pos_tables = 1 if config.share_attn_bias else config.decoder_layers self.token_rel_pos_table_list = nn.ModuleList([ Embedding( token_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.decoder_layers) + for _ in range(num_rel_pos_tables) ]) - self.image_bucket_size = config.image_bucket_size - image_num_rel_dis = (2 * config.image_bucket_size - - 1) * (2 * config.image_bucket_size - 1) + 3 - image_rp_bucket = make_image_bucket_position(config.image_bucket_size, - image_num_rel_dis) - image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ - torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa - image_position_idx = torch.cat( - [torch.tensor([0]), image_position_idx.view(-1)]) - image_position_idx = torch.cat( - [image_position_idx, - torch.tensor([1024] * 768)]) - self.image_rel_pos_table_list = nn.ModuleList([ - Embedding( - image_num_rel_dis, self.num_attention_heads, zero_init=True) - for _ in range(config.decoder_layers) - ]) + if config.use_image_feature: + if not config.use_ofasys: + self.image_bucket_size = config.image_bucket_size + image_num_rel_dis = (2 * config.image_bucket_size - 1) * ( + 2 * config.image_bucket_size - 1) + 3 + image_rp_bucket = make_image_bucket_position( + config.image_bucket_size, image_num_rel_dis) + image_position_idx = torch.arange(self.window_size).unsqueeze(0).expand(self.window_size, self.window_size) + \ + torch.arange(self.window_size).unsqueeze(1) * config.image_bucket_size + 1 # noqa + image_position_idx = torch.cat( + [torch.tensor([0]), + image_position_idx.view(-1)]) + image_position_idx = torch.cat( + [image_position_idx, + torch.tensor([1024] * 768)]) + self.register_buffer('image_position_idx', image_position_idx) + + self.image_rel_pos_table_list = nn.ModuleList([ + Embedding( + image_num_rel_dis, + self.num_attention_heads, + zero_init=True) for _ in range(num_rel_pos_tables) + ]) + self.register_buffer('image_rp_bucket', image_rp_bucket) self.register_buffer('token_rp_bucket', token_rp_bucket) - self.register_buffer('image_rp_bucket', image_rp_bucket) - self.register_buffer('image_position_idx', image_position_idx) self.entangle_position_embedding = config.entangle_position_embedding self.gradient_checkpointing = False @@ -1556,26 +1596,46 @@ class OFADecoder(OFAPreTrainedModel): batch_size = tgt_pos_embed.size(0) tgt_len = tgt_pos_embed.size(1) - tgt_pos_embed = self.image_pos_ln( - tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) + if not self.use_ofasys: + tgt_pos_embed = self.image_pos_ln( + tgt_pos_embed) if use_image else self.pos_ln(tgt_pos_embed) if src_pos_embed is not None: src_len = src_pos_embed.size(1) - pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( - batch_size, tgt_len, self.num_attention_heads, -1).transpose( - 1, 2) * self.pos_scaling - pos_k = self.cross_pos_k_linear(src_pos_embed).view( - batch_size, src_len, self.num_attention_heads, - -1).transpose(1, 2) + if not (self.entangle_position_embedding and self.use_ofasys): + pos_q = self.cross_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.cross_pos_k_linear(src_pos_embed).view( + batch_size, src_len, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + tgt_len, + src_len, + dtype=tgt_pos_embed.dtype, + device=tgt_pos_embed.device) else: - src_len = tgt_pos_embed.size(1) - pos_q = self.self_pos_q_linear(tgt_pos_embed).view( - batch_size, tgt_len, self.num_attention_heads, -1).transpose( - 1, 2) * self.pos_scaling - pos_k = self.self_pos_k_linear(tgt_pos_embed).view( - batch_size, src_len, self.num_attention_heads, - -1).transpose(1, 2) - abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + # batch_size, seq_length = tgt_pos_embed.size(0), tgt_pos_embed.size(1) + if not (self.entangle_position_embedding and self.use_ofasys): + pos_q = self.self_pos_q_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) * self.pos_scaling + pos_k = self.self_pos_k_linear(tgt_pos_embed).view( + batch_size, tgt_len, self.num_attention_heads, + -1).transpose(1, 2) + abs_pos_bias = torch.matmul(pos_q, pos_k.transpose(2, 3)) + else: + abs_pos_bias = torch.zeros( + batch_size, + self.num_attention_heads, + tgt_len, + tgt_len, + dtype=tgt_pos_embed.dtype, + device=tgt_pos_embed.device) return abs_pos_bias @@ -1809,17 +1869,18 @@ class OFADecoder(OFAPreTrainedModel): past_key_values) > 0 else None self_attn_bias = self_abs_pos_bias.clone() + real_idx = 0 if self.share_attn_bias else idx if code_masks is None or not code_masks.any(): self_attn_bias += self.get_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) elif code_masks is not None and code_masks.all(): self_attn_bias += self.get_image_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) else: self_attn_bias[~code_masks] += self.get_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) self_attn_bias[code_masks] += self.get_image_rel_pos_bias( - all_prev_output_tokens, idx).unsqueeze(0) + all_prev_output_tokens, real_idx).unsqueeze(0) self_attn_bias = self_attn_bias.reshape( -1, *self_attn_bias.size()[-2:]) @@ -1892,6 +1953,7 @@ class OFAModel(OFAPreTrainedModel): self.encoder = OFAEncoder(config, shared) self.decoder = OFADecoder(config, shared) + self.use_ofasys = config.use_ofasys # Initialize weights and apply final processing self.post_init() diff --git a/modelscope/models/multi_modal/ofa/utils/utils.py b/modelscope/models/multi_modal/ofa/utils/utils.py index 6d8943a1..c5aa8483 100644 --- a/modelscope/models/multi_modal/ofa/utils/utils.py +++ b/modelscope/models/multi_modal/ofa/utils/utils.py @@ -2,6 +2,7 @@ from typing import Optional import torch +import torch.nn as nn def expand_mask(mask: torch.Tensor, @@ -17,3 +18,42 @@ def expand_mask(mask: torch.Tensor, src_len).to(dtype) return expanded_mask.masked_fill(expanded_mask.bool(), torch.finfo(dtype).min) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + x (`nn.Modules`): input nn layers. + drop_prob (`float`): drop path ratio. + training (`bool`): whether is training or inference. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (1, x.shape[1], 1) + random_tensor = keep_prob + torch.rand( + shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + r""" + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Args: + drop_prob: drop path ratio. + """ + + def __init__(self, drop_prob=None): + super().__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return 'p={}'.format(self.drop_prob) diff --git a/modelscope/models/multi_modal/ofa/vit.py b/modelscope/models/multi_modal/ofa/vit.py new file mode 100644 index 00000000..b6bba7ee --- /dev/null +++ b/modelscope/models/multi_modal/ofa/vit.py @@ -0,0 +1,155 @@ +from collections import OrderedDict + +import torch +import torch.nn.functional as F +from fairseq.modules import LayerNorm +from torch import nn + +from .utils.utils import DropPath + +__all__ = [ + 'vit_base', + 'vit_large', + 'vit_large_336', + 'vit_huge', +] + + +class QuickGELU(nn.Module): + + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + + def __init__(self, + d_model: int, + n_head: int, + attn_mask: torch.Tensor = None, + drop_path_rate=0.0): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict([ + ('c_fc', nn.Linear(d_model, d_model * 4)), + ('gelu', QuickGELU()), + ('c_proj', nn.Linear(d_model * 4, d_model)), + ])) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + self.drop_path = DropPath(drop_path_rate) + + def attention(self, x: torch.Tensor): + self.attn_mask = ( + self.attn_mask.to(dtype=x.dtype, device=x.device) + if self.attn_mask is not None else None) + return self.attn( + x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] + + def forward(self, x: torch.Tensor): + x = x + self.drop_path(self.attention(self.ln_1(x))) + x = x + self.drop_path(self.mlp(self.ln_2(x))) + return x + + +class Transformer(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + attn_mask: torch.Tensor = None, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ + ResidualAttentionBlock(width, heads, attn_mask, drop_path_rate) + for _ in range(layers) + ]) + + def forward(self, x: torch.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Module): + + def __init__( + self, + input_resolution: int, + patch_size: int, + width: int, + layers: int, + heads: int, + drop_path_rate: float = 0.0, + ): + super().__init__() + self.input_resolution = input_resolution + self.patch_size = patch_size + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.width = width + self.positional_embedding = nn.Parameter(scale * torch.randn( + (input_resolution // patch_size)**2 + 1, width)) + self.ln_pre = LayerNorm(width) + self.transformer = Transformer( + width, layers, heads, drop_path_rate=drop_path_rate) + + def forward(self, x: torch.Tensor): + resolution = x.shape[-2] + height, width = x.shape[-2] // self.patch_size, x.shape[ + -1] // self.patch_size + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + if resolution != self.input_resolution: + old_pe = self.positional_embedding[1:] + patch_num = self.input_resolution // self.patch_size + old_pe = old_pe.reshape(1, patch_num, patch_num, + -1).permute(0, 3, 1, 2) + new_pe = F.interpolate( + old_pe, size=(height, width), mode='bilinear') + new_pe = new_pe.permute(0, 2, 3, 1).reshape(height * width, -1) + x = x + new_pe.to(x.dtype) + else: + x = x + self.positional_embedding[1:].to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + bz, seq, hidden = x.shape + x = x.transpose(1, 2).reshape(bz, hidden, height, width) + + return x + + +def vit_base(drop_path_rate: float = 0.0): + return VisionTransformer(224, 16, 768, 9, 12, drop_path_rate) + + +def vit_large(drop_path_rate: float = 0.0): + return VisionTransformer(224, 14, 1024, 18, 16, drop_path_rate) + + +def vit_large_336(drop_path_rate: float = 0.0): + return VisionTransformer(336, 14, 1024, 18, 16, drop_path_rate) + + +def vit_huge(drop_path_rate: float = 0.0): + return VisionTransformer(224, 14, 1280, 24, 16, drop_path_rate) diff --git a/modelscope/models/multi_modal/ofa_for_all_tasks.py b/modelscope/models/multi_modal/ofa_for_all_tasks.py index 56d19ad8..2c6034e8 100644 --- a/modelscope/models/multi_modal/ofa_for_all_tasks.py +++ b/modelscope/models/multi_modal/ofa_for_all_tasks.py @@ -53,8 +53,11 @@ class OfaForAllTasks(TorchModel): raise NotImplementedError # there is some diff between here and our ofa code, # there will be no need to use param: use_bpe - self.tokenizer.add_tokens([''.format(i) for i in range(8192)]) - self.tokenizer.add_tokens([''.format(i) for i in range(1000)]) + if not model.use_ofasys: + self.tokenizer.add_tokens( + [''.format(i) for i in range(8192)]) + self.tokenizer.add_tokens( + [''.format(i) for i in range(1000)]) self.cfg.update({'num_bins': 1000, 'num_codes': 8192}) self.batch_size = self.cfg.model.get('batch_size', 1) self.patch_image_size = self.cfg.model.get('patch_image_size', 480)