mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add quadtree_image_matching
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11199857 * init quadtree code * quadtree attention image matching can run in modelscope pipeline * jit build quadtree attention * update license info
This commit is contained in:
@@ -1 +1 @@
|
||||
recursive-include modelscope/configs *.py
|
||||
recursive-include modelscope/configs *.py *.cu *.h *.cpp
|
||||
|
||||
3
data/test/images/image_matching1.jpg
Normal file
3
data/test/images/image_matching1.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:05ad1e66d7fee2f9e11766160522ad823f1fcc0ab8a5740a6c89b1765228ea32
|
||||
size 334048
|
||||
3
data/test/images/image_matching2.jpg
Normal file
3
data/test/images/image_matching2.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8ed3a68939b922bc2362b1d8051c24d2ca03be6a431fcc7c423e157012debd5a
|
||||
size 424584
|
||||
@@ -68,6 +68,7 @@ class Models(object):
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
|
||||
vision_middleware = 'vision-middleware'
|
||||
video_stabilization = 'video-stabilization'
|
||||
real_basicvsr = 'real-basicvsr'
|
||||
@@ -287,6 +288,7 @@ class Pipelines(object):
|
||||
vision_middleware_multi_task = 'vision-middleware-multi-task'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
image_matching = 'image-matching'
|
||||
video_stabilization = 'video-stabilization'
|
||||
video_super_resolution = 'realbasicvsr-video-super-resolution'
|
||||
pointcloud_sceneflow_estimation = 'pointcloud-sceneflow-estimation'
|
||||
|
||||
@@ -97,11 +97,9 @@ class Model(ABC):
|
||||
prefetched = kwargs.get('model_prefetched')
|
||||
if prefetched is not None:
|
||||
kwargs.pop('model_prefetched')
|
||||
|
||||
invoked_by = kwargs.get(Invoke.KEY)
|
||||
if invoked_by is not None:
|
||||
kwargs.pop(Invoke.KEY)
|
||||
else:
|
||||
invoked_by = Invoke.PRETRAINED
|
||||
|
||||
if osp.exists(model_name_or_path):
|
||||
|
||||
@@ -6,7 +6,7 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
crowd_counting, face_2d_keypoints, face_detection,
|
||||
face_generation, human_wholebody_keypoint, image_classification,
|
||||
image_color_enhance, image_colorization, image_denoise,
|
||||
image_inpainting, image_instance_segmentation,
|
||||
image_inpainting, image_instance_segmentation, image_matching,
|
||||
image_mvs_depth_estimation, image_panoptic_segmentation,
|
||||
image_portrait_enhancement, image_reid_person,
|
||||
image_semantic_segmentation, image_to_image_generation,
|
||||
|
||||
22
modelscope/models/cv/image_matching/__init__.py
Normal file
22
modelscope/models/cv/image_matching/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .quadtree_attention_model import QuadTreeAttentionForImageMatching
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'quadtree_attention_model': ['QuadTreeAttentionForImageMatching'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
173
modelscope/models/cv/image_matching/config/default.py
Normal file
173
modelscope/models/cv/image_matching/config/default.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
_CN = CN()
|
||||
|
||||
_CN.LOFTR = CN()
|
||||
_CN.LOFTR.BACKBONE_TYPE = 'ResNetFPN'
|
||||
_CN.LOFTR.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
|
||||
_CN.LOFTR.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
||||
_CN.LOFTR.FINE_CONCAT_COARSE_FEAT = True
|
||||
|
||||
# 1. LoFTR-backbone (local feature CNN) config
|
||||
_CN.LOFTR.RESNETFPN = CN()
|
||||
_CN.LOFTR.RESNETFPN.INITIAL_DIM = 128
|
||||
_CN.LOFTR.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
|
||||
|
||||
# 2. LoFTR-coarse module config
|
||||
_CN.LOFTR.COARSE = CN()
|
||||
_CN.LOFTR.COARSE.D_MODEL = 256
|
||||
_CN.LOFTR.COARSE.D_FFN = 256
|
||||
_CN.LOFTR.COARSE.NHEAD = 8
|
||||
_CN.LOFTR.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
|
||||
_CN.LOFTR.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
|
||||
_CN.LOFTR.COARSE.TEMP_BUG_FIX = True
|
||||
_CN.LOFTR.COARSE.BLOCK_TYPE = 'quadtree'
|
||||
_CN.LOFTR.COARSE.ATTN_TYPE = 'B'
|
||||
_CN.LOFTR.COARSE.TOPKS = [16, 8, 8]
|
||||
|
||||
# 3. Coarse-Matching config
|
||||
_CN.LOFTR.MATCH_COARSE = CN()
|
||||
_CN.LOFTR.MATCH_COARSE.THR = 0.2
|
||||
_CN.LOFTR.MATCH_COARSE.BORDER_RM = 2
|
||||
_CN.LOFTR.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
|
||||
_CN.LOFTR.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
|
||||
_CN.LOFTR.MATCH_COARSE.SKH_ITERS = 3
|
||||
_CN.LOFTR.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
|
||||
_CN.LOFTR.MATCH_COARSE.SKH_PREFILTER = False
|
||||
_CN.LOFTR.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.3 # training tricks: save GPU memory
|
||||
_CN.LOFTR.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
|
||||
_CN.LOFTR.MATCH_COARSE.SPARSE_SPVS = False
|
||||
|
||||
# 4. LoFTR-fine module config
|
||||
_CN.LOFTR.FINE = CN()
|
||||
_CN.LOFTR.FINE.D_MODEL = 128
|
||||
_CN.LOFTR.FINE.D_FFN = 128
|
||||
_CN.LOFTR.FINE.NHEAD = 8
|
||||
_CN.LOFTR.FINE.LAYER_NAMES = ['self', 'cross'] * 1
|
||||
_CN.LOFTR.FINE.ATTENTION = 'linear'
|
||||
_CN.LOFTR.FINE.BLOCK_TYPE = 'loftr'
|
||||
|
||||
# 5. LoFTR Losses
|
||||
# -- # coarse-level
|
||||
_CN.LOFTR.LOSS = CN()
|
||||
_CN.LOFTR.LOSS.COARSE_TYPE = 'focal' # ['focal', 'cross_entropy']
|
||||
_CN.LOFTR.LOSS.COARSE_WEIGHT = 1.0
|
||||
# -- - -- # focal loss (coarse)
|
||||
_CN.LOFTR.LOSS.FOCAL_ALPHA = 0.25
|
||||
_CN.LOFTR.LOSS.FOCAL_GAMMA = 2.0
|
||||
_CN.LOFTR.LOSS.POS_WEIGHT = 1.0
|
||||
_CN.LOFTR.LOSS.NEG_WEIGHT = 1.0
|
||||
|
||||
# -- # fine-level
|
||||
_CN.LOFTR.LOSS.FINE_TYPE = 'l2_with_std' # ['l2_with_std', 'l2']
|
||||
_CN.LOFTR.LOSS.FINE_WEIGHT = 1.0
|
||||
_CN.LOFTR.LOSS.FINE_CORRECT_THR = 1.0
|
||||
|
||||
_CN.DATASET = CN()
|
||||
# 1. data config
|
||||
# training and validating
|
||||
_CN.DATASET.TRAINVAL_DATA_SOURCE = None # options: ['ScanNet', 'MegaDepth']
|
||||
_CN.DATASET.TRAIN_DATA_ROOT = None
|
||||
_CN.DATASET.TRAIN_POSE_ROOT = None # (optional directory for poses)
|
||||
_CN.DATASET.TRAIN_NPZ_ROOT = None
|
||||
_CN.DATASET.TRAIN_LIST_PATH = None
|
||||
_CN.DATASET.TRAIN_INTRINSIC_PATH = None
|
||||
_CN.DATASET.VAL_DATA_ROOT = None
|
||||
_CN.DATASET.VAL_POSE_ROOT = None # (optional directory for poses)
|
||||
_CN.DATASET.VAL_NPZ_ROOT = None
|
||||
_CN.DATASET.VAL_LIST_PATH = None # None if val data from all scenes are bundled into a single npz file
|
||||
_CN.DATASET.VAL_INTRINSIC_PATH = None
|
||||
# testing
|
||||
_CN.DATASET.TEST_DATA_SOURCE = None
|
||||
_CN.DATASET.TEST_DATA_ROOT = None
|
||||
_CN.DATASET.TEST_POSE_ROOT = None # (optional directory for poses)
|
||||
_CN.DATASET.TEST_NPZ_ROOT = None
|
||||
_CN.DATASET.TEST_LIST_PATH = None # None if test data from all scenes are bundled into a single npz file
|
||||
_CN.DATASET.TEST_INTRINSIC_PATH = None
|
||||
|
||||
# 2. dataset config
|
||||
# general options
|
||||
_CN.DATASET.MIN_OVERLAP_SCORE_TRAIN = 0.4 # discard data with overlap_score < min_overlap_score
|
||||
_CN.DATASET.MIN_OVERLAP_SCORE_TEST = 0.0
|
||||
_CN.DATASET.AUGMENTATION_TYPE = None # options: [None, 'dark', 'mobile']
|
||||
|
||||
# MegaDepth options
|
||||
_CN.DATASET.MGDPT_IMG_RESIZE = 640 # resize the longer side, zero-pad bottom-right to square.
|
||||
_CN.DATASET.MGDPT_IMG_PAD = True # pad img to square with size = MGDPT_IMG_RESIZE
|
||||
_CN.DATASET.MGDPT_DEPTH_PAD = True # pad depthmap to square with size = 2000
|
||||
_CN.DATASET.MGDPT_DF = 8
|
||||
|
||||
_CN.TRAINER = CN()
|
||||
_CN.TRAINER.WORLD_SIZE = 1
|
||||
_CN.TRAINER.CANONICAL_BS = 64
|
||||
_CN.TRAINER.CANONICAL_LR = 8e-3
|
||||
_CN.TRAINER.SCALING = None # this will be calculated automatically
|
||||
_CN.TRAINER.FIND_LR = False # use learning rate finder from pytorch-lightning
|
||||
|
||||
# optimizer
|
||||
_CN.TRAINER.OPTIMIZER = 'adamw' # [adam, adamw]
|
||||
_CN.TRAINER.TRUE_LR = None # this will be calculated automatically at runtime
|
||||
_CN.TRAINER.ADAM_DECAY = 0. # ADAM: for adam
|
||||
_CN.TRAINER.ADAMW_DECAY = 0.1
|
||||
|
||||
# step-based warm-up
|
||||
_CN.TRAINER.WARMUP_TYPE = 'linear' # [linear, constant]
|
||||
_CN.TRAINER.WARMUP_RATIO = 0.1
|
||||
_CN.TRAINER.WARMUP_STEP = 1875
|
||||
|
||||
# learning rate scheduler
|
||||
_CN.TRAINER.SCHEDULER = 'MultiStepLR' # [MultiStepLR, CosineAnnealing, ExponentialLR]
|
||||
_CN.TRAINER.SCHEDULER_INTERVAL = 'epoch' # [epoch, step]
|
||||
_CN.TRAINER.MSLR_MILESTONES = [8, 12, 16, 20, 24] # MSLR: MultiStepLR
|
||||
_CN.TRAINER.MSLR_GAMMA = 0.5
|
||||
_CN.TRAINER.COSA_TMAX = 30 # COSA: CosineAnnealing
|
||||
_CN.TRAINER.ELR_GAMMA = 0.999992 # ELR: ExponentialLR, this value for 'step' interval
|
||||
|
||||
# plotting related
|
||||
_CN.TRAINER.ENABLE_PLOTTING = True
|
||||
_CN.TRAINER.N_VAL_PAIRS_TO_PLOT = 32 # number of val/test paris for plotting
|
||||
_CN.TRAINER.PLOT_MODE = 'evaluation' # ['evaluation', 'confidence']
|
||||
_CN.TRAINER.PLOT_MATCHES_ALPHA = 'dynamic'
|
||||
|
||||
# geometric metrics and pose solver
|
||||
_CN.TRAINER.EPI_ERR_THR = 5e-4 # recommendation: 5e-4 for ScanNet, 1e-4 for MegaDepth (from SuperGlue)
|
||||
_CN.TRAINER.POSE_GEO_MODEL = 'E' # ['E', 'F', 'H']
|
||||
_CN.TRAINER.POSE_ESTIMATION_METHOD = 'RANSAC' # [RANSAC, DEGENSAC, MAGSAC]
|
||||
_CN.TRAINER.RANSAC_PIXEL_THR = 0.5
|
||||
_CN.TRAINER.RANSAC_CONF = 0.99999
|
||||
_CN.TRAINER.RANSAC_MAX_ITERS = 10000
|
||||
_CN.TRAINER.USE_MAGSACPP = False
|
||||
|
||||
# data sampler for train_dataloader
|
||||
_CN.TRAINER.DATA_SAMPLER = 'scene_balance' # options: ['scene_balance', 'random', 'normal']
|
||||
# 'scene_balance' config
|
||||
_CN.TRAINER.N_SAMPLES_PER_SUBSET = 200
|
||||
_CN.TRAINER.SB_SUBSET_SAMPLE_REPLACEMENT = True # whether sample each scene with replacement or not
|
||||
_CN.TRAINER.SB_SUBSET_SHUFFLE = True # after sampling from scenes, whether shuffle within the epoch or not
|
||||
_CN.TRAINER.SB_REPEAT = 1 # repeat N times for training the sampled data
|
||||
# 'random' config
|
||||
_CN.TRAINER.RDM_REPLACEMENT = True
|
||||
_CN.TRAINER.RDM_NUM_SAMPLES = None
|
||||
|
||||
# gradient clipping
|
||||
_CN.TRAINER.GRADIENT_CLIPPING = 0.5
|
||||
|
||||
# reproducibility
|
||||
# This seed affects the data sampling. With the same seed, the data sampling is promised
|
||||
# to be the same. When resume training from a checkpoint, it's better to use a different
|
||||
# seed, otherwise the sampled data will be exactly the same as before resuming, which will
|
||||
# cause less unique data items sampled during the entire training.
|
||||
# Use of different seed values might affect the final training result, since not all data items
|
||||
# are used during training on ScanNet. (60M pairs of images sampled during traing from 230M pairs in total.)
|
||||
_CN.TRAINER.SEED = 66
|
||||
|
||||
|
||||
def get_cfg_defaults():
|
||||
"""Get a yacs CfgNode object with default values for my_project."""
|
||||
# Return a clone so that the defaults will not be altered
|
||||
# This is for the "local variable" use pattern
|
||||
return _CN.clone()
|
||||
@@ -0,0 +1,5 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
from .loftr import LoFTR
|
||||
@@ -0,0 +1,16 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
|
||||
|
||||
|
||||
def build_backbone(config):
|
||||
if config['backbone_type'] == 'ResNetFPN':
|
||||
if config['resolution'] == (8, 2):
|
||||
return ResNetFPN_8_2(config['resnetfpn'])
|
||||
elif config['resolution'] == (16, 4):
|
||||
return ResNetFPN_16_4(config['resnetfpn'])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
||||
@@ -0,0 +1,223 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution without padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(in_planes, planes, stride)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
conv1x1(in_planes, planes, stride=stride),
|
||||
nn.BatchNorm2d(planes))
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.bn1(self.conv1(y)))
|
||||
y = self.bn2(self.conv2(y))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class ResNetFPN_8_2(nn.Module):
|
||||
"""
|
||||
ResNet+FPN, output resolution are 1/8 and 1/2.
|
||||
Each block has 2 layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Config
|
||||
block = BasicBlock
|
||||
initial_dim = config['initial_dim']
|
||||
block_dims = config['block_dims']
|
||||
|
||||
# Class Variable
|
||||
self.in_planes = initial_dim
|
||||
|
||||
# Networks
|
||||
self.conv1 = nn.Conv2d(
|
||||
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(initial_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
||||
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
||||
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
||||
|
||||
# 3. FPN upsample
|
||||
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
||||
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
||||
self.layer2_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[2], block_dims[2]),
|
||||
nn.BatchNorm2d(block_dims[2]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[2], block_dims[1]),
|
||||
)
|
||||
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
||||
self.layer1_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[1], block_dims[1]),
|
||||
nn.BatchNorm2d(block_dims[1]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[1], block_dims[0]),
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, dim, stride=1):
|
||||
layer1 = block(self.in_planes, dim, stride=stride)
|
||||
layer2 = block(dim, dim, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# ResNet Backbone
|
||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
||||
x1 = self.layer1(x0) # 1/2
|
||||
x2 = self.layer2(x1) # 1/4
|
||||
x3 = self.layer3(x2) # 1/8
|
||||
|
||||
# FPN
|
||||
x3_out = self.layer3_outconv(x3)
|
||||
|
||||
x3_out_2x = F.interpolate(
|
||||
x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x2_out = self.layer2_outconv(x2)
|
||||
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
|
||||
|
||||
x2_out_2x = F.interpolate(
|
||||
x2_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x1_out = self.layer1_outconv(x1)
|
||||
x1_out = self.layer1_outconv2(x1_out + x2_out_2x)
|
||||
|
||||
return [x3_out, x1_out]
|
||||
|
||||
|
||||
class ResNetFPN_16_4(nn.Module):
|
||||
"""
|
||||
ResNet+FPN, output resolution are 1/16 and 1/4.
|
||||
Each block has 2 layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Config
|
||||
block = BasicBlock
|
||||
initial_dim = config['initial_dim']
|
||||
block_dims = config['block_dims']
|
||||
|
||||
# Class Variable
|
||||
self.in_planes = initial_dim
|
||||
|
||||
# Networks
|
||||
self.conv1 = nn.Conv2d(
|
||||
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(initial_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
||||
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
||||
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
||||
self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
|
||||
|
||||
# 3. FPN upsample
|
||||
self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
|
||||
self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
|
||||
self.layer3_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[3], block_dims[3]),
|
||||
nn.BatchNorm2d(block_dims[3]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[3], block_dims[2]),
|
||||
)
|
||||
|
||||
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
||||
self.layer2_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[2], block_dims[2]),
|
||||
nn.BatchNorm2d(block_dims[2]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[2], block_dims[1]),
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, dim, stride=1):
|
||||
layer1 = block(self.in_planes, dim, stride=stride)
|
||||
layer2 = block(dim, dim, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# ResNet Backbone
|
||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
||||
x1 = self.layer1(x0) # 1/2
|
||||
x2 = self.layer2(x1) # 1/4
|
||||
x3 = self.layer3(x2) # 1/8
|
||||
x4 = self.layer4(x3) # 1/16
|
||||
|
||||
# FPN
|
||||
x4_out = self.layer4_outconv(x4)
|
||||
|
||||
x4_out_2x = F.interpolate(
|
||||
x4_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x3_out = self.layer3_outconv(x3)
|
||||
x3_out = self.layer3_outconv2(x3_out + x4_out_2x)
|
||||
|
||||
x3_out_2x = F.interpolate(
|
||||
x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x2_out = self.layer2_outconv(x2)
|
||||
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
|
||||
|
||||
return [x4_out, x2_out]
|
||||
98
modelscope/models/cv/image_matching/loftr_quadtree/loftr.py
Normal file
98
modelscope/models/cv/image_matching/loftr_quadtree/loftr.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops.einops import rearrange
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .loftr_module import FinePreprocess, LocalFeatureTransformer
|
||||
from .utils.coarse_matching import CoarseMatching
|
||||
from .utils.fine_matching import FineMatching
|
||||
from .utils.position_encoding import PositionEncodingSine
|
||||
|
||||
|
||||
class LoFTR(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Misc
|
||||
self.config = config
|
||||
|
||||
# Modules
|
||||
self.backbone = build_backbone(config)
|
||||
self.pos_encoding = PositionEncodingSine(
|
||||
config['coarse']['d_model'],
|
||||
temp_bug_fix=config['coarse']['temp_bug_fix'])
|
||||
self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
|
||||
self.coarse_matching = CoarseMatching(config['match_coarse'])
|
||||
self.fine_preprocess = FinePreprocess(config)
|
||||
self.loftr_fine = LocalFeatureTransformer(config['fine'])
|
||||
self.fine_matching = FineMatching()
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
Update:
|
||||
data (dict): {
|
||||
'image0': (torch.Tensor): (N, 1, H, W)
|
||||
'image1': (torch.Tensor): (N, 1, H, W)
|
||||
'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
|
||||
'mask1'(optional) : (torch.Tensor): (N, H, W)
|
||||
}
|
||||
"""
|
||||
# 1. Local Feature CNN
|
||||
data.update({
|
||||
'bs': data['image0'].size(0),
|
||||
'hw0_i': data['image0'].shape[2:],
|
||||
'hw1_i': data['image1'].shape[2:]
|
||||
})
|
||||
|
||||
if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
|
||||
feats_c, feats_f = self.backbone(
|
||||
torch.cat([data['image0'], data['image1']], dim=0))
|
||||
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
|
||||
data['bs']), feats_f.split(data['bs'])
|
||||
else: # handle different input shapes
|
||||
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
|
||||
data['image0']), self.backbone(data['image1'])
|
||||
|
||||
data.update({
|
||||
'hw0_c': feat_c0.shape[2:],
|
||||
'hw1_c': feat_c1.shape[2:],
|
||||
'hw0_f': feat_f0.shape[2:],
|
||||
'hw1_f': feat_f1.shape[2:]
|
||||
})
|
||||
|
||||
# 2. coarse-level loftr module
|
||||
# add featmap with positional encoding, then flatten it to sequence [N, HW, C]
|
||||
|
||||
feat_c0 = self.pos_encoding(feat_c0)
|
||||
feat_c1 = self.pos_encoding(feat_c1)
|
||||
|
||||
mask_c0 = mask_c1 = None # mask is useful in training
|
||||
if 'mask0' in data:
|
||||
mask_c0, mask_c1 = data['mask0'].flatten(
|
||||
-2), data['mask1'].flatten(-2)
|
||||
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0,
|
||||
mask_c1)
|
||||
|
||||
# 3. match coarse-level
|
||||
self.coarse_matching(
|
||||
feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
|
||||
|
||||
# 4. fine-level refinement
|
||||
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
|
||||
feat_f0, feat_f1, feat_c0, feat_c1, data)
|
||||
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
|
||||
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
|
||||
feat_f0_unfold, feat_f1_unfold)
|
||||
|
||||
# 5. match fine-level
|
||||
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
|
||||
|
||||
def load_state_dict(self, state_dict, *args, **kwargs):
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith('matcher.'):
|
||||
state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
|
||||
return super().load_state_dict(state_dict, *args, **kwargs)
|
||||
@@ -0,0 +1,6 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
from .fine_preprocess import FinePreprocess
|
||||
from .transformer import LocalFeatureTransformer
|
||||
@@ -0,0 +1,77 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.einops import rearrange, repeat
|
||||
|
||||
|
||||
class FinePreprocess(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.cat_c_feat = config['fine_concat_coarse_feat']
|
||||
self.W = self.config['fine_window_size']
|
||||
|
||||
d_model_c = self.config['coarse']['d_model']
|
||||
d_model_f = self.config['fine']['d_model']
|
||||
self.d_model_f = d_model_f
|
||||
if self.cat_c_feat:
|
||||
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
|
||||
self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.kaiming_normal_(p, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
|
||||
W = self.W
|
||||
stride = data['hw0_f'][0] // data['hw0_c'][0]
|
||||
|
||||
data.update({'W': W})
|
||||
if data['b_ids'].shape[0] == 0:
|
||||
feat0 = torch.empty(
|
||||
0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
feat1 = torch.empty(
|
||||
0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
return feat0, feat1
|
||||
|
||||
# 1. unfold(crop) all local windows
|
||||
feat_f0_unfold = F.unfold(
|
||||
feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2)
|
||||
feat_f0_unfold = rearrange(
|
||||
feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
feat_f1_unfold = F.unfold(
|
||||
feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2)
|
||||
feat_f1_unfold = rearrange(
|
||||
feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
|
||||
# 2. select only the predicted matches
|
||||
feat_f0_unfold = feat_f0_unfold[data['b_ids'],
|
||||
data['i_ids']] # [n, ww, cf]
|
||||
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
|
||||
|
||||
# option: use coarse-level loftr feature as context: concat and linear
|
||||
if self.cat_c_feat:
|
||||
feat_c_win = self.down_proj(
|
||||
torch.cat([
|
||||
feat_c0[data['b_ids'], data['i_ids']],
|
||||
feat_c1[data['b_ids'], data['j_ids']]
|
||||
], 0)) # [2n, c]
|
||||
feat_cf_win = self.merge_feat(
|
||||
torch.cat(
|
||||
[
|
||||
torch.cat([feat_f0_unfold, feat_f1_unfold],
|
||||
0), # [2n, ww, cf]
|
||||
repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # noqa
|
||||
],
|
||||
-1)) # noqa
|
||||
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
|
||||
|
||||
return feat_f0_unfold, feat_f1_unfold
|
||||
@@ -0,0 +1,89 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
"""
|
||||
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
||||
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn import Dropout, Module
|
||||
|
||||
|
||||
def elu_feature_map(x):
|
||||
return torch.nn.functional.elu(x) + 1
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
super().__init__()
|
||||
self.feature_map = elu_feature_map
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
""" Multi-Head linear attention proposed in "Transformers are RNNs"
|
||||
Args:
|
||||
queries: [N, L, H, D]
|
||||
keys: [N, S, H, D]
|
||||
values: [N, S, H, D]
|
||||
q_mask: [N, L]
|
||||
kv_mask: [N, S]
|
||||
Returns:
|
||||
queried_values: (N, L, H, D)
|
||||
"""
|
||||
Q = self.feature_map(queries)
|
||||
K = self.feature_map(keys)
|
||||
|
||||
# set padded position to zero
|
||||
if q_mask is not None:
|
||||
Q = Q * q_mask[:, :, None, None]
|
||||
if kv_mask is not None:
|
||||
K = K * kv_mask[:, :, None, None]
|
||||
values = values * kv_mask[:, :, None, None]
|
||||
|
||||
v_length = values.size(1)
|
||||
values = values / v_length # prevent fp16 overflow
|
||||
KV = torch.einsum('nshd,nshv->nhdv', K, values) # (S,D)' @ S,V
|
||||
Z = 1 / (torch.einsum('nlhd,nhd->nlh', Q, K.sum(dim=1)) + self.eps)
|
||||
queried_values = torch.einsum('nlhd,nhdv,nlh->nlhv', Q, KV,
|
||||
Z) * v_length
|
||||
|
||||
return queried_values.contiguous()
|
||||
|
||||
|
||||
class FullAttention(Module):
|
||||
|
||||
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
||||
super().__init__()
|
||||
self.use_dropout = use_dropout
|
||||
self.dropout = Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
||||
Args:
|
||||
queries: [N, L, H, D]
|
||||
keys: [N, S, H, D]
|
||||
values: [N, S, H, D]
|
||||
q_mask: [N, L]
|
||||
kv_mask: [N, S]
|
||||
Returns:
|
||||
queried_values: (N, L, H, D)
|
||||
"""
|
||||
|
||||
# Compute the unnormalized attention and apply the masks
|
||||
QK = torch.einsum('nlhd,nshd->nlsh', queries, keys)
|
||||
if kv_mask is not None:
|
||||
QK.masked_fill_(
|
||||
~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]),
|
||||
float('-inf'))
|
||||
|
||||
# Compute the attention and the weighted average
|
||||
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
||||
A = torch.softmax(softmax_temp * QK, dim=2)
|
||||
if self.use_dropout:
|
||||
A = self.dropout(A)
|
||||
|
||||
queried_values = torch.einsum('nlsh,nshd->nlhd', A, values)
|
||||
|
||||
return queried_values.contiguous()
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from modelscope.ops.quadtree_attention import QTAttA, QTAttB
|
||||
|
||||
|
||||
class QuadtreeAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
topks,
|
||||
value_branch=False,
|
||||
act=nn.GELU(),
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
scale=1,
|
||||
attn_type='B',
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, f'dim {dim} should be divided by num_heads {num_heads}.'
|
||||
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.q_proj = nn.Conv2d(
|
||||
dim, dim, kernel_size=1, stride=1, bias=qkv_bias)
|
||||
self.k_proj = nn.Conv2d(
|
||||
dim, dim, kernel_size=1, stride=1, bias=qkv_bias)
|
||||
self.v_proj = nn.Conv2d(
|
||||
dim, dim, kernel_size=1, stride=1, bias=qkv_bias)
|
||||
if attn_type == 'A':
|
||||
self.py_att = QTAttA(
|
||||
num_heads, dim // num_heads, scale=scale, topks=topks)
|
||||
else:
|
||||
self.py_att = QTAttB(
|
||||
num_heads, dim // num_heads, scale=scale, topks=topks)
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
m.init = True
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, target, H, W, msg=None):
|
||||
|
||||
B, N, C = x.shape
|
||||
x = x.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
target = target.permute(0, 2, 1).reshape(B, C, H, W)
|
||||
keys = []
|
||||
values = []
|
||||
queries = []
|
||||
|
||||
q = self.q_proj(x)
|
||||
k = self.k_proj(target)
|
||||
v = self.v_proj(target)
|
||||
for i in range(self.scale):
|
||||
keys.append(k)
|
||||
values.append(v)
|
||||
queries.append(q)
|
||||
|
||||
if i != self.scale - 1:
|
||||
k = F.avg_pool2d(k, kernel_size=2, stride=2)
|
||||
q = F.avg_pool2d(q, kernel_size=2, stride=2)
|
||||
v = F.avg_pool2d(v, kernel_size=2, stride=2)
|
||||
|
||||
msg = self.py_att(queries, keys, values).view(B, -1, C)
|
||||
|
||||
x = self.proj(msg)
|
||||
x = self.proj_drop(x)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,287 @@
|
||||
# Part of the implementation is borrowed and modified from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
import copy
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops.einops import rearrange
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
from .linear_attention import FullAttention, LinearAttention
|
||||
from .quadtree_attention import QuadtreeAttention
|
||||
|
||||
|
||||
class DWConv(nn.Module):
|
||||
|
||||
def __init__(self, dim=768):
|
||||
super(DWConv, self).__init__()
|
||||
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
||||
|
||||
def forward(self, x, H, W):
|
||||
B, N, C = x.shape
|
||||
x = x.transpose(1, 2).view(B, C, H, W)
|
||||
x = self.dwconv(x)
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.dwconv = DWConv(hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, H, W):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.dwconv(x, H, W)
|
||||
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LoFTREncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, d_model, nhead, attention='linear'):
|
||||
super(LoFTREncoderLayer, self).__init__()
|
||||
|
||||
self.dim = d_model // nhead
|
||||
self.nhead = nhead
|
||||
|
||||
# multi-head attention
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.attention = LinearAttention(
|
||||
) if attention == 'linear' else FullAttention()
|
||||
self.merge = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
# feed-forward network
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_model * 2, d_model * 2, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(d_model * 2, d_model, bias=False),
|
||||
)
|
||||
|
||||
# norm and dropout
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x, source, x_mask=None, source_mask=None):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): [N, L, C]
|
||||
source (torch.Tensor): [N, S, C]
|
||||
x_mask (torch.Tensor): [N, L] (optional)
|
||||
source_mask (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
bs = x.size(0)
|
||||
query, key, value = x, source, source
|
||||
|
||||
# multi-head attention
|
||||
query = self.q_proj(query).view(bs, -1, self.nhead,
|
||||
self.dim) # [N, L, (H, D)]
|
||||
key = self.k_proj(key).view(bs, -1, self.nhead,
|
||||
self.dim) # [N, S, (H, D)]
|
||||
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
||||
message = self.attention(
|
||||
query, key, value, q_mask=x_mask,
|
||||
kv_mask=source_mask) # [N, L, (H, D)]
|
||||
message = self.merge(message.view(bs, -1,
|
||||
self.nhead * self.dim)) # [N, L, C]
|
||||
message = self.norm1(message)
|
||||
|
||||
# feed-forward network
|
||||
message = self.mlp(torch.cat([x, message], dim=2))
|
||||
message = self.norm2(message)
|
||||
|
||||
return x + message
|
||||
|
||||
|
||||
class QuadtreeBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
topks,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
scale=1,
|
||||
attn_type='B'):
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = QuadtreeAttention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
topks=topks,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
scale=scale,
|
||||
attn_type=attn_type)
|
||||
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if hasattr(m, 'init'):
|
||||
return
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
elif isinstance(m, nn.Conv2d):
|
||||
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
fan_out //= m.groups
|
||||
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x, target, H, W):
|
||||
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), self.norm1(target), H, W))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LocalFeatureTransformer(nn.Module):
|
||||
"""A Local Feature Transformer (LoFTR) module."""
|
||||
|
||||
def __init__(self, config):
|
||||
super(LocalFeatureTransformer, self).__init__()
|
||||
self.block_type = config['block_type']
|
||||
self.config = config
|
||||
self.d_model = config['d_model']
|
||||
self.nhead = config['nhead']
|
||||
self.layer_names = config['layer_names']
|
||||
|
||||
if config['block_type'] == 'loftr':
|
||||
encoder_layer = LoFTREncoderLayer(config['d_model'],
|
||||
config['nhead'],
|
||||
config['attention'])
|
||||
self.layers = nn.ModuleList([
|
||||
copy.deepcopy(encoder_layer)
|
||||
for _ in range(len(self.layer_names))
|
||||
])
|
||||
self._reset_parameters()
|
||||
elif config['block_type'] == 'quadtree':
|
||||
encoder_layer = QuadtreeBlock(
|
||||
config['d_model'],
|
||||
config['nhead'],
|
||||
attn_type=config['attn_type'],
|
||||
topks=config['topks'],
|
||||
scale=3)
|
||||
self.layers = nn.ModuleList([
|
||||
copy.deepcopy(encoder_layer)
|
||||
for _ in range(len(self.layer_names))
|
||||
])
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [N, L, C]
|
||||
feat1 (torch.Tensor): [N, S, C]
|
||||
mask0 (torch.Tensor): [N, L] (optional)
|
||||
mask1 (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
|
||||
if len(feat0.shape) == 4:
|
||||
B, C, H, W = feat0.shape
|
||||
feat0 = rearrange(feat0, 'b c h w -> b (h w) c')
|
||||
feat1 = rearrange(feat1, 'b c h w -> b (h w) c')
|
||||
|
||||
if self.block_type == 'loftr':
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
elif name == 'cross':
|
||||
feat0 = layer(feat0, feat1, mask0, mask1)
|
||||
feat1 = layer(feat1, feat0, mask1, mask0)
|
||||
else:
|
||||
raise KeyError
|
||||
else:
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, H, W)
|
||||
feat1 = layer(feat1, feat1, H, W)
|
||||
elif name == 'cross':
|
||||
if self.config['block_type'] == 'quadtree':
|
||||
feat0, feat1 = layer(feat0, feat1, H,
|
||||
W), layer(feat1, feat0, H, W)
|
||||
else:
|
||||
feat0 = layer(feat0, feat1, H, W)
|
||||
feat1 = layer(feat1, feat0, H, W)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
return feat0, feat1
|
||||
@@ -0,0 +1,268 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.einops import rearrange
|
||||
|
||||
INF = 1e9
|
||||
|
||||
|
||||
def mask_border(m, b: int, v):
|
||||
""" Mask borders with value
|
||||
Args:
|
||||
m (torch.Tensor): [N, H0, W0, H1, W1]
|
||||
b (int)
|
||||
v (m.dtype)
|
||||
"""
|
||||
if b <= 0:
|
||||
return
|
||||
|
||||
m[:, :b] = v
|
||||
m[:, :, :b] = v
|
||||
m[:, :, :, :b] = v
|
||||
m[:, :, :, :, :b] = v
|
||||
m[:, -b:] = v
|
||||
m[:, :, -b:] = v
|
||||
m[:, :, :, -b:] = v
|
||||
m[:, :, :, :, -b:] = v
|
||||
|
||||
|
||||
def mask_border_with_padding(m, bd, v, p_m0, p_m1):
|
||||
if bd <= 0:
|
||||
return
|
||||
|
||||
m[:, :bd] = v
|
||||
m[:, :, :bd] = v
|
||||
m[:, :, :, :bd] = v
|
||||
m[:, :, :, :, :bd] = v
|
||||
|
||||
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
|
||||
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
|
||||
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
|
||||
m[b_idx, h0 - bd:] = v
|
||||
m[b_idx, :, w0 - bd:] = v
|
||||
m[b_idx, :, :, h1 - bd:] = v
|
||||
m[b_idx, :, :, :, w1 - bd:] = v
|
||||
|
||||
|
||||
def compute_max_candidates(p_m0, p_m1):
|
||||
"""Compute the max candidates of all pairs within a batch
|
||||
|
||||
Args:
|
||||
p_m0, p_m1 (torch.Tensor): padded masks
|
||||
"""
|
||||
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
|
||||
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
|
||||
max_cand = torch.sum(
|
||||
torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
|
||||
return max_cand
|
||||
|
||||
|
||||
class CoarseMatching(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# general config
|
||||
self.thr = config['thr']
|
||||
self.border_rm = config['border_rm']
|
||||
# -- # for trainig fine-level LoFTR
|
||||
self.train_coarse_percent = config['train_coarse_percent']
|
||||
self.train_pad_num_gt_min = config['train_pad_num_gt_min']
|
||||
|
||||
# we provide 2 options for differentiable matching
|
||||
self.match_type = config['match_type']
|
||||
if self.match_type == 'dual_softmax':
|
||||
self.temperature = config['dsmax_temperature']
|
||||
elif self.match_type == 'sinkhorn':
|
||||
try:
|
||||
from .superglue import log_optimal_transport
|
||||
except ImportError:
|
||||
raise ImportError('download superglue.py first!')
|
||||
self.log_optimal_transport = log_optimal_transport
|
||||
self.bin_score = nn.Parameter(
|
||||
torch.tensor(config['skh_init_bin_score'], requires_grad=True))
|
||||
self.skh_iters = config['skh_iters']
|
||||
self.skh_prefilter = config['skh_prefilter']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [N, L, C]
|
||||
feat1 (torch.Tensor): [N, S, C]
|
||||
data (dict)
|
||||
mask_c0 (torch.Tensor): [N, L] (optional)
|
||||
mask_c1 (torch.Tensor): [N, S] (optional)
|
||||
Update:
|
||||
data (dict): {
|
||||
'b_ids' (torch.Tensor): [M'],
|
||||
'i_ids' (torch.Tensor): [M'],
|
||||
'j_ids' (torch.Tensor): [M'],
|
||||
'gt_mask' (torch.Tensor): [M'],
|
||||
'mkpts0_c' (torch.Tensor): [M, 2],
|
||||
'mkpts1_c' (torch.Tensor): [M, 2],
|
||||
'mconf' (torch.Tensor): [M]}
|
||||
NOTE: M' != M during training.
|
||||
"""
|
||||
_, L, S, _ = feat_c0.size(0), feat_c0.size(1), feat_c1.size(
|
||||
1), feat_c0.size(2)
|
||||
|
||||
# normalize
|
||||
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
|
||||
[feat_c0, feat_c1])
|
||||
|
||||
if self.match_type == 'dual_softmax':
|
||||
sim_matrix = torch.einsum('nlc,nsc->nls', feat_c0,
|
||||
feat_c1) / self.temperature
|
||||
if mask_c0 is not None:
|
||||
sim_matrix.masked_fill_(
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF)
|
||||
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
|
||||
|
||||
elif self.match_type == 'sinkhorn':
|
||||
# sinkhorn, dustbin included
|
||||
sim_matrix = torch.einsum('nlc,nsc->nls', feat_c0, feat_c1)
|
||||
if mask_c0 is not None:
|
||||
sim_matrix[:, :L, :S].masked_fill_(
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF)
|
||||
|
||||
# build uniform prior & use sinkhorn
|
||||
log_assign_matrix = self.log_optimal_transport(
|
||||
sim_matrix, self.bin_score, self.skh_iters)
|
||||
assign_matrix = log_assign_matrix.exp()
|
||||
conf_matrix = assign_matrix[:, :-1, :-1]
|
||||
|
||||
# filter prediction with dustbin score (only in evaluation mode)
|
||||
if not self.training and self.skh_prefilter:
|
||||
filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L]
|
||||
filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S]
|
||||
conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
|
||||
conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
|
||||
|
||||
if self.config['sparse_spvs']:
|
||||
data.update({'conf_matrix_with_bin': assign_matrix.clone()})
|
||||
|
||||
data.update({'conf_matrix': conf_matrix})
|
||||
|
||||
# predict coarse matches from conf_matrix
|
||||
data.update(**self.get_coarse_match(conf_matrix, data))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_coarse_match(self, conf_matrix, data):
|
||||
"""
|
||||
Args:
|
||||
conf_matrix (torch.Tensor): [N, L, S]
|
||||
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
|
||||
Returns:
|
||||
coarse_matches (dict): {
|
||||
'b_ids' (torch.Tensor): [M'],
|
||||
'i_ids' (torch.Tensor): [M'],
|
||||
'j_ids' (torch.Tensor): [M'],
|
||||
'gt_mask' (torch.Tensor): [M'],
|
||||
'm_bids' (torch.Tensor): [M],
|
||||
'mkpts0_c' (torch.Tensor): [M, 2],
|
||||
'mkpts1_c' (torch.Tensor): [M, 2],
|
||||
'mconf' (torch.Tensor): [M]}
|
||||
"""
|
||||
axes_lengths = {
|
||||
'h0c': data['hw0_c'][0],
|
||||
'w0c': data['hw0_c'][1],
|
||||
'h1c': data['hw1_c'][0],
|
||||
'w1c': data['hw1_c'][1]
|
||||
}
|
||||
_device = conf_matrix.device
|
||||
# 1. confidence thresholding
|
||||
mask = conf_matrix > self.thr
|
||||
mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
|
||||
**axes_lengths)
|
||||
if 'mask0' not in data:
|
||||
mask_border(mask, self.border_rm, False)
|
||||
else:
|
||||
mask_border_with_padding(mask, self.border_rm, False,
|
||||
data['mask0'], data['mask1'])
|
||||
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
|
||||
**axes_lengths)
|
||||
|
||||
# 2. mutual nearest
|
||||
mask = mask \
|
||||
* (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
|
||||
* (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
|
||||
|
||||
# 3. find all valid coarse matches
|
||||
# this only works when at most one `True` in each row
|
||||
mask_v, all_j_ids = mask.max(dim=2)
|
||||
b_ids, i_ids = torch.where(mask_v)
|
||||
j_ids = all_j_ids[b_ids, i_ids]
|
||||
mconf = conf_matrix[b_ids, i_ids, j_ids]
|
||||
|
||||
# 4. Random sampling of training samples for fine-level LoFTR
|
||||
# (optional) pad samples with gt coarse-level matches
|
||||
if self.training:
|
||||
# NOTE:
|
||||
# The sampling is performed across all pairs in a batch without manually balancing
|
||||
# #samples for fine-level increases w.r.t. batch_size
|
||||
if 'mask0' not in data:
|
||||
num_candidates_max = mask.size(0) * max(
|
||||
mask.size(1), mask.size(2))
|
||||
else:
|
||||
num_candidates_max = compute_max_candidates(
|
||||
data['mask0'], data['mask1'])
|
||||
num_matches_train = int(num_candidates_max
|
||||
* self.train_coarse_percent)
|
||||
num_matches_pred = len(b_ids)
|
||||
assert self.train_pad_num_gt_min < num_matches_train, 'min-num-gt-pad should be less than num-train-matches'
|
||||
|
||||
# pred_indices is to select from prediction
|
||||
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
|
||||
pred_indices = torch.arange(num_matches_pred, device=_device)
|
||||
else:
|
||||
pred_indices = torch.randint(
|
||||
num_matches_pred,
|
||||
(num_matches_train - self.train_pad_num_gt_min, ),
|
||||
device=_device)
|
||||
|
||||
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
|
||||
gt_pad_indices = torch.randint(
|
||||
len(data['spv_b_ids']),
|
||||
(max(num_matches_train - num_matches_pred,
|
||||
self.train_pad_num_gt_min), ),
|
||||
device=_device)
|
||||
mconf_gt = torch.zeros(
|
||||
len(data['spv_b_ids']),
|
||||
device=_device) # set conf of gt paddings to all zero
|
||||
|
||||
b_ids, i_ids, j_ids, mconf = map(
|
||||
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
|
||||
dim=0),
|
||||
*zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
|
||||
[j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
|
||||
|
||||
# These matches select patches that feed into fine-level network
|
||||
coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
|
||||
|
||||
# 4. Update with matches in original image resolution
|
||||
scale = data['hw0_i'][0] / data['hw0_c'][0]
|
||||
scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
|
||||
scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
|
||||
mkpts0_c = torch.stack(
|
||||
[i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
|
||||
dim=1) * scale0
|
||||
mkpts1_c = torch.stack(
|
||||
[j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
|
||||
dim=1) * scale1
|
||||
|
||||
# These matches is the current prediction (for visualization)
|
||||
coarse_matches.update({
|
||||
'gt_mask': mconf == 0,
|
||||
'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
|
||||
'mkpts0_c': mkpts0_c[mconf != 0],
|
||||
'mkpts1_c': mkpts1_c[mconf != 0],
|
||||
'mconf': mconf[mconf != 0]
|
||||
})
|
||||
|
||||
return coarse_matches
|
||||
@@ -0,0 +1,86 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from kornia.geometry.subpix import dsnt
|
||||
from kornia.utils.grid import create_meshgrid
|
||||
|
||||
|
||||
class FineMatching(nn.Module):
|
||||
"""FineMatching with s2d paradigm"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, feat_f0, feat_f1, data):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [M, WW, C]
|
||||
feat1 (torch.Tensor): [M, WW, C]
|
||||
data (dict)
|
||||
Update:
|
||||
data (dict):{
|
||||
'expec_f' (torch.Tensor): [M, 3],
|
||||
'mkpts0_f' (torch.Tensor): [M, 2],
|
||||
'mkpts1_f' (torch.Tensor): [M, 2]}
|
||||
"""
|
||||
M, WW, C = feat_f0.shape
|
||||
W = int(math.sqrt(WW))
|
||||
scale = data['hw0_i'][0] / data['hw0_f'][0]
|
||||
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
|
||||
|
||||
# corner case: if no coarse matches found
|
||||
if M == 0:
|
||||
assert self.training is False, 'M is always >0, when training, see coarse_matching.py'
|
||||
# logger.warning('No matches found in coarse-level.')
|
||||
data.update({
|
||||
'expec_f': torch.empty(0, 3, device=feat_f0.device),
|
||||
'mkpts0_f': data['mkpts0_c'],
|
||||
'mkpts1_f': data['mkpts1_c'],
|
||||
})
|
||||
return
|
||||
|
||||
feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :]
|
||||
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
|
||||
softmax_temp = 1. / C**.5
|
||||
heatmap = torch.softmax(
|
||||
softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
||||
|
||||
# compute coordinates from heatmap
|
||||
coords_normalized = dsnt.spatial_expectation2d(heatmap[None],
|
||||
True)[0] # [M, 2]
|
||||
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
|
||||
1, -1, 2) # [1, WW, 2]
|
||||
|
||||
# compute std over <x, y>
|
||||
var = torch.sum(
|
||||
grid_normalized**2 * heatmap.view(-1, WW, 1),
|
||||
dim=1) - coords_normalized**2 # [M, 2]
|
||||
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)),
|
||||
-1) # [M] clamp needed for numerical stability
|
||||
|
||||
# for fine-level supervision
|
||||
data.update(
|
||||
{'expec_f':
|
||||
torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
|
||||
|
||||
# compute absolute kpt coords
|
||||
self.get_fine_match(coords_normalized, data)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_fine_match(self, coords_normed, data):
|
||||
W, scale = self.W, self.scale
|
||||
|
||||
# mkpts0_f and mkpts1_f
|
||||
mkpts0_f = data['mkpts0_c']
|
||||
scale1 = scale * data['scale1'][
|
||||
data['b_ids']] if 'scale0' in data else scale
|
||||
mkpts1_f = data['mkpts1_c'] + (
|
||||
coords_normed * # noqa
|
||||
(W // 2) * scale1)[:len(data['mconf'])] # noqa
|
||||
|
||||
data.update({'mkpts0_f': mkpts0_f, 'mkpts1_f': mkpts1_f})
|
||||
@@ -0,0 +1,52 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PositionEncodingSine(nn.Module):
|
||||
"""
|
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
|
||||
"""
|
||||
Args:
|
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
|
||||
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
|
||||
on the final performance. For now, we keep both impls for backward compatability.
|
||||
We will remove the buggy impl after re-training all variants of our released models.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros((d_model, *max_shape))
|
||||
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
||||
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
||||
if temp_bug_fix:
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model // 2, 2).float() * # noqa
|
||||
(-math.log(10000.0) / (d_model // 2)))
|
||||
else: # a buggy implementation (for backward compatability only)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, d_model // 2, 2).float() * # noqa
|
||||
(-math.log(10000.0) / d_model // 2))
|
||||
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
||||
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
||||
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
||||
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
||||
|
||||
self.register_buffer(
|
||||
'pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [N, C, H, W]
|
||||
"""
|
||||
return x + self.pe[:, :, :x.size(2), :x.size(3)]
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os.path as osp
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .config.default import get_cfg_defaults
|
||||
from .loftr_quadtree.loftr import LoFTR
|
||||
from .utils.misc import lower_config
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_matching, module_name=Models.quadtree_attention_image_matching)
|
||||
class QuadTreeAttentionForImageMatching(TorchModel):
|
||||
'''
|
||||
Image matching with quadtree attention. This model is trained on outdoor images.
|
||||
For more details, please refer to https://arxiv.org/abs/2201.02767
|
||||
'''
|
||||
|
||||
def __init__(self, model_dir: str, model_type='outdoor', **kwargs):
|
||||
'''
|
||||
Args:
|
||||
model_dir: model directory
|
||||
model_type: model type, 'outdoor' or 'indoor'. Only support outdoor model for modelscope.
|
||||
'''
|
||||
assert model_type == 'outdoor', 'Only support outdoor model for modelscope'
|
||||
# Note: for indoor model, max_image_size should be 640 because scannet training image size is 640,
|
||||
# and currently, this model is overfited on scannet. For outdoor model, larger image size will be better
|
||||
|
||||
super().__init__(model_dir, **kwargs)
|
||||
config = get_cfg_defaults()
|
||||
_config = lower_config(config)
|
||||
|
||||
matcher = LoFTR(config=_config['loftr'])
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
state_dict = torch.load(
|
||||
str(model_path), map_location='cpu')['state_dict']
|
||||
|
||||
matcher.load_state_dict(state_dict, strict=True)
|
||||
self.matcher = matcher
|
||||
|
||||
self.matcher.eval()
|
||||
self.matcher.to('cuda')
|
||||
|
||||
def forward(self, Inputs):
|
||||
'''
|
||||
Args:
|
||||
Inputs: a dict with keys 'image0', 'image1' and 'preprocess_info'.
|
||||
'image0' and 'image1' are torch tensor with shape [1, 1, H1, W1]
|
||||
and [1, 1, H2, W2]. 'preprocess_info' contains the information of
|
||||
resizing, which will be used for postprocessing.
|
||||
'''
|
||||
self.matcher(Inputs)
|
||||
return {
|
||||
'kpts0': Inputs['mkpts0_f'],
|
||||
'kpts1': Inputs['mkpts1_f'],
|
||||
'conf': Inputs['mconf'],
|
||||
'preprocess_info': Inputs['preprocess_info']
|
||||
}
|
||||
|
||||
def postprocess(self, Inputs):
|
||||
matching_result = Inputs
|
||||
|
||||
results = {OutputKeys.MATCHES: matching_result}
|
||||
return results
|
||||
|
||||
def inference(self, data):
|
||||
results = self.forward(data)
|
||||
|
||||
return results
|
||||
11
modelscope/models/cv/image_matching/utils/misc.py
Normal file
11
modelscope/models/cv/image_matching/utils/misc.py
Normal file
@@ -0,0 +1,11 @@
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
|
||||
def lower_config(yacs_cfg):
|
||||
if not isinstance(yacs_cfg, CN):
|
||||
return yacs_cfg
|
||||
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
|
||||
0
modelscope/ops/__init__.py
Normal file
0
modelscope/ops/__init__.py
Normal file
3
modelscope/ops/quadtree_attention/__init__.py
Normal file
3
modelscope/ops/quadtree_attention/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .modules.quadtree_attention import QTAttA, QTAttB
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from einops.einops import rearrange
|
||||
from torch.autograd import Function
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
cur_dir = Path(__file__).parent.resolve()
|
||||
score_computation_cuda = \
|
||||
load(name='score_computation_cuda', # noqa
|
||||
sources=[str(cur_dir / '../src/score_computation.cpp'), # noqa
|
||||
str(cur_dir / '../src/score_computation_kernal.cu')], # noqa
|
||||
extra_cflags=['-g'], extra_cuda_cflags=['-O2']) # noqa
|
||||
|
||||
value_aggregation_cuda = \
|
||||
load(name='value_aggregation_cuda', # noqa
|
||||
sources=[str(cur_dir / '../src/value_aggregation.cpp'), # noqa
|
||||
str(cur_dir / '../src/value_aggregation_kernel.cu')], # noqa
|
||||
extra_cflags=['-g'], extra_cuda_cflags=['-O2']) # noqa
|
||||
|
||||
|
||||
class ScoreComputation(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, query, key, index):
|
||||
x = score_computation_cuda.score_forward(query, key, index)
|
||||
ctx.save_for_backward(query, key, index)
|
||||
return x[0]
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input1, input2, index = ctx.saved_tensors
|
||||
grad_output = grad_output.contiguous()
|
||||
x = score_computation_cuda.score_backward(grad_output, input1, input2,
|
||||
index)
|
||||
return x[0], x[1], None
|
||||
|
||||
|
||||
score_computation_op = ScoreComputation.apply
|
||||
|
||||
|
||||
class value_aggregation(Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, score, value, index):
|
||||
ctx.save_for_backward(score, value, index)
|
||||
f = score.shape[2]
|
||||
score = rearrange(
|
||||
score,
|
||||
'b n f K h -> b (n f) K h') # [b, N, 4, 4K, H] -> [b, 4N, 4K, H]
|
||||
index = rearrange(
|
||||
index,
|
||||
'b n f K h -> b (n f) K h') # [b, N, 4, 4K, H] -> [b, 4N, 4K, H]
|
||||
b, N, _, H = score.shape
|
||||
D = value.shape[-1]
|
||||
# value [b, M, H, D]
|
||||
output = score.new_zeros([b, N, H, D]).contiguous() # b, 4N, H, D
|
||||
value_aggregation_cuda.value_aggregation_forward(
|
||||
score, value, index, output)
|
||||
output = rearrange(output, 'b (n f) h d -> b n f h d', f=f)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
score, value, index = ctx.saved_tensors
|
||||
f = score.shape[2]
|
||||
score = rearrange(score, 'b n f K h -> b (n f) K h')
|
||||
index = rearrange(index, 'b n f K h -> b (n f) K h')
|
||||
|
||||
grad_output = grad_output.contiguous()
|
||||
|
||||
grad_score = score.new_zeros(score.shape).contiguous()
|
||||
grad_value = value.new_zeros(value.shape).contiguous()
|
||||
|
||||
value_aggregation_cuda.value_aggregation_backward(
|
||||
grad_output, score, value, index, grad_score, grad_value)
|
||||
grad_score = rearrange(grad_score, 'b (n f) K h -> b n f K h', f=f)
|
||||
return grad_score, grad_value, None
|
||||
|
||||
|
||||
value_aggregation_op = value_aggregation.apply
|
||||
370
modelscope/ops/quadtree_attention/modules/quadtree_attention.py
Normal file
370
modelscope/ops/quadtree_attention/modules/quadtree_attention.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.einops import rearrange
|
||||
|
||||
from modelscope.ops.quadtree_attention.functions.quadtree_attention import (
|
||||
score_computation_op, value_aggregation_op)
|
||||
|
||||
|
||||
class QTAttA(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nhead,
|
||||
dim,
|
||||
topks=[32, 32, 32, 32],
|
||||
scale=None,
|
||||
use_dropout=False,
|
||||
attention_dropout=0.1,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_dropout = use_dropout
|
||||
self.topks = topks
|
||||
self.nhead = nhead
|
||||
self.dim = dim
|
||||
|
||||
def process_coarse_level(self, query, key, value, topk):
|
||||
bs, c, h, w = key.shape
|
||||
cur_dim = key.shape[1] // self.nhead
|
||||
|
||||
key = rearrange(key,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
value = rearrange(value,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
query = rearrange(query,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim)
|
||||
|
||||
QK = torch.einsum('nlhd,nshd->nlsh', query, key)
|
||||
softmax_temp = 1.0 / cur_dim**0.5 # sqrt(D)
|
||||
A = torch.softmax(softmax_temp * QK, dim=-2)
|
||||
|
||||
# mask out top K tokens
|
||||
topk_score, topk_idx = torch.topk(A, dim=-2, k=topk, largest=True)
|
||||
mask = torch.ones_like(A)
|
||||
mask = mask.scatter(
|
||||
dim=-2, index=topk_idx, src=torch.zeros_like(topk_idx).float())
|
||||
|
||||
# message is only computed within the unmasked
|
||||
message = torch.einsum(
|
||||
'nlsh,nshd->nlhd', A * mask,
|
||||
value) # .reshape(bs, h, w, self.nhead, cur_dim)
|
||||
|
||||
return A, message, topk_score, topk_idx
|
||||
|
||||
def process_fine_level(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
topk_score,
|
||||
topk_pos,
|
||||
topk_prev,
|
||||
topk,
|
||||
final=False):
|
||||
bs, c, h, w = key.shape
|
||||
|
||||
cur_dim = key.shape[1] // self.nhead
|
||||
key = rearrange(key,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
value = rearrange(value,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
|
||||
query = query.view(bs, c, h // 2, 2, w // 2, 2)
|
||||
query = rearrange(query, 'b c h t1 w t2-> b (h w) (t1 t2) c ').view(
|
||||
bs, -1, 4, self.nhead, cur_dim)
|
||||
|
||||
# convert 2d coordinates to 1d index
|
||||
idx_gather = []
|
||||
topk_pos = topk_pos * 2
|
||||
for x in [0, 1]:
|
||||
for y in [0, 1]:
|
||||
idx = (topk_pos[0]
|
||||
+ x) * w + topk_pos[1] + y # convert to index
|
||||
idx_gather.append(idx)
|
||||
|
||||
idx = torch.stack(idx_gather, dim=3) # [N, L, K, 4, H, D]
|
||||
|
||||
# Compute score
|
||||
# query: [b, N, 4, H, D]
|
||||
# key: [b, 4N, H, D]
|
||||
# idx: [b, N, K, 4, H]
|
||||
# QK: [b, N, 4, 4K, H]
|
||||
QK = score_computation_op(query, key.contiguous(),
|
||||
idx.view(bs, -1, topk_prev * 4, self.nhead))
|
||||
QK = rearrange(QK, 'n l w (k f) h -> n l w k f h', k=topk_prev, f=4)
|
||||
softmax_temp = 1.0 / cur_dim**0.5 # sqrt(D)
|
||||
A = torch.softmax(
|
||||
softmax_temp * QK, dim=-2) # [N, L//scale**i, K, 4, H]
|
||||
# Score redistribution
|
||||
topk_score = topk_score.unsqueeze(-2).unsqueeze(2)
|
||||
A = (A * topk_score).reshape(bs, -1, 4, topk_prev * 4, self.nhead)
|
||||
idx = idx.view(bs, -1, 1, topk_prev * 4,
|
||||
self.nhead).repeat(1, 1, 4, 1, 1) # [N, L,4, K*4, H]
|
||||
topk_score, topk_idx = torch.topk(A, dim=-2, k=topk, largest=True)
|
||||
|
||||
if not final:
|
||||
mask = torch.ones_like(A)
|
||||
mask = mask.scatter(
|
||||
dim=-2, index=topk_idx, src=torch.zeros_like(topk_idx).float())
|
||||
message = value_aggregation_op(A * mask, value.contiguous(), idx)
|
||||
else:
|
||||
message = value_aggregation_op(A, value.contiguous(), idx)
|
||||
|
||||
if not final:
|
||||
topk_idx = torch.gather(idx, index=topk_idx, dim=-2)
|
||||
topk_idx = rearrange(
|
||||
topk_idx,
|
||||
'b (h w) (t1 t2) k nh -> b (h t1 w t2) k nh',
|
||||
h=h // 2,
|
||||
t1=2) # reshape back
|
||||
topk_score = rearrange(
|
||||
topk_score,
|
||||
'b (h w) (t1 t2) k nh -> b (h t1 w t2) k nh',
|
||||
h=h // 2,
|
||||
t1=2) # reshape back
|
||||
|
||||
return A, message, topk_score, topk_idx
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
"""Multi-head quadtree attention
|
||||
Args:
|
||||
queries: Query pyramid [N, C, H, W]
|
||||
keys: Key pyramid [N, C, H, W]
|
||||
values: Value pyramid [N, C, H, W]
|
||||
Returns:
|
||||
message: (N, C, H, W)
|
||||
"""
|
||||
|
||||
bs = queries[0].shape[0]
|
||||
messages = []
|
||||
topk = self.topks[0]
|
||||
|
||||
for i, (query, key, value) in enumerate(
|
||||
zip(reversed(queries), reversed(keys), reversed(values))):
|
||||
bs, c, h, w = key.shape
|
||||
if i == 0:
|
||||
A, message, topk_score, topk_idx = self.process_coarse_level(
|
||||
query, key, value,
|
||||
topk) # Full attention for coarest level
|
||||
else:
|
||||
topk_prev = topk
|
||||
topk = self.topks[i]
|
||||
final = True if i == len(queries) - 1 else False
|
||||
A, message, topk_score, topk_idx = self.process_fine_level(
|
||||
query, key, value, topk_score, topk_pos, topk_prev, topk,
|
||||
final) # Quadtree attention
|
||||
|
||||
messages.append(message)
|
||||
if topk_idx is not None:
|
||||
topk_pos = torch.stack([ # noqa
|
||||
topk_idx // w, topk_idx % w
|
||||
]) # convert to coordinate
|
||||
|
||||
final_message = 0
|
||||
for i, m in enumerate(messages):
|
||||
if i == 0:
|
||||
final_message = m
|
||||
else:
|
||||
final_message = final_message.unsqueeze(2) + m
|
||||
final_message = rearrange(
|
||||
final_message,
|
||||
'b (H W) (t1 t2) h d -> b (H t1 W t2) h d',
|
||||
t1=2,
|
||||
t2=2,
|
||||
H=queries[-i].shape[2])
|
||||
|
||||
return final_message
|
||||
|
||||
|
||||
class QTAttB(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
nhead,
|
||||
dim,
|
||||
scale,
|
||||
topks=[32, 32, 32, 32],
|
||||
use_dropout=False,
|
||||
attention_dropout=0.1,
|
||||
lepe=False):
|
||||
super().__init__()
|
||||
self.use_dropout = use_dropout
|
||||
self.topks = topks
|
||||
self.nhead = nhead
|
||||
self.dim = dim
|
||||
self.lepe = lepe
|
||||
if lepe: # locally enhanced position encoding
|
||||
self.get_vs = nn.ModuleList([
|
||||
nn.Conv2d(
|
||||
dim * nhead,
|
||||
dim * nhead,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=dim * nhead) for _ in range(scale)
|
||||
])
|
||||
self.register_parameter('weight', nn.Parameter(torch.randn(scale)))
|
||||
|
||||
def process_coarse_level(self, query, key, value, topk):
|
||||
bs, c, h, w = key.shape
|
||||
|
||||
cur_dim = key.shape[1] // self.nhead
|
||||
key = rearrange(key,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
value = rearrange(value,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
query = rearrange(query,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim)
|
||||
QK = torch.einsum('nlhd,nshd->nlsh', query, key)
|
||||
softmax_temp = 1.0 / cur_dim**0.5 # sqrt(D)
|
||||
|
||||
A = torch.softmax(softmax_temp * QK, dim=-2)
|
||||
topk_score, topk_idx = torch.topk(A, dim=-2, k=topk, largest=True)
|
||||
|
||||
message = torch.einsum(
|
||||
'nlsh,nshd->nlhd', A,
|
||||
value) # .reshape(bs, h, w, self.nhead, cur_dim)
|
||||
|
||||
return A, message, topk_score, topk_idx
|
||||
|
||||
def process_fine_level(self,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
topk_score,
|
||||
topk_pos,
|
||||
topk_prev,
|
||||
topk,
|
||||
final=False):
|
||||
bs, c, h, w = key.shape
|
||||
|
||||
cur_dim = key.shape[1] // self.nhead
|
||||
key = rearrange(key,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
value = rearrange(value,
|
||||
'b c h w -> b (h w) c').view(bs, -1, self.nhead,
|
||||
cur_dim) # [N, S, H, D]
|
||||
|
||||
query = query.view(bs, c, h // 2, 2, w // 2, 2)
|
||||
query = rearrange(query, 'b c h t1 w t2-> b (h w) (t1 t2) c ').view(
|
||||
bs, -1, 4, self.nhead, cur_dim)
|
||||
|
||||
# convert 2D coordiantes to 1D index
|
||||
topk_pos = topk_pos * 2
|
||||
idx_gather = []
|
||||
for x in [0, 1]:
|
||||
for y in [0, 1]:
|
||||
idx = (topk_pos[0]
|
||||
+ x) * w + topk_pos[1] + y # convert to index
|
||||
idx_gather.append(idx)
|
||||
idx = torch.stack(idx_gather, dim=3) # [N, L, K, 4, H, D]
|
||||
|
||||
# score computation
|
||||
# query: [b, N, 4, H, D]
|
||||
# key: [b, 4N, H, D]
|
||||
# idx: [b, N, K, 4, H]
|
||||
# QK: [b, N, 4, 4K, H]
|
||||
QK = score_computation_op(query, key.contiguous(),
|
||||
idx.view(bs, -1, topk_prev * 4, self.nhead))
|
||||
softmax_temp = 1.0 / cur_dim**0.5 # sqrt(D)
|
||||
A = torch.softmax(
|
||||
softmax_temp * QK, dim=-2) # [N, L//scale**i, K, 4, H]
|
||||
A = A.reshape(bs, -1, 4, topk_prev * 4, self.nhead)
|
||||
idx = idx.view(bs, -1, 1, topk_prev * 4,
|
||||
self.nhead).repeat(1, 1, 4, 1, 1) # [N, L,4, K*4, H]
|
||||
|
||||
topk_score, topk_idx = torch.topk(A, dim=-2, k=topk, largest=True)
|
||||
message = value_aggregation_op(A, value.contiguous(), idx)
|
||||
topk_idx = torch.gather(idx, index=topk_idx, dim=-2)
|
||||
topk_idx = rearrange(
|
||||
topk_idx,
|
||||
'b (h w) (t1 t2) k nh -> b (h t1 w t2) k nh',
|
||||
h=h // 2,
|
||||
t1=2) # reshape back
|
||||
topk_score = rearrange(
|
||||
topk_score,
|
||||
'b (h w) (t1 t2) k nh -> b (h t1 w t2) k nh',
|
||||
h=h // 2,
|
||||
t1=2) # reshape back
|
||||
|
||||
return A, message, topk_score, topk_idx
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
"""Multi-head quadtree attention
|
||||
Args:
|
||||
queries: Query pyramid [N, C, H, W]
|
||||
keys: Key pyramid [N, C, H, W]
|
||||
values: Value pyramid [N, C, H, W]
|
||||
Returns:
|
||||
message: (N, C, H, W)
|
||||
"""
|
||||
|
||||
bs = queries[0].shape[0]
|
||||
|
||||
messages = []
|
||||
topk = self.topks[0]
|
||||
for i, (query, key, value) in enumerate(
|
||||
zip(reversed(queries), reversed(keys), reversed(values))):
|
||||
bs, c, h, w = key.shape
|
||||
if i == 0: # Full attention for the coarest level
|
||||
A, message, topk_score, topk_idx = self.process_coarse_level(
|
||||
query, key, value, topk)
|
||||
else:
|
||||
topk_prev = topk
|
||||
topk = self.topks[i]
|
||||
final = True if i == len(queries) - 1 else False
|
||||
A, message, topk_score, topk_idx = self.process_fine_level(
|
||||
query, key, value, topk_score, topk_pos, topk_prev, topk,
|
||||
final)
|
||||
|
||||
messages.append(message)
|
||||
topk_pos = torch.stack([ # noqa
|
||||
topk_idx // w, topk_idx % w
|
||||
]) # convert to coordinate
|
||||
|
||||
# Merge messages of different layers
|
||||
final_message = 0
|
||||
|
||||
weight = torch.softmax(self.weight, dim=0)
|
||||
for i, m in enumerate(messages):
|
||||
if self.lepe:
|
||||
H, W = values[-(i + 1)].shape[-2:]
|
||||
lepe = self.get_vs[i](values[-(i + 1)])
|
||||
|
||||
if i == 0:
|
||||
if self.lepe:
|
||||
lepe = rearrange(
|
||||
lepe, 'b (hd d) H W -> b (H W) hd d', hd=self.nhead)
|
||||
final_message = (m + lepe) * weight[i]
|
||||
else:
|
||||
final_message = m * weight[i]
|
||||
else:
|
||||
if self.lepe:
|
||||
lepe = rearrange(
|
||||
lepe,
|
||||
'b (hd d) (H t1) (W t2) -> b (H W) (t1 t2) hd d',
|
||||
hd=self.nhead,
|
||||
t1=2,
|
||||
t2=2)
|
||||
final_message = final_message.unsqueeze(
|
||||
2) + (m + lepe) * weight[i]
|
||||
else:
|
||||
final_message = final_message.unsqueeze(2) + m * weight[i]
|
||||
|
||||
final_message = rearrange(
|
||||
final_message,
|
||||
'b (H W) (t1 t2) h d -> b (H t1 W t2) h d',
|
||||
t1=2,
|
||||
t2=2,
|
||||
H=queries[-i].shape[2])
|
||||
return final_message
|
||||
0
modelscope/ops/quadtree_attention/src/__init__.py
Normal file
0
modelscope/ops/quadtree_attention/src/__init__.py
Normal file
40
modelscope/ops/quadtree_attention/src/score_computation.cpp
Normal file
40
modelscope/ops/quadtree_attention/src/score_computation.cpp
Normal file
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
#include "score_computation.h"
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
#include<iostream>
|
||||
#include<stdio.h>
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
// == Forward
|
||||
std::vector<torch::Tensor> score_cuda_forward(torch::Tensor input1, //parameter: K*group_num, C
|
||||
torch::Tensor input2, //tensor : B, N, C
|
||||
torch::Tensor index) //tensor: B, N, K
|
||||
{
|
||||
CHECK_INPUT(input1);
|
||||
CHECK_INPUT(input2);
|
||||
CHECK_INPUT(index);
|
||||
return ScoreData_ongpu(input1, input2, index);
|
||||
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> score_cuda_backward(torch::Tensor grad_output1, //B,N,C,group_num
|
||||
torch::Tensor input1, //scene : N, H, W, C1
|
||||
torch::Tensor input2, // scene coords: N, H, W, 3
|
||||
torch::Tensor index) //tensor: B, N, K
|
||||
{
|
||||
CHECK_INPUT(grad_output1);
|
||||
CHECK_INPUT(input1);
|
||||
CHECK_INPUT(input2);
|
||||
CHECK_INPUT(index);
|
||||
return ScoreData_backward_ongpu(grad_output1, input1, input2, index);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("score_forward", &score_cuda_forward, "score forward (CUDA)");
|
||||
m.def("score_backward", &score_cuda_backward, "score forward (CUDA)");
|
||||
}
|
||||
24
modelscope/ops/quadtree_attention/src/score_computation.h
Normal file
24
modelscope/ops/quadtree_attention/src/score_computation.h
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
#ifndef _Score_CUDA
|
||||
#define _Score_CUDA
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
std::vector<torch::Tensor> score_cuda_forward(torch::Tensor input1, //query t: N, H, W, C1
|
||||
torch::Tensor input2, //scene : N, H, W, C1
|
||||
torch::Tensor index); //scene : N, H, W, C1
|
||||
|
||||
|
||||
|
||||
std::vector<at::Tensor> ScoreData_ongpu(at::Tensor input1, //query t: N, H, W, C1
|
||||
at::Tensor input2, //scene : N, H, W, C1
|
||||
at::Tensor index); //scene : N, H, W, C1
|
||||
|
||||
|
||||
std::vector<torch::Tensor> ScoreData_backward_ongpu(torch::Tensor grad_output1, //B,N,C,group_num
|
||||
torch::Tensor input1, //scene : N, H, W, C1
|
||||
torch::Tensor input2, // scene coords: N, H, W, 3
|
||||
torch::Tensor index); //tensor: B, N, K
|
||||
|
||||
#endif
|
||||
@@ -0,0 +1,186 @@
|
||||
// Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
#include <vector>
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
#include <vector>
|
||||
#include "score_computation.h"
|
||||
#include <stdio.h>
|
||||
|
||||
#define ROUND_OFF 50000
|
||||
|
||||
#define CUDA_NUM_THREADS 1024
|
||||
#define WARPS_PER_BLOCK 1
|
||||
#define THREADS_PER_WARP 32
|
||||
#define MAX_H 8
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
|
||||
|
||||
#define GET_BLOCKS(n, t) (n+t-1) / t
|
||||
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ScoreData(
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> query, // B, N1, 4, H, dim
|
||||
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> key, //B, N2, H, dim
|
||||
torch::PackedTensorAccessor32<long,4,torch::RestrictPtrTraits> index, //B, N1, K*4, H
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> output //B, N1, 4, K*4, H
|
||||
){
|
||||
extern __shared__ char patch_data_char[];
|
||||
|
||||
scalar_t *feat1_data = (scalar_t *)patch_data_char;
|
||||
|
||||
|
||||
int b = blockIdx.x;
|
||||
int n1 = blockIdx.y;
|
||||
int f = blockIdx.z;
|
||||
|
||||
int ch_off = threadIdx.x;
|
||||
|
||||
int D=query.size(4);
|
||||
int HD=query.size(3)*D;
|
||||
int K=index.size(2);
|
||||
for(int ch = ch_off; ch < HD; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS
|
||||
feat1_data[ch] = query[b][n1][f][ch/D][ch%D];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
__shared__ scalar_t score[THREADS_PER_WARP*MAX_H];
|
||||
for(int k = ch_off; k < K; k += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS
|
||||
|
||||
for(int h=0;h<query.size(3);h++){
|
||||
int score_idx=ch_off*query.size(3)+h;
|
||||
score[score_idx]=0;
|
||||
int idx=index[b][n1][k][h];
|
||||
for(int d=0;d<query.size(4);d++){
|
||||
score[score_idx]+=feat1_data[h*D+d]*key[b][idx][h][d];
|
||||
}
|
||||
output[b][n1][f][k][h]=score[score_idx];
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
std::vector<torch::Tensor> ScoreData_ongpu(torch::Tensor query, // B, N1, 4, H, dim
|
||||
torch::Tensor key, // B, N2, H, dim
|
||||
torch::Tensor index) // B, N1, K, 4, H
|
||||
{
|
||||
|
||||
const auto B = query.size(0);
|
||||
const auto N1 = query.size(1);
|
||||
const auto H = query.size(3);
|
||||
const auto D = query.size(4);
|
||||
const auto K = index.size(-2);
|
||||
|
||||
|
||||
auto output = torch::zeros({B, N1, 4, K, H},torch::device(torch::kCUDA));
|
||||
|
||||
int shared_memory_per_block = H*D;
|
||||
|
||||
dim3 totalBlocks(B, N1, 4);
|
||||
dim3 threadsPerBlock(THREADS_PER_WARP);
|
||||
AT_DISPATCH_FLOATING_TYPES(query.type(), "ScoreData_ongpu", ([&] {
|
||||
ScoreData<scalar_t><<<totalBlocks, threadsPerBlock, shared_memory_per_block * sizeof(scalar_t)>>>(
|
||||
query.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
||||
key.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
|
||||
index.packed_accessor32<long,4,torch::RestrictPtrTraits>(),
|
||||
output.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>());
|
||||
}));
|
||||
return {output};
|
||||
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void ScoreDataBackward(
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> grad, //B, N1, 4, K*4, H
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> query, //B, N1, 4, H, dim
|
||||
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> key, // B, N2, H, dim
|
||||
torch::PackedTensorAccessor32<long,4,torch::RestrictPtrTraits> index,// B, N1, K*4, H
|
||||
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> query_grad, //B, N1, 4, H, D
|
||||
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> key_grad //B, N2, H, D
|
||||
){
|
||||
int b = blockIdx.x;
|
||||
int n1 = blockIdx.y;
|
||||
int f = blockIdx.z;
|
||||
|
||||
extern __shared__ char patch_data_char[];
|
||||
|
||||
|
||||
int ch_off = threadIdx.x;
|
||||
|
||||
int D=query.size(4);
|
||||
int H=query.size(3);
|
||||
int HD=H*D;
|
||||
int K=index.size(2);
|
||||
|
||||
scalar_t *query_data = (scalar_t *)patch_data_char;
|
||||
|
||||
scalar_t *grad_data = (scalar_t *) (HD*sizeof(scalar_t)+patch_data_char);
|
||||
|
||||
|
||||
for(int ch = ch_off; ch <HD; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS
|
||||
query_data[ch] = query[b][n1][f][ch/D][ch%D];
|
||||
}
|
||||
for(int ch = ch_off; ch <K*H; ch += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS
|
||||
grad_data[ch] = grad[b][n1][f][ch/H][ch%H];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for(int k = ch_off; k < K; k += (WARPS_PER_BLOCK*THREADS_PER_WARP)) { // CHANNELS
|
||||
|
||||
for(int h=0;h<H;h++){
|
||||
int idx=index[b][n1][k][h];
|
||||
for(int d=0;d<D;d++){
|
||||
|
||||
atomicAdd(&query_grad[b][n1][f][h][d], grad_data[k*H+h]*key[b][idx][h][d]);
|
||||
atomicAdd(&key_grad[b][idx][h][d],grad_data[k*H+h]*query_data[h*D+d]);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> ScoreData_backward_ongpu(torch::Tensor grad_output1, //B, N1, 4, K*4, H
|
||||
torch::Tensor query, //B, N1, 4, H, dim
|
||||
torch::Tensor key, //B, N2, H, dim
|
||||
torch::Tensor index) //B, N1, K*4, H
|
||||
|
||||
{
|
||||
|
||||
const auto B = grad_output1.size(0);
|
||||
const auto N1 = grad_output1.size(1);
|
||||
const auto N2 = key.size(1);
|
||||
const auto K = grad_output1.size(3);
|
||||
const auto H = key.size(2);
|
||||
const auto D = key.size(3);
|
||||
|
||||
|
||||
auto query_grad = torch::zeros({B, N1, 4, H, D},torch::device(torch::kCUDA));
|
||||
|
||||
auto key_grad = torch::zeros({B, N2, H, D},torch::device(torch::kCUDA));
|
||||
|
||||
|
||||
int shared_memory_per_block = H*D+K*H;
|
||||
|
||||
dim3 totalBlocks(B, N1, 4);
|
||||
dim3 threadsPerBlock(THREADS_PER_WARP);
|
||||
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(key.type(), "ScoreDatabackward_ongpu", ([&] {
|
||||
ScoreDataBackward<scalar_t><<<totalBlocks, threadsPerBlock, shared_memory_per_block * sizeof(scalar_t)>>>(
|
||||
grad_output1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
||||
query.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
||||
key.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
|
||||
index.packed_accessor32<long,4,torch::RestrictPtrTraits>(),
|
||||
query_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
||||
key_grad.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>()
|
||||
);
|
||||
}));
|
||||
|
||||
return {query_grad, key_grad};
|
||||
|
||||
}
|
||||
28
modelscope/ops/quadtree_attention/src/utils.h
Normal file
28
modelscope/ops/quadtree_attention/src/utils.h
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
|
||||
class Formatter {
|
||||
public:
|
||||
Formatter() {}
|
||||
~Formatter() {}
|
||||
|
||||
template <typename Type> Formatter &operator<<(const Type &value) {
|
||||
stream_ << value;
|
||||
return *this;
|
||||
}
|
||||
|
||||
std::string str() const { return stream_.str(); }
|
||||
operator std::string() const { return stream_.str(); }
|
||||
|
||||
enum ConvertToString { to_str };
|
||||
|
||||
std::string operator>>(ConvertToString) { return stream_.str(); }
|
||||
|
||||
private:
|
||||
std::stringstream stream_;
|
||||
Formatter(const Formatter &);
|
||||
Formatter &operator=(Formatter &);
|
||||
};
|
||||
67
modelscope/ops/quadtree_attention/src/value_aggregation.cpp
Normal file
67
modelscope/ops/quadtree_attention/src/value_aggregation.cpp
Normal file
@@ -0,0 +1,67 @@
|
||||
// Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "value_aggregation.h"
|
||||
//extern THCState *state;
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
void value_aggregation_cuda_forward(
|
||||
at::Tensor score, // B, N, K, H
|
||||
at::Tensor value, // B, M, H, D
|
||||
at::Tensor index, // B, N, K, H
|
||||
at::Tensor output)// B, N, H, D
|
||||
{
|
||||
CHECK_INPUT(score);
|
||||
CHECK_INPUT(value);
|
||||
CHECK_INPUT(index);
|
||||
auto score_size = score.sizes();
|
||||
auto value_size = value.sizes();
|
||||
int B = score_size[0];
|
||||
int N = score_size[1];
|
||||
int K = score_size[2];
|
||||
int H = score_size[3];
|
||||
int M = value_size[1];
|
||||
int D = value_size[3];
|
||||
|
||||
|
||||
value_aggregation_forward_kernel(score.data<float>(), value.data<float>(),
|
||||
index.data<long>(), output.data<float>(), B, N, K, H, M, D,
|
||||
at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
void value_aggregation_cuda_backward(
|
||||
at::Tensor grad_output, // B, N, H, D
|
||||
at::Tensor score, // B, N, K, H
|
||||
at::Tensor value, // B, M, H, D
|
||||
at::Tensor index, // B, N, K, H
|
||||
at::Tensor grad_score, // B, N, K, H
|
||||
at::Tensor grad_value // B, M, H, D
|
||||
)
|
||||
{
|
||||
CHECK_INPUT(score);
|
||||
CHECK_INPUT(value);
|
||||
CHECK_INPUT(index);
|
||||
CHECK_INPUT(grad_output);
|
||||
|
||||
auto score_size = score.sizes();
|
||||
auto value_size = value.sizes();
|
||||
int B = score_size[0];
|
||||
int N = score_size[1];
|
||||
int K = score_size[2];
|
||||
int H = score_size[3];
|
||||
int M = value_size[1];
|
||||
int D = value_size[3];
|
||||
|
||||
|
||||
value_aggregation_backward_kernel(grad_output.data<float>(), score.data<float>(),
|
||||
value.data<float>(), index.data<long>(), grad_score.data<float>(), grad_value.data<float>(),
|
||||
B, N, K, H, M, D, at::cuda::getCurrentCUDAStream());
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("value_aggregation_forward", &value_aggregation_cuda_forward, "value forward (CUDA)");
|
||||
m.def("value_aggregation_backward", &value_aggregation_cuda_backward, "value backward (CUDA)");
|
||||
}
|
||||
19
modelscope/ops/quadtree_attention/src/value_aggregation.h
Normal file
19
modelscope/ops/quadtree_attention/src/value_aggregation.h
Normal file
@@ -0,0 +1,19 @@
|
||||
// Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
#ifndef _VALUE_AGGREGATION_
|
||||
#define _VALUE_AGGREGATION_
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
void value_aggregation_forward_kernel(float* score, // B, N, K, H
|
||||
float* value, // B, M, H, D
|
||||
long* index, // B, N, K, H
|
||||
float* output, // B, N, H, D
|
||||
int B, int N, int K, int H, int M, int D, cudaStream_t stream
|
||||
);
|
||||
|
||||
void value_aggregation_cuda_forward(at::Tensor score, at::Tensor value, at::Tensor index, at::Tensor output);
|
||||
|
||||
void value_aggregation_backward_kernel(float* grad_output, float* score, float* value,long* index, float* grad_score, float* grad_value, int B, int N, int K, int H, int M, int D, cudaStream_t stream);
|
||||
|
||||
#endif // _VALUE_AGGREGATION_
|
||||
@@ -0,0 +1,88 @@
|
||||
// Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
#include <vector>
|
||||
#include <stdio.h>
|
||||
#include <math.h>
|
||||
#include <float.h>
|
||||
#include <vector>
|
||||
#include "value_aggregation.h"
|
||||
#include "THC/THCAtomics.cuh"
|
||||
#include <stdio.h>
|
||||
#include "utils.h"
|
||||
|
||||
#define ROUND_OFF 50000
|
||||
|
||||
#define CUDA_NUM_THREADS 1024
|
||||
#define WARPS_PER_BLOCK 1
|
||||
#define THREADS_PER_WARP 32
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
|
||||
|
||||
#define GET_BLOCKS(n, t) (n+t-1) / t
|
||||
|
||||
__global__ void ValueAggregationForwardFunc(float* score, float* value, long* index, float* output, int B, int N, int K, int H, int M, int D) {
|
||||
///*
|
||||
long LENGTH = B*N*H*D;
|
||||
CUDA_KERNEL_LOOP(cur_idx, LENGTH){
|
||||
long d_idx = cur_idx % D;
|
||||
long h_idx = (cur_idx - d_idx) / D % H;
|
||||
long n_idx = (cur_idx - d_idx - h_idx * D) / D / H % N;
|
||||
long b_idx = (cur_idx - d_idx - h_idx * D - n_idx * H * D) / D / H / N;
|
||||
if (cur_idx < LENGTH) {
|
||||
long score_start_idx = b_idx * N * K * H + n_idx * K * H + h_idx;
|
||||
long value_start_idx = b_idx * M * H * D + h_idx * D + d_idx;
|
||||
|
||||
float out_val = 0;
|
||||
for(int k_idx = 0; k_idx < K; k_idx++){
|
||||
int score_idx = score_start_idx + k_idx * H;
|
||||
int value_idx = value_start_idx + index[score_idx] * H * D;
|
||||
out_val += score[score_idx] * value[value_idx];
|
||||
}
|
||||
output[cur_idx] = out_val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void value_aggregation_forward_kernel(float* score, float* value, long* index, float* ouput, int B, int N, int K, int H, int M, int D, cudaStream_t stream){
|
||||
ValueAggregationForwardFunc
|
||||
<<<GET_BLOCKS(B*N*H*D, CUDA_NUM_THREADS), CUDA_NUM_THREADS, 0, stream>>>(score, value, index, ouput, B, N, K, H, M, D);
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (cudaSuccess != err)
|
||||
throw std::runtime_error(Formatter()
|
||||
<< "CUDA kernel failed : " << std::to_string(err));
|
||||
}
|
||||
|
||||
__global__ void ValueAggregationBackwardFunc(float* grad_output, float* score, float* value, long* index, float* grad_score,
|
||||
float* grad_value, int B, int N, int K, int H, int M, int D) {
|
||||
long LENGTH = B*N*K*H;
|
||||
CUDA_KERNEL_LOOP(cur_idx, LENGTH){
|
||||
long h_idx = cur_idx % H;
|
||||
long k_idx = (cur_idx - h_idx) / H % K;
|
||||
long n_idx = (cur_idx - h_idx - k_idx * H) / H / K % N;
|
||||
long b_idx = (cur_idx - h_idx - k_idx * H - n_idx * H * K) / H / K / N;
|
||||
|
||||
if (cur_idx < LENGTH) {
|
||||
long output_start_idx = b_idx * N * H * D + n_idx * H * D + h_idx * D;
|
||||
long value_start_idx = b_idx * M * H * D + h_idx * D;
|
||||
for (int d_idx = 0; d_idx < D; d_idx ++){
|
||||
long output_idx = output_start_idx + d_idx;
|
||||
long value_idx = value_start_idx + index[cur_idx] * H * D + d_idx;
|
||||
auto grad_output_val = grad_output[output_idx];
|
||||
grad_score[cur_idx] += grad_output_val * value[value_idx];
|
||||
gpuAtomicAdd(&grad_value[value_idx], grad_output_val * score[cur_idx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void value_aggregation_backward_kernel(float* grad_output, float* score, float* value, long* index, float* grad_score, float* grad_value, int B, int N, int K, int H, int M, int D, cudaStream_t stream){
|
||||
ValueAggregationBackwardFunc
|
||||
<<<GET_BLOCKS(B*N*K*H, CUDA_NUM_THREADS), CUDA_NUM_THREADS, 0, stream>>>(grad_output, score, value, index, grad_score, grad_value, B, N, K, H, M, D);
|
||||
|
||||
cudaError_t err = cudaGetLastError();
|
||||
if (cudaSuccess != err)
|
||||
throw std::runtime_error(Formatter()
|
||||
<< "CUDA kernel failed : " << std::to_string(err));
|
||||
}
|
||||
@@ -52,6 +52,7 @@ class OutputKeys(object):
|
||||
SCENE_NUM = 'scene_num'
|
||||
SCENE_META_LIST = 'scene_meta_list'
|
||||
SHOT_META_LIST = 'shot_meta_list'
|
||||
MATCHES = 'matches'
|
||||
PCD12 = 'pcd12'
|
||||
PCD12_ALIGN = 'pcd12_align'
|
||||
|
||||
|
||||
@@ -271,6 +271,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_flow-based-body-reshaping_damo'),
|
||||
Tasks.image_face_fusion: (Pipelines.image_face_fusion,
|
||||
'damo/cv_unet-image-face-fusion_damo'),
|
||||
Tasks.image_matching: (
|
||||
Pipelines.image_matching,
|
||||
'damo/cv_quadtree_attention_image-matching_outdoor'),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -74,6 +74,7 @@ if TYPE_CHECKING:
|
||||
from .image_skychange_pipeline import ImageSkychangePipeline
|
||||
from .vop_retrieval_pipeline import VopRetrievalPipeline
|
||||
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
|
||||
from .image_matching_pipeline import ImageMatchingPipeline
|
||||
from .video_stabilization_pipeline import VideoStabilizationPipeline
|
||||
from .video_super_resolution_pipeline import VideoSuperResolutionPipeline
|
||||
from .pointcloud_sceneflow_estimation_pipeline import PointCloudSceneFlowEstimationPipeline
|
||||
@@ -180,6 +181,7 @@ else:
|
||||
'video_object_segmentation_pipeline': [
|
||||
'VideoObjectSegmentationPipeline'
|
||||
],
|
||||
'image_matching_pipeline': ['ImageMatchingPipeline'],
|
||||
'video_stabilization_pipeline': ['VideoStabilizationPipeline'],
|
||||
'video_super_resolution_pipeline': ['VideoSuperResolutionPipeline'],
|
||||
'pointcloud_sceneflow_estimation_pipeline': [
|
||||
|
||||
175
modelscope/pipelines/cv/image_matching_pipeline.py
Normal file
175
modelscope/pipelines/cv/image_matching_pipeline.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_matching, module_name=Pipelines.image_matching)
|
||||
class ImageMatchingPipeline(Pipeline):
|
||||
""" Image Matching Pipeline.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
task = 'image-matching'
|
||||
model_id = 'damo/cv_quadtree_attention_image-matching_outdoor'
|
||||
|
||||
input_location = [
|
||||
['data/test/images/image_matching1.jpg',
|
||||
'data/test/images/image_matching2.jpg']
|
||||
]
|
||||
estimator = pipeline(Tasks.image_matching, model=self.model_id)
|
||||
result = estimator(input_location)
|
||||
kpts0, kpts1, conf = result[0][OutputKeys.MATCHES]
|
||||
print(f'Found {len(kpts0)} matches')
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image matching pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
# check if cuda is available
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
'Cuda is not available. Image matching model only supports cuda.'
|
||||
)
|
||||
|
||||
logger.info('image matching model, pipeline init')
|
||||
|
||||
def resize_image(self, img, max_image_size):
|
||||
h, w = img.shape[:2]
|
||||
scale = 1
|
||||
if max(h, w) > max_image_size:
|
||||
scale = max_image_size / max(h, w)
|
||||
new_w, new_h = int(w * scale), int(h * scale)
|
||||
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||
return img, scale
|
||||
|
||||
def compute_paded_size(self, size, div):
|
||||
return int(np.ceil(size / div) * div)
|
||||
|
||||
def pad_image(self, img, h=None, w=None, div=32):
|
||||
cur_h, cur_w = img.shape[:2]
|
||||
if h is None and w is None:
|
||||
h, w = cur_h, cur_w
|
||||
h_pad, w_pad = self.compute_paded_size(h,
|
||||
div), self.compute_paded_size(
|
||||
w, div)
|
||||
img = cv2.copyMakeBorder(
|
||||
img,
|
||||
0,
|
||||
h_pad - cur_h,
|
||||
0,
|
||||
w_pad - cur_w,
|
||||
cv2.BORDER_CONSTANT,
|
||||
value=0)
|
||||
return img
|
||||
|
||||
def load_image(self, img_name):
|
||||
img = LoadImage.convert_to_ndarray(img_name).astype(np.float32)
|
||||
img = img / 255.
|
||||
# convert rgb to gray
|
||||
if len(img.shape) == 3:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||
return img
|
||||
|
||||
def preprocess(self, input: Input, max_image_size=1024):
|
||||
assert len(input) == 2, 'input should be a list of two images'
|
||||
|
||||
img1 = self.load_image(input[0])
|
||||
img1, scale1 = self.resize_image(img1, max_image_size)
|
||||
scaled_h1, scaled_w1 = img1.shape[:2]
|
||||
|
||||
img2 = self.load_image(input[1])
|
||||
img2, scale2 = self.resize_image(img2, max_image_size)
|
||||
scaled_h2, scaled_w2 = img2.shape[:2]
|
||||
|
||||
h_max, w_max = max(scaled_h1, scaled_h2), max(scaled_w1, scaled_w2)
|
||||
img1 = self.pad_image(img1, h_max, w_max)
|
||||
img2 = self.pad_image(img2, h_max, w_max)
|
||||
|
||||
img1 = torch.from_numpy(img1)[None][None].cuda().float()
|
||||
img2 = torch.from_numpy(img2)[None][None].cuda().float()
|
||||
return {
|
||||
'image0':
|
||||
img1,
|
||||
'image1':
|
||||
img2,
|
||||
'preprocess_info':
|
||||
[scale1, scale2, scaled_h1, scaled_w1, scaled_h2, scaled_w2]
|
||||
}
|
||||
|
||||
def postprocess_match(self, kpt1, kpt2, conf, scale1, scale2, scaled_h1,
|
||||
scaled_w1, scaled_h2, scaled_w2):
|
||||
# filter out points outside the image
|
||||
valid_match = (kpt1[:, 0] < scaled_w1) & (kpt1[:, 1] < scaled_h1) & (
|
||||
kpt2[:, 0] < scaled_w2) & (
|
||||
kpt2[:, 1] < scaled_h2)
|
||||
kpt1, kpt2 = kpt1[valid_match], kpt2[valid_match]
|
||||
kpt1 = kpt1 / scale1
|
||||
kpt2 = kpt2 / scale2
|
||||
conf = conf[valid_match]
|
||||
|
||||
return kpt1, kpt2, conf
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.inference(input)
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.postprocess(inputs)
|
||||
matches = results[OutputKeys.MATCHES]
|
||||
|
||||
kpts0 = matches['kpts0'].cpu().numpy()
|
||||
kpts1 = matches['kpts1'].cpu().numpy()
|
||||
conf = matches['conf'].cpu().numpy()
|
||||
preprocess_info = [v.cpu().numpy() for v in inputs['preprocess_info']]
|
||||
kpts0, kpts1, conf = self.postprocess_match(kpts0, kpts1, conf,
|
||||
*preprocess_info)
|
||||
|
||||
outputs = {
|
||||
OutputKeys.MATCHES: [kpts0, kpts1, conf],
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
def __call__(self, input, **kwargs):
|
||||
"""
|
||||
Match two images and return the matched keypoints and confidence.
|
||||
|
||||
Args:
|
||||
input (`List[List[str]]`): A list of two image paths.
|
||||
|
||||
Return:
|
||||
A list of result.
|
||||
The list contain the following values:
|
||||
|
||||
- kpts0 -- Matched keypoints in the first image
|
||||
- kpts1 -- Matched keypoints in the second image
|
||||
- conf -- Confidence of the match
|
||||
"""
|
||||
return super().__call__(input, **kwargs)
|
||||
@@ -60,6 +60,7 @@ class CVTasks(object):
|
||||
face_human_hand_detection = 'face-human-hand-detection'
|
||||
face_emotion = 'face-emotion'
|
||||
product_segmentation = 'product-segmentation'
|
||||
image_matching = 'image-matching'
|
||||
|
||||
crowd_counting = 'crowd-counting'
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import matplotlib
|
||||
import matplotlib.cm as cm
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
@@ -499,3 +501,99 @@ def masks_visualization(masks, palette):
|
||||
img_E.putpalette(palette)
|
||||
vis_masks.append(img_E)
|
||||
return vis_masks
|
||||
|
||||
|
||||
# This implementation is adopted from LoFTR,
|
||||
# made public available under the Apache License, Version 2.0,
|
||||
# at https://github.com/zju3dv/LoFTR
|
||||
|
||||
|
||||
def make_matching_figure(img0,
|
||||
img1,
|
||||
mkpts0,
|
||||
mkpts1,
|
||||
color,
|
||||
kpts0=None,
|
||||
kpts1=None,
|
||||
text=[],
|
||||
dpi=75,
|
||||
path=None):
|
||||
# draw image pair
|
||||
assert mkpts0.shape[0] == mkpts1.shape[
|
||||
0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
||||
axes[0].imshow(img0, cmap='gray')
|
||||
axes[1].imshow(img1, cmap='gray')
|
||||
for i in range(2): # clear all frames
|
||||
axes[i].get_yaxis().set_ticks([])
|
||||
axes[i].get_xaxis().set_ticks([])
|
||||
for spine in axes[i].spines.values():
|
||||
spine.set_visible(False)
|
||||
plt.tight_layout(pad=1)
|
||||
|
||||
if kpts0 is not None:
|
||||
assert kpts1 is not None
|
||||
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
|
||||
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
|
||||
|
||||
# draw matches
|
||||
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
||||
fig.canvas.draw()
|
||||
transFigure = fig.transFigure.inverted()
|
||||
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
||||
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
||||
fig.lines = [
|
||||
matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
|
||||
(fkpts0[i, 1], fkpts1[i, 1]),
|
||||
transform=fig.transFigure,
|
||||
c=color[i],
|
||||
linewidth=1) for i in range(len(mkpts0))
|
||||
]
|
||||
|
||||
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
||||
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
||||
|
||||
# put txts
|
||||
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
|
||||
fig.text(
|
||||
0.01,
|
||||
0.99,
|
||||
'\n'.join(text),
|
||||
transform=fig.axes[0].transAxes,
|
||||
fontsize=15,
|
||||
va='top',
|
||||
ha='left',
|
||||
color=txt_color)
|
||||
|
||||
# save or return figure
|
||||
if path:
|
||||
plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
|
||||
plt.close()
|
||||
else:
|
||||
return fig
|
||||
|
||||
|
||||
def match_pair_visualization(img_name0,
|
||||
img_name1,
|
||||
kpts0,
|
||||
kpts1,
|
||||
conf,
|
||||
output_filename='quadtree_match.png',
|
||||
method='QuadTreeAttention'):
|
||||
|
||||
print(f'Found {len(kpts0)} matches')
|
||||
|
||||
# visualize the matches
|
||||
img0 = cv2.imread(str(img_name0))
|
||||
img1 = cv2.imread(str(img_name1))
|
||||
|
||||
# Draw
|
||||
color = cm.jet(conf)
|
||||
text = [
|
||||
method,
|
||||
'Matches: {}'.format(len(kpts0)),
|
||||
]
|
||||
fig = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
|
||||
|
||||
# save the figure
|
||||
fig.savefig(str(output_filename), dpi=300, bbox_inches='tight')
|
||||
|
||||
3
setup.py
3
setup.py
@@ -201,6 +201,9 @@ if __name__ == '__main__':
|
||||
url='https://github.com/modelscope/modelscope',
|
||||
packages=find_packages(exclude=('configs', 'tools', 'demo')),
|
||||
include_package_data=True,
|
||||
package_data={
|
||||
'': ['*.h', '*.cpp', '*.cu'],
|
||||
},
|
||||
classifiers=[
|
||||
'Development Status :: 4 - Beta',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
|
||||
46
tests/pipelines/test_image_matching.py
Normal file
46
tests/pipelines/test_image_matching.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.cm as cm
|
||||
import numpy as np
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import match_pair_visualization
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageMatchingTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = 'image-matching'
|
||||
self.model_id = 'damo/cv_quadtree_attention_image-matching_outdoor'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_matching(self):
|
||||
input_location = [[
|
||||
'data/test/images/image_matching1.jpg',
|
||||
'data/test/images/image_matching2.jpg'
|
||||
]]
|
||||
estimator = pipeline(Tasks.image_matching, model=self.model_id)
|
||||
result = estimator(input_location)
|
||||
kpts0, kpts1, conf = result[0][OutputKeys.MATCHES]
|
||||
|
||||
match_pair_visualization(
|
||||
input_location[0][0],
|
||||
input_location[0][1],
|
||||
kpts0,
|
||||
kpts1,
|
||||
conf,
|
||||
output_filename='quadtree_match.png')
|
||||
|
||||
print('test_image_matching DONE')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user