mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 08:17:45 +01:00
Feature/LoFTR_image_local_feature_matching (#687)
* Add loftr image local feature matching. * add pipeline doc str and remove example data as examples exists in data/test * update pipeline doc str. * add pipeline doc str add pipeline doc str --------- Co-authored-by: 翼生 <heyisheng.hys@alibaba-inc.com> Co-authored-by: wenmeng zhou <wenmeng.zwm@alibaba-inc.com>
This commit is contained in:
committed by
GitHub
parent
a9d5b88407
commit
94ce1ebd7a
@@ -88,6 +88,7 @@ class Models(object):
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
|
||||
loftr_image_local_feature_matching = 'loftr-image-local-feature-matching'
|
||||
lightglue_image_matching = 'lightglue-image-matching'
|
||||
vision_middleware = 'vision-middleware'
|
||||
vidt = 'vidt'
|
||||
@@ -395,6 +396,7 @@ class Pipelines(object):
|
||||
image_depth_estimation = 'image-depth-estimation'
|
||||
image_normal_estimation = 'image-normal-estimation'
|
||||
indoor_layout_estimation = 'indoor-layout-estimation'
|
||||
image_local_feature_matching = 'image-local-feature-matching'
|
||||
video_depth_estimation = 'video-depth-estimation'
|
||||
panorama_depth_estimation = 'panorama-depth-estimation'
|
||||
panorama_depth_estimation_s2net = 'panorama-depth-estimation-s2net'
|
||||
@@ -805,6 +807,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.panorama_depth_estimation:
|
||||
(Pipelines.panorama_depth_estimation,
|
||||
'damo/cv_unifuse_panorama-depth-estimation'),
|
||||
Tasks.image_local_feature_matching:
|
||||
(Pipelines.image_local_feature_matching, 'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data'),
|
||||
Tasks.image_style_transfer: (Pipelines.image_style_transfer,
|
||||
'damo/cv_aams_style-transfer_damo'),
|
||||
Tasks.face_image_generation: (Pipelines.face_image_generation,
|
||||
|
||||
@@ -29,6 +29,6 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
video_panoptic_segmentation, video_single_object_tracking,
|
||||
video_stabilization, video_summarization,
|
||||
video_super_resolution, vidt, virual_tryon, vision_middleware,
|
||||
vop_retrieval,image_matching_fast)
|
||||
vop_retrieval, image_local_feature_matching,image_matching_fast)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loftr_model import LocalFeatureMatching
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'loftr_image_local_feature_matching': ['LocalFeatureMatching'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
|
||||
import io
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.image_local_feature_matching.src.loftr import \
|
||||
LoFTR, default_cfg
|
||||
from modelscope.models.cv.image_local_feature_matching.src.utils.plotting import make_matching_figure
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
import matplotlib.cm as cm
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_local_feature_matching,
|
||||
module_name=Models.loftr_image_local_feature_matching)
|
||||
class LocalFeatureMatching(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
# build model
|
||||
# Initialize LoFTR
|
||||
_default_cfg = deepcopy(default_cfg)
|
||||
self.model = LoFTR(config=_default_cfg)
|
||||
|
||||
# load model
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
self.model.load_state_dict(checkpoint['state_dict'])
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, Inputs):
|
||||
self.model(Inputs)
|
||||
result = {
|
||||
'kpts0': Inputs['mkpts0_f'],
|
||||
'kpts1': Inputs['mkpts1_f'],
|
||||
'conf': Inputs['mconf'],
|
||||
}
|
||||
Inputs.update(result)
|
||||
return Inputs
|
||||
|
||||
def postprocess(self, Inputs):
|
||||
# Draw
|
||||
color = cm.jet(Inputs['conf'].cpu().numpy())
|
||||
img0, img1, mkpts0, mkpts1 = Inputs["image0"].squeeze().cpu().numpy(), Inputs["image1"].squeeze().cpu().numpy(), Inputs["kpts0"].cpu().numpy(), Inputs["kpts1"].cpu().numpy()
|
||||
text = [
|
||||
'LoFTR',
|
||||
'Matches: {}'.format(len(Inputs['kpts0'])),
|
||||
]
|
||||
img0, img1 = (img0 * 255).astype(np.uint8), (img1 * 255).astype(np.uint8)
|
||||
fig = make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text)
|
||||
io_buf = io.BytesIO()
|
||||
fig.savefig(io_buf, format="png", dpi=75)
|
||||
io_buf.seek(0)
|
||||
buf_data = np.frombuffer(io_buf.getvalue(), dtype=np.uint8)
|
||||
io_buf.close()
|
||||
vis_img = cv2.imdecode(buf_data, 1)
|
||||
|
||||
results = {OutputKeys.MATCHES: Inputs, OutputKeys.OUTPUT_IMG: vis_img}
|
||||
return results
|
||||
|
||||
def inference(self, data):
|
||||
results = self.forward(data)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,2 @@
|
||||
from .loftr import LoFTR
|
||||
from .utils.cvpr_ds_config import default_cfg
|
||||
@@ -0,0 +1,11 @@
|
||||
from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4
|
||||
|
||||
|
||||
def build_backbone(config):
|
||||
if config['backbone_type'] == 'ResNetFPN':
|
||||
if config['resolution'] == (8, 2):
|
||||
return ResNetFPN_8_2(config['resnetfpn'])
|
||||
elif config['resolution'] == (16, 4):
|
||||
return ResNetFPN_16_4(config['resnetfpn'])
|
||||
else:
|
||||
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
||||
@@ -0,0 +1,199 @@
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution without padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(in_planes, planes, stride)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
conv1x1(in_planes, planes, stride=stride),
|
||||
nn.BatchNorm2d(planes)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.bn1(self.conv1(y)))
|
||||
y = self.bn2(self.conv2(y))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
|
||||
class ResNetFPN_8_2(nn.Module):
|
||||
"""
|
||||
ResNet+FPN, output resolution are 1/8 and 1/2.
|
||||
Each block has 2 layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Config
|
||||
block = BasicBlock
|
||||
initial_dim = config['initial_dim']
|
||||
block_dims = config['block_dims']
|
||||
|
||||
# Class Variable
|
||||
self.in_planes = initial_dim
|
||||
|
||||
# Networks
|
||||
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(initial_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
||||
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
||||
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
||||
|
||||
# 3. FPN upsample
|
||||
self.layer3_outconv = conv1x1(block_dims[2], block_dims[2])
|
||||
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
||||
self.layer2_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[2], block_dims[2]),
|
||||
nn.BatchNorm2d(block_dims[2]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[2], block_dims[1]),
|
||||
)
|
||||
self.layer1_outconv = conv1x1(block_dims[0], block_dims[1])
|
||||
self.layer1_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[1], block_dims[1]),
|
||||
nn.BatchNorm2d(block_dims[1]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[1], block_dims[0]),
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, dim, stride=1):
|
||||
layer1 = block(self.in_planes, dim, stride=stride)
|
||||
layer2 = block(dim, dim, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# ResNet Backbone
|
||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
||||
x1 = self.layer1(x0) # 1/2
|
||||
x2 = self.layer2(x1) # 1/4
|
||||
x3 = self.layer3(x2) # 1/8
|
||||
|
||||
# FPN
|
||||
x3_out = self.layer3_outconv(x3)
|
||||
|
||||
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x2_out = self.layer2_outconv(x2)
|
||||
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
||||
|
||||
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x1_out = self.layer1_outconv(x1)
|
||||
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
||||
|
||||
return [x3_out, x1_out]
|
||||
|
||||
|
||||
class ResNetFPN_16_4(nn.Module):
|
||||
"""
|
||||
ResNet+FPN, output resolution are 1/16 and 1/4.
|
||||
Each block has 2 layers.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Config
|
||||
block = BasicBlock
|
||||
initial_dim = config['initial_dim']
|
||||
block_dims = config['block_dims']
|
||||
|
||||
# Class Variable
|
||||
self.in_planes = initial_dim
|
||||
|
||||
# Networks
|
||||
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(initial_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2
|
||||
self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4
|
||||
self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8
|
||||
self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16
|
||||
|
||||
# 3. FPN upsample
|
||||
self.layer4_outconv = conv1x1(block_dims[3], block_dims[3])
|
||||
self.layer3_outconv = conv1x1(block_dims[2], block_dims[3])
|
||||
self.layer3_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[3], block_dims[3]),
|
||||
nn.BatchNorm2d(block_dims[3]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[3], block_dims[2]),
|
||||
)
|
||||
|
||||
self.layer2_outconv = conv1x1(block_dims[1], block_dims[2])
|
||||
self.layer2_outconv2 = nn.Sequential(
|
||||
conv3x3(block_dims[2], block_dims[2]),
|
||||
nn.BatchNorm2d(block_dims[2]),
|
||||
nn.LeakyReLU(),
|
||||
conv3x3(block_dims[2], block_dims[1]),
|
||||
)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, block, dim, stride=1):
|
||||
layer1 = block(self.in_planes, dim, stride=stride)
|
||||
layer2 = block(dim, dim, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
# ResNet Backbone
|
||||
x0 = self.relu(self.bn1(self.conv1(x)))
|
||||
x1 = self.layer1(x0) # 1/2
|
||||
x2 = self.layer2(x1) # 1/4
|
||||
x3 = self.layer3(x2) # 1/8
|
||||
x4 = self.layer4(x3) # 1/16
|
||||
|
||||
# FPN
|
||||
x4_out = self.layer4_outconv(x4)
|
||||
|
||||
x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x3_out = self.layer3_outconv(x3)
|
||||
x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
|
||||
|
||||
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x2_out = self.layer2_outconv(x2)
|
||||
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
||||
|
||||
return [x4_out, x2_out]
|
||||
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops.einops import rearrange
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .utils.position_encoding import PositionEncodingSine
|
||||
from .loftr_module import LocalFeatureTransformer, FinePreprocess
|
||||
from .utils.coarse_matching import CoarseMatching
|
||||
from .utils.fine_matching import FineMatching
|
||||
|
||||
|
||||
class LoFTR(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Misc
|
||||
self.config = config
|
||||
|
||||
# Modules
|
||||
self.backbone = build_backbone(config)
|
||||
self.pos_encoding = PositionEncodingSine(
|
||||
config['coarse']['d_model'],
|
||||
temp_bug_fix=config['coarse']['temp_bug_fix'])
|
||||
self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
|
||||
self.coarse_matching = CoarseMatching(config['match_coarse'])
|
||||
self.fine_preprocess = FinePreprocess(config)
|
||||
self.loftr_fine = LocalFeatureTransformer(config["fine"])
|
||||
self.fine_matching = FineMatching()
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
Update:
|
||||
data (dict): {
|
||||
'image0': (torch.Tensor): (N, 1, H, W)
|
||||
'image1': (torch.Tensor): (N, 1, H, W)
|
||||
'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position
|
||||
'mask1'(optional) : (torch.Tensor): (N, H, W)
|
||||
}
|
||||
"""
|
||||
# 1. Local Feature CNN
|
||||
data.update({
|
||||
'bs': data['image0'].size(0),
|
||||
'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
|
||||
})
|
||||
|
||||
if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
|
||||
feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
|
||||
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
|
||||
else: # handle different input shapes
|
||||
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
|
||||
|
||||
data.update({
|
||||
'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
|
||||
'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
|
||||
})
|
||||
|
||||
# 2. coarse-level loftr module
|
||||
# add featmap with positional encoding, then flatten it to sequence [N, HW, C]
|
||||
feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c')
|
||||
feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c')
|
||||
|
||||
mask_c0 = mask_c1 = None # mask is useful in training
|
||||
if 'mask0' in data:
|
||||
mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
|
||||
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
|
||||
|
||||
# 3. match coarse-level
|
||||
self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
|
||||
|
||||
# 4. fine-level refinement
|
||||
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
|
||||
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
|
||||
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
|
||||
|
||||
# 5. match fine-level
|
||||
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
|
||||
|
||||
def load_state_dict(self, state_dict, *args, **kwargs):
|
||||
for k in list(state_dict.keys()):
|
||||
if k.startswith('matcher.'):
|
||||
state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k)
|
||||
return super().load_state_dict(state_dict, *args, **kwargs)
|
||||
@@ -0,0 +1,2 @@
|
||||
from .transformer import LocalFeatureTransformer
|
||||
from .fine_preprocess import FinePreprocess
|
||||
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.einops import rearrange, repeat
|
||||
|
||||
|
||||
class FinePreprocess(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.cat_c_feat = config['fine_concat_coarse_feat']
|
||||
self.W = self.config['fine_window_size']
|
||||
|
||||
d_model_c = self.config['coarse']['d_model']
|
||||
d_model_f = self.config['fine']['d_model']
|
||||
self.d_model_f = d_model_f
|
||||
if self.cat_c_feat:
|
||||
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
|
||||
self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
|
||||
|
||||
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
|
||||
W = self.W
|
||||
stride = data['hw0_f'][0] // data['hw0_c'][0]
|
||||
|
||||
data.update({'W': W})
|
||||
if data['b_ids'].shape[0] == 0:
|
||||
feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
return feat0, feat1
|
||||
|
||||
# 1. unfold(crop) all local windows
|
||||
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
|
||||
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
|
||||
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
|
||||
# 2. select only the predicted matches
|
||||
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
|
||||
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
|
||||
|
||||
# option: use coarse-level loftr feature as context: concat and linear
|
||||
if self.cat_c_feat:
|
||||
feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
|
||||
feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
|
||||
feat_cf_win = self.merge_feat(torch.cat([
|
||||
torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
|
||||
repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
|
||||
], -1))
|
||||
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
|
||||
|
||||
return feat_f0_unfold, feat_f1_unfold
|
||||
@@ -0,0 +1,81 @@
|
||||
"""
|
||||
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
||||
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, Dropout
|
||||
|
||||
|
||||
def elu_feature_map(x):
|
||||
return torch.nn.functional.elu(x) + 1
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
def __init__(self, eps=1e-6):
|
||||
super().__init__()
|
||||
self.feature_map = elu_feature_map
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
""" Multi-Head linear attention proposed in "Transformers are RNNs"
|
||||
Args:
|
||||
queries: [N, L, H, D]
|
||||
keys: [N, S, H, D]
|
||||
values: [N, S, H, D]
|
||||
q_mask: [N, L]
|
||||
kv_mask: [N, S]
|
||||
Returns:
|
||||
queried_values: (N, L, H, D)
|
||||
"""
|
||||
Q = self.feature_map(queries)
|
||||
K = self.feature_map(keys)
|
||||
|
||||
# set padded position to zero
|
||||
if q_mask is not None:
|
||||
Q = Q * q_mask[:, :, None, None]
|
||||
if kv_mask is not None:
|
||||
K = K * kv_mask[:, :, None, None]
|
||||
values = values * kv_mask[:, :, None, None]
|
||||
|
||||
v_length = values.size(1)
|
||||
values = values / v_length # prevent fp16 overflow
|
||||
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
||||
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
||||
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
||||
|
||||
return queried_values.contiguous()
|
||||
|
||||
|
||||
class FullAttention(Module):
|
||||
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
||||
super().__init__()
|
||||
self.use_dropout = use_dropout
|
||||
self.dropout = Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
||||
Args:
|
||||
queries: [N, L, H, D]
|
||||
keys: [N, S, H, D]
|
||||
values: [N, S, H, D]
|
||||
q_mask: [N, L]
|
||||
kv_mask: [N, S]
|
||||
Returns:
|
||||
queried_values: (N, L, H, D)
|
||||
"""
|
||||
|
||||
# Compute the unnormalized attention and apply the masks
|
||||
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
||||
if kv_mask is not None:
|
||||
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
|
||||
|
||||
# Compute the attention and the weighted average
|
||||
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
||||
A = torch.softmax(softmax_temp * QK, dim=2)
|
||||
if self.use_dropout:
|
||||
A = self.dropout(A)
|
||||
|
||||
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
|
||||
|
||||
return queried_values.contiguous()
|
||||
@@ -0,0 +1,101 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .linear_attention import LinearAttention, FullAttention
|
||||
|
||||
|
||||
class LoFTREncoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
attention='linear'):
|
||||
super(LoFTREncoderLayer, self).__init__()
|
||||
|
||||
self.dim = d_model // nhead
|
||||
self.nhead = nhead
|
||||
|
||||
# multi-head attention
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
||||
self.merge = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
# feed-forward network
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_model*2, d_model*2, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(d_model*2, d_model, bias=False),
|
||||
)
|
||||
|
||||
# norm and dropout
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x, source, x_mask=None, source_mask=None):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): [N, L, C]
|
||||
source (torch.Tensor): [N, S, C]
|
||||
x_mask (torch.Tensor): [N, L] (optional)
|
||||
source_mask (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
bs = x.size(0)
|
||||
query, key, value = x, source, source
|
||||
|
||||
# multi-head attention
|
||||
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
||||
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
||||
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
||||
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
||||
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
||||
message = self.norm1(message)
|
||||
|
||||
# feed-forward network
|
||||
message = self.mlp(torch.cat([x, message], dim=2))
|
||||
message = self.norm2(message)
|
||||
|
||||
return x + message
|
||||
|
||||
|
||||
class LocalFeatureTransformer(nn.Module):
|
||||
"""A Local Feature Transformer (LoFTR) module."""
|
||||
|
||||
def __init__(self, config):
|
||||
super(LocalFeatureTransformer, self).__init__()
|
||||
|
||||
self.config = config
|
||||
self.d_model = config['d_model']
|
||||
self.nhead = config['nhead']
|
||||
self.layer_names = config['layer_names']
|
||||
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
|
||||
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [N, L, C]
|
||||
feat1 (torch.Tensor): [N, S, C]
|
||||
mask0 (torch.Tensor): [N, L] (optional)
|
||||
mask1 (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
elif name == 'cross':
|
||||
feat0 = layer(feat0, feat1, mask0, mask1)
|
||||
feat1 = layer(feat1, feat0, mask1, mask0)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
return feat0, feat1
|
||||
@@ -0,0 +1,261 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops.einops import rearrange
|
||||
|
||||
INF = 1e9
|
||||
|
||||
def mask_border(m, b: int, v):
|
||||
""" Mask borders with value
|
||||
Args:
|
||||
m (torch.Tensor): [N, H0, W0, H1, W1]
|
||||
b (int)
|
||||
v (m.dtype)
|
||||
"""
|
||||
if b <= 0:
|
||||
return
|
||||
|
||||
m[:, :b] = v
|
||||
m[:, :, :b] = v
|
||||
m[:, :, :, :b] = v
|
||||
m[:, :, :, :, :b] = v
|
||||
m[:, -b:] = v
|
||||
m[:, :, -b:] = v
|
||||
m[:, :, :, -b:] = v
|
||||
m[:, :, :, :, -b:] = v
|
||||
|
||||
|
||||
def mask_border_with_padding(m, bd, v, p_m0, p_m1):
|
||||
if bd <= 0:
|
||||
return
|
||||
|
||||
m[:, :bd] = v
|
||||
m[:, :, :bd] = v
|
||||
m[:, :, :, :bd] = v
|
||||
m[:, :, :, :, :bd] = v
|
||||
|
||||
h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int()
|
||||
h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int()
|
||||
for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)):
|
||||
m[b_idx, h0 - bd:] = v
|
||||
m[b_idx, :, w0 - bd:] = v
|
||||
m[b_idx, :, :, h1 - bd:] = v
|
||||
m[b_idx, :, :, :, w1 - bd:] = v
|
||||
|
||||
|
||||
def compute_max_candidates(p_m0, p_m1):
|
||||
"""Compute the max candidates of all pairs within a batch
|
||||
|
||||
Args:
|
||||
p_m0, p_m1 (torch.Tensor): padded masks
|
||||
"""
|
||||
h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0]
|
||||
h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0]
|
||||
max_cand = torch.sum(
|
||||
torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0])
|
||||
return max_cand
|
||||
|
||||
|
||||
class CoarseMatching(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# general config
|
||||
self.thr = config['thr']
|
||||
self.border_rm = config['border_rm']
|
||||
# -- # for trainig fine-level LoFTR
|
||||
self.train_coarse_percent = config['train_coarse_percent']
|
||||
self.train_pad_num_gt_min = config['train_pad_num_gt_min']
|
||||
|
||||
# we provide 2 options for differentiable matching
|
||||
self.match_type = config['match_type']
|
||||
if self.match_type == 'dual_softmax':
|
||||
self.temperature = config['dsmax_temperature']
|
||||
elif self.match_type == 'sinkhorn':
|
||||
try:
|
||||
from .superglue import log_optimal_transport
|
||||
except ImportError:
|
||||
raise ImportError("download superglue.py first!")
|
||||
self.log_optimal_transport = log_optimal_transport
|
||||
self.bin_score = nn.Parameter(
|
||||
torch.tensor(config['skh_init_bin_score'], requires_grad=True))
|
||||
self.skh_iters = config['skh_iters']
|
||||
self.skh_prefilter = config['skh_prefilter']
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [N, L, C]
|
||||
feat1 (torch.Tensor): [N, S, C]
|
||||
data (dict)
|
||||
mask_c0 (torch.Tensor): [N, L] (optional)
|
||||
mask_c1 (torch.Tensor): [N, S] (optional)
|
||||
Update:
|
||||
data (dict): {
|
||||
'b_ids' (torch.Tensor): [M'],
|
||||
'i_ids' (torch.Tensor): [M'],
|
||||
'j_ids' (torch.Tensor): [M'],
|
||||
'gt_mask' (torch.Tensor): [M'],
|
||||
'mkpts0_c' (torch.Tensor): [M, 2],
|
||||
'mkpts1_c' (torch.Tensor): [M, 2],
|
||||
'mconf' (torch.Tensor): [M]}
|
||||
NOTE: M' != M during training.
|
||||
"""
|
||||
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
|
||||
|
||||
# normalize
|
||||
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
|
||||
[feat_c0, feat_c1])
|
||||
|
||||
if self.match_type == 'dual_softmax':
|
||||
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
|
||||
feat_c1) / self.temperature
|
||||
if mask_c0 is not None:
|
||||
sim_matrix.masked_fill_(
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
||||
-INF)
|
||||
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
|
||||
|
||||
elif self.match_type == 'sinkhorn':
|
||||
# sinkhorn, dustbin included
|
||||
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
|
||||
if mask_c0 is not None:
|
||||
sim_matrix[:, :L, :S].masked_fill_(
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
||||
-INF)
|
||||
|
||||
# build uniform prior & use sinkhorn
|
||||
log_assign_matrix = self.log_optimal_transport(
|
||||
sim_matrix, self.bin_score, self.skh_iters)
|
||||
assign_matrix = log_assign_matrix.exp()
|
||||
conf_matrix = assign_matrix[:, :-1, :-1]
|
||||
|
||||
# filter prediction with dustbin score (only in evaluation mode)
|
||||
if not self.training and self.skh_prefilter:
|
||||
filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L]
|
||||
filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S]
|
||||
conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0
|
||||
conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0
|
||||
|
||||
if self.config['sparse_spvs']:
|
||||
data.update({'conf_matrix_with_bin': assign_matrix.clone()})
|
||||
|
||||
data.update({'conf_matrix': conf_matrix})
|
||||
|
||||
# predict coarse matches from conf_matrix
|
||||
data.update(**self.get_coarse_match(conf_matrix, data))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_coarse_match(self, conf_matrix, data):
|
||||
"""
|
||||
Args:
|
||||
conf_matrix (torch.Tensor): [N, L, S]
|
||||
data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c']
|
||||
Returns:
|
||||
coarse_matches (dict): {
|
||||
'b_ids' (torch.Tensor): [M'],
|
||||
'i_ids' (torch.Tensor): [M'],
|
||||
'j_ids' (torch.Tensor): [M'],
|
||||
'gt_mask' (torch.Tensor): [M'],
|
||||
'm_bids' (torch.Tensor): [M],
|
||||
'mkpts0_c' (torch.Tensor): [M, 2],
|
||||
'mkpts1_c' (torch.Tensor): [M, 2],
|
||||
'mconf' (torch.Tensor): [M]}
|
||||
"""
|
||||
axes_lengths = {
|
||||
'h0c': data['hw0_c'][0],
|
||||
'w0c': data['hw0_c'][1],
|
||||
'h1c': data['hw1_c'][0],
|
||||
'w1c': data['hw1_c'][1]
|
||||
}
|
||||
_device = conf_matrix.device
|
||||
# 1. confidence thresholding
|
||||
mask = conf_matrix > self.thr
|
||||
mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c',
|
||||
**axes_lengths)
|
||||
if 'mask0' not in data:
|
||||
mask_border(mask, self.border_rm, False)
|
||||
else:
|
||||
mask_border_with_padding(mask, self.border_rm, False,
|
||||
data['mask0'], data['mask1'])
|
||||
mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)',
|
||||
**axes_lengths)
|
||||
|
||||
# 2. mutual nearest
|
||||
mask = mask \
|
||||
* (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \
|
||||
* (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0])
|
||||
|
||||
# 3. find all valid coarse matches
|
||||
# this only works when at most one `True` in each row
|
||||
mask_v, all_j_ids = mask.max(dim=2)
|
||||
b_ids, i_ids = torch.where(mask_v)
|
||||
j_ids = all_j_ids[b_ids, i_ids]
|
||||
mconf = conf_matrix[b_ids, i_ids, j_ids]
|
||||
|
||||
# 4. Random sampling of training samples for fine-level LoFTR
|
||||
# (optional) pad samples with gt coarse-level matches
|
||||
if self.training:
|
||||
# NOTE:
|
||||
# The sampling is performed across all pairs in a batch without manually balancing
|
||||
# #samples for fine-level increases w.r.t. batch_size
|
||||
if 'mask0' not in data:
|
||||
num_candidates_max = mask.size(0) * max(
|
||||
mask.size(1), mask.size(2))
|
||||
else:
|
||||
num_candidates_max = compute_max_candidates(
|
||||
data['mask0'], data['mask1'])
|
||||
num_matches_train = int(num_candidates_max *
|
||||
self.train_coarse_percent)
|
||||
num_matches_pred = len(b_ids)
|
||||
assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
|
||||
|
||||
# pred_indices is to select from prediction
|
||||
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
|
||||
pred_indices = torch.arange(num_matches_pred, device=_device)
|
||||
else:
|
||||
pred_indices = torch.randint(
|
||||
num_matches_pred,
|
||||
(num_matches_train - self.train_pad_num_gt_min, ),
|
||||
device=_device)
|
||||
|
||||
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
|
||||
gt_pad_indices = torch.randint(
|
||||
len(data['spv_b_ids']),
|
||||
(max(num_matches_train - num_matches_pred,
|
||||
self.train_pad_num_gt_min), ),
|
||||
device=_device)
|
||||
mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
|
||||
|
||||
b_ids, i_ids, j_ids, mconf = map(
|
||||
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
|
||||
dim=0),
|
||||
*zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']],
|
||||
[j_ids, data['spv_j_ids']], [mconf, mconf_gt]))
|
||||
|
||||
# These matches select patches that feed into fine-level network
|
||||
coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids}
|
||||
|
||||
# 4. Update with matches in original image resolution
|
||||
scale = data['hw0_i'][0] / data['hw0_c'][0]
|
||||
scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale
|
||||
scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale
|
||||
mkpts0_c = torch.stack(
|
||||
[i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]],
|
||||
dim=1) * scale0
|
||||
mkpts1_c = torch.stack(
|
||||
[j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]],
|
||||
dim=1) * scale1
|
||||
|
||||
# These matches is the current prediction (for visualization)
|
||||
coarse_matches.update({
|
||||
'gt_mask': mconf == 0,
|
||||
'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches
|
||||
'mkpts0_c': mkpts0_c[mconf != 0],
|
||||
'mkpts1_c': mkpts1_c[mconf != 0],
|
||||
'mconf': mconf[mconf != 0]
|
||||
})
|
||||
|
||||
return coarse_matches
|
||||
@@ -0,0 +1,50 @@
|
||||
from yacs.config import CfgNode as CN
|
||||
|
||||
|
||||
def lower_config(yacs_cfg):
|
||||
if not isinstance(yacs_cfg, CN):
|
||||
return yacs_cfg
|
||||
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
|
||||
|
||||
|
||||
_CN = CN()
|
||||
_CN.BACKBONE_TYPE = 'ResNetFPN'
|
||||
_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)]
|
||||
_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd
|
||||
_CN.FINE_CONCAT_COARSE_FEAT = True
|
||||
|
||||
# 1. LoFTR-backbone (local feature CNN) config
|
||||
_CN.RESNETFPN = CN()
|
||||
_CN.RESNETFPN.INITIAL_DIM = 128
|
||||
_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3
|
||||
|
||||
# 2. LoFTR-coarse module config
|
||||
_CN.COARSE = CN()
|
||||
_CN.COARSE.D_MODEL = 256
|
||||
_CN.COARSE.D_FFN = 256
|
||||
_CN.COARSE.NHEAD = 8
|
||||
_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4
|
||||
_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full']
|
||||
_CN.COARSE.TEMP_BUG_FIX = False
|
||||
|
||||
# 3. Coarse-Matching config
|
||||
_CN.MATCH_COARSE = CN()
|
||||
_CN.MATCH_COARSE.THR = 0.2
|
||||
_CN.MATCH_COARSE.BORDER_RM = 2
|
||||
_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn']
|
||||
_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1
|
||||
_CN.MATCH_COARSE.SKH_ITERS = 3
|
||||
_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0
|
||||
_CN.MATCH_COARSE.SKH_PREFILTER = True
|
||||
_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory
|
||||
_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock
|
||||
|
||||
# 4. LoFTR-fine module config
|
||||
_CN.FINE = CN()
|
||||
_CN.FINE.D_MODEL = 128
|
||||
_CN.FINE.D_FFN = 128
|
||||
_CN.FINE.NHEAD = 8
|
||||
_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1
|
||||
_CN.FINE.ATTENTION = 'linear'
|
||||
|
||||
default_cfg = lower_config(_CN)
|
||||
@@ -0,0 +1,163 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def create_meshgrid(
|
||||
height: int,
|
||||
width: int,
|
||||
normalized_coordinates: bool = True,
|
||||
device = None,
|
||||
dtype = None,
|
||||
):
|
||||
"""Generate a coordinate grid for an image.
|
||||
|
||||
When the flag ``normalized_coordinates`` is set to True, the grid is
|
||||
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
|
||||
function :py:func:`torch.nn.functional.grid_sample`.
|
||||
|
||||
Args:
|
||||
height: the image height (rows).
|
||||
width: the image width (cols).
|
||||
normalized_coordinates: whether to normalize
|
||||
coordinates in the range :math:`[-1,1]` in order to be consistent with the
|
||||
PyTorch function :py:func:`torch.nn.functional.grid_sample`.
|
||||
device: the device on which the grid will be generated.
|
||||
dtype: the data type of the generated grid.
|
||||
|
||||
Return:
|
||||
grid tensor with shape :math:`(1, H, W, 2)`.
|
||||
|
||||
Example:
|
||||
>>> create_meshgrid(2, 2)
|
||||
tensor([[[[-1., -1.],
|
||||
[ 1., -1.]],
|
||||
<BLANKLINE>
|
||||
[[-1., 1.],
|
||||
[ 1., 1.]]]])
|
||||
|
||||
>>> create_meshgrid(2, 2, normalized_coordinates=False)
|
||||
tensor([[[[0., 0.],
|
||||
[1., 0.]],
|
||||
<BLANKLINE>
|
||||
[[0., 1.],
|
||||
[1., 1.]]]])
|
||||
"""
|
||||
xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
|
||||
ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
|
||||
if normalized_coordinates:
|
||||
xs = (xs / (width - 1) - 0.5) * 2
|
||||
ys = (ys / (height - 1) - 0.5) * 2
|
||||
base_grid = torch.stack(torch.meshgrid([xs, ys], indexing="ij"), dim=-1) # WxHx2
|
||||
return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
|
||||
|
||||
|
||||
def spatial_expectation2d(input, normalized_coordinates: bool = True):
|
||||
r"""Compute the expectation of coordinate values using spatial probabilities.
|
||||
|
||||
The input heatmap is assumed to represent a valid spatial probability distribution,
|
||||
which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`.
|
||||
|
||||
Args:
|
||||
input: the input tensor representing dense spatial probabilities with shape :math:`(B, N, H, W)`.
|
||||
normalized_coordinates: whether to return the coordinates normalized in the range
|
||||
of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape.
|
||||
|
||||
Returns:
|
||||
expected value of the 2D coordinates with shape :math:`(B, N, 2)`. Output order of the coordinates is (x, y).
|
||||
|
||||
Examples:
|
||||
>>> heatmaps = torch.tensor([[[
|
||||
... [0., 0., 0.],
|
||||
... [0., 0., 0.],
|
||||
... [0., 1., 0.]]]])
|
||||
>>> spatial_expectation2d(heatmaps, False)
|
||||
tensor([[[1., 2.]]])
|
||||
"""
|
||||
|
||||
batch_size, channels, height, width = input.shape
|
||||
|
||||
# Create coordinates grid.
|
||||
grid = create_meshgrid(height, width, normalized_coordinates, input.device)
|
||||
grid = grid.to(input.dtype)
|
||||
|
||||
pos_x = grid[..., 0].reshape(-1)
|
||||
pos_y = grid[..., 1].reshape(-1)
|
||||
|
||||
input_flat = input.view(batch_size, channels, -1)
|
||||
|
||||
# Compute the expectation of the coordinates.
|
||||
expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True)
|
||||
expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True)
|
||||
|
||||
output = torch.cat([expected_x, expected_y], -1)
|
||||
|
||||
return output.view(batch_size, channels, 2) # BxNx2
|
||||
|
||||
|
||||
class FineMatching(nn.Module):
|
||||
"""FineMatching with s2d paradigm"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, feat_f0, feat_f1, data):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [M, WW, C]
|
||||
feat1 (torch.Tensor): [M, WW, C]
|
||||
data (dict)
|
||||
Update:
|
||||
data (dict):{
|
||||
'expec_f' (torch.Tensor): [M, 3],
|
||||
'mkpts0_f' (torch.Tensor): [M, 2],
|
||||
'mkpts1_f' (torch.Tensor): [M, 2]}
|
||||
"""
|
||||
M, WW, C = feat_f0.shape
|
||||
W = int(math.sqrt(WW))
|
||||
scale = data['hw0_i'][0] / data['hw0_f'][0]
|
||||
self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale
|
||||
|
||||
# corner case: if no coarse matches found
|
||||
if M == 0:
|
||||
assert self.training == False, "M is always >0, when training, see coarse_matching.py"
|
||||
# logger.warning('No matches found in coarse-level.')
|
||||
data.update({
|
||||
'expec_f': torch.empty(0, 3, device=feat_f0.device),
|
||||
'mkpts0_f': data['mkpts0_c'],
|
||||
'mkpts1_f': data['mkpts1_c'],
|
||||
})
|
||||
return
|
||||
|
||||
feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
|
||||
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
|
||||
softmax_temp = 1. / C**.5
|
||||
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
||||
|
||||
# compute coordinates from heatmap
|
||||
coords_normalized = spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
|
||||
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
|
||||
|
||||
# compute std over <x, y>
|
||||
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
|
||||
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
|
||||
|
||||
# for fine-level supervision
|
||||
data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
|
||||
|
||||
# compute absolute kpt coords
|
||||
self.get_fine_match(coords_normalized, data)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_fine_match(self, coords_normed, data):
|
||||
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
||||
|
||||
# mkpts0_f and mkpts1_f
|
||||
mkpts0_f = data['mkpts0_c']
|
||||
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
|
||||
mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
|
||||
|
||||
data.update({
|
||||
"mkpts0_f": mkpts0_f,
|
||||
"mkpts1_f": mkpts1_f
|
||||
})
|
||||
@@ -0,0 +1,54 @@
|
||||
import torch
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
||||
""" Warp kpts0 from I0 to I1 with depth, K and Rt
|
||||
Also check covisibility and depth consistency.
|
||||
Depth is consistent if relative error < 0.2 (hard-coded).
|
||||
|
||||
Args:
|
||||
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
|
||||
depth0 (torch.Tensor): [N, H, W],
|
||||
depth1 (torch.Tensor): [N, H, W],
|
||||
T_0to1 (torch.Tensor): [N, 3, 4],
|
||||
K0 (torch.Tensor): [N, 3, 3],
|
||||
K1 (torch.Tensor): [N, 3, 3],
|
||||
Returns:
|
||||
calculable_mask (torch.Tensor): [N, L]
|
||||
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
|
||||
"""
|
||||
kpts0_long = kpts0.round().long()
|
||||
|
||||
# Sample depth, get calculable_mask on depth != 0
|
||||
kpts0_depth = torch.stack(
|
||||
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
|
||||
) # (N, L)
|
||||
nonzero_mask = kpts0_depth != 0
|
||||
|
||||
# Unproject
|
||||
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
|
||||
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
||||
|
||||
# Rigid Transform
|
||||
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
|
||||
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
||||
|
||||
# Project
|
||||
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
||||
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
|
||||
|
||||
# Covisible Check
|
||||
h, w = depth1.shape[1:3]
|
||||
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
|
||||
(w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
|
||||
w_kpts0_long = w_kpts0.long()
|
||||
w_kpts0_long[~covisible_mask, :] = 0
|
||||
|
||||
w_kpts0_depth = torch.stack(
|
||||
[depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
|
||||
) # (N, L)
|
||||
consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
|
||||
valid_mask = nonzero_mask * covisible_mask * consistent_mask
|
||||
|
||||
return valid_mask, w_kpts0
|
||||
@@ -0,0 +1,42 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PositionEncodingSine(nn.Module):
|
||||
"""
|
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
|
||||
"""
|
||||
Args:
|
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
|
||||
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
|
||||
on the final performance. For now, we keep both impls for backward compatability.
|
||||
We will remove the buggy impl after re-training all variants of our released models.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros((d_model, *max_shape))
|
||||
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
||||
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
||||
if temp_bug_fix:
|
||||
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
|
||||
else: # a buggy implementation (for backward compatability only)
|
||||
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
|
||||
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
||||
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
||||
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
||||
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
||||
|
||||
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [N, C, H, W]
|
||||
"""
|
||||
return x + self.pe[:, :, :x.size(2), :x.size(3)]
|
||||
@@ -0,0 +1,151 @@
|
||||
from math import log
|
||||
from loguru import logger
|
||||
|
||||
import torch
|
||||
from einops import repeat
|
||||
from kornia.utils import create_meshgrid
|
||||
|
||||
from .geometry import warp_kpts
|
||||
|
||||
############## ↓ Coarse-Level supervision ↓ ##############
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def mask_pts_at_padded_regions(grid_pt, mask):
|
||||
"""For megadepth dataset, zero-padding exists in images"""
|
||||
mask = repeat(mask, 'n h w -> n (h w) c', c=2)
|
||||
grid_pt[~mask.bool()] = 0
|
||||
return grid_pt
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def spvs_coarse(data, config):
|
||||
"""
|
||||
Update:
|
||||
data (dict): {
|
||||
"conf_matrix_gt": [N, hw0, hw1],
|
||||
'spv_b_ids': [M]
|
||||
'spv_i_ids': [M]
|
||||
'spv_j_ids': [M]
|
||||
'spv_w_pt0_i': [N, hw0, 2], in original image resolution
|
||||
'spv_pt1_i': [N, hw1, 2], in original image resolution
|
||||
}
|
||||
|
||||
NOTE:
|
||||
- for scannet dataset, there're 3 kinds of resolution {i, c, f}
|
||||
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
|
||||
"""
|
||||
# 1. misc
|
||||
device = data['image0'].device
|
||||
N, _, H0, W0 = data['image0'].shape
|
||||
_, _, H1, W1 = data['image1'].shape
|
||||
scale = config['LOFTR']['RESOLUTION'][0]
|
||||
scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
|
||||
scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
|
||||
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
|
||||
|
||||
# 2. warp grids
|
||||
# create kpts in meshgrid and resize them to image resolution
|
||||
grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
|
||||
grid_pt0_i = scale0 * grid_pt0_c
|
||||
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
|
||||
grid_pt1_i = scale1 * grid_pt1_c
|
||||
|
||||
# mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
|
||||
if 'mask0' in data:
|
||||
grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0'])
|
||||
grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1'])
|
||||
|
||||
# warp kpts bi-directionally and resize them to coarse-level resolution
|
||||
# (no depth consistency check, since it leads to worse results experimentally)
|
||||
# (unhandled edge case: points with 0-depth will be warped to the left-up corner)
|
||||
_, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
|
||||
_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
|
||||
w_pt0_c = w_pt0_i / scale1
|
||||
w_pt1_c = w_pt1_i / scale0
|
||||
|
||||
# 3. check if mutual nearest neighbor
|
||||
w_pt0_c_round = w_pt0_c[:, :, :].round().long()
|
||||
nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1
|
||||
w_pt1_c_round = w_pt1_c[:, :, :].round().long()
|
||||
nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0
|
||||
|
||||
# corner case: out of boundary
|
||||
def out_bound_mask(pt, w, h):
|
||||
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
|
||||
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
|
||||
nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
|
||||
|
||||
loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
|
||||
correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
|
||||
correct_0to1[:, 0] = False # ignore the top-left corner
|
||||
|
||||
# 4. construct a gt conf_matrix
|
||||
conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
|
||||
b_ids, i_ids = torch.where(correct_0to1 != 0)
|
||||
j_ids = nearest_index1[b_ids, i_ids]
|
||||
|
||||
conf_matrix_gt[b_ids, i_ids, j_ids] = 1
|
||||
data.update({'conf_matrix_gt': conf_matrix_gt})
|
||||
|
||||
# 5. save coarse matches(gt) for training fine level
|
||||
if len(b_ids) == 0:
|
||||
logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
|
||||
# this won't affect fine-level loss calculation
|
||||
b_ids = torch.tensor([0], device=device)
|
||||
i_ids = torch.tensor([0], device=device)
|
||||
j_ids = torch.tensor([0], device=device)
|
||||
|
||||
data.update({
|
||||
'spv_b_ids': b_ids,
|
||||
'spv_i_ids': i_ids,
|
||||
'spv_j_ids': j_ids
|
||||
})
|
||||
|
||||
# 6. save intermediate results (for fast fine-level computation)
|
||||
data.update({
|
||||
'spv_w_pt0_i': w_pt0_i,
|
||||
'spv_pt1_i': grid_pt1_i
|
||||
})
|
||||
|
||||
|
||||
def compute_supervision_coarse(data, config):
|
||||
assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
|
||||
data_source = data['dataset_name'][0]
|
||||
if data_source.lower() in ['scannet', 'megadepth']:
|
||||
spvs_coarse(data, config)
|
||||
else:
|
||||
raise ValueError(f'Unknown data source: {data_source}')
|
||||
|
||||
|
||||
############## ↓ Fine-Level supervision ↓ ##############
|
||||
|
||||
@torch.no_grad()
|
||||
def spvs_fine(data, config):
|
||||
"""
|
||||
Update:
|
||||
data (dict):{
|
||||
"expec_f_gt": [M, 2]}
|
||||
"""
|
||||
# 1. misc
|
||||
# w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i')
|
||||
w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i']
|
||||
scale = config['LOFTR']['RESOLUTION'][1]
|
||||
radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2
|
||||
|
||||
# 2. get coarse prediction
|
||||
b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids']
|
||||
|
||||
# 3. compute gt
|
||||
scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
|
||||
# `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
|
||||
expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
|
||||
data.update({"expec_f_gt": expec_f_gt})
|
||||
|
||||
|
||||
def compute_supervision_fine(data, config):
|
||||
data_source = data['dataset_name'][0]
|
||||
if data_source.lower() in ['scannet', 'megadepth']:
|
||||
spvs_fine(data, config)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,154 @@
|
||||
import bisect
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
|
||||
|
||||
def _compute_conf_thresh(data):
|
||||
dataset_name = data['dataset_name'][0].lower()
|
||||
if dataset_name == 'scannet':
|
||||
thr = 5e-4
|
||||
elif dataset_name == 'megadepth':
|
||||
thr = 1e-4
|
||||
else:
|
||||
raise ValueError(f'Unknown dataset: {dataset_name}')
|
||||
return thr
|
||||
|
||||
|
||||
# --- VISUALIZATION --- #
|
||||
|
||||
def make_matching_figure(
|
||||
img0, img1, mkpts0, mkpts1, color,
|
||||
kpts0=None, kpts1=None, text=[], dpi=75, path=None):
|
||||
# draw image pair
|
||||
assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
||||
axes[0].imshow(img0, cmap='gray')
|
||||
axes[1].imshow(img1, cmap='gray')
|
||||
for i in range(2): # clear all frames
|
||||
axes[i].get_yaxis().set_ticks([])
|
||||
axes[i].get_xaxis().set_ticks([])
|
||||
for spine in axes[i].spines.values():
|
||||
spine.set_visible(False)
|
||||
plt.tight_layout(pad=1)
|
||||
|
||||
if kpts0 is not None:
|
||||
assert kpts1 is not None
|
||||
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
|
||||
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
|
||||
|
||||
# draw matches
|
||||
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
||||
fig.canvas.draw()
|
||||
transFigure = fig.transFigure.inverted()
|
||||
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
||||
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
||||
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
|
||||
(fkpts0[i, 1], fkpts1[i, 1]),
|
||||
transform=fig.transFigure, c=color[i], linewidth=1)
|
||||
for i in range(len(mkpts0))]
|
||||
|
||||
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
||||
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
||||
|
||||
# put txts
|
||||
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
|
||||
fig.text(
|
||||
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
|
||||
fontsize=15, va='top', ha='left', color=txt_color)
|
||||
|
||||
# save or return figure
|
||||
if path:
|
||||
plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
|
||||
plt.close()
|
||||
else:
|
||||
return fig
|
||||
|
||||
|
||||
def _make_evaluation_figure(data, b_id, alpha='dynamic'):
|
||||
b_mask = data['m_bids'] == b_id
|
||||
conf_thr = _compute_conf_thresh(data)
|
||||
|
||||
img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
||||
img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
||||
kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
|
||||
kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
|
||||
|
||||
# for megadepth, we visualize matches on the resized image
|
||||
if 'scale0' in data:
|
||||
kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
|
||||
kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]]
|
||||
|
||||
epi_errs = data['epi_errs'][b_mask].cpu().numpy()
|
||||
correct_mask = epi_errs < conf_thr
|
||||
precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0
|
||||
n_correct = np.sum(correct_mask)
|
||||
n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu())
|
||||
recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches)
|
||||
# recall might be larger than 1, since the calculation of conf_matrix_gt
|
||||
# uses groundtruth depths and camera poses, but epipolar distance is used here.
|
||||
|
||||
# matching info
|
||||
if alpha == 'dynamic':
|
||||
alpha = dynamic_alpha(len(correct_mask))
|
||||
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
||||
|
||||
text = [
|
||||
f'#Matches {len(kpts0)}',
|
||||
f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
|
||||
f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
|
||||
]
|
||||
|
||||
# make the figure
|
||||
figure = make_matching_figure(img0, img1, kpts0, kpts1,
|
||||
color, text=text)
|
||||
return figure
|
||||
|
||||
def _make_confidence_figure(data, b_id):
|
||||
# TODO: Implement confidence figure
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def make_matching_figures(data, config, mode='evaluation'):
|
||||
""" Make matching figures for a batch.
|
||||
|
||||
Args:
|
||||
data (Dict): a batch updated by PL_LoFTR.
|
||||
config (Dict): matcher config
|
||||
Returns:
|
||||
figures (Dict[str, List[plt.figure]]
|
||||
"""
|
||||
assert mode in ['evaluation', 'confidence'] # 'confidence'
|
||||
figures = {mode: []}
|
||||
for b_id in range(data['image0'].size(0)):
|
||||
if mode == 'evaluation':
|
||||
fig = _make_evaluation_figure(
|
||||
data, b_id,
|
||||
alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
|
||||
elif mode == 'confidence':
|
||||
fig = _make_confidence_figure(data, b_id)
|
||||
else:
|
||||
raise ValueError(f'Unknown plot mode: {mode}')
|
||||
figures[mode].append(fig)
|
||||
return figures
|
||||
|
||||
|
||||
def dynamic_alpha(n_matches,
|
||||
milestones=[0, 300, 1000, 2000],
|
||||
alphas=[1.0, 0.8, 0.4, 0.2]):
|
||||
if n_matches == 0:
|
||||
return 1.0
|
||||
ranges = list(zip(alphas, alphas[1:] + [None]))
|
||||
loc = bisect.bisect_right(milestones, n_matches) - 1
|
||||
_range = ranges[loc]
|
||||
if _range[1] is None:
|
||||
return _range[0]
|
||||
return _range[1] + (milestones[loc + 1] - n_matches) / (
|
||||
milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
|
||||
|
||||
|
||||
def error_colormap(err, thr, alpha=1.0):
|
||||
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
||||
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
||||
return np.clip(
|
||||
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
|
||||
@@ -118,6 +118,7 @@ if TYPE_CHECKING:
|
||||
from .text_to_360panorama_image_pipeline import Text2360PanoramaImagePipeline
|
||||
from .human3d_render_pipeline import Human3DRenderPipeline
|
||||
from .human3d_animation_pipeline import Human3DAnimationPipeline
|
||||
from .image_local_feature_matching_pipeline import ImageLocalFeatureMatchingPipeline
|
||||
from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline
|
||||
from .anydoor_pipeline import AnydoorPipeline
|
||||
else:
|
||||
@@ -295,6 +296,7 @@ else:
|
||||
],
|
||||
'human3d_render_pipeline': ['Human3DRenderPipeline'],
|
||||
'human3d_animation_pipeline': ['Human3DAnimationPipeline'],
|
||||
'image_local_feature_matching_pipeline': ['ImageLocalFeatureMatchingPipeline'],
|
||||
'rife_video_frame_interpolation_pipeline': [
|
||||
'RIFEVideoFrameInterpolationPipeline'
|
||||
],
|
||||
|
||||
121
modelscope/pipelines/cv/image_local_feature_matching_pipeline.py
Normal file
121
modelscope/pipelines/cv/image_local_feature_matching_pipeline.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_local_feature_matching,
|
||||
module_name=Pipelines.image_local_feature_matching)
|
||||
class ImageLocalFeatureMatchingPipeline(Pipeline):
|
||||
r""" Image Local Feature Matching Pipeline.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
|
||||
>>> matcher = pipeline(Tasks.image_local_feature_matching, model='Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data')
|
||||
>>> matcher([['https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching1.jpg','https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching2.jpg']])
|
||||
>>> [{
|
||||
>>> 'matches': [array([[720.5 , 187.8 ],
|
||||
>>> [707.4 , 198.23334],
|
||||
>>> ...,
|
||||
>>> [746.7 , 594.7 ],
|
||||
>>> [759.8 , 594.7 ]], dtype=float32),
|
||||
>>> array([[652.49744 , 29.599142],
|
||||
>>> [639.25287 , 45.90798 ],
|
||||
>>> [653.041 , 43.399014],
|
||||
>>> ...,
|
||||
>>> [670.8787 , 547.8298 ],
|
||||
>>> [608.5573 , 548.97815 ],
|
||||
>>> [617.82574 , 548.601 ]], dtype=float32),
|
||||
>>> array([0.25541496, 0.2781789 , 0.20776041, ..., 0.39656195, 0.7202848 ,
|
||||
>>> 0.37208357], dtype=float32)],
|
||||
>>> 'output_img': array([[[255, 255, 255],
|
||||
>>> [255, 255, 255],
|
||||
>>> [255, 255, 255],
|
||||
>>> ...,
|
||||
>>> [255, 255, 255],
|
||||
>>> [255, 255, 255],
|
||||
>>> [255, 255, 255]],
|
||||
>>> ...,
|
||||
>>> [[255, 255, 255],
|
||||
>>> [255, 255, 255],
|
||||
>>> [255, 255, 255],
|
||||
>>> ...,
|
||||
>>> [255, 255, 255],
|
||||
>>> [255, 255, 255],
|
||||
>>> [255, 255, 255]]], dtype=uint8)}]
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image local feature matching pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
|
||||
def load_image(self, img_name):
|
||||
img = LoadImage.convert_to_ndarray(img_name).astype(np.float32)
|
||||
img = img / 255.
|
||||
# convert rgb to gray
|
||||
if len(img.shape) == 3:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
||||
H, W = 480, 640
|
||||
h_scale, w_scale = H / img.shape[0], W / img.shape[1]
|
||||
img = cv2.resize(img, (W, H))
|
||||
return img, h_scale, w_scale
|
||||
|
||||
def preprocess(self, input: Input):
|
||||
assert len(input) == 2, 'input should be a list of two images'
|
||||
|
||||
img1, h_scale1, w_scale1 = self.load_image(input[0])
|
||||
|
||||
img2, h_scale2, w_scale2 = self.load_image(input[1])
|
||||
|
||||
img1 = torch.from_numpy(img1)[None][None].cuda().float()
|
||||
img2 = torch.from_numpy(img2)[None][None].cuda().float()
|
||||
return {
|
||||
'image0': img1,
|
||||
'image1': img2,
|
||||
'scale_info': [h_scale1, w_scale1, h_scale2, w_scale2]
|
||||
}
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.inference(input)
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.postprocess(inputs)
|
||||
matches = results[OutputKeys.MATCHES]
|
||||
|
||||
kpts0 = matches['kpts0'].cpu().numpy()
|
||||
kpts1 = matches['kpts1'].cpu().numpy()
|
||||
conf = matches['conf'].cpu().numpy()
|
||||
scale_info = [v.cpu().numpy() for v in inputs['scale_info']]
|
||||
kpts0[:, 0] = kpts0[:, 0] / scale_info[1]
|
||||
kpts0[:, 1] = kpts0[:, 1] / scale_info[0]
|
||||
kpts1[:, 0] = kpts1[:, 0] / scale_info[3]
|
||||
kpts1[:, 1] = kpts1[:, 1] / scale_info[2]
|
||||
|
||||
outputs = {
|
||||
OutputKeys.MATCHES: [kpts0, kpts1, conf],
|
||||
OutputKeys.OUTPUT_IMG: results[OutputKeys.OUTPUT_IMG]
|
||||
}
|
||||
|
||||
return outputs
|
||||
@@ -70,6 +70,7 @@ class CVTasks(object):
|
||||
face_emotion = 'face-emotion'
|
||||
product_segmentation = 'product-segmentation'
|
||||
image_matching = 'image-matching'
|
||||
image_local_feature_matching = 'image-local-feature-matching'
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
|
||||
crowd_counting = 'crowd-counting'
|
||||
|
||||
@@ -1281,6 +1281,13 @@
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"image-local-feature-matching": {
|
||||
"input": {},
|
||||
"parameters": {},
|
||||
"output": {
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"image-multi-view-depth-estimation": {
|
||||
"input": {},
|
||||
"parameters": {},
|
||||
|
||||
39
tests/pipelines/test_image_local_feature_matching.py
Normal file
39
tests/pipelines/test_image_local_feature_matching.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.cm as cm
|
||||
import numpy as np
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import match_pair_visualization
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageLocalFeatureMatchingTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = 'image-local-feature-matching'
|
||||
self.model_id = 'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_local_feature_matching(self):
|
||||
input_location = [[
|
||||
'data/test/images/image_matching1.jpg',
|
||||
'data/test/images/image_matching2.jpg'
|
||||
]]
|
||||
estimator = pipeline(Tasks.image_local_feature_matching, model=self.model_id)
|
||||
result = estimator(input_location)
|
||||
kpts0, kpts1, conf = result[0][OutputKeys.MATCHES]
|
||||
vis_img = result[0][OutputKeys.OUTPUT_IMG]
|
||||
cv2.imwrite("vis_demo.jpg", vis_img)
|
||||
|
||||
print('test_image_local_feature_matching DONE')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user