fix lint issue

This commit is contained in:
mulin.lyh
2024-01-22 15:52:30 +08:00
parent 672c32e7bd
commit 588e41c787
32 changed files with 868 additions and 680 deletions

View File

@@ -808,7 +808,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
(Pipelines.panorama_depth_estimation,
'damo/cv_unifuse_panorama-depth-estimation'),
Tasks.image_local_feature_matching:
(Pipelines.image_local_feature_matching, 'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data'),
(Pipelines.image_local_feature_matching,
'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data'),
Tasks.image_style_transfer: (Pipelines.image_style_transfer,
'damo/cv_aams_style-transfer_damo'),
Tasks.face_image_generation: (Pipelines.face_image_generation,
@@ -832,9 +833,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.image_object_detection:
(Pipelines.image_object_detection_auto,
'damo/cv_yolox_image-object-detection-auto'),
Tasks.ocr_recognition:
(Pipelines.ocr_recognition,
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
Tasks.ocr_recognition: (
Pipelines.ocr_recognition,
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
Tasks.skin_retouching: (Pipelines.skin_retouching,
'damo/cv_unet_skin-retouching'),
Tasks.faq_question_answering: (

View File

@@ -8,10 +8,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
face_reconstruction, human3d_animation, human_reconstruction,
image_classification, image_color_enhance, image_colorization,
image_defrcn_fewshot, image_denoise, image_editing,
image_inpainting, image_instance_segmentation, image_matching,
image_mvs_depth_estimation, image_panoptic_segmentation,
image_portrait_enhancement, image_probing_model,
image_quality_assessment_degradation,
image_inpainting, image_instance_segmentation,
image_local_feature_matching, image_matching,
image_matching_fast, image_mvs_depth_estimation,
image_panoptic_segmentation, image_portrait_enhancement,
image_probing_model, image_quality_assessment_degradation,
image_quality_assessment_man, image_quality_assessment_mos,
image_reid_person, image_restoration,
image_semantic_segmentation, image_super_resolution_pasd,
@@ -29,6 +30,6 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
video_panoptic_segmentation, video_single_object_tracking,
video_stabilization, video_summarization,
video_super_resolution, vidt, virual_tryon, vision_middleware,
vop_retrieval, image_local_feature_matching,image_matching_fast)
vop_retrieval)
# yapf: enable

View File

@@ -19,4 +19,4 @@ else:
_import_structure,
module_spec=__spec__,
extra_objects={},
)
)

View File

@@ -1,21 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import io
import cv2
import torch
import numpy as np
import os.path as osp
from copy import deepcopy
import cv2
import matplotlib.cm as cm
import numpy as np
import torch
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.image_local_feature_matching.src.loftr import \
LoFTR, default_cfg
from modelscope.models.cv.image_local_feature_matching.src.utils.plotting import make_matching_figure
from modelscope.models.cv.image_local_feature_matching.src.loftr import (
LoFTR, default_cfg)
from modelscope.models.cv.image_local_feature_matching.src.utils.plotting import \
make_matching_figure
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
import matplotlib.cm as cm
@MODELS.register_module(
@@ -51,15 +52,19 @@ class LocalFeatureMatching(TorchModel):
def postprocess(self, Inputs):
# Draw
color = cm.jet(Inputs['conf'].cpu().numpy())
img0, img1, mkpts0, mkpts1 = Inputs["image0"].squeeze().cpu().numpy(), Inputs["image1"].squeeze().cpu().numpy(), Inputs["kpts0"].cpu().numpy(), Inputs["kpts1"].cpu().numpy()
img0, img1, mkpts0, mkpts1 = Inputs['image0'].squeeze().cpu().numpy(
), Inputs['image1'].squeeze().cpu().numpy(), Inputs['kpts0'].cpu(
).numpy(), Inputs['kpts1'].cpu().numpy()
text = [
'LoFTR',
'Matches: {}'.format(len(Inputs['kpts0'])),
]
img0, img1 = (img0 * 255).astype(np.uint8), (img1 * 255).astype(np.uint8)
fig = make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text)
img0, img1 = (img0 * 255).astype(np.uint8), (img1 * 255).astype(
np.uint8)
fig = make_matching_figure(
img0, img1, mkpts0, mkpts1, color, text=text)
io_buf = io.BytesIO()
fig.savefig(io_buf, format="png", dpi=75)
fig.savefig(io_buf, format='png', dpi=75)
io_buf.seek(0)
buf_data = np.frombuffer(io_buf.getvalue(), dtype=np.uint8)
io_buf.close()
@@ -71,4 +76,4 @@ class LocalFeatureMatching(TorchModel):
def inference(self, data):
results = self.forward(data)
return results
return results

View File

@@ -8,4 +8,5 @@ def build_backbone(config):
elif config['resolution'] == (16, 4):
return ResNetFPN_16_4(config['resnetfpn'])
else:
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
raise ValueError(
f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")

View File

@@ -4,15 +4,28 @@ import torch.nn.functional as F
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution without padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=1,
stride=stride,
padding=0,
bias=False)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_planes, planes, stride=1):
super().__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
@@ -26,8 +39,7 @@ class BasicBlock(nn.Module):
else:
self.downsample = nn.Sequential(
conv1x1(in_planes, planes, stride=stride),
nn.BatchNorm2d(planes)
)
nn.BatchNorm2d(planes))
def forward(self, x):
y = x
@@ -37,7 +49,7 @@ class BasicBlock(nn.Module):
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x+y)
return self.relu(x + y)
class ResNetFPN_8_2(nn.Module):
@@ -57,7 +69,8 @@ class ResNetFPN_8_2(nn.Module):
self.in_planes = initial_dim
# Networks
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1 = nn.Conv2d(
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(initial_dim)
self.relu = nn.ReLU(inplace=True)
@@ -84,7 +97,8 @@ class ResNetFPN_8_2(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@@ -107,13 +121,15 @@ class ResNetFPN_8_2(nn.Module):
# FPN
x3_out = self.layer3_outconv(x3)
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
x3_out_2x = F.interpolate(
x3_out, scale_factor=2., mode='bilinear', align_corners=True)
x2_out = self.layer2_outconv(x2)
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
x2_out_2x = F.interpolate(
x2_out, scale_factor=2., mode='bilinear', align_corners=True)
x1_out = self.layer1_outconv(x1)
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
x1_out = self.layer1_outconv2(x1_out + x2_out_2x)
return [x3_out, x1_out]
@@ -135,7 +151,8 @@ class ResNetFPN_16_4(nn.Module):
self.in_planes = initial_dim
# Networks
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
self.conv1 = nn.Conv2d(
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(initial_dim)
self.relu = nn.ReLU(inplace=True)
@@ -164,7 +181,8 @@ class ResNetFPN_16_4(nn.Module):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
@@ -188,12 +206,14 @@ class ResNetFPN_16_4(nn.Module):
# FPN
x4_out = self.layer4_outconv(x4)
x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
x4_out_2x = F.interpolate(
x4_out, scale_factor=2., mode='bilinear', align_corners=True)
x3_out = self.layer3_outconv(x3)
x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
x3_out = self.layer3_outconv2(x3_out + x4_out_2x)
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
x3_out_2x = F.interpolate(
x3_out, scale_factor=2., mode='bilinear', align_corners=True)
x2_out = self.layer2_outconv(x2)
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
return [x4_out, x2_out]

View File

@@ -3,13 +3,14 @@ import torch.nn as nn
from einops.einops import rearrange
from .backbone import build_backbone
from .utils.position_encoding import PositionEncodingSine
from .loftr_module import LocalFeatureTransformer, FinePreprocess
from .loftr_module import FinePreprocess, LocalFeatureTransformer
from .utils.coarse_matching import CoarseMatching
from .utils.fine_matching import FineMatching
from .utils.position_encoding import PositionEncodingSine
class LoFTR(nn.Module):
def __init__(self, config):
super().__init__()
# Misc
@@ -23,11 +24,11 @@ class LoFTR(nn.Module):
self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
self.coarse_matching = CoarseMatching(config['match_coarse'])
self.fine_preprocess = FinePreprocess(config)
self.loftr_fine = LocalFeatureTransformer(config["fine"])
self.loftr_fine = LocalFeatureTransformer(config['fine'])
self.fine_matching = FineMatching()
def forward(self, data):
"""
"""
Update:
data (dict): {
'image0': (torch.Tensor): (N, 1, H, W)
@@ -39,18 +40,24 @@ class LoFTR(nn.Module):
# 1. Local Feature CNN
data.update({
'bs': data['image0'].size(0),
'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
'hw0_i': data['image0'].shape[2:],
'hw1_i': data['image1'].shape[2:]
})
if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
feats_c, feats_f = self.backbone(
torch.cat([data['image0'], data['image1']], dim=0))
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
data['bs']), feats_f.split(data['bs'])
else: # handle different input shapes
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
data['image0']), self.backbone(data['image1'])
data.update({
'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
'hw0_c': feat_c0.shape[2:],
'hw1_c': feat_c1.shape[2:],
'hw0_f': feat_f0.shape[2:],
'hw1_f': feat_f1.shape[2:]
})
# 2. coarse-level loftr module
@@ -60,16 +67,21 @@ class LoFTR(nn.Module):
mask_c0 = mask_c1 = None # mask is useful in training
if 'mask0' in data:
mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
mask_c0, mask_c1 = data['mask0'].flatten(
-2), data['mask1'].flatten(-2)
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0,
mask_c1)
# 3. match coarse-level
self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
self.coarse_matching(
feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
# 4. fine-level refinement
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
feat_f0, feat_f1, feat_c0, feat_c1, data)
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
feat_f0_unfold, feat_f1_unfold)
# 5. match fine-level
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)

View File

@@ -1,2 +1,2 @@
from .transformer import LocalFeatureTransformer
from .fine_preprocess import FinePreprocess
from .transformer import LocalFeatureTransformer

View File

@@ -5,6 +5,7 @@ from einops.einops import rearrange, repeat
class FinePreprocess(nn.Module):
def __init__(self, config):
super().__init__()
@@ -17,14 +18,14 @@ class FinePreprocess(nn.Module):
self.d_model_f = d_model_f
if self.cat_c_feat:
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
nn.init.kaiming_normal_(p, mode='fan_out', nonlinearity='relu')
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
W = self.W
@@ -32,28 +33,41 @@ class FinePreprocess(nn.Module):
data.update({'W': W})
if data['b_ids'].shape[0] == 0:
feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
feat0 = torch.empty(
0, self.W**2, self.d_model_f, device=feat_f0.device)
feat1 = torch.empty(
0, self.W**2, self.d_model_f, device=feat_f0.device)
return feat0, feat1
# 1. unfold(crop) all local windows
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
feat_f0_unfold = F.unfold(
feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2)
feat_f0_unfold = rearrange(
feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
feat_f1_unfold = F.unfold(
feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2)
feat_f1_unfold = rearrange(
feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
# 2. select only the predicted matches
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
feat_f0_unfold = feat_f0_unfold[data['b_ids'],
data['i_ids']] # [n, ww, cf]
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
# option: use coarse-level loftr feature as context: concat and linear
if self.cat_c_feat:
feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
feat_cf_win = self.merge_feat(torch.cat([
torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
], -1))
feat_c_win = self.down_proj(
torch.cat([
feat_c0[data['b_ids'], data['i_ids']],
feat_c1[data['b_ids'], data['j_ids']]
], 0)) # [2n, c]
feat_cf_win = self.merge_feat(
torch.cat(
[
torch.cat([feat_f0_unfold, feat_f1_unfold],
0), # [2n, ww, cf]
repeat(feat_c_win, 'n c -> n ww c', ww = W ** 2), # [2n, ww, cf]
], -1)) # yapf: disable
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
return feat_f0_unfold, feat_f1_unfold

View File

@@ -4,7 +4,7 @@ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_trans
"""
import torch
from torch.nn import Module, Dropout
from torch.nn import Dropout, Module
def elu_feature_map(x):
@@ -12,6 +12,7 @@ def elu_feature_map(x):
class LinearAttention(Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
@@ -40,14 +41,16 @@ class LinearAttention(Module):
v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
KV = torch.einsum('nshd,nshv->nhdv', K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum('nlhd,nhd->nlh', Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum('nlhd,nhdv,nlh->nlhv', Q, KV,
Z) * v_length
return queried_values.contiguous()
class FullAttention(Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
@@ -66,9 +69,11 @@ class FullAttention(Module):
"""
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
QK = torch.einsum('nlhd,nshd->nlsh', queries, keys)
if kv_mask is not None:
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
QK.masked_fill_(
~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]),
float('-inf'))
# Compute the attention and the weighted average
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
@@ -76,6 +81,6 @@ class FullAttention(Module):
if self.use_dropout:
A = self.dropout(A)
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
queried_values = torch.einsum('nlsh,nshd->nlhd', A, values)
return queried_values.contiguous()

