YannisK commited on
Commit
b3806d2
1 Parent(s): 371f74f
Files changed (2) hide show
  1. app.py +15 -7
  2. fire_imagenet.pth +3 -0
app.py CHANGED
@@ -27,9 +27,13 @@ device = 'cpu'
27
  # Load net
28
  state = torch.load('fire.pth', map_location='cpu')
29
  state['net_params']['pretrained'] = None # no need for imagenet pretrained model
30
- net = fire_network.init_network(**state['net_params']).to(device)
31
- net.load_state_dict(state['state_dict'])
32
 
 
 
 
 
33
 
34
  # ---------------------------------------
35
  transform = transforms.Compose([
@@ -86,10 +90,13 @@ def match(query_feat, pos_feat, LoweRatioTh=0.9):
86
 
87
  col = plt.get_cmap('tab10')
88
 
89
- def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='', only_matching=True):
90
  print('im1:', im1.size)
91
  print('im2:', im2.size)
92
 
 
 
 
93
 
94
  # dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
95
  # loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
@@ -256,8 +263,9 @@ iface = gr.Interface(
256
  # gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
257
  gr.inputs.Image(type="pil", label="First Image"),
258
  gr.inputs.Image(type="pil", label="Second Image"),
 
259
  gr.inputs.Slider(minimum=0, maximum=6, step=1, default=2, label="Scale"),
260
- gr.inputs.Slider(minimum=0, maximum=255, step=25, default=100, label="Binarization Threshold"),
261
  gr.inputs.Textbox(lines=1, default="", label="SF IDs to show (comma separated numbers from 0-255; typing 'rX' will return X random SFs", optional=True),
262
  gr.inputs.Checkbox(default=True, label="Show only matching SFs", optional=False),
263
  ],
@@ -273,9 +281,9 @@ iface = gr.Interface(
273
  article=article,
274
  css=css,
275
  examples=[
276
- ["chateau_1.png", "chateau_2.png", 2, 100, '55,14,5,4,52,57,40,9', True],
277
- ["anafi1.jpeg", "anafi2.jpeg", 4, 150, '51,141,185,99,', True],
278
- ["areopoli1.jpeg", "areopoli2.jpeg", 4, 50, '72,44,142,213,236', True],
279
  ]
280
  )
281
  iface.launch(enable_queue=True)
27
  # Load net
28
  state = torch.load('fire.pth', map_location='cpu')
29
  state['net_params']['pretrained'] = None # no need for imagenet pretrained model
30
+ net_sfm = fire_network.init_network(**state['net_params']).to(device)
31
+ net_sfm.load_state_dict(state['state_dict'])
32
 
33
+ state2 = torch.load('fire_imagenet.pth', map_location='cpu')
34
+ state2['net_params']['pretrained'] = None # no need for imagenet pretrained model
35
+ net_imagenet = fire_network.init_network(**state2['net_params']).to(device)
36
+ net_imagenet.load_state_dict(state2['state_dict'])
37
 
38
  # ---------------------------------------
39
  transform = transforms.Compose([
90
 
91
  col = plt.get_cmap('tab10')
92
 
93
+ def generate_matching_superfeatures(im1, im2, model_fn, scale_id=6, threshold=50, sf_ids='', only_matching=True):
94
  print('im1:', im1.size)
95
  print('im2:', im2.size)
96
 
97
+ net = net_sfm
98
+ if model_fn == "ImageNet":
99
+ net = net_imagenet
100
 
101
  # dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
102
  # loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
263
  # gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
264
  gr.inputs.Image(type="pil", label="First Image"),
265
  gr.inputs.Image(type="pil", label="Second Image"),
266
+ gr.inputs.Radio(["ImageNet", "SfM-120k (landmarks)"], label="Model", optional=False),
267
  gr.inputs.Slider(minimum=0, maximum=6, step=1, default=2, label="Scale"),
268
+ gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"),
269
  gr.inputs.Textbox(lines=1, default="", label="SF IDs to show (comma separated numbers from 0-255; typing 'rX' will return X random SFs", optional=True),
270
  gr.inputs.Checkbox(default=True, label="Show only matching SFs", optional=False),
271
  ],
281
  article=article,
282
  css=css,
283
  examples=[
284
+ ["chateau_1.png", "chateau_2.png", 3, 150, 'r8', True],
285
+ ["anafi1.jpeg", "anafi2.jpeg", 4, 150, 'r8', True],
286
+ ["areopoli1.jpeg", "areopoli2.jpeg", 4, 150, 'r8', True],
287
  ]
288
  )
289
  iface.launch(enable_queue=True)
fire_imagenet.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e05eb75bd4155cd33d3d6fc063fb658b412fff62b73484f86842fc5df5f17b52
3
+ size 60557200