YannisK commited on
Commit
9651aac
β€’
1 Parent(s): 2f370fd

temp state

Browse files
Files changed (4) hide show
  1. README.md +11 -1
  2. app.py +36 -2
  3. fire_network.py +130 -0
  4. lit.py +92 -0
README.md CHANGED
@@ -1 +1,11 @@
1
- TBD
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SuperFeatures
3
+ emoji: πŸ“š
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ ---
10
+ # Learning Super-Features for Image Retrieval
11
+ A demo for the ICLR 22 paper "Learning Super-Features for Image Retrieval". [[Paper](https://openreview.net/pdf?id=wogsFPHwftY)] [[Official Github Repo](https://github.com/naver/fire)]
app.py CHANGED
@@ -3,6 +3,40 @@ import gradio as gr
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
 
 
6
 
7
+ # Model to use
8
+ net_path = 'fire.pth'
9
+
10
+ # CPU / GPU
11
+ device = 'cpu'
12
+
13
+ # Images will be downscaled to this size prior processing with the network
14
+ image_size = 1024
15
+
16
+ # Wrapper
17
+ def generate_matching_superfeatures(im1, im2, scale=6):
18
+
19
+ # Possible Scales for multiscale inference
20
+ scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
21
+
22
+
23
+ # GRADIO APP
24
+ title = "Visualizing Super-features"
25
+ description = "TBD"
26
+ article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
27
+
28
+
29
+ iface = gr.Interface(
30
+ fn=generate_matching_superfeatures,
31
+ inputs=[
32
+ gr.inputs.Image(shape=(240, 240), type="pil"),
33
+ gr.inputs.Image(shape=(240, 240), type="pil"),
34
+ gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale")],
35
+ outputs="plot",
36
+ enable_queue=True,
37
+ title=title,
38
+ description=description,
39
+ article=article,
40
+ examples=[["chateau_1.png", "chateau_2.png", 6]],
41
+ )
42
+ iface.launch()
fire_network.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ import torch
6
+ from torch import nn
7
+ import torchvision
8
+
9
+ from cirtorch.networks import imageretrievalnet
10
+
11
+ from how import layers
12
+ from how.layers import functional as HF
13
+
14
+ from lit import LocalfeatureIntegrationTransformer
15
+
16
+ from how.networks.how_net import HOWNet, CORERCF_SIZE
17
+
18
+ class FIReNet(HOWNet):
19
+
20
+ def __init__(self, features, attention, lit, dim_reduction, meta, runtime):
21
+ super().__init__(features, attention, None, dim_reduction, meta, runtime)
22
+ self.lit = lit
23
+ self.return_global = False
24
+
25
+ def copy_excluding_dim_reduction(self):
26
+ """Return a copy of this network without the dim_reduction layer"""
27
+ meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
28
+ return self.__class__(self.features, self.attention, self.lit, None, meta, self.runtime)
29
+
30
+ def copy_with_runtime(self, runtime):
31
+ """Return a copy of this network with a different runtime dict"""
32
+ return self.__class__(self.features, self.attention, self.lit, self.dim_reduction, self.meta, runtime)
33
+
34
+ def parameter_groups(self):
35
+ """Return torch parameter groups"""
36
+ layers = [self.features, self.attention, self.smoothing, self.lit]
37
+ parameters = [{'params': x.parameters()} for x in layers if x is not None]
38
+ if self.dim_reduction:
39
+ # Do not update dimensionality reduction layer
40
+ parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
41
+ return parameters
42
+
43
+ def get_superfeatures(self, x, *, scales):
44
+ """
45
+ return a list of tuple (features, attentionmpas) where each is a list containing requested scales
46
+ features is a tensor BxDxNx1
47
+ attentionmaps is a tensor BxNxHxW
48
+ """
49
+ feats = []
50
+ attns = []
51
+ strengths = []
52
+ for s in scales:
53
+ xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
54
+ o = self.features(xs)
55
+ o, attn = self.lit(o)
56
+ strength = self.attention(o)
57
+ if self.smoothing:
58
+ o = self.smoothing(o)
59
+ if self.dim_reduction:
60
+ o = self.dim_reduction(o)
61
+ feats.append(o)
62
+ attns.append(attn)
63
+ strengths.append(strength)
64
+ return feats, attns, strengths
65
+
66
+ def forward(self, x):
67
+ if self.return_global:
68
+ return self.forward_global(x, scales=self.runtime['training_scales'])
69
+ return self.get_superfeatures(x, scales=self.runtime['training_scales'])
70
+
71
+ def forward_global(self, x, *, scales):
72
+ """Return global descriptor"""
73
+ feats, _, strengths = self.get_superfeatures(x, scales=scales)
74
+ return HF.weighted_spoc(feats, strengths)
75
+
76
+ def forward_local(self, x, *, features_num, scales):
77
+ """Return selected super features"""
78
+ feats, _, strengths = self.get_superfeatures(x, scales=scales)
79
+ return HF.how_select_local(feats, strengths, scales=scales, features_num=features_num)
80
+
81
+ def init_network(architecture, pretrained, skip_layer, dim_reduction, lit, runtime):
82
+ """Initialize FIRe network
83
+ :param str architecture: Network backbone architecture (e.g. resnet18)
84
+ :param str pretrained: url of the pretrained model (None for using random initialization)
85
+ :param int skip_layer: How many layers of blocks should be skipped (from the end)
86
+ :param dict dim_reduction: Options for the dimensionality reduction layer
87
+ :param dict lit: Options for the lit layer
88
+ :param dict runtime: Runtime options to be stored in the network
89
+ :return FIRe: Initialized network
90
+ """
91
+ # Take convolutional layers as features, always ends with ReLU to make last activations non-negative
92
+ net_in = getattr(torchvision.models, architecture)(pretrained=False) # use trained weights including the LIT module instead
93
+ if architecture.startswith('alexnet') or architecture.startswith('vgg'):
94
+ features = list(net_in.features.children())[:-1]
95
+ elif architecture.startswith('resnet'):
96
+ features = list(net_in.children())[:-2]
97
+ elif architecture.startswith('densenet'):
98
+ features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
99
+ elif architecture.startswith('squeezenet'):
100
+ features = list(net_in.features.children())
101
+ else:
102
+ raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
103
+
104
+ if skip_layer > 0:
105
+ features = features[:-skip_layer]
106
+ backbone_dim = imageretrievalnet.OUTPUT_DIM[architecture] // (2 ** skip_layer)
107
+
108
+ att_layer = layers.attention.L2Attention()
109
+
110
+ lit_layer = LocalfeatureIntegrationTransformer(**lit, input_dim=backbone_dim)
111
+
112
+ reduction_layer = None
113
+ if dim_reduction:
114
+ reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=lit['dim'])
115
+
116
+ meta = {
117
+ "architecture": architecture,
118
+ "backbone_dim": lit['dim'],
119
+ "outputdim": reduction_layer.out_channels if dim_reduction else lit['dim'],
120
+ "corercf_size": CORERCF_SIZE[architecture] // (2 ** skip_layer),
121
+ }
122
+ net = FIReNet(nn.Sequential(*features), att_layer, lit_layer, reduction_layer, meta, runtime)
123
+
124
+ if pretrained is not None:
125
+ assert os.path.isfile(pretrained), pretrained
126
+ ckpt = torch.load(pretrained, map_location='cpu')
127
+ missing, unexpected = net.load_state_dict(ckpt['state_dict'], strict=False)
128
+ assert all(['dim_reduction' in a for a in missing]), "Loading did not go well"
129
+ assert all(['fc' in a for a in unexpected]), "Loading did not go well"
130
+ return net
lit.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021-2022 Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ class LocalfeatureIntegrationTransformer(nn.Module):
8
+ """Map a set of local features to a fixed number of SuperFeatures """
9
+
10
+ def __init__(self, T, N, input_dim, dim):
11
+ """
12
+ T: number of iterations
13
+ N: number of SuperFeatures
14
+ input_dim: dimension of input local features
15
+ dim: dimension of SuperFeatures
16
+ """
17
+ super().__init__()
18
+ self.T = T
19
+ self.N = N
20
+ self.input_dim = input_dim
21
+ self.dim = dim
22
+ # learnable initialization
23
+ self.templates_init = nn.Parameter(torch.randn(1,self.N,dim))
24
+ # qkv
25
+ self.project_q = nn.Linear(dim, dim, bias=False)
26
+ self.project_k = nn.Linear(input_dim, dim, bias=False)
27
+ self.project_v = nn.Linear(input_dim, dim, bias=False)
28
+ # layer norms
29
+ self.norm_inputs = nn.LayerNorm(input_dim)
30
+ self.norm_templates = nn.LayerNorm(dim)
31
+ # for the normalization
32
+ self.softmax = nn.Softmax(dim=-1)
33
+ self.scale = dim ** -0.5
34
+ # mlp
35
+ self.norm_mlp = nn.LayerNorm(dim)
36
+ mlp_dim = dim//2
37
+ self.mlp = nn.Sequential(nn.Linear(dim, mlp_dim), nn.ReLU(), nn.Linear(mlp_dim, dim) )
38
+
39
+
40
+ def forward(self, x):
41
+ """
42
+ input:
43
+ x has shape BxCxHxW
44
+ output:
45
+ template (output SuperFeatures): tensor of shape BxCxNx1
46
+ attn (attention over local features at the last iteration): tensor of shape BxNxHxW
47
+ """
48
+ # reshape inputs from BxCxHxW to Bx(H*W)xC
49
+ B,C,H,W = x.size()
50
+ x = x.reshape(B,C,H*W).permute(0,2,1)
51
+
52
+ # k and v projection
53
+ x = self.norm_inputs(x)
54
+ k = self.project_k(x)
55
+ v = self.project_v(x)
56
+
57
+ # template initialization
58
+ templates = torch.repeat_interleave(self.templates_init, B, dim=0)
59
+ attn = None
60
+
61
+ # main iteration loop
62
+ for _ in range(self.T):
63
+ templates_prev = templates
64
+
65
+ # q projection
66
+ templates = self.norm_templates(templates)
67
+ q = self.project_q(templates)
68
+
69
+ # attention
70
+ q = q * self.scale # Normalization.
71
+ attn_logits = torch.einsum('bnd,bld->bln', q, k)
72
+ attn = self.softmax(attn_logits)
73
+ attn = attn + 1e-8 # to avoid zero when with the L1 norm below
74
+ attn = attn / attn.sum(dim=-2, keepdim=True)
75
+
76
+ # update template
77
+ templates = templates_prev + torch.einsum('bld,bln->bnd', v, attn)
78
+
79
+ # mlp
80
+ templates = templates + self.mlp(self.norm_mlp(templates))
81
+
82
+ # reshape templates to BxDxNx1
83
+ templates = templates.permute(0,2,1)[:,:,:,None]
84
+ attn = attn.permute(0,2,1).view(B,self.N,H,W)
85
+
86
+ return templates, attn
87
+
88
+ def __repr__(self):
89
+ s = str(self.__class__.__name__)
90
+ for k in ["T","N","input_dim","dim"]:
91
+ s += "\n {:s}: {:d}".format(k, getattr(self,k))
92
+ return s