View File

@@ -1,14 +1,14 @@
import copy
import torch
import torch.nn as nn
from .linear_attention import LinearAttention, FullAttention
from .linear_attention import FullAttention, LinearAttention
class LoFTREncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
attention='linear'):
def __init__(self, d_model, nhead, attention='linear'):
super(LoFTREncoderLayer, self).__init__()
self.dim = d_model // nhead
@@ -18,14 +18,15 @@ class LoFTREncoderLayer(nn.Module):
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
self.attention = LinearAttention(
) if attention == 'linear' else FullAttention()
self.merge = nn.Linear(d_model, d_model, bias=False)
# feed-forward network
self.mlp = nn.Sequential(
nn.Linear(d_model*2, d_model*2, bias=False),
nn.Linear(d_model * 2, d_model * 2, bias=False),
nn.ReLU(True),
nn.Linear(d_model*2, d_model, bias=False),
nn.Linear(d_model * 2, d_model, bias=False),
)
# norm and dropout
@@ -44,11 +45,16 @@ class LoFTREncoderLayer(nn.Module):
query, key, value = x, source, source
# multi-head attention
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
query = self.q_proj(query).view(bs, -1, self.nhead,
self.dim) # [N, L, (H, D)]
key = self.k_proj(key).view(bs, -1, self.nhead,
self.dim) # [N, S, (H, D)]
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
message = self.attention(
query, key, value, q_mask=x_mask,
kv_mask=source_mask) # [N, L, (H, D)]
message = self.merge(message.view(bs, -1,
self.nhead * self.dim)) # [N, L, C]
message = self.norm1(message)
# feed-forward network
@@ -68,8 +74,11 @@ class LocalFeatureTransformer(nn.Module):
self.d_model = config['d_model']
self.nhead = config['nhead']
self.layer_names = config['layer_names']
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'],
config['attention'])
self.layers = nn.ModuleList([
copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))
])
self._reset_parameters()
def _reset_parameters(self):
@@ -86,7 +95,8 @@ class LocalFeatureTransformer(nn.Module):
mask1 (torch.Tensor): [N, S] (optional)
"""
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
assert self.d_model == feat0.size(
2), 'the feature number of src and transformer must be equal'
for layer, name in zip(self.layers, self.layer_names):
if name == 'self':

View File

@@ -5,6 +5,7 @@ from einops.einops import rearrange
INF = 1e9
def mask_border(m, b: int, v):
""" Mask borders with value
Args:
@@ -45,7 +46,7 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
def compute_max_candidates(p_m0, p_m1):
"""Compute the max candidates of all pairs within a batch
Args:
p_m0, p_m1 (torch.Tensor): padded masks
"""
@@ -57,6 +58,7 @@ def compute_max_candidates(p_m0, p_m1):
class CoarseMatching(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
@@ -75,7 +77,7 @@ class CoarseMatching(nn.Module):
try:
from .superglue import log_optimal_transport
except ImportError:
raise ImportError("download superglue.py first!")
raise ImportError('download superglue.py first!')
self.log_optimal_transport = log_optimal_transport
self.bin_score = nn.Parameter(
torch.tensor(config['skh_init_bin_score'], requires_grad=True))
@@ -103,28 +105,27 @@ class CoarseMatching(nn.Module):
'mconf' (torch.Tensor): [M]}
NOTE: M' != M during training.
"""
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
_, L, S, _ = feat_c0.size(0), feat_c0.size(1), feat_c1.size(
1), feat_c0.size(2)
# normalize
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
[feat_c0, feat_c1])
if self.match_type == 'dual_softmax':
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
sim_matrix = torch.einsum('nlc,nsc->nls', feat_c0,
feat_c1) / self.temperature
if mask_c0 is not None:
sim_matrix.masked_fill_(
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
-INF)
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF)
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
elif self.match_type == 'sinkhorn':
# sinkhorn, dustbin included
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
sim_matrix = torch.einsum('nlc,nsc->nls', feat_c0, feat_c1)
if mask_c0 is not None:
sim_matrix[:, :L, :S].masked_fill_(
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
-INF)
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF)
# build uniform prior & use sinkhorn
log_assign_matrix = self.log_optimal_transport(
@@ -207,10 +208,10 @@ class CoarseMatching(nn.Module):
else:
num_candidates_max = compute_max_candidates(
data['mask0'], data['mask1'])
num_matches_train = int(num_candidates_max *
self.train_coarse_percent)
num_matches_train = int(num_candidates_max
* self.train_coarse_percent)
num_matches_pred = len(b_ids)
assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
assert self.train_pad_num_gt_min < num_matches_train, 'min-num-gt-pad should be less than num-train-matches'
# pred_indices is to select from prediction
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
@@ -223,11 +224,13 @@ class CoarseMatching(nn.Module):
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
gt_pad_indices = torch.randint(
len(data['spv_b_ids']),
(max(num_matches_train - num_matches_pred,
self.train_pad_num_gt_min), ),
device=_device)
mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
len(data['spv_b_ids']),
(max(num_matches_train - num_matches_pred,
self.train_pad_num_gt_min), ),
device=_device)
mconf_gt = torch.zeros(
len(data['spv_b_ids']),
device=_device) # set conf of gt paddings to all zero
b_ids, i_ids, j_ids, mconf = map(
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],

View File

@@ -1,4 +1,5 @@
import math
import torch
import torch.nn as nn
@@ -7,8 +8,8 @@ def create_meshgrid(
height: int,
width: int,
normalized_coordinates: bool = True,
device = None,
dtype = None,
device=None,
dtype=None,
):
"""Generate a coordinate grid for an image.
@@ -48,7 +49,8 @@ def create_meshgrid(
if normalized_coordinates:
xs = (xs / (width - 1) - 0.5) * 2
ys = (ys / (height - 1) - 0.5) * 2
base_grid = torch.stack(torch.meshgrid([xs, ys], indexing="ij"), dim=-1) # WxHx2
base_grid = torch.stack(
torch.meshgrid([xs, ys], indexing='ij'), dim=-1) # WxHx2
return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
@@ -120,7 +122,7 @@ class FineMatching(nn.Module):
# corner case: if no coarse matches found
if M == 0:
assert self.training == False, "M is always >0, when training, see coarse_matching.py"
assert self.training is False, 'M is always >0, when training, see coarse_matching.py'
# logger.warning('No matches found in coarse-level.')
data.update({
'expec_f': torch.empty(0, 3, device=feat_f0.device),
@@ -129,35 +131,41 @@ class FineMatching(nn.Module):
})
return
feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :]
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
softmax_temp = 1. / C**.5
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
heatmap = torch.softmax(
softmax_temp * sim_matrix, dim=1).view(-1, W, W)
# compute coordinates from heatmap
coords_normalized = spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
coords_normalized = spatial_expectation2d(heatmap[None],
True)[0] # [M, 2]
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
1, -1, 2) # [1, WW, 2]
# compute std over <x, y>
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
var = torch.sum(
grid_normalized**2 * heatmap.view(-1, WW, 1),
dim=1) - coords_normalized**2 # [M, 2]
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)),
-1) # [M] clamp needed for numerical stability
# for fine-level supervision
data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
data.update(
{'expec_f':
torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
# compute absolute kpt coords
self.get_fine_match(coords_normalized, data)
@torch.no_grad()
def get_fine_match(self, coords_normed, data):
W, WW, C, scale = self.W, self.WW, self.C, self.scale
W, _, _, scale = self.W, self.WW, self.C, self.scale
# mkpts0_f and mkpts1_f
mkpts0_f = data['mkpts0_c']
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
scale1 = scale * data['scale1'][
data['b_ids']] if 'scale0' in data else scale
mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] # yapf: disable
data.update({
"mkpts0_f": mkpts0_f,
"mkpts1_f": mkpts1_f
})
data.update({'mkpts0_f': mkpts0_f, 'mkpts1_f': mkpts1_f})

View File

@@ -6,7 +6,7 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
""" Warp kpts0 from I0 to I1 with depth, K and Rt
Also check covisibility and depth consistency.
Depth is consistent if relative error < 0.2 (hard-coded).
Args:
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
depth0 (torch.Tensor): [N, H, W],
@@ -21,34 +21,37 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
kpts0_long = kpts0.round().long()
# Sample depth, get calculable_mask on depth != 0
kpts0_depth = torch.stack(
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
) # (N, L)
kpts0_depth = torch.stack([
depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]]
for i in range(kpts0.shape[0])
],
dim=0) # noqa E501
nonzero_mask = kpts0_depth != 0
# Unproject
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])],
dim=-1) * kpts0_depth[..., None] # (N, L, 3)
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
# Rigid Transform
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3,
[3]] # (N, 3, L)
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
# Project
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4
) # (N, L, 2), +1e-4 to avoid zero depth
# Covisible Check
h, w = depth1.shape[1:3]
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
(w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w - 1) * (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h - 1) # noqa E501 yapf: disable
w_kpts0_long = w_kpts0.long()
w_kpts0_long[~covisible_mask, :] = 0
w_kpts0_depth = torch.stack(
[depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
) # (N, L)
consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
w_kpts0_depth = torch.stack([depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0) # noqa E501 yapf: disable
consistent_mask = (
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
valid_mask = nonzero_mask * covisible_mask * consistent_mask
return valid_mask, w_kpts0

View File

@@ -1,4 +1,5 @@
import math
import torch
from torch import nn
@@ -23,16 +24,17 @@ class PositionEncodingSine(nn.Module):
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
if temp_bug_fix:
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
div_term = torch.exp(torch.arange(0, d_model // 2, 2).float() * (-math.log(10000.0) / (d_model // 2))) # noqa E501 yapf: disable
else: # a buggy implementation (for backward compatability only)
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
div_term = torch.exp(torch.arange(0, d_model // 2, 2).float() * (-math.log(10000.0) / d_model // 2)) # noqa E501 yapf: disable
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term)
pe[1::4, :, :] = torch.cos(x_position * div_term)
pe[2::4, :, :] = torch.sin(y_position * div_term)
pe[3::4, :, :] = torch.cos(y_position * div_term)
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
self.register_buffer(
'pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
def forward(self, x):
"""

View File

@@ -1,13 +1,13 @@
from math import log
from loguru import logger
import torch
from einops import repeat
from kornia.utils import create_meshgrid
from loguru import logger
from .geometry import warp_kpts
############## ↓ Coarse-Level supervision ↓ ##############
# ↓ Coarse-Level supervision ↓ ##############
@torch.no_grad()
@@ -30,7 +30,7 @@ def spvs_coarse(data, config):
'spv_w_pt0_i': [N, hw0, 2], in original image resolution
'spv_pt1_i': [N, hw1, 2], in original image resolution
}
NOTE:
- for scannet dataset, there're 3 kinds of resolution {i, c, f}
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
@@ -46,9 +46,14 @@ def spvs_coarse(data, config):
# 2. warp grids
# create kpts in meshgrid and resize them to image resolution
grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
grid_pt0_c = create_meshgrid(h0, w0, False,
device).reshape(1, h0 * w0,
2).repeat(N, 1,
1) # [N, hw, 2]
grid_pt0_i = scale0 * grid_pt0_c
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
grid_pt1_c = create_meshgrid(h1, w1, False,
device).reshape(1, h1 * w1,
2).repeat(N, 1, 1)
grid_pt1_i = scale1 * grid_pt1_c
# mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
@@ -59,8 +64,10 @@ def spvs_coarse(data, config):
# warp kpts bi-directionally and resize them to coarse-level resolution
# (no depth consistency check, since it leads to worse results experimentally)
# (unhandled edge case: points with 0-depth will be warped to the left-up corner)
_, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
_, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'],
data['T_0to1'], data['K0'], data['K1'])
_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'],
data['T_1to0'], data['K1'], data['K0'])
w_pt0_c = w_pt0_i / scale1
w_pt1_c = w_pt1_i / scale0
@@ -72,16 +79,21 @@ def spvs_coarse(data, config):
# corner case: out of boundary
def out_bound_mask(pt, w, h):
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (
pt[..., 1] >= h)
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
loop_back = torch.stack(
[nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)],
dim=0)
correct_0to1 = loop_back == torch.arange(
h0 * w0, device=device)[None].repeat(N, 1)
correct_0to1[:, 0] = False # ignore the top-left corner
# 4. construct a gt conf_matrix
conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device)
b_ids, i_ids = torch.where(correct_0to1 != 0)
j_ids = nearest_index1[b_ids, i_ids]
@@ -90,27 +102,22 @@ def spvs_coarse(data, config):
# 5. save coarse matches(gt) for training fine level
if len(b_ids) == 0:
logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
logger.warning(
f"No groundtruth coarse match found for: {data['pair_names']}")
# this won't affect fine-level loss calculation
b_ids = torch.tensor([0], device=device)
i_ids = torch.tensor([0], device=device)
j_ids = torch.tensor([0], device=device)
data.update({
'spv_b_ids': b_ids,
'spv_i_ids': i_ids,
'spv_j_ids': j_ids
})
data.update({'spv_b_ids': b_ids, 'spv_i_ids': i_ids, 'spv_j_ids': j_ids})
# 6. save intermediate results (for fast fine-level computation)
data.update({
'spv_w_pt0_i': w_pt0_i,
'spv_pt1_i': grid_pt1_i
})
data.update({'spv_w_pt0_i': w_pt0_i, 'spv_pt1_i': grid_pt1_i})
def compute_supervision_coarse(data, config):
assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
assert len(set(
data['dataset_name'])) == 1, 'Do not support mixed datasets training!'
data_source = data['dataset_name'][0]
if data_source.lower() in ['scannet', 'megadepth']:
spvs_coarse(data, config)
@@ -118,7 +125,8 @@ def compute_supervision_coarse(data, config):
raise ValueError(f'Unknown data source: {data_source}')
############## ↓ Fine-Level supervision ↓ ##############
# ↓ Fine-Level supervision ↓ ##############
@torch.no_grad()
def spvs_fine(data, config):
@@ -139,8 +147,9 @@ def spvs_fine(data, config):
# 3. compute gt
scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
# `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
data.update({"expec_f_gt": expec_f_gt})
expec_f_gt = (w_pt0_i[b_ids, i_ids]
- pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
data.update({'expec_f_gt': expec_f_gt})
def compute_supervision_fine(data, config):

View File

@@ -1,7 +1,8 @@
import bisect
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
def _compute_conf_thresh(data):
@@ -17,21 +18,30 @@ def _compute_conf_thresh(data):
# --- VISUALIZATION --- #
def make_matching_figure(
img0, img1, mkpts0, mkpts1, color,
kpts0=None, kpts1=None, text=[], dpi=75, path=None):
def make_matching_figure(img0,
img1,
mkpts0,
mkpts1,
color,
kpts0=None,
kpts1=None,
text=[],
dpi=75,
path=None):
# draw image pair
assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
assert mkpts0.shape[0] == mkpts1.shape[
0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
axes[0].imshow(img0, cmap='gray')
axes[1].imshow(img1, cmap='gray')
for i in range(2): # clear all frames
for i in range(2): # clear all frames
axes[i].get_yaxis().set_ticks([])
axes[i].get_xaxis().set_ticks([])
for spine in axes[i].spines.values():
spine.set_visible(False)
plt.tight_layout(pad=1)
if kpts0 is not None:
assert kpts1 is not None
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
@@ -43,19 +53,28 @@ def make_matching_figure(
transFigure = fig.transFigure.inverted()
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
(fkpts0[i, 1], fkpts1[i, 1]),
transform=fig.transFigure, c=color[i], linewidth=1)
for i in range(len(mkpts0))]
fig.lines = [
matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
(fkpts0[i, 1], fkpts1[i, 1]),
transform=fig.transFigure,
c=color[i],
linewidth=1) for i in range(len(mkpts0))
]
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
# put txts
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
fig.text(
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
fontsize=15, va='top', ha='left', color=txt_color)
0.01,
0.99,
'\n'.join(text),
transform=fig.axes[0].transAxes,
fontsize=15,
va='top',
ha='left',
color=txt_color)
# save or return figure
if path:
@@ -68,12 +87,14 @@ def make_matching_figure(
def _make_evaluation_figure(data, b_id, alpha='dynamic'):
b_mask = data['m_bids'] == b_id
conf_thr = _compute_conf_thresh(data)
img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(
np.int32)
img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(
np.int32)
kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
# for megadepth, we visualize matches on the resized image
if 'scale0' in data:
kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
@@ -92,18 +113,18 @@ def _make_evaluation_figure(data, b_id, alpha='dynamic'):
if alpha == 'dynamic':
alpha = dynamic_alpha(len(correct_mask))
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
text = [
f'#Matches {len(kpts0)}',
f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
]
# make the figure
figure = make_matching_figure(img0, img1, kpts0, kpts1,
color, text=text)
figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
return figure
def _make_confidence_figure(data, b_id):
# TODO: Implement confidence figure
raise NotImplementedError()
@@ -111,7 +132,7 @@ def _make_confidence_figure(data, b_id):
def make_matching_figures(data, config, mode='evaluation'):
""" Make matching figures for a batch.
Args:
data (Dict): a batch updated by PL_LoFTR.
config (Dict): matcher config
@@ -123,8 +144,7 @@ def make_matching_figures(data, config, mode='evaluation'):
for b_id in range(data['image0'].size(0)):
if mode == 'evaluation':
fig = _make_evaluation_figure(
data, b_id,
alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
elif mode == 'confidence':
fig = _make_confidence_figure(data, b_id)
else:
@@ -144,11 +164,14 @@ def dynamic_alpha(n_matches,
if _range[1] is None:
return _range[0]
return _range[1] + (milestones[loc + 1] - n_matches) / (
milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
milestones[loc + 1] - milestones[loc]) * (
_range[0] - _range[1])
def error_colormap(err, thr, alpha=1.0):
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
assert alpha <= 1.0 and alpha > 0, f'Invaid alpha value: {alpha}'
x = 1 - np.clip(err / (thr * 2), 0, 1)
return np.clip(
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
np.stack([2 - x * 2, x * 2,
np.zeros_like(x),
np.ones_like(x) * alpha], -1), 0, 1)

View File

@@ -1 +1 @@
from .default import lightglue_default_conf
from .default import lightglue_default_conf

View File

@@ -1,15 +1,15 @@
lightglue_default_conf = {
"features":"superpoint", # superpoint disk aliked sift
"name": "lightglue", # just for interfacing
"input_dim": 256, # input descriptor dimension (autoselected from weights)
"descriptor_dim": 256,
"add_scale_ori": False,
"n_layers": 9,
"num_heads": 4,
"flash": True, # enable FlashAttention if available.
"mp": False, # enable mixed precision
"depth_confidence": 0.95, # early stopping, disable with -1
"width_confidence": 0.99, # point pruning, disable with -1
"filter_threshold": 0.1, # match threshold
"weights": None,
'features': 'superpoint', # superpoint disk aliked sift
'name': 'lightglue', # just for interfacing
'input_dim': 256, # input descriptor dimension (autoselected from weights)
'descriptor_dim': 256,
'add_scale_ori': False,
'n_layers': 9,
'num_heads': 4,
'flash': True, # enable FlashAttention if available.
'mp': False, # enable mixed precision
'depth_confidence': 0.95, # early stopping, disable with -1
'width_confidence': 0.99, # point pruning, disable with -1
'filter_threshold': 0.1, # match threshold
'weights': None,
}

View File

@@ -45,16 +45,15 @@ from torchvision.models import resnet
from .utils import Extractor
def get_patches(
tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
) -> torch.Tensor:
def get_patches(tensor: torch.Tensor, required_corners: torch.Tensor,
ps: int) -> torch.Tensor:
c, h, w = tensor.shape
corner = (required_corners - ps / 2 + 1).long()
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
offset = torch.arange(0, ps)
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
kw = {'indexing': 'ij'} if torch.__version__ >= '1.10' else {}
x, y = torch.meshgrid(offset, offset, **kw)
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
patches = patches.to(corner) + corner[None, None]
@@ -70,8 +69,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int):
zeros = torch.zeros_like(scores)
max_mask = scores == torch.nn.functional.max_pool2d(
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
for _ in range(2):
supp_mask = (
@@ -80,18 +78,19 @@ def simple_nms(scores: torch.Tensor, nms_radius: int):
kernel_size=nms_radius * 2 + 1,
stride=1,
padding=nms_radius,
)
> 0
)
) > 0)
supp_scores = torch.where(supp_mask, zeros, scores)
new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)
supp_scores,
kernel_size=nms_radius * 2 + 1,
stride=1,
padding=nms_radius)
max_mask = max_mask | (new_max_mask & (~supp_mask))
return torch.where(max_mask, scores, zeros)
class DKD(nn.Module):
def __init__(
self,
radius: int = 2,
@@ -115,14 +114,15 @@ class DKD(nn.Module):
self.n_limit = n_limit
self.kernel_size = 2 * self.radius + 1
self.temperature = 0.1 # tuned temperature
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
self.unfold = nn.Unfold(
kernel_size=self.kernel_size, padding=self.radius)
# local xy grid
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
# (kernel_size*kernel_size) x 2 : (w,h)
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
kw = {'indexing': 'ij'} if torch.__version__ >= '1.10' else {}
self.hw_grid = (
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
)
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:,
[1, 0]])
def forward(
self,
@@ -141,29 +141,32 @@ class DKD(nn.Module):
nms_scores = simple_nms(scores_nograd, self.radius)
# remove border
nms_scores[:, :, : self.radius, :] = 0
nms_scores[:, :, :, : self.radius] = 0
nms_scores[:, :, :self.radius, :] = 0
nms_scores[:, :, :, :self.radius] = 0
if image_size is not None:
for i in range(scores_map.shape[0]):
w, h = image_size[i].long()
nms_scores[i, :, h.item() - self.radius :, :] = 0
nms_scores[i, :, :, w.item() - self.radius :] = 0
nms_scores[i, :, h.item() - self.radius:, :] = 0
nms_scores[i, :, :, w.item() - self.radius:] = 0
else:
nms_scores[:, :, -self.radius :, :] = 0
nms_scores[:, :, :, -self.radius :] = 0
nms_scores[:, :, -self.radius:, :] = 0
nms_scores[:, :, :, -self.radius:] = 0
# detect keypoints without grad
if self.top_k > 0:
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
indices_keypoints = [topk.indices[i]
for i in range(b)] # B x top_k
else:
if self.scores_th > 0:
masks = nms_scores > self.scores_th
if masks.sum() == 0:
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
th = scores_nograd.reshape(b, -1).mean(
dim=1) # th = self.scores_th
masks = nms_scores > th.reshape(b, 1, 1, 1)
else:
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
th = scores_nograd.reshape(b, -1).mean(
dim=1) # th = self.scores_th
masks = nms_scores > th.reshape(b, 1, 1, 1)
masks = masks.reshape(b, -1)
@@ -174,7 +177,7 @@ class DKD(nn.Module):
if len(indices) > self.n_limit:
kpts_sc = scores[indices]
sort_idx = kpts_sc.sort(descending=True)[1]
sel_idx = sort_idx[: self.n_limit]
sel_idx = sort_idx[:self.n_limit]
indices = indices[sel_idx]
indices_keypoints.append(indices)
@@ -190,34 +193,34 @@ class DKD(nn.Module):
for b_idx in range(b):
patch = patches[b_idx].t() # (H*W) x (kernel**2)
indices_kpt = indices_keypoints[
b_idx
] # one dimension vector, say its size is M
b_idx] # one dimension vector, say its size is M
patch_scores = patch[indices_kpt] # M x (kernel**2)
keypoints_xy_nms = torch.stack(
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
[
indices_kpt % w,
torch.div(indices_kpt, w, rounding_mode='trunc')
],
dim=1,
) # Mx2
# max is detached to prevent undesired backprop loops in the graph
max_v = patch_scores.max(dim=1).values.detach()[:, None]
x_exp = (
(patch_scores - max_v) / self.temperature
).exp() # M * (kernel**2), in [0, 1]
(patch_scores - max_v)
/ self.temperature).exp() # M * (kernel**2), in [0, 1]
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
xy_residual = (
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
) # Soft-argmax, Mx2
xy_residual = (x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
) # Soft-argmax, Mx2
hw_grid_dist2 = (
torch.norm(
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
/ self.radius,
dim=-1,
)
** 2
)
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
)**2)
scoredispersity = (x_exp * hw_grid_dist2).sum(
dim=1) / x_exp.sum(dim=1)
# compute result keypoints
keypoints_xy = keypoints_xy_nms + xy_residual
@@ -226,11 +229,9 @@ class DKD(nn.Module):
kptscore = torch.nn.functional.grid_sample(
scores_map[b_idx].unsqueeze(0),
keypoints_xy.view(1, 1, -1, 2),
mode="bilinear",
mode='bilinear',
align_corners=True,
)[
0, 0, 0, :
] # CxN
)[0, 0, 0, :] # CxN
keypoints.append(keypoints_xy)
scoredispersitys.append(scoredispersity)
@@ -238,24 +239,25 @@ class DKD(nn.Module):
else:
for b_idx in range(b):
indices_kpt = indices_keypoints[
b_idx
] # one dimension vector, say its size is M
b_idx] # one dimension vector, say its size is M
# To avoid warning: UserWarning: __floordiv__ is deprecated
keypoints_xy_nms = torch.stack(
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
[
indices_kpt % w,
torch.div(indices_kpt, w, rounding_mode='trunc')
],
dim=1,
) # Mx2
keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
kptscore = torch.nn.functional.grid_sample(
scores_map[b_idx].unsqueeze(0),
keypoints_xy.view(1, 1, -1, 2),
mode="bilinear",
mode='bilinear',
align_corners=True,
)[
0, 0, 0, :
] # CxN
)[0, 0, 0, :] # CxN
keypoints.append(keypoints_xy)
scoredispersitys.append(kptscore) # for jit.script compatability
scoredispersitys.append(
kptscore) # for jit.script compatability
kptscores.append(kptscore)
return keypoints, scoredispersitys, kptscores
@@ -278,17 +280,18 @@ class InputPadder(object):
def pad(self, x: torch.Tensor):
assert x.ndim == 4
return F.pad(x, self._pad, mode="replicate")
return F.pad(x, self._pad, mode='replicate')
def unpad(self, x: torch.Tensor):
assert x.ndim == 4
ht = x.shape[-2]
wd = x.shape[-1]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0] : c[1], c[2] : c[3]]
return x[..., c[0]:c[1], c[2]:c[3]]
class DeformableConv2d(nn.Module):
def __init__(
self,
in_channels,
@@ -304,9 +307,8 @@ class DeformableConv2d(nn.Module):
self.padding = padding
self.mask = mask
self.channel_num = (
3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
)
self.channel_num = (3 * kernel_size * kernel_size if mask else 2
* kernel_size * kernel_size)
self.offset_conv = nn.Conv2d(
in_channels,
self.channel_num,
@@ -356,10 +358,10 @@ def get_conv(
stride=1,
padding=1,
bias=False,
conv_type="conv",
conv_type='conv',
mask=False,
):
if conv_type == "conv":
if conv_type == 'conv':
conv = nn.Conv2d(
inplanes,
planes,
@@ -368,7 +370,7 @@ def get_conv(
padding=padding,
bias=bias,
)
elif conv_type == "dcn":
elif conv_type == 'dcn':
conv = DeformableConv2d(
inplanes,
planes,
@@ -384,13 +386,14 @@ def get_conv(
class ConvBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
gate: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
conv_type: str = "conv",
conv_type: str = 'conv',
mask: bool = False,
):
super().__init__()
@@ -401,12 +404,18 @@ class ConvBlock(nn.Module):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self.conv1 = get_conv(
in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
)
in_channels,
out_channels,
kernel_size=3,
conv_type=conv_type,
mask=mask)
self.bn1 = norm_layer(out_channels)
self.conv2 = get_conv(
out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
)
out_channels,
out_channels,
kernel_size=3,
conv_type=conv_type,
mask=mask)
self.bn2 = norm_layer(out_channels)
def forward(self, x):
@@ -430,7 +439,7 @@ class ResBlock(nn.Module):
dilation: int = 1,
gate: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
conv_type: str = "conv",
conv_type: str = 'conv',
mask: bool = False,
) -> None:
super(ResBlock, self).__init__()
@@ -441,18 +450,17 @@ class ResBlock(nn.Module):
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError("ResBlock only supports groups=1 and base_width=64")
raise ValueError(
'ResBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
raise NotImplementedError('Dilation > 1 not supported in ResBlock')
# Both self.conv1 and self.downsample layers
# downsample the input when stride != 1
self.conv1 = get_conv(
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
)
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask)
self.bn1 = norm_layer(planes)
self.conv2 = get_conv(
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
)
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
@@ -477,14 +485,15 @@ class ResBlock(nn.Module):
class SDDH(nn.Module):
def __init__(
self,
dims: int,
kernel_size: int = 3,
n_pos: int = 8,
gate=nn.ReLU(),
conv2D=False,
mask=False,
self,
dims: int,
kernel_size: int = 3,
n_pos: int = 8,
gate=nn.ReLU(),
conv2D=False,
mask=False,
):
super(SDDH, self).__init__()
self.kernel_size = kernel_size
@@ -518,18 +527,21 @@ class SDDH(nn.Module):
# sampled feature conv
self.sf_conv = nn.Conv2d(
dims, dims, kernel_size=1, stride=1, padding=0, bias=False
)
dims, dims, kernel_size=1, stride=1, padding=0, bias=False)
# convM
if not conv2D:
# deformable desc weights
agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
self.register_parameter("agg_weights", agg_weights)
self.register_parameter('agg_weights', agg_weights)
else:
self.convM = nn.Conv2d(
dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
)
dims * n_pos,
dims,
kernel_size=1,
stride=1,
padding=0,
bias=False)
def forward(self, x, keypoints):
# x: [B,C,H,W]
@@ -548,29 +560,28 @@ class SDDH(nn.Module):
if self.kernel_size > 1:
patch = self.get_patches_func(
xi, kptsi_wh.long(), self.kernel_size
) # [N_kpts, C, K, K]
xi, kptsi_wh.long(), self.kernel_size) # [N_kpts, C, K, K]
else:
kptsi_wh_long = kptsi_wh.long()
patch = (
xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
.permute(1, 0)
.reshape(N_kpts, c, 1, 1)
)
xi[:, kptsi_wh_long[:, 1],
kptsi_wh_long[:,
0]].permute(1,
0).reshape(N_kpts, c, 1, 1))
offset = self.offset_conv(patch).clamp(
-max_offset, max_offset
) # [N_kpts, 2*n_pos, 1, 1]
-max_offset, max_offset) # [N_kpts, 2*n_pos, 1, 1]
if self.mask:
offset = (
offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
) # [N_kpts, n_pos, 3]
offset = (offset[:, :, 0, 0].view(N_kpts, 3,
self.n_pos).permute(0, 2, 1)
) # [N_kpts, n_pos, 3]
offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
mask_weight = torch.sigmoid(offset[:, :,
-1]) # [N_kpts, n_pos]
else:
offset = (
offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
) # [N_kpts, n_pos, 2]
offset = (offset[:, :, 0, 0].view(N_kpts, 2,
self.n_pos).permute(0, 2, 1)
) # [N_kpts, n_pos, 2]
offsets.append(offset) # for visualization
# get sample positions
@@ -580,26 +591,23 @@ class SDDH(nn.Module):
# sample features
features = F.grid_sample(
xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
) # [1,C,(N_kpts*n_pos),1]
features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
1, 0, 2, 3
) # [N_kpts, C, n_pos, 1]
xi.unsqueeze(0), pos, mode='bilinear',
align_corners=True) # [1,C,(N_kpts*n_pos),1]
features = features.reshape(c, N_kpts, self.n_pos,
1).permute(1, 0, 2,
3) # [N_kpts, C, n_pos, 1]
if self.mask:
features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
features = torch.einsum('ncpo,np->ncpo', features, mask_weight)
features = torch.selu_(self.sf_conv(features)).squeeze(
-1
) # [N_kpts, C, n_pos]
-1) # [N_kpts, C, n_pos]
# convM
if not self.conv2D:
descs = torch.einsum(
"ncp,pcd->nd", features, self.agg_weights
) # [N_kpts, C]
descs = torch.einsum('ncp,pcd->nd', features,
self.agg_weights) # [N_kpts, C]
else:
features = features.reshape(N_kpts, -1)[
:, :, None, None
] # [N_kpts, C*n_pos, 1, 1]
features = features.reshape(
N_kpts, -1)[:, :, None, None] # [N_kpts, C*n_pos, 1, 1]
descs = self.convM(features).squeeze() # [N_kpts, C]
# normalize
@@ -611,34 +619,34 @@ class SDDH(nn.Module):
class ALIKED(Extractor):
default_conf = {
"model_name": "aliked-n16",
"max_num_keypoints": -1,
"detection_threshold": 0.2,
"nms_radius": 2,
'model_name': 'aliked-n16',
'max_num_keypoints': -1,
'detection_threshold': 0.2,
'nms_radius': 2,
}
checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
checkpoint_url = 'https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth'
n_limit_max = 20000
# c1, c2, c3, c4, dim, K, M
cfgs = {
"aliked-t16": [8, 16, 32, 64, 64, 3, 16],
"aliked-n16": [16, 32, 64, 128, 128, 3, 16],
"aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
"aliked-n32": [16, 32, 64, 128, 128, 3, 32],
'aliked-t16': [8, 16, 32, 64, 64, 3, 16],
'aliked-n16': [16, 32, 64, 128, 128, 3, 16],
'aliked-n16rot': [16, 32, 64, 128, 128, 3, 16],
'aliked-n32': [16, 32, 64, 128, 128, 3, 32],
}
preprocess_conf = {
"resize": 1024,
'resize': 1024,
}
required_data_keys = ["image"]
required_data_keys = ['image']
def __init__(self, **conf):
super().__init__(**conf) # Update with default configuration.
conf = self.conf
c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
conv_types = ["conv", "conv", "dcn", "dcn"]
conv_types = ['conv', 'conv', 'dcn', 'dcn']
conv2D = False
mask = False
@@ -647,7 +655,8 @@ class ALIKED(Extractor):
self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
self.norm = nn.BatchNorm2d
self.gate = nn.SELU(inplace=True)
self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
self.block1 = ConvBlock(
3, c1, self.gate, self.norm, conv_type=conv_types[0])
self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
@@ -657,17 +666,13 @@ class ALIKED(Extractor):
self.conv3 = resnet.conv1x1(c3, dim // 4)
self.conv4 = resnet.conv1x1(dim, dim // 4)
self.upsample2 = nn.Upsample(
scale_factor=2, mode="bilinear", align_corners=True
)
scale_factor=2, mode='bilinear', align_corners=True)
self.upsample4 = nn.Upsample(
scale_factor=4, mode="bilinear", align_corners=True
)
scale_factor=4, mode='bilinear', align_corners=True)
self.upsample8 = nn.Upsample(
scale_factor=8, mode="bilinear", align_corners=True
)
scale_factor=8, mode='bilinear', align_corners=True)
self.upsample32 = nn.Upsample(
scale_factor=32, mode="bilinear", align_corners=True
)
scale_factor=32, mode='bilinear', align_corners=True)
self.score_head = nn.Sequential(
resnet.conv1x1(dim, 8),
self.gate,
@@ -677,19 +682,19 @@ class ALIKED(Extractor):
self.gate,
resnet.conv3x3(4, 1),
)
self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
self.desc_head = SDDH(
dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
self.dkd = DKD(
radius=conf.nms_radius,
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
top_k=-1
if conf.detection_threshold > 0 else conf.max_num_keypoints,
scores_th=conf.detection_threshold,
n_limit=conf.max_num_keypoints
if conf.max_num_keypoints > 0
else self.n_limit_max,
if conf.max_num_keypoints > 0 else self.n_limit_max,
)
state_dict = torch.hub.load_state_dict_from_url(
self.checkpoint_url.format(conf.model_name), map_location="cpu"
)
self.checkpoint_url.format(conf.model_name), map_location='cpu')
self.load_state_dict(state_dict, strict=True)
def get_resblock(self, c_in, c_out, conv_type, mask):
@@ -738,13 +743,12 @@ class ALIKED(Extractor):
return feature_map, score_map
def forward(self, data: dict) -> dict:
image = data["image"]
image = data['image']
if image.shape[1] == 1:
image = grayscale_to_rgb(image)
feature_map, score_map = self.extract_dense_map(image)
keypoints, kptscores, scoredispersitys = self.dkd(
score_map, image_size=data.get("image_size")
)
score_map, image_size=data.get('image_size'))
descriptors, offsets = self.desc_head(feature_map, keypoints)
_, _, h, w = image.shape
@@ -752,7 +756,7 @@ class ALIKED(Extractor):
# no padding required
# we can set detection_threshold=-1 and conf.max_num_keypoints > 0
return {
"keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
"descriptors": torch.stack(descriptors), # B x N x D
"keypoint_scores": torch.stack(kptscores), # B x N
'keypoints': wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
'descriptors': torch.stack(descriptors), # B x N x D
'keypoint_scores': torch.stack(kptscores), # B x N
}

View File

@@ -6,20 +6,20 @@ from .utils import Extractor
class DISK(Extractor):
default_conf = {
"weights": "depth",
"max_num_keypoints": None,
"desc_dim": 128,
"nms_window_size": 5,
"detection_threshold": 0.0,
"pad_if_not_divisible": True,
'weights': 'depth',
'max_num_keypoints': None,
'desc_dim': 128,
'nms_window_size': 5,
'detection_threshold': 0.0,
'pad_if_not_divisible': True,
}
preprocess_conf = {
"resize": 1024,
"grayscale": False,
'resize': 1024,
'grayscale': False,
}
required_data_keys = ["image"]
required_data_keys = ['image']
def __init__(self, **conf) -> None:
super().__init__(**conf) # Update with default configuration.
@@ -28,8 +28,8 @@ class DISK(Extractor):
def forward(self, data: dict) -> dict:
"""Compute keypoints, scores, descriptors for image"""
for key in self.required_data_keys:
assert key in data, f"Missing key {key} in data"
image = data["image"]
assert key in data, f'Missing key {key} in data'
image = data['image']
if image.shape[1] == 1:
image = kornia.color.grayscale_to_rgb(image)
features = self.model(
@@ -49,7 +49,7 @@ class DISK(Extractor):
descriptors = torch.stack(descriptors, 0)
return {
"keypoints": keypoints.to(image).contiguous(),
"keypoint_scores": scores.to(image).contiguous(),
"descriptors": descriptors.to(image).contiguous(),
'keypoints': keypoints.to(image).contiguous(),
'keypoint_scores': scores.to(image).contiguous(),
'descriptors': descriptors.to(image).contiguous(),
}

View File

@@ -1,3 +1,4 @@
import os.path as osp
import warnings
from pathlib import Path
from types import SimpleNamespace
@@ -8,13 +9,12 @@ import torch
import torch.nn.functional as F
from torch import nn
import os.path as osp
try:
from flash_attn.modules.mha import FlashCrossAttention
except ModuleNotFoundError:
FlashCrossAttention = None
if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
if FlashCrossAttention or hasattr(F, 'scaled_dot_product_attention'):
FLASH_AVAILABLE = True
else:
FLASH_AVAILABLE = False
@@ -23,9 +23,8 @@ torch.backends.cudnn.deterministic = True
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
def normalize_keypoints(
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
) -> torch.Tensor:
def normalize_keypoints(kpts: torch.Tensor,
size: Optional[torch.Tensor] = None) -> torch.Tensor:
if size is None:
size = 1 + kpts.max(-2).values - kpts.min(-2).values
elif not isinstance(size, torch.Tensor):
@@ -41,11 +40,14 @@ def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
if length <= x.shape[-2]:
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
pad = torch.ones(
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
)
*x.shape[:-2],
length - x.shape[-2],
x.shape[-1],
device=x.device,
dtype=x.dtype)
y = torch.cat([x, pad], dim=-2)
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
mask[..., : x.shape[-2], :] = True
mask[..., :x.shape[-2], :] = True
return y, mask
@@ -55,12 +57,18 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor:
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
def apply_cached_rotary_emb(freqs: torch.Tensor,
t: torch.Tensor) -> torch.Tensor:
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
class LearnableFourierPositionalEncoding(nn.Module):
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
def __init__(self,
M: int,
dim: int,
F_dim: int = None,
gamma: float = 1.0) -> None:
super().__init__()
F_dim = F_dim if F_dim is not None else dim
self.gamma = gamma
@@ -76,6 +84,7 @@ class LearnableFourierPositionalEncoding(nn.Module):
class TokenConfidence(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
@@ -89,27 +98,33 @@ class TokenConfidence(nn.Module):
class Attention(nn.Module):
def __init__(self, allow_flash: bool) -> None:
super().__init__()
if allow_flash and not FLASH_AVAILABLE:
warnings.warn(
"FlashAttention is not available. For optimal speed, "
"consider installing torch >= 2.0 or flash-attn.",
'FlashAttention is not available. For optimal speed, '
'consider installing torch >= 2.0 or flash-attn.',
stacklevel=2,
)
self.enable_flash = allow_flash and FLASH_AVAILABLE
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
self.has_sdp = hasattr(F, 'scaled_dot_product_attention')
if allow_flash and FlashCrossAttention:
self.flash_ = FlashCrossAttention()
if self.has_sdp:
torch.backends.cuda.enable_flash_sdp(allow_flash)
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.enable_flash and q.device.type == "cuda":
def forward(self,
q,
k,
v,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
if self.enable_flash and q.device.type == 'cuda':
# use torch 2.0 scaled_dot_product_attention with flash
if self.has_sdp:
args = [x.half().contiguous() for x in [q, k, v]]
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
v = F.scaled_dot_product_attention(
*args, attn_mask=mask).to(q.dtype)
return v if mask is None else v.nan_to_num()
else:
assert mask is None
@@ -121,18 +136,21 @@ class Attention(nn.Module):
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
return v if mask is None else v.nan_to_num()
else:
s = q.shape[-1] ** -0.5
sim = torch.einsum("...id,...jd->...ij", q, k) * s
s = q.shape[-1]**-0.5
sim = torch.einsum('...id,...jd->...ij', q, k) * s
if mask is not None:
sim.masked_fill(~mask, -float("inf"))
sim.masked_fill(~mask, -float('inf'))
attn = F.softmax(sim, -1)
return torch.einsum("...ij,...jd->...id", attn, v)
return torch.einsum('...ij,...jd->...id', attn, v)
class SelfBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
) -> None:
def __init__(self,
embed_dim: int,
num_heads: int,
flash: bool = False,
bias: bool = True) -> None:
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
@@ -165,9 +183,12 @@ class SelfBlock(nn.Module):
class CrossBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
) -> None:
def __init__(self,
embed_dim: int,
num_heads: int,
flash: bool = False,
bias: bool = True) -> None:
super().__init__()
self.heads = num_heads
dim_head = embed_dim // num_heads
@@ -190,32 +211,35 @@ class CrossBlock(nn.Module):
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
return func(x0), func(x1)
def forward(
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
) -> List[torch.Tensor]:
def forward(self,
x0: torch.Tensor,
x1: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> List[torch.Tensor]:
qk0, qk1 = self.map_(self.to_qk, x0, x1)
v0, v1 = self.map_(self.to_v, x0, x1)
qk0, qk1, v0, v1 = map(
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
(qk0, qk1, v0, v1),
)
if self.flash is not None and qk0.device.type == "cuda":
if self.flash is not None and qk0.device.type == 'cuda':
m0 = self.flash(qk0, qk1, v1, mask)
m1 = self.flash(
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
)
qk1, qk0, v0,
mask.transpose(-1, -2) if mask is not None else None)
else:
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
sim = torch.einsum('bhid, bhjd -> bhij', qk0, qk1)
if mask is not None:
sim = sim.masked_fill(~mask, -float("inf"))
sim = sim.masked_fill(~mask, -float('inf'))
attn01 = F.softmax(sim, dim=-1)
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1)
m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1),
v0)
if mask is not None:
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
m0, m1)
m0, m1 = self.map_(self.to_out, m0, m1)
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
@@ -223,6 +247,7 @@ class CrossBlock(nn.Module):
class TransformerLayer(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self.self_attn = SelfBlock(*args, **kwargs)
@@ -238,7 +263,8 @@ class TransformerLayer(nn.Module):
mask1: Optional[torch.Tensor] = None,
):
if mask0 is not None and mask1 is not None:
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
return self.masked_forward(desc0, desc1, encoding0, encoding1,
mask0, mask1)
else:
desc0 = self.self_attn(desc0, encoding0)
desc1 = self.self_attn(desc1, encoding1)
@@ -254,14 +280,14 @@ class TransformerLayer(nn.Module):
return self.cross_attn(desc0, desc1, mask)
def sigmoid_log_double_softmax(
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
) -> torch.Tensor:
def sigmoid_log_double_softmax(sim: torch.Tensor, z0: torch.Tensor,
z1: torch.Tensor) -> torch.Tensor:
"""create the log assignment matrix from logits and similarity"""
b, m, n = sim.shape
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
scores0 = F.log_softmax(sim, 2)
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(),
2).transpose(-1, -2)
scores = sim.new_full((b, m + 1, n + 1), 0)
scores[:, :m, :n] = scores0 + scores1 + certainties
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
@@ -270,6 +296,7 @@ def sigmoid_log_double_softmax(
class MatchAssignment(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.dim = dim
@@ -281,7 +308,7 @@ class MatchAssignment(nn.Module):
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
_, _, d = mdesc0.shape
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
z0 = self.matchability(desc0)
z1 = self.matchability(desc1)
scores = sigmoid_log_double_softmax(sim, z0, z1)
@@ -315,34 +342,34 @@ class LightGlue(nn.Module):
# Point pruning involves an overhead (gather).
# Therefore, we only activate it if there are enough keypoints.
pruning_keypoint_thresholds = {
"cpu": -1,
"mps": -1,
"cuda": 1024,
"flash": 1536,
'cpu': -1,
'mps': -1,
'cuda': 1024,
'flash': 1536,
}
required_data_keys = ["image0", "image1"]
required_data_keys = ['image0', 'image1']
version = "v0.1_arxiv"
weight_path = "{}_lightglue.pth"
version = 'v0.1_arxiv'
weight_path = '{}_lightglue.pth'
features = {
"superpoint": {
"weights": "superpoint_lightglue",
"input_dim": 256,
'superpoint': {
'weights': 'superpoint_lightglue',
'input_dim': 256,
},
"disk": {
"weights": "disk_lightglue",
"input_dim": 128,
'disk': {
'weights': 'disk_lightglue',
'input_dim': 128,
},
"aliked": {
"weights": "aliked_lightglue",
"input_dim": 128,
'aliked': {
'weights': 'aliked_lightglue',
'input_dim': 128,
},
"sift": {
"weights": "sift_lightglue",
"input_dim": 128,
"add_scale_ori": True,
'sift': {
'weights': 'sift_lightglue',
'input_dim': 128,
'add_scale_ori': True,
},
}
@@ -352,77 +379,78 @@ class LightGlue(nn.Module):
if conf.features is not None:
if conf.features not in self.features:
raise ValueError(
f"Unsupported features: {conf.features} not in "
f"{{{','.join(self.features)}}}"
)
f'Unsupported features: {conf.features} not in '
f"{{{','.join(self.features)}}}")
for k, v in self.features[conf.features].items():
setattr(conf, k, v)
if conf.input_dim != conf.descriptor_dim:
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
self.input_proj = nn.Linear(
conf.input_dim, conf.descriptor_dim, bias=True)
else:
self.input_proj = nn.Identity()
head_dim = conf.descriptor_dim // conf.num_heads
self.posenc = LearnableFourierPositionalEncoding(
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
)
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim)
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
self.transformers = nn.ModuleList(
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
)
[TransformerLayer(d, h, conf.flash) for _ in range(n)])
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
self.log_assignment = nn.ModuleList(
[MatchAssignment(d) for _ in range(n)])
self.token_confidence = nn.ModuleList(
[TokenConfidence(d) for _ in range(n - 1)]
)
[TokenConfidence(d) for _ in range(n - 1)])
self.register_buffer(
"confidence_thresholds",
torch.Tensor(
[self.confidence_threshold(i) for i in range(self.conf.n_layers)]
),
'confidence_thresholds',
torch.Tensor([
self.confidence_threshold(i) for i in range(self.conf.n_layers)
]),
)
state_dict = None
if conf.features is not None:
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
state_dict = torch.load(
osp.join(model_dir,
self.weight_path.format(conf.features)), map_location="cpu"
)
osp.join(model_dir, self.weight_path.format(conf.features)),
map_location='cpu')
self.load_state_dict(state_dict, strict=False)
elif conf.weights is not None:
path = Path(__file__).parent
path = path / "weights/{}.pth".format(self.conf.weights)
state_dict = torch.load(str(path), map_location="cpu")
path = path / 'weights/{}.pth'.format(self.conf.weights)
state_dict = torch.load(str(path), map_location='cpu')
if state_dict:
# rename old state dict entries
for i in range(self.conf.n_layers):
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
pattern = f'self_attn.{i}', f'transformers.{i}.self_attn'
state_dict = {
k.replace(*pattern): v
for k, v in state_dict.items()
}
pattern = f'cross_attn.{i}', f'transformers.{i}.cross_attn'
state_dict = {
k.replace(*pattern): v
for k, v in state_dict.items()
}
self.load_state_dict(state_dict, strict=False)
# static lengths LightGlue is compiled for (only used with torch.compile)
self.static_lengths = None
def compile(
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
):
def compile(self,
mode='reduce-overhead',
static_lengths=[256, 512, 768, 1024, 1280, 1536]):
if self.conf.width_confidence != -1:
warnings.warn(
"Point pruning is partially disabled for compiled forward.",
'Point pruning is partially disabled for compiled forward.',
stacklevel=2,
)
for i in range(self.conf.n_layers):
self.transformers[i].masked_forward = torch.compile(
self.transformers[i].masked_forward, mode=mode, fullgraph=True
)
self.transformers[i].masked_forward, mode=mode, fullgraph=True)
self.static_lengths = static_lengths
@@ -447,30 +475,30 @@ class LightGlue(nn.Module):
matching_scores1: [B x N]
matches: List[[Si x 2]], scores: List[[Si]]
"""
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
with torch.autocast(enabled=self.conf.mp, device_type='cuda'):
return self._forward(data)
def _forward(self, data: dict) -> dict:
for key in self.required_data_keys:
assert key in data, f"Missing key {key} in data"
data0, data1 = data["image0"], data["image1"]
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
assert key in data, f'Missing key {key} in data'
data0, data1 = data['image0'], data['image1']
kpts0, kpts1 = data0['keypoints'], data1['keypoints']
b, m, _ = kpts0.shape
b, n, _ = kpts1.shape
device = kpts0.device
size0, size1 = data0.get("image_size"), data1.get("image_size")
size0, size1 = data0.get('image_size'), data1.get('image_size')
kpts0 = normalize_keypoints(kpts0, size0).clone()
kpts1 = normalize_keypoints(kpts1, size1).clone()
if self.conf.add_scale_ori:
kpts0 = torch.cat(
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)
[kpts0] + [data0[k].unsqueeze(-1) for k in ('scales', 'oris')],
-1)
kpts1 = torch.cat(
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
)
desc0 = data0["descriptors"].detach().contiguous()
desc1 = data1["descriptors"].detach().contiguous()
[kpts1] + [data1[k].unsqueeze(-1) for k in ('scales', 'oris')],
-1)
desc0 = data0['descriptors'].detach().contiguous()
desc1 = data1['descriptors'].detach().contiguous()
assert desc0.shape[-1] == self.conf.input_dim
assert desc1.shape[-1] == self.conf.input_dim
@@ -507,14 +535,14 @@ class LightGlue(nn.Module):
token0, token1 = None, None
for i in range(self.conf.n_layers):
desc0, desc1 = self.transformers[i](
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
)
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1)
if i == self.conf.n_layers - 1:
continue # no early stopping or adaptive width at last layer
if do_early_stop:
token0, token1 = self.token_confidence[i](desc0, desc1)
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :],
i, m + n):
break
if do_point_pruning and desc0.shape[-2] > pruning_th:
scores0 = self.log_assignment[i].get_matchability(desc0)
@@ -535,7 +563,8 @@ class LightGlue(nn.Module):
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
scores, _ = self.log_assignment[i](desc0, desc1)
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
m0, m1, mscores0, mscores1 = filter_matches(scores,
self.conf.filter_threshold)
matches, mscores = [], []
for k in range(b):
valid = m0[k] > -1
@@ -551,8 +580,10 @@ class LightGlue(nn.Module):
if do_point_pruning:
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
m0_[:, ind0] = torch.where(m0 == -1, -1,
ind1.gather(1, m0.clamp(min=0)))
m1_[:, ind1] = torch.where(m1 == -1, -1,
ind0.gather(1, m1.clamp(min=0)))
mscores0_ = torch.zeros((b, m), device=mscores0.device)
mscores1_ = torch.zeros((b, n), device=mscores1.device)
mscores0_[:, ind0] = mscores0
@@ -563,15 +594,15 @@ class LightGlue(nn.Module):
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
pred = {
"matches0": m0,
"matches1": m1,
"matching_scores0": mscores0,
"matching_scores1": mscores1,
"stop": i + 1,
"matches": matches,
"scores": mscores,
"prune0": prune0,
"prune1": prune1,
'matches0': m0,
'matches1': m1,
'matching_scores0': mscores0,
'matching_scores1': mscores1,
'stop': i + 1,
'matches': matches,
'scores': mscores,
'prune0': prune0,
'prune1': prune1,
}
return pred
@@ -581,9 +612,8 @@ class LightGlue(nn.Module):
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
return np.clip(threshold, 0, 1)
def get_pruning_mask(
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
) -> torch.Tensor:
def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor,
layer_index: int) -> torch.Tensor:
"""mask points which should be removed"""
keep = scores > (1 - self.conf.width_confidence)
if confidences is not None: # Low-confidence points are never pruned.
@@ -600,11 +630,12 @@ class LightGlue(nn.Module):
"""evaluate stopping condition"""
confidences = torch.cat([confidences0, confidences1], -1)
threshold = self.confidence_thresholds[layer_index]
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
ratio_confident = 1.0 - (
confidences < threshold).float().sum() / num_points # noqa E501
return ratio_confident > self.conf.depth_confidence
def pruning_min_kpts(self, device: torch.device):
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
return self.pruning_keypoint_thresholds["flash"]
if self.conf.flash and FLASH_AVAILABLE and device.type == 'cuda':
return self.pruning_keypoint_thresholds['flash']
else:
return self.pruning_keypoint_thresholds[device.type]

