diff --git a/data/test/videos/video_caption_and_qa_test.mp4 b/data/test/videos/video_caption_and_qa_test.mp4 new file mode 100644 index 00000000..125783af --- /dev/null +++ b/data/test/videos/video_caption_and_qa_test.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7c822c66fcf04de28016b224ef372cb1c93b7f13f2cba4e11f53a37fec8e769e +size 828272 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index c79b1e42..97d8a059 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -120,6 +120,7 @@ class Models(object): multi_stage_diffusion = 'multi-stage-diffusion-text-to-image-synthesis' team = 'team-multi-modal-similarity' video_clip = 'video-clip-multi-modal-embedding' + hitea = 'hitea' # science models unifold = 'unifold' @@ -322,6 +323,8 @@ class Pipelines(object): image_text_retrieval = 'image-text-retrieval' ofa_ocr_recognition = 'ofa-ocr-recognition' ofa_asr = 'ofa-asr' + video_captioning = 'video-captioning' + video_question_answering = 'video-question-answering' # science tasks protein_structure = 'unifold-protein-structure' @@ -446,6 +449,7 @@ class Preprocessors(object): ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' clip_preprocessor = 'clip-preprocessor' mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' + hitea_tasks_preprocessor = 'hitea-tasks-preprocessor' # science preprocessor unifold_preprocessor = 'unifold-preprocessor' diff --git a/modelscope/models/multi_modal/__init__.py b/modelscope/models/multi_modal/__init__.py index 0053da43..ba30f7b7 100644 --- a/modelscope/models/multi_modal/__init__.py +++ b/modelscope/models/multi_modal/__init__.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from .team import TEAMForMultiModalSimilarity from .diffusion import DiffusionForTextToImageSynthesis from .mmr import VideoCLIPForMultiModalEmbedding - from .mplug_for_all_tasks import MPlugForAllTasks + from .mplug_for_all_tasks import MPlugForAllTasks, HiTeAForAllTasks from .ofa_for_all_tasks import OfaForAllTasks from .ofa_for_text_to_image_synthesis_model import \ OfaForTextToImageSynthesis @@ -24,7 +24,7 @@ else: 'gemm': ['GEMMForMultiModalEmbedding'], 'team': ['TEAMForMultiModalSimilarity'], 'mmr': ['VideoCLIPForMultiModalEmbedding'], - 'mplug_for_all_tasks': ['MPlugForAllTasks'], + 'mplug_for_all_tasks': ['MPlugForAllTasks', 'HiTeAForAllTasks'], 'ofa_for_all_tasks': ['OfaForAllTasks'], 'ofa_for_text_to_image_synthesis_model': ['OfaForTextToImageSynthesis'], diff --git a/modelscope/models/multi_modal/mplug/__init__.py b/modelscope/models/multi_modal/mplug/__init__.py index 955c87e2..67b0a426 100644 --- a/modelscope/models/multi_modal/mplug/__init__.py +++ b/modelscope/models/multi_modal/mplug/__init__.py @@ -13,5 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configuration_mplug import MPlugConfig -from .modeling_mplug import CONFIG_NAME, MPlug +from .configuration_mplug import HiTeAConfig, MPlugConfig +from .modeling_mplug import CONFIG_NAME, HiTeA, MPlug diff --git a/modelscope/models/multi_modal/mplug/configuration_mplug.py b/modelscope/models/multi_modal/mplug/configuration_mplug.py index 9900ff7c..dcbb0270 100644 --- a/modelscope/models/multi_modal/mplug/configuration_mplug.py +++ b/modelscope/models/multi_modal/mplug/configuration_mplug.py @@ -114,3 +114,67 @@ class MPlugConfig(PretrainedConfig): with open(yaml_file, 'r', encoding='utf-8') as reader: config_dict = yaml.load(reader, Loader=yaml.Loader) return cls(**config_dict) + + +class HiTeAConfig(PretrainedConfig): + + model_type = 'hitea' + + def __init__( + self, + task=Tasks.video_question_answering, + bert_config='config_bert.json', + image_res=224, + num_frames=16, + batch_size_train=32, + vision_width=768, + distill=True, + batch_size_test=64, + k_test=128, + alpha=0.4, + warm_up=True, + eos='[SEP]', + optimizer=None, + schedular=None, + min_length=1, + max_length=10, + beam_size=5, + text_encoder='bert-base-uncased', + text_decoder='bert-base-uncased', + # retrieval + queue_size=65536, + embed_dim=256, + temp=0.07, + **kwargs): + + super().__init__(**kwargs) + self.task = task + self.bert_config = bert_config + self.image_res = image_res + self.num_frames = num_frames + self.batch_size_train = batch_size_train + self.vision_width = vision_width + self.distill = distill + self.batch_size_test = batch_size_test + self.k_test = k_test + self.alpha = alpha + self.warm_up = warm_up + self.eos = eos + self.optimizer = optimizer + self.schedular = schedular + self.min_length = min_length + self.max_length = max_length + self.beam_size = beam_size + self.text_encoder = text_encoder + self.text_decoder = text_decoder + # retrieval + self.queue_size = queue_size + self.embed_dim = embed_dim + self.temp = temp + + @classmethod + def from_yaml_file(cls, yaml_file: Union[str, + os.PathLike]) -> Dict[str, Any]: + with open(yaml_file, 'r', encoding='utf-8') as reader: + config_dict = yaml.load(reader, Loader=yaml.Loader) + return cls(**config_dict) diff --git a/modelscope/models/multi_modal/mplug/modeling_mplug.py b/modelscope/models/multi_modal/mplug/modeling_mplug.py index 4b393439..98edd898 100755 --- a/modelscope/models/multi_modal/mplug/modeling_mplug.py +++ b/modelscope/models/multi_modal/mplug/modeling_mplug.py @@ -40,7 +40,9 @@ from transformers.modeling_utils import (PreTrainedModel, prune_linear_layer) from transformers.utils import logging -from modelscope.models.multi_modal.mplug.configuration_mplug import MPlugConfig +from modelscope.models.multi_modal.mplug.configuration_mplug import ( + HiTeAConfig, MPlugConfig) +from modelscope.models.multi_modal.mplug.mvit import MViTv2, MViTv2_Base_config from modelscope.models.multi_modal.mplug.predictor import TextGenerator from modelscope.utils.constant import ModelFile @@ -2483,3 +2485,322 @@ class MPlugForImageTextRetrieval(MPlug): scores = F.softmax(scores, dim=-1) return scores + + +class HiTeA(PreTrainedModel): + config_class = HiTeAConfig + + def __init__(self, config): + super().__init__(config) + self.config = config + self.tokenizer = BertTokenizer.from_pretrained( + os.path.join(config.model_dir, ModelFile.VOCAB_FILE)) + self.module_setting(config) + self.visual_encoder = MViTv2( + img_size=config.image_res, + config=MViTv2_Base_config, + num_frames=config.num_frames) + self.text_encoder = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder = FusionModel( + self.config_fusion, add_pooling_layer=False) + + @classmethod + def from_pretrained(cls, model_dir, load_checkpoint=True): + from modelscope.utils.constant import Tasks + + task_mapping = { + Tasks.video_question_answering: HiTeAForVideoQuestionAnswering, + Tasks.video_captioning: HiTeAForVideoCaption, + } + config = cls.config_class.from_yaml_file( + os.path.join(model_dir, CONFIG_NAME)) + config.model_dir = model_dir + model = task_mapping[config.task](config) + if load_checkpoint: + checkpoint_path = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + checkpoint = torch.load(checkpoint_path, map_location='cpu') + if 'model' in checkpoint: + checkpoint = checkpoint['model'] + if 'module' in checkpoint: + checkpoint = checkpoint['module'] + checkpoint = { + k.replace('model.', ''): v + for k, v in checkpoint.items() + } + + model.load_state_dict(checkpoint, strict=False) + return model + + def init_distill(self, config): + self.distill = config.distill + if self.distill: + self.visual_encoder_m = MViTv2( + img_size=config.image_res, + config=MViTv2_Base_config, + num_frames=config.num_frames) + self.text_encoder_m = BertModel( + self.config_encoder, add_pooling_layer=False) + self.fusion_encoder_m = FusionModel( + self.config_fusion, add_pooling_layer=False) + self.text_decoder_m = BertLMHeadModel(self.config_decoder) + self.model_pairs = [ + [self.visual_encoder, self.visual_encoder_m], + [self.text_encoder, self.text_encoder_m], + [self.text_decoder, self.text_decoder_m], + ] + self.copy_params() + self.momentum = 0.995 + + def forward(self, *args, **kwargs): + raise NotImplementedError + + def module_setting(self, config): + bert_config_path = os.path.join(config.model_dir, config.bert_config) + self.config_encoder = BertConfig.from_json_file(bert_config_path) + self.config_encoder.num_hidden_layers = self.config_encoder.text_encoder_layers + self.config_fusion = BertConfig.from_json_file(bert_config_path) + self.config_decoder = BertConfig.from_json_file(bert_config_path) + self.config_decoder.add_cross_attention = True + self.config_decoder.num_hidden_layers = self.config_decoder.text_decode_layers + + @torch.no_grad() + def copy_params(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data.copy_(param.data) # initialize + param_m.requires_grad = False # not update by gradient + + @torch.no_grad() + def _momentum_update(self): + for model_pair in self.model_pairs: + for param, param_m in zip(model_pair[0].parameters(), + model_pair[1].parameters()): + param_m.data = param_m.data * self.momentum + param.data * ( + 1. - self.momentum) + + def generation(self, question_states, question_atts, out_size=1): + encoder_inputs = [question_states, question_atts] + topk_ids, topk_scores = self.beam_generator.translate_batch( + encoder_inputs, out_size=out_size) + return topk_ids, topk_scores + + @staticmethod + def _tile(x, dim, n_tile): + import numpy as np + init_dim = x.size(dim) + repeat_idx = [1] * x.dim() + repeat_idx[dim] = n_tile + x = x.repeat(*(repeat_idx)) + order_index = torch.LongTensor( + np.concatenate( + [init_dim * np.arange(n_tile) + i for i in range(init_dim)])) + return torch.index_select(x, dim, order_index.to(x.device)) + + +class HiTeAForVideoQuestionAnswering(HiTeA): + + def __init__(self, config): + super().__init__(config) + self.text_decoder = BertLMHeadModel(self.config_decoder) + self.beam_generator = TextGenerator(config, self.text_decoder) + self.init_distill(config) + + def forward(self, + video, + question, + answer=None, + alpha=0, + k=None, + weights=None, + train=True): + video = video.to(dtype=next(self.parameters()).dtype) + video_embeds = self.visual_encoder(video) + video_atts = torch.ones( + video_embeds.size()[:-1], dtype=torch.long).to(video.device) + + if train: + ''' + k: number of answers for each question + weights: weight for each answer + ''' + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + return_dict=False) + + video_output, question_output = fusion_output + + question_output = torch.cat([video_output, question_output], 1) + merge_text_attention = torch.cat( + [video_atts, question.attention_mask], 1) + + if k is None: + k = [1] * question_output.shape[0] + question_states = [] + question_atts = [] + for b, n in enumerate(k): + question_states += [question_output[b]] * n + question_atts += [merge_text_attention[b]] * n + question_states = torch.stack(question_states, 0) + question_atts = torch.stack(question_atts, 0) + + if self.distill: + with torch.no_grad(): + self._momentum_update() + video_embeds_m = self.visual_encoder_m(video) + text_output_m = self.text_encoder_m( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds_m = text_output_m.last_hidden_state + fusion_output_m = self.fusion_encoder_m( + encoder_embeds=text_embeds_m, + attention_mask=question.attention_mask, + encoder_hidden_states=video_embeds_m, + encoder_attention_mask=video_atts, + return_dict=False) + + image_output_m, question_output_m = fusion_output_m + question_output_m = torch.cat( + [image_output_m, question_output_m], 1) + + question_states_m = [] + for b, n in enumerate(k): + question_states_m += [question_output_m[b]] * n + question_states_m = torch.stack(question_states_m, 0) + + logits_m = self.text_decoder_m( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states_m, + encoder_attention_mask=question_atts, + return_logits=True, + ) + + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + soft_labels=F.softmax(logits_m, dim=-1), + reduction='none', + ) + else: + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=question_states, + encoder_attention_mask=question_atts, + labels=answer_targets, + return_dict=True, + reduction='none', + ) + if weights is None: + weights = 1 + loss = weights * answer_output.loss + loss = loss.sum() / video.size(0) + + return loss + + else: + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + return_dict=False) + video_output, question_output = fusion_output + question_output = torch.cat([video_output, question_output], 1) + merge_text_attention = torch.cat( + [video_atts, question.attention_mask], 1) + topk_ids, topk_probs = self.generation(question_output, + merge_text_attention) + return topk_ids, topk_probs + + +class HiTeAForVideoCaption(HiTeA): + + def __init__(self, config): + super().__init__(config) + self.text_decoder = BertPrefixModel(self.config_decoder) + self.beam_generator = TextGenerator(config, self.text_decoder) + + def beam_search(self, + video, + question, + answer=None, + train=True, + out_size=5): + video_embeds = self.visual_encoder(video) + video_atts = torch.ones( + video_embeds.size()[:-1], dtype=torch.long).to(video.device) + text_output = self.text_encoder( + question.input_ids, + attention_mask=question.attention_mask, + return_dict=True) + text_embeds = text_output.last_hidden_state + fusion_output = self.fusion_encoder( + encoder_embeds=text_embeds, + attention_mask=question.attention_mask, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + return_dict=False) + video_output, question_output = fusion_output + question_output = torch.cat([video_output, question_output], 1) + merge_text_attention = torch.cat([video_atts, question.attention_mask], + 1) + topk_ids, topk_probs = self.generation( + question_output, merge_text_attention, out_size=out_size) + return topk_ids, topk_probs + + def forward(self, + video, + question, + answer=None, + train=True, + out_size=5, + scst=False): + if (scst): + return self.beam_search( + video, question, answer, train=True, out_size=out_size) + video = video.to(dtype=next(self.parameters()).dtype) + video_embeds = self.visual_encoder(video) + video_atts = torch.ones( + video_embeds.size()[:-1], dtype=torch.long).to(video.device) + + if train: + answer_targets = answer.input_ids.masked_fill( + answer.input_ids == self.tokenizer.pad_token_id, -100) + answer_output = self.text_decoder( + answer.input_ids, + attention_mask=answer.attention_mask, + encoder_hidden_states=video_embeds, + encoder_attention_mask=video_atts, + labels=answer_targets, + return_dict=True, + reduction='none') + loss = answer_output.loss + + return loss + else: + topk_ids, topk_probs = self.generation(video_embeds, video_atts) + return topk_ids, topk_probs diff --git a/modelscope/models/multi_modal/mplug/mvit.py b/modelscope/models/multi_modal/mplug/mvit.py new file mode 100644 index 00000000..f3140ce4 --- /dev/null +++ b/modelscope/models/multi_modal/mplug/mvit.py @@ -0,0 +1,1007 @@ +# Copyright 2021-2022 The Alibaba DAMO NLP Team Authors. +# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved. + +from collections import OrderedDict +from functools import partial + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import trunc_normal_ + +try: + from fairscale.nn.checkpoint import checkpoint_wrapper +except ImportError: + checkpoint_wrapper = None + +MViTv2_Base_config = { + 'depth': + 24, + 'dim_mul': [[2, 2.0], [5, 2.0], [21, 2.0]], + 'head_mul': [[2, 2.0], [5, 2.0], [21, 2.0]], + 'pool_q_stride': + [[0, 1, 1, 1], [1, 1, 1, 1], [2, 1, 2, 2], [3, 1, 1, 1], [4, 1, 1, 1], + [5, 1, 2, 2], [6, 1, 1, 1], [7, 1, 1, 1], [8, 1, 1, 1], [9, 1, 1, 1], + [10, 1, 1, 1], [11, 1, 1, 1], [12, 1, 1, 1], [13, 1, 1, 1], [14, 1, 1, 1], + [15, 1, 1, 1], [16, 1, 1, 1], [17, 1, 1, 1], [18, 1, 1, 1], [19, 1, 1, 1], + [20, 1, 1, 1], [21, 1, 2, 2], [22, 1, 1, 1], [23, 1, 1, 1]], + 'pool_kvq_kernel': [3, 3, 3], + 'pool_kv_stride_adaptive': [1, 4, 4], +} + + +def interpolate_rel_pos_embed(state_dict_origin, + state_dict_model, + temporal=True, + verbose=False): + rel_pos_embed_types = ['rel_pos_h', 'rel_pos_w'] + if temporal: + rel_pos_embed_types += ['rel_pos_t'] + + state_dict_inflated = state_dict_origin.copy() + for k, v2d in state_dict_origin.items(): + if any([x in k for x in rel_pos_embed_types]): + v3d = state_dict_model[k] + if v2d.shape[0] != v3d.shape[0]: + rel_pos_resized = F.interpolate( + v2d.reshape(1, v2d.shape[0], -1).permute(0, 2, 1), + size=v3d.shape[0], + mode='linear', + ) + v3d = rel_pos_resized.reshape(-1, v3d.shape[0]).permute(1, 0) + if verbose: + print('Inflate {}: {} -> {}: {}'.format( + k, v2d.shape, k, v3d.shape)) + else: + v3d = v2d + state_dict_inflated[k] = v3d.clone() + return state_dict_inflated + + +def _prepare_mvit_configs(cfg): + depth = cfg['depth'] + dim_mul, head_mul = torch.ones(depth + 1), torch.ones(depth + 1) + for i in range(len(cfg['dim_mul'])): + dim_mul[cfg['dim_mul'][i][0]] = cfg['dim_mul'][i][1] + for i in range(len(cfg['head_mul'])): + head_mul[cfg['head_mul'][i][0]] = cfg['head_mul'][i][1] + + pool_q = [[] for i in range(depth)] + pool_kv = [[] for i in range(depth)] + stride_q = [[] for i in range(depth)] + stride_kv = [[] for i in range(depth)] + + for i in range(len(cfg['pool_q_stride'])): + stride_q[cfg['pool_q_stride'][i][0]] = cfg['pool_q_stride'][i][1:] + pool_q[cfg['pool_q_stride'][i][0]] = cfg['pool_kvq_kernel'] + + if cfg['pool_kv_stride_adaptive'] is not None: + _stride_kv = cfg['pool_kv_stride_adaptive'] + cfg['pool_kv_stride'] = [] + for i in range(cfg['depth']): + if len(stride_q[i]) > 0: + _stride_kv = [ + max(_stride_kv[d] // stride_q[i][d], 1) + for d in range(len(_stride_kv)) + ] + cfg['pool_kv_stride'].append([i] + _stride_kv) + + for i in range(len(cfg['pool_kv_stride'])): + stride_kv[cfg['pool_kv_stride'][i][0]] = cfg['pool_kv_stride'][i][1:] + pool_kv[cfg['pool_kv_stride'][i][0]] = cfg['pool_kvq_kernel'] + + return dim_mul, head_mul, pool_q, pool_kv, stride_q, stride_kv + + +class Mlp(nn.Module): + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop_rate=0.0, + ): + super().__init__() + self.drop_rate = drop_rate + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + if self.drop_rate > 0.0: + self.drop = nn.Dropout(drop_rate) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + if self.drop_rate > 0.0: + x = self.drop(x) + x = self.fc2(x) + if self.drop_rate > 0.0: + x = self.drop(x) + return x + + +class Permute(nn.Module): + + def __init__(self, dims): + super().__init__() + self.dims = dims + + def forward(self, x): + return x.permute(*self.dims) + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """ + Stochastic Depth per sample. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0], ) + (1, ) * ( + x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + mask = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + mask.floor_() # binarize + output = x.div(keep_prob) * mask + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +def round_width(width, multiplier, min_width=1, divisor=1, verbose=False): + if not multiplier: + return width + width *= multiplier + min_width = min_width or divisor + if verbose: + print(f'min width {min_width}') + print(f'width {width} divisor {divisor}') + print(f'other {int(width + divisor / 2) // divisor * divisor}') + + width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) + if width_out < 0.9 * width: + width_out += divisor + return int(width_out) + + +class PatchEmbed(nn.Module): + """ + PatchEmbed. + """ + + def __init__( + self, + dim_in=3, + dim_out=768, + kernel=(7, 7), + stride=(4, 4), + padding=(3, 3), + conv2d=False, + ): + super().__init__() + + if conv2d: + conv_function = nn.Conv2d + else: + conv_function = nn.Conv3d + + self.proj = conv_function( + dim_in, + dim_out, + kernel_size=kernel, + stride=stride, + padding=padding, + ) + + def forward(self, x): + x = self.proj(x) + # B C H W -> B HW C + return x.flatten(2).transpose(1, 2), x.shape + + +def attention_pool(tensor, pool, thw_shape, has_cls_embed=True, norm=None): + if pool is None: + return tensor, thw_shape + tensor_dim = tensor.ndim + if tensor_dim == 4: + pass + elif tensor_dim == 3: + tensor = tensor.unsqueeze(1) + else: + raise NotImplementedError( + f'Unsupported input dimension {tensor.shape}') + + if has_cls_embed: + cls_tok, tensor = tensor[:, :, :1, :], tensor[:, :, 1:, :] + + B, N, L, C = tensor.shape + T, H, W = thw_shape + tensor = ( + tensor.reshape(B * N, T, H, W, C).permute(0, 4, 1, 2, 3).contiguous()) + + tensor = pool(tensor) + + thw_shape = [tensor.shape[2], tensor.shape[3], tensor.shape[4]] + L_pooled = tensor.shape[2] * tensor.shape[3] * tensor.shape[4] + tensor = tensor.reshape(B, N, C, L_pooled).transpose(2, 3) + if has_cls_embed: + tensor = torch.cat((cls_tok, tensor), dim=2) + if norm is not None: + tensor = norm(tensor) + # Assert tensor_dim in [3, 4] + if tensor_dim == 4: + pass + else: # tensor_dim == 3: + tensor = tensor.squeeze(1) + return tensor, thw_shape + + +def get_rel_pos(rel_pos, d): + if isinstance(d, int): + ori_d = rel_pos.shape[0] + if ori_d == d: + return rel_pos + else: + # Interpolate rel pos. + new_pos_embed = F.interpolate( + rel_pos.reshape(1, ori_d, -1).permute(0, 2, 1), + size=d, + mode='linear', + ) + + return new_pos_embed.reshape(-1, d).permute(1, 0) + + +def cal_rel_pos_spatial(attn, q, k, has_cls_embed, q_shape, k_shape, rel_pos_h, + rel_pos_w): + """ + Decomposed Spatial Relative Positional Embeddings. + """ + sp_idx = 1 if has_cls_embed else 0 + q_t, q_h, q_w = q_shape + k_t, k_h, k_w = k_shape + dh = int(2 * max(q_h, k_h) - 1) + dw = int(2 * max(q_w, k_w) - 1) + + # Scale up rel pos if shapes for q and k are different. + q_h_ratio = max(k_h / q_h, 1.0) + k_h_ratio = max(q_h / k_h, 1.0) + dist_h = ( + torch.arange(q_h)[:, None] * q_h_ratio + - torch.arange(k_h)[None, :] * k_h_ratio) + dist_h += (k_h - 1) * k_h_ratio + q_w_ratio = max(k_w / q_w, 1.0) + k_w_ratio = max(q_w / k_w, 1.0) + dist_w = ( + torch.arange(q_w)[:, None] * q_w_ratio + - torch.arange(k_w)[None, :] * k_w_ratio) + dist_w += (k_w - 1) * k_w_ratio + + # Intepolate rel pos if needed. + rel_pos_h = get_rel_pos(rel_pos_h, dh) + rel_pos_w = get_rel_pos(rel_pos_w, dw) + Rh = rel_pos_h[dist_h.long()] + Rw = rel_pos_w[dist_w.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) + rel_h_q = torch.einsum('bythwc,hkc->bythwk', r_q, + Rh) # [B, H, q_t, qh, qw, k_h] + rel_w_q = torch.einsum('bythwc,wkc->bythwk', r_q, + Rw) # [B, H, q_t, qh, qw, k_w] + + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + + rel_h_q[:, :, :, :, :, None, :, None] + + rel_w_q[:, :, :, :, :, None, None, :]).view(B, -1, q_t * q_h * q_w, + k_t * k_h * k_w) + + return attn + + +def cal_rel_pos_temporal(attn, q, has_cls_embed, q_shape, k_shape, rel_pos_t): + """ + Temporal Relative Positional Embeddings. + """ + sp_idx = 1 if has_cls_embed else 0 + q_t, q_h, q_w = q_shape + k_t, k_h, k_w = k_shape + dt = int(2 * max(q_t, k_t) - 1) + # Intepolate rel pos if needed. + rel_pos_t = get_rel_pos(rel_pos_t, dt) + + # Scale up rel pos if shapes for q and k are different. + q_t_ratio = max(k_t / q_t, 1.0) + k_t_ratio = max(q_t / k_t, 1.0) + dist_t = ( + torch.arange(q_t)[:, None] * q_t_ratio + - torch.arange(k_t)[None, :] * k_t_ratio) + dist_t += (k_t - 1) * k_t_ratio + Rt = rel_pos_t[dist_t.long()] + + B, n_head, q_N, dim = q.shape + + r_q = q[:, :, sp_idx:].reshape(B, n_head, q_t, q_h, q_w, dim) + # [B, H, q_t, q_h, q_w, dim] -> [q_t, B, H, q_h, q_w, dim] -> [q_t, B*H*q_h*q_w, dim] + r_q = r_q.permute(2, 0, 1, 3, 4, 5).reshape(q_t, B * n_head * q_h * q_w, + dim) + + # [q_t, B*H*q_h*q_w, dim] * [q_t, dim, k_t] = [q_t, B*H*q_h*q_w, k_t] -> [B*H*q_h*q_w, q_t, k_t] + rel = torch.matmul(r_q, Rt.transpose(1, 2)).transpose(0, 1) + # [B*H*q_h*q_w, q_t, k_t] -> [B, H, q_t, q_h, q_w, k_t] + rel = rel.view(B, n_head, q_h, q_w, q_t, k_t).permute(0, 1, 4, 2, 3, 5) + + attn[:, :, sp_idx:, sp_idx:] = ( + attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_t, q_h, q_w, k_t, k_h, k_w) + + rel[:, :, :, :, :, :, None, None]).view(B, -1, q_t * q_h * q_w, + k_t * k_h * k_w) + + return attn + + +class MultiScaleAttention(nn.Module): + + def __init__( + self, + dim, + dim_out, + input_size, + num_heads=8, + qkv_bias=False, + drop_rate=0.0, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + norm_layer=nn.LayerNorm, + has_cls_embed=True, + # Options include `conv`, `avg`, and `max`. + mode='conv', + # If True, perform pool before projection. + pool_first=False, + rel_pos_spatial=False, + rel_pos_temporal=False, + rel_pos_zero_init=False, + residual_pooling=True, + separate_qkv=False, + ): + super().__init__() + self.pool_first = pool_first + self.separate_qkv = separate_qkv + self.drop_rate = drop_rate + self.num_heads = num_heads + self.dim_out = dim_out + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + self.has_cls_embed = has_cls_embed + padding_q = [int(q // 2) for q in kernel_q] + padding_kv = [int(kv // 2) for kv in kernel_kv] + + if pool_first or separate_qkv: + self.q = nn.Linear(dim, dim_out, bias=qkv_bias) + self.k = nn.Linear(dim, dim_out, bias=qkv_bias) + self.v = nn.Linear(dim, dim_out, bias=qkv_bias) + else: + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) + + self.proj = nn.Linear(dim_out, dim_out) + if drop_rate > 0.0: + self.proj_drop = nn.Dropout(drop_rate) + + # Skip pooling with kernel and stride size of (1, 1, 1). + if np.prod(kernel_q) == 1 and np.prod(stride_q) == 1: + kernel_q = () + if np.prod(kernel_kv) == 1 and np.prod(stride_kv) == 1: + kernel_kv = () + self.mode = mode + + if mode in ('avg', 'max'): + pool_op = nn.MaxPool3d if mode == 'max' else nn.AvgPool3d + self.pool_q = ( + pool_op(kernel_q, stride_q, padding_q, ceil_mode=False) + if len(kernel_q) > 0 else None) + self.pool_k = ( + pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) + if len(kernel_kv) > 0 else None) + self.pool_v = ( + pool_op(kernel_kv, stride_kv, padding_kv, ceil_mode=False) + if len(kernel_kv) > 0 else None) + elif mode == 'conv' or mode == 'conv_unshared': + if pool_first: + dim_conv = dim // num_heads if mode == 'conv' else dim + else: + dim_conv = dim_out // num_heads if mode == 'conv' else dim_out + self.pool_q = ( + nn.Conv3d( + dim_conv, + dim_conv, + kernel_q, + stride=stride_q, + padding=padding_q, + groups=dim_conv, + bias=False, + ) if len(kernel_q) > 0 else None) + self.norm_q = norm_layer(dim_conv) if len(kernel_q) > 0 else None + self.pool_k = ( + nn.Conv3d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) if len(kernel_kv) > 0 else None) + self.norm_k = norm_layer(dim_conv) if len(kernel_kv) > 0 else None + self.pool_v = ( + nn.Conv3d( + dim_conv, + dim_conv, + kernel_kv, + stride=stride_kv, + padding=padding_kv, + groups=dim_conv, + bias=False, + ) if len(kernel_kv) > 0 else None) + self.norm_v = norm_layer(dim_conv) if len(kernel_kv) > 0 else None + else: + raise NotImplementedError(f'Unsupported model {mode}') + + self.rel_pos_spatial = rel_pos_spatial + self.rel_pos_temporal = rel_pos_temporal + if self.rel_pos_spatial: + assert input_size[1] == input_size[2] + size = input_size[1] + q_size = size // stride_q[1] if len(stride_q) > 0 else size + kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size + rel_sp_dim = 2 * max(q_size, kv_size) - 1 + + self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, head_dim)) + if not rel_pos_zero_init: + trunc_normal_(self.rel_pos_h, std=0.02) + trunc_normal_(self.rel_pos_w, std=0.02) + if self.rel_pos_temporal: + self.rel_pos_t = nn.Parameter( + torch.zeros(2 * input_size[0] - 1, head_dim)) + # if not rel_pos_zero_init: + # trunc_normal_(self.rel_pos_t, std=0.02) + + self.residual_pooling = residual_pooling + + def forward(self, x, thw_shape): + B, N, _ = x.shape + + if self.pool_first: + if self.mode == 'conv_unshared': + fold_dim = 1 + else: + fold_dim = self.num_heads + x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3) + q = k = v = x + else: + assert self.mode != 'conv_unshared' + if not self.separate_qkv: + qkv = ( + self.qkv(x).reshape(B, N, 3, self.num_heads, + -1).permute(2, 0, 3, 1, 4)) + q, k, v = qkv[0], qkv[1], qkv[2] + else: + q = k = v = x + q = ( + self.q(q).reshape(B, N, self.num_heads, + -1).permute(0, 2, 1, 3)) + k = ( + self.k(k).reshape(B, N, self.num_heads, + -1).permute(0, 2, 1, 3)) + v = ( + self.v(v).reshape(B, N, self.num_heads, + -1).permute(0, 2, 1, 3)) + + q, q_shape = attention_pool( + q, + self.pool_q, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_q if hasattr(self, 'norm_q') else None, + ) + k, k_shape = attention_pool( + k, + self.pool_k, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_k if hasattr(self, 'norm_k') else None, + ) + v, v_shape = attention_pool( + v, + self.pool_v, + thw_shape, + has_cls_embed=self.has_cls_embed, + norm=self.norm_v if hasattr(self, 'norm_v') else None, + ) + + if self.pool_first: + q_N = ( + np.prod(q_shape) + + 1 if self.has_cls_embed else np.prod(q_shape)) + k_N = ( + np.prod(k_shape) + + 1 if self.has_cls_embed else np.prod(k_shape)) + v_N = ( + np.prod(v_shape) + + 1 if self.has_cls_embed else np.prod(v_shape)) + + q = q.permute(0, 2, 1, 3).reshape(B, q_N, -1) + q = ( + self.q(q).reshape(B, q_N, self.num_heads, + -1).permute(0, 2, 1, 3)) + + v = v.permute(0, 2, 1, 3).reshape(B, v_N, -1) + v = ( + self.v(v).reshape(B, v_N, self.num_heads, + -1).permute(0, 2, 1, 3)) + + k = k.permute(0, 2, 1, 3).reshape(B, k_N, -1) + k = ( + self.k(k).reshape(B, k_N, self.num_heads, + -1).permute(0, 2, 1, 3)) + + N = q.shape[2] + attn = (q * self.scale) @ k.transpose(-2, -1) + if self.rel_pos_spatial: + attn = cal_rel_pos_spatial( + attn, + q, + k, + self.has_cls_embed, + q_shape, + k_shape, + self.rel_pos_h, + self.rel_pos_w, + ) + + if self.rel_pos_temporal: + attn = cal_rel_pos_temporal( + attn, + q, + self.has_cls_embed, + q_shape, + k_shape, + self.rel_pos_t, + ) + attn = attn.softmax(dim=-1) + + x = attn @ v + + if self.residual_pooling: + # Minor Difference + if self.has_cls_embed: + x[:, :, 1:, :] += q[:, :, 1:, :] + else: + x = x + q + + x = x.transpose(1, 2).reshape(B, -1, self.dim_out) + x = self.proj(x) + + if self.drop_rate > 0.0: + x = self.proj_drop(x) + return x, q_shape + + +class MultiScaleBlock(nn.Module): + + def __init__( + self, + dim, + dim_out, + num_heads, + input_size, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + drop_rate=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + up_rate=None, + kernel_q=(1, 1, 1), + kernel_kv=(1, 1, 1), + stride_q=(1, 1, 1), + stride_kv=(1, 1, 1), + mode='conv', + has_cls_embed=True, + pool_first=False, + rel_pos_spatial=False, + rel_pos_temporal=False, + rel_pos_zero_init=False, + residual_pooling=True, + dim_mul_in_att=False, + separate_qkv=False, + use_grad_checkpoint=False, + ): + super().__init__() + self.dim = dim + self.dim_out = dim_out + self.norm1 = norm_layer(dim) + self.dim_mul_in_att = dim_mul_in_att + kernel_skip = [s + 1 if s > 1 else s for s in stride_q] + stride_skip = stride_q + padding_skip = [int(skip // 2) for skip in kernel_skip] + att_dim = dim_out if dim_mul_in_att else dim + + self.use_grad_checkpoint = use_grad_checkpoint + + self.attn = MultiScaleAttention( + dim, + att_dim, + num_heads=num_heads, + input_size=input_size, + qkv_bias=qkv_bias, + drop_rate=drop_rate, + kernel_q=kernel_q, + kernel_kv=kernel_kv, + stride_q=stride_q, + stride_kv=stride_kv, + norm_layer=norm_layer, + has_cls_embed=has_cls_embed, + mode=mode, + pool_first=pool_first, + rel_pos_spatial=rel_pos_spatial, + rel_pos_temporal=rel_pos_temporal, + rel_pos_zero_init=rel_pos_zero_init, + residual_pooling=residual_pooling, + separate_qkv=separate_qkv, + ) + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else nn.Identity()) + self.norm2 = norm_layer(att_dim) + mlp_hidden_dim = int(att_dim * mlp_ratio) + self.has_cls_embed = has_cls_embed + # TODO: check the use case for up_rate, and merge the following lines + if up_rate is not None and up_rate > 1: + mlp_dim_out = dim * up_rate + else: + mlp_dim_out = dim_out + self.mlp = Mlp( + in_features=att_dim, + hidden_features=mlp_hidden_dim, + out_features=mlp_dim_out, + act_layer=act_layer, + drop_rate=drop_rate, + ) + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + self.pool_skip = ( + nn.MaxPool3d( + kernel_skip, stride_skip, padding_skip, ceil_mode=False) + if len(kernel_skip) > 0 else None) + + def forward(self, x, thw_shape): + x_norm = self.norm1(x) + if self.use_grad_checkpoint: + x_block, thw_shape_new = checkpoint.checkpoint( + self.attn, x_norm, thw_shape) + else: + x_block, thw_shape_new = self.attn(x_norm, thw_shape) + + if self.dim_mul_in_att and self.dim != self.dim_out: + x = self.proj(x_norm) + x_res, _ = attention_pool( + x, self.pool_skip, thw_shape, has_cls_embed=self.has_cls_embed) + x = x_res + self.drop_path(x_block) + x_norm = self.norm2(x) + if self.use_grad_checkpoint: + x_mlp = checkpoint.checkpoint(self.mlp, x_norm) + else: + x_mlp = self.mlp(x_norm) + + if not self.dim_mul_in_att and self.dim != self.dim_out: + x = self.proj(x_norm) + x = x + self.drop_path(x_mlp) + return x, thw_shape_new + + +class MViTv2(nn.Module): + """ + Improved Multiscale Vision Transformers for Classification and Detection + Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik, + Christoph Feichtenhofer* + https://arxiv.org/abs/2112.01526 + Multiscale Vision Transformers + Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik, + Christoph Feichtenhofer* + https://arxiv.org/abs/2104.11227 + """ + + def __init__( + self, + img_size=224, + embed_dim=96, + num_classes=1000, + num_frames=4, + num_heads=1, + depth=24, + patch_kernel=[3, 7, 7], + patch_stride=[2, 4, 4], + patch_padding=[1, 3, 3], + config=None, + dropout_rate=0., + drop_path_rate=0., + mlp_ratio=4., + qkv_bias=True, + mode='conv', + cls_embed_on=True, + use_abs_pos=False, + rel_pos_spatial=True, + rel_pos_temporal=True, + rel_pos_zero_init=False, + residual_pooling=True, + dim_mul_in_att=True, + pool_first=False, + zero_decay_pos_cls=False, + separate_qkv=False, + norm_stem=False, + sep_pos_embed=False, + use_grad_checkpoint=True, + ): + super().__init__() + # Prepare input. + in_chans = 3 + self.img_size = img_size + # Prepare output. + self.num_classes = num_classes + self.embed_dim = embed_dim + # MViT params. + self.num_heads = num_heads + self.depth = depth + self.cls_embed_on = cls_embed_on + self.use_abs_pos = use_abs_pos + self.zero_decay_pos_cls = zero_decay_pos_cls + self.use_grad_checkpoint = use_grad_checkpoint + self.sep_pos_embed = sep_pos_embed + self.drop_rate = dropout_rate + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + if use_grad_checkpoint: + self.patch_embed = checkpoint_wrapper( + PatchEmbed( + dim_in=in_chans, + dim_out=embed_dim, + kernel=patch_kernel, + stride=patch_stride, + padding=patch_padding, + )) + else: + self.patch_embed = PatchEmbed( + dim_in=in_chans, + dim_out=embed_dim, + kernel=patch_kernel, + stride=patch_stride, + padding=patch_padding, + ) + + patch_dims = [ + num_frames // patch_stride[0], + img_size // patch_stride[1], + img_size // patch_stride[2], + ] + num_patches = np.prod(patch_dims) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if self.cls_embed_on: + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + pos_embed_dim = num_patches + 1 + else: + pos_embed_dim = num_patches + + if self.use_abs_pos: + self.pos_embed = nn.Parameter( + torch.zeros(1, pos_embed_dim, embed_dim)) + + if self.use_abs_pos: + if self.sep_pos_embed: + self.pos_embed_spatial = nn.Parameter( + torch.zeros(1, self.patch_dims[1] * self.patch_dims[2], + embed_dim)) + self.pos_embed_temporal = nn.Parameter( + torch.zeros(1, self.patch_dims[0], embed_dim)) + if self.cls_embed_on: + self.pos_embed_class = nn.Parameter( + torch.zeros(1, 1, embed_dim)) + else: + self.pos_embed = nn.Parameter( + torch.zeros(1, pos_embed_dim, embed_dim)) + + assert config is not None + # MViT backbone configs + dim_mul, head_mul, pool_q, pool_kv, stride_q, stride_kv = _prepare_mvit_configs( + config) + input_size = patch_dims + + self.norm_stem = norm_layer(embed_dim) if norm_stem else None + + self.blocks = nn.ModuleList() + for i in range(depth): + num_heads = round_width(num_heads, head_mul[i]) + if dim_mul_in_att: + dim_out = round_width( + embed_dim, + dim_mul[i], + divisor=round_width(num_heads, head_mul[i]), + ) + else: + dim_out = round_width( + embed_dim, + dim_mul[i + 1], + divisor=round_width(num_heads, head_mul[i + 1]), + ) + attention_block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + input_size=input_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_rate=self.drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + kernel_q=pool_q[i] if len(pool_q) > i else [], + kernel_kv=pool_kv[i] if len(pool_kv) > i else [], + stride_q=stride_q[i] if len(stride_q) > i else [], + stride_kv=stride_kv[i] if len(stride_kv) > i else [], + mode=mode, + has_cls_embed=self.cls_embed_on, + pool_first=pool_first, + rel_pos_spatial=rel_pos_spatial, + rel_pos_temporal=rel_pos_temporal, + rel_pos_zero_init=rel_pos_zero_init, + residual_pooling=residual_pooling, + dim_mul_in_att=dim_mul_in_att, + separate_qkv=separate_qkv, + use_grad_checkpoint=False) + if use_grad_checkpoint: + attention_block = checkpoint_wrapper( + attention_block, offload_to_cpu=False) + self.blocks.append(attention_block) + + if len(stride_q[i]) > 0: + input_size = [ + size // stride + for size, stride in zip(input_size, stride_q[i]) + ] + embed_dim = dim_out + + self.norm = norm_layer(embed_dim) + + self.head = nn.Identity() + + if self.use_abs_pos: + if self.sep_pos_embed: + trunc_normal_(self.pos_embed_spatial, std=0.02) + trunc_normal_(self.pos_embed_temporal, std=0.02) + if self.cls_embed_on: + trunc_normal_(self.pos_embed_class, std=0.02) + else: + trunc_normal_(self.pos_embed, std=0.02) + if self.cls_embed_on: + trunc_normal_(self.cls_token, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + names = [] + if self.zero_decay_pos_cls: + if self.use_abs_pos: + if self.sep_pos_embed: + names.extend([ + 'pos_embed_spatial', + 'pos_embed_temporal', + 'pos_embed_class', + ]) + else: + names.append(['pos_embed']) + if self.rel_pos_spatial: + names.extend(['rel_pos_h', 'rel_pos_w', 'rel_pos_hw']) + if self.rel_pos_temporal: + names.extend(['rel_pos_t']) + if self.cls_embed_on: + names.append('cls_token') + + return names + + def _get_pos_embed(self, pos_embed, bcthw): + t, h, w = bcthw[-3], bcthw[-2], bcthw[-1] + if self.cls_embed_on: + cls_pos_embed = pos_embed[:, 0:1, :] + pos_embed = pos_embed[:, 1:] + txy_num = pos_embed.shape[1] + p_t, p_h, p_w = self.patch_dims + assert p_t * p_h * p_w == txy_num + + if (p_t, p_h, p_w) != (t, h, w): + new_pos_embed = F.interpolate( + pos_embed[:, :, :].reshape(1, p_t, p_h, p_w, + -1).permute(0, 4, 1, 2, 3), + size=(t, h, w), + mode='trilinear', + ) + pos_embed = new_pos_embed.reshape(1, -1, + t * h * w).permute(0, 2, 1) + + if self.cls_embed_on: + pos_embed = torch.cat((cls_pos_embed, pos_embed), dim=1) + + return pos_embed + + def forward_features(self, x): + x = x.permute(0, 2, 1, 3, 4) + x, bcthw = self.patch_embed(x) + + T, H, W = bcthw[-3], bcthw[-2], bcthw[-1] + B, N, C = x.shape + + if self.cls_embed_on: + cls_tokens = self.cls_token.expand( + B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + x = torch.cat((cls_tokens, x), dim=1) + + if self.use_abs_pos: + if self.sep_pos_embed: + pos_embed = self.pos_embed_spatial.repeat( + 1, self.patch_dims[0], 1) + torch.repeat_interleave( + self.pos_embed_temporal, + self.patch_dims[1] * self.patch_dims[2], + dim=1) + if self.cls_embed_on: + pos_embed = torch.cat([self.pos_embed_class, pos_embed], 1) + pos_embed = self._get_pos_embed(pos_embed, bcthw) + x = x + pos_embed + else: + pos_embed = self._get_pos_embed(self.pos_embed, bcthw) + x = x + pos_embed + + if self.drop_rate: + x = self.pos_drop(x) + + if self.norm_stem: + x = self.norm_stem(x) + + thw = [T, H, W] + for blk in self.blocks: + x, thw = blk(x, thw) + + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x diff --git a/modelscope/models/multi_modal/mplug_for_all_tasks.py b/modelscope/models/multi_modal/mplug_for_all_tasks.py index 4d2a6ac2..f15b69d2 100644 --- a/modelscope/models/multi_modal/mplug_for_all_tasks.py +++ b/modelscope/models/multi_modal/mplug_for_all_tasks.py @@ -11,7 +11,7 @@ from modelscope.outputs import OutputKeys from modelscope.utils.config import Config from modelscope.utils.constant import ModelFile, Tasks -__all__ = ['MPlugForAllTasks'] +__all__ = ['MPlugForAllTasks', 'HiTeAForAllTasks'] @MODELS.register_module( @@ -81,3 +81,69 @@ class MPlugForAllTasks(TorchModel): # evaluate topk_ids, _ = output return {'sequences': [list_tensor[0] for list_tensor in topk_ids]} + + +@MODELS.register_module( + Tasks.video_question_answering, module_name=Models.hitea) +@MODELS.register_module(Tasks.video_captioning, module_name=Models.hitea) +class HiTeAForAllTasks(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the hitea model from the `model_dir` path. + Args: + model_dir (str): the model path. + """ + + super().__init__(model_dir, *args, **kwargs) + from modelscope.models.multi_modal.mplug import HiTeA + self.model = HiTeA.from_pretrained(model_dir) + self.tokenizer = self.model.tokenizer + + def forward(self, input: Dict[str, Tensor]) -> Dict[str, Tensor]: + """return the result by the model + + Args: + input (Dict[str, Tensor]): the preprocessed data + + Returns: + Dict[str, Tensor]: results + Example: + { + 'predictions': Tensor([[1377, 4959, 2785, 6392...])]), + } + """ + + # get task from config file + task = Config.from_file( + osp.join(self.model_dir, ModelFile.CONFIGURATION)).task + + # inference + if not self.training and 'question' in input: + output = self.model(input['video'], input['question'], train=False) + topk_ids, _ = output + pred_string: List[str] = \ + self.tokenizer.decode(topk_ids[0][0], skip_special_tokens=True) + output_key = OutputKeys.CAPTION \ + if task == Tasks.video_captioning else OutputKeys.TEXT + return {output_key: pred_string} + + # train and evaluate + import addict + video = input['video'] + answer = addict.Dict( + input_ids=input['answer_input_ids'], + attention_mask=input['answer_attention_mask']) + if 'index' not in input: + question = addict.Dict( + input_ids=input['question_input_ids'], + attention_mask=input['question_attention_mask']) + output = self.model(video, question, answer, train=self.training) + else: + index = input['index'] + output = self.model(video, answer, index, train=self.training) + if self.training: + return {OutputKeys.LOSS: output} + + # evaluate + topk_ids, _ = output + return {'sequences': [list_tensor[0] for list_tensor in topk_ids]} diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index d0f55fe1..e950c15d 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -711,6 +711,12 @@ TASK_OUTPUTS = { # "caption": "this is an image caption text." # } Tasks.image_captioning: [OutputKeys.CAPTION], + + # video caption result for single sample + # { + # "caption": "this is an video caption text." + # } + Tasks.video_captioning: [OutputKeys.CAPTION], Tasks.ocr_recognition: [OutputKeys.TEXT], # visual grounding result for single sample @@ -769,6 +775,10 @@ TASK_OUTPUTS = { # {"text": "this is a text answser. "} Tasks.visual_question_answering: [OutputKeys.TEXT], + # VideoQA result for a sample + # {"text": "this is a text answser. "} + Tasks.video_question_answering: [OutputKeys.TEXT], + # auto_speech_recognition result for a single sample # { # "text": "每天都要快乐喔" diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index adb9fe23..57ca561c 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -1,5 +1,6 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import cv2 import numpy as np from PIL import Image @@ -222,6 +223,9 @@ TASK_INPUTS = { Tasks.image_captioning: [InputType.IMAGE, { 'image': InputType.IMAGE, }], + Tasks.video_captioning: [InputType.VIDEO, { + 'video': InputType.VIDEO, + }], Tasks.visual_grounding: { 'image': InputType.IMAGE, 'text': InputType.TEXT @@ -245,6 +249,10 @@ TASK_INPUTS = { 'image': InputType.IMAGE, 'text': InputType.TEXT }, + Tasks.video_question_answering: { + 'video': InputType.VIDEO, + 'text': InputType.TEXT + }, Tasks.visual_entailment: { 'image': InputType.IMAGE, 'text': InputType.TEXT, diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index 01e82156..66b68a83 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -80,6 +80,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/nlp_bart_text-error-correction_chinese'), Tasks.image_captioning: (Pipelines.image_captioning, 'damo/ofa_image-caption_coco_large_en'), + Tasks.video_captioning: + (Pipelines.video_captioning, + 'damo/multi-modal_hitea_video-captioning_base_en'), Tasks.image_portrait_stylization: (Pipelines.person_image_cartoon, 'damo/cv_unet_person-image-cartoon_compound-models'), @@ -114,6 +117,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.visual_question_answering: (Pipelines.visual_question_answering, 'damo/mplug_visual-question-answering_coco_large_en'), + Tasks.video_question_answering: + (Pipelines.video_question_answering, + 'damo/multi-modal_hitea_video-question-answering_base_en'), Tasks.video_embedding: (Pipelines.cmdssl_video_embedding, 'damo/cv_r2p1d_video_embedding'), Tasks.text_to_image_synthesis: diff --git a/modelscope/pipelines/multi_modal/__init__.py b/modelscope/pipelines/multi_modal/__init__.py index d5c171a3..b16eb360 100644 --- a/modelscope/pipelines/multi_modal/__init__.py +++ b/modelscope/pipelines/multi_modal/__init__.py @@ -14,7 +14,8 @@ if TYPE_CHECKING: VideoMultiModalEmbeddingPipeline from .visual_question_answering_pipeline import VisualQuestionAnsweringPipeline from .asr_pipeline import AutomaticSpeechRecognitionPipeline - + from .video_captioning_pipeline import VideoCaptioningPipeline + from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline else: _import_structure = { 'image_captioning_pipeline': ['ImageCaptioningPipeline'], @@ -29,6 +30,9 @@ else: 'generative_multi_modal_embedding_pipeline': ['GEMMMultiModalEmbeddingPipeline'], 'asr_pipeline': ['AutomaticSpeechRecognitionPipeline'], + 'video_captioning_pipeline': ['VideoCaptioningPipeline'], + 'video_question_answering_pipeline': + ['VideoQuestionAnsweringPipeline'] } import sys diff --git a/modelscope/pipelines/multi_modal/video_captioning_pipeline.py b/modelscope/pipelines/multi_modal/video_captioning_pipeline.py new file mode 100644 index 00000000..e13e1ae5 --- /dev/null +++ b/modelscope/pipelines/multi_modal/video_captioning_pipeline.py @@ -0,0 +1,56 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models.multi_modal import HiTeAForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import HiTeAPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.video_captioning, module_name=Pipelines.video_captioning) +class VideoCaptioningPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """ + use `model` and `preprocessor` to create a video captioning pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model.eval() + if preprocessor is None: + if isinstance(self.model, HiTeAForAllTasks): + self.preprocessor = HiTeAPreprocessor(self.model.model_dir) + + def _batch(self, data): + if isinstance(self.model, HiTeAForAllTasks): + from transformers.tokenization_utils_base import BatchEncoding + batch_data = dict(train=data[0]['train']) + batch_data['video'] = torch.cat([d['video'] for d in data]) + question = {} + for k in data[0]['question'].keys(): + question[k] = torch.cat([d['question'][k] for d in data]) + batch_data['question'] = BatchEncoding(question) + return batch_data + else: + return super()._collate_batch(data) + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/pipelines/multi_modal/video_question_answering_pipeline.py b/modelscope/pipelines/multi_modal/video_question_answering_pipeline.py new file mode 100644 index 00000000..63a730ac --- /dev/null +++ b/modelscope/pipelines/multi_modal/video_question_answering_pipeline.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import torch + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.multi_modal import HiTeAForAllTasks +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Pipeline, Tensor +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import HiTeAPreprocessor, Preprocessor +from modelscope.utils.constant import Tasks + +__all__ = ['VideoQuestionAnsweringPipeline'] + + +@PIPELINES.register_module( + Tasks.video_question_answering, + module_name=Pipelines.video_question_answering) +class VideoQuestionAnsweringPipeline(Pipeline): + + def __init__(self, + model: Union[Model, str], + preprocessor: Optional[Preprocessor] = None, + **kwargs): + """use `model` and `preprocessor` to create a video question answering pipeline for prediction + + Args: + model (HiTeAForVideoQuestionAnswering): a model instance + preprocessor (HiTeAForVideoQuestionAnsweringPreprocessor): a preprocessor instance + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + if preprocessor is None: + if isinstance(self.model, HiTeAForAllTasks): + self.preprocessor = HiTeAPreprocessor(self.model.model_dir) + self.model.eval() + + def forward(self, inputs: Dict[str, Any], + **forward_params) -> Dict[str, Any]: + with torch.no_grad(): + return super().forward(inputs, **forward_params) + + def postprocess(self, inputs: Dict[str, Tensor], + **postprocess_params) -> Dict[str, str]: + """process the prediction results + + Args: + inputs (Dict[str, Any]): _description_ + + Returns: + Dict[str, str]: the prediction results + """ + return inputs diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index c538a580..ab658cc4 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -15,7 +15,8 @@ if TYPE_CHECKING: ImageDenoisePreprocessor) from .kws import WavToLists from .tts import KanttsDataPreprocessor - from .multi_modal import (OfaPreprocessor, MPlugPreprocessor) + from .multi_modal import (OfaPreprocessor, MPlugPreprocessor, + HiTeAPreprocessor) from .nlp import ( DocumentSegmentationTransformersPreprocessor, FaqQuestionAnsweringTransformersPreprocessor, @@ -52,7 +53,8 @@ else: ], 'kws': ['WavToLists'], 'tts': ['KanttsDataPreprocessor'], - 'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'], + 'multi_modal': + ['OfaPreprocessor', 'MPlugPreprocessor', 'HiTeAPreprocessor'], 'nlp': [ 'DocumentSegmentationTransformersPreprocessor', 'FaqQuestionAnsweringTransformersPreprocessor', diff --git a/modelscope/preprocessors/multi_modal.py b/modelscope/preprocessors/multi_modal.py index 6d326df3..85ef4cdd 100644 --- a/modelscope/preprocessors/multi_modal.py +++ b/modelscope/preprocessors/multi_modal.py @@ -3,7 +3,9 @@ import os.path as osp from io import BytesIO from typing import Any, Dict, List, Tuple, Union +import decord import json +import numpy as np import torch from PIL import Image from timm.data import create_transform @@ -12,6 +14,8 @@ from torchvision.transforms import Compose, Normalize, Resize, ToTensor from modelscope.hub.snapshot_download import snapshot_download from modelscope.metainfo import Preprocessors from modelscope.pipelines.base import Input +from modelscope.pipelines.cv.cmdssl_video_embedding_pipeline import ( + VCenterCrop, VCompose, VNormalize, VRescale, VToTensor) from modelscope.preprocessors import load_image from modelscope.utils.config import Config from modelscope.utils.constant import (Fields, Invoke, ModeKeys, ModelFile, @@ -22,10 +26,7 @@ from .ofa import * # noqa from .ofa.utils.collate import collate_fn from .ofa.utils.constant import OFA_TASK_KEY_MAPPING -__all__ = [ - 'OfaPreprocessor', - 'MPlugPreprocessor', -] +__all__ = ['OfaPreprocessor', 'MPlugPreprocessor', 'HiTeAPreprocessor'] @PREPROCESSORS.register_module( @@ -387,3 +388,141 @@ class MPlugPreprocessor(Preprocessor): if self.cfg.task == Tasks.image_text_retrieval: output['index'] = index return output + + +@PREPROCESSORS.register_module( + Fields.multi_modal, module_name=Preprocessors.hitea_tasks_preprocessor) +class HiTeAPreprocessor(Preprocessor): + + def __init__(self, + model_dir: str, + mode: str = ModeKeys.INFERENCE, + tokenizer_max_length: int = 25, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.model_dir = model_dir + self.mode = mode + self.tokenizer_max_length = tokenizer_max_length + + self._tokenizer = None + self._patch_resize_transform = None + self._num_frames = None + self._video_map = {} + + @property + def tokenizer(self): + from transformers import BertTokenizer + + if self._tokenizer is None: + self._tokenizer = BertTokenizer.from_pretrained(self.model_dir) + return self._tokenizer + + @property + def patch_resize_transform(self): + if self._patch_resize_transform is None: + from torchvision import transforms + from modelscope.models.multi_modal.mplug import CONFIG_NAME, HiTeAConfig + + config = HiTeAConfig.from_yaml_file( + osp.join(self.model_dir, CONFIG_NAME)) + + mean = (0.48145466, 0.4578275, 0.40821073) + std = (0.26862954, 0.26130258, 0.27577711) + + self._patch_resize_transform = transforms.Compose([ + transforms.Resize((config.image_res, config.image_res), + interpolation=Image.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ]) + return self._patch_resize_transform + + @property + def num_frames(self): + if self._num_frames is None: + from torchvision import transforms + from modelscope.models.multi_modal.mplug import CONFIG_NAME, HiTeAConfig + + config = HiTeAConfig.from_yaml_file( + osp.join(self.model_dir, CONFIG_NAME)) + + self._num_frames = config.num_frames + return self._num_frames + + def video_open(self, path: str) -> Tuple[decord.VideoReader, int]: + if path not in self._video_map: + index = len(self._video_map) + vr = decord.VideoReader(path, ctx=decord.cpu(0)) + self._video_map[path] = (vr, index) + return self._video_map[path] + + def sample_frames(self, num_frames: int, vlen: int) -> List[int]: + acc_samples = min(num_frames, vlen) + # split the video into `acc_samples` intervals, and sample from each interval. + intervals = np.linspace( + start=0, stop=vlen, num=acc_samples + 1).astype(int) + ranges = [] + for idx, interv in enumerate(intervals[:-1]): + ranges.append((interv, intervals[idx + 1] - 1)) + + frame_indices = [(x[0] + x[1]) // 2 for x in ranges] + + if len(frame_indices) < num_frames: # padded with last frame + padded_frame_indices = [frame_indices[-1]] * num_frames + padded_frame_indices[:len(frame_indices)] = frame_indices + frame_indices = padded_frame_indices + return frame_indices + + def __call__( + self, data: Union[decord.VideoReader, tuple, + Dict[str, Any]]) -> Dict[str, Any]: + self.cfg = Config.from_file( + osp.join(self.model_dir, ModelFile.CONFIGURATION)) + + if isinstance(data, (decord.VideoReader, str)): + video = data + elif isinstance(data, tuple): + video = data[0] + else: + video = data['video'] + index = 0 + if isinstance(video, str): + video, index = self.video_open(video) + frame_indices = self.sample_frames(self.num_frames, len(video)) + video.seek(0) + video = torch.from_numpy(video.get_batch(frame_indices).asnumpy()) + video = [ + self.patch_resize_transform(Image.fromarray(f)) + for f in video.numpy() + ] + video = torch.stack(video, dim=0) + question = '' if self.cfg.task == Tasks.video_captioning \ + else data[1 if isinstance(data, tuple) + else ('text' if 'text' in data else 'question')] + question = self.tokenizer( + question.lower(), + padding='max_length', + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors='pt') + + if self.mode == ModeKeys.INFERENCE: + video = torch.stack([video], dim=0) + return {'video': video, 'question': question} + else: + answer = data['answer'] + answer = self.tokenizer( + answer, + padding='max_length', + truncation=True, + max_length=self.tokenizer_max_length, + return_tensors='pt') + output = { + 'video': video, + 'question_input_ids': question.input_ids.squeeze(), + 'question_attention_mask': question.attention_mask.squeeze(), + 'answer_input_ids': answer.input_ids.squeeze(), + 'answer_attention_mask': answer.attention_mask.squeeze(), + } + return output diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 92c9c8ba..7d207ba4 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -166,6 +166,8 @@ class MultiModalTasks(object): visual_entailment = 'visual-entailment' video_multi_modal_embedding = 'video-multi-modal-embedding' image_text_retrieval = 'image-text-retrieval' + video_captioning = 'video-captioning' + video_question_answering = 'video-question-answering' class ScienceTasks(object): diff --git a/tests/pipelines/test_hitea_tasks.py b/tests/pipelines/test_hitea_tasks.py new file mode 100644 index 00000000..50efdfbd --- /dev/null +++ b/tests/pipelines/test_hitea_tasks.py @@ -0,0 +1,64 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class HiTeATasksTest(unittest.TestCase, DemoCompatibilityCheck): + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_video_captioning_with_model(self): + model = Model.from_pretrained( + 'damo/multi-modal_hitea_video-captioning_base_en') + pipeline_caption = pipeline( + task=Tasks.video_captioning, + model=model, + ) + video = 'data/test/videos/video_caption_and_qa_test.mp4' + result = pipeline_caption(video) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_video_captioning_with_name(self): + model = 'damo/multi-modal_hitea_video-captioning_base_en' + pipeline_caption = pipeline( + Tasks.video_captioning, + model=model, + ) + video = 'data/test/videos/video_caption_and_qa_test.mp4' + result = pipeline_caption(video) + print(result[OutputKeys.CAPTION]) + + @unittest.skipUnless(test_level() >= 1, 'skip test in current test level') + def test_run_with_video_question_answering_with_model(self): + model = Model.from_pretrained( + 'damo/multi-modal_hitea_video-question-answering_base_en') + pipeline_vqa = pipeline(Tasks.video_question_answering, model=model) + video = 'data/test/videos/video_caption_and_qa_test.mp4' + text = 'How many people are there?' + input = {'video': video, 'text': text} + result = pipeline_vqa(input) + print(result) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_video_question_answering_with_name(self): + model = 'damo/multi-modal_hitea_video-question-answering_base_en' + pipeline_vqa = pipeline(Tasks.video_question_answering, model=model) + video = 'data/test/videos/video_caption_and_qa_test.mp4' + text = 'Who teaches a girl how to paint eggs?' + input = {'video': video, 'text': text} + result = pipeline_vqa(input) + print(result) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()