mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
fix lint issue
This commit is contained in:
@@ -808,7 +808,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
(Pipelines.panorama_depth_estimation,
|
||||
'damo/cv_unifuse_panorama-depth-estimation'),
|
||||
Tasks.image_local_feature_matching:
|
||||
(Pipelines.image_local_feature_matching, 'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data'),
|
||||
(Pipelines.image_local_feature_matching,
|
||||
'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data'),
|
||||
Tasks.image_style_transfer: (Pipelines.image_style_transfer,
|
||||
'damo/cv_aams_style-transfer_damo'),
|
||||
Tasks.face_image_generation: (Pipelines.face_image_generation,
|
||||
@@ -832,9 +833,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.image_object_detection:
|
||||
(Pipelines.image_object_detection_auto,
|
||||
'damo/cv_yolox_image-object-detection-auto'),
|
||||
Tasks.ocr_recognition:
|
||||
(Pipelines.ocr_recognition,
|
||||
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
|
||||
Tasks.ocr_recognition: (
|
||||
Pipelines.ocr_recognition,
|
||||
'damo/cv_convnextTiny_ocr-recognition-general_damo'),
|
||||
Tasks.skin_retouching: (Pipelines.skin_retouching,
|
||||
'damo/cv_unet_skin-retouching'),
|
||||
Tasks.faq_question_answering: (
|
||||
|
||||
@@ -8,10 +8,11 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
face_reconstruction, human3d_animation, human_reconstruction,
|
||||
image_classification, image_color_enhance, image_colorization,
|
||||
image_defrcn_fewshot, image_denoise, image_editing,
|
||||
image_inpainting, image_instance_segmentation, image_matching,
|
||||
image_mvs_depth_estimation, image_panoptic_segmentation,
|
||||
image_portrait_enhancement, image_probing_model,
|
||||
image_quality_assessment_degradation,
|
||||
image_inpainting, image_instance_segmentation,
|
||||
image_local_feature_matching, image_matching,
|
||||
image_matching_fast, image_mvs_depth_estimation,
|
||||
image_panoptic_segmentation, image_portrait_enhancement,
|
||||
image_probing_model, image_quality_assessment_degradation,
|
||||
image_quality_assessment_man, image_quality_assessment_mos,
|
||||
image_reid_person, image_restoration,
|
||||
image_semantic_segmentation, image_super_resolution_pasd,
|
||||
@@ -29,6 +30,6 @@ from . import (action_recognition, animal_recognition, bad_image_detecting,
|
||||
video_panoptic_segmentation, video_single_object_tracking,
|
||||
video_stabilization, video_summarization,
|
||||
video_super_resolution, vidt, virual_tryon, vision_middleware,
|
||||
vop_retrieval, image_local_feature_matching,image_matching_fast)
|
||||
vop_retrieval)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -19,4 +19,4 @@ else:
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,21 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os.path as osp
|
||||
|
||||
import io
|
||||
import cv2
|
||||
import torch
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from copy import deepcopy
|
||||
|
||||
import cv2
|
||||
import matplotlib.cm as cm
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.image_local_feature_matching.src.loftr import \
|
||||
LoFTR, default_cfg
|
||||
from modelscope.models.cv.image_local_feature_matching.src.utils.plotting import make_matching_figure
|
||||
from modelscope.models.cv.image_local_feature_matching.src.loftr import (
|
||||
LoFTR, default_cfg)
|
||||
from modelscope.models.cv.image_local_feature_matching.src.utils.plotting import \
|
||||
make_matching_figure
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
import matplotlib.cm as cm
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
@@ -51,15 +52,19 @@ class LocalFeatureMatching(TorchModel):
|
||||
def postprocess(self, Inputs):
|
||||
# Draw
|
||||
color = cm.jet(Inputs['conf'].cpu().numpy())
|
||||
img0, img1, mkpts0, mkpts1 = Inputs["image0"].squeeze().cpu().numpy(), Inputs["image1"].squeeze().cpu().numpy(), Inputs["kpts0"].cpu().numpy(), Inputs["kpts1"].cpu().numpy()
|
||||
img0, img1, mkpts0, mkpts1 = Inputs['image0'].squeeze().cpu().numpy(
|
||||
), Inputs['image1'].squeeze().cpu().numpy(), Inputs['kpts0'].cpu(
|
||||
).numpy(), Inputs['kpts1'].cpu().numpy()
|
||||
text = [
|
||||
'LoFTR',
|
||||
'Matches: {}'.format(len(Inputs['kpts0'])),
|
||||
]
|
||||
img0, img1 = (img0 * 255).astype(np.uint8), (img1 * 255).astype(np.uint8)
|
||||
fig = make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text)
|
||||
img0, img1 = (img0 * 255).astype(np.uint8), (img1 * 255).astype(
|
||||
np.uint8)
|
||||
fig = make_matching_figure(
|
||||
img0, img1, mkpts0, mkpts1, color, text=text)
|
||||
io_buf = io.BytesIO()
|
||||
fig.savefig(io_buf, format="png", dpi=75)
|
||||
fig.savefig(io_buf, format='png', dpi=75)
|
||||
io_buf.seek(0)
|
||||
buf_data = np.frombuffer(io_buf.getvalue(), dtype=np.uint8)
|
||||
io_buf.close()
|
||||
@@ -71,4 +76,4 @@ class LocalFeatureMatching(TorchModel):
|
||||
def inference(self, data):
|
||||
results = self.forward(data)
|
||||
|
||||
return results
|
||||
return results
|
||||
|
||||
@@ -8,4 +8,5 @@ def build_backbone(config):
|
||||
elif config['resolution'] == (16, 4):
|
||||
return ResNetFPN_16_4(config['resnetfpn'])
|
||||
else:
|
||||
raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
||||
raise ValueError(
|
||||
f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.")
|
||||
|
||||
@@ -4,15 +4,28 @@ import torch.nn.functional as F
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution without padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False)
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
bias=False)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1):
|
||||
super().__init__()
|
||||
self.conv1 = conv3x3(in_planes, planes, stride)
|
||||
@@ -26,8 +39,7 @@ class BasicBlock(nn.Module):
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
conv1x1(in_planes, planes, stride=stride),
|
||||
nn.BatchNorm2d(planes)
|
||||
)
|
||||
nn.BatchNorm2d(planes))
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
@@ -37,7 +49,7 @@ class BasicBlock(nn.Module):
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class ResNetFPN_8_2(nn.Module):
|
||||
@@ -57,7 +69,8 @@ class ResNetFPN_8_2(nn.Module):
|
||||
self.in_planes = initial_dim
|
||||
|
||||
# Networks
|
||||
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(initial_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
@@ -84,7 +97,8 @@ class ResNetFPN_8_2(nn.Module):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
@@ -107,13 +121,15 @@ class ResNetFPN_8_2(nn.Module):
|
||||
# FPN
|
||||
x3_out = self.layer3_outconv(x3)
|
||||
|
||||
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x3_out_2x = F.interpolate(
|
||||
x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x2_out = self.layer2_outconv(x2)
|
||||
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
||||
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
|
||||
|
||||
x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x2_out_2x = F.interpolate(
|
||||
x2_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x1_out = self.layer1_outconv(x1)
|
||||
x1_out = self.layer1_outconv2(x1_out+x2_out_2x)
|
||||
x1_out = self.layer1_outconv2(x1_out + x2_out_2x)
|
||||
|
||||
return [x3_out, x1_out]
|
||||
|
||||
@@ -135,7 +151,8 @@ class ResNetFPN_16_4(nn.Module):
|
||||
self.in_planes = initial_dim
|
||||
|
||||
# Networks
|
||||
self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.conv1 = nn.Conv2d(
|
||||
1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(initial_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
@@ -164,7 +181,8 @@ class ResNetFPN_16_4(nn.Module):
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
@@ -188,12 +206,14 @@ class ResNetFPN_16_4(nn.Module):
|
||||
# FPN
|
||||
x4_out = self.layer4_outconv(x4)
|
||||
|
||||
x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x4_out_2x = F.interpolate(
|
||||
x4_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x3_out = self.layer3_outconv(x3)
|
||||
x3_out = self.layer3_outconv2(x3_out+x4_out_2x)
|
||||
x3_out = self.layer3_outconv2(x3_out + x4_out_2x)
|
||||
|
||||
x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x3_out_2x = F.interpolate(
|
||||
x3_out, scale_factor=2., mode='bilinear', align_corners=True)
|
||||
x2_out = self.layer2_outconv(x2)
|
||||
x2_out = self.layer2_outconv2(x2_out+x3_out_2x)
|
||||
x2_out = self.layer2_outconv2(x2_out + x3_out_2x)
|
||||
|
||||
return [x4_out, x2_out]
|
||||
|
||||
@@ -3,13 +3,14 @@ import torch.nn as nn
|
||||
from einops.einops import rearrange
|
||||
|
||||
from .backbone import build_backbone
|
||||
from .utils.position_encoding import PositionEncodingSine
|
||||
from .loftr_module import LocalFeatureTransformer, FinePreprocess
|
||||
from .loftr_module import FinePreprocess, LocalFeatureTransformer
|
||||
from .utils.coarse_matching import CoarseMatching
|
||||
from .utils.fine_matching import FineMatching
|
||||
from .utils.position_encoding import PositionEncodingSine
|
||||
|
||||
|
||||
class LoFTR(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# Misc
|
||||
@@ -23,11 +24,11 @@ class LoFTR(nn.Module):
|
||||
self.loftr_coarse = LocalFeatureTransformer(config['coarse'])
|
||||
self.coarse_matching = CoarseMatching(config['match_coarse'])
|
||||
self.fine_preprocess = FinePreprocess(config)
|
||||
self.loftr_fine = LocalFeatureTransformer(config["fine"])
|
||||
self.loftr_fine = LocalFeatureTransformer(config['fine'])
|
||||
self.fine_matching = FineMatching()
|
||||
|
||||
def forward(self, data):
|
||||
"""
|
||||
"""
|
||||
Update:
|
||||
data (dict): {
|
||||
'image0': (torch.Tensor): (N, 1, H, W)
|
||||
@@ -39,18 +40,24 @@ class LoFTR(nn.Module):
|
||||
# 1. Local Feature CNN
|
||||
data.update({
|
||||
'bs': data['image0'].size(0),
|
||||
'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:]
|
||||
'hw0_i': data['image0'].shape[2:],
|
||||
'hw1_i': data['image1'].shape[2:]
|
||||
})
|
||||
|
||||
if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence
|
||||
feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0))
|
||||
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs'])
|
||||
feats_c, feats_f = self.backbone(
|
||||
torch.cat([data['image0'], data['image1']], dim=0))
|
||||
(feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(
|
||||
data['bs']), feats_f.split(data['bs'])
|
||||
else: # handle different input shapes
|
||||
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1'])
|
||||
(feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(
|
||||
data['image0']), self.backbone(data['image1'])
|
||||
|
||||
data.update({
|
||||
'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:],
|
||||
'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:]
|
||||
'hw0_c': feat_c0.shape[2:],
|
||||
'hw1_c': feat_c1.shape[2:],
|
||||
'hw0_f': feat_f0.shape[2:],
|
||||
'hw1_f': feat_f1.shape[2:]
|
||||
})
|
||||
|
||||
# 2. coarse-level loftr module
|
||||
@@ -60,16 +67,21 @@ class LoFTR(nn.Module):
|
||||
|
||||
mask_c0 = mask_c1 = None # mask is useful in training
|
||||
if 'mask0' in data:
|
||||
mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2)
|
||||
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1)
|
||||
mask_c0, mask_c1 = data['mask0'].flatten(
|
||||
-2), data['mask1'].flatten(-2)
|
||||
feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0,
|
||||
mask_c1)
|
||||
|
||||
# 3. match coarse-level
|
||||
self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
|
||||
self.coarse_matching(
|
||||
feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1)
|
||||
|
||||
# 4. fine-level refinement
|
||||
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data)
|
||||
feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(
|
||||
feat_f0, feat_f1, feat_c0, feat_c1, data)
|
||||
if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted
|
||||
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold)
|
||||
feat_f0_unfold, feat_f1_unfold = self.loftr_fine(
|
||||
feat_f0_unfold, feat_f1_unfold)
|
||||
|
||||
# 5. match fine-level
|
||||
self.fine_matching(feat_f0_unfold, feat_f1_unfold, data)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .transformer import LocalFeatureTransformer
|
||||
from .fine_preprocess import FinePreprocess
|
||||
from .transformer import LocalFeatureTransformer
|
||||
|
||||
@@ -5,6 +5,7 @@ from einops.einops import rearrange, repeat
|
||||
|
||||
|
||||
class FinePreprocess(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
@@ -17,14 +18,14 @@ class FinePreprocess(nn.Module):
|
||||
self.d_model_f = d_model_f
|
||||
if self.cat_c_feat:
|
||||
self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True)
|
||||
self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True)
|
||||
self.merge_feat = nn.Linear(2 * d_model_f, d_model_f, bias=True)
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu")
|
||||
nn.init.kaiming_normal_(p, mode='fan_out', nonlinearity='relu')
|
||||
|
||||
def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data):
|
||||
W = self.W
|
||||
@@ -32,28 +33,41 @@ class FinePreprocess(nn.Module):
|
||||
|
||||
data.update({'W': W})
|
||||
if data['b_ids'].shape[0] == 0:
|
||||
feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
feat0 = torch.empty(
|
||||
0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
feat1 = torch.empty(
|
||||
0, self.W**2, self.d_model_f, device=feat_f0.device)
|
||||
return feat0, feat1
|
||||
|
||||
# 1. unfold(crop) all local windows
|
||||
feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2)
|
||||
feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2)
|
||||
feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
feat_f0_unfold = F.unfold(
|
||||
feat_f0, kernel_size=(W, W), stride=stride, padding=W // 2)
|
||||
feat_f0_unfold = rearrange(
|
||||
feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
feat_f1_unfold = F.unfold(
|
||||
feat_f1, kernel_size=(W, W), stride=stride, padding=W // 2)
|
||||
feat_f1_unfold = rearrange(
|
||||
feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2)
|
||||
|
||||
# 2. select only the predicted matches
|
||||
feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf]
|
||||
feat_f0_unfold = feat_f0_unfold[data['b_ids'],
|
||||
data['i_ids']] # [n, ww, cf]
|
||||
feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']]
|
||||
|
||||
# option: use coarse-level loftr feature as context: concat and linear
|
||||
if self.cat_c_feat:
|
||||
feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']],
|
||||
feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c]
|
||||
feat_cf_win = self.merge_feat(torch.cat([
|
||||
torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf]
|
||||
repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf]
|
||||
], -1))
|
||||
feat_c_win = self.down_proj(
|
||||
torch.cat([
|
||||
feat_c0[data['b_ids'], data['i_ids']],
|
||||
feat_c1[data['b_ids'], data['j_ids']]
|
||||
], 0)) # [2n, c]
|
||||
feat_cf_win = self.merge_feat(
|
||||
torch.cat(
|
||||
[
|
||||
torch.cat([feat_f0_unfold, feat_f1_unfold],
|
||||
0), # [2n, ww, cf]
|
||||
repeat(feat_c_win, 'n c -> n ww c', ww = W ** 2), # [2n, ww, cf]
|
||||
], -1)) # yapf: disable
|
||||
feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0)
|
||||
|
||||
return feat_f0_unfold, feat_f1_unfold
|
||||
|
||||
@@ -4,7 +4,7 @@ Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_trans
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, Dropout
|
||||
from torch.nn import Dropout, Module
|
||||
|
||||
|
||||
def elu_feature_map(x):
|
||||
@@ -12,6 +12,7 @@ def elu_feature_map(x):
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
super().__init__()
|
||||
self.feature_map = elu_feature_map
|
||||
@@ -40,14 +41,16 @@ class LinearAttention(Module):
|
||||
|
||||
v_length = values.size(1)
|
||||
values = values / v_length # prevent fp16 overflow
|
||||
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
||||
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
||||
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
||||
KV = torch.einsum('nshd,nshv->nhdv', K, values) # (S,D)' @ S,V
|
||||
Z = 1 / (torch.einsum('nlhd,nhd->nlh', Q, K.sum(dim=1)) + self.eps)
|
||||
queried_values = torch.einsum('nlhd,nhdv,nlh->nlhv', Q, KV,
|
||||
Z) * v_length
|
||||
|
||||
return queried_values.contiguous()
|
||||
|
||||
|
||||
class FullAttention(Module):
|
||||
|
||||
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
||||
super().__init__()
|
||||
self.use_dropout = use_dropout
|
||||
@@ -66,9 +69,11 @@ class FullAttention(Module):
|
||||
"""
|
||||
|
||||
# Compute the unnormalized attention and apply the masks
|
||||
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
||||
QK = torch.einsum('nlhd,nshd->nlsh', queries, keys)
|
||||
if kv_mask is not None:
|
||||
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
|
||||
QK.masked_fill_(
|
||||
~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]),
|
||||
float('-inf'))
|
||||
|
||||
# Compute the attention and the weighted average
|
||||
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
||||
@@ -76,6 +81,6 @@ class FullAttention(Module):
|
||||
if self.use_dropout:
|
||||
A = self.dropout(A)
|
||||
|
||||
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
|
||||
queried_values = torch.einsum('nlsh,nshd->nlhd', A, values)
|
||||
|
||||
return queried_values.contiguous()
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
import copy
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .linear_attention import LinearAttention, FullAttention
|
||||
|
||||
from .linear_attention import FullAttention, LinearAttention
|
||||
|
||||
|
||||
class LoFTREncoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
attention='linear'):
|
||||
|
||||
def __init__(self, d_model, nhead, attention='linear'):
|
||||
super(LoFTREncoderLayer, self).__init__()
|
||||
|
||||
self.dim = d_model // nhead
|
||||
@@ -18,14 +18,15 @@ class LoFTREncoderLayer(nn.Module):
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
||||
self.attention = LinearAttention(
|
||||
) if attention == 'linear' else FullAttention()
|
||||
self.merge = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
# feed-forward network
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_model*2, d_model*2, bias=False),
|
||||
nn.Linear(d_model * 2, d_model * 2, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(d_model*2, d_model, bias=False),
|
||||
nn.Linear(d_model * 2, d_model, bias=False),
|
||||
)
|
||||
|
||||
# norm and dropout
|
||||
@@ -44,11 +45,16 @@ class LoFTREncoderLayer(nn.Module):
|
||||
query, key, value = x, source, source
|
||||
|
||||
# multi-head attention
|
||||
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
||||
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
||||
query = self.q_proj(query).view(bs, -1, self.nhead,
|
||||
self.dim) # [N, L, (H, D)]
|
||||
key = self.k_proj(key).view(bs, -1, self.nhead,
|
||||
self.dim) # [N, S, (H, D)]
|
||||
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
||||
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
||||
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
||||
message = self.attention(
|
||||
query, key, value, q_mask=x_mask,
|
||||
kv_mask=source_mask) # [N, L, (H, D)]
|
||||
message = self.merge(message.view(bs, -1,
|
||||
self.nhead * self.dim)) # [N, L, C]
|
||||
message = self.norm1(message)
|
||||
|
||||
# feed-forward network
|
||||
@@ -68,8 +74,11 @@ class LocalFeatureTransformer(nn.Module):
|
||||
self.d_model = config['d_model']
|
||||
self.nhead = config['nhead']
|
||||
self.layer_names = config['layer_names']
|
||||
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention'])
|
||||
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
||||
encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'],
|
||||
config['attention'])
|
||||
self.layers = nn.ModuleList([
|
||||
copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))
|
||||
])
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
@@ -86,7 +95,8 @@ class LocalFeatureTransformer(nn.Module):
|
||||
mask1 (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||
assert self.d_model == feat0.size(
|
||||
2), 'the feature number of src and transformer must be equal'
|
||||
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
if name == 'self':
|
||||
|
||||
@@ -5,6 +5,7 @@ from einops.einops import rearrange
|
||||
|
||||
INF = 1e9
|
||||
|
||||
|
||||
def mask_border(m, b: int, v):
|
||||
""" Mask borders with value
|
||||
Args:
|
||||
@@ -45,7 +46,7 @@ def mask_border_with_padding(m, bd, v, p_m0, p_m1):
|
||||
|
||||
def compute_max_candidates(p_m0, p_m1):
|
||||
"""Compute the max candidates of all pairs within a batch
|
||||
|
||||
|
||||
Args:
|
||||
p_m0, p_m1 (torch.Tensor): padded masks
|
||||
"""
|
||||
@@ -57,6 +58,7 @@ def compute_max_candidates(p_m0, p_m1):
|
||||
|
||||
|
||||
class CoarseMatching(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -75,7 +77,7 @@ class CoarseMatching(nn.Module):
|
||||
try:
|
||||
from .superglue import log_optimal_transport
|
||||
except ImportError:
|
||||
raise ImportError("download superglue.py first!")
|
||||
raise ImportError('download superglue.py first!')
|
||||
self.log_optimal_transport = log_optimal_transport
|
||||
self.bin_score = nn.Parameter(
|
||||
torch.tensor(config['skh_init_bin_score'], requires_grad=True))
|
||||
@@ -103,28 +105,27 @@ class CoarseMatching(nn.Module):
|
||||
'mconf' (torch.Tensor): [M]}
|
||||
NOTE: M' != M during training.
|
||||
"""
|
||||
N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2)
|
||||
_, L, S, _ = feat_c0.size(0), feat_c0.size(1), feat_c1.size(
|
||||
1), feat_c0.size(2)
|
||||
|
||||
# normalize
|
||||
feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5,
|
||||
[feat_c0, feat_c1])
|
||||
|
||||
if self.match_type == 'dual_softmax':
|
||||
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0,
|
||||
sim_matrix = torch.einsum('nlc,nsc->nls', feat_c0,
|
||||
feat_c1) / self.temperature
|
||||
if mask_c0 is not None:
|
||||
sim_matrix.masked_fill_(
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
||||
-INF)
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF)
|
||||
conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2)
|
||||
|
||||
elif self.match_type == 'sinkhorn':
|
||||
# sinkhorn, dustbin included
|
||||
sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1)
|
||||
sim_matrix = torch.einsum('nlc,nsc->nls', feat_c0, feat_c1)
|
||||
if mask_c0 is not None:
|
||||
sim_matrix[:, :L, :S].masked_fill_(
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(),
|
||||
-INF)
|
||||
~(mask_c0[..., None] * mask_c1[:, None]).bool(), -INF)
|
||||
|
||||
# build uniform prior & use sinkhorn
|
||||
log_assign_matrix = self.log_optimal_transport(
|
||||
@@ -207,10 +208,10 @@ class CoarseMatching(nn.Module):
|
||||
else:
|
||||
num_candidates_max = compute_max_candidates(
|
||||
data['mask0'], data['mask1'])
|
||||
num_matches_train = int(num_candidates_max *
|
||||
self.train_coarse_percent)
|
||||
num_matches_train = int(num_candidates_max
|
||||
* self.train_coarse_percent)
|
||||
num_matches_pred = len(b_ids)
|
||||
assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches"
|
||||
assert self.train_pad_num_gt_min < num_matches_train, 'min-num-gt-pad should be less than num-train-matches'
|
||||
|
||||
# pred_indices is to select from prediction
|
||||
if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min:
|
||||
@@ -223,11 +224,13 @@ class CoarseMatching(nn.Module):
|
||||
|
||||
# gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200)
|
||||
gt_pad_indices = torch.randint(
|
||||
len(data['spv_b_ids']),
|
||||
(max(num_matches_train - num_matches_pred,
|
||||
self.train_pad_num_gt_min), ),
|
||||
device=_device)
|
||||
mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero
|
||||
len(data['spv_b_ids']),
|
||||
(max(num_matches_train - num_matches_pred,
|
||||
self.train_pad_num_gt_min), ),
|
||||
device=_device)
|
||||
mconf_gt = torch.zeros(
|
||||
len(data['spv_b_ids']),
|
||||
device=_device) # set conf of gt paddings to all zero
|
||||
|
||||
b_ids, i_ids, j_ids, mconf = map(
|
||||
lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]],
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -7,8 +8,8 @@ def create_meshgrid(
|
||||
height: int,
|
||||
width: int,
|
||||
normalized_coordinates: bool = True,
|
||||
device = None,
|
||||
dtype = None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
"""Generate a coordinate grid for an image.
|
||||
|
||||
@@ -48,7 +49,8 @@ def create_meshgrid(
|
||||
if normalized_coordinates:
|
||||
xs = (xs / (width - 1) - 0.5) * 2
|
||||
ys = (ys / (height - 1) - 0.5) * 2
|
||||
base_grid = torch.stack(torch.meshgrid([xs, ys], indexing="ij"), dim=-1) # WxHx2
|
||||
base_grid = torch.stack(
|
||||
torch.meshgrid([xs, ys], indexing='ij'), dim=-1) # WxHx2
|
||||
return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
|
||||
|
||||
|
||||
@@ -120,7 +122,7 @@ class FineMatching(nn.Module):
|
||||
|
||||
# corner case: if no coarse matches found
|
||||
if M == 0:
|
||||
assert self.training == False, "M is always >0, when training, see coarse_matching.py"
|
||||
assert self.training is False, 'M is always >0, when training, see coarse_matching.py'
|
||||
# logger.warning('No matches found in coarse-level.')
|
||||
data.update({
|
||||
'expec_f': torch.empty(0, 3, device=feat_f0.device),
|
||||
@@ -129,35 +131,41 @@ class FineMatching(nn.Module):
|
||||
})
|
||||
return
|
||||
|
||||
feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :]
|
||||
feat_f0_picked = feat_f0_picked = feat_f0[:, WW // 2, :]
|
||||
sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1)
|
||||
softmax_temp = 1. / C**.5
|
||||
heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
||||
heatmap = torch.softmax(
|
||||
softmax_temp * sim_matrix, dim=1).view(-1, W, W)
|
||||
|
||||
# compute coordinates from heatmap
|
||||
coords_normalized = spatial_expectation2d(heatmap[None], True)[0] # [M, 2]
|
||||
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2]
|
||||
coords_normalized = spatial_expectation2d(heatmap[None],
|
||||
True)[0] # [M, 2]
|
||||
grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(
|
||||
1, -1, 2) # [1, WW, 2]
|
||||
|
||||
# compute std over <x, y>
|
||||
var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2]
|
||||
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability
|
||||
|
||||
var = torch.sum(
|
||||
grid_normalized**2 * heatmap.view(-1, WW, 1),
|
||||
dim=1) - coords_normalized**2 # [M, 2]
|
||||
std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)),
|
||||
-1) # [M] clamp needed for numerical stability
|
||||
|
||||
# for fine-level supervision
|
||||
data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
|
||||
data.update(
|
||||
{'expec_f':
|
||||
torch.cat([coords_normalized, std.unsqueeze(1)], -1)})
|
||||
|
||||
# compute absolute kpt coords
|
||||
self.get_fine_match(coords_normalized, data)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_fine_match(self, coords_normed, data):
|
||||
W, WW, C, scale = self.W, self.WW, self.C, self.scale
|
||||
W, _, _, scale = self.W, self.WW, self.C, self.scale
|
||||
|
||||
# mkpts0_f and mkpts1_f
|
||||
mkpts0_f = data['mkpts0_c']
|
||||
scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale
|
||||
mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])]
|
||||
scale1 = scale * data['scale1'][
|
||||
data['b_ids']] if 'scale0' in data else scale
|
||||
mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] # yapf: disable
|
||||
|
||||
data.update({
|
||||
"mkpts0_f": mkpts0_f,
|
||||
"mkpts1_f": mkpts1_f
|
||||
})
|
||||
data.update({'mkpts0_f': mkpts0_f, 'mkpts1_f': mkpts1_f})
|
||||
|
||||
@@ -6,7 +6,7 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
||||
""" Warp kpts0 from I0 to I1 with depth, K and Rt
|
||||
Also check covisibility and depth consistency.
|
||||
Depth is consistent if relative error < 0.2 (hard-coded).
|
||||
|
||||
|
||||
Args:
|
||||
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
|
||||
depth0 (torch.Tensor): [N, H, W],
|
||||
@@ -21,34 +21,37 @@ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
||||
kpts0_long = kpts0.round().long()
|
||||
|
||||
# Sample depth, get calculable_mask on depth != 0
|
||||
kpts0_depth = torch.stack(
|
||||
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
|
||||
) # (N, L)
|
||||
kpts0_depth = torch.stack([
|
||||
depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]]
|
||||
for i in range(kpts0.shape[0])
|
||||
],
|
||||
dim=0) # noqa E501
|
||||
nonzero_mask = kpts0_depth != 0
|
||||
|
||||
# Unproject
|
||||
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
|
||||
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])],
|
||||
dim=-1) * kpts0_depth[..., None] # (N, L, 3)
|
||||
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
||||
|
||||
# Rigid Transform
|
||||
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
|
||||
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3,
|
||||
[3]] # (N, 3, L)
|
||||
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
||||
|
||||
# Project
|
||||
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
||||
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth
|
||||
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4
|
||||
) # (N, L, 2), +1e-4 to avoid zero depth
|
||||
|
||||
# Covisible Check
|
||||
h, w = depth1.shape[1:3]
|
||||
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
|
||||
(w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
|
||||
covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w - 1) * (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h - 1) # noqa E501 yapf: disable
|
||||
w_kpts0_long = w_kpts0.long()
|
||||
w_kpts0_long[~covisible_mask, :] = 0
|
||||
|
||||
w_kpts0_depth = torch.stack(
|
||||
[depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
|
||||
) # (N, L)
|
||||
consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
|
||||
w_kpts0_depth = torch.stack([depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0) # noqa E501 yapf: disable
|
||||
consistent_mask = (
|
||||
(w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
|
||||
valid_mask = nonzero_mask * covisible_mask * consistent_mask
|
||||
|
||||
return valid_mask, w_kpts0
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -23,16 +24,17 @@ class PositionEncodingSine(nn.Module):
|
||||
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
||||
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
||||
if temp_bug_fix:
|
||||
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
|
||||
div_term = torch.exp(torch.arange(0, d_model // 2, 2).float() * (-math.log(10000.0) / (d_model // 2))) # noqa E501 yapf: disable
|
||||
else: # a buggy implementation (for backward compatability only)
|
||||
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
|
||||
div_term = torch.exp(torch.arange(0, d_model // 2, 2).float() * (-math.log(10000.0) / d_model // 2)) # noqa E501 yapf: disable
|
||||
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
||||
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
||||
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
||||
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
||||
|
||||
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
||||
self.register_buffer(
|
||||
'pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from math import log
|
||||
from loguru import logger
|
||||
|
||||
import torch
|
||||
from einops import repeat
|
||||
from kornia.utils import create_meshgrid
|
||||
from loguru import logger
|
||||
|
||||
from .geometry import warp_kpts
|
||||
|
||||
############## ↓ Coarse-Level supervision ↓ ##############
|
||||
# ↓ Coarse-Level supervision ↓ ##############
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -30,7 +30,7 @@ def spvs_coarse(data, config):
|
||||
'spv_w_pt0_i': [N, hw0, 2], in original image resolution
|
||||
'spv_pt1_i': [N, hw1, 2], in original image resolution
|
||||
}
|
||||
|
||||
|
||||
NOTE:
|
||||
- for scannet dataset, there're 3 kinds of resolution {i, c, f}
|
||||
- for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f}
|
||||
@@ -46,9 +46,14 @@ def spvs_coarse(data, config):
|
||||
|
||||
# 2. warp grids
|
||||
# create kpts in meshgrid and resize them to image resolution
|
||||
grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2]
|
||||
grid_pt0_c = create_meshgrid(h0, w0, False,
|
||||
device).reshape(1, h0 * w0,
|
||||
2).repeat(N, 1,
|
||||
1) # [N, hw, 2]
|
||||
grid_pt0_i = scale0 * grid_pt0_c
|
||||
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1)
|
||||
grid_pt1_c = create_meshgrid(h1, w1, False,
|
||||
device).reshape(1, h1 * w1,
|
||||
2).repeat(N, 1, 1)
|
||||
grid_pt1_i = scale1 * grid_pt1_c
|
||||
|
||||
# mask padded region to (0, 0), so no need to manually mask conf_matrix_gt
|
||||
@@ -59,8 +64,10 @@ def spvs_coarse(data, config):
|
||||
# warp kpts bi-directionally and resize them to coarse-level resolution
|
||||
# (no depth consistency check, since it leads to worse results experimentally)
|
||||
# (unhandled edge case: points with 0-depth will be warped to the left-up corner)
|
||||
_, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
|
||||
_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
|
||||
_, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'],
|
||||
data['T_0to1'], data['K0'], data['K1'])
|
||||
_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'],
|
||||
data['T_1to0'], data['K1'], data['K0'])
|
||||
w_pt0_c = w_pt0_i / scale1
|
||||
w_pt1_c = w_pt1_i / scale0
|
||||
|
||||
@@ -72,16 +79,21 @@ def spvs_coarse(data, config):
|
||||
|
||||
# corner case: out of boundary
|
||||
def out_bound_mask(pt, w, h):
|
||||
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h)
|
||||
return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (
|
||||
pt[..., 1] >= h)
|
||||
|
||||
nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0
|
||||
nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0
|
||||
|
||||
loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0)
|
||||
correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1)
|
||||
loop_back = torch.stack(
|
||||
[nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)],
|
||||
dim=0)
|
||||
correct_0to1 = loop_back == torch.arange(
|
||||
h0 * w0, device=device)[None].repeat(N, 1)
|
||||
correct_0to1[:, 0] = False # ignore the top-left corner
|
||||
|
||||
# 4. construct a gt conf_matrix
|
||||
conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device)
|
||||
conf_matrix_gt = torch.zeros(N, h0 * w0, h1 * w1, device=device)
|
||||
b_ids, i_ids = torch.where(correct_0to1 != 0)
|
||||
j_ids = nearest_index1[b_ids, i_ids]
|
||||
|
||||
@@ -90,27 +102,22 @@ def spvs_coarse(data, config):
|
||||
|
||||
# 5. save coarse matches(gt) for training fine level
|
||||
if len(b_ids) == 0:
|
||||
logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}")
|
||||
logger.warning(
|
||||
f"No groundtruth coarse match found for: {data['pair_names']}")
|
||||
# this won't affect fine-level loss calculation
|
||||
b_ids = torch.tensor([0], device=device)
|
||||
i_ids = torch.tensor([0], device=device)
|
||||
j_ids = torch.tensor([0], device=device)
|
||||
|
||||
data.update({
|
||||
'spv_b_ids': b_ids,
|
||||
'spv_i_ids': i_ids,
|
||||
'spv_j_ids': j_ids
|
||||
})
|
||||
data.update({'spv_b_ids': b_ids, 'spv_i_ids': i_ids, 'spv_j_ids': j_ids})
|
||||
|
||||
# 6. save intermediate results (for fast fine-level computation)
|
||||
data.update({
|
||||
'spv_w_pt0_i': w_pt0_i,
|
||||
'spv_pt1_i': grid_pt1_i
|
||||
})
|
||||
data.update({'spv_w_pt0_i': w_pt0_i, 'spv_pt1_i': grid_pt1_i})
|
||||
|
||||
|
||||
def compute_supervision_coarse(data, config):
|
||||
assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!"
|
||||
assert len(set(
|
||||
data['dataset_name'])) == 1, 'Do not support mixed datasets training!'
|
||||
data_source = data['dataset_name'][0]
|
||||
if data_source.lower() in ['scannet', 'megadepth']:
|
||||
spvs_coarse(data, config)
|
||||
@@ -118,7 +125,8 @@ def compute_supervision_coarse(data, config):
|
||||
raise ValueError(f'Unknown data source: {data_source}')
|
||||
|
||||
|
||||
############## ↓ Fine-Level supervision ↓ ##############
|
||||
# ↓ Fine-Level supervision ↓ ##############
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def spvs_fine(data, config):
|
||||
@@ -139,8 +147,9 @@ def spvs_fine(data, config):
|
||||
# 3. compute gt
|
||||
scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale
|
||||
# `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later
|
||||
expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
|
||||
data.update({"expec_f_gt": expec_f_gt})
|
||||
expec_f_gt = (w_pt0_i[b_ids, i_ids]
|
||||
- pt1_i[b_ids, j_ids]) / scale / radius # [M, 2]
|
||||
data.update({'expec_f_gt': expec_f_gt})
|
||||
|
||||
|
||||
def compute_supervision_fine(data, config):
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import bisect
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _compute_conf_thresh(data):
|
||||
@@ -17,21 +18,30 @@ def _compute_conf_thresh(data):
|
||||
|
||||
# --- VISUALIZATION --- #
|
||||
|
||||
def make_matching_figure(
|
||||
img0, img1, mkpts0, mkpts1, color,
|
||||
kpts0=None, kpts1=None, text=[], dpi=75, path=None):
|
||||
|
||||
def make_matching_figure(img0,
|
||||
img1,
|
||||
mkpts0,
|
||||
mkpts1,
|
||||
color,
|
||||
kpts0=None,
|
||||
kpts1=None,
|
||||
text=[],
|
||||
dpi=75,
|
||||
path=None):
|
||||
# draw image pair
|
||||
assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
|
||||
assert mkpts0.shape[0] == mkpts1.shape[
|
||||
0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
|
||||
fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
|
||||
axes[0].imshow(img0, cmap='gray')
|
||||
axes[1].imshow(img1, cmap='gray')
|
||||
for i in range(2): # clear all frames
|
||||
for i in range(2): # clear all frames
|
||||
axes[i].get_yaxis().set_ticks([])
|
||||
axes[i].get_xaxis().set_ticks([])
|
||||
for spine in axes[i].spines.values():
|
||||
spine.set_visible(False)
|
||||
plt.tight_layout(pad=1)
|
||||
|
||||
|
||||
if kpts0 is not None:
|
||||
assert kpts1 is not None
|
||||
axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
|
||||
@@ -43,19 +53,28 @@ def make_matching_figure(
|
||||
transFigure = fig.transFigure.inverted()
|
||||
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
||||
fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
|
||||
fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
|
||||
(fkpts0[i, 1], fkpts1[i, 1]),
|
||||
transform=fig.transFigure, c=color[i], linewidth=1)
|
||||
for i in range(len(mkpts0))]
|
||||
|
||||
fig.lines = [
|
||||
matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
|
||||
(fkpts0[i, 1], fkpts1[i, 1]),
|
||||
transform=fig.transFigure,
|
||||
c=color[i],
|
||||
linewidth=1) for i in range(len(mkpts0))
|
||||
]
|
||||
|
||||
axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
|
||||
axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
|
||||
|
||||
# put txts
|
||||
txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
|
||||
fig.text(
|
||||
0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes,
|
||||
fontsize=15, va='top', ha='left', color=txt_color)
|
||||
0.01,
|
||||
0.99,
|
||||
'\n'.join(text),
|
||||
transform=fig.axes[0].transAxes,
|
||||
fontsize=15,
|
||||
va='top',
|
||||
ha='left',
|
||||
color=txt_color)
|
||||
|
||||
# save or return figure
|
||||
if path:
|
||||
@@ -68,12 +87,14 @@ def make_matching_figure(
|
||||
def _make_evaluation_figure(data, b_id, alpha='dynamic'):
|
||||
b_mask = data['m_bids'] == b_id
|
||||
conf_thr = _compute_conf_thresh(data)
|
||||
|
||||
img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
||||
img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
||||
|
||||
img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(
|
||||
np.int32)
|
||||
img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(
|
||||
np.int32)
|
||||
kpts0 = data['mkpts0_f'][b_mask].cpu().numpy()
|
||||
kpts1 = data['mkpts1_f'][b_mask].cpu().numpy()
|
||||
|
||||
|
||||
# for megadepth, we visualize matches on the resized image
|
||||
if 'scale0' in data:
|
||||
kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]]
|
||||
@@ -92,18 +113,18 @@ def _make_evaluation_figure(data, b_id, alpha='dynamic'):
|
||||
if alpha == 'dynamic':
|
||||
alpha = dynamic_alpha(len(correct_mask))
|
||||
color = error_colormap(epi_errs, conf_thr, alpha=alpha)
|
||||
|
||||
|
||||
text = [
|
||||
f'#Matches {len(kpts0)}',
|
||||
f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}',
|
||||
f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}'
|
||||
]
|
||||
|
||||
|
||||
# make the figure
|
||||
figure = make_matching_figure(img0, img1, kpts0, kpts1,
|
||||
color, text=text)
|
||||
figure = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
|
||||
return figure
|
||||
|
||||
|
||||
def _make_confidence_figure(data, b_id):
|
||||
# TODO: Implement confidence figure
|
||||
raise NotImplementedError()
|
||||
@@ -111,7 +132,7 @@ def _make_confidence_figure(data, b_id):
|
||||
|
||||
def make_matching_figures(data, config, mode='evaluation'):
|
||||
""" Make matching figures for a batch.
|
||||
|
||||
|
||||
Args:
|
||||
data (Dict): a batch updated by PL_LoFTR.
|
||||
config (Dict): matcher config
|
||||
@@ -123,8 +144,7 @@ def make_matching_figures(data, config, mode='evaluation'):
|
||||
for b_id in range(data['image0'].size(0)):
|
||||
if mode == 'evaluation':
|
||||
fig = _make_evaluation_figure(
|
||||
data, b_id,
|
||||
alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
|
||||
data, b_id, alpha=config.TRAINER.PLOT_MATCHES_ALPHA)
|
||||
elif mode == 'confidence':
|
||||
fig = _make_confidence_figure(data, b_id)
|
||||
else:
|
||||
@@ -144,11 +164,14 @@ def dynamic_alpha(n_matches,
|
||||
if _range[1] is None:
|
||||
return _range[0]
|
||||
return _range[1] + (milestones[loc + 1] - n_matches) / (
|
||||
milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1])
|
||||
milestones[loc + 1] - milestones[loc]) * (
|
||||
_range[0] - _range[1])
|
||||
|
||||
|
||||
def error_colormap(err, thr, alpha=1.0):
|
||||
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
||||
assert alpha <= 1.0 and alpha > 0, f'Invaid alpha value: {alpha}'
|
||||
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
||||
return np.clip(
|
||||
np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1)
|
||||
np.stack([2 - x * 2, x * 2,
|
||||
np.zeros_like(x),
|
||||
np.ones_like(x) * alpha], -1), 0, 1)
|
||||
|
||||
@@ -1 +1 @@
|
||||
from .default import lightglue_default_conf
|
||||
from .default import lightglue_default_conf
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
lightglue_default_conf = {
|
||||
"features":"superpoint", # superpoint disk aliked sift
|
||||
"name": "lightglue", # just for interfacing
|
||||
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
||||
"descriptor_dim": 256,
|
||||
"add_scale_ori": False,
|
||||
"n_layers": 9,
|
||||
"num_heads": 4,
|
||||
"flash": True, # enable FlashAttention if available.
|
||||
"mp": False, # enable mixed precision
|
||||
"depth_confidence": 0.95, # early stopping, disable with -1
|
||||
"width_confidence": 0.99, # point pruning, disable with -1
|
||||
"filter_threshold": 0.1, # match threshold
|
||||
"weights": None,
|
||||
'features': 'superpoint', # superpoint disk aliked sift
|
||||
'name': 'lightglue', # just for interfacing
|
||||
'input_dim': 256, # input descriptor dimension (autoselected from weights)
|
||||
'descriptor_dim': 256,
|
||||
'add_scale_ori': False,
|
||||
'n_layers': 9,
|
||||
'num_heads': 4,
|
||||
'flash': True, # enable FlashAttention if available.
|
||||
'mp': False, # enable mixed precision
|
||||
'depth_confidence': 0.95, # early stopping, disable with -1
|
||||
'width_confidence': 0.99, # point pruning, disable with -1
|
||||
'filter_threshold': 0.1, # match threshold
|
||||
'weights': None,
|
||||
}
|
||||
|
||||
@@ -45,16 +45,15 @@ from torchvision.models import resnet
|
||||
from .utils import Extractor
|
||||
|
||||
|
||||
def get_patches(
|
||||
tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
|
||||
) -> torch.Tensor:
|
||||
def get_patches(tensor: torch.Tensor, required_corners: torch.Tensor,
|
||||
ps: int) -> torch.Tensor:
|
||||
c, h, w = tensor.shape
|
||||
corner = (required_corners - ps / 2 + 1).long()
|
||||
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
|
||||
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
|
||||
offset = torch.arange(0, ps)
|
||||
|
||||
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
||||
kw = {'indexing': 'ij'} if torch.__version__ >= '1.10' else {}
|
||||
x, y = torch.meshgrid(offset, offset, **kw)
|
||||
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
|
||||
patches = patches.to(corner) + corner[None, None]
|
||||
@@ -70,8 +69,7 @@ def simple_nms(scores: torch.Tensor, nms_radius: int):
|
||||
|
||||
zeros = torch.zeros_like(scores)
|
||||
max_mask = scores == torch.nn.functional.max_pool2d(
|
||||
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
||||
)
|
||||
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
|
||||
|
||||
for _ in range(2):
|
||||
supp_mask = (
|
||||
@@ -80,18 +78,19 @@ def simple_nms(scores: torch.Tensor, nms_radius: int):
|
||||
kernel_size=nms_radius * 2 + 1,
|
||||
stride=1,
|
||||
padding=nms_radius,
|
||||
)
|
||||
> 0
|
||||
)
|
||||
) > 0)
|
||||
supp_scores = torch.where(supp_mask, zeros, scores)
|
||||
new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
|
||||
supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
||||
)
|
||||
supp_scores,
|
||||
kernel_size=nms_radius * 2 + 1,
|
||||
stride=1,
|
||||
padding=nms_radius)
|
||||
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
||||
return torch.where(max_mask, scores, zeros)
|
||||
|
||||
|
||||
class DKD(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
radius: int = 2,
|
||||
@@ -115,14 +114,15 @@ class DKD(nn.Module):
|
||||
self.n_limit = n_limit
|
||||
self.kernel_size = 2 * self.radius + 1
|
||||
self.temperature = 0.1 # tuned temperature
|
||||
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
|
||||
self.unfold = nn.Unfold(
|
||||
kernel_size=self.kernel_size, padding=self.radius)
|
||||
# local xy grid
|
||||
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
|
||||
# (kernel_size*kernel_size) x 2 : (w,h)
|
||||
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
||||
kw = {'indexing': 'ij'} if torch.__version__ >= '1.10' else {}
|
||||
self.hw_grid = (
|
||||
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
|
||||
)
|
||||
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:,
|
||||
[1, 0]])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -141,29 +141,32 @@ class DKD(nn.Module):
|
||||
nms_scores = simple_nms(scores_nograd, self.radius)
|
||||
|
||||
# remove border
|
||||
nms_scores[:, :, : self.radius, :] = 0
|
||||
nms_scores[:, :, :, : self.radius] = 0
|
||||
nms_scores[:, :, :self.radius, :] = 0
|
||||
nms_scores[:, :, :, :self.radius] = 0
|
||||
if image_size is not None:
|
||||
for i in range(scores_map.shape[0]):
|
||||
w, h = image_size[i].long()
|
||||
nms_scores[i, :, h.item() - self.radius :, :] = 0
|
||||
nms_scores[i, :, :, w.item() - self.radius :] = 0
|
||||
nms_scores[i, :, h.item() - self.radius:, :] = 0
|
||||
nms_scores[i, :, :, w.item() - self.radius:] = 0
|
||||
else:
|
||||
nms_scores[:, :, -self.radius :, :] = 0
|
||||
nms_scores[:, :, :, -self.radius :] = 0
|
||||
nms_scores[:, :, -self.radius:, :] = 0
|
||||
nms_scores[:, :, :, -self.radius:] = 0
|
||||
|
||||
# detect keypoints without grad
|
||||
if self.top_k > 0:
|
||||
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
|
||||
indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
|
||||
indices_keypoints = [topk.indices[i]
|
||||
for i in range(b)] # B x top_k
|
||||
else:
|
||||
if self.scores_th > 0:
|
||||
masks = nms_scores > self.scores_th
|
||||
if masks.sum() == 0:
|
||||
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
||||
th = scores_nograd.reshape(b, -1).mean(
|
||||
dim=1) # th = self.scores_th
|
||||
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
||||
else:
|
||||
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
||||
th = scores_nograd.reshape(b, -1).mean(
|
||||
dim=1) # th = self.scores_th
|
||||
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
||||
masks = masks.reshape(b, -1)
|
||||
|
||||
@@ -174,7 +177,7 @@ class DKD(nn.Module):
|
||||
if len(indices) > self.n_limit:
|
||||
kpts_sc = scores[indices]
|
||||
sort_idx = kpts_sc.sort(descending=True)[1]
|
||||
sel_idx = sort_idx[: self.n_limit]
|
||||
sel_idx = sort_idx[:self.n_limit]
|
||||
indices = indices[sel_idx]
|
||||
indices_keypoints.append(indices)
|
||||
|
||||
@@ -190,34 +193,34 @@ class DKD(nn.Module):
|
||||
for b_idx in range(b):
|
||||
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
||||
indices_kpt = indices_keypoints[
|
||||
b_idx
|
||||
] # one dimension vector, say its size is M
|
||||
b_idx] # one dimension vector, say its size is M
|
||||
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
||||
keypoints_xy_nms = torch.stack(
|
||||
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
||||
[
|
||||
indices_kpt % w,
|
||||
torch.div(indices_kpt, w, rounding_mode='trunc')
|
||||
],
|
||||
dim=1,
|
||||
) # Mx2
|
||||
|
||||
# max is detached to prevent undesired backprop loops in the graph
|
||||
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
||||
x_exp = (
|
||||
(patch_scores - max_v) / self.temperature
|
||||
).exp() # M * (kernel**2), in [0, 1]
|
||||
(patch_scores - max_v)
|
||||
/ self.temperature).exp() # M * (kernel**2), in [0, 1]
|
||||
|
||||
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
||||
xy_residual = (
|
||||
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
|
||||
) # Soft-argmax, Mx2
|
||||
xy_residual = (x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
|
||||
) # Soft-argmax, Mx2
|
||||
|
||||
hw_grid_dist2 = (
|
||||
torch.norm(
|
||||
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
|
||||
/ self.radius,
|
||||
dim=-1,
|
||||
)
|
||||
** 2
|
||||
)
|
||||
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
||||
)**2)
|
||||
scoredispersity = (x_exp * hw_grid_dist2).sum(
|
||||
dim=1) / x_exp.sum(dim=1)
|
||||
|
||||
# compute result keypoints
|
||||
keypoints_xy = keypoints_xy_nms + xy_residual
|
||||
@@ -226,11 +229,9 @@ class DKD(nn.Module):
|
||||
kptscore = torch.nn.functional.grid_sample(
|
||||
scores_map[b_idx].unsqueeze(0),
|
||||
keypoints_xy.view(1, 1, -1, 2),
|
||||
mode="bilinear",
|
||||
mode='bilinear',
|
||||
align_corners=True,
|
||||
)[
|
||||
0, 0, 0, :
|
||||
] # CxN
|
||||
)[0, 0, 0, :] # CxN
|
||||
|
||||
keypoints.append(keypoints_xy)
|
||||
scoredispersitys.append(scoredispersity)
|
||||
@@ -238,24 +239,25 @@ class DKD(nn.Module):
|
||||
else:
|
||||
for b_idx in range(b):
|
||||
indices_kpt = indices_keypoints[
|
||||
b_idx
|
||||
] # one dimension vector, say its size is M
|
||||
b_idx] # one dimension vector, say its size is M
|
||||
# To avoid warning: UserWarning: __floordiv__ is deprecated
|
||||
keypoints_xy_nms = torch.stack(
|
||||
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
||||
[
|
||||
indices_kpt % w,
|
||||
torch.div(indices_kpt, w, rounding_mode='trunc')
|
||||
],
|
||||
dim=1,
|
||||
) # Mx2
|
||||
keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
||||
kptscore = torch.nn.functional.grid_sample(
|
||||
scores_map[b_idx].unsqueeze(0),
|
||||
keypoints_xy.view(1, 1, -1, 2),
|
||||
mode="bilinear",
|
||||
mode='bilinear',
|
||||
align_corners=True,
|
||||
)[
|
||||
0, 0, 0, :
|
||||
] # CxN
|
||||
)[0, 0, 0, :] # CxN
|
||||
keypoints.append(keypoints_xy)
|
||||
scoredispersitys.append(kptscore) # for jit.script compatability
|
||||
scoredispersitys.append(
|
||||
kptscore) # for jit.script compatability
|
||||
kptscores.append(kptscore)
|
||||
|
||||
return keypoints, scoredispersitys, kptscores
|
||||
@@ -278,17 +280,18 @@ class InputPadder(object):
|
||||
|
||||
def pad(self, x: torch.Tensor):
|
||||
assert x.ndim == 4
|
||||
return F.pad(x, self._pad, mode="replicate")
|
||||
return F.pad(x, self._pad, mode='replicate')
|
||||
|
||||
def unpad(self, x: torch.Tensor):
|
||||
assert x.ndim == 4
|
||||
ht = x.shape[-2]
|
||||
wd = x.shape[-1]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0] : c[1], c[2] : c[3]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
|
||||
|
||||
class DeformableConv2d(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
@@ -304,9 +307,8 @@ class DeformableConv2d(nn.Module):
|
||||
self.padding = padding
|
||||
self.mask = mask
|
||||
|
||||
self.channel_num = (
|
||||
3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
|
||||
)
|
||||
self.channel_num = (3 * kernel_size * kernel_size if mask else 2
|
||||
* kernel_size * kernel_size)
|
||||
self.offset_conv = nn.Conv2d(
|
||||
in_channels,
|
||||
self.channel_num,
|
||||
@@ -356,10 +358,10 @@ def get_conv(
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_type="conv",
|
||||
conv_type='conv',
|
||||
mask=False,
|
||||
):
|
||||
if conv_type == "conv":
|
||||
if conv_type == 'conv':
|
||||
conv = nn.Conv2d(
|
||||
inplanes,
|
||||
planes,
|
||||
@@ -368,7 +370,7 @@ def get_conv(
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
)
|
||||
elif conv_type == "dcn":
|
||||
elif conv_type == 'dcn':
|
||||
conv = DeformableConv2d(
|
||||
inplanes,
|
||||
planes,
|
||||
@@ -384,13 +386,14 @@ def get_conv(
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
gate: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
conv_type: str = "conv",
|
||||
conv_type: str = 'conv',
|
||||
mask: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -401,12 +404,18 @@ class ConvBlock(nn.Module):
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
self.conv1 = get_conv(
|
||||
in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
conv_type=conv_type,
|
||||
mask=mask)
|
||||
self.bn1 = norm_layer(out_channels)
|
||||
self.conv2 = get_conv(
|
||||
out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
conv_type=conv_type,
|
||||
mask=mask)
|
||||
self.bn2 = norm_layer(out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -430,7 +439,7 @@ class ResBlock(nn.Module):
|
||||
dilation: int = 1,
|
||||
gate: Optional[Callable[..., nn.Module]] = None,
|
||||
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
||||
conv_type: str = "conv",
|
||||
conv_type: str = 'conv',
|
||||
mask: bool = False,
|
||||
) -> None:
|
||||
super(ResBlock, self).__init__()
|
||||
@@ -441,18 +450,17 @@ class ResBlock(nn.Module):
|
||||
if norm_layer is None:
|
||||
norm_layer = nn.BatchNorm2d
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError("ResBlock only supports groups=1 and base_width=64")
|
||||
raise ValueError(
|
||||
'ResBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
||||
raise NotImplementedError('Dilation > 1 not supported in ResBlock')
|
||||
# Both self.conv1 and self.downsample layers
|
||||
# downsample the input when stride != 1
|
||||
self.conv1 = get_conv(
|
||||
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask)
|
||||
self.bn1 = norm_layer(planes)
|
||||
self.conv2 = get_conv(
|
||||
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
||||
)
|
||||
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask)
|
||||
self.bn2 = norm_layer(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
@@ -477,14 +485,15 @@ class ResBlock(nn.Module):
|
||||
|
||||
|
||||
class SDDH(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dims: int,
|
||||
kernel_size: int = 3,
|
||||
n_pos: int = 8,
|
||||
gate=nn.ReLU(),
|
||||
conv2D=False,
|
||||
mask=False,
|
||||
self,
|
||||
dims: int,
|
||||
kernel_size: int = 3,
|
||||
n_pos: int = 8,
|
||||
gate=nn.ReLU(),
|
||||
conv2D=False,
|
||||
mask=False,
|
||||
):
|
||||
super(SDDH, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
@@ -518,18 +527,21 @@ class SDDH(nn.Module):
|
||||
|
||||
# sampled feature conv
|
||||
self.sf_conv = nn.Conv2d(
|
||||
dims, dims, kernel_size=1, stride=1, padding=0, bias=False
|
||||
)
|
||||
dims, dims, kernel_size=1, stride=1, padding=0, bias=False)
|
||||
|
||||
# convM
|
||||
if not conv2D:
|
||||
# deformable desc weights
|
||||
agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
|
||||
self.register_parameter("agg_weights", agg_weights)
|
||||
self.register_parameter('agg_weights', agg_weights)
|
||||
else:
|
||||
self.convM = nn.Conv2d(
|
||||
dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
|
||||
)
|
||||
dims * n_pos,
|
||||
dims,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False)
|
||||
|
||||
def forward(self, x, keypoints):
|
||||
# x: [B,C,H,W]
|
||||
@@ -548,29 +560,28 @@ class SDDH(nn.Module):
|
||||
|
||||
if self.kernel_size > 1:
|
||||
patch = self.get_patches_func(
|
||||
xi, kptsi_wh.long(), self.kernel_size
|
||||
) # [N_kpts, C, K, K]
|
||||
xi, kptsi_wh.long(), self.kernel_size) # [N_kpts, C, K, K]
|
||||
else:
|
||||
kptsi_wh_long = kptsi_wh.long()
|
||||
patch = (
|
||||
xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
|
||||
.permute(1, 0)
|
||||
.reshape(N_kpts, c, 1, 1)
|
||||
)
|
||||
xi[:, kptsi_wh_long[:, 1],
|
||||
kptsi_wh_long[:,
|
||||
0]].permute(1,
|
||||
0).reshape(N_kpts, c, 1, 1))
|
||||
|
||||
offset = self.offset_conv(patch).clamp(
|
||||
-max_offset, max_offset
|
||||
) # [N_kpts, 2*n_pos, 1, 1]
|
||||
-max_offset, max_offset) # [N_kpts, 2*n_pos, 1, 1]
|
||||
if self.mask:
|
||||
offset = (
|
||||
offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
|
||||
) # [N_kpts, n_pos, 3]
|
||||
offset = (offset[:, :, 0, 0].view(N_kpts, 3,
|
||||
self.n_pos).permute(0, 2, 1)
|
||||
) # [N_kpts, n_pos, 3]
|
||||
offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
|
||||
mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
|
||||
mask_weight = torch.sigmoid(offset[:, :,
|
||||
-1]) # [N_kpts, n_pos]
|
||||
else:
|
||||
offset = (
|
||||
offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
|
||||
) # [N_kpts, n_pos, 2]
|
||||
offset = (offset[:, :, 0, 0].view(N_kpts, 2,
|
||||
self.n_pos).permute(0, 2, 1)
|
||||
) # [N_kpts, n_pos, 2]
|
||||
offsets.append(offset) # for visualization
|
||||
|
||||
# get sample positions
|
||||
@@ -580,26 +591,23 @@ class SDDH(nn.Module):
|
||||
|
||||
# sample features
|
||||
features = F.grid_sample(
|
||||
xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
|
||||
) # [1,C,(N_kpts*n_pos),1]
|
||||
features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
|
||||
1, 0, 2, 3
|
||||
) # [N_kpts, C, n_pos, 1]
|
||||
xi.unsqueeze(0), pos, mode='bilinear',
|
||||
align_corners=True) # [1,C,(N_kpts*n_pos),1]
|
||||
features = features.reshape(c, N_kpts, self.n_pos,
|
||||
1).permute(1, 0, 2,
|
||||
3) # [N_kpts, C, n_pos, 1]
|
||||
if self.mask:
|
||||
features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
|
||||
features = torch.einsum('ncpo,np->ncpo', features, mask_weight)
|
||||
|
||||
features = torch.selu_(self.sf_conv(features)).squeeze(
|
||||
-1
|
||||
) # [N_kpts, C, n_pos]
|
||||
-1) # [N_kpts, C, n_pos]
|
||||
# convM
|
||||
if not self.conv2D:
|
||||
descs = torch.einsum(
|
||||
"ncp,pcd->nd", features, self.agg_weights
|
||||
) # [N_kpts, C]
|
||||
descs = torch.einsum('ncp,pcd->nd', features,
|
||||
self.agg_weights) # [N_kpts, C]
|
||||
else:
|
||||
features = features.reshape(N_kpts, -1)[
|
||||
:, :, None, None
|
||||
] # [N_kpts, C*n_pos, 1, 1]
|
||||
features = features.reshape(
|
||||
N_kpts, -1)[:, :, None, None] # [N_kpts, C*n_pos, 1, 1]
|
||||
descs = self.convM(features).squeeze() # [N_kpts, C]
|
||||
|
||||
# normalize
|
||||
@@ -611,34 +619,34 @@ class SDDH(nn.Module):
|
||||
|
||||
class ALIKED(Extractor):
|
||||
default_conf = {
|
||||
"model_name": "aliked-n16",
|
||||
"max_num_keypoints": -1,
|
||||
"detection_threshold": 0.2,
|
||||
"nms_radius": 2,
|
||||
'model_name': 'aliked-n16',
|
||||
'max_num_keypoints': -1,
|
||||
'detection_threshold': 0.2,
|
||||
'nms_radius': 2,
|
||||
}
|
||||
|
||||
checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
|
||||
checkpoint_url = 'https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth'
|
||||
|
||||
n_limit_max = 20000
|
||||
|
||||
# c1, c2, c3, c4, dim, K, M
|
||||
cfgs = {
|
||||
"aliked-t16": [8, 16, 32, 64, 64, 3, 16],
|
||||
"aliked-n16": [16, 32, 64, 128, 128, 3, 16],
|
||||
"aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
|
||||
"aliked-n32": [16, 32, 64, 128, 128, 3, 32],
|
||||
'aliked-t16': [8, 16, 32, 64, 64, 3, 16],
|
||||
'aliked-n16': [16, 32, 64, 128, 128, 3, 16],
|
||||
'aliked-n16rot': [16, 32, 64, 128, 128, 3, 16],
|
||||
'aliked-n32': [16, 32, 64, 128, 128, 3, 32],
|
||||
}
|
||||
preprocess_conf = {
|
||||
"resize": 1024,
|
||||
'resize': 1024,
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
required_data_keys = ['image']
|
||||
|
||||
def __init__(self, **conf):
|
||||
super().__init__(**conf) # Update with default configuration.
|
||||
conf = self.conf
|
||||
c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
|
||||
conv_types = ["conv", "conv", "dcn", "dcn"]
|
||||
conv_types = ['conv', 'conv', 'dcn', 'dcn']
|
||||
conv2D = False
|
||||
mask = False
|
||||
|
||||
@@ -647,7 +655,8 @@ class ALIKED(Extractor):
|
||||
self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
|
||||
self.norm = nn.BatchNorm2d
|
||||
self.gate = nn.SELU(inplace=True)
|
||||
self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
|
||||
self.block1 = ConvBlock(
|
||||
3, c1, self.gate, self.norm, conv_type=conv_types[0])
|
||||
self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
|
||||
self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
|
||||
self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
|
||||
@@ -657,17 +666,13 @@ class ALIKED(Extractor):
|
||||
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
||||
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
||||
self.upsample2 = nn.Upsample(
|
||||
scale_factor=2, mode="bilinear", align_corners=True
|
||||
)
|
||||
scale_factor=2, mode='bilinear', align_corners=True)
|
||||
self.upsample4 = nn.Upsample(
|
||||
scale_factor=4, mode="bilinear", align_corners=True
|
||||
)
|
||||
scale_factor=4, mode='bilinear', align_corners=True)
|
||||
self.upsample8 = nn.Upsample(
|
||||
scale_factor=8, mode="bilinear", align_corners=True
|
||||
)
|
||||
scale_factor=8, mode='bilinear', align_corners=True)
|
||||
self.upsample32 = nn.Upsample(
|
||||
scale_factor=32, mode="bilinear", align_corners=True
|
||||
)
|
||||
scale_factor=32, mode='bilinear', align_corners=True)
|
||||
self.score_head = nn.Sequential(
|
||||
resnet.conv1x1(dim, 8),
|
||||
self.gate,
|
||||
@@ -677,19 +682,19 @@ class ALIKED(Extractor):
|
||||
self.gate,
|
||||
resnet.conv3x3(4, 1),
|
||||
)
|
||||
self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
|
||||
self.desc_head = SDDH(
|
||||
dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
|
||||
self.dkd = DKD(
|
||||
radius=conf.nms_radius,
|
||||
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
|
||||
top_k=-1
|
||||
if conf.detection_threshold > 0 else conf.max_num_keypoints,
|
||||
scores_th=conf.detection_threshold,
|
||||
n_limit=conf.max_num_keypoints
|
||||
if conf.max_num_keypoints > 0
|
||||
else self.n_limit_max,
|
||||
if conf.max_num_keypoints > 0 else self.n_limit_max,
|
||||
)
|
||||
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
self.checkpoint_url.format(conf.model_name), map_location="cpu"
|
||||
)
|
||||
self.checkpoint_url.format(conf.model_name), map_location='cpu')
|
||||
self.load_state_dict(state_dict, strict=True)
|
||||
|
||||
def get_resblock(self, c_in, c_out, conv_type, mask):
|
||||
@@ -738,13 +743,12 @@ class ALIKED(Extractor):
|
||||
return feature_map, score_map
|
||||
|
||||
def forward(self, data: dict) -> dict:
|
||||
image = data["image"]
|
||||
image = data['image']
|
||||
if image.shape[1] == 1:
|
||||
image = grayscale_to_rgb(image)
|
||||
feature_map, score_map = self.extract_dense_map(image)
|
||||
keypoints, kptscores, scoredispersitys = self.dkd(
|
||||
score_map, image_size=data.get("image_size")
|
||||
)
|
||||
score_map, image_size=data.get('image_size'))
|
||||
descriptors, offsets = self.desc_head(feature_map, keypoints)
|
||||
|
||||
_, _, h, w = image.shape
|
||||
@@ -752,7 +756,7 @@ class ALIKED(Extractor):
|
||||
# no padding required
|
||||
# we can set detection_threshold=-1 and conf.max_num_keypoints > 0
|
||||
return {
|
||||
"keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
|
||||
"descriptors": torch.stack(descriptors), # B x N x D
|
||||
"keypoint_scores": torch.stack(kptscores), # B x N
|
||||
'keypoints': wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
|
||||
'descriptors': torch.stack(descriptors), # B x N x D
|
||||
'keypoint_scores': torch.stack(kptscores), # B x N
|
||||
}
|
||||
|
||||
@@ -6,20 +6,20 @@ from .utils import Extractor
|
||||
|
||||
class DISK(Extractor):
|
||||
default_conf = {
|
||||
"weights": "depth",
|
||||
"max_num_keypoints": None,
|
||||
"desc_dim": 128,
|
||||
"nms_window_size": 5,
|
||||
"detection_threshold": 0.0,
|
||||
"pad_if_not_divisible": True,
|
||||
'weights': 'depth',
|
||||
'max_num_keypoints': None,
|
||||
'desc_dim': 128,
|
||||
'nms_window_size': 5,
|
||||
'detection_threshold': 0.0,
|
||||
'pad_if_not_divisible': True,
|
||||
}
|
||||
|
||||
preprocess_conf = {
|
||||
"resize": 1024,
|
||||
"grayscale": False,
|
||||
'resize': 1024,
|
||||
'grayscale': False,
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
required_data_keys = ['image']
|
||||
|
||||
def __init__(self, **conf) -> None:
|
||||
super().__init__(**conf) # Update with default configuration.
|
||||
@@ -28,8 +28,8 @@ class DISK(Extractor):
|
||||
def forward(self, data: dict) -> dict:
|
||||
"""Compute keypoints, scores, descriptors for image"""
|
||||
for key in self.required_data_keys:
|
||||
assert key in data, f"Missing key {key} in data"
|
||||
image = data["image"]
|
||||
assert key in data, f'Missing key {key} in data'
|
||||
image = data['image']
|
||||
if image.shape[1] == 1:
|
||||
image = kornia.color.grayscale_to_rgb(image)
|
||||
features = self.model(
|
||||
@@ -49,7 +49,7 @@ class DISK(Extractor):
|
||||
descriptors = torch.stack(descriptors, 0)
|
||||
|
||||
return {
|
||||
"keypoints": keypoints.to(image).contiguous(),
|
||||
"keypoint_scores": scores.to(image).contiguous(),
|
||||
"descriptors": descriptors.to(image).contiguous(),
|
||||
'keypoints': keypoints.to(image).contiguous(),
|
||||
'keypoint_scores': scores.to(image).contiguous(),
|
||||
'descriptors': descriptors.to(image).contiguous(),
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os.path as osp
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
@@ -8,13 +9,12 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
import os.path as osp
|
||||
try:
|
||||
from flash_attn.modules.mha import FlashCrossAttention
|
||||
except ModuleNotFoundError:
|
||||
FlashCrossAttention = None
|
||||
|
||||
if FlashCrossAttention or hasattr(F, "scaled_dot_product_attention"):
|
||||
if FlashCrossAttention or hasattr(F, 'scaled_dot_product_attention'):
|
||||
FLASH_AVAILABLE = True
|
||||
else:
|
||||
FLASH_AVAILABLE = False
|
||||
@@ -23,9 +23,8 @@ torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
||||
def normalize_keypoints(
|
||||
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
def normalize_keypoints(kpts: torch.Tensor,
|
||||
size: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if size is None:
|
||||
size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
||||
elif not isinstance(size, torch.Tensor):
|
||||
@@ -41,11 +40,14 @@ def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
|
||||
if length <= x.shape[-2]:
|
||||
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
|
||||
pad = torch.ones(
|
||||
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
|
||||
)
|
||||
*x.shape[:-2],
|
||||
length - x.shape[-2],
|
||||
x.shape[-1],
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
y = torch.cat([x, pad], dim=-2)
|
||||
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
|
||||
mask[..., : x.shape[-2], :] = True
|
||||
mask[..., :x.shape[-2], :] = True
|
||||
return y, mask
|
||||
|
||||
|
||||
@@ -55,12 +57,18 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
|
||||
|
||||
|
||||
def apply_cached_rotary_emb(freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
|
||||
def apply_cached_rotary_emb(freqs: torch.Tensor,
|
||||
t: torch.Tensor) -> torch.Tensor:
|
||||
return (t * freqs[0]) + (rotate_half(t) * freqs[1])
|
||||
|
||||
|
||||
class LearnableFourierPositionalEncoding(nn.Module):
|
||||
def __init__(self, M: int, dim: int, F_dim: int = None, gamma: float = 1.0) -> None:
|
||||
|
||||
def __init__(self,
|
||||
M: int,
|
||||
dim: int,
|
||||
F_dim: int = None,
|
||||
gamma: float = 1.0) -> None:
|
||||
super().__init__()
|
||||
F_dim = F_dim if F_dim is not None else dim
|
||||
self.gamma = gamma
|
||||
@@ -76,6 +84,7 @@ class LearnableFourierPositionalEncoding(nn.Module):
|
||||
|
||||
|
||||
class TokenConfidence(nn.Module):
|
||||
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
|
||||
@@ -89,27 +98,33 @@ class TokenConfidence(nn.Module):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(self, allow_flash: bool) -> None:
|
||||
super().__init__()
|
||||
if allow_flash and not FLASH_AVAILABLE:
|
||||
warnings.warn(
|
||||
"FlashAttention is not available. For optimal speed, "
|
||||
"consider installing torch >= 2.0 or flash-attn.",
|
||||
'FlashAttention is not available. For optimal speed, '
|
||||
'consider installing torch >= 2.0 or flash-attn.',
|
||||
stacklevel=2,
|
||||
)
|
||||
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
||||
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
|
||||
self.has_sdp = hasattr(F, 'scaled_dot_product_attention')
|
||||
if allow_flash and FlashCrossAttention:
|
||||
self.flash_ = FlashCrossAttention()
|
||||
if self.has_sdp:
|
||||
torch.backends.cuda.enable_flash_sdp(allow_flash)
|
||||
|
||||
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.enable_flash and q.device.type == "cuda":
|
||||
def forward(self,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.enable_flash and q.device.type == 'cuda':
|
||||
# use torch 2.0 scaled_dot_product_attention with flash
|
||||
if self.has_sdp:
|
||||
args = [x.half().contiguous() for x in [q, k, v]]
|
||||
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
||||
v = F.scaled_dot_product_attention(
|
||||
*args, attn_mask=mask).to(q.dtype)
|
||||
return v if mask is None else v.nan_to_num()
|
||||
else:
|
||||
assert mask is None
|
||||
@@ -121,18 +136,21 @@ class Attention(nn.Module):
|
||||
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
||||
return v if mask is None else v.nan_to_num()
|
||||
else:
|
||||
s = q.shape[-1] ** -0.5
|
||||
sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
||||
s = q.shape[-1]**-0.5
|
||||
sim = torch.einsum('...id,...jd->...ij', q, k) * s
|
||||
if mask is not None:
|
||||
sim.masked_fill(~mask, -float("inf"))
|
||||
sim.masked_fill(~mask, -float('inf'))
|
||||
attn = F.softmax(sim, -1)
|
||||
return torch.einsum("...ij,...jd->...id", attn, v)
|
||||
return torch.einsum('...ij,...jd->...id', attn, v)
|
||||
|
||||
|
||||
class SelfBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
||||
) -> None:
|
||||
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
flash: bool = False,
|
||||
bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
@@ -165,9 +183,12 @@ class SelfBlock(nn.Module):
|
||||
|
||||
|
||||
class CrossBlock(nn.Module):
|
||||
def __init__(
|
||||
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
||||
) -> None:
|
||||
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
flash: bool = False,
|
||||
bias: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.heads = num_heads
|
||||
dim_head = embed_dim // num_heads
|
||||
@@ -190,32 +211,35 @@ class CrossBlock(nn.Module):
|
||||
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
||||
return func(x0), func(x1)
|
||||
|
||||
def forward(
|
||||
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
|
||||
) -> List[torch.Tensor]:
|
||||
def forward(self,
|
||||
x0: torch.Tensor,
|
||||
x1: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None) -> List[torch.Tensor]:
|
||||
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
||||
v0, v1 = self.map_(self.to_v, x0, x1)
|
||||
qk0, qk1, v0, v1 = map(
|
||||
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
||||
(qk0, qk1, v0, v1),
|
||||
)
|
||||
if self.flash is not None and qk0.device.type == "cuda":
|
||||
if self.flash is not None and qk0.device.type == 'cuda':
|
||||
m0 = self.flash(qk0, qk1, v1, mask)
|
||||
m1 = self.flash(
|
||||
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
|
||||
)
|
||||
qk1, qk0, v0,
|
||||
mask.transpose(-1, -2) if mask is not None else None)
|
||||
else:
|
||||
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
||||
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
||||
sim = torch.einsum('bhid, bhjd -> bhij', qk0, qk1)
|
||||
if mask is not None:
|
||||
sim = sim.masked_fill(~mask, -float("inf"))
|
||||
sim = sim.masked_fill(~mask, -float('inf'))
|
||||
attn01 = F.softmax(sim, dim=-1)
|
||||
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
||||
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
||||
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
||||
m0 = torch.einsum('bhij, bhjd -> bhid', attn01, v1)
|
||||
m1 = torch.einsum('bhji, bhjd -> bhid', attn10.transpose(-2, -1),
|
||||
v0)
|
||||
if mask is not None:
|
||||
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
||||
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
||||
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2),
|
||||
m0, m1)
|
||||
m0, m1 = self.map_(self.to_out, m0, m1)
|
||||
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
||||
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
||||
@@ -223,6 +247,7 @@ class CrossBlock(nn.Module):
|
||||
|
||||
|
||||
class TransformerLayer(nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.self_attn = SelfBlock(*args, **kwargs)
|
||||
@@ -238,7 +263,8 @@ class TransformerLayer(nn.Module):
|
||||
mask1: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if mask0 is not None and mask1 is not None:
|
||||
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
|
||||
return self.masked_forward(desc0, desc1, encoding0, encoding1,
|
||||
mask0, mask1)
|
||||
else:
|
||||
desc0 = self.self_attn(desc0, encoding0)
|
||||
desc1 = self.self_attn(desc1, encoding1)
|
||||
@@ -254,14 +280,14 @@ class TransformerLayer(nn.Module):
|
||||
return self.cross_attn(desc0, desc1, mask)
|
||||
|
||||
|
||||
def sigmoid_log_double_softmax(
|
||||
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
def sigmoid_log_double_softmax(sim: torch.Tensor, z0: torch.Tensor,
|
||||
z1: torch.Tensor) -> torch.Tensor:
|
||||
"""create the log assignment matrix from logits and similarity"""
|
||||
b, m, n = sim.shape
|
||||
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
|
||||
scores0 = F.log_softmax(sim, 2)
|
||||
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
|
||||
scores1 = F.log_softmax(sim.transpose(-1, -2).contiguous(),
|
||||
2).transpose(-1, -2)
|
||||
scores = sim.new_full((b, m + 1, n + 1), 0)
|
||||
scores[:, :m, :n] = scores0 + scores1 + certainties
|
||||
scores[:, :-1, -1] = F.logsigmoid(-z0.squeeze(-1))
|
||||
@@ -270,6 +296,7 @@ def sigmoid_log_double_softmax(
|
||||
|
||||
|
||||
class MatchAssignment(nn.Module):
|
||||
|
||||
def __init__(self, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@@ -281,7 +308,7 @@ class MatchAssignment(nn.Module):
|
||||
mdesc0, mdesc1 = self.final_proj(desc0), self.final_proj(desc1)
|
||||
_, _, d = mdesc0.shape
|
||||
mdesc0, mdesc1 = mdesc0 / d**0.25, mdesc1 / d**0.25
|
||||
sim = torch.einsum("bmd,bnd->bmn", mdesc0, mdesc1)
|
||||
sim = torch.einsum('bmd,bnd->bmn', mdesc0, mdesc1)
|
||||
z0 = self.matchability(desc0)
|
||||
z1 = self.matchability(desc1)
|
||||
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
||||
@@ -315,34 +342,34 @@ class LightGlue(nn.Module):
|
||||
# Point pruning involves an overhead (gather).
|
||||
# Therefore, we only activate it if there are enough keypoints.
|
||||
pruning_keypoint_thresholds = {
|
||||
"cpu": -1,
|
||||
"mps": -1,
|
||||
"cuda": 1024,
|
||||
"flash": 1536,
|
||||
'cpu': -1,
|
||||
'mps': -1,
|
||||
'cuda': 1024,
|
||||
'flash': 1536,
|
||||
}
|
||||
|
||||
required_data_keys = ["image0", "image1"]
|
||||
required_data_keys = ['image0', 'image1']
|
||||
|
||||
version = "v0.1_arxiv"
|
||||
weight_path = "{}_lightglue.pth"
|
||||
version = 'v0.1_arxiv'
|
||||
weight_path = '{}_lightglue.pth'
|
||||
|
||||
features = {
|
||||
"superpoint": {
|
||||
"weights": "superpoint_lightglue",
|
||||
"input_dim": 256,
|
||||
'superpoint': {
|
||||
'weights': 'superpoint_lightglue',
|
||||
'input_dim': 256,
|
||||
},
|
||||
"disk": {
|
||||
"weights": "disk_lightglue",
|
||||
"input_dim": 128,
|
||||
'disk': {
|
||||
'weights': 'disk_lightglue',
|
||||
'input_dim': 128,
|
||||
},
|
||||
"aliked": {
|
||||
"weights": "aliked_lightglue",
|
||||
"input_dim": 128,
|
||||
'aliked': {
|
||||
'weights': 'aliked_lightglue',
|
||||
'input_dim': 128,
|
||||
},
|
||||
"sift": {
|
||||
"weights": "sift_lightglue",
|
||||
"input_dim": 128,
|
||||
"add_scale_ori": True,
|
||||
'sift': {
|
||||
'weights': 'sift_lightglue',
|
||||
'input_dim': 128,
|
||||
'add_scale_ori': True,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -352,77 +379,78 @@ class LightGlue(nn.Module):
|
||||
if conf.features is not None:
|
||||
if conf.features not in self.features:
|
||||
raise ValueError(
|
||||
f"Unsupported features: {conf.features} not in "
|
||||
f"{{{','.join(self.features)}}}"
|
||||
)
|
||||
f'Unsupported features: {conf.features} not in '
|
||||
f"{{{','.join(self.features)}}}")
|
||||
for k, v in self.features[conf.features].items():
|
||||
setattr(conf, k, v)
|
||||
|
||||
if conf.input_dim != conf.descriptor_dim:
|
||||
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
||||
self.input_proj = nn.Linear(
|
||||
conf.input_dim, conf.descriptor_dim, bias=True)
|
||||
else:
|
||||
self.input_proj = nn.Identity()
|
||||
|
||||
head_dim = conf.descriptor_dim // conf.num_heads
|
||||
self.posenc = LearnableFourierPositionalEncoding(
|
||||
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
|
||||
)
|
||||
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim)
|
||||
|
||||
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
||||
|
||||
self.transformers = nn.ModuleList(
|
||||
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
|
||||
)
|
||||
[TransformerLayer(d, h, conf.flash) for _ in range(n)])
|
||||
|
||||
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
||||
self.log_assignment = nn.ModuleList(
|
||||
[MatchAssignment(d) for _ in range(n)])
|
||||
self.token_confidence = nn.ModuleList(
|
||||
[TokenConfidence(d) for _ in range(n - 1)]
|
||||
)
|
||||
[TokenConfidence(d) for _ in range(n - 1)])
|
||||
self.register_buffer(
|
||||
"confidence_thresholds",
|
||||
torch.Tensor(
|
||||
[self.confidence_threshold(i) for i in range(self.conf.n_layers)]
|
||||
),
|
||||
'confidence_thresholds',
|
||||
torch.Tensor([
|
||||
self.confidence_threshold(i) for i in range(self.conf.n_layers)
|
||||
]),
|
||||
)
|
||||
|
||||
state_dict = None
|
||||
if conf.features is not None:
|
||||
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
|
||||
state_dict = torch.load(
|
||||
osp.join(model_dir,
|
||||
self.weight_path.format(conf.features)), map_location="cpu"
|
||||
)
|
||||
osp.join(model_dir, self.weight_path.format(conf.features)),
|
||||
map_location='cpu')
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
elif conf.weights is not None:
|
||||
path = Path(__file__).parent
|
||||
path = path / "weights/{}.pth".format(self.conf.weights)
|
||||
state_dict = torch.load(str(path), map_location="cpu")
|
||||
path = path / 'weights/{}.pth'.format(self.conf.weights)
|
||||
state_dict = torch.load(str(path), map_location='cpu')
|
||||
|
||||
if state_dict:
|
||||
# rename old state dict entries
|
||||
for i in range(self.conf.n_layers):
|
||||
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
||||
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
||||
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
||||
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
||||
pattern = f'self_attn.{i}', f'transformers.{i}.self_attn'
|
||||
state_dict = {
|
||||
k.replace(*pattern): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
pattern = f'cross_attn.{i}', f'transformers.{i}.cross_attn'
|
||||
state_dict = {
|
||||
k.replace(*pattern): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
self.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# static lengths LightGlue is compiled for (only used with torch.compile)
|
||||
self.static_lengths = None
|
||||
|
||||
def compile(
|
||||
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
|
||||
):
|
||||
def compile(self,
|
||||
mode='reduce-overhead',
|
||||
static_lengths=[256, 512, 768, 1024, 1280, 1536]):
|
||||
if self.conf.width_confidence != -1:
|
||||
warnings.warn(
|
||||
"Point pruning is partially disabled for compiled forward.",
|
||||
'Point pruning is partially disabled for compiled forward.',
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
for i in range(self.conf.n_layers):
|
||||
self.transformers[i].masked_forward = torch.compile(
|
||||
self.transformers[i].masked_forward, mode=mode, fullgraph=True
|
||||
)
|
||||
self.transformers[i].masked_forward, mode=mode, fullgraph=True)
|
||||
|
||||
self.static_lengths = static_lengths
|
||||
|
||||
@@ -447,30 +475,30 @@ class LightGlue(nn.Module):
|
||||
matching_scores1: [B x N]
|
||||
matches: List[[Si x 2]], scores: List[[Si]]
|
||||
"""
|
||||
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
||||
with torch.autocast(enabled=self.conf.mp, device_type='cuda'):
|
||||
return self._forward(data)
|
||||
|
||||
def _forward(self, data: dict) -> dict:
|
||||
for key in self.required_data_keys:
|
||||
assert key in data, f"Missing key {key} in data"
|
||||
data0, data1 = data["image0"], data["image1"]
|
||||
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
|
||||
assert key in data, f'Missing key {key} in data'
|
||||
data0, data1 = data['image0'], data['image1']
|
||||
kpts0, kpts1 = data0['keypoints'], data1['keypoints']
|
||||
b, m, _ = kpts0.shape
|
||||
b, n, _ = kpts1.shape
|
||||
device = kpts0.device
|
||||
size0, size1 = data0.get("image_size"), data1.get("image_size")
|
||||
size0, size1 = data0.get('image_size'), data1.get('image_size')
|
||||
kpts0 = normalize_keypoints(kpts0, size0).clone()
|
||||
kpts1 = normalize_keypoints(kpts1, size1).clone()
|
||||
|
||||
if self.conf.add_scale_ori:
|
||||
kpts0 = torch.cat(
|
||||
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
||||
)
|
||||
[kpts0] + [data0[k].unsqueeze(-1) for k in ('scales', 'oris')],
|
||||
-1)
|
||||
kpts1 = torch.cat(
|
||||
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
||||
)
|
||||
desc0 = data0["descriptors"].detach().contiguous()
|
||||
desc1 = data1["descriptors"].detach().contiguous()
|
||||
[kpts1] + [data1[k].unsqueeze(-1) for k in ('scales', 'oris')],
|
||||
-1)
|
||||
desc0 = data0['descriptors'].detach().contiguous()
|
||||
desc1 = data1['descriptors'].detach().contiguous()
|
||||
|
||||
assert desc0.shape[-1] == self.conf.input_dim
|
||||
assert desc1.shape[-1] == self.conf.input_dim
|
||||
@@ -507,14 +535,14 @@ class LightGlue(nn.Module):
|
||||
token0, token1 = None, None
|
||||
for i in range(self.conf.n_layers):
|
||||
desc0, desc1 = self.transformers[i](
|
||||
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
|
||||
)
|
||||
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1)
|
||||
if i == self.conf.n_layers - 1:
|
||||
continue # no early stopping or adaptive width at last layer
|
||||
|
||||
if do_early_stop:
|
||||
token0, token1 = self.token_confidence[i](desc0, desc1)
|
||||
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :], i, m + n):
|
||||
if self.check_if_stop(token0[..., :m, :], token1[..., :n, :],
|
||||
i, m + n):
|
||||
break
|
||||
if do_point_pruning and desc0.shape[-2] > pruning_th:
|
||||
scores0 = self.log_assignment[i].get_matchability(desc0)
|
||||
@@ -535,7 +563,8 @@ class LightGlue(nn.Module):
|
||||
|
||||
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :]
|
||||
scores, _ = self.log_assignment[i](desc0, desc1)
|
||||
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
||||
m0, m1, mscores0, mscores1 = filter_matches(scores,
|
||||
self.conf.filter_threshold)
|
||||
matches, mscores = [], []
|
||||
for k in range(b):
|
||||
valid = m0[k] > -1
|
||||
@@ -551,8 +580,10 @@ class LightGlue(nn.Module):
|
||||
if do_point_pruning:
|
||||
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
||||
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
||||
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
||||
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
||||
m0_[:, ind0] = torch.where(m0 == -1, -1,
|
||||
ind1.gather(1, m0.clamp(min=0)))
|
||||
m1_[:, ind1] = torch.where(m1 == -1, -1,
|
||||
ind0.gather(1, m1.clamp(min=0)))
|
||||
mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
||||
mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
||||
mscores0_[:, ind0] = mscores0
|
||||
@@ -563,15 +594,15 @@ class LightGlue(nn.Module):
|
||||
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
||||
|
||||
pred = {
|
||||
"matches0": m0,
|
||||
"matches1": m1,
|
||||
"matching_scores0": mscores0,
|
||||
"matching_scores1": mscores1,
|
||||
"stop": i + 1,
|
||||
"matches": matches,
|
||||
"scores": mscores,
|
||||
"prune0": prune0,
|
||||
"prune1": prune1,
|
||||
'matches0': m0,
|
||||
'matches1': m1,
|
||||
'matching_scores0': mscores0,
|
||||
'matching_scores1': mscores1,
|
||||
'stop': i + 1,
|
||||
'matches': matches,
|
||||
'scores': mscores,
|
||||
'prune0': prune0,
|
||||
'prune1': prune1,
|
||||
}
|
||||
|
||||
return pred
|
||||
@@ -581,9 +612,8 @@ class LightGlue(nn.Module):
|
||||
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
||||
return np.clip(threshold, 0, 1)
|
||||
|
||||
def get_pruning_mask(
|
||||
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
|
||||
) -> torch.Tensor:
|
||||
def get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor,
|
||||
layer_index: int) -> torch.Tensor:
|
||||
"""mask points which should be removed"""
|
||||
keep = scores > (1 - self.conf.width_confidence)
|
||||
if confidences is not None: # Low-confidence points are never pruned.
|
||||
@@ -600,11 +630,12 @@ class LightGlue(nn.Module):
|
||||
"""evaluate stopping condition"""
|
||||
confidences = torch.cat([confidences0, confidences1], -1)
|
||||
threshold = self.confidence_thresholds[layer_index]
|
||||
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
||||
ratio_confident = 1.0 - (
|
||||
confidences < threshold).float().sum() / num_points # noqa E501
|
||||
return ratio_confident > self.conf.depth_confidence
|
||||
|
||||
def pruning_min_kpts(self, device: torch.device):
|
||||
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
|
||||
return self.pruning_keypoint_thresholds["flash"]
|
||||
if self.conf.flash and FLASH_AVAILABLE and device.type == 'cuda':
|
||||
return self.pruning_keypoint_thresholds['flash']
|
||||
else:
|
||||
return self.pruning_keypoint_thresholds[device.type]
|
||||
|
||||
@@ -6,15 +6,20 @@ import torch
|
||||
from kornia.color import rgb_to_grayscale
|
||||
from packaging import version
|
||||
|
||||
from .utils import Extractor
|
||||
|
||||
try:
|
||||
import pycolmap
|
||||
except ImportError:
|
||||
pycolmap = None
|
||||
|
||||
from .utils import Extractor
|
||||
|
||||
|
||||
def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
|
||||
def filter_dog_point(points,
|
||||
scales,
|
||||
angles,
|
||||
image_shape,
|
||||
nms_radius,
|
||||
scores=None):
|
||||
h, w = image_shape
|
||||
ij = np.round(points - 0.5).astype(int).T[::-1]
|
||||
|
||||
@@ -72,59 +77,59 @@ def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
|
||||
points = np.array([k.pt for k in detections], dtype=np.float32)
|
||||
scores = np.array([k.response for k in detections], dtype=np.float32)
|
||||
scales = np.array([k.size for k in detections], dtype=np.float32)
|
||||
angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
|
||||
angles = np.deg2rad(
|
||||
np.array([k.angle for k in detections], dtype=np.float32))
|
||||
return points, scores, scales, angles, descriptors
|
||||
|
||||
|
||||
class SIFT(Extractor):
|
||||
default_conf = {
|
||||
"rootsift": True,
|
||||
"nms_radius": 0, # None to disable filtering entirely.
|
||||
"max_num_keypoints": 4096,
|
||||
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
||||
"detection_threshold": 0.0066667, # from COLMAP
|
||||
"edge_threshold": 10,
|
||||
"first_octave": -1, # only used by pycolmap, the default of COLMAP
|
||||
"num_octaves": 4,
|
||||
'rootsift': True,
|
||||
'nms_radius': 0, # None to disable filtering entirely.
|
||||
'max_num_keypoints': 4096,
|
||||
'backend':
|
||||
'opencv', # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
||||
'detection_threshold': 0.0066667, # from COLMAP
|
||||
'edge_threshold': 10,
|
||||
'first_octave': -1, # only used by pycolmap, the default of COLMAP
|
||||
'num_octaves': 4,
|
||||
}
|
||||
|
||||
preprocess_conf = {
|
||||
"resize": 1024,
|
||||
'resize': 1024,
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
required_data_keys = ['image']
|
||||
|
||||
def __init__(self, **conf):
|
||||
super().__init__(**conf) # Update with default configuration.
|
||||
backend = self.conf.backend
|
||||
if backend.startswith("pycolmap"):
|
||||
if backend.startswith('pycolmap'):
|
||||
if pycolmap is None:
|
||||
raise ImportError(
|
||||
"Cannot find module pycolmap: install it with pip"
|
||||
"or use backend=opencv."
|
||||
)
|
||||
'Cannot find module pycolmap: install it with pip'
|
||||
'or use backend=opencv.')
|
||||
options = {
|
||||
"peak_threshold": self.conf.detection_threshold,
|
||||
"edge_threshold": self.conf.edge_threshold,
|
||||
"first_octave": self.conf.first_octave,
|
||||
"num_octaves": self.conf.num_octaves,
|
||||
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
||||
'peak_threshold': self.conf.detection_threshold,
|
||||
'edge_threshold': self.conf.edge_threshold,
|
||||
'first_octave': self.conf.first_octave,
|
||||
'num_octaves': self.conf.num_octaves,
|
||||
'normalization':
|
||||
pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
||||
}
|
||||
device = (
|
||||
"auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
|
||||
)
|
||||
if (
|
||||
backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
||||
) and pycolmap.__version__ < "0.5.0":
|
||||
device = ('auto' if backend == 'pycolmap' else backend.replace(
|
||||
'pycolmap_', ''))
|
||||
if (backend == 'pycolmap_cpu' or not pycolmap.has_cuda
|
||||
) and pycolmap.__version__ < '0.5.0': # noqa E501
|
||||
warnings.warn(
|
||||
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
|
||||
"consider upgrading pycolmap or use the CUDA version.",
|
||||
'The pycolmap CPU SIFT is buggy in version < 0.5.0, '
|
||||
'consider upgrading pycolmap or use the CUDA version.',
|
||||
stacklevel=1,
|
||||
)
|
||||
else:
|
||||
options["max_num_features"] = self.conf.max_num_keypoints
|
||||
options['max_num_features'] = self.conf.max_num_keypoints
|
||||
self.sift = pycolmap.Sift(options=options, device=device)
|
||||
elif backend == "opencv":
|
||||
elif backend == 'opencv':
|
||||
self.sift = cv2.SIFT_create(
|
||||
contrastThreshold=self.conf.detection_threshold,
|
||||
nfeatures=self.conf.max_num_keypoints,
|
||||
@@ -132,56 +137,52 @@ class SIFT(Extractor):
|
||||
nOctaveLayers=self.conf.num_octaves,
|
||||
)
|
||||
else:
|
||||
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
|
||||
raise ValueError(
|
||||
f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
|
||||
)
|
||||
backends = {'opencv', 'pycolmap', 'pycolmap_cpu', 'pycolmap_cuda'}
|
||||
raise ValueError(f'Unknown backend: {backend} not in '
|
||||
f"{{{','.join(backends)}}}.")
|
||||
|
||||
def extract_single_image(self, image: torch.Tensor):
|
||||
image_np = image.cpu().numpy().squeeze(0)
|
||||
|
||||
if self.conf.backend.startswith("pycolmap"):
|
||||
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
|
||||
if self.conf.backend.startswith('pycolmap'):
|
||||
if version.parse(pycolmap.__version__) >= version.parse('0.5.0'):
|
||||
detections, descriptors = self.sift.extract(image_np)
|
||||
scores = None # Scores are not exposed by COLMAP anymore.
|
||||
else:
|
||||
detections, scores, descriptors = self.sift.extract(image_np)
|
||||
keypoints = detections[:, :2] # Keep only (x, y).
|
||||
scales, angles = detections[:, -2:].T
|
||||
if scores is not None and (
|
||||
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
||||
):
|
||||
if scores is not None and (self.conf.backend == 'pycolmap_cpu'
|
||||
or not pycolmap.has_cuda):
|
||||
# Set the scores as a combination of abs. response and scale.
|
||||
scores = np.abs(scores) * scales
|
||||
elif self.conf.backend == "opencv":
|
||||
elif self.conf.backend == 'opencv':
|
||||
# TODO: Check if opencv keypoints are already in corner convention
|
||||
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
|
||||
self.sift, (image_np * 255.0).astype(np.uint8)
|
||||
)
|
||||
self.sift, (image_np * 255.0).astype(np.uint8))
|
||||
pred = {
|
||||
"keypoints": keypoints,
|
||||
"scales": scales,
|
||||
"oris": angles,
|
||||
"descriptors": descriptors,
|
||||
'keypoints': keypoints,
|
||||
'scales': scales,
|
||||
'oris': angles,
|
||||
'descriptors': descriptors,
|
||||
}
|
||||
if scores is not None:
|
||||
pred["keypoint_scores"] = scores
|
||||
pred['keypoint_scores'] = scores
|
||||
|
||||
# sometimes pycolmap returns points outside the image. We remove them
|
||||
if self.conf.backend.startswith("pycolmap"):
|
||||
is_inside = (
|
||||
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
|
||||
).all(-1)
|
||||
if self.conf.backend.startswith('pycolmap'):
|
||||
is_inside = (pred['keypoints'] + 0.5 < np.array(
|
||||
[image_np.shape[-2:][::-1]])).all(-1)
|
||||
pred = {k: v[is_inside] for k, v in pred.items()}
|
||||
|
||||
if self.conf.nms_radius is not None:
|
||||
keep = filter_dog_point(
|
||||
pred["keypoints"],
|
||||
pred["scales"],
|
||||
pred["oris"],
|
||||
pred['keypoints'],
|
||||
pred['scales'],
|
||||
pred['oris'],
|
||||
image_np.shape,
|
||||
self.conf.nms_radius,
|
||||
scores=pred.get("keypoint_scores"),
|
||||
scores=pred.get('keypoint_scores'),
|
||||
)
|
||||
pred = {k: v[keep] for k, v in pred.items()}
|
||||
|
||||
@@ -189,14 +190,15 @@ class SIFT(Extractor):
|
||||
if scores is not None:
|
||||
# Keep the k keypoints with highest score
|
||||
num_points = self.conf.max_num_keypoints
|
||||
if num_points is not None and len(pred["keypoints"]) > num_points:
|
||||
indices = torch.topk(pred["keypoint_scores"], num_points).indices
|
||||
if num_points is not None and len(pred['keypoints']) > num_points:
|
||||
indices = torch.topk(pred['keypoint_scores'],
|
||||
num_points).indices
|
||||
pred = {k: v[indices] for k, v in pred.items()}
|
||||
|
||||
return pred
|
||||
|
||||
def forward(self, data: dict) -> dict:
|
||||
image = data["image"]
|
||||
image = data['image']
|
||||
if image.shape[1] == 3:
|
||||
image = rgb_to_grayscale(image)
|
||||
device = image.device
|
||||
@@ -204,13 +206,16 @@ class SIFT(Extractor):
|
||||
pred = []
|
||||
for k in range(len(image)):
|
||||
img = image[k]
|
||||
if "image_size" in data.keys():
|
||||
if 'image_size' in data.keys():
|
||||
# avoid extracting points in padded areas
|
||||
w, h = data["image_size"][k]
|
||||
w, h = data['image_size'][k]
|
||||
img = img[:, :h, :w]
|
||||
p = self.extract_single_image(img)
|
||||
pred.append(p)
|
||||
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
||||
pred = {
|
||||
k: torch.stack([p[k] for p in pred], 0).to(device)
|
||||
for k in pred[0]
|
||||
}
|
||||
if self.conf.rootsift:
|
||||
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
|
||||
pred['descriptors'] = sift_to_rootsift(pred['descriptors'])
|
||||
return pred
|
||||
|
||||
@@ -42,12 +42,13 @@
|
||||
|
||||
# Adapted by Remi Pautrat, Philipp Lindenberger
|
||||
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
from kornia.color import rgb_to_grayscale
|
||||
from torch import nn
|
||||
|
||||
from .utils import Extractor
|
||||
import os.path as osp
|
||||
|
||||
|
||||
def simple_nms(scores, nms_radius: int):
|
||||
@@ -56,8 +57,7 @@ def simple_nms(scores, nms_radius: int):
|
||||
|
||||
def max_pool(x):
|
||||
return torch.nn.functional.max_pool2d(
|
||||
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
||||
)
|
||||
x, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius)
|
||||
|
||||
zeros = torch.zeros_like(scores)
|
||||
max_mask = scores == max_pool(scores)
|
||||
@@ -80,19 +80,14 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
|
||||
"""Interpolate descriptors at keypoint locations"""
|
||||
b, c, h, w = descriptors.shape
|
||||
keypoints = keypoints - s / 2 + 0.5
|
||||
keypoints /= torch.tensor(
|
||||
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
|
||||
).to(
|
||||
keypoints
|
||||
)[None]
|
||||
keypoints /= torch.tensor([(w * s - s / 2 - 0.5),
|
||||
(h * s - s / 2 - 0.5)], ).to(keypoints)[None]
|
||||
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
||||
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
|
||||
args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
|
||||
descriptors = torch.nn.functional.grid_sample(
|
||||
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
|
||||
)
|
||||
descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
|
||||
descriptors = torch.nn.functional.normalize(
|
||||
descriptors.reshape(b, c, -1), p=2, dim=1
|
||||
)
|
||||
descriptors.reshape(b, c, -1), p=2, dim=1)
|
||||
return descriptors
|
||||
|
||||
|
||||
@@ -106,20 +101,20 @@ class SuperPoint(Extractor):
|
||||
"""
|
||||
|
||||
default_conf = {
|
||||
"descriptor_dim": 256,
|
||||
"nms_radius": 4,
|
||||
"max_num_keypoints": None,
|
||||
"detection_threshold": 0.0005,
|
||||
"remove_borders": 4,
|
||||
'descriptor_dim': 256,
|
||||
'nms_radius': 4,
|
||||
'max_num_keypoints': None,
|
||||
'detection_threshold': 0.0005,
|
||||
'remove_borders': 4,
|
||||
}
|
||||
|
||||
preprocess_conf = {
|
||||
"resize": 1024,
|
||||
'resize': 1024,
|
||||
}
|
||||
|
||||
required_data_keys = ["image"]
|
||||
required_data_keys = ['image']
|
||||
|
||||
def __init__(self,model_dir, **conf):
|
||||
def __init__(self, model_dir, **conf):
|
||||
super().__init__(**conf) # Update with default configuration.
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
@@ -139,21 +134,19 @@ class SuperPoint(Extractor):
|
||||
|
||||
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
||||
self.convDb = nn.Conv2d(
|
||||
c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
|
||||
weights_path = osp.join(model_dir,"superpoint_v1.pth")
|
||||
self.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
||||
weights_path = osp.join(model_dir, 'superpoint_v1.pth')
|
||||
self.load_state_dict(torch.load(weights_path, map_location='cpu'))
|
||||
|
||||
if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
|
||||
raise ValueError("max_num_keypoints must be positive or None")
|
||||
raise ValueError('max_num_keypoints must be positive or None')
|
||||
|
||||
def forward(self, data: dict) -> dict:
|
||||
"""Compute keypoints, scores, descriptors for image"""
|
||||
for key in self.required_data_keys:
|
||||
assert key in data, f"Missing key {key} in data"
|
||||
image = data["image"]
|
||||
assert key in data, f'Missing key {key} in data'
|
||||
image = data['image']
|
||||
if image.shape[1] == 3:
|
||||
image = rgb_to_grayscale(image)
|
||||
|
||||
@@ -193,20 +186,18 @@ class SuperPoint(Extractor):
|
||||
|
||||
# Separate into batches
|
||||
keypoints = [
|
||||
torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i] for i in range(b)
|
||||
torch.stack(best_kp[1:3], dim=-1)[best_kp[0] == i]
|
||||
for i in range(b)
|
||||
]
|
||||
scores = [scores[best_kp[0] == i] for i in range(b)]
|
||||
|
||||
# Keep the k keypoints with highest score
|
||||
if self.conf.max_num_keypoints is not None:
|
||||
keypoints, scores = list(
|
||||
zip(
|
||||
*[
|
||||
top_k_keypoints(k, s, self.conf.max_num_keypoints)
|
||||
for k, s in zip(keypoints, scores)
|
||||
]
|
||||
)
|
||||
)
|
||||
zip(*[
|
||||
top_k_keypoints(k, s, self.conf.max_num_keypoints)
|
||||
for k, s in zip(keypoints, scores)
|
||||
]))
|
||||
|
||||
# Convert (h, w) to (x, y)
|
||||
keypoints = [torch.flip(k, [1]).float() for k in keypoints]
|
||||
@@ -223,7 +214,10 @@ class SuperPoint(Extractor):
|
||||
]
|
||||
|
||||
return {
|
||||
"keypoints": torch.stack(keypoints, 0),
|
||||
"keypoint_scores": torch.stack(scores, 0),
|
||||
"descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
|
||||
'keypoints':
|
||||
torch.stack(keypoints, 0),
|
||||
'keypoint_scores':
|
||||
torch.stack(scores, 0),
|
||||
'descriptors':
|
||||
torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
|
||||
}
|
||||
|
||||
@@ -11,11 +11,11 @@ import torch
|
||||
|
||||
class ImagePreprocessor:
|
||||
default_conf = {
|
||||
"resize": None, # target edge length, None for no resizing
|
||||
"side": "long",
|
||||
"interpolation": "bilinear",
|
||||
"align_corners": None,
|
||||
"antialias": True,
|
||||
'resize': None, # target edge length, None for no resizing
|
||||
'side': 'long',
|
||||
'interpolation': 'bilinear',
|
||||
'align_corners': None,
|
||||
'antialias': True,
|
||||
}
|
||||
|
||||
def __init__(self, **conf) -> None:
|
||||
@@ -52,7 +52,9 @@ def map_tensor(input_, func: Callable):
|
||||
return input_
|
||||
|
||||
|
||||
def batch_to_device(batch: dict, device: str = "cpu", non_blocking: bool = True):
|
||||
def batch_to_device(batch: dict,
|
||||
device: str = 'cpu',
|
||||
non_blocking: bool = True):
|
||||
"""Move batch (dict) to device"""
|
||||
|
||||
def _func(tensor):
|
||||
@@ -72,11 +74,11 @@ def rbd(data: dict) -> dict:
|
||||
def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
|
||||
"""Read an image from path as RGB or grayscale"""
|
||||
if not Path(path).exists():
|
||||
raise FileNotFoundError(f"No image at path {path}.")
|
||||
raise FileNotFoundError(f'No image at path {path}.')
|
||||
mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
|
||||
image = cv2.imread(str(path), mode)
|
||||
if image is None:
|
||||
raise IOError(f"Could not read image at {path}.")
|
||||
raise IOError(f'Could not read image at {path}.')
|
||||
if not grayscale:
|
||||
image = image[..., ::-1]
|
||||
return image
|
||||
@@ -89,20 +91,20 @@ def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
|
||||
elif image.ndim == 2:
|
||||
image = image[None] # add channel axis
|
||||
else:
|
||||
raise ValueError(f"Not an image: {image.shape}")
|
||||
raise ValueError(f'Not an image: {image.shape}')
|
||||
return torch.tensor(image / 255.0, dtype=torch.float)
|
||||
|
||||
|
||||
def resize_image(
|
||||
image: np.ndarray,
|
||||
size: Union[List[int], int],
|
||||
fn: str = "max",
|
||||
interp: Optional[str] = "area",
|
||||
fn: str = 'max',
|
||||
interp: Optional[str] = 'area',
|
||||
) -> np.ndarray:
|
||||
"""Resize an image to a fixed size, or according to max or min edge."""
|
||||
h, w = image.shape[:2]
|
||||
|
||||
fn = {"max": max, "min": min}[fn]
|
||||
fn = {'max': max, 'min': min}[fn]
|
||||
if isinstance(size, int):
|
||||
scale = size / fn(h, w)
|
||||
h_new, w_new = int(round(h * scale)), int(round(w * scale))
|
||||
@@ -111,12 +113,12 @@ def resize_image(
|
||||
h_new, w_new = size
|
||||
scale = (w_new / w, h_new / h)
|
||||
else:
|
||||
raise ValueError(f"Incorrect new size: {size}")
|
||||
raise ValueError(f'Incorrect new size: {size}')
|
||||
mode = {
|
||||
"linear": cv2.INTER_LINEAR,
|
||||
"cubic": cv2.INTER_CUBIC,
|
||||
"nearest": cv2.INTER_NEAREST,
|
||||
"area": cv2.INTER_AREA,
|
||||
'linear': cv2.INTER_LINEAR,
|
||||
'cubic': cv2.INTER_CUBIC,
|
||||
'nearest': cv2.INTER_NEAREST,
|
||||
'area': cv2.INTER_AREA,
|
||||
}[interp]
|
||||
return cv2.resize(image, (w_new, h_new), interpolation=mode), scale
|
||||
|
||||
@@ -129,6 +131,7 @@ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
|
||||
|
||||
|
||||
class Extractor(torch.nn.Module):
|
||||
|
||||
def __init__(self, **conf):
|
||||
super().__init__()
|
||||
self.conf = SimpleNamespace(**{**self.default_conf, **conf})
|
||||
@@ -140,10 +143,14 @@ class Extractor(torch.nn.Module):
|
||||
img = img[None] # add batch dim
|
||||
assert img.dim() == 4 and img.shape[0] == 1
|
||||
shape = img.shape[-2:][::-1]
|
||||
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
|
||||
feats = self.forward({"image": img})
|
||||
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
|
||||
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
|
||||
img, scales = ImagePreprocessor(**{
|
||||
**self.preprocess_conf,
|
||||
**conf
|
||||
})(
|
||||
img)
|
||||
feats = self.forward({'image': img})
|
||||
feats['image_size'] = torch.tensor(shape)[None].to(img).float()
|
||||
feats['keypoints'] = (feats['keypoints'] + 0.5) / scales[None] - 0.5
|
||||
return feats
|
||||
|
||||
|
||||
@@ -152,13 +159,13 @@ def match_pair(
|
||||
matcher,
|
||||
image0: torch.Tensor,
|
||||
image1: torch.Tensor,
|
||||
device: str = "cpu",
|
||||
device: str = 'cpu',
|
||||
**preprocess,
|
||||
):
|
||||
"""Match a pair of images (image0, image1) with an extractor and matcher"""
|
||||
feats0 = extractor.extract(image0, **preprocess)
|
||||
feats1 = extractor.extract(image1, **preprocess)
|
||||
matches01 = matcher({"image0": feats0, "image1": feats1})
|
||||
matches01 = matcher({'image0': feats0, 'image1': feats1})
|
||||
data = [feats0, feats1, matches01]
|
||||
# remove batch dim and move to target device
|
||||
feats0, feats1, matches01 = [batch_to_device(rbd(x), device) for x in data]
|
||||
|
||||
@@ -22,10 +22,12 @@ def cm_RdGn(x):
|
||||
def cm_BlRdGn(x_):
|
||||
"""Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
|
||||
x = np.clip(x_, 0, 1)[..., None] * 2
|
||||
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
|
||||
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array(
|
||||
[[1.0, 0, 0, 1.0]])
|
||||
|
||||
xn = -np.clip(x_, -1, 0)[..., None] * 2
|
||||
cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
|
||||
cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array(
|
||||
[[1.0, 0, 0, 1.0]])
|
||||
out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
|
||||
return out
|
||||
|
||||
@@ -39,7 +41,12 @@ def cm_prune(x_):
|
||||
return cm_BlRdGn(norm_x)
|
||||
|
||||
|
||||
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
|
||||
def plot_images(imgs,
|
||||
titles=None,
|
||||
cmaps='gray',
|
||||
dpi=100,
|
||||
pad=0.5,
|
||||
adaptive=True):
|
||||
"""Plot a set of images horizontally.
|
||||
Args:
|
||||
imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
|
||||
@@ -49,9 +56,8 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
|
||||
"""
|
||||
# conversion to (H, W, 3) for torch.Tensor
|
||||
imgs = [
|
||||
img.permute(1, 2, 0).cpu().numpy()
|
||||
if (isinstance(img, torch.Tensor) and img.dim() == 3)
|
||||
else img
|
||||
img.permute(1, 2, 0).cpu().numpy() if
|
||||
(isinstance(img, torch.Tensor) and img.dim() == 3) else img
|
||||
for img in imgs
|
||||
]
|
||||
|
||||
@@ -65,8 +71,7 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
|
||||
ratios = [4 / 3] * n
|
||||
figsize = [sum(ratios) * 4.5, 4.5]
|
||||
fig, ax = plt.subplots(
|
||||
1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
|
||||
)
|
||||
1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
|
||||
if n == 1:
|
||||
ax = [ax]
|
||||
for i in range(n):
|
||||
@@ -81,7 +86,7 @@ def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True
|
||||
fig.tight_layout(pad=pad)
|
||||
|
||||
|
||||
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
|
||||
def plot_keypoints(kpts, colors='lime', ps=4, axes=None, a=1.0):
|
||||
"""Plot keypoints for existing images.
|
||||
Args:
|
||||
kpts: list of ndarrays of size (N, 2).
|
||||
@@ -100,7 +105,14 @@ def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
|
||||
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
|
||||
|
||||
|
||||
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
|
||||
def plot_matches(kpts0,
|
||||
kpts1,
|
||||
color=None,
|
||||
lw=1.5,
|
||||
ps=4,
|
||||
a=1.0,
|
||||
labels=None,
|
||||
axes=None):
|
||||
"""Plot matches for a pair of existing images.
|
||||
Args:
|
||||
kpts0, kpts1: corresponding keypoints of size (N, 2).
|
||||
@@ -160,25 +172,28 @@ def add_text(
|
||||
text,
|
||||
pos=(0.01, 0.99),
|
||||
fs=15,
|
||||
color="w",
|
||||
lcolor="k",
|
||||
color='w',
|
||||
lcolor='k',
|
||||
lwidth=2,
|
||||
ha="left",
|
||||
va="top",
|
||||
ha='left',
|
||||
va='top',
|
||||
):
|
||||
ax = plt.gcf().axes[idx]
|
||||
t = ax.text(
|
||||
*pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
|
||||
)
|
||||
*pos,
|
||||
text,
|
||||
fontsize=fs,
|
||||
ha=ha,
|
||||
va=va,
|
||||
color=color,
|
||||
transform=ax.transAxes)
|
||||
if lcolor is not None:
|
||||
t.set_path_effects(
|
||||
[
|
||||
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
|
||||
path_effects.Normal(),
|
||||
]
|
||||
)
|
||||
t.set_path_effects([
|
||||
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
|
||||
path_effects.Normal(),
|
||||
])
|
||||
|
||||
|
||||
def save_plot(path, **kw):
|
||||
"""Save the current figure without any white margin."""
|
||||
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
|
||||
plt.savefig(path, bbox_inches='tight', pad_inches=0, **kw)
|
||||
|
||||
@@ -13,9 +13,9 @@ from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .lightglue import LightGlue, SuperPoint, DISK, ALIKED, SIFT
|
||||
from .lightglue.utils import rbd, numpy_image_to_torch
|
||||
from .config.default import lightglue_default_conf
|
||||
from .lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint
|
||||
from .lightglue.utils import numpy_image_to_torch, rbd
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
@@ -30,20 +30,28 @@ class LightGlueImageMatching(TorchModel):
|
||||
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 'mps', 'cpu'
|
||||
self.device = torch.device(
|
||||
'cuda' if torch.cuda.is_available() else 'cpu') # 'mps', 'cpu'
|
||||
|
||||
features = lightglue_default_conf.get('features', 'superpoint')
|
||||
|
||||
features = lightglue_default_conf.get('features','superpoint')
|
||||
|
||||
if features == 'disk':
|
||||
self.extractor = DISK(max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
self.extractor = DISK(
|
||||
max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
elif features == 'aliked':
|
||||
self.extractor = ALIKED(max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
self.extractor = ALIKED(
|
||||
max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
elif features == 'sift':
|
||||
self.extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
self.extractor = SIFT(
|
||||
max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
else:
|
||||
self.extractor = SuperPoint(model_dir=model_dir, max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
|
||||
self.matcher = LightGlue(model_dir=model_dir, default_conf=lightglue_default_conf).eval().to(self.device)
|
||||
self.extractor = SuperPoint(
|
||||
model_dir=model_dir,
|
||||
max_num_keypoints=max_num_keypoints).eval().to(self.device)
|
||||
|
||||
self.matcher = LightGlue(
|
||||
model_dir=model_dir,
|
||||
default_conf=lightglue_default_conf).eval().to(self.device)
|
||||
|
||||
def forward(self, inputs):
|
||||
'''
|
||||
@@ -51,9 +59,11 @@ class LightGlueImageMatching(TorchModel):
|
||||
inputs: a dict with keys 'image0', 'image1'
|
||||
'''
|
||||
|
||||
feats0 = self.extractor.extract(numpy_image_to_torch(inputs['image0']).to(self.device))
|
||||
feats1 = self.extractor.extract(numpy_image_to_torch(inputs['image1']).to(self.device))
|
||||
matches01 = self.matcher({"image0": feats0, "image1": feats1})
|
||||
feats0 = self.extractor.extract(
|
||||
numpy_image_to_torch(inputs['image0']).to(self.device))
|
||||
feats1 = self.extractor.extract(
|
||||
numpy_image_to_torch(inputs['image1']).to(self.device))
|
||||
matches01 = self.matcher({'image0': feats0, 'image1': feats1})
|
||||
|
||||
return [feats0, feats1, matches01]
|
||||
|
||||
@@ -63,17 +73,21 @@ class LightGlueImageMatching(TorchModel):
|
||||
inputs: a list of feats0, feats1, matches01
|
||||
'''
|
||||
matching_result = inputs
|
||||
feats0, feats1, matches01 = [
|
||||
rbd(x) for x in matching_result
|
||||
] # remove batch dimension
|
||||
feats0, feats1, matches01 = [rbd(x) for x in matching_result
|
||||
] # remove batch dimension
|
||||
|
||||
kpts0, kpts1, matches = feats0["keypoints"], feats1["keypoints"], matches01["matches"]
|
||||
kpts0, kpts1, matches = feats0['keypoints'], feats1[
|
||||
'keypoints'], matches01['matches']
|
||||
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
|
||||
|
||||
# match confidence
|
||||
confidence = matches01["scores"]
|
||||
confidence = matches01['scores']
|
||||
|
||||
matches_result = {'kpts0': m_kpts0,'kpts1': m_kpts1,'confidence': confidence}
|
||||
matches_result = {
|
||||
'kpts0': m_kpts0,
|
||||
'kpts1': m_kpts1,
|
||||
'confidence': confidence
|
||||
}
|
||||
|
||||
results = {OutputKeys.MATCHES: matches_result}
|
||||
return results
|
||||
|
||||
@@ -296,7 +296,9 @@ else:
|
||||
],
|
||||
'human3d_render_pipeline': ['Human3DRenderPipeline'],
|
||||
'human3d_animation_pipeline': ['Human3DAnimationPipeline'],
|
||||
'image_local_feature_matching_pipeline': ['ImageLocalFeatureMatchingPipeline'],
|
||||
'image_local_feature_matching_pipeline': [
|
||||
'ImageLocalFeatureMatchingPipeline'
|
||||
],
|
||||
'rife_video_frame_interpolation_pipeline': [
|
||||
'RIFEVideoFrameInterpolationPipeline'
|
||||
],
|
||||
|
||||
@@ -27,8 +27,10 @@ class ImageLocalFeatureMatchingPipeline(Pipeline):
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
|
||||
>>> matcher = pipeline(Tasks.image_local_feature_matching, model='Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data')
|
||||
>>> matcher([['https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching1.jpg','https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching2.jpg']])
|
||||
>>> matcher = pipeline(Tasks.image_local_feature_matching,
|
||||
>>> model='Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data')
|
||||
>>> matcher([['https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching1.jpg',
|
||||
>>> 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching2.jpg']])
|
||||
>>> [{
|
||||
>>> 'matches': [array([[720.5 , 187.8 ],
|
||||
>>> [707.4 , 198.23334],
|
||||
@@ -69,7 +71,6 @@ class ImageLocalFeatureMatchingPipeline(Pipeline):
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
|
||||
def load_image(self, img_name):
|
||||
img = LoadImage.convert_to_ndarray(img_name).astype(np.float32)
|
||||
img = img / 255.
|
||||
|
||||
@@ -67,10 +67,7 @@ class ImageMatchingFastPipeline(Pipeline):
|
||||
img1 = self.load_image(input[0])
|
||||
img2 = self.load_image(input[1])
|
||||
|
||||
return {
|
||||
'image0':img1,
|
||||
'image1':img2
|
||||
}
|
||||
return {'image0': img1, 'image1': img2}
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> list:
|
||||
results = self.model.inference(input)
|
||||
|
||||
@@ -26,11 +26,12 @@ class ImageLocalFeatureMatchingTest(unittest.TestCase):
|
||||
'data/test/images/image_matching1.jpg',
|
||||
'data/test/images/image_matching2.jpg'
|
||||
]]
|
||||
estimator = pipeline(Tasks.image_local_feature_matching, model=self.model_id)
|
||||
estimator = pipeline(
|
||||
Tasks.image_local_feature_matching, model=self.model_id)
|
||||
result = estimator(input_location)
|
||||
kpts0, kpts1, conf = result[0][OutputKeys.MATCHES]
|
||||
vis_img = result[0][OutputKeys.OUTPUT_IMG]
|
||||
cv2.imwrite("vis_demo.jpg", vis_img)
|
||||
cv2.imwrite('vis_demo.jpg', vis_img)
|
||||
|
||||
print('test_image_local_feature_matching DONE')
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ class ImageMatchingFastTest(unittest.TestCase):
|
||||
kpts1,
|
||||
confidence,
|
||||
output_filename='lightglue-matches.png',
|
||||
method="lightglue")
|
||||
method='lightglue')
|
||||
|
||||
print('test_image_matching DONE')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user