add soonet for video temporal grounding

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11810444
This commit is contained in:
yanwen.pyl
2023-03-10 09:02:39 +08:00
committed by wenmeng.zwm
parent 8a19e9645d
commit fc7daea9c2
13 changed files with 1909 additions and 0 deletions

View File

@@ -195,6 +195,7 @@ class Models(object):
mgeo = 'mgeo'
vldoc = 'vldoc'
hitea = 'hitea'
soonet = 'soonet'
# science models
unifold = 'unifold'
@@ -497,6 +498,7 @@ class Pipelines(object):
text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
soonet_video_temporal_grounding = 'soonet-video-temporal-grounding'
# science tasks
protein_structure = 'unifold-protein-structure'

View File

@@ -0,0 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .tokenizer import SimpleTokenizer
from .model import SOONet
from .utils import decode_video
from .clip import load_clip
else:
_import_structure = {
'model': ['SOONet'],
'tokenizer': ['SimpleTokenizer'],
'utils': ['decode_video'],
'clip': ['load_clip']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,287 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class Q2VRankerStage1(nn.Module):
"""
Used to calculate the qv_ctx_score with query embedding and multi anchor context embeddings as input.
The qv_ctx_score is used to pre-rank and retain top-k related anchors.
"""
def __init__(self, nscales, hidden_dim):
super().__init__()
self.fc = nn.Linear(hidden_dim, hidden_dim)
self.nscales = nscales
def forward(self, ctx_feats, qfeat):
qfeat = self.fc(qfeat)
qv_ctx_scores = list()
for i in range(self.nscales):
score = torch.einsum('bld,bd->bl',
F.normalize(ctx_feats[i], p=2, dim=2),
F.normalize(qfeat, p=2, dim=1))
qv_ctx_scores.append(score)
return qv_ctx_scores
class V2QRankerStage1(nn.Module):
"""
Used to calculate the vq_ctx_score with anchor context embeddings and multi query embeddings as input.
"""
def __init__(self, nscales, hidden_dim):
super().__init__()
self.fc = nn.Linear(hidden_dim, hidden_dim)
self.nscales = nscales
def forward(self, ctx_feats, qfeat):
vq_ctx_scores = list()
for i in range(self.nscales):
score = torch.einsum(
'bld,bd->bl', F.normalize(self.fc(ctx_feats[i]), p=2, dim=2),
F.normalize(qfeat, p=2, dim=1))
vq_ctx_scores.append(score)
return vq_ctx_scores
class Q2VRankerStage2(nn.Module):
"""
Used to calculate the qv_ctn_score with query embedding and video sequence embedding as input.
The qv_ctn_score is used to re-rank anchors.
"""
def __init__(self, nscales, hidden_dim, snippet_length=10):
super().__init__()
self.nscales = nscales
self.snippet_length = snippet_length
self.qfc = nn.Linear(hidden_dim, hidden_dim)
self.encoder = V2VAttention()
def forward(self, vfeats, qfeat, hit_indices, qv_ctx_scores):
qfeat = self.qfc(qfeat)
qv_ctn_scores = list()
qv_merge_scores = list()
_, L, D = vfeats.size()
ctn_feats = list()
for i in range(self.nscales):
anchor_length = self.snippet_length * 2**i
assert L // anchor_length == qv_ctx_scores[i].size(1)
qv_ctx_score = torch.index_select(qv_ctx_scores[i], 1,
hit_indices[i])
ctn_feat = vfeats.view(L // anchor_length, anchor_length,
D).detach()
ctn_feat = torch.index_select(ctn_feat, 0, hit_indices[i])
ctn_feat = self.encoder(
ctn_feat,
torch.ones(ctn_feat.size()[:2], device=ctn_feat.device))
ctn_feats.append(ctn_feat)
qv_ctn_score = torch.einsum(
'bkld,bd->bkl', F.normalize(ctn_feat.unsqueeze(0), p=2, dim=3),
F.normalize(qfeat, p=2, dim=1))
qv_ctn_score, _ = torch.max(qv_ctn_score, dim=2)
qv_ctn_scores.append(qv_ctn_score)
qv_merge_scores.append(qv_ctx_score + qv_ctn_score)
return qv_merge_scores, qv_ctn_scores, ctn_feats
class V2QRankerStage2(nn.Module):
"""
Used to calculate the vq_ctn_score with anchor content embeddings and multi query embeddings as input.
"""
def __init__(self, nscales, hidden_dim):
super().__init__()
self.fc = nn.Linear(hidden_dim, hidden_dim)
self.nscales = nscales
def forward(self, ctn_feats, qfeat):
vq_ctn_scores = list()
for i in range(self.nscales):
score = torch.einsum(
'bkld,bd->bkl',
F.normalize(self.fc(ctn_feats[i]).unsqueeze(0), p=2, dim=3),
F.normalize(qfeat, p=2, dim=1))
score = torch.mean(score, dim=2)
vq_ctn_scores.append(score)
return vq_ctn_scores
class V2VAttention(nn.Module):
"""
Self-attention encoder for anchor frame sequence to encode intra-anchor knowledge.
"""
def __init__(self):
super().__init__()
self.posemb = PositionEncoding(max_len=400, dim=512, dropout=0.0)
self.encoder = MultiHeadAttention(dim=512, n_heads=8, dropout=0.1)
self.dropout = nn.Dropout(0.0)
def forward(self, video_feats, video_masks):
mask = torch.einsum('bm,bn->bmn', video_masks,
video_masks).unsqueeze(1)
residual = video_feats
video_feats = video_feats + self.posemb(video_feats)
out = self.encoder(
query=video_feats, key=video_feats, value=video_feats, mask=mask)
video_feats = self.dropout(residual
+ out) * video_masks.unsqueeze(2).float()
return video_feats
class BboxRegressor(nn.Module):
"""
Predict the offset of bounding box for each candidate anchor.
"""
def __init__(self, hidden_dim, enable_stage2=False):
super().__init__()
self.fc_ctx = nn.Linear(hidden_dim, hidden_dim)
self.fc_q = nn.Linear(hidden_dim, hidden_dim)
if enable_stage2:
self.fc_ctn = nn.Linear(hidden_dim, hidden_dim)
self.attn = SelfAttention(hidden_dim)
self.predictor = nn.Sequential(
nn.Linear(2 * hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 2))
else:
self.predictor = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
nn.Linear(hidden_dim, 2))
self.enable_stage2 = enable_stage2
def forward(self, ctx_feats, ctn_feats, qfeat):
qfeat = self.fc_q(qfeat)
ctx_feats = torch.cat(ctx_feats, dim=1)
ctx_fuse_feats = F.relu(self.fc_ctx(ctx_feats)) * F.relu(
qfeat.unsqueeze(1))
if self.enable_stage2 and ctn_feats:
ctn_fuse_feats = list()
for i in range(len(ctn_feats)):
out = F.relu(self.fc_ctn(ctn_feats[i]).unsqueeze(0)) * F.relu(
qfeat.unsqueeze(1).unsqueeze(1))
out = self.attn(out)
ctn_fuse_feats.append(out)
ctn_fuse_feats = torch.cat(ctn_fuse_feats, dim=1)
fuse_feats = torch.cat([ctx_fuse_feats, ctn_fuse_feats], dim=-1)
else:
fuse_feats = ctx_fuse_feats
out = self.predictor(fuse_feats)
return out
class SelfAttention(nn.Module):
"""
Obtain pooled features by self-attentive pooling.
"""
def __init__(self, hidden_dim):
super().__init__()
self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_dim // 2, 1)
def forward(self, x):
att = self.fc2(self.relu(self.fc1(x))).squeeze(3)
att = F.softmax(att, dim=2).unsqueeze(3)
out = torch.sum(x * att, dim=2)
return out
class PositionEncoding(nn.Module):
"""
An implementation of trainable positional embedding which is added to
sequence features to inject time/position information.
Args:
max_len: The max number of trainable positional embeddings.
dim: the dimension of positional embedding.
"""
def __init__(self, max_len, dim, dropout=0.0):
super(PositionEncoding, self).__init__()
self.embed = nn.Embedding(max_len, dim)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
def forward(self, x):
batch_size, seq_len = x.shape[:2]
pos_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
pos_ids = pos_ids.unsqueeze(0).repeat(batch_size, 1)
pos_emb = self.dropout(self.relu(self.embed(pos_ids)))
return pos_emb
class MultiHeadAttention(nn.Module):
"""
An implementation of multi-head attention module, as described in
'Attention Is All You Need <https://arxiv.org/abs/1706.03762>'
Args:
dim: the dimension of features of hidden layers.
n_heads: the number of head.
"""
def __init__(self, dim, n_heads, dropout=0.0):
super(MultiHeadAttention, self).__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.to_q = nn.Linear(dim, dim)
self.to_k = nn.Linear(dim, dim)
self.to_v = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=-1)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.n_heads, self.head_dim)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
def forward(self, query, key, value, mask):
q = self.to_q(query)
k = self.to_k(key)
v = self.to_v(value)
q_trans = self.transpose_for_scores(q)
k_trans = self.transpose_for_scores(k)
v_trans = self.transpose_for_scores(v)
att = torch.matmul(q_trans, k_trans.transpose(-1,
-2)) # (N, nh, Lq, L)
att = att / math.sqrt(self.head_dim)
att = mask_logits(att, mask)
att = self.softmax(att)
att = self.dropout(att)
ctx_v = torch.matmul(att, v_trans) # (N, nh, Lq, dh)
ctx_v = ctx_v.permute(0, 2, 1, 3).contiguous() # (N, Lq, nh, dh)
shape = ctx_v.size()[:-2] + (self.dim, )
ctx_v = ctx_v.view(*shape) # (N, Lq, D)
return ctx_v
def mask_logits(inputs, mask, mask_value=-1e30):
mask = mask.type(torch.float32)
return inputs + (1.0 - mask) * mask_value