View File

@@ -6,15 +6,20 @@ import torch
from kornia.color import rgb_to_grayscale
from packaging import version
from .utils import Extractor
try:
import pycolmap
except ImportError:
pycolmap = None
from .utils import Extractor
def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
def filter_dog_point(points,
scales,
angles,
image_shape,
nms_radius,
scores=None):
h, w = image_shape
ij = np.round(points - 0.5).astype(int).T[::-1]
@@ -72,59 +77,59 @@ def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
points = np.array([k.pt for k in detections], dtype=np.float32)
scores = np.array([k.response for k in detections], dtype=np.float32)
scales = np.array([k.size for k in detections], dtype=np.float32)
angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
angles = np.deg2rad(
np.array([k.angle for k in detections], dtype=np.float32))
return points, scores, scales, angles, descriptors
class SIFT(Extractor):
default_conf = {
"rootsift": True,
"nms_radius": 0, # None to disable filtering entirely.
"max_num_keypoints": 4096,
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
"detection_threshold": 0.0066667, # from COLMAP
"edge_threshold": 10,
"first_octave": -1, # only used by pycolmap, the default of COLMAP
"num_octaves": 4,
'rootsift': True,
'nms_radius': 0, # None to disable filtering entirely.
'max_num_keypoints': 4096,
'backend':
'opencv', # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
'detection_threshold': 0.0066667, # from COLMAP
'edge_threshold': 10,
'first_octave': -1, # only used by pycolmap, the default of COLMAP
'num_octaves': 4,
}
preprocess_conf = {
"resize": 1024,
'resize': 1024,
}
required_data_keys = ["image"]
required_data_keys = ['image']
def __init__(self, **conf):
super().__init__(**conf) # Update with default configuration.
backend = self.conf.backend
if backend.startswith("pycolmap"):
if backend.startswith('pycolmap'):
if pycolmap is None:
raise ImportError(
"Cannot find module pycolmap: install it with pip"
"or use backend=opencv."
)
'Cannot find module pycolmap: install it with pip'
'or use backend=opencv.')
options = {
"peak_threshold": self.conf.detection_threshold,
"edge_threshold": self.conf.edge_threshold,
"first_octave": self.conf.first_octave,
"num_octaves": self.conf.num_octaves,
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
'peak_threshold': self.conf.detection_threshold,
'edge_threshold': self.conf.edge_threshold,
'first_octave': self.conf.first_octave,
'num_octaves': self.conf.num_octaves,
'normalization':
pycolmap.Normalization.L2, # L1_ROOT is buggy.
}
device = (
"auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
)
if (
backend == "pycolmap_cpu" or not pycolmap.has_cuda
) and pycolmap.__version__ < "0.5.0":
device = ('auto' if backend == 'pycolmap' else backend.replace(
'pycolmap_', ''))
if (backend == 'pycolmap_cpu' or not pycolmap.has_cuda
) and pycolmap.__version__ < '0.5.0': # noqa E501
warnings.warn(
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
"consider upgrading pycolmap or use the CUDA version.",
'The pycolmap CPU SIFT is buggy in version < 0.5.0, '
'consider upgrading pycolmap or use the CUDA version.',
stacklevel=1,
)
else:
options["max_num_features"] = self.conf.max_num_keypoints
options['max_num_features'] = self.conf.max_num_keypoints
self.sift = pycolmap.Sift(options=options, device=device)
elif backend == "opencv":
elif backend == 'opencv':
self.sift = cv2.SIFT_create(
contrastThreshold=self.conf.detection_threshold,
nfeatures=self.conf.max_num_keypoints,
@@ -132,56 +137,52 @@ class SIFT(Extractor):
nOctaveLayers=self.conf.num_octaves,
)
else:
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
raise ValueError(
f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
)
backends = {'opencv', 'pycolmap', 'pycolmap_cpu', 'pycolmap_cuda'}
raise ValueError(f'Unknown backend: {backend} not in '
f"{{{','.join(backends)}}}.")
def extract_single_image(self, image: torch.Tensor):
image_np = image.cpu().numpy().squeeze(0)
if self.conf.backend.startswith("pycolmap"):
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
if self.conf.backend.startswith('pycolmap'):
if version.parse(pycolmap.__version__) >= version.parse('0.5.0'):
detections, descriptors = self.sift.extract(image_np)
scores = None # Scores are not exposed by COLMAP anymore.
else:
detections, scores, descriptors = self.sift.extract(image_np)
keypoints = detections[:, :2] # Keep only (x, y).
scales, angles = detections[:, -2:].T
if scores is not None and (
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
):
if scores is not None and (self.conf.backend == 'pycolmap_cpu'
or not pycolmap.has_cuda):
# Set the scores as a combination of abs. response and scale.
scores = np.abs(scores) * scales
elif self.conf.backend == "opencv":
elif self.conf.backend == 'opencv':
# TODO: Check if opencv keypoints are already in corner convention
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
self.sift, (image_np * 255.0).astype(np.uint8)
)
self.sift, (image_np * 255.0).astype(np.uint8))
pred = {
"keypoints": keypoints,
"scales": scales,
"oris": angles,
"descriptors": descriptors,
'keypoints': keypoints,
'scales': scales,
'oris': angles,
'descriptors': descriptors,
}
if scores is not None:
pred["keypoint_scores"] = scores
pred['keypoint_scores'] = scores
# sometimes pycolmap returns points outside the image. We remove them
if self.conf.backend.startswith("pycolmap"):
is_inside = (
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
).all(-1)
if self.conf.backend.startswith('pycolmap'):
is_inside = (pred['keypoints'] + 0.5 < np.array(
[image_np.shape[-2:][::-1]])).all(-1)
pred = {k: v[is_inside] for k, v in pred.items()}
if self.conf.nms_radius is not None:
keep = filter_dog_point(
pred["keypoints"],
pred["scales"],
pred["oris"],
pred['keypoints'],
pred['scales'],
pred['oris'],
image_np.shape,
self.conf.nms_radius,
scores=pred.get("keypoint_scores"),
scores=pred.get('keypoint_scores'),
)
pred = {k: v[keep] for k, v in pred.items()}
@@ -189,14 +190,15 @@ class SIFT(Extractor):
if scores is not None:
# Keep the k keypoints with highest score
num_points = self.conf.max_num_keypoints
if num_points is not None and len(pred["keypoints"]) > num_points:
indices = torch.topk(pred["keypoint_scores"], num_points).indices
if num_points is not None and len(pred['keypoints']) > num_points:
indices = torch.topk(pred['keypoint_scores'],
num_points).indices
pred = {k: v[indices] for k, v in pred.items()}
return pred
def forward(self, data: dict) -> dict:
image = data["image"]
image = data['image']
if image.shape[1] == 3:
image = rgb_to_grayscale(image)
device = image.device
@@ -204,13 +206,16 @@ class SIFT(Extractor):
pred = []
for k in range(len(image)):
img = image[k]
if "image_size" in data.keys():
if 'image_size' in data.keys():
# avoid extracting points in padded areas
w, h = data["image_size"][k]
w, h = data['image_size'][k]
img = img[:, :h, :w]
p = self.extract_single_image(img)
pred.append(p)
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
pred = {
k: torch.stack([p[k] for p in pred], 0).to(device)
for k in pred[0]
}
if self.conf.rootsift:
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
pred['descriptors'] = sift_to_rootsift(pred['descriptors'])
return pred

