File size: 2,578 Bytes
a80d6bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from fvcore.nn import FlopCountAnalysis
from einops.einops import rearrange

from src import get_model_cfg
from src.models.backbone import FPN as topicfm_featnet
from src.models.modules import TopicFormer
from src.utils.dataset import read_scannet_gray

from third_party.loftr.src.loftr.utils.cvpr_ds_config import default_cfg
from third_party.loftr.src.loftr.backbone import ResNetFPN_8_2 as loftr_featnet
from third_party.loftr.src.loftr.loftr_module import LocalFeatureTransformer


def feat_net_flops(feat_net, config, input):
    model = feat_net(config)
    model.eval()
    flops = FlopCountAnalysis(model, input)
    feat_c, _ = model(input)
    return feat_c, flops.total() / 1e9


def coarse_model_flops(coarse_model, config, inputs):
    model = coarse_model(config)
    model.eval()
    flops = FlopCountAnalysis(model, inputs)
    return flops.total() / 1e9


if __name__ == '__main__':
    path_img0 = "assets/scannet_sample_images/scene0711_00_frame-001680.jpg"
    path_img1 = "assets/scannet_sample_images/scene0711_00_frame-001995.jpg"
    img0, img1 = read_scannet_gray(path_img0), read_scannet_gray(path_img1)
    img0, img1 = img0.unsqueeze(0), img1.unsqueeze(0)

    # LoFTR
    loftr_conf = dict(default_cfg)
    feat_c0, loftr_featnet_flops0 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img0)
    feat_c1, loftr_featnet_flops1 = feat_net_flops(loftr_featnet, loftr_conf["resnetfpn"], img1)
    print("FLOPs of feature extraction in LoFTR: {} GFLOPs".format((loftr_featnet_flops0 + loftr_featnet_flops1)/2))
    feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
    feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
    loftr_coarse_model_flops = coarse_model_flops(LocalFeatureTransformer, loftr_conf["coarse"], (feat_c0, feat_c1))
    print("FLOPs of coarse matching model in LoFTR: {} GFLOPs".format(loftr_coarse_model_flops))

    # TopicFM
    topicfm_conf = get_model_cfg()
    feat_c0, topicfm_featnet_flops0 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img0)
    feat_c1, topicfm_featnet_flops1 = feat_net_flops(topicfm_featnet, topicfm_conf["fpn"], img1)
    print("FLOPs of feature extraction in TopicFM: {} GFLOPs".format((topicfm_featnet_flops0 + topicfm_featnet_flops1) / 2))
    feat_c0 = rearrange(feat_c0, 'n c h w -> n (h w) c')
    feat_c1 = rearrange(feat_c1, 'n c h w -> n (h w) c')
    topicfm_coarse_model_flops = coarse_model_flops(TopicFormer, topicfm_conf["coarse"], (feat_c0, feat_c1))
    print("FLOPs of coarse matching model in TopicFM: {} GFLOPs".format(topicfm_coarse_model_flops))