Spaces:
Build error
Build error
temp state
Browse files
README.md
CHANGED
@@ -1 +1,11 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: SuperFeatures
|
3 |
+
emoji: π
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
---
|
10 |
+
# Learning Super-Features for Image Retrieval
|
11 |
+
A demo for the ICLR 22 paper "Learning Super-Features for Image Retrieval". [[Paper](https://openreview.net/pdf?id=wogsFPHwftY)] [[Official Github Repo](https://github.com/naver/fire)]
|
app.py
CHANGED
@@ -3,6 +3,40 @@ import gradio as gr
|
|
3 |
def greet(name):
|
4 |
return "Hello " + name + "!!"
|
5 |
|
6 |
-
iface = gr.Interface(fn=greet, inputs="text", outputs="text")
|
7 |
-
iface.launch()
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
def greet(name):
|
4 |
return "Hello " + name + "!!"
|
5 |
|
|
|
|
|
6 |
|
7 |
+
# Model to use
|
8 |
+
net_path = 'fire.pth'
|
9 |
+
|
10 |
+
# CPU / GPU
|
11 |
+
device = 'cpu'
|
12 |
+
|
13 |
+
# Images will be downscaled to this size prior processing with the network
|
14 |
+
image_size = 1024
|
15 |
+
|
16 |
+
# Wrapper
|
17 |
+
def generate_matching_superfeatures(im1, im2, scale=6):
|
18 |
+
|
19 |
+
# Possible Scales for multiscale inference
|
20 |
+
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
|
21 |
+
|
22 |
+
|
23 |
+
# GRADIO APP
|
24 |
+
title = "Visualizing Super-features"
|
25 |
+
description = "TBD"
|
26 |
+
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
|
27 |
+
|
28 |
+
|
29 |
+
iface = gr.Interface(
|
30 |
+
fn=generate_matching_superfeatures,
|
31 |
+
inputs=[
|
32 |
+
gr.inputs.Image(shape=(240, 240), type="pil"),
|
33 |
+
gr.inputs.Image(shape=(240, 240), type="pil"),
|
34 |
+
gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale")],
|
35 |
+
outputs="plot",
|
36 |
+
enable_queue=True,
|
37 |
+
title=title,
|
38 |
+
description=description,
|
39 |
+
article=article,
|
40 |
+
examples=[["chateau_1.png", "chateau_2.png", 6]],
|
41 |
+
)
|
42 |
+
iface.launch()
|
fire_network.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from cirtorch.networks import imageretrievalnet
|
10 |
+
|
11 |
+
from how import layers
|
12 |
+
from how.layers import functional as HF
|
13 |
+
|
14 |
+
from lit import LocalfeatureIntegrationTransformer
|
15 |
+
|
16 |
+
from how.networks.how_net import HOWNet, CORERCF_SIZE
|
17 |
+
|
18 |
+
class FIReNet(HOWNet):
|
19 |
+
|
20 |
+
def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
|
21 |
+
super().__init__(features, attention, None, dim_reduction, meta, runtime)
|
22 |
+
self.lit = lit
|
23 |
+
self.return_global = False
|
24 |
+
|
25 |
+
def copy_excluding_dim_reduction(self):
|
26 |
+
"""Return a copy of this network without the dim_reduction layer"""
|
27 |
+
meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
|
28 |
+
return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)
|
29 |
+
|
30 |
+
def copy_with_runtime(self, runtime):
|
31 |
+
"""Return a copy of this network with a different runtime dict"""
|
32 |
+
return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)
|
33 |
+
|
34 |
+
def parameter_groups(self):
|
35 |
+
"""Return torch parameter groups"""
|
36 |
+
layers = [self.features, self.attention, self.smoothing, self.lit]
|
37 |
+
parameters = [{'params': x.parameters()} for x in layers if x is not None]
|
38 |
+
if self.dim_reduction:
|
39 |
+
# Do not update dimensionality reduction layer
|
40 |
+
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
|
41 |
+
return parameters
|
42 |
+
|
43 |
+
def get_superfeatures(self, x, *, scales):
|
44 |
+
"""
|
45 |
+
return a list of tuple (features, attentionmpas) where each is a list containing requested scales
|
46 |
+
features is a tensor BxDxNx1
|
47 |
+
attentionmaps is a tensor BxNxHxW
|
48 |
+
"""
|
49 |
+
feats = []
|
50 |
+
attns = []
|
51 |
+
strengths = []
|
52 |
+
for s in scales:
|
53 |
+
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
|
54 |
+
o = self.features(xs)
|
55 |
+
o, attn = self.lit(o)
|
56 |
+
strength = self.attention(o)
|
57 |
+
if self.smoothing:
|
58 |
+
o = self.smoothing(o)
|
59 |
+
if self.dim_reduction:
|
60 |
+
o = self.dim_reduction(o)
|
61 |
+
feats.append(o)
|
62 |
+
attns.append(attn)
|
63 |
+
strengths.append(strength)
|
64 |
+
return feats, attns, strengths
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
if self.return_global:
|
68 |
+
return self.forward_global(x, scales=self.runtime['training_scales'])
|
69 |
+
return self.get_superfeatures(x, scales=self.runtime['training_scales'])
|
70 |
+
|
71 |
+
def forward_global(self, x, *, scales):
|
72 |
+
"""Return global descriptor"""
|
73 |
+
feats, _, strengths = self.get_superfeatures(x, scales=scales)
|
74 |
+
return HF.weighted_spoc(feats, strengths)
|
75 |
+
|
76 |
+
def forward_local(self, x, *, features_num, scales):
|
77 |
+
"""Return selected super features"""
|
78 |
+
feats, _, strengths = self.get_superfeatures(x, scales=scales)
|
79 |
+
return HF.how_select_local(feats, strengths, scales=scales, features_num=features_num)
|
80 |
+
|
81 |
+
def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
|
82 |
+
"""Initialize FIRe network
|
83 |
+
:param str architecture: Network backbone architecture (e.g. resnet18)
|
84 |
+
:param str pretrained: url of the pretrained model (None for using random initialization)
|
85 |
+
:param int skip_layer: How many layers of blocks should be skipped (from the end)
|
86 |
+
:param dict dim_reduction: Options for the dimensionality reduction layer
|
87 |
+
:param dict lit: Options for the lit layer
|
88 |
+
:param dict runtime: Runtime options to be stored in the network
|
89 |
+
:return FIRe: Initialized network
|
90 |
+
"""
|
91 |
+
# Take convolutional layers as features, always ends with ReLU to make last activations non-negative
|
92 |
+
net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead
|
93 |
+
if architecture.startswith('alexnet') or architecture.startswith('vgg'):
|
94 |
+
features = list(net_in.features.children())[:-1]
|
95 |
+
elif architecture.startswith('resnet'):
|
96 |
+
features = list(net_in.children())[:-2]
|
97 |
+
elif architecture.startswith('densenet'):
|
98 |
+
features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
|
99 |
+
elif architecture.startswith('squeezenet'):
|
100 |
+
features = list(net_in.features.children())
|
101 |
+
else:
|
102 |
+
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
|
103 |
+
|
104 |
+
if skip_layer > 0:
|
105 |
+
features = features[:-skip_layer]
|
106 |
+
backbone_dim = imageretrievalnet.OUTPUT_DIM[architecture] // (2 ** skip_layer)
|
107 |
+
|
108 |
+
att_layer = layers.attention.L2Attention()
|
109 |
+
|
110 |
+
lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)
|
111 |
+
|
112 |
+
reduction_layer = None
|
113 |
+
if dim_reduction:
|
114 |
+
reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])
|
115 |
+
|
116 |
+
meta = {
|
117 |
+
"architecture": architecture,
|
118 |
+
"backbone_dim": lit['dim'],
|
119 |
+
"outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
|
120 |
+
"corercf_size": CORERCF_SIZE[architecture] // (2 ** skip_layer),
|
121 |
+
}
|
122 |
+
net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
|
123 |
+
|
124 |
+
if pretrained is not None:
|
125 |
+
assert os.path.isfile(pretrained), pretrained
|
126 |
+
ckpt = torch.load(pretrained, map_location='cpu')
|
127 |
+
missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
|
128 |
+
assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
|
129 |
+
assert all(['fc' in a for a in unexpected]), "Loading did not go well"
|
130 |
+
return net
|
lit.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
|
7 |
+
class LocalfeatureIntegrationTransformer(nn.Module):
|
8 |
+
"""Map a set of local features to a fixed number of SuperFeatures """
|
9 |
+
|
10 |
+
def __init__(self, T, N, input_dim, dim):
|
11 |
+
"""
|
12 |
+
T: number of iterations
|
13 |
+
N: number of SuperFeatures
|
14 |
+
input_dim: dimension of input local features
|
15 |
+
dim: dimension of SuperFeatures
|
16 |
+
"""
|
17 |
+
super().__init__()
|
18 |
+
self.T = T
|
19 |
+
self.N = N
|
20 |
+
self.input_dim = input_dim
|
21 |
+
self.dim = dim
|
22 |
+
# learnable initialization
|
23 |
+
self.templates_init = nn.Parameter(torch.randn(1,self.N,dim))
|
24 |
+
# qkv
|
25 |
+
self.project_q = nn.Linear(dim, dim, bias=False)
|
26 |
+
self.project_k = nn.Linear(input_dim, dim, bias=False)
|
27 |
+
self.project_v = nn.Linear(input_dim, dim, bias=False)
|
28 |
+
# layer norms
|
29 |
+
self.norm_inputs = nn.LayerNorm(input_dim)
|
30 |
+
self.norm_templates = nn.LayerNorm(dim)
|
31 |
+
# for the normalization
|
32 |
+
self.softmax = nn.Softmax(dim=-1)
|
33 |
+
self.scale = dim ** -0.5
|
34 |
+
# mlp
|
35 |
+
self.norm_mlp = nn.LayerNorm(dim)
|
36 |
+
mlp_dim = dim//2
|
37 |
+
self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim) )
|
38 |
+
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
"""
|
42 |
+
input:
|
43 |
+
x has shape BxCxHxW
|
44 |
+
output:
|
45 |
+
template (output SuperFeatures): tensor of shape BxCxNx1
|
46 |
+
attn (attention over local features at the last iteration): tensor of shape BxNxHxW
|
47 |
+
"""
|
48 |
+
# reshape inputs from BxCxHxW to Bx(H*W)xC
|
49 |
+
B,C,H,W = x.size()
|
50 |
+
x = x.reshape(B,C,H*W).permute(0,2,1)
|
51 |
+
|
52 |
+
# k and v projection
|
53 |
+
x = self.norm_inputs(x)
|
54 |
+
k = self.project_k(x)
|
55 |
+
v = self.project_v(x)
|
56 |
+
|
57 |
+
# template initialization
|
58 |
+
templates = torch.repeat_interleave(self.templates_init, B, dim=0)
|
59 |
+
attn = None
|
60 |
+
|
61 |
+
# main iteration loop
|
62 |
+
for _ in range(self.T):
|
63 |
+
templates_prev = templates
|
64 |
+
|
65 |
+
# q projection
|
66 |
+
templates = self.norm_templates(templates)
|
67 |
+
q = self.project_q(templates)
|
68 |
+
|
69 |
+
# attention
|
70 |
+
q = q * self.scale # Normalization.
|
71 |
+
attn_logits = torch.einsum('bnd,bld->bln', q, k)
|
72 |
+
attn = self.softmax(attn_logits)
|
73 |
+
attn = attn + 1e-8 # to avoid zero when with the L1 norm below
|
74 |
+
attn = attn / attn.sum(dim=-2, keepdim=True)
|
75 |
+
|
76 |
+
# update template
|
77 |
+
templates = templates_prev + torch.einsum('bld,bln->bnd', v, attn)
|
78 |
+
|
79 |
+
# mlp
|
80 |
+
templates = templates + self.mlp(self.norm_mlp(templates))
|
81 |
+
|
82 |
+
# reshape templates to BxDxNx1
|
83 |
+
templates = templates.permute(0,2,1)[:,:,:,None]
|
84 |
+
attn = attn.permute(0,2,1).view(B,self.N,H,W)
|
85 |
+
|
86 |
+
return templates, attn
|
87 |
+
|
88 |
+
def __repr__(self):
|
89 |
+
s = str(self.__class__.__name__)
|
90 |
+
for k in ["T","N","input_dim","dim"]:
|
91 |
+
s += "\n {:s}: {:d}".format(k, getattr(self,k))
|
92 |
+
return s
|