View File

@@ -42,12 +42,13 @@
# Adapted by Remi Pautrat, Philipp Lindenberger
import os.path as osp
import torch
from kornia.color import rgb_to_grayscale
from torch import nn
from .utils import Extractor
import os.path as osp
def simple_nms(scores, nms_radius: int):
@@ -56,8 +57,7 @@ def simple_nms(scores, nms_radius: int):
def max_pool(x):
return torch.nn.functional.max_pool2d(
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
)
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
zeros = torch.zeros_like(scores)
max_mask = scores == max_pool(scores)
@@ -80,19 +80,14 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
"""Interpolate descriptors at keypoint locations"""
b, c, h, w = descriptors.shape
keypoints = keypoints - s / 2 + 0.5
keypoints /= torch.tensor(
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
).to(
keypoints
)[None]
keypoints /= torch.tensor([(w * s - s / 2 - 0.5),
(h * s - s / 2 - 0.5)], ).to(keypoints)[None]
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
descriptors = torch.nn.functional.grid_sample(
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
)
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
descriptors = torch.nn.functional.normalize(
descriptors.reshape(b, c, -1), p=2, dim=1
)
descriptors.reshape(b, c, -1), p=2, dim=1)
return descriptors
@@ -106,20 +101,20 @@ class SuperPoint(Extractor):
"""
default_conf = {
"descriptor_dim": 256,
"nms_radius": 4,
"max_num_keypoints": None,
"detection_threshold": 0.0005,
"remove_borders": 4,
'descriptor_dim': 256,
'nms_radius': 4,
'max_num_keypoints': None,
'detection_threshold': 0.0005,
'remove_borders': 4,
}
preprocess_conf = {
"resize": 1024,
'resize': 1024,
}
required_data_keys = ["image"]
required_data_keys = ['image']
def __init__(self,model_dir, **conf):
def __init__(self, model_dir, **conf):
super().__init__(**conf) # Update with default configuration.
self.relu = nn.ReLU(inplace=True)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
@@ -139,21 +134,19 @@ class SuperPoint(Extractor):
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
self.convDb = nn.Conv2d(
c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
)
c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0)
weights_path = osp.join(model_dir,"superpoint_v1.pth")
self.load_state_dict(torch.load(weights_path, map_location="cpu"))
weights_path = osp.join(model_dir, 'superpoint_v1.pth')
self.load_state_dict(torch.load(weights_path, map_location='cpu'))
if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
raise ValueError("max_num_keypoints must be positive or None")
raise ValueError('max_num_keypoints must be positive or None')
def forward(self, data: dict) -> dict:
"""Compute keypoints, scores, descriptors for image"""
for key in self.required_data_keys:
assert key in data, f"Missing key {key} in data"
image = data["image"]
assert key in data, f'Missing key {key} in data'
image = data['image']
if image.shape[1] == 3:
image = rgb_to_grayscale(image)
@@ -193,20 +186,18 @@ class SuperPoint(Extractor):
# Separate into batches
keypoints = [
torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i]
for i in range(b)
]
scores = [scores[best_kp[0] == i] for i in range(b)]
# Keep the k keypoints with highest score
if self.conf.max_num_keypoints is not None:
keypoints, scores = list(
zip(
*[
top_k_keypoints(k, s, self.conf.max_num_keypoints)
for k, s in zip(keypoints, scores)
]
)
)
zip(*[
top_k_keypoints(k, s, self.conf.max_num_keypoints)
for k, s in zip(keypoints, scores)
]))
# Convert (h, w) to (x, y)
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
@@ -223,7 +214,10 @@ class SuperPoint(Extractor):
]
return {
"keypoints": torch.stack(keypoints, 0),
"keypoint_scores": torch.stack(scores, 0),
"descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
'keypoints':
torch.stack(keypoints, 0),
'keypoint_scores':
torch.stack(scores, 0),
'descriptors':
torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
}

View File

@@ -11,11 +11,11 @@ import torch
class ImagePreprocessor:
default_conf = {
"resize": None, # target edge length, None for no resizing
"side": "long",
"interpolation": "bilinear",
"align_corners": None,
"antialias": True,
'resize': None, # target edge length, None for no resizing
'side': 'long',
'interpolation': 'bilinear',
'align_corners': None,
'antialias': True,
}
def __init__(self, **conf) -> None:
@@ -52,7 +52,9 @@ def map_tensor(input_, func: Callable):
return input_
def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
def batch_to_device(batch: dict,
device: str = 'cpu',
non_blocking: bool = True):
"""Move batch (dict) to device"""
def _func(tensor):
@@ -72,11 +74,11 @@ def rbd(data: dict) -> dict:
def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
"""Read an image from path as RGB or grayscale"""
if not Path(path).exists():
raise FileNotFoundError(f"No image at path {path}.")
raise FileNotFoundError(f'No image at path {path}.')
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
image = cv2.imread(str(path), mode)
if image is None:
raise IOError(f"Could not read image at {path}.")
raise IOError(f'Could not read image at {path}.')
if not grayscale:
image = image[..., ::-1]
return image
@@ -89,20 +91,20 @@ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
elif image.ndim == 2:
image = image[None] # add channel axis
else:
raise ValueError(f"Not an image: {image.shape}")
raise ValueError(f'Not an image: {image.shape}')
return torch.tensor(image / 255.0, dtype=torch.float)
def resize_image(
image: np.ndarray,
size: Union[List[int], int],
fn: str = "max",
interp: Optional[str] = "area",
fn: str = 'max',
interp: Optional[str] = 'area',
) -> np.ndarray:
"""Resize an image to a fixed size, or according to max or min edge."""
h, w = image.shape[:2]
fn = {"max": max, "min": min}[fn]
fn = {'max': max, 'min': min}[fn]
if isinstance(size, int):
scale = size / fn(h, w)
h_new, w_new = int(round(h * scale)), int(round(w * scale))
@@ -111,12 +113,12 @@ def resize_image(
h_new, w_new = size
scale = (w_new / w, h_new / h)
else:
raise ValueError(f"Incorrect new size: {size}")
raise ValueError(f'Incorrect new size: {size}')
mode = {
"linear": cv2.INTER_LINEAR,
"cubic": cv2.INTER_CUBIC,
"nearest": cv2.INTER_NEAREST,
"area": cv2.INTER_AREA,
'linear': cv2.INTER_LINEAR,
'cubic': cv2.INTER_CUBIC,
'nearest': cv2.INTER_NEAREST,
'area': cv2.INTER_AREA,
}[interp]
return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
@@ -129,6 +131,7 @@ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
class Extractor(torch.nn.Module):
def __init__(self, **conf):
super().__init__()
self.conf = SimpleNamespace(**{**self.default_conf, **conf})
@@ -140,10 +143,14 @@ class Extractor(torch.nn.Module):
img = img[None] # add batch dim
assert img.dim() == 4 and img.shape[0] == 1
shape = img.shape[-2:][::-1]
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
feats = self.forward({"image": img})
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
img, scales = ImagePreprocessor(**{
**self.preprocess_conf,
**conf
})(
img)
feats = self.forward({'image': img})
feats['image_size'] = torch.tensor(shape)[None].to(img).float()
feats['keypoints'] = (feats['keypoints'] + 0.5) / scales[None] - 0.5
return feats
@@ -152,13 +159,13 @@ def match_pair(
matcher,
image0: torch.Tensor,
image1: torch.Tensor,
device: str = "cpu",
device: str = 'cpu',
**preprocess,
):
"""Match a pair of images (image0, image1) with an extractor and matcher"""
feats0 = extractor.extract(image0, **preprocess)
feats1 = extractor.extract(image1, **preprocess)
matches01 = matcher({"image0": feats0, "image1": feats1})
matches01 = matcher({'image0': feats0, 'image1': feats1})
data = [feats0, feats1, matches01]
# remove batch dim and move to target device
feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]

View File

@@ -22,10 +22,12 @@ def cm_RdGn(x):
def cm_BlRdGn(x_):
"""Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
x = np.clip(x_, 0, 1)[..., None] * 2
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array(
[[1.0, 0, 0, 1.0]])
xn = -np.clip(x_, -1, 0)[..., None] * 2
cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array(
[[1.0, 0, 0, 1.0]])
out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
return out
@@ -39,7 +41,12 @@ def cm_prune(x_):
return cm_BlRdGn(norm_x)
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
def plot_images(imgs,
titles=None,
cmaps='gray',
dpi=100,
pad=0.5,
adaptive=True):
"""Plot a set of images horizontally.
Args:
imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
@@ -49,9 +56,8 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
"""
# conversion to (H, W, 3) for torch.Tensor
imgs = [
img.permute(1, 2, 0).cpu().numpy()
if (isinstance(img, torch.Tensor) and img.dim() == 3)
else img
img.permute(1, 2, 0).cpu().numpy() if
(isinstance(img, torch.Tensor) and img.dim() == 3) else img
for img in imgs
]
@@ -65,8 +71,7 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
ratios = [4 / 3] * n
figsize = [sum(ratios) * 4.5, 4.5]
fig, ax = plt.subplots(
1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
)
1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
if n == 1:
ax = [ax]
for i in range(n):
@@ -81,7 +86,7 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
fig.tight_layout(pad=pad)
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0):
"""Plot keypoints for existing images.
Args:
kpts: list of ndarrays of size (N, 2).
@@ -100,7 +105,14 @@ def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
def plot_matches(kpts0,
kpts1,
color=None,
lw=1.5,
ps=4,
a=1.0,
labels=None,
axes=None):
"""Plot matches for a pair of existing images.
Args:
kpts0, kpts1: corresponding keypoints of size (N, 2).
@@ -160,25 +172,28 @@ def add_text(
text,
pos=(0.01, 0.99),
fs=15,
color="w",
lcolor="k",
color='w',
lcolor='k',
lwidth=2,
ha="left",
va="top",
ha='left',
va='top',
):
ax = plt.gcf().axes[idx]
t = ax.text(
*pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
)
*pos,
text,
fontsize=fs,
ha=ha,
va=va,
color=color,
transform=ax.transAxes)
if lcolor is not None:
t.set_path_effects(
[
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
path_effects.Normal(),
]
)
t.set_path_effects([
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
path_effects.Normal(),
])
def save_plot(path, **kw):
"""Save the current figure without any white margin."""
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw)

