mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
add soonet for video temporal grounding
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11810444
This commit is contained in:
@@ -195,6 +195,7 @@ class Models(object):
|
||||
mgeo = 'mgeo'
|
||||
vldoc = 'vldoc'
|
||||
hitea = 'hitea'
|
||||
soonet = 'soonet'
|
||||
|
||||
# science models
|
||||
unifold = 'unifold'
|
||||
@@ -497,6 +498,7 @@ class Pipelines(object):
|
||||
text_to_video_synthesis = 'latent-text-to-video-synthesis' # latent-text-to-video-synthesis
|
||||
gridvlp_multi_modal_classification = 'gridvlp-multi-modal-classification'
|
||||
gridvlp_multi_modal_embedding = 'gridvlp-multi-modal-embedding'
|
||||
soonet_video_temporal_grounding = 'soonet-video-temporal-grounding'
|
||||
|
||||
# science tasks
|
||||
protein_structure = 'unifold-protein-structure'
|
||||
|
||||
27
modelscope/models/multi_modal/soonet/__init__.py
Normal file
27
modelscope/models/multi_modal/soonet/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tokenizer import SimpleTokenizer
|
||||
from .model import SOONet
|
||||
from .utils import decode_video
|
||||
from .clip import load_clip
|
||||
else:
|
||||
_import_structure = {
|
||||
'model': ['SOONet'],
|
||||
'tokenizer': ['SimpleTokenizer'],
|
||||
'utils': ['decode_video'],
|
||||
'clip': ['load_clip']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
287
modelscope/models/multi_modal/soonet/blocks.py
Normal file
287
modelscope/models/multi_modal/soonet/blocks.py
Normal file
@@ -0,0 +1,287 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Q2VRankerStage1(nn.Module):
|
||||
"""
|
||||
Used to calculate the qv_ctx_score with query embedding and multi anchor context embeddings as input.
|
||||
The qv_ctx_score is used to pre-rank and retain top-k related anchors.
|
||||
"""
|
||||
|
||||
def __init__(self, nscales, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.nscales = nscales
|
||||
|
||||
def forward(self, ctx_feats, qfeat):
|
||||
qfeat = self.fc(qfeat)
|
||||
qv_ctx_scores = list()
|
||||
for i in range(self.nscales):
|
||||
score = torch.einsum('bld,bd->bl',
|
||||
F.normalize(ctx_feats[i], p=2, dim=2),
|
||||
F.normalize(qfeat, p=2, dim=1))
|
||||
qv_ctx_scores.append(score)
|
||||
|
||||
return qv_ctx_scores
|
||||
|
||||
|
||||
class V2QRankerStage1(nn.Module):
|
||||
"""
|
||||
Used to calculate the vq_ctx_score with anchor context embeddings and multi query embeddings as input.
|
||||
"""
|
||||
|
||||
def __init__(self, nscales, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.nscales = nscales
|
||||
|
||||
def forward(self, ctx_feats, qfeat):
|
||||
vq_ctx_scores = list()
|
||||
for i in range(self.nscales):
|
||||
score = torch.einsum(
|
||||
'bld,bd->bl', F.normalize(self.fc(ctx_feats[i]), p=2, dim=2),
|
||||
F.normalize(qfeat, p=2, dim=1))
|
||||
vq_ctx_scores.append(score)
|
||||
|
||||
return vq_ctx_scores
|
||||
|
||||
|
||||
class Q2VRankerStage2(nn.Module):
|
||||
"""
|
||||
Used to calculate the qv_ctn_score with query embedding and video sequence embedding as input.
|
||||
The qv_ctn_score is used to re-rank anchors.
|
||||
"""
|
||||
|
||||
def __init__(self, nscales, hidden_dim, snippet_length=10):
|
||||
super().__init__()
|
||||
self.nscales = nscales
|
||||
self.snippet_length = snippet_length
|
||||
self.qfc = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.encoder = V2VAttention()
|
||||
|
||||
def forward(self, vfeats, qfeat, hit_indices, qv_ctx_scores):
|
||||
qfeat = self.qfc(qfeat)
|
||||
|
||||
qv_ctn_scores = list()
|
||||
qv_merge_scores = list()
|
||||
|
||||
_, L, D = vfeats.size()
|
||||
ctn_feats = list()
|
||||
for i in range(self.nscales):
|
||||
anchor_length = self.snippet_length * 2**i
|
||||
assert L // anchor_length == qv_ctx_scores[i].size(1)
|
||||
qv_ctx_score = torch.index_select(qv_ctx_scores[i], 1,
|
||||
hit_indices[i])
|
||||
|
||||
ctn_feat = vfeats.view(L // anchor_length, anchor_length,
|
||||
D).detach()
|
||||
ctn_feat = torch.index_select(ctn_feat, 0, hit_indices[i])
|
||||
ctn_feat = self.encoder(
|
||||
ctn_feat,
|
||||
torch.ones(ctn_feat.size()[:2], device=ctn_feat.device))
|
||||
ctn_feats.append(ctn_feat)
|
||||
|
||||
qv_ctn_score = torch.einsum(
|
||||
'bkld,bd->bkl', F.normalize(ctn_feat.unsqueeze(0), p=2, dim=3),
|
||||
F.normalize(qfeat, p=2, dim=1))
|
||||
qv_ctn_score, _ = torch.max(qv_ctn_score, dim=2)
|
||||
qv_ctn_scores.append(qv_ctn_score)
|
||||
qv_merge_scores.append(qv_ctx_score + qv_ctn_score)
|
||||
|
||||
return qv_merge_scores, qv_ctn_scores, ctn_feats
|
||||
|
||||
|
||||
class V2QRankerStage2(nn.Module):
|
||||
"""
|
||||
Used to calculate the vq_ctn_score with anchor content embeddings and multi query embeddings as input.
|
||||
"""
|
||||
|
||||
def __init__(self, nscales, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.nscales = nscales
|
||||
|
||||
def forward(self, ctn_feats, qfeat):
|
||||
vq_ctn_scores = list()
|
||||
for i in range(self.nscales):
|
||||
score = torch.einsum(
|
||||
'bkld,bd->bkl',
|
||||
F.normalize(self.fc(ctn_feats[i]).unsqueeze(0), p=2, dim=3),
|
||||
F.normalize(qfeat, p=2, dim=1))
|
||||
score = torch.mean(score, dim=2)
|
||||
vq_ctn_scores.append(score)
|
||||
|
||||
return vq_ctn_scores
|
||||
|
||||
|
||||
class V2VAttention(nn.Module):
|
||||
"""
|
||||
Self-attention encoder for anchor frame sequence to encode intra-anchor knowledge.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.posemb = PositionEncoding(max_len=400, dim=512, dropout=0.0)
|
||||
self.encoder = MultiHeadAttention(dim=512, n_heads=8, dropout=0.1)
|
||||
self.dropout = nn.Dropout(0.0)
|
||||
|
||||
def forward(self, video_feats, video_masks):
|
||||
mask = torch.einsum('bm,bn->bmn', video_masks,
|
||||
video_masks).unsqueeze(1)
|
||||
residual = video_feats
|
||||
video_feats = video_feats + self.posemb(video_feats)
|
||||
out = self.encoder(
|
||||
query=video_feats, key=video_feats, value=video_feats, mask=mask)
|
||||
video_feats = self.dropout(residual
|
||||
+ out) * video_masks.unsqueeze(2).float()
|
||||
return video_feats
|
||||
|
||||
|
||||
class BboxRegressor(nn.Module):
|
||||
"""
|
||||
Predict the offset of bounding box for each candidate anchor.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim, enable_stage2=False):
|
||||
super().__init__()
|
||||
self.fc_ctx = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.fc_q = nn.Linear(hidden_dim, hidden_dim)
|
||||
|
||||
if enable_stage2:
|
||||
self.fc_ctn = nn.Linear(hidden_dim, hidden_dim)
|
||||
self.attn = SelfAttention(hidden_dim)
|
||||
self.predictor = nn.Sequential(
|
||||
nn.Linear(2 * hidden_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 2))
|
||||
else:
|
||||
self.predictor = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
|
||||
nn.Linear(hidden_dim, 2))
|
||||
self.enable_stage2 = enable_stage2
|
||||
|
||||
def forward(self, ctx_feats, ctn_feats, qfeat):
|
||||
qfeat = self.fc_q(qfeat)
|
||||
|
||||
ctx_feats = torch.cat(ctx_feats, dim=1)
|
||||
ctx_fuse_feats = F.relu(self.fc_ctx(ctx_feats)) * F.relu(
|
||||
qfeat.unsqueeze(1))
|
||||
|
||||
if self.enable_stage2 and ctn_feats:
|
||||
ctn_fuse_feats = list()
|
||||
for i in range(len(ctn_feats)):
|
||||
out = F.relu(self.fc_ctn(ctn_feats[i]).unsqueeze(0)) * F.relu(
|
||||
qfeat.unsqueeze(1).unsqueeze(1))
|
||||
out = self.attn(out)
|
||||
ctn_fuse_feats.append(out)
|
||||
ctn_fuse_feats = torch.cat(ctn_fuse_feats, dim=1)
|
||||
fuse_feats = torch.cat([ctx_fuse_feats, ctn_fuse_feats], dim=-1)
|
||||
else:
|
||||
fuse_feats = ctx_fuse_feats
|
||||
|
||||
out = self.predictor(fuse_feats)
|
||||
return out
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""
|
||||
Obtain pooled features by self-attentive pooling.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim):
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(hidden_dim, hidden_dim // 2)
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(hidden_dim // 2, 1)
|
||||
|
||||
def forward(self, x):
|
||||
att = self.fc2(self.relu(self.fc1(x))).squeeze(3)
|
||||
att = F.softmax(att, dim=2).unsqueeze(3)
|
||||
out = torch.sum(x * att, dim=2)
|
||||
return out
|
||||
|
||||
|
||||
class PositionEncoding(nn.Module):
|
||||
"""
|
||||
An implementation of trainable positional embedding which is added to
|
||||
sequence features to inject time/position information.
|
||||
|
||||
Args:
|
||||
max_len: The max number of trainable positional embeddings.
|
||||
dim: the dimension of positional embedding.
|
||||
"""
|
||||
|
||||
def __init__(self, max_len, dim, dropout=0.0):
|
||||
super(PositionEncoding, self).__init__()
|
||||
|
||||
self.embed = nn.Embedding(max_len, dim)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, seq_len = x.shape[:2]
|
||||
pos_ids = torch.arange(seq_len, dtype=torch.long, device=x.device)
|
||||
pos_ids = pos_ids.unsqueeze(0).repeat(batch_size, 1)
|
||||
pos_emb = self.dropout(self.relu(self.embed(pos_ids)))
|
||||
|
||||
return pos_emb
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
An implementation of multi-head attention module, as described in
|
||||
'Attention Is All You Need <https://arxiv.org/abs/1706.03762>'
|
||||
|
||||
Args:
|
||||
dim: the dimension of features of hidden layers.
|
||||
n_heads: the number of head.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, n_heads, dropout=0.0):
|
||||
super(MultiHeadAttention, self).__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
self.to_q = nn.Linear(dim, dim)
|
||||
self.to_k = nn.Linear(dim, dim)
|
||||
self.to_v = nn.Linear(dim, dim)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def transpose_for_scores(self, x):
|
||||
new_x_shape = x.size()[:-1] + (self.n_heads, self.head_dim)
|
||||
x = x.view(*new_x_shape)
|
||||
return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
q = self.to_q(query)
|
||||
k = self.to_k(key)
|
||||
v = self.to_v(value)
|
||||
|
||||
q_trans = self.transpose_for_scores(q)
|
||||
k_trans = self.transpose_for_scores(k)
|
||||
v_trans = self.transpose_for_scores(v)
|
||||
|
||||
att = torch.matmul(q_trans, k_trans.transpose(-1,
|
||||
-2)) # (N, nh, Lq, L)
|
||||
att = att / math.sqrt(self.head_dim)
|
||||
att = mask_logits(att, mask)
|
||||
att = self.softmax(att)
|
||||
att = self.dropout(att)
|
||||
|
||||
ctx_v = torch.matmul(att, v_trans) # (N, nh, Lq, dh)
|
||||
ctx_v = ctx_v.permute(0, 2, 1, 3).contiguous() # (N, Lq, nh, dh)
|
||||
shape = ctx_v.size()[:-2] + (self.dim, )
|
||||
ctx_v = ctx_v.view(*shape) # (N, Lq, D)
|
||||
return ctx_v
|
||||
|
||||
|
||||
def mask_logits(inputs, mask, mask_value=-1e30):
|
||||
mask = mask.type(torch.float32)
|
||||
return inputs + (1.0 - mask) * mask_value
|
||||
342
modelscope/models/multi_modal/soonet/clip.py
Normal file
342
modelscope/models/multi_modal/soonet/clip.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# The implementation is adopted from CLIP, made publicly available
|
||||
# under MIT License at https://github.com/openai/CLIP
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask())
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(
|
||||
torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(
|
||||
torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
proj_std = (self.transformer.width**-0.5) * (
|
||||
(2 * self.transformer.layers)**-0.5)
|
||||
attn_std = self.transformer.width**-0.5
|
||||
fc_std = (2 * self.transformer.width)**-0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(
|
||||
self.text_projection, std=self.transformer.width**-0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float('-inf'))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(
|
||||
self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
# x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(
|
||||
dim=1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape = [global_batch_size, global_batch_size]
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
n_head: int,
|
||||
attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(
|
||||
OrderedDict([('c_fc', nn.Linear(d_model, d_model * 4)),
|
||||
('gelu', QuickGELU()),
|
||||
('c_proj', nn.Linear(d_model * 4, d_model))]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(
|
||||
dtype=x.dtype,
|
||||
device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(
|
||||
x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
width: int,
|
||||
layers: int,
|
||||
heads: int,
|
||||
attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[
|
||||
ResidualAttentionBlock(width, heads, attn_mask)
|
||||
for _ in range(layers)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int,
|
||||
layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=width,
|
||||
kernel_size=patch_size,
|
||||
stride=patch_size,
|
||||
bias=False)
|
||||
|
||||
scale = width**-0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn(
|
||||
(input_resolution // patch_size)**2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1],
|
||||
-1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
class_token = self.class_embedding.to(x.dtype) + torch.zeros(
|
||||
x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
|
||||
x = torch.cat([class_token, x], dim=1)
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x[:, 0, :])
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def build_model(state_dict: dict):
|
||||
vision_width = state_dict['visual.conv1.weight'].shape[0]
|
||||
vision_layers = len([
|
||||
k for k in state_dict.keys()
|
||||
if k.startswith('visual.') and k.endswith('.attn.in_proj_weight')
|
||||
])
|
||||
vision_patch_size = state_dict['visual.conv1.weight'].shape[-1]
|
||||
grid_size = round(
|
||||
(state_dict['visual.positional_embedding'].shape[0] - 1)**0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
|
||||
embed_dim = state_dict['text_projection'].shape[1]
|
||||
context_length = state_dict['positional_embedding'].shape[0]
|
||||
vocab_size = state_dict['token_embedding.weight'].shape[0]
|
||||
transformer_width = state_dict['ln_final.weight'].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(
|
||||
set(
|
||||
k.split('.')[2] for k in state_dict
|
||||
if k.startswith('transformer.resblocks')))
|
||||
|
||||
model = CLIP(embed_dim, image_resolution, vision_layers, vision_width,
|
||||
vision_patch_size, context_length, vocab_size,
|
||||
transformer_width, transformer_heads, transformer_layers)
|
||||
|
||||
for key in ['input_resolution', 'context_length', 'vocab_size']:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def load_clip(name: str,
|
||||
device: Union[str, torch.device] = 'cuda'
|
||||
if torch.cuda.is_available() else 'cpu',
|
||||
jit=True):
|
||||
jit = False
|
||||
model_path = name
|
||||
try:
|
||||
model = torch.jit.load(
|
||||
model_path, map_location=device if jit else 'cpu').eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
if jit:
|
||||
warnings.warn(
|
||||
f'File {model_path} is not a JIT archive. Loading as a state dict instead'
|
||||
)
|
||||
jit = False
|
||||
state_dict = torch.load(model_path, map_location='cpu')
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
if str(device) == 'cpu':
|
||||
model.float()
|
||||
return model
|
||||
|
||||
device_holder = torch.jit.trace(
|
||||
lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [
|
||||
n for n in device_holder.graph.findAllNodes('prim::Constant')
|
||||
if 'Device' in repr(n)
|
||||
][-1]
|
||||
|
||||
def patch_device(module):
|
||||
graphs = [module.graph] if hasattr(module, 'graph') else []
|
||||
if hasattr(module, 'forward1'):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes('prim::Constant'):
|
||||
if 'value' in node.attributeNames() and str(
|
||||
node['value']).startswith('cuda'):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
if str(device) == 'cpu':
|
||||
float_holder = torch.jit.trace(
|
||||
lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode('aten::to').inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
graphs = [module.graph] if hasattr(module, 'graph') else []
|
||||
if hasattr(module, 'forward1'):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes('aten::to'):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]:
|
||||
if inputs[i].node()['value'] == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
patch_float(model.encode_image)
|
||||
patch_float(model.encode_text)
|
||||
|
||||
model.float()
|
||||
|
||||
return model
|
||||
156
modelscope/models/multi_modal/soonet/model.py
Normal file
156
modelscope/models/multi_modal/soonet/model.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .blocks import (BboxRegressor, Q2VRankerStage1, Q2VRankerStage2,
|
||||
V2QRankerStage1, V2QRankerStage2)
|
||||
from .swin_transformer import SwinTransformerV2_1D
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.video_temporal_grounding, module_name=Models.soonet)
|
||||
class SOONet(TorchModel):
|
||||
"""
|
||||
The implementation of 'Scanning Only Once: An End-to-end Framework for Fast Temporal Grounding
|
||||
in Long Videos'. The model is dynamically initialized with the following parts:
|
||||
- q2v_stage1: calculate qv_ctx_score.
|
||||
- v2q_stage1: calculate vq_ctx_score.
|
||||
- q2v_stage2: calculate qv_ctn_score.
|
||||
- v2q_stage2: calculate vq_ctn_score.
|
||||
- regressor: predict the offset of bounding box for each candidate anchor.
|
||||
"""
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""
|
||||
Initialize SOONet Model
|
||||
|
||||
Args:
|
||||
model_dir: model id or path
|
||||
"""
|
||||
super().__init__()
|
||||
config_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
|
||||
self.config = Config.from_file(config_path).hyperparams
|
||||
nscales = self.config.nscales
|
||||
hidden_dim = self.config.hidden_dim
|
||||
snippet_length = self.config.snippet_length
|
||||
self.enable_stage2 = self.config.enable_stage2
|
||||
self.stage2_topk = self.config.stage2_topk
|
||||
self.nscales = nscales
|
||||
|
||||
self.video_encoder = SwinTransformerV2_1D(
|
||||
patch_size=snippet_length,
|
||||
in_chans=hidden_dim,
|
||||
embed_dim=hidden_dim,
|
||||
depths=[2] * nscales,
|
||||
num_heads=[8] * nscales,
|
||||
window_size=[64] * nscales,
|
||||
mlp_ratio=2.,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm,
|
||||
patch_norm=True,
|
||||
use_checkpoint=False,
|
||||
pretrained_window_sizes=[0] * nscales)
|
||||
|
||||
self.q2v_stage1 = Q2VRankerStage1(nscales, hidden_dim)
|
||||
self.v2q_stage1 = V2QRankerStage1(nscales, hidden_dim)
|
||||
if self.enable_stage2:
|
||||
self.q2v_stage2 = Q2VRankerStage2(nscales, hidden_dim,
|
||||
snippet_length)
|
||||
self.v2q_stage2 = V2QRankerStage2(nscales, hidden_dim)
|
||||
self.regressor = BboxRegressor(hidden_dim, self.enable_stage2)
|
||||
|
||||
# Load trained weights
|
||||
model_path = os.path.join(model_dir,
|
||||
'SOONet_MAD_VIT-B-32_4Scale_10C.pth')
|
||||
state_dict = torch.load(model_path, map_location='cpu')['model']
|
||||
self.load_state_dict(state_dict, strict=True)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
if self.training:
|
||||
return self.forward_train(**kwargs)
|
||||
else:
|
||||
return self.forward_test(**kwargs)
|
||||
|
||||
def forward_train(self, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_test(self,
|
||||
query_feats=None,
|
||||
video_feats=None,
|
||||
start_ts=None,
|
||||
end_ts=None,
|
||||
scale_boundaries=None,
|
||||
**kwargs):
|
||||
"""
|
||||
Obtain matching scores and bbox bias of the top-k candidate anchors, with
|
||||
pre-extracted query features and video features as input.
|
||||
|
||||
Args:
|
||||
query_feats: the pre-extracted text features.
|
||||
video_feats: the pre-extracted video features.
|
||||
start_ts: the start timestamps of pre-defined multi-scale anchors.
|
||||
end_ts: the end timestamps of pre-defined multi-scale anchors.
|
||||
scale_boundaries: the begin and end anchor index for each scale in start_ts and end_ts.
|
||||
|
||||
Returns:
|
||||
[final_scores, bbox_bias, starts, ends]
|
||||
"""
|
||||
sent_feat = query_feats
|
||||
ctx_feats = self.video_encoder(video_feats.permute(0, 2, 1))
|
||||
qv_ctx_scores = self.q2v_stage1(ctx_feats, sent_feat)
|
||||
if self.enable_stage2:
|
||||
hit_indices = list()
|
||||
starts = list()
|
||||
ends = list()
|
||||
filtered_ctx_feats = list()
|
||||
for i in range(self.nscales):
|
||||
_, indices = torch.sort(
|
||||
qv_ctx_scores[i], dim=1, descending=True)
|
||||
indices, _ = torch.sort(
|
||||
torch.LongTensor(
|
||||
list(
|
||||
set(indices[:, :self.stage2_topk].flatten().cpu().
|
||||
numpy().tolist()))))
|
||||
indices = indices.to(video_feats.device)
|
||||
hit_indices.append(indices)
|
||||
|
||||
filtered_ctx_feats.append(
|
||||
torch.index_select(ctx_feats[i], 1, indices))
|
||||
|
||||
scale_first = scale_boundaries[i]
|
||||
scale_last = scale_boundaries[i + 1]
|
||||
|
||||
filtered_start = torch.index_select(
|
||||
start_ts[scale_first:scale_last], 0, indices)
|
||||
filtered_end = torch.index_select(
|
||||
end_ts[scale_first:scale_last], 0, indices)
|
||||
starts.append(filtered_start)
|
||||
ends.append(filtered_end)
|
||||
|
||||
starts = torch.cat(starts, dim=0)
|
||||
ends = torch.cat(ends, dim=0)
|
||||
|
||||
qv_merge_scores, qv_ctn_scores, ctn_feats = self.q2v_stage2(
|
||||
video_feats, sent_feat, hit_indices, qv_ctx_scores)
|
||||
ctx_feats = filtered_ctx_feats
|
||||
else:
|
||||
ctn_feats = None
|
||||
qv_merge_scores = qv_ctx_scores
|
||||
starts = start_ts
|
||||
ends = end_ts
|
||||
|
||||
bbox_bias = self.regressor(ctx_feats, ctn_feats, sent_feat)
|
||||
final_scores = torch.sigmoid(torch.cat(qv_merge_scores, dim=1))
|
||||
|
||||
return final_scores, bbox_bias, starts, ends
|
||||
623
modelscope/models/multi_modal/soonet/swin_transformer.py
Normal file
623
modelscope/models/multi_modal/soonet/swin_transformer.py
Normal file
@@ -0,0 +1,623 @@
|
||||
# The implementation is adopted from Swin-Transformer-1D, made publicly available
|
||||
# at https://github.com/meraks/Swin-Transformer-1D
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
|
||||
def drop_path(x,
|
||||
drop_prob: float = 0.,
|
||||
training: bool = False,
|
||||
scale_by_keep: bool = True):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
"""
|
||||
if drop_prob == 0. or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0], ) + (1, ) * (
|
||||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
"""
|
||||
|
||||
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
self.scale_by_keep = scale_by_keep
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
||||
|
||||
def extra_repr(self):
|
||||
return f'drop_prob={round(self.drop_prob,3):0.3f}'
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, L, C)
|
||||
window_size (int): window size
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, C)
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
x = x.view(B, L // window_size, window_size, C)
|
||||
windows = x.permute(0, 1, 2, 3).contiguous().view(-1, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, L):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
L (int): sequence length
|
||||
Returns:
|
||||
x: (B, L, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (L / window_size))
|
||||
x = windows.view(B, L // window_size, window_size, -1)
|
||||
x = x.permute(0, 1, 2, 3).contiguous().view(B, L, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention_1D(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (int): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
pretrained_window_size (int): The height and width of the window in pre-training.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.,
|
||||
pretrained_window_size=0):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wl
|
||||
self.pretrained_window_size = pretrained_window_size
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.logit_scale = nn.Parameter(
|
||||
torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
|
||||
|
||||
# mlp to generate continuous relative position bias
|
||||
self.cpb_mlp = nn.Sequential(
|
||||
nn.Linear(1, 512, bias=True), nn.ReLU(inplace=True),
|
||||
nn.Linear(512, num_heads, bias=False))
|
||||
|
||||
# get relative_coords_table
|
||||
relative_coords_l = torch.arange(
|
||||
-(self.window_size - 1), self.window_size, dtype=torch.float32)
|
||||
relative_coords_table = torch.stack(
|
||||
torch.meshgrid([relative_coords_l], indexing='ij')).permute(
|
||||
1, 0).contiguous().unsqueeze(0) # 1, 2*Wl-1, 1
|
||||
if pretrained_window_size > 0:
|
||||
relative_coords_table[:, :, :] /= (pretrained_window_size - 1)
|
||||
else:
|
||||
relative_coords_table[:, :, :] /= (self.window_size - 1)
|
||||
relative_coords_table *= 8 # normalize to -8, 8
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
torch.abs(relative_coords_table) + 1.0) / np.log2(8)
|
||||
|
||||
self.register_buffer('relative_coords_table', relative_coords_table)
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_l = torch.arange(self.window_size)
|
||||
coords = torch.stack(torch.meshgrid([coords_l],
|
||||
indexing='ij')) # 1, Wl
|
||||
coords_flatten = torch.flatten(coords, 1) # 1, Wl
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 1, Wl, Wl
|
||||
relative_coords = relative_coords.permute(1, 2,
|
||||
0).contiguous() # Wl, Wl, 1
|
||||
relative_coords[:, :,
|
||||
0] += self.window_size - 1 # shift to start from 0
|
||||
relative_position_index = relative_coords.sum(-1) # Wl, Wl
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wl, Wl) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat(
|
||||
(self.q_bias,
|
||||
torch.zeros_like(self.v_bias,
|
||||
requires_grad=False), self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# cosine attention
|
||||
attn = (
|
||||
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
||||
logit_scale = torch.clamp(
|
||||
self.logit_scale,
|
||||
max=torch.log(torch.tensor(1. / 0.01, device=attn.device))).exp()
|
||||
attn = attn * logit_scale
|
||||
|
||||
relative_position_bias_table = self.cpb_mlp(
|
||||
self.relative_coords_table).view(-1, self.num_heads)
|
||||
relative_position_bias = relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size, self.window_size, -1) # Wl,l,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wl, Wl
|
||||
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def compute_mask(L, window_size, shift_size):
|
||||
Lp = int(np.ceil(L / window_size)) * window_size
|
||||
img_mask = torch.zeros((1, Lp, 1)) # 1 Lp 1
|
||||
pad_size = int(Lp - L)
|
||||
if (pad_size == 0) or (pad_size + shift_size == window_size):
|
||||
segs = (slice(-window_size), slice(-window_size, -shift_size),
|
||||
slice(-shift_size, None))
|
||||
elif pad_size + shift_size > window_size:
|
||||
seg1 = int(window_size * 2 - L + shift_size)
|
||||
segs = (slice(-seg1), slice(-seg1, -window_size),
|
||||
slice(-window_size, -shift_size), slice(-shift_size, None))
|
||||
elif pad_size + shift_size < window_size:
|
||||
seg1 = int(window_size * 2 - L + shift_size)
|
||||
segs = (slice(-window_size), slice(-window_size, -seg1),
|
||||
slice(-seg1, -shift_size), slice(-shift_size, None))
|
||||
cnt = 0
|
||||
for d in segs:
|
||||
img_mask[:, d, :] = cnt
|
||||
cnt += 1
|
||||
mask_windows = window_partition(img_mask, window_size) # nW, ws, 1
|
||||
mask_windows = mask_windows.squeeze(-1) # nW, ws
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
return attn_mask
|
||||
|
||||
|
||||
class SwinTransformerBlock_1D(nn.Module):
|
||||
r""" Swin Transformer Block.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
||||
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
pretrained_window_size (int): Window size in pre-training.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
pretrained_window_size=0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = WindowAttention_1D(
|
||||
dim,
|
||||
window_size=self.window_size,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
pretrained_window_size=pretrained_window_size)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
def forward(self, x):
|
||||
B, L, C = x.shape
|
||||
|
||||
attn_mask = compute_mask(L, self.window_size,
|
||||
self.shift_size).to(x.device)
|
||||
|
||||
shortcut = x
|
||||
# x = x.view(B, L, C)
|
||||
|
||||
# padding x
|
||||
pad_r = (self.window_size - L % self.window_size) % self.window_size
|
||||
x = F.pad(x, (0, 0, 0, pad_r))
|
||||
_, Lp, _ = x.shape
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(x, shifts=(-self.shift_size), dims=(1))
|
||||
else:
|
||||
shifted_x = x
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(shifted_x,
|
||||
self.window_size) # nW*B, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size,
|
||||
C) # nW*B, window_siz, C
|
||||
|
||||
# W-MSA/SW-MSA
|
||||
attn_windows = self.attn(
|
||||
x_windows, mask=attn_mask) # nW*B, window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size,
|
||||
Lp) # B L' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(shifted_x, shifts=(self.shift_size), dims=(1))
|
||||
else:
|
||||
x = shifted_x
|
||||
x = x.view(B, Lp, C)
|
||||
# reverse padding x
|
||||
x = x[:, :L, :].contiguous()
|
||||
x = shortcut + self.drop_path(self.norm1(x))
|
||||
|
||||
# FFN
|
||||
x = x + self.drop_path(self.norm2(self.mlp(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
""" Patch Merging Layer
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
# self.reduction = nn.Linear(2 * dim, dim, bias=False)
|
||||
# self.norm = norm_layer(2 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
""" Forward function.
|
||||
Args:
|
||||
x: Input feature, tensor size (B, L, C).
|
||||
"""
|
||||
B, L, C = x.shape
|
||||
x = F.pad(x, (0, 0, 0, L % 2))
|
||||
|
||||
x0 = x[:, 0::2, :] # B L/2 C
|
||||
x1 = x[:, 1::2, :] # B L/2 C
|
||||
|
||||
x = torch.maximum(x0, x1)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
pretrained_window_size (int): Local window size in pre-training.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
pretrained_window_size=0):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
SwinTransformerBlock_1D(
|
||||
dim=dim,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i]
|
||||
if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
pretrained_window_size=pretrained_window_size)
|
||||
for i in range(depth)
|
||||
])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x)
|
||||
else:
|
||||
x = blk(x)
|
||||
|
||||
proposal = x
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x, proposal
|
||||
|
||||
def _init_respostnorm(self):
|
||||
for blk in self.blocks:
|
||||
nn.init.constant_(blk.norm1.bias, 0)
|
||||
nn.init.constant_(blk.norm1.weight, 0)
|
||||
nn.init.constant_(blk.norm2.bias, 0)
|
||||
nn.init.constant_(blk.norm2.weight, 0)
|
||||
|
||||
|
||||
class PatchEmbed1D(nn.Module):
|
||||
""" Video to Patch Embedding.
|
||||
Args:
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input video channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
patch_size=4,
|
||||
in_chans=32,
|
||||
embed_dim=128,
|
||||
norm_layer=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.proj = nn.Conv1d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
# padding
|
||||
_, _, L = x.size()
|
||||
pad_r = (self.patch_size - L % self.patch_size) % self.patch_size
|
||||
x = F.pad(x, (0, pad_r))
|
||||
x = self.proj(x) # B C Wl
|
||||
if self.norm is not None:
|
||||
# Wl = x.size(2)
|
||||
x = x.transpose(1, 2)
|
||||
x = self.norm(x)
|
||||
# x = x.transpose(1, 2).view(-1, self.embed_dim, Wl)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SwinTransformerV2_1D(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
patch_size=4,
|
||||
in_chans=32,
|
||||
embed_dim=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=[7, 7, 7, 7],
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm,
|
||||
patch_norm=True,
|
||||
use_checkpoint=False,
|
||||
pretrained_window_sizes=[0, 0, 0, 0],
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_layers = len(depths)
|
||||
self.embed_dim = embed_dim
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = int(embed_dim * 2**(self.num_layers - 1))
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
# split image into non-overlapping patches
|
||||
self.patch_embed = PatchEmbed1D(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
# stochastic depth
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
|
||||
] # stochastic depth decay rule
|
||||
|
||||
# build layers
|
||||
self.layers = nn.ModuleList()
|
||||
for i_layer in range(self.num_layers):
|
||||
layer = BasicLayer(
|
||||
dim=embed_dim,
|
||||
depth=depths[i_layer],
|
||||
num_heads=num_heads[i_layer],
|
||||
window_size=window_size[i_layer],
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=PatchMerging if
|
||||
(i_layer < self.num_layers - 1) else None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
pretrained_window_size=pretrained_window_sizes[i_layer])
|
||||
self.layers.append(layer)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
for bly in self.layers:
|
||||
bly._init_respostnorm()
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'absolute_pos_embed'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'cpb_mlp', 'logit_scale', 'relative_position_bias_table'}
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
x = self.pos_drop(x)
|
||||
|
||||
proposals = list()
|
||||
for layer in self.layers:
|
||||
x, proposal = layer(x)
|
||||
proposals.append(proposal)
|
||||
|
||||
return proposals
|
||||
|
||||
def forward(self, x):
|
||||
return self.forward_features(x)
|
||||
152
modelscope/models/multi_modal/soonet/tokenizer.py
Normal file
152
modelscope/models/multi_modal/soonet/tokenizer.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# The implementation is adopted from CLIP, made publicly available
|
||||
# under MIT License at https://github.com/openai/CLIP
|
||||
|
||||
import gzip
|
||||
import html
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
bs = list(range(ord('!'),
|
||||
ord('~') + 1)) + list(range(
|
||||
ord('¡'),
|
||||
ord('¬') + 1)) + list(range(ord('®'),
|
||||
ord('ÿ') + 1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8 + n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
|
||||
def __init__(self, bpe_path):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode('utf-8').split('\n')
|
||||
merges = merges[1:49152 - 256 - 2 + 1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v + '</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {
|
||||
'<|startoftext|>': '<|startoftext|>',
|
||||
'<|endoftext|>': '<|endoftext|>'
|
||||
}
|
||||
self.pat = re.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + (token[-1] + '</w>', )
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token + '</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(
|
||||
pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except ValueError:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word) - 1 and word[
|
||||
i + 1] == second:
|
||||
new_word.append(first + second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b]
|
||||
for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token]
|
||||
for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode(
|
||||
'utf-8', errors='replace').replace('</w>', ' ')
|
||||
return text
|
||||
|
||||
def tokenize(self, texts, context_length=77):
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
sot_token = self.encoder['<|startoftext|>']
|
||||
eot_token = self.encoder['<|endoftext|>']
|
||||
all_tokens = [[sot_token] + self.encode(text) + [eot_token]
|
||||
for text in texts]
|
||||
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
||||
|
||||
for i, tokens in enumerate(all_tokens):
|
||||
if len(tokens) > context_length:
|
||||
tokens = tokens[:context_length]
|
||||
tokens[-1] = eot_token
|
||||
|
||||
result[i, :len(tokens)] = torch.tensor(tokens)
|
||||
|
||||
return result
|
||||
58
modelscope/models/multi_modal/soonet/utils.py
Normal file
58
modelscope/models/multi_modal/soonet/utils.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import copy
|
||||
|
||||
import decord
|
||||
import numpy as np
|
||||
from decord import VideoReader, cpu
|
||||
from decord._ffi.base import DECORDError
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def decode_video(video_path, target_fps=5):
|
||||
"""
|
||||
Decode video from 'video_path' and return the sampled frames based on target_fps.
|
||||
The default value of target_fps is 5.
|
||||
|
||||
Args:
|
||||
video_path: the absolute path of video.
|
||||
target_fps: the number of sampled video frames per second.
|
||||
|
||||
Returns:
|
||||
[imgs, duration]
|
||||
"""
|
||||
decord.bridge.set_bridge('torch')
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
cur_fps = vr.get_avg_fps()
|
||||
if cur_fps > target_fps:
|
||||
interval = float(cur_fps) / float(target_fps)
|
||||
start = float(interval) / 2.
|
||||
else:
|
||||
interval = 1.0
|
||||
start = 0.0
|
||||
|
||||
vid_length = len(vr)
|
||||
duration = vid_length / cur_fps
|
||||
sampled_idxs = np.clip(
|
||||
np.round(np.arange(start, float(vid_length), step=interval)), 0,
|
||||
vid_length - 1).astype(np.int32)
|
||||
|
||||
imgs = list()
|
||||
for i in tqdm(sampled_idxs):
|
||||
bias = 0
|
||||
# avoid broken frames
|
||||
while bias <= 10:
|
||||
try:
|
||||
img = vr[i - bias]
|
||||
break
|
||||
except DECORDError:
|
||||
bias += 1
|
||||
if bias > 10:
|
||||
img = copy.deepcopy(imgs[-1])
|
||||
imgs.append(img)
|
||||
else:
|
||||
img = img / 255.
|
||||
img = img.permute(2, 0, 1)
|
||||
imgs.append(img)
|
||||
|
||||
return imgs, duration
|
||||
@@ -57,6 +57,7 @@ class OutputKeys(object):
|
||||
MATCHES = 'matches'
|
||||
PCD12 = 'pcd12'
|
||||
PCD12_ALIGN = 'pcd12_align'
|
||||
TBOUNDS = 'tbounds'
|
||||
|
||||
|
||||
TASK_OUTPUTS = {
|
||||
@@ -1105,6 +1106,7 @@ TASK_OUTPUTS = {
|
||||
Tasks.document_grounded_dialog_generate: [OutputKeys.TEXT],
|
||||
Tasks.document_grounded_dialog_rerank: [OutputKeys.OUTPUT],
|
||||
Tasks.document_grounded_dialog_retrieval: [OutputKeys.OUTPUT],
|
||||
Tasks.video_temporal_grounding: [OutputKeys.SCORES, OutputKeys.TBOUNDS],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ if TYPE_CHECKING:
|
||||
from .video_captioning_pipeline import VideoCaptioningPipeline
|
||||
from .video_question_answering_pipeline import VideoQuestionAnsweringPipeline
|
||||
from .diffusers_wrapped import StableDiffusionWrapperPipeline, ChineseStableDiffusionPipeline
|
||||
from .soonet_video_temporal_grounding_pipeline import SOONetVideoTemporalGroundingPipeline
|
||||
from .text_to_video_synthesis_pipeline import TextToVideoSynthesisPipeline
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -41,6 +42,8 @@ else:
|
||||
['VideoQuestionAnsweringPipeline'],
|
||||
'diffusers_wrapped':
|
||||
['StableDiffusionWrapperPipeline', 'ChineseStableDiffusionPipeline'],
|
||||
'soonet_video_temporal_grounding_pipeline':
|
||||
['SOONetVideoTemporalGroundingPipeline'],
|
||||
'text_to_video_synthesis_pipeline': ['TextToVideoSynthesisPipeline'],
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.multi_modal.soonet import (SimpleTokenizer,
|
||||
decode_video, load_clip)
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_temporal_grounding,
|
||||
module_name=Pipelines.soonet_video_temporal_grounding)
|
||||
class SOONetVideoTemporalGroundingPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
SOONet pipeline for video temporal groundinng
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
|
||||
>>> soonet_pipeline = pipeline("video-temporal-grounding", "damo/multi-modal_soonet_video-temporal-grounding")
|
||||
>>> soonet_pipeline(
|
||||
('a man takes food out of the refrigerator.',
|
||||
'soonet_video_temporal_grounding_test_video.mp4'))
|
||||
|
||||
>>> {
|
||||
>>> "scores": [
|
||||
>>> 0.80661213,
|
||||
>>> 0.8060084,
|
||||
>>> 0.8018835,
|
||||
>>> 0.79837507,
|
||||
>>> 0.7963626,
|
||||
>>> 0.7949013,
|
||||
>>> 0.79353744,
|
||||
>>> 0.79287416,
|
||||
>>> 0.79066336,
|
||||
>>> 0.79027915
|
||||
>>> ],
|
||||
>>> "tbounds": [
|
||||
>>> [
|
||||
>>> 0,
|
||||
>>> 2.9329566955566406
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 1.0630402565002441,
|
||||
>>> 4.9339457750320435
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 300.96919429302216,
|
||||
>>> 304.8546848297119
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 302.96981167793274,
|
||||
>>> 306.7714672088623
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 0,
|
||||
>>> 5.0421366691589355
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 304.9119266271591,
|
||||
>>> 308.7636929154396
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 258.96133184432983,
|
||||
>>> 262.805901825428
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 122.9599289894104,
|
||||
>>> 126.86622190475464
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 126.94010400772095,
|
||||
>>> 130.8090701699257
|
||||
>>> ],
|
||||
>>> [
|
||||
>>> 121.04773849248886,
|
||||
>>> 124.79261875152588
|
||||
>>> ]
|
||||
>>> ]
|
||||
>>> }
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
self.model_dir = model
|
||||
self.clip = load_clip(os.path.join(self.model_dir,
|
||||
'ViT-B-32.pt')).to(self.device)
|
||||
self.model = self.model.float().to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# Load Configuration from File
|
||||
config_path = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
|
||||
self.config = Config.from_file(config_path).hyperparams
|
||||
self.nscales = self.config.nscales
|
||||
self.snippet_length = self.config.snippet_length
|
||||
self.max_anchor_length = self.snippet_length * 2**(self.nscales - 1)
|
||||
self.topk = 10
|
||||
self.fps = 5
|
||||
# Define image transform
|
||||
self.img_transform = transforms.Compose([
|
||||
transforms.Resize(224),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
||||
(0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
logger.info('Init transform done')
|
||||
|
||||
# Init tokenizer
|
||||
bpe_path = os.path.join(self.model_dir, 'bpe_simple_vocab_16e6.txt.gz')
|
||||
self.tokenizer = SimpleTokenizer(bpe_path)
|
||||
logger.info('Init tokenizer done')
|
||||
|
||||
def pad(self, arr, pad_len):
|
||||
new_arr = np.zeros((pad_len, ), dtype=float)
|
||||
new_arr[:len(arr)] = arr
|
||||
return new_arr
|
||||
|
||||
def preprocess(self, input: Input, **preprocess_params) -> Dict[str, Any]:
|
||||
text, video_name = input
|
||||
video_path = os.path.join(self.model_dir, video_name)
|
||||
imgs, duration = decode_video(video_path, self.fps)
|
||||
trans_imgs = list()
|
||||
for i, img in enumerate(imgs):
|
||||
trans_imgs.append(self.img_transform(img))
|
||||
imgs = trans_imgs
|
||||
token_ids = self.tokenizer.tokenize(text).to(
|
||||
self.device, non_blocking=True)
|
||||
# get the start and end timestamps of anchors
|
||||
start_ts, end_ts, scale_boundaries = list(), list(), [0]
|
||||
ori_video_length = len(imgs)
|
||||
pad_video_length = int(
|
||||
np.math.ceil(ori_video_length / self.max_anchor_length)
|
||||
* self.max_anchor_length)
|
||||
for i in range(self.config.nscales):
|
||||
anchor_length = self.config.snippet_length * (2**i)
|
||||
pad_feat_length = pad_video_length // anchor_length
|
||||
nfeats = np.math.ceil(ori_video_length / anchor_length)
|
||||
s_times = np.arange(0, nfeats).astype(np.float32) * (
|
||||
anchor_length // self.fps)
|
||||
e_times = np.arange(1, nfeats + 1).astype(np.float32) * (
|
||||
anchor_length // self.fps)
|
||||
e_times = np.minimum(duration, e_times)
|
||||
start_ts.append(self.pad(s_times, pad_feat_length))
|
||||
end_ts.append(self.pad(e_times, pad_feat_length))
|
||||
scale_boundaries.append(scale_boundaries[-1] + pad_feat_length)
|
||||
|
||||
start_ts = torch.from_numpy(np.concatenate(start_ts, axis=0))
|
||||
end_ts = torch.from_numpy(np.concatenate(end_ts, axis=0))
|
||||
scale_boundaries = torch.LongTensor(scale_boundaries)
|
||||
result = {
|
||||
'token_ids': token_ids,
|
||||
'imgs': torch.stack(imgs, dim=0),
|
||||
'start_ts': start_ts,
|
||||
'end_ts': end_ts,
|
||||
'scale_boundaries': scale_boundaries
|
||||
}
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any],
|
||||
**forward_params) -> Dict[str, Any]:
|
||||
with torch.no_grad():
|
||||
video_feats = self.clip.encode_image(input['imgs'].to(self.device))
|
||||
query_feats = self.clip.encode_text(input['token_ids'].to(
|
||||
self.device))
|
||||
#
|
||||
ori_video_length, feat_dim = video_feats.shape
|
||||
pad_video_length = int(
|
||||
np.math.ceil(ori_video_length / self.max_anchor_length)
|
||||
* self.max_anchor_length)
|
||||
pad_video_feats = torch.zeros((pad_video_length, feat_dim),
|
||||
dtype=float)
|
||||
pad_video_feats[:ori_video_length, :] = video_feats
|
||||
final_scores, bbox_bias, starts, ends = self.model(
|
||||
query_feats=query_feats.float().to(self.device),
|
||||
video_feats=pad_video_feats.unsqueeze(0).float().to(
|
||||
self.device),
|
||||
start_ts=input['start_ts'].float().to(self.device),
|
||||
end_ts=input['end_ts'].float().to(self.device),
|
||||
scale_boundaries=input['scale_boundaries'])
|
||||
#
|
||||
final_scores = final_scores.cpu().numpy()
|
||||
bbox_bias = bbox_bias.cpu().numpy()
|
||||
starts = starts.cpu().numpy()
|
||||
ends = ends.cpu().numpy()
|
||||
pred_scores, pred_bboxes = list(), list()
|
||||
rank_id = np.argsort(final_scores[0])[::-1]
|
||||
for j in range(self.topk):
|
||||
if j >= len(rank_id):
|
||||
break
|
||||
pred_scores.append(final_scores[0, rank_id[j]])
|
||||
ori_end = float(ends[rank_id[j]])
|
||||
ori_start = float(starts[rank_id[j]])
|
||||
duration = ori_end - ori_start
|
||||
sbias = bbox_bias[0, rank_id[j], 0]
|
||||
ebias = bbox_bias[0, rank_id[j], 1]
|
||||
pred_start = max(0, ori_start + sbias * duration)
|
||||
pred_end = ori_end + ebias * duration
|
||||
pred_bboxes.append([pred_start, pred_end])
|
||||
|
||||
return {
|
||||
OutputKeys.SCORES: pred_scores,
|
||||
OutputKeys.TBOUNDS: pred_bboxes
|
||||
}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any],
|
||||
**post_params) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -239,6 +239,7 @@ class MultiModalTasks(object):
|
||||
document_vl_embedding = 'document-vl-embedding'
|
||||
video_captioning = 'video-captioning'
|
||||
video_question_answering = 'video-question-answering'
|
||||
video_temporal_grounding = 'video-temporal-grounding'
|
||||
text_to_video_synthesis = 'text-to-video-synthesis'
|
||||
|
||||
|
||||
|
||||
34
tests/pipelines/test_soonet_video_temporal_grounding.py
Normal file
34
tests/pipelines/test_soonet_video_temporal_grounding.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# Copyright 2022-2023 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
import unittest
|
||||
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.multi_modal.soonet import SOONet
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class SOONetVideoTemporalGroundingTest(unittest.TestCase,
|
||||
DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.video_temporal_grounding
|
||||
self.model_id = 'damo/multi-modal_soonet_video-temporal-grounding'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
soonet_pipeline = pipeline(self.task, self.model_id)
|
||||
result = soonet_pipeline(
|
||||
('a man takes food out of the refrigerator.',
|
||||
'soonet_video_temporal_grounding_test_video.mp4'))
|
||||
print(f'soonet output: {result}.')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_load_model_from_pretrained(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
self.assertTrue(model.__class__ == SOONet)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user