File size: 2,097 Bytes
bb3ea39
 
f4b82b2
bb3ea39
f4b82b2
9651aac
c9dadbf
9651aac
f4b82b2
c9dadbf
f4b82b2
 
 
 
 
 
 
c9dadbf
f4b82b2
 
 
 
 
9651aac
c9dadbf
 
 
 
 
 
4e8ced7
9651aac
c9dadbf
 
f4b82b2
c9dadbf
 
f4b82b2
c9dadbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4b82b2
1a2db09
 
 
 
 
 
 
 
 
 
 
 
 
f4b82b2
 
1a2db09
 
 
 
 
f4b82b2
1a2db09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr

import torch

import fire_network

import cv2

# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] 

# Load net
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net = fire_network.init_network(**state['net_params']).to(device)
net.load_state_dict(state['state_dict'])

transform = transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(), 
        transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))
        ])


# which sf
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]


col = plt.get_cmap('tab10')

def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
    
    im1_tensor = transform(im1)
    im2_tensor = transform(im2)

    im1_cv = cv2.imread(im1)
    im2_cv = cv2.imread(im2)

    # extract features
    with torch.no_grad():
        output1 = net.get_superfeatures(im1.to(device), scales=scales)
        feats1 = output1[0]
        attns1 = output1[1]
        strenghts1 = output1[2]

        output2 = net.get_superfeatures(im2.to(device), scales=scales)
        feats2 = output2[0]
        attns2 = output2[1]
        strenghts2 = output2[2]

    print(feats1.shape)
    print(attns1.shape)
    print(strenghts1.shape)



# GRADIO APP
title = "Visualizing Super-features"
description = "TBD"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"


iface = gr.Interface(
    fn=generate_matching_superfeatures,
    inputs=[
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
        gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
    outputs="plot",
    enable_queue=True,
    title=title,
    description=description,
    article=article,
    examples=[["chateau_1.png", "chateau_2.png", 6, 50]],
)
iface.launch()