diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 15e990f5..d3ccffd1 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -88,6 +88,7 @@ class Models(object): video_object_segmentation = 'video-object-segmentation' video_deinterlace = 'video-deinterlace' quadtree_attention_image_matching = 'quadtree-attention-image-matching' + loftr_image_local_feature_matching = 'loftr-image-local-feature-matching' lightglue_image_matching = 'lightglue-image-matching' vision_middleware = 'vision-middleware' vidt = 'vidt' @@ -395,6 +396,7 @@ class Pipelines(object): image_depth_estimation = 'image-depth-estimation' image_normal_estimation = 'image-normal-estimation' indoor_layout_estimation = 'indoor-layout-estimation' + image_local_feature_matching = 'image-local-feature-matching' video_depth_estimation = 'video-depth-estimation' panorama_depth_estimation = 'panorama-depth-estimation' panorama_depth_estimation_s2net = 'panorama-depth-estimation-s2net' @@ -805,6 +807,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.panorama_depth_estimation: (Pipelines.panorama_depth_estimation, 'damo/cv_unifuse_panorama-depth-estimation'), + Tasks.image_local_feature_matching: + (Pipelines.image_local_feature_matching, 'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data'), Tasks.image_style_transfer: (Pipelines.image_style_transfer, 'damo/cv_aams_style-transfer_damo'), Tasks.face_image_generation: (Pipelines.face_image_generation, diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index fa10868b..39f46f5d 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -29,6 +29,6 @@ from . import (action_recognition, animal_recognition, bad_image_detecting, video_panoptic_segmentation, video_single_object_tracking, video_stabilization, video_summarization, video_super_resolution, vidt, virual_tryon, vision_middleware, - vop_retrieval,image_matching_fast) + vop_retrieval, image_local_feature_matching,image_matching_fast) # yapf: enable diff --git a/modelscope/models/cv/image_local_feature_matching/__init__.py b/modelscope/models/cv/image_local_feature_matching/__init__.py new file mode 100644 index 00000000..256843b8 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .loftr_model import LocalFeatureMatching + +else: + _import_structure = { + 'loftr_image_local_feature_matching': ['LocalFeatureMatching'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) \ No newline at end of file diff --git a/modelscope/models/cv/image_local_feature_matching/loftr_model.py b/modelscope/models/cv/image_local_feature_matching/loftr_model.py new file mode 100644 index 00000000..157dfa28 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/loftr_model.py @@ -0,0 +1,74 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp + +import io +import cv2 +import torch +import numpy as np +from copy import deepcopy + +from modelscope.metainfo import Models +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.models.cv.image_local_feature_matching.src.loftr import \ + LoFTR, default_cfg +from modelscope.models.cv.image_local_feature_matching.src.utils.plotting import make_matching_figure +from modelscope.outputs import OutputKeys +from modelscope.utils.constant import ModelFile, Tasks +import matplotlib.cm as cm + + +@MODELS.register_module( + Tasks.image_local_feature_matching, + module_name=Models.loftr_image_local_feature_matching) +class LocalFeatureMatching(TorchModel): + + def __init__(self, model_dir: str, **kwargs): + """str -- model file root.""" + super().__init__(model_dir, **kwargs) + + # build model + # Initialize LoFTR + _default_cfg = deepcopy(default_cfg) + self.model = LoFTR(config=_default_cfg) + + # load model + model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE) + checkpoint = torch.load(model_path, map_location='cpu') + self.model.load_state_dict(checkpoint['state_dict']) + self.model.eval() + + def forward(self, Inputs): + self.model(Inputs) + result = { + 'kpts0': Inputs['mkpts0_f'], + 'kpts1': Inputs['mkpts1_f'], + 'conf': Inputs['mconf'], + } + Inputs.update(result) + return Inputs + + def postprocess(self, Inputs): + # Draw + color = cm.jet(Inputs['conf'].cpu().numpy()) + img0, img1, mkpts0, mkpts1 = Inputs["image0"].squeeze().cpu().numpy(), Inputs["image1"].squeeze().cpu().numpy(), Inputs["kpts0"].cpu().numpy(), Inputs["kpts1"].cpu().numpy() + text = [ + 'LoFTR', + 'Matches: {}'.format(len(Inputs['kpts0'])), + ] + img0, img1 = (img0 * 255).astype(np.uint8), (img1 * 255).astype(np.uint8) + fig = make_matching_figure(img0, img1, mkpts0, mkpts1, color, text=text) + io_buf = io.BytesIO() + fig.savefig(io_buf, format="png", dpi=75) + io_buf.seek(0) + buf_data = np.frombuffer(io_buf.getvalue(), dtype=np.uint8) + io_buf.close() + vis_img = cv2.imdecode(buf_data, 1) + + results = {OutputKeys.MATCHES: Inputs, OutputKeys.OUTPUT_IMG: vis_img} + return results + + def inference(self, data): + results = self.forward(data) + + return results \ No newline at end of file diff --git a/modelscope/models/cv/image_local_feature_matching/src/__init__.py b/modelscope/models/cv/image_local_feature_matching/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/__init__.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/__init__.py new file mode 100644 index 00000000..0d69b9c1 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/__init__.py @@ -0,0 +1,2 @@ +from .loftr import LoFTR +from .utils.cvpr_ds_config import default_cfg diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/backbone/__init__.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/backbone/__init__.py new file mode 100644 index 00000000..b6e731b3 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/backbone/__init__.py @@ -0,0 +1,11 @@ +from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 + + +def build_backbone(config): + if config['backbone_type'] == 'ResNetFPN': + if config['resolution'] == (8, 2): + return ResNetFPN_8_2(config['resnetfpn']) + elif config['resolution'] == (16, 4): + return ResNetFPN_16_4(config['resnetfpn']) + else: + raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/backbone/resnet_fpn.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/backbone/resnet_fpn.py new file mode 100644 index 00000000..985e5b3f --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/backbone/resnet_fpn.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn.functional as F + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution without padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=False) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class BasicBlock(nn.Module): + def __init__(self, in_planes, planes, stride=1): + super().__init__() + self.conv1 = conv3x3(in_planes, planes, stride) + self.conv2 = conv3x3(planes, planes) + self.bn1 = nn.BatchNorm2d(planes) + self.bn2 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + conv1x1(in_planes, planes, stride=stride), + nn.BatchNorm2d(planes) + ) + + def forward(self, x): + y = x + y = self.relu(self.bn1(self.conv1(y))) + y = self.bn2(self.conv2(y)) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResNetFPN_8_2(nn.Module): + """ + ResNet+FPN, output resolution are 1/8 and 1/2. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + + # 3. FPN upsample + self.layer3_outconv = conv1x1(block_dims[2], block_dims[2]) + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + self.layer1_outconv = conv1x1(block_dims[0], block_dims[1]) + self.layer1_outconv2 = nn.Sequential( + conv3x3(block_dims[1], block_dims[1]), + nn.BatchNorm2d(block_dims[1]), + nn.LeakyReLU(), + conv3x3(block_dims[1], block_dims[0]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + + # FPN + x3_out = self.layer3_outconv(x3) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + x2_out_2x = F.interpolate(x2_out, scale_factor=2., mode='bilinear', align_corners=True) + x1_out = self.layer1_outconv(x1) + x1_out = self.layer1_outconv2(x1_out+x2_out_2x) + + return [x3_out, x1_out] + + +class ResNetFPN_16_4(nn.Module): + """ + ResNet+FPN, output resolution are 1/16 and 1/4. + Each block has 2 layers. + """ + + def __init__(self, config): + super().__init__() + # Config + block = BasicBlock + initial_dim = config['initial_dim'] + block_dims = config['block_dims'] + + # Class Variable + self.in_planes = initial_dim + + # Networks + self.conv1 = nn.Conv2d(1, initial_dim, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = nn.BatchNorm2d(initial_dim) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(block, block_dims[0], stride=1) # 1/2 + self.layer2 = self._make_layer(block, block_dims[1], stride=2) # 1/4 + self.layer3 = self._make_layer(block, block_dims[2], stride=2) # 1/8 + self.layer4 = self._make_layer(block, block_dims[3], stride=2) # 1/16 + + # 3. FPN upsample + self.layer4_outconv = conv1x1(block_dims[3], block_dims[3]) + self.layer3_outconv = conv1x1(block_dims[2], block_dims[3]) + self.layer3_outconv2 = nn.Sequential( + conv3x3(block_dims[3], block_dims[3]), + nn.BatchNorm2d(block_dims[3]), + nn.LeakyReLU(), + conv3x3(block_dims[3], block_dims[2]), + ) + + self.layer2_outconv = conv1x1(block_dims[1], block_dims[2]) + self.layer2_outconv2 = nn.Sequential( + conv3x3(block_dims[2], block_dims[2]), + nn.BatchNorm2d(block_dims[2]), + nn.LeakyReLU(), + conv3x3(block_dims[2], block_dims[1]), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, dim, stride=1): + layer1 = block(self.in_planes, dim, stride=stride) + layer2 = block(dim, dim, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + # ResNet Backbone + x0 = self.relu(self.bn1(self.conv1(x))) + x1 = self.layer1(x0) # 1/2 + x2 = self.layer2(x1) # 1/4 + x3 = self.layer3(x2) # 1/8 + x4 = self.layer4(x3) # 1/16 + + # FPN + x4_out = self.layer4_outconv(x4) + + x4_out_2x = F.interpolate(x4_out, scale_factor=2., mode='bilinear', align_corners=True) + x3_out = self.layer3_outconv(x3) + x3_out = self.layer3_outconv2(x3_out+x4_out_2x) + + x3_out_2x = F.interpolate(x3_out, scale_factor=2., mode='bilinear', align_corners=True) + x2_out = self.layer2_outconv(x2) + x2_out = self.layer2_outconv2(x2_out+x3_out_2x) + + return [x4_out, x2_out] diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr.py new file mode 100644 index 00000000..79c491ee --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr.py @@ -0,0 +1,81 @@ +import torch +import torch.nn as nn +from einops.einops import rearrange + +from .backbone import build_backbone +from .utils.position_encoding import PositionEncodingSine +from .loftr_module import LocalFeatureTransformer, FinePreprocess +from .utils.coarse_matching import CoarseMatching +from .utils.fine_matching import FineMatching + + +class LoFTR(nn.Module): + def __init__(self, config): + super().__init__() + # Misc + self.config = config + + # Modules + self.backbone = build_backbone(config) + self.pos_encoding = PositionEncodingSine( + config['coarse']['d_model'], + temp_bug_fix=config['coarse']['temp_bug_fix']) + self.loftr_coarse = LocalFeatureTransformer(config['coarse']) + self.coarse_matching = CoarseMatching(config['match_coarse']) + self.fine_preprocess = FinePreprocess(config) + self.loftr_fine = LocalFeatureTransformer(config["fine"]) + self.fine_matching = FineMatching() + + def forward(self, data): + """ + Update: + data (dict): { + 'image0': (torch.Tensor): (N, 1, H, W) + 'image1': (torch.Tensor): (N, 1, H, W) + 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position + 'mask1'(optional) : (torch.Tensor): (N, H, W) + } + """ + # 1. Local Feature CNN + data.update({ + 'bs': data['image0'].size(0), + 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] + }) + + if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence + feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) + (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) + else: # handle different input shapes + (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) + + data.update({ + 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], + 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] + }) + + # 2. coarse-level loftr module + # add featmap with positional encoding, then flatten it to sequence [N, HW, C] + feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') + feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') + + mask_c0 = mask_c1 = None # mask is useful in training + if 'mask0' in data: + mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) + feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) + + # 3. match coarse-level + self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) + + # 4. fine-level refinement + feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) + if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted + feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) + + # 5. match fine-level + self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) + + def load_state_dict(self, state_dict, *args, **kwargs): + for k in list(state_dict.keys()): + if k.startswith('matcher.'): + state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) + return super().load_state_dict(state_dict, *args, **kwargs) diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/__init__.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/__init__.py new file mode 100644 index 00000000..ca51db4f --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .fine_preprocess import FinePreprocess diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/fine_preprocess.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/fine_preprocess.py new file mode 100644 index 00000000..5bb8eefd --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/fine_preprocess.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange, repeat + + +class FinePreprocess(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.cat_c_feat = config['fine_concat_coarse_feat'] + self.W = self.config['fine_window_size'] + + d_model_c = self.config['coarse']['d_model'] + d_model_f = self.config['fine']['d_model'] + self.d_model_f = d_model_f + if self.cat_c_feat: + self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) + self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) + + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") + + def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): + W = self.W + stride = data['hw0_f'][0] // data['hw0_c'][0] + + data.update({'W': W}) + if data['b_ids'].shape[0] == 0: + feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) + return feat0, feat1 + + # 1. unfold(crop) all local windows + feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) + feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) + + # 2. select only the predicted matches + feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] + feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] + + # option: use coarse-level loftr feature as context: concat and linear + if self.cat_c_feat: + feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], + feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] + feat_cf_win = self.merge_feat(torch.cat([ + torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] + repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] + ], -1)) + feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) + + return feat_f0_unfold, feat_f1_unfold diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/linear_attention.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/linear_attention.py new file mode 100644 index 00000000..b73c5a6a --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/transformer.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/transformer.py new file mode 100644 index 00000000..d79390ca --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/loftr_module/transformer.py @@ -0,0 +1,101 @@ +import copy +import torch +import torch.nn as nn +from .linear_attention import LinearAttention, FullAttention + + +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, config): + super(LocalFeatureTransformer, self).__init__() + + self.config = config + self.d_model = config['d_model'] + self.nhead = config['nhead'] + self.layer_names = config['layer_names'] + encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/__init__.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/coarse_matching.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/coarse_matching.py new file mode 100644 index 00000000..a9726333 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/coarse_matching.py @@ -0,0 +1,261 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.einops import rearrange + +INF = 1e9 + +def mask_border(m, b: int, v): + """ Mask borders with value + Args: + m (torch.Tensor): [N, H0, W0, H1, W1] + b (int) + v (m.dtype) + """ + if b <= 0: + return + + m[:, :b] = v + m[:, :, :b] = v + m[:, :, :, :b] = v + m[:, :, :, :, :b] = v + m[:, -b:] = v + m[:, :, -b:] = v + m[:, :, :, -b:] = v + m[:, :, :, :, -b:] = v + + +def mask_border_with_padding(m, bd, v, p_m0, p_m1): + if bd <= 0: + return + + m[:, :bd] = v + m[:, :, :bd] = v + m[:, :, :, :bd] = v + m[:, :, :, :, :bd] = v + + h0s, w0s = p_m0.sum(1).max(-1)[0].int(), p_m0.sum(-1).max(-1)[0].int() + h1s, w1s = p_m1.sum(1).max(-1)[0].int(), p_m1.sum(-1).max(-1)[0].int() + for b_idx, (h0, w0, h1, w1) in enumerate(zip(h0s, w0s, h1s, w1s)): + m[b_idx, h0 - bd:] = v + m[b_idx, :, w0 - bd:] = v + m[b_idx, :, :, h1 - bd:] = v + m[b_idx, :, :, :, w1 - bd:] = v + + +def compute_max_candidates(p_m0, p_m1): + """Compute the max candidates of all pairs within a batch + + Args: + p_m0, p_m1 (torch.Tensor): padded masks + """ + h0s, w0s = p_m0.sum(1).max(-1)[0], p_m0.sum(-1).max(-1)[0] + h1s, w1s = p_m1.sum(1).max(-1)[0], p_m1.sum(-1).max(-1)[0] + max_cand = torch.sum( + torch.min(torch.stack([h0s * w0s, h1s * w1s], -1), -1)[0]) + return max_cand + + +class CoarseMatching(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + # general config + self.thr = config['thr'] + self.border_rm = config['border_rm'] + # -- # for trainig fine-level LoFTR + self.train_coarse_percent = config['train_coarse_percent'] + self.train_pad_num_gt_min = config['train_pad_num_gt_min'] + + # we provide 2 options for differentiable matching + self.match_type = config['match_type'] + if self.match_type == 'dual_softmax': + self.temperature = config['dsmax_temperature'] + elif self.match_type == 'sinkhorn': + try: + from .superglue import log_optimal_transport + except ImportError: + raise ImportError("download superglue.py first!") + self.log_optimal_transport = log_optimal_transport + self.bin_score = nn.Parameter( + torch.tensor(config['skh_init_bin_score'], requires_grad=True)) + self.skh_iters = config['skh_iters'] + self.skh_prefilter = config['skh_prefilter'] + else: + raise NotImplementedError() + + def forward(self, feat_c0, feat_c1, data, mask_c0=None, mask_c1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + data (dict) + mask_c0 (torch.Tensor): [N, L] (optional) + mask_c1 (torch.Tensor): [N, S] (optional) + Update: + data (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + NOTE: M' != M during training. + """ + N, L, S, C = feat_c0.size(0), feat_c0.size(1), feat_c1.size(1), feat_c0.size(2) + + # normalize + feat_c0, feat_c1 = map(lambda feat: feat / feat.shape[-1]**.5, + [feat_c0, feat_c1]) + + if self.match_type == 'dual_softmax': + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, + feat_c1) / self.temperature + if mask_c0 is not None: + sim_matrix.masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + conf_matrix = F.softmax(sim_matrix, 1) * F.softmax(sim_matrix, 2) + + elif self.match_type == 'sinkhorn': + # sinkhorn, dustbin included + sim_matrix = torch.einsum("nlc,nsc->nls", feat_c0, feat_c1) + if mask_c0 is not None: + sim_matrix[:, :L, :S].masked_fill_( + ~(mask_c0[..., None] * mask_c1[:, None]).bool(), + -INF) + + # build uniform prior & use sinkhorn + log_assign_matrix = self.log_optimal_transport( + sim_matrix, self.bin_score, self.skh_iters) + assign_matrix = log_assign_matrix.exp() + conf_matrix = assign_matrix[:, :-1, :-1] + + # filter prediction with dustbin score (only in evaluation mode) + if not self.training and self.skh_prefilter: + filter0 = (assign_matrix.max(dim=2)[1] == S)[:, :-1] # [N, L] + filter1 = (assign_matrix.max(dim=1)[1] == L)[:, :-1] # [N, S] + conf_matrix[filter0[..., None].repeat(1, 1, S)] = 0 + conf_matrix[filter1[:, None].repeat(1, L, 1)] = 0 + + if self.config['sparse_spvs']: + data.update({'conf_matrix_with_bin': assign_matrix.clone()}) + + data.update({'conf_matrix': conf_matrix}) + + # predict coarse matches from conf_matrix + data.update(**self.get_coarse_match(conf_matrix, data)) + + @torch.no_grad() + def get_coarse_match(self, conf_matrix, data): + """ + Args: + conf_matrix (torch.Tensor): [N, L, S] + data (dict): with keys ['hw0_i', 'hw1_i', 'hw0_c', 'hw1_c'] + Returns: + coarse_matches (dict): { + 'b_ids' (torch.Tensor): [M'], + 'i_ids' (torch.Tensor): [M'], + 'j_ids' (torch.Tensor): [M'], + 'gt_mask' (torch.Tensor): [M'], + 'm_bids' (torch.Tensor): [M], + 'mkpts0_c' (torch.Tensor): [M, 2], + 'mkpts1_c' (torch.Tensor): [M, 2], + 'mconf' (torch.Tensor): [M]} + """ + axes_lengths = { + 'h0c': data['hw0_c'][0], + 'w0c': data['hw0_c'][1], + 'h1c': data['hw1_c'][0], + 'w1c': data['hw1_c'][1] + } + _device = conf_matrix.device + # 1. confidence thresholding + mask = conf_matrix > self.thr + mask = rearrange(mask, 'b (h0c w0c) (h1c w1c) -> b h0c w0c h1c w1c', + **axes_lengths) + if 'mask0' not in data: + mask_border(mask, self.border_rm, False) + else: + mask_border_with_padding(mask, self.border_rm, False, + data['mask0'], data['mask1']) + mask = rearrange(mask, 'b h0c w0c h1c w1c -> b (h0c w0c) (h1c w1c)', + **axes_lengths) + + # 2. mutual nearest + mask = mask \ + * (conf_matrix == conf_matrix.max(dim=2, keepdim=True)[0]) \ + * (conf_matrix == conf_matrix.max(dim=1, keepdim=True)[0]) + + # 3. find all valid coarse matches + # this only works when at most one `True` in each row + mask_v, all_j_ids = mask.max(dim=2) + b_ids, i_ids = torch.where(mask_v) + j_ids = all_j_ids[b_ids, i_ids] + mconf = conf_matrix[b_ids, i_ids, j_ids] + + # 4. Random sampling of training samples for fine-level LoFTR + # (optional) pad samples with gt coarse-level matches + if self.training: + # NOTE: + # The sampling is performed across all pairs in a batch without manually balancing + # #samples for fine-level increases w.r.t. batch_size + if 'mask0' not in data: + num_candidates_max = mask.size(0) * max( + mask.size(1), mask.size(2)) + else: + num_candidates_max = compute_max_candidates( + data['mask0'], data['mask1']) + num_matches_train = int(num_candidates_max * + self.train_coarse_percent) + num_matches_pred = len(b_ids) + assert self.train_pad_num_gt_min < num_matches_train, "min-num-gt-pad should be less than num-train-matches" + + # pred_indices is to select from prediction + if num_matches_pred <= num_matches_train - self.train_pad_num_gt_min: + pred_indices = torch.arange(num_matches_pred, device=_device) + else: + pred_indices = torch.randint( + num_matches_pred, + (num_matches_train - self.train_pad_num_gt_min, ), + device=_device) + + # gt_pad_indices is to select from gt padding. e.g. max(3787-4800, 200) + gt_pad_indices = torch.randint( + len(data['spv_b_ids']), + (max(num_matches_train - num_matches_pred, + self.train_pad_num_gt_min), ), + device=_device) + mconf_gt = torch.zeros(len(data['spv_b_ids']), device=_device) # set conf of gt paddings to all zero + + b_ids, i_ids, j_ids, mconf = map( + lambda x, y: torch.cat([x[pred_indices], y[gt_pad_indices]], + dim=0), + *zip([b_ids, data['spv_b_ids']], [i_ids, data['spv_i_ids']], + [j_ids, data['spv_j_ids']], [mconf, mconf_gt])) + + # These matches select patches that feed into fine-level network + coarse_matches = {'b_ids': b_ids, 'i_ids': i_ids, 'j_ids': j_ids} + + # 4. Update with matches in original image resolution + scale = data['hw0_i'][0] / data['hw0_c'][0] + scale0 = scale * data['scale0'][b_ids] if 'scale0' in data else scale + scale1 = scale * data['scale1'][b_ids] if 'scale1' in data else scale + mkpts0_c = torch.stack( + [i_ids % data['hw0_c'][1], i_ids // data['hw0_c'][1]], + dim=1) * scale0 + mkpts1_c = torch.stack( + [j_ids % data['hw1_c'][1], j_ids // data['hw1_c'][1]], + dim=1) * scale1 + + # These matches is the current prediction (for visualization) + coarse_matches.update({ + 'gt_mask': mconf == 0, + 'm_bids': b_ids[mconf != 0], # mconf == 0 => gt matches + 'mkpts0_c': mkpts0_c[mconf != 0], + 'mkpts1_c': mkpts1_c[mconf != 0], + 'mconf': mconf[mconf != 0] + }) + + return coarse_matches diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/cvpr_ds_config.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/cvpr_ds_config.py new file mode 100644 index 00000000..1c9ce701 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/cvpr_ds_config.py @@ -0,0 +1,50 @@ +from yacs.config import CfgNode as CN + + +def lower_config(yacs_cfg): + if not isinstance(yacs_cfg, CN): + return yacs_cfg + return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} + + +_CN = CN() +_CN.BACKBONE_TYPE = 'ResNetFPN' +_CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] +_CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd +_CN.FINE_CONCAT_COARSE_FEAT = True + +# 1. LoFTR-backbone (local feature CNN) config +_CN.RESNETFPN = CN() +_CN.RESNETFPN.INITIAL_DIM = 128 +_CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 + +# 2. LoFTR-coarse module config +_CN.COARSE = CN() +_CN.COARSE.D_MODEL = 256 +_CN.COARSE.D_FFN = 256 +_CN.COARSE.NHEAD = 8 +_CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 +_CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] +_CN.COARSE.TEMP_BUG_FIX = False + +# 3. Coarse-Matching config +_CN.MATCH_COARSE = CN() +_CN.MATCH_COARSE.THR = 0.2 +_CN.MATCH_COARSE.BORDER_RM = 2 +_CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] +_CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 +_CN.MATCH_COARSE.SKH_ITERS = 3 +_CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 +_CN.MATCH_COARSE.SKH_PREFILTER = True +_CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory +_CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock + +# 4. LoFTR-fine module config +_CN.FINE = CN() +_CN.FINE.D_MODEL = 128 +_CN.FINE.D_FFN = 128 +_CN.FINE.NHEAD = 8 +_CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 +_CN.FINE.ATTENTION = 'linear' + +default_cfg = lower_config(_CN) diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/fine_matching.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/fine_matching.py new file mode 100644 index 00000000..689518d9 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/fine_matching.py @@ -0,0 +1,163 @@ +import math +import torch +import torch.nn as nn + + +def create_meshgrid( + height: int, + width: int, + normalized_coordinates: bool = True, + device = None, + dtype = None, +): + """Generate a coordinate grid for an image. + + When the flag ``normalized_coordinates`` is set to True, the grid is + normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch + function :py:func:`torch.nn.functional.grid_sample`. + + Args: + height: the image height (rows). + width: the image width (cols). + normalized_coordinates: whether to normalize + coordinates in the range :math:`[-1,1]` in order to be consistent with the + PyTorch function :py:func:`torch.nn.functional.grid_sample`. + device: the device on which the grid will be generated. + dtype: the data type of the generated grid. + + Return: + grid tensor with shape :math:`(1, H, W, 2)`. + + Example: + >>> create_meshgrid(2, 2) + tensor([[[[-1., -1.], + [ 1., -1.]], + + [[-1., 1.], + [ 1., 1.]]]]) + + >>> create_meshgrid(2, 2, normalized_coordinates=False) + tensor([[[[0., 0.], + [1., 0.]], + + [[0., 1.], + [1., 1.]]]]) + """ + xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype) + ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype) + if normalized_coordinates: + xs = (xs / (width - 1) - 0.5) * 2 + ys = (ys / (height - 1) - 0.5) * 2 + base_grid = torch.stack(torch.meshgrid([xs, ys], indexing="ij"), dim=-1) # WxHx2 + return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2 + + +def spatial_expectation2d(input, normalized_coordinates: bool = True): + r"""Compute the expectation of coordinate values using spatial probabilities. + + The input heatmap is assumed to represent a valid spatial probability distribution, + which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`. + + Args: + input: the input tensor representing dense spatial probabilities with shape :math:`(B, N, H, W)`. + normalized_coordinates: whether to return the coordinates normalized in the range + of :math:`[-1, 1]`. Otherwise, it will return the coordinates in the range of the input shape. + + Returns: + expected value of the 2D coordinates with shape :math:`(B, N, 2)`. Output order of the coordinates is (x, y). + + Examples: + >>> heatmaps = torch.tensor([[[ + ... [0., 0., 0.], + ... [0., 0., 0.], + ... [0., 1., 0.]]]]) + >>> spatial_expectation2d(heatmaps, False) + tensor([[[1., 2.]]]) + """ + + batch_size, channels, height, width = input.shape + + # Create coordinates grid. + grid = create_meshgrid(height, width, normalized_coordinates, input.device) + grid = grid.to(input.dtype) + + pos_x = grid[..., 0].reshape(-1) + pos_y = grid[..., 1].reshape(-1) + + input_flat = input.view(batch_size, channels, -1) + + # Compute the expectation of the coordinates. + expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True) + expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True) + + output = torch.cat([expected_x, expected_y], -1) + + return output.view(batch_size, channels, 2) # BxNx2 + + +class FineMatching(nn.Module): + """FineMatching with s2d paradigm""" + + def __init__(self): + super().__init__() + + def forward(self, feat_f0, feat_f1, data): + """ + Args: + feat0 (torch.Tensor): [M, WW, C] + feat1 (torch.Tensor): [M, WW, C] + data (dict) + Update: + data (dict):{ + 'expec_f' (torch.Tensor): [M, 3], + 'mkpts0_f' (torch.Tensor): [M, 2], + 'mkpts1_f' (torch.Tensor): [M, 2]} + """ + M, WW, C = feat_f0.shape + W = int(math.sqrt(WW)) + scale = data['hw0_i'][0] / data['hw0_f'][0] + self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale + + # corner case: if no coarse matches found + if M == 0: + assert self.training == False, "M is always >0, when training, see coarse_matching.py" + # logger.warning('No matches found in coarse-level.') + data.update({ + 'expec_f': torch.empty(0, 3, device=feat_f0.device), + 'mkpts0_f': data['mkpts0_c'], + 'mkpts1_f': data['mkpts1_c'], + }) + return + + feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] + sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) + softmax_temp = 1. / C**.5 + heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) + + # compute coordinates from heatmap + coords_normalized = spatial_expectation2d(heatmap[None], True)[0] # [M, 2] + grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] + + # compute std over + var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] + std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability + + # for fine-level supervision + data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) + + # compute absolute kpt coords + self.get_fine_match(coords_normalized, data) + + @torch.no_grad() + def get_fine_match(self, coords_normed, data): + W, WW, C, scale = self.W, self.WW, self.C, self.scale + + # mkpts0_f and mkpts1_f + mkpts0_f = data['mkpts0_c'] + scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale + mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] + + data.update({ + "mkpts0_f": mkpts0_f, + "mkpts1_f": mkpts1_f + }) diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/geometry.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/geometry.py new file mode 100644 index 00000000..f95cdb65 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/geometry.py @@ -0,0 +1,54 @@ +import torch + + +@torch.no_grad() +def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): + """ Warp kpts0 from I0 to I1 with depth, K and Rt + Also check covisibility and depth consistency. + Depth is consistent if relative error < 0.2 (hard-coded). + + Args: + kpts0 (torch.Tensor): [N, L, 2] - , + depth0 (torch.Tensor): [N, H, W], + depth1 (torch.Tensor): [N, H, W], + T_0to1 (torch.Tensor): [N, 3, 4], + K0 (torch.Tensor): [N, 3, 3], + K1 (torch.Tensor): [N, 3, 3], + Returns: + calculable_mask (torch.Tensor): [N, L] + warped_keypoints0 (torch.Tensor): [N, L, 2] + """ + kpts0_long = kpts0.round().long() + + # Sample depth, get calculable_mask on depth != 0 + kpts0_depth = torch.stack( + [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 + ) # (N, L) + nonzero_mask = kpts0_depth != 0 + + # Unproject + kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) + kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) + + # Rigid Transform + w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) + w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] + + # Project + w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) + w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth + + # Covisible Check + h, w = depth1.shape[1:3] + covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ + (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) + w_kpts0_long = w_kpts0.long() + w_kpts0_long[~covisible_mask, :] = 0 + + w_kpts0_depth = torch.stack( + [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 + ) # (N, L) + consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 + valid_mask = nonzero_mask * covisible_mask * consistent_mask + + return valid_mask, w_kpts0 diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/position_encoding.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/position_encoding.py new file mode 100644 index 00000000..732d28c8 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/position_encoding.py @@ -0,0 +1,42 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] diff --git a/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/supervision.py b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/supervision.py new file mode 100644 index 00000000..4749e24a --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/loftr/utils/supervision.py @@ -0,0 +1,151 @@ +from math import log +from loguru import logger + +import torch +from einops import repeat +from kornia.utils import create_meshgrid + +from .geometry import warp_kpts + +############## ↓ Coarse-Level supervision ↓ ############## + + +@torch.no_grad() +def mask_pts_at_padded_regions(grid_pt, mask): + """For megadepth dataset, zero-padding exists in images""" + mask = repeat(mask, 'n h w -> n (h w) c', c=2) + grid_pt[~mask.bool()] = 0 + return grid_pt + + +@torch.no_grad() +def spvs_coarse(data, config): + """ + Update: + data (dict): { + "conf_matrix_gt": [N, hw0, hw1], + 'spv_b_ids': [M] + 'spv_i_ids': [M] + 'spv_j_ids': [M] + 'spv_w_pt0_i': [N, hw0, 2], in original image resolution + 'spv_pt1_i': [N, hw1, 2], in original image resolution + } + + NOTE: + - for scannet dataset, there're 3 kinds of resolution {i, c, f} + - for megadepth dataset, there're 4 kinds of resolution {i, i_resize, c, f} + """ + # 1. misc + device = data['image0'].device + N, _, H0, W0 = data['image0'].shape + _, _, H1, W1 = data['image1'].shape + scale = config['LOFTR']['RESOLUTION'][0] + scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale + scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale + h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1]) + + # 2. warp grids + # create kpts in meshgrid and resize them to image resolution + grid_pt0_c = create_meshgrid(h0, w0, False, device).reshape(1, h0*w0, 2).repeat(N, 1, 1) # [N, hw, 2] + grid_pt0_i = scale0 * grid_pt0_c + grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) + grid_pt1_i = scale1 * grid_pt1_c + + # mask padded region to (0, 0), so no need to manually mask conf_matrix_gt + if 'mask0' in data: + grid_pt0_i = mask_pts_at_padded_regions(grid_pt0_i, data['mask0']) + grid_pt1_i = mask_pts_at_padded_regions(grid_pt1_i, data['mask1']) + + # warp kpts bi-directionally and resize them to coarse-level resolution + # (no depth consistency check, since it leads to worse results experimentally) + # (unhandled edge case: points with 0-depth will be warped to the left-up corner) + _, w_pt0_i = warp_kpts(grid_pt0_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1']) + _, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0']) + w_pt0_c = w_pt0_i / scale1 + w_pt1_c = w_pt1_i / scale0 + + # 3. check if mutual nearest neighbor + w_pt0_c_round = w_pt0_c[:, :, :].round().long() + nearest_index1 = w_pt0_c_round[..., 0] + w_pt0_c_round[..., 1] * w1 + w_pt1_c_round = w_pt1_c[:, :, :].round().long() + nearest_index0 = w_pt1_c_round[..., 0] + w_pt1_c_round[..., 1] * w0 + + # corner case: out of boundary + def out_bound_mask(pt, w, h): + return (pt[..., 0] < 0) + (pt[..., 0] >= w) + (pt[..., 1] < 0) + (pt[..., 1] >= h) + nearest_index1[out_bound_mask(w_pt0_c_round, w1, h1)] = 0 + nearest_index0[out_bound_mask(w_pt1_c_round, w0, h0)] = 0 + + loop_back = torch.stack([nearest_index0[_b][_i] for _b, _i in enumerate(nearest_index1)], dim=0) + correct_0to1 = loop_back == torch.arange(h0*w0, device=device)[None].repeat(N, 1) + correct_0to1[:, 0] = False # ignore the top-left corner + + # 4. construct a gt conf_matrix + conf_matrix_gt = torch.zeros(N, h0*w0, h1*w1, device=device) + b_ids, i_ids = torch.where(correct_0to1 != 0) + j_ids = nearest_index1[b_ids, i_ids] + + conf_matrix_gt[b_ids, i_ids, j_ids] = 1 + data.update({'conf_matrix_gt': conf_matrix_gt}) + + # 5. save coarse matches(gt) for training fine level + if len(b_ids) == 0: + logger.warning(f"No groundtruth coarse match found for: {data['pair_names']}") + # this won't affect fine-level loss calculation + b_ids = torch.tensor([0], device=device) + i_ids = torch.tensor([0], device=device) + j_ids = torch.tensor([0], device=device) + + data.update({ + 'spv_b_ids': b_ids, + 'spv_i_ids': i_ids, + 'spv_j_ids': j_ids + }) + + # 6. save intermediate results (for fast fine-level computation) + data.update({ + 'spv_w_pt0_i': w_pt0_i, + 'spv_pt1_i': grid_pt1_i + }) + + +def compute_supervision_coarse(data, config): + assert len(set(data['dataset_name'])) == 1, "Do not support mixed datasets training!" + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_coarse(data, config) + else: + raise ValueError(f'Unknown data source: {data_source}') + + +############## ↓ Fine-Level supervision ↓ ############## + +@torch.no_grad() +def spvs_fine(data, config): + """ + Update: + data (dict):{ + "expec_f_gt": [M, 2]} + """ + # 1. misc + # w_pt0_i, pt1_i = data.pop('spv_w_pt0_i'), data.pop('spv_pt1_i') + w_pt0_i, pt1_i = data['spv_w_pt0_i'], data['spv_pt1_i'] + scale = config['LOFTR']['RESOLUTION'][1] + radius = config['LOFTR']['FINE_WINDOW_SIZE'] // 2 + + # 2. get coarse prediction + b_ids, i_ids, j_ids = data['b_ids'], data['i_ids'], data['j_ids'] + + # 3. compute gt + scale = scale * data['scale1'][b_ids] if 'scale0' in data else scale + # `expec_f_gt` might exceed the window, i.e. abs(*) > 1, which would be filtered later + expec_f_gt = (w_pt0_i[b_ids, i_ids] - pt1_i[b_ids, j_ids]) / scale / radius # [M, 2] + data.update({"expec_f_gt": expec_f_gt}) + + +def compute_supervision_fine(data, config): + data_source = data['dataset_name'][0] + if data_source.lower() in ['scannet', 'megadepth']: + spvs_fine(data, config) + else: + raise NotImplementedError diff --git a/modelscope/models/cv/image_local_feature_matching/src/utils/__init__.py b/modelscope/models/cv/image_local_feature_matching/src/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_local_feature_matching/src/utils/plotting.py b/modelscope/models/cv/image_local_feature_matching/src/utils/plotting.py new file mode 100644 index 00000000..3d4c5ca5 --- /dev/null +++ b/modelscope/models/cv/image_local_feature_matching/src/utils/plotting.py @@ -0,0 +1,154 @@ +import bisect +import numpy as np +import matplotlib.pyplot as plt +import matplotlib + + +def _compute_conf_thresh(data): + dataset_name = data['dataset_name'][0].lower() + if dataset_name == 'scannet': + thr = 5e-4 + elif dataset_name == 'megadepth': + thr = 1e-4 + else: + raise ValueError(f'Unknown dataset: {dataset_name}') + return thr + + +# --- VISUALIZATION --- # + +def make_matching_figure( + img0, img1, mkpts0, mkpts1, color, + kpts0=None, kpts1=None, text=[], dpi=75, path=None): + # draw image pair + assert mkpts0.shape[0] == mkpts1.shape[0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' + fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) + axes[0].imshow(img0, cmap='gray') + axes[1].imshow(img1, cmap='gray') + for i in range(2): # clear all frames + axes[i].get_yaxis().set_ticks([]) + axes[i].get_xaxis().set_ticks([]) + for spine in axes[i].spines.values(): + spine.set_visible(False) + plt.tight_layout(pad=1) + + if kpts0 is not None: + assert kpts1 is not None + axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) + axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) + + # draw matches + if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: + fig.canvas.draw() + transFigure = fig.transFigure.inverted() + fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) + fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) + fig.lines = [matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), + (fkpts0[i, 1], fkpts1[i, 1]), + transform=fig.transFigure, c=color[i], linewidth=1) + for i in range(len(mkpts0))] + + axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) + axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) + + # put txts + txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' + fig.text( + 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, + fontsize=15, va='top', ha='left', color=txt_color) + + # save or return figure + if path: + plt.savefig(str(path), bbox_inches='tight', pad_inches=0) + plt.close() + else: + return fig + + +def _make_evaluation_figure(data, b_id, alpha='dynamic'): + b_mask = data['m_bids'] == b_id + conf_thr = _compute_conf_thresh(data) + + img0 = (data['image0'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + img1 = (data['image1'][b_id][0].cpu().numpy() * 255).round().astype(np.int32) + kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() + kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() + + # for megadepth, we visualize matches on the resized image + if 'scale0' in data: + kpts0 = kpts0 / data['scale0'][b_id].cpu().numpy()[[1, 0]] + kpts1 = kpts1 / data['scale1'][b_id].cpu().numpy()[[1, 0]] + + epi_errs = data['epi_errs'][b_mask].cpu().numpy() + correct_mask = epi_errs < conf_thr + precision = np.mean(correct_mask) if len(correct_mask) > 0 else 0 + n_correct = np.sum(correct_mask) + n_gt_matches = int(data['conf_matrix_gt'][b_id].sum().cpu()) + recall = 0 if n_gt_matches == 0 else n_correct / (n_gt_matches) + # recall might be larger than 1, since the calculation of conf_matrix_gt + # uses groundtruth depths and camera poses, but epipolar distance is used here. + + # matching info + if alpha == 'dynamic': + alpha = dynamic_alpha(len(correct_mask)) + color = error_colormap(epi_errs, conf_thr, alpha=alpha) + + text = [ + f'#Matches {len(kpts0)}', + f'Precision({conf_thr:.2e}) ({100 * precision:.1f}%): {n_correct}/{len(kpts0)}', + f'Recall({conf_thr:.2e}) ({100 * recall:.1f}%): {n_correct}/{n_gt_matches}' + ] + + # make the figure + figure = make_matching_figure(img0, img1, kpts0, kpts1, + color, text=text) + return figure + +def _make_confidence_figure(data, b_id): + # TODO: Implement confidence figure + raise NotImplementedError() + + +def make_matching_figures(data, config, mode='evaluation'): + """ Make matching figures for a batch. + + Args: + data (Dict): a batch updated by PL_LoFTR. + config (Dict): matcher config + Returns: + figures (Dict[str, List[plt.figure]] + """ + assert mode in ['evaluation', 'confidence'] # 'confidence' + figures = {mode: []} + for b_id in range(data['image0'].size(0)): + if mode == 'evaluation': + fig = _make_evaluation_figure( + data, b_id, + alpha=config.TRAINER.PLOT_MATCHES_ALPHA) + elif mode == 'confidence': + fig = _make_confidence_figure(data, b_id) + else: + raise ValueError(f'Unknown plot mode: {mode}') + figures[mode].append(fig) + return figures + + +def dynamic_alpha(n_matches, + milestones=[0, 300, 1000, 2000], + alphas=[1.0, 0.8, 0.4, 0.2]): + if n_matches == 0: + return 1.0 + ranges = list(zip(alphas, alphas[1:] + [None])) + loc = bisect.bisect_right(milestones, n_matches) - 1 + _range = ranges[loc] + if _range[1] is None: + return _range[0] + return _range[1] + (milestones[loc + 1] - n_matches) / ( + milestones[loc + 1] - milestones[loc]) * (_range[0] - _range[1]) + + +def error_colormap(err, thr, alpha=1.0): + assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}" + x = 1 - np.clip(err / (thr * 2), 0, 1) + return np.clip( + np.stack([2-x*2, x*2, np.zeros_like(x), np.ones_like(x)*alpha], -1), 0, 1) diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 989af094..4d74e965 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -118,6 +118,7 @@ if TYPE_CHECKING: from .text_to_360panorama_image_pipeline import Text2360PanoramaImagePipeline from .human3d_render_pipeline import Human3DRenderPipeline from .human3d_animation_pipeline import Human3DAnimationPipeline + from .image_local_feature_matching_pipeline import ImageLocalFeatureMatchingPipeline from .rife_video_frame_interpolation_pipeline import RIFEVideoFrameInterpolationPipeline from .anydoor_pipeline import AnydoorPipeline else: @@ -295,6 +296,7 @@ else: ], 'human3d_render_pipeline': ['Human3DRenderPipeline'], 'human3d_animation_pipeline': ['Human3DAnimationPipeline'], + 'image_local_feature_matching_pipeline': ['ImageLocalFeatureMatchingPipeline'], 'rife_video_frame_interpolation_pipeline': [ 'RIFEVideoFrameInterpolationPipeline' ], diff --git a/modelscope/pipelines/cv/image_local_feature_matching_pipeline.py b/modelscope/pipelines/cv/image_local_feature_matching_pipeline.py new file mode 100644 index 00000000..81fc60d0 --- /dev/null +++ b/modelscope/pipelines/cv/image_local_feature_matching_pipeline.py @@ -0,0 +1,121 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL +import torch + +from modelscope.metainfo import Pipelines +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_local_feature_matching, + module_name=Pipelines.image_local_feature_matching) +class ImageLocalFeatureMatchingPipeline(Pipeline): + r""" Image Local Feature Matching Pipeline. + + Examples: + + >>> from modelscope.pipelines import pipeline + + >>> matcher = pipeline(Tasks.image_local_feature_matching, model='Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data') + >>> matcher([['https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching1.jpg','https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matching2.jpg']]) + >>> [{ + >>> 'matches': [array([[720.5 , 187.8 ], + >>> [707.4 , 198.23334], + >>> ..., + >>> [746.7 , 594.7 ], + >>> [759.8 , 594.7 ]], dtype=float32), + >>> array([[652.49744 , 29.599142], + >>> [639.25287 , 45.90798 ], + >>> [653.041 , 43.399014], + >>> ..., + >>> [670.8787 , 547.8298 ], + >>> [608.5573 , 548.97815 ], + >>> [617.82574 , 548.601 ]], dtype=float32), + >>> array([0.25541496, 0.2781789 , 0.20776041, ..., 0.39656195, 0.7202848 , + >>> 0.37208357], dtype=float32)], + >>> 'output_img': array([[[255, 255, 255], + >>> [255, 255, 255], + >>> [255, 255, 255], + >>> ..., + >>> [255, 255, 255], + >>> [255, 255, 255], + >>> [255, 255, 255]], + >>> ..., + >>> [[255, 255, 255], + >>> [255, 255, 255], + >>> [255, 255, 255], + >>> ..., + >>> [255, 255, 255], + >>> [255, 255, 255], + >>> [255, 255, 255]]], dtype=uint8)}] + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image local feature matching pipeline for prediction + Args: + model: model id on modelscope hub. + """ + super().__init__(model=model, **kwargs) + + + def load_image(self, img_name): + img = LoadImage.convert_to_ndarray(img_name).astype(np.float32) + img = img / 255. + # convert rgb to gray + if len(img.shape) == 3: + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + H, W = 480, 640 + h_scale, w_scale = H / img.shape[0], W / img.shape[1] + img = cv2.resize(img, (W, H)) + return img, h_scale, w_scale + + def preprocess(self, input: Input): + assert len(input) == 2, 'input should be a list of two images' + + img1, h_scale1, w_scale1 = self.load_image(input[0]) + + img2, h_scale2, w_scale2 = self.load_image(input[1]) + + img1 = torch.from_numpy(img1)[None][None].cuda().float() + img2 = torch.from_numpy(img2)[None][None].cuda().float() + return { + 'image0': img1, + 'image1': img2, + 'scale_info': [h_scale1, w_scale1, h_scale2, w_scale2] + } + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.inference(input) + return results + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + results = self.model.postprocess(inputs) + matches = results[OutputKeys.MATCHES] + + kpts0 = matches['kpts0'].cpu().numpy() + kpts1 = matches['kpts1'].cpu().numpy() + conf = matches['conf'].cpu().numpy() + scale_info = [v.cpu().numpy() for v in inputs['scale_info']] + kpts0[:, 0] = kpts0[:, 0] / scale_info[1] + kpts0[:, 1] = kpts0[:, 1] / scale_info[0] + kpts1[:, 0] = kpts1[:, 0] / scale_info[3] + kpts1[:, 1] = kpts1[:, 1] / scale_info[2] + + outputs = { + OutputKeys.MATCHES: [kpts0, kpts1, conf], + OutputKeys.OUTPUT_IMG: results[OutputKeys.OUTPUT_IMG] + } + + return outputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 562d0105..2d0030ab 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -70,6 +70,7 @@ class CVTasks(object): face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' image_matching = 'image-matching' + image_local_feature_matching = 'image-local-feature-matching' image_quality_assessment_degradation = 'image-quality-assessment-degradation' crowd_counting = 'crowd-counting' diff --git a/modelscope/utils/pipeline_schema.json b/modelscope/utils/pipeline_schema.json index c1fe8c0b..b8e80ef0 100644 --- a/modelscope/utils/pipeline_schema.json +++ b/modelscope/utils/pipeline_schema.json @@ -1281,6 +1281,13 @@ "type": "object" } }, + "image-local-feature-matching": { + "input": {}, + "parameters": {}, + "output": { + "type": "object" + } + }, "image-multi-view-depth-estimation": { "input": {}, "parameters": {}, diff --git a/tests/pipelines/test_image_local_feature_matching.py b/tests/pipelines/test_image_local_feature_matching.py new file mode 100644 index 00000000..84c99d01 --- /dev/null +++ b/tests/pipelines/test_image_local_feature_matching.py @@ -0,0 +1,39 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import unittest +from pathlib import Path + +import cv2 +import matplotlib.cm as cm +import numpy as np + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.cv.image_utils import match_pair_visualization +from modelscope.utils.test_utils import test_level + + +class ImageLocalFeatureMatchingTest(unittest.TestCase): + + def setUp(self) -> None: + self.task = 'image-local-feature-matching' + self.model_id = 'Damo_XR_Lab/cv_resnet-transformer_local-feature-matching_outdoor-data' + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_image_local_feature_matching(self): + input_location = [[ + 'data/test/images/image_matching1.jpg', + 'data/test/images/image_matching2.jpg' + ]] + estimator = pipeline(Tasks.image_local_feature_matching, model=self.model_id) + result = estimator(input_location) + kpts0, kpts1, conf = result[0][OutputKeys.MATCHES] + vis_img = result[0][OutputKeys.OUTPUT_IMG] + cv2.imwrite("vis_demo.jpg", vis_img) + + print('test_image_local_feature_matching DONE') + + +if __name__ == '__main__': + unittest.main()