View File

@@ -13,9 +13,9 @@ from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.outputs import OutputKeys
from modelscope.utils.constant import ModelFile, Tasks
from .lightglue import LightGlue, SuperPoint, DISK, ALIKED, SIFT
from .lightglue.utils import rbd, numpy_image_to_torch
from .config.default import lightglue_default_conf
from .lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint
from .lightglue.utils import numpy_image_to_torch, rbd
@MODELS.register_module(
@@ -30,20 +30,28 @@ class LightGlueImageMatching(TorchModel):
super().__init__(model_dir, **kwargs)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 'mps', 'cpu'
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu') # 'mps', 'cpu'
features = lightglue_default_conf.get('features', 'superpoint')
features = lightglue_default_conf.get('features','superpoint')
if features == 'disk':
self.extractor = DISK(max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.extractor = DISK(
max_num_keypoints=max_num_keypoints).eval().to(self.device)
elif features == 'aliked':
self.extractor = ALIKED(max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.extractor = ALIKED(
max_num_keypoints=max_num_keypoints).eval().to(self.device)
elif features == 'sift':
self.extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.extractor = SIFT(
max_num_keypoints=max_num_keypoints).eval().to(self.device)
else:
self.extractor = SuperPoint(model_dir=model_dir, max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.matcher = LightGlue(model_dir=model_dir, default_conf=lightglue_default_conf).eval().to(self.device)
self.extractor = SuperPoint(
model_dir=model_dir,
max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.matcher = LightGlue(
model_dir=model_dir,
default_conf=lightglue_default_conf).eval().to(self.device)
def forward(self, inputs):
'''
@@ -51,9 +59,11 @@ class LightGlueImageMatching(TorchModel):
inputs: a dict with keys 'image0', 'image1'
'''
feats0 = self.extractor.extract(numpy_image_to_torch(inputs['image0']).to(self.device))
feats1 = self.extractor.extract(numpy_image_to_torch(inputs['image1']).to(self.device))
matches01 = self.matcher({"image0": feats0, "image1": feats1})
feats0 = self.extractor.extract(
numpy_image_to_torch(inputs['image0']).to(self.device))
feats1 = self.extractor.extract(
numpy_image_to_torch(inputs['image1']).to(self.device))
matches01 = self.matcher({'image0': feats0, 'image1': feats1})
return [feats0, feats1, matches01]
@@ -63,17 +73,21 @@ class LightGlueImageMatching(TorchModel):
inputs: a list of feats0, feats1, matches01
'''
matching_result = inputs
feats0, feats1, matches01 = [
rbd(x) for x in matching_result
] # remove batch dimension
feats0, feats1, matches01 = [rbd(x) for x in matching_result
] # remove batch dimension
kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
kpts0, kpts1, matches = feats0['keypoints'], feats1[
'keypoints'], matches01['matches']
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
# match confidence
confidence = matches01["scores"]
confidence = matches01['scores']
matches_result = {'kpts0': m_kpts0,'kpts1': m_kpts1,'confidence': confidence}
matches_result = {
'kpts0': m_kpts0,
'kpts1': m_kpts1,
'confidence': confidence
}
results = {OutputKeys.MATCHES: matches_result}
return results

View File

@@ -296,7 +296,9 @@ else:
],
'human3d_render_pipeline': ['Human3DRenderPipeline'],
'human3d_animation_pipeline': ['Human3DAnimationPipeline'],
'image_local_feature_matching_pipeline': ['ImageLocalFeatureMatchingPipeline'],
'image_local_feature_matching_pipeline': [
'ImageLocalFeatureMatchingPipeline'
],
'rife_video_frame_interpolation_pipeline': [
'RIFEVideoFrameInterpolationPipeline'
],

View File

@@ -27,8 +27,10 @@ class ImageLocalFeatureMatchingPipeline(Pipeline):
>>> from modelscope.pipelines import pipeline
>>> matcher = pipeline(Tasks.image_local_feature_matching, model='Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data')
>>> matcher([['https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching1.jpg','https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching2.jpg']])
>>> matcher = pipeline(Tasks.image_local_feature_matching,
>>> model='Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data')
>>> matcher([['https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching1.jpg',
>>> 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching2.jpg']])
>>> [{
>>> 'matches': [array([[720.5 , 187.8 ],
>>> [707.4 , 198.23334],
@@ -69,7 +71,6 @@ class ImageLocalFeatureMatchingPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
def load_image(self, img_name):
img = LoadImage.convert_to_ndarray(img_name).astype(np.float32)
img = img / 255.

View File

@@ -67,10 +67,7 @@ class ImageMatchingFastPipeline(Pipeline):
img1 = self.load_image(input[0])
img2 = self.load_image(input[1])
return {
'image0':img1,
'image1':img2
}
return {'image0': img1, 'image1': img2}
def forward(self, input: Dict[str, Any]) -> list:
results = self.model.inference(input)

View File

@@ -26,11 +26,12 @@ class ImageLocalFeatureMatchingTest(unittest.TestCase):
'data/test/images/image_matching1.jpg',
'data/test/images/image_matching2.jpg'
]]
estimator = pipeline(Tasks.image_local_feature_matching, model=self.model_id)
estimator = pipeline(
Tasks.image_local_feature_matching, model=self.model_id)
result = estimator(input_location)
kpts0, kpts1, conf = result[0][OutputKeys.MATCHES]
vis_img = result[0][OutputKeys.OUTPUT_IMG]
cv2.imwrite("vis_demo.jpg", vis_img)
cv2.imwrite('vis_demo.jpg', vis_img)
print('test_image_local_feature_matching DONE')

View File

@@ -32,7 +32,7 @@ class ImageMatchingFastTest(unittest.TestCase):
kpts1,
confidence,
output_filename='lightglue-matches.png',
method="lightglue")
method='lightglue')
print('test_image_matching DONE')