View File

@@ -0,0 +1,342 @@
# The implementation is adopted from CLIP, made publicly available
# under MIT License at https://github.com/openai/CLIP
import warnings
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
from torch import nn
class CLIP(nn.Module):
def __init__(
self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int):
super().__init__()
self.context_length = context_length
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask())
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(
torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(
torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
proj_std = (self.transformer.width**-0.5) * (
(2 * self.transformer.layers)**-0.5)
attn_std = self.transformer.width**-0.5
fc_std = (2 * self.transformer.width)**-0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(
self.text_projection, std=self.transformer.width**-0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(
self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(
dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self,
d_model: int,
n_head: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
('gelu', QuickGELU()),
('c_proj', nn.Linear(d_model * 4, d_model))]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(
dtype=x.dtype,
device=x.device) if self.attn_mask is not None else None
return self.attn(
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self,
width: int,
layers: int,
heads: int,
attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[
ResidualAttentionBlock(width, heads, attn_mask)
for _ in range(layers)
])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int,
layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(
in_channels=3,
out_channels=width,
kernel_size=patch_size,
stride=patch_size,
bias=False)
scale = width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn(
(input_resolution // patch_size)**2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1],
-1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
class_token = self.class_embedding.to(x.dtype) + torch.zeros(
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
x = torch.cat([class_token, x], dim=1)
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
def build_model(state_dict: dict):
vision_width = state_dict['visual.conv1.weight'].shape[0]
vision_layers = len([
k for k in state_dict.keys()
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
])
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
grid_size = round(
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
image_resolution = vision_patch_size * grid_size
embed_dim = state_dict['text_projection'].shape[1]
context_length = state_dict['positional_embedding'].shape[0]
vocab_size = state_dict['token_embedding.weight'].shape[0]
transformer_width = state_dict['ln_final.weight'].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(
set(
k.split('.')[2] for k in state_dict
if k.startswith('transformer.resblocks')))
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
vision_patch_size, context_length, vocab_size,
transformer_width, transformer_heads, transformer_layers)
for key in ['input_resolution', 'context_length', 'vocab_size']:
if key in state_dict:
del state_dict[key]
model.load_state_dict(state_dict)
return model.eval()
def load_clip(name: str,
device: Union[str, torch.device] = 'cuda'
if torch.cuda.is_available() else 'cpu',
jit=True):
jit = False
model_path = name
try:
model = torch.jit.load(
model_path, map_location=device if jit else 'cpu').eval()
state_dict = None
except RuntimeError:
if jit:
warnings.warn(
f'File {model_path} is not a JIT archive. Loading as a state dict instead'
)
jit = False
state_dict = torch.load(model_path, map_location='cpu')
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == 'cpu':
model.float()
return model
device_holder = torch.jit.trace(
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [
n for n in device_holder.graph.findAllNodes('prim::Constant')
if 'Device' in repr(n)
][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, 'graph') else []
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes('prim::Constant'):
if 'value' in node.attributeNames() and str(
node['value']).startswith('cuda'):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
if str(device) == 'cpu':
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, 'graph') else []
if hasattr(module, 'forward1'):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes('aten::to'):
inputs = list(node.inputs())
for i in [1, 2]:
if inputs[i].node()['value'] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model

View File

@@ -0,0 +1,156 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
import torch
import torch.nn as nn
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from .blocks import (BboxRegressor, Q2VRankerStage1, Q2VRankerStage2,
V2QRankerStage1, V2QRankerStage2)
from .swin_transformer import SwinTransformerV2_1D
@MODELS.register_module(
Tasks.video_temporal_grounding, module_name=Models.soonet)
class SOONet(TorchModel):
"""
The implementation of 'Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding
in Long Videos'. The model is dynamically initialized with the following parts:
- q2v_stage1: calculate qv_ctx_score.
- v2q_stage1: calculate vq_ctx_score.
- q2v_stage2: calculate qv_ctn_score.
- v2q_stage2: calculate vq_ctn_score.
- regressor: predict the offset of bounding box for each candidate anchor.
"""
def __init__(self, model_dir: str, *args, **kwargs):
"""
Initialize SOONet Model
Args:
model_dir: model id or path
"""
super().__init__()
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
self.config = Config.from_file(config_path).hyperparams
nscales = self.config.nscales
hidden_dim = self.config.hidden_dim
snippet_length = self.config.snippet_length
self.enable_stage2 = self.config.enable_stage2
self.stage2_topk = self.config.stage2_topk
self.nscales = nscales
self.video_encoder = SwinTransformerV2_1D(
patch_size=snippet_length,
in_chans=hidden_dim,
embed_dim=hidden_dim,
depths=[2] * nscales,
num_heads=[8] * nscales,
window_size=[64] * nscales,
mlp_ratio=2.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
patch_norm=True,
use_checkpoint=False,
pretrained_window_sizes=[0] * nscales)
self.q2v_stage1 = Q2VRankerStage1(nscales, hidden_dim)
self.v2q_stage1 = V2QRankerStage1(nscales, hidden_dim)
if self.enable_stage2:
self.q2v_stage2 = Q2VRankerStage2(nscales, hidden_dim,
snippet_length)
self.v2q_stage2 = V2QRankerStage2(nscales, hidden_dim)
self.regressor = BboxRegressor(hidden_dim, self.enable_stage2)
# Load trained weights
model_path = os.path.join(model_dir,
'SOONet_MAD_VIT-B-32_4Scale_10C.pth')
state_dict = torch.load(model_path, map_location='cpu')['model']
self.load_state_dict(state_dict, strict=True)
def forward(self, **kwargs):
if self.training:
return self.forward_train(**kwargs)
else:
return self.forward_test(**kwargs)
def forward_train(self, **kwargs):
raise NotImplementedError
def forward_test(self,
query_feats=None,
video_feats=None,
start_ts=None,
end_ts=None,
scale_boundaries=None,
**kwargs):
"""
Obtain matching scores and bbox bias of the top-k candidate anchors, with
pre-extracted query features and video features as input.
Args:
query_feats: the pre-extracted text features.
video_feats: the pre-extracted video features.
start_ts: the start timestamps of pre-defined multi-scale anchors.
end_ts: the end timestamps of pre-defined multi-scale anchors.
scale_boundaries: the begin and end anchor index for each scale in start_ts and end_ts.
Returns:
[final_scores, bbox_bias, starts, ends]
"""
sent_feat = query_feats
ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1))
qv_ctx_scores = self.q2v_stage1(ctx_feats, sent_feat)
if self.enable_stage2:
hit_indices = list()
starts = list()
ends = list()
filtered_ctx_feats = list()
for i in range(self.nscales):
_, indices = torch.sort(
qv_ctx_scores[i], dim=1, descending=True)
indices, _ = torch.sort(
torch.LongTensor(
list(
set(indices[:, :self.stage2_topk].flatten().cpu().
numpy().tolist()))))
indices = indices.to(video_feats.device)
hit_indices.append(indices)
filtered_ctx_feats.append(
torch.index_select(ctx_feats[i], 1, indices))
scale_first = scale_boundaries[i]
scale_last = scale_boundaries[i + 1]
filtered_start = torch.index_select(
start_ts[scale_first:scale_last], 0, indices)
filtered_end = torch.index_select(
end_ts[scale_first:scale_last], 0, indices)
starts.append(filtered_start)
ends.append(filtered_end)
starts = torch.cat(starts, dim=0)
ends = torch.cat(ends, dim=0)
qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2(
video_feats, sent_feat, hit_indices, qv_ctx_scores)
ctx_feats = filtered_ctx_feats
else:
ctn_feats = None
qv_merge_scores = qv_ctx_scores
starts = start_ts
ends = end_ts
bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat)
final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1))
return final_scores, bbox_bias, starts, ends

View File

@@ -0,0 +1,623 @@
# The implementation is adopted from Swin-Transformer-1D, made publicly available
# at https://github.com/meraks/Swin-Transformer-1D
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn.init import trunc_normal_
def drop_path(x,
drop_prob: float = 0.,
training: bool = False,
scale_by_keep: bool = True):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, L, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, C)
"""
B, L, C = x.shape
x = x.view(B, L // window_size, window_size, C)
windows = x.permute(0, 1, 2, 3).contiguous().view(-1, window_size, C)
return windows
def window_reverse(windows, window_size, L):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
L (int): sequence length
Returns:
x: (B, L, C)
"""
B = int(windows.shape[0] / (L / window_size))
x = windows.view(B, L // window_size, window_size, -1)
x = x.permute(0, 1, 2, 3).contiguous().view(B, L, -1)
return x
class WindowAttention_1D(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (int): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (int): The height and width of the window in pre-training.
"""
def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
attn_drop=0.,
proj_drop=0.,
pretrained_window_size=0):
super().__init__()
self.dim = dim
self.window_size = window_size # Wl
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.logit_scale = nn.Parameter(
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(1, 512, bias=True), nn.ReLU(inplace=True),
nn.Linear(512, num_heads, bias=False))
# get relative_coords_table
relative_coords_l = torch.arange(
-(self.window_size - 1), self.window_size, dtype=torch.float32)
relative_coords_table = torch.stack(
torch.meshgrid([relative_coords_l], indexing='ij')).permute(
1, 0).contiguous().unsqueeze(0) # 1, 2*Wl-1, 1
if pretrained_window_size > 0:
relative_coords_table[:, :, :] /= (pretrained_window_size - 1)
else:
relative_coords_table[:, :, :] /= (self.window_size - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
self.register_buffer('relative_coords_table', relative_coords_table)
# get pair-wise relative position index for each token inside the window
coords_l = torch.arange(self.window_size)
coords = torch.stack(torch.meshgrid([coords_l],
indexing='ij')) # 1, Wl
coords_flatten = torch.flatten(coords, 1) # 1, Wl
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:,
None, :] # 1, Wl, Wl
relative_coords = relative_coords.permute(1, 2,
0).contiguous() # Wl, Wl, 1
relative_coords[:, :,
0] += self.window_size - 1 # shift to start from 0
relative_position_index = relative_coords.sum(-1) # Wl, Wl
self.register_buffer('relative_position_index',
relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(dim))
self.v_bias = nn.Parameter(torch.zeros(dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wl, Wl) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat(
(self.q_bias,
torch.zeros_like(self.v_bias,
requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
# cosine attention
attn = (
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(
self.logit_scale,
max=torch.log(torch.tensor(1. / 0.01, device=attn.device))).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(
self.relative_coords_table).view(-1, self.num_heads)
relative_position_bias = relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size, self.window_size, -1) # Wl,l,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wl, Wl
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def compute_mask(L, window_size, shift_size):
Lp = int(np.ceil(L / window_size)) * window_size
img_mask = torch.zeros((1, Lp, 1)) # 1 Lp 1
pad_size = int(Lp - L)
if (pad_size == 0) or (pad_size + shift_size == window_size):
segs = (slice(-window_size), slice(-window_size, -shift_size),
slice(-shift_size, None))
elif pad_size + shift_size > window_size:
seg1 = int(window_size * 2 - L + shift_size)
segs = (slice(-seg1), slice(-seg1, -window_size),
slice(-window_size, -shift_size), slice(-shift_size, None))
elif pad_size + shift_size < window_size:
seg1 = int(window_size * 2 - L + shift_size)
segs = (slice(-window_size), slice(-window_size, -seg1),
slice(-seg1, -shift_size), slice(-shift_size, None))
cnt = 0
for d in segs:
img_mask[:, d, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size) # nW, ws, 1
mask_windows = mask_windows.squeeze(-1) # nW, ws
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
return attn_mask
class SwinTransformerBlock_1D(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pre-training.
"""
def __init__(self,
dim,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
pretrained_window_size=0):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
self.norm1 = norm_layer(dim)
self.attn = WindowAttention_1D(
dim,
window_size=self.window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop,
pretrained_window_size=pretrained_window_size)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
B, L, C = x.shape
attn_mask = compute_mask(L, self.window_size,
self.shift_size).to(x.device)
shortcut = x
# x = x.view(B, L, C)
# padding x
pad_r = (self.window_size - L % self.window_size) % self.window_size
x = F.pad(x, (0, 0, 0, pad_r))
_, Lp, _ = x.shape
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size), dims=(1))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x,
self.window_size) # nW*B, window_size, C
x_windows = x_windows.view(-1, self.window_size,
C) # nW*B, window_siz, C
# W-MSA/SW-MSA
attn_windows = self.attn(
x_windows, mask=attn_mask) # nW*B, window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size,
Lp) # B L' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size), dims=(1))
else:
x = shifted_x
x = x.view(B, Lp, C)
# reverse padding x
x = x[:, :L, :].contiguous()
x = shortcut + self.drop_path(self.norm1(x))
# FFN
x = x + self.drop_path(self.norm2(self.mlp(x)))
return x
class PatchMerging(nn.Module):
""" Patch Merging Layer
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
# self.reduction = nn.Linear(2 * dim, dim, bias=False)
# self.norm = norm_layer(2 * dim)
def forward(self, x):
""" Forward function.
Args:
x: Input feature, tensor size (B, L, C).
"""
B, L, C = x.shape
x = F.pad(x, (0, 0, 0, L % 2))
x0 = x[:, 0::2, :] # B L/2 C
x1 = x[:, 1::2, :] # B L/2 C
x = torch.maximum(x0, x1)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
pretrained_window_size (int): Local window size in pre-training.
"""
def __init__(self,
dim,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
pretrained_window_size=0):
super().__init__()
self.dim = dim
self.depth = depth
self.use_checkpoint = use_checkpoint
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock_1D(
dim=dim,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size)
for i in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
proposal = x
if self.downsample is not None:
x = self.downsample(x)
return x, proposal
def _init_respostnorm(self):
for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0)
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
class PatchEmbed1D(nn.Module):
""" Video to Patch Embedding.
Args:
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input video channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self,
patch_size=4,
in_chans=32,
embed_dim=128,
norm_layer=None):
super().__init__()
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
self.proj = nn.Conv1d(
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
"""Forward function."""
# padding
_, _, L = x.size()
pad_r = (self.patch_size - L % self.patch_size) % self.patch_size
x = F.pad(x, (0, pad_r))
x = self.proj(x) # B C Wl
if self.norm is not None:
# Wl = x.size(2)
x = x.transpose(1, 2)
x = self.norm(x)
# x = x.transpose(1, 2).view(-1, self.embed_dim, Wl)
return x
class SwinTransformerV2_1D(nn.Module):
def __init__(self,
patch_size=4,
in_chans=32,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=[7, 7, 7, 7],
mlp_ratio=4.,
qkv_bias=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
patch_norm=True,
use_checkpoint=False,
pretrained_window_sizes=[0, 0, 0, 0],
**kwargs):
super().__init__()
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2**(self.num_layers - 1))
self.mlp_ratio = mlp_ratio
# split image into non-overlapping patches
self.patch_embed = PatchEmbed1D(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=embed_dim,
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size[i_layer],
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if
(i_layer < self.num_layers - 1) else None,
use_checkpoint=use_checkpoint,
pretrained_window_size=pretrained_window_sizes[i_layer])
self.layers.append(layer)
self.apply(self._init_weights)
for bly in self.layers:
bly._init_respostnorm()
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'cpb_mlp', 'logit_scale', 'relative_position_bias_table'}
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
proposals = list()
for layer in self.layers:
x, proposal = layer(x)
proposals.append(proposal)
return proposals
def forward(self, x):
return self.forward_features(x)

View File

@@ -0,0 +1,152 @@
# The implementation is adopted from CLIP, made publicly available
# under MIT License at https://github.com/openai/CLIP
import gzip
import html
from functools import lru_cache
import ftfy
import regex as re
import torch
@lru_cache()
def bytes_to_unicode():
bs = list(range(ord('!'),
ord('~') + 1)) + list(range(
ord('¡'),
ord('¬') + 1)) + list(range(ord('®'),
ord('ÿ') + 1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
merges = merges[1:49152 - 256 - 2 + 1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v + '</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {
'<|startoftext|>': '<|startoftext|>',
'<|endoftext|>': '<|endoftext|>'
}
self.pat = re.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + (token[-1] + '</w>', )
pairs = get_pairs(word)
if not pairs:
return token + '</w>'
while True:
bigram = min(
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except ValueError:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word) - 1 and word[
i + 1] == second:
new_word.append(first + second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b]
for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token]
for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode(
'utf-8', errors='replace').replace('</w>', ' ')
return text
def tokenize(self, texts, context_length=77):
if isinstance(texts, str):
texts = [texts]
sot_token = self.encoder['<|startoftext|>']
eot_token = self.encoder['<|endoftext|>']
all_tokens = [[sot_token] + self.encode(text) + [eot_token]
for text in texts]
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
for i, tokens in enumerate(all_tokens):
if len(tokens) > context_length:
tokens = tokens[:context_length]
tokens[-1] = eot_token
result[i, :len(tokens)] = torch.tensor(tokens)
return result

View File

@@ -0,0 +1,58 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import copy
import decord
import numpy as np
from decord import VideoReader, cpu
from decord._ffi.base import DECORDError
from tqdm import tqdm
def decode_video(video_path, target_fps=5):
"""
Decode video from 'video_path' and return the sampled frames based on target_fps.
The default value of target_fps is 5.
Args:
video_path: the absolute path of video.
target_fps: the number of sampled video frames per second.
Returns:
[imgs, duration]
"""
decord.bridge.set_bridge('torch')
vr = VideoReader(video_path, ctx=cpu(0))
cur_fps = vr.get_avg_fps()
if cur_fps > target_fps:
interval = float(cur_fps) / float(target_fps)
start = float(interval) / 2.
else:
interval = 1.0
start = 0.0
vid_length = len(vr)
duration = vid_length / cur_fps
sampled_idxs = np.clip(
np.round(np.arange(start, float(vid_length), step=interval)), 0,
vid_length - 1).astype(np.int32)
imgs = list()
for i in tqdm(sampled_idxs):
bias = 0
# avoid broken frames
while bias <= 10:
try:
img = vr[i - bias]
break
except DECORDError:
bias += 1
if bias > 10:
img = copy.deepcopy(imgs[-1])
imgs.append(img)
else:
img = img / 255.
img = img.permute(2, 0, 1)
imgs.append(img)
return imgs, duration

View File

@@ -57,6 +57,7 @@ class OutputKeys(object):
MATCHES = 'matches'
PCD12 = 'pcd12'
PCD12_ALIGN = 'pcd12_align'
TBOUNDS = 'tbounds'
TASK_OUTPUTS = {
@@ -1105,6 +1106,7 @@ TASK_OUTPUTS = {
Tasks.document_grounded_dialog_generate: [OutputKeys.TEXT],
Tasks.document_grounded_dialog_rerank: [OutputKeys.OUTPUT],
Tasks.document_grounded_dialog_retrieval: [OutputKeys.OUTPUT],
Tasks.video_temporal_grounding: [OutputKeys.SCORES, OutputKeys.TBOUNDS],
}

View File

@@ -19,6 +19,7 @@ if TYPE_CHECKING:
from .video_captioning_pipeline import VideoCaptioningPipeline
from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline
from .diffusers_wrapped import StableDiffusionWrapperPipeline, ChineseStableDiffusionPipeline
from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
else:
_import_structure = {
@@ -41,6 +42,8 @@ else:
['VideoQuestionAnsweringPipeline'],
'diffusers_wrapped':
['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'],
'soonet_video_temporal_grounding_pipeline':
['SOONetVideoTemporalGroundingPipeline'],
'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'],
}

View File

@@ -0,0 +1,222 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import os
from typing import Any, Dict
import numpy as np
import torch
from torchvision import transforms
from modelscope.metainfo import Pipelines
from modelscope.models.multi_modal.soonet import (SimpleTokenizer,
decode_video, load_clip)
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.video_temporal_grounding,
module_name=Pipelines.soonet_video_temporal_grounding)
class SOONetVideoTemporalGroundingPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
SOONet pipeline for video temporal groundinng
Examples:
>>> from modelscope.pipelines import pipeline
>>> soonet_pipeline = pipeline("video-temporal-grounding", "damo/multi-modal_soonet_video-temporal-grounding")
>>> soonet_pipeline(
('a man takes food out of the refrigerator.',
'soonet_video_temporal_grounding_test_video.mp4'))
>>> {
>>> "scores": [
>>> 0.80661213,
>>> 0.8060084,
>>> 0.8018835,
>>> 0.79837507,
>>> 0.7963626,
>>> 0.7949013,
>>> 0.79353744,
>>> 0.79287416,
>>> 0.79066336,
>>> 0.79027915
>>> ],
>>> "tbounds": [
>>> [
>>> 0,
>>> 2.9329566955566406
>>> ],
>>> [
>>> 1.0630402565002441,
>>> 4.9339457750320435
>>> ],
>>> [
>>> 300.96919429302216,
>>> 304.8546848297119
>>> ],
>>> [
>>> 302.96981167793274,
>>> 306.7714672088623
>>> ],
>>> [
>>> 0,
>>> 5.0421366691589355
>>> ],
>>> [
>>> 304.9119266271591,
>>> 308.7636929154396
>>> ],
>>> [
>>> 258.96133184432983,
>>> 262.805901825428
>>> ],
>>> [
>>> 122.9599289894104,
>>> 126.86622190475464
>>> ],
>>> [
>>> 126.94010400772095,
>>> 130.8090701699257
>>> ],
>>> [
>>> 121.04773849248886,
>>> 124.79261875152588
>>> ]
>>> ]
>>> }
"""
super().__init__(model=model, **kwargs)
self.model_dir = model
self.clip = load_clip(os.path.join(self.model_dir,
'ViT-B-32.pt')).to(self.device)
self.model = self.model.float().to(self.device)
self.model.eval()
# Load Configuration from File
config_path = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
self.config = Config.from_file(config_path).hyperparams
self.nscales = self.config.nscales
self.snippet_length = self.config.snippet_length
self.max_anchor_length = self.snippet_length * 2**(self.nscales - 1)
self.topk = 10
self.fps = 5
# Define image transform
self.img_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
logger.info('Init transform done')
# Init tokenizer
bpe_path = os.path.join(self.model_dir, 'bpe_simple_vocab_16e6.txt.gz')
self.tokenizer = SimpleTokenizer(bpe_path)
logger.info('Init tokenizer done')
def pad(self, arr, pad_len):
new_arr = np.zeros((pad_len, ), dtype=float)
new_arr[:len(arr)] = arr
return new_arr
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
text, video_name = input
video_path = os.path.join(self.model_dir, video_name)
imgs, duration = decode_video(video_path, self.fps)
trans_imgs = list()
for i, img in enumerate(imgs):
trans_imgs.append(self.img_transform(img))
imgs = trans_imgs
token_ids = self.tokenizer.tokenize(text).to(
self.device, non_blocking=True)
# get the start and end timestamps of anchors
start_ts, end_ts, scale_boundaries = list(), list(), [0]
ori_video_length = len(imgs)
pad_video_length = int(
np.math.ceil(ori_video_length / self.max_anchor_length)
* self.max_anchor_length)
for i in range(self.config.nscales):
anchor_length = self.config.snippet_length * (2**i)
pad_feat_length = pad_video_length // anchor_length
nfeats = np.math.ceil(ori_video_length / anchor_length)
s_times = np.arange(0, nfeats).astype(np.float32) * (
anchor_length // self.fps)
e_times = np.arange(1, nfeats + 1).astype(np.float32) * (
anchor_length // self.fps)
e_times = np.minimum(duration, e_times)
start_ts.append(self.pad(s_times, pad_feat_length))
end_ts.append(self.pad(e_times, pad_feat_length))
scale_boundaries.append(scale_boundaries[-1] + pad_feat_length)
start_ts = torch.from_numpy(np.concatenate(start_ts, axis=0))
end_ts = torch.from_numpy(np.concatenate(end_ts, axis=0))
scale_boundaries = torch.LongTensor(scale_boundaries)
result = {
'token_ids': token_ids,
'imgs': torch.stack(imgs, dim=0),
'start_ts': start_ts,
'end_ts': end_ts,
'scale_boundaries': scale_boundaries
}
return result
def forward(self, input: Dict[str, Any],
**forward_params) -> Dict[str, Any]:
with torch.no_grad():
video_feats = self.clip.encode_image(input['imgs'].to(self.device))
query_feats = self.clip.encode_text(input['token_ids'].to(
self.device))
#
ori_video_length, feat_dim = video_feats.shape
pad_video_length = int(
np.math.ceil(ori_video_length / self.max_anchor_length)
* self.max_anchor_length)
pad_video_feats = torch.zeros((pad_video_length, feat_dim),
dtype=float)
pad_video_feats[:ori_video_length, :] = video_feats
final_scores, bbox_bias, starts, ends = self.model(
query_feats=query_feats.float().to(self.device),
video_feats=pad_video_feats.unsqueeze(0).float().to(
self.device),
start_ts=input['start_ts'].float().to(self.device),
end_ts=input['end_ts'].float().to(self.device),
scale_boundaries=input['scale_boundaries'])
#
final_scores = final_scores.cpu().numpy()
bbox_bias = bbox_bias.cpu().numpy()
starts = starts.cpu().numpy()
ends = ends.cpu().numpy()
pred_scores, pred_bboxes = list(), list()
rank_id = np.argsort(final_scores[0])[::-1]
for j in range(self.topk):
if j >= len(rank_id):
break
pred_scores.append(final_scores[0, rank_id[j]])
ori_end = float(ends[rank_id[j]])
ori_start = float(starts[rank_id[j]])
duration = ori_end - ori_start
sbias = bbox_bias[0, rank_id[j], 0]
ebias = bbox_bias[0, rank_id[j], 1]
pred_start = max(0, ori_start + sbias * duration)
pred_end = ori_end + ebias * duration
pred_bboxes.append([pred_start, pred_end])
return {
OutputKeys.SCORES: pred_scores,
OutputKeys.TBOUNDS: pred_bboxes
}
def postprocess(self, inputs: Dict[str, Any],
**post_params) -> Dict[str, Any]:
return inputs

View File

@@ -239,6 +239,7 @@ class MultiModalTasks(object):
document_vl_embedding = 'document-vl-embedding'
video_captioning = 'video-captioning'
video_question_answering = 'video-question-answering'
video_temporal_grounding = 'video-temporal-grounding'
text_to_video_synthesis = 'text-to-video-synthesis'

View File

@@ -0,0 +1,34 @@
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
import unittest
from modelscope.models import Model
from modelscope.models.multi_modal.soonet import SOONet
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class SOONetVideoTemporalGroundingTest(unittest.TestCase,
DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.video_temporal_grounding
self.model_id = 'damo/multi-modal_soonet_video-temporal-grounding'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
soonet_pipeline = pipeline(self.task, self.model_id)
result = soonet_pipeline(
('a man takes food out of the refrigerator.',
'soonet_video_temporal_grounding_test_video.mp4'))
print(f'soonet output: {result}.')
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_load_model_from_pretrained(self):
model = Model.from_pretrained(self.model_id)
self.assertTrue(model.__class__ == SOONet)
if __name__ == '__main__':
unittest.main()