Spaces:
Build error
Build error
temp state
Browse files- LICENSE +85 -0
- README.md +1 -13
- chateau_1.png +0 -0
- chateau_2.png +0 -0
- fire.pth +3 -0
- how/__init__.py +4 -0
- how/layers/__init__.py +5 -0
- how/layers/__pycache__/__init__.cpython-37.pyc +0 -0
- how/layers/__pycache__/attention.cpython-37.pyc +0 -0
- how/layers/__pycache__/dim_reduction.cpython-37.pyc +0 -0
- how/layers/__pycache__/functional.cpython-37.pyc +0 -0
- how/layers/__pycache__/pooling.cpython-37.pyc +0 -0
- how/layers/attention.py +10 -0
- how/layers/dim_reduction.py +29 -0
- how/layers/functional.py +73 -0
- how/layers/pooling.py +19 -0
- how/networks/__init__.py +5 -0
- how/networks/__pycache__/__init__.cpython-37.pyc +0 -0
- how/networks/__pycache__/how_net.cpython-37.pyc +0 -0
- how/networks/how_net.py +221 -0
- how/stages/__init__.py +5 -0
- how/stages/evaluate.py +314 -0
- how/stages/train.py +241 -0
- how/utils/__init__.py +3 -0
- how/utils/__pycache__/__init__.cpython-37.pyc +0 -0
- how/utils/__pycache__/data_helpers.cpython-37.pyc +0 -0
- how/utils/__pycache__/download.cpython-37.pyc +0 -0
- how/utils/__pycache__/html.cpython-37.pyc +0 -0
- how/utils/__pycache__/io_helpers.cpython-37.pyc +0 -0
- how/utils/__pycache__/score_helpers.cpython-37.pyc +0 -0
- how/utils/__pycache__/visualize.cpython-37.pyc +0 -0
- how/utils/__pycache__/whitening.cpython-37.pyc +0 -0
- how/utils/data_helpers.py +90 -0
- how/utils/download.py +44 -0
- how/utils/html.py +252 -0
- how/utils/io_helpers.py +105 -0
- how/utils/logging.py +63 -0
- how/utils/plots.py +37 -0
- how/utils/score_helpers.py +59 -0
- how/utils/visualize.py +99 -0
- how/utils/whitening.py +36 -0
- requirements.txt +5 -0
LICENSE
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FIRe, Copyright (c) 2021-2022 Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
|
2 |
+
|
3 |
+
A summary of the CC BY-NC-SA 4.0 license is located here:
|
4 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/
|
5 |
+
|
6 |
+
The CC BY-NC-SA 4.0 license is located here:
|
7 |
+
https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
|
8 |
+
|
9 |
+
|
10 |
+
SEE NOTICE BELOW WITH RESPECT TO THE FILE: train.py and evaluate.py
|
11 |
+
SEE NOTICE BELOW WITH RESPECT TO THE FILES in folder how/
|
12 |
+
|
13 |
+
**********************************
|
14 |
+
|
15 |
+
|
16 |
+
NOTICE WITH RESPECT TO THE FILE: train.py and evaluate.py
|
17 |
+
|
18 |
+
|
19 |
+
This software is being redistributed in a modifiled form. The original form is available here:
|
20 |
+
|
21 |
+
https://github.com/gtolias/how
|
22 |
+
|
23 |
+
|
24 |
+
ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
|
25 |
+
|
26 |
+
https://github.com/gtolias/how/blob/master/LICENSE
|
27 |
+
|
28 |
+
|
29 |
+
MIT License
|
30 |
+
|
31 |
+
Copyright (c) 2020 Giorgos Tolias, Tomas Jenicek
|
32 |
+
|
33 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
34 |
+
of this software and associated documentation files (the "Software"), to deal
|
35 |
+
in the Software without restriction, including without limitation the rights
|
36 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
37 |
+
copies of the Software, and to permit persons to whom the Software is
|
38 |
+
furnished to do so, subject to the following conditions:
|
39 |
+
The above copyright notice and this permission notice shall be included in all
|
40 |
+
copies or substantial portions of the Software.
|
41 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
42 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
43 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
44 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
45 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
46 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
47 |
+
SOFTWARE.
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
**********************************
|
52 |
+
|
53 |
+
SEE NOTICE BELOW WITH RESPECT TO THE FILES in folder how/
|
54 |
+
|
55 |
+
This project contains subcomponents with separate copyright notices and license terms.
|
56 |
+
Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
|
57 |
+
|
58 |
+
====
|
59 |
+
|
60 |
+
gtolias/how
|
61 |
+
https://github.com/gtolias/how
|
62 |
+
|
63 |
+
MIT License
|
64 |
+
|
65 |
+
Copyright (c) 2020 Giorgos Tolias, Tomas Jenicek
|
66 |
+
|
67 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
68 |
+
of this software and associated documentation files (the "Software"), to deal
|
69 |
+
in the Software without restriction, including without limitation the rights
|
70 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
71 |
+
copies of the Software, and to permit persons to whom the Software is
|
72 |
+
furnished to do so, subject to the following conditions:
|
73 |
+
|
74 |
+
The above copyright notice and this permission notice shall be included in all
|
75 |
+
copies or substantial portions of the Software.
|
76 |
+
|
77 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
78 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
79 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
80 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
81 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
82 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
83 |
+
SOFTWARE.
|
84 |
+
|
85 |
+
====
|
README.md
CHANGED
@@ -1,13 +1 @@
|
|
1 |
-
|
2 |
-
title: Superfeatures
|
3 |
-
emoji: 🏢
|
4 |
-
colorFrom: blue
|
5 |
-
colorTo: indigo
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 2.9.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: cc-by-nc-sa-4.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
|
|
|
1 |
+
TBD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chateau_1.png
ADDED
chateau_2.png
ADDED
fire.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7ddeb04ebdd5ca3e7a9d86ce6a5dec5dabfbb23a70a6f3d0907b17e484474202
|
3 |
+
size 52765649
|
how/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Official Python implementation of HOW method for ECCV 2020 paper "Learning and aggregating deep
|
3 |
+
local descriptors for instance-level recognition"
|
4 |
+
"""
|
how/layers/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Modules implementing layers in pytorch by inheriting from torch.nn.Module
|
3 |
+
"""
|
4 |
+
|
5 |
+
from . import attention, dim_reduction, pooling
|
how/layers/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (281 Bytes). View file
|
|
how/layers/__pycache__/attention.cpython-37.pyc
ADDED
Binary file (643 Bytes). View file
|
|
how/layers/__pycache__/dim_reduction.cpython-37.pyc
ADDED
Binary file (1.46 kB). View file
|
|
how/layers/__pycache__/functional.cpython-37.pyc
ADDED
Binary file (2.89 kB). View file
|
|
how/layers/__pycache__/pooling.cpython-37.pyc
ADDED
Binary file (928 Bytes). View file
|
|
how/layers/attention.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Layers producing a 2D attention map from a feature map"""
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class L2Attention(nn.Module):
|
7 |
+
"""Compute the attention as L2-norm of local descriptors"""
|
8 |
+
|
9 |
+
def forward(self, x):
|
10 |
+
return (x.pow(2.0).sum(1) + 1e-10).sqrt().squeeze(0)
|
how/layers/dim_reduction.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Layers implementing dimensionality reduction of a feature map"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
from ..utils import whitening
|
7 |
+
|
8 |
+
|
9 |
+
class ConvDimReduction(nn.Conv2d):
|
10 |
+
"""Dimensionality reduction as a convolutional layer
|
11 |
+
|
12 |
+
:param int input_dim: Network out_channels
|
13 |
+
:param in dim: Whitening out_channels, for dimensionality reduction
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, input_dim, dim):
|
17 |
+
super().__init__(input_dim, dim, (1, 1), padding=0, bias=True)
|
18 |
+
|
19 |
+
def initialize_pca_whitening(self, des):
|
20 |
+
"""Initialize PCA whitening from given descriptors. Return tuple of shift and projection."""
|
21 |
+
m, P = whitening.pcawhitenlearn_shrinkage(des)
|
22 |
+
m, P = m.T, P.T
|
23 |
+
|
24 |
+
projection = torch.Tensor(P[:self.weight.shape[0], :]).unsqueeze(-1).unsqueeze(-1)
|
25 |
+
self.weight.data = projection.to(self.weight.device)
|
26 |
+
|
27 |
+
projected_shift = -torch.mm(torch.FloatTensor(P), torch.FloatTensor(m)).squeeze()
|
28 |
+
self.bias.data = projected_shift[:self.weight.shape[0]].to(self.bias.device)
|
29 |
+
return m.T, P.T
|
how/layers/functional.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Layer functions"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import cirtorch.layers.functional as CF
|
7 |
+
|
8 |
+
|
9 |
+
def smoothing_avg_pooling(feats, kernel_size):
|
10 |
+
"""Smoothing average pooling
|
11 |
+
|
12 |
+
:param torch.Tensor feats: Feature map
|
13 |
+
:param int kernel_size: kernel size of pooling
|
14 |
+
:return torch.Tensor: Smoothend feature map
|
15 |
+
"""
|
16 |
+
pad = kernel_size // 2
|
17 |
+
return F.avg_pool2d(feats, (kernel_size, kernel_size), stride=1, padding=pad,
|
18 |
+
count_include_pad=False)
|
19 |
+
|
20 |
+
|
21 |
+
def weighted_spoc(ms_feats, ms_weights):
|
22 |
+
"""Weighted SPoC pooling, summed over scales.
|
23 |
+
|
24 |
+
:param list ms_feats: A list of feature maps, each at a different scale
|
25 |
+
:param list ms_weights: A list of weights, each at a different scale
|
26 |
+
:return torch.Tensor: L2-normalized global descriptor
|
27 |
+
"""
|
28 |
+
desc = torch.zeros((1, ms_feats[0].shape[1]), dtype=torch.float32, device=ms_feats[0].device)
|
29 |
+
for feats, weights in zip(ms_feats, ms_weights):
|
30 |
+
desc += (feats * weights).sum((-2, -1)).squeeze()
|
31 |
+
return CF.l2n(desc)
|
32 |
+
|
33 |
+
|
34 |
+
def how_select_local(ms_feats, ms_masks, *, scales, features_num):
|
35 |
+
"""Convert multi-scale feature maps with attentions to a list of local descriptors
|
36 |
+
|
37 |
+
:param list ms_feats: A list of feature maps, each at a different scale
|
38 |
+
:param list ms_masks: A list of attentions, each at a different scale
|
39 |
+
:param list scales: A list of scales (floats)
|
40 |
+
:param int features_num: Number of features to be returned (sorted by attenions)
|
41 |
+
:return tuple: A list of descriptors, attentions, locations (x_coor, y_coor) and scales where
|
42 |
+
elements from each list correspond to each other
|
43 |
+
"""
|
44 |
+
device = ms_feats[0].device
|
45 |
+
size = sum(x.shape[0] * x.shape[1] for x in ms_masks)
|
46 |
+
|
47 |
+
desc = torch.zeros(size, ms_feats[0].shape[1], dtype=torch.float32, device=device)
|
48 |
+
atts = torch.zeros(size, dtype=torch.float32, device=device)
|
49 |
+
locs = torch.zeros(size, 2, dtype=torch.int16, device=device)
|
50 |
+
scls = torch.zeros(size, dtype=torch.float16, device=device)
|
51 |
+
|
52 |
+
pointer = 0
|
53 |
+
for sc, vs, ms in zip(scales, ms_feats, ms_masks):
|
54 |
+
if len(ms.shape) == 0:
|
55 |
+
continue
|
56 |
+
|
57 |
+
height, width = ms.shape
|
58 |
+
numel = torch.numel(ms)
|
59 |
+
slc = slice(pointer, pointer+numel)
|
60 |
+
pointer += numel
|
61 |
+
|
62 |
+
desc[slc] = vs.squeeze(0).reshape(vs.shape[1], -1).T
|
63 |
+
atts[slc] = ms.reshape(-1)
|
64 |
+
width_arr = torch.arange(width, dtype=torch.int16)
|
65 |
+
locs[slc, 0] = width_arr.repeat(height).to(device) # x axis
|
66 |
+
height_arr = torch.arange(height, dtype=torch.int16)
|
67 |
+
locs[slc, 1] = height_arr.view(-1, 1).repeat(1, width).reshape(-1).to(device) # y axis
|
68 |
+
scls[slc] = sc
|
69 |
+
|
70 |
+
keep_n = min(features_num, atts.shape[0]) if features_num is not None else atts.shape[0]
|
71 |
+
idx = atts.sort(descending=True)[1][:keep_n]
|
72 |
+
|
73 |
+
return desc[idx], atts[idx], locs[idx], scls[idx]
|
how/layers/pooling.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Spatial pooling layers"""
|
2 |
+
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from . import functional as LF
|
6 |
+
|
7 |
+
|
8 |
+
class SmoothingAvgPooling(nn.Module):
|
9 |
+
"""Average pooling that smoothens the feature map, keeping its size
|
10 |
+
|
11 |
+
:param int kernel_size: Kernel size of given pooling (e.g. 3)
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, kernel_size):
|
15 |
+
super().__init__()
|
16 |
+
self.kernel_size = kernel_size
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return LF.smoothing_avg_pooling(x, kernel_size=self.kernel_size)
|
how/networks/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Pytorch networks
|
3 |
+
"""
|
4 |
+
|
5 |
+
from . import how_net
|
how/networks/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (182 Bytes). View file
|
|
how/networks/__pycache__/how_net.cpython-37.pyc
ADDED
Binary file (8.49 kB). View file
|
|
how/networks/how_net.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module of the HOW method"""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
|
8 |
+
from cirtorch.networks import imageretrievalnet
|
9 |
+
|
10 |
+
from .. import layers
|
11 |
+
from ..layers import functional as HF
|
12 |
+
from ..utils import io_helpers
|
13 |
+
|
14 |
+
NUM_WORKERS = 6
|
15 |
+
|
16 |
+
CORERCF_SIZE = {
|
17 |
+
'resnet18': 32,
|
18 |
+
'resnet50': 32,
|
19 |
+
'resnet101': 32,
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
class HOWNet(nn.Module):
|
24 |
+
"""Network for the HOW method
|
25 |
+
|
26 |
+
:param list features: A list of torch.nn.Module which act as feature extractor
|
27 |
+
:param torch.nn.Module attention: Attention layer
|
28 |
+
:param torch.nn.Module smoothing: Smoothing layer
|
29 |
+
:param torch.nn.Module dim_reduction: Dimensionality reduction layer
|
30 |
+
:param dict meta: Metadata that are stored with the network
|
31 |
+
:param dict runtime: Runtime options that can be used as default for e.g. inference
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, features, attention, smoothing, dim_reduction, meta, runtime):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
self.features = features
|
38 |
+
self.attention = attention
|
39 |
+
self.smoothing = smoothing
|
40 |
+
self.dim_reduction = dim_reduction
|
41 |
+
|
42 |
+
self.meta = meta
|
43 |
+
self.runtime = runtime
|
44 |
+
|
45 |
+
|
46 |
+
def copy_excluding_dim_reduction(self):
|
47 |
+
"""Return a copy of this network without the dim_reduction layer"""
|
48 |
+
meta = {**self.meta, "outputdim": self.meta['backbone_dim']}
|
49 |
+
return self.__class__(self.features, self.attention, self.smoothing, None, meta, self.runtime)
|
50 |
+
|
51 |
+
def copy_with_runtime(self, runtime):
|
52 |
+
"""Return a copy of this network with a different runtime dict"""
|
53 |
+
return self.__class__(self.features, self.attention, self.smoothing, self.dim_reduction, self.meta, runtime)
|
54 |
+
|
55 |
+
|
56 |
+
# Methods of nn.Module
|
57 |
+
|
58 |
+
@staticmethod
|
59 |
+
def _set_batchnorm_eval(mod):
|
60 |
+
if mod.__class__.__name__.find('BatchNorm') != -1:
|
61 |
+
# freeze running mean and std
|
62 |
+
mod.eval()
|
63 |
+
|
64 |
+
def train(self, mode=True):
|
65 |
+
res = super().train(mode)
|
66 |
+
if mode:
|
67 |
+
self.apply(HOWNet._set_batchnorm_eval)
|
68 |
+
return res
|
69 |
+
|
70 |
+
def parameter_groups(self, optimizer_opts):
|
71 |
+
"""Return torch parameter groups"""
|
72 |
+
layers = [self.features, self.attention, self.smoothing]
|
73 |
+
parameters = [{'params': x.parameters()} for x in layers if x is not None]
|
74 |
+
if self.dim_reduction:
|
75 |
+
# Do not update dimensionality reduction layer
|
76 |
+
parameters.append({'params': self.dim_reduction.parameters(), 'lr': 0.0})
|
77 |
+
return parameters
|
78 |
+
|
79 |
+
|
80 |
+
# Forward
|
81 |
+
|
82 |
+
def features_attentions(self, x, *, scales):
|
83 |
+
"""Return a tuple (features, attentions) where each is a list containing requested scales"""
|
84 |
+
feats = []
|
85 |
+
masks = []
|
86 |
+
for s in scales:
|
87 |
+
xs = nn.functional.interpolate(x, scale_factor=s, mode='bilinear', align_corners=False)
|
88 |
+
o = self.features(xs)
|
89 |
+
m = self.attention(o)
|
90 |
+
if self.smoothing:
|
91 |
+
o = self.smoothing(o)
|
92 |
+
if self.dim_reduction:
|
93 |
+
o = self.dim_reduction(o)
|
94 |
+
feats.append(o)
|
95 |
+
masks.append(m)
|
96 |
+
|
97 |
+
# Normalize max weight to 1
|
98 |
+
mx = max(x.max() for x in masks)
|
99 |
+
masks = [x/mx for x in masks]
|
100 |
+
|
101 |
+
return feats, masks
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
return self.forward_global(x, scales=self.runtime['training_scales'])
|
105 |
+
|
106 |
+
def forward_global(self, x, *, scales):
|
107 |
+
"""Return global descriptor"""
|
108 |
+
feats, masks = self.features_attentions(x, scales=scales)
|
109 |
+
return HF.weighted_spoc(feats, masks)
|
110 |
+
|
111 |
+
def forward_local(self, x, *, features_num, scales):
|
112 |
+
"""Return local descriptors"""
|
113 |
+
feats, masks = self.features_attentions(x, scales=scales)
|
114 |
+
return HF.how_select_local(feats, masks, scales=scales, features_num=features_num)
|
115 |
+
|
116 |
+
|
117 |
+
# String conversion
|
118 |
+
|
119 |
+
def __repr__(self):
|
120 |
+
meta_str = "\n".join(" %s: %s" % x for x in self.meta.items())
|
121 |
+
return "%s(meta={\n%s\n})" % (self.__class__.__name__, meta_str)
|
122 |
+
|
123 |
+
def meta_repr(self):
|
124 |
+
"""Return meta representation"""
|
125 |
+
return str(self)
|
126 |
+
|
127 |
+
|
128 |
+
def init_network(architecture, pretrained, skip_layer, dim_reduction, smoothing, runtime):
|
129 |
+
"""Initialize HOW network
|
130 |
+
|
131 |
+
:param str architecture: Network backbone architecture (e.g. resnet18)
|
132 |
+
:param bool pretrained: Whether to start with a network pretrained on ImageNet
|
133 |
+
:param int skip_layer: How many layers of blocks should be skipped (from the end)
|
134 |
+
:param dict dim_reduction: Options for the dimensionality reduction layer
|
135 |
+
:param dict smoothing: Options for the smoothing layer
|
136 |
+
:param dict runtime: Runtime options to be stored in the network
|
137 |
+
:return HOWNet: Initialized network
|
138 |
+
"""
|
139 |
+
# Take convolutional layers as features, always ends with ReLU to make last activations non-negative
|
140 |
+
net_in = getattr(torchvision.models, architecture)(pretrained=pretrained)
|
141 |
+
if architecture.startswith('alexnet') or architecture.startswith('vgg'):
|
142 |
+
features = list(net_in.features.children())[:-1]
|
143 |
+
elif architecture.startswith('resnet'):
|
144 |
+
features = list(net_in.children())[:-2]
|
145 |
+
elif architecture.startswith('densenet'):
|
146 |
+
features = list(net_in.features.children()) + [nn.ReLU(inplace=True)]
|
147 |
+
elif architecture.startswith('squeezenet'):
|
148 |
+
features = list(net_in.features.children())
|
149 |
+
else:
|
150 |
+
raise ValueError('Unsupported or unknown architecture: {}!'.format(architecture))
|
151 |
+
|
152 |
+
if skip_layer > 0:
|
153 |
+
features = features[:-skip_layer]
|
154 |
+
backbone_dim = imageretrievalnet.OUTPUT_DIM[architecture] // (2 ** skip_layer)
|
155 |
+
|
156 |
+
att_layer = layers.attention.L2Attention()
|
157 |
+
smooth_layer = None
|
158 |
+
if smoothing:
|
159 |
+
smooth_layer = layers.pooling.SmoothingAvgPooling(**smoothing)
|
160 |
+
reduction_layer = None
|
161 |
+
if dim_reduction:
|
162 |
+
reduction_layer = layers.dim_reduction.ConvDimReduction(**dim_reduction, input_dim=backbone_dim)
|
163 |
+
|
164 |
+
meta = {
|
165 |
+
"architecture": architecture,
|
166 |
+
"backbone_dim": backbone_dim,
|
167 |
+
"outputdim": reduction_layer.out_channels if dim_reduction else backbone_dim,
|
168 |
+
"corercf_size": CORERCF_SIZE[architecture] // (2 ** skip_layer),
|
169 |
+
}
|
170 |
+
return HOWNet(nn.Sequential(*features), att_layer, smooth_layer, reduction_layer, meta, runtime)
|
171 |
+
|
172 |
+
|
173 |
+
def extract_vectors(net, dataset, device, *, scales):
|
174 |
+
"""Return global descriptors in torch.Tensor"""
|
175 |
+
net.eval()
|
176 |
+
loader = torch.utils.data.DataLoader(dataset, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)
|
177 |
+
|
178 |
+
with torch.no_grad():
|
179 |
+
vecs = torch.zeros(len(loader), net.meta['outputdim'])
|
180 |
+
for i, inp in io_helpers.progress(enumerate(loader), size=len(loader), print_freq=100):
|
181 |
+
vecs[i] = net.forward_global(inp.to(device), scales=scales).cpu().squeeze()
|
182 |
+
|
183 |
+
return vecs
|
184 |
+
|
185 |
+
|
186 |
+
def extract_vectors_local(net, dataset, device, *, features_num, scales):
|
187 |
+
"""Return tuple (local descriptors, image ids, strenghts, locations and scales) where locations
|
188 |
+
consists of (coor_x, coor_y, scale) and elements of each list correspond to each other"""
|
189 |
+
net.eval()
|
190 |
+
loader = torch.utils.data.DataLoader(dataset, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)
|
191 |
+
|
192 |
+
with torch.no_grad():
|
193 |
+
vecs, strengths, locs, scls, imids = [], [], [], [], []
|
194 |
+
for imid, inp in io_helpers.progress(enumerate(loader), size=len(loader), print_freq=100):
|
195 |
+
output = net.forward_local(inp.to(device), features_num=features_num, scales=scales)
|
196 |
+
|
197 |
+
vecs.append(output[0].cpu().numpy())
|
198 |
+
strengths.append(output[1].cpu().numpy())
|
199 |
+
locs.append(output[2].cpu().numpy())
|
200 |
+
scls.append(output[3].cpu().numpy())
|
201 |
+
imids.append(np.full((output[0].shape[0],), imid))
|
202 |
+
|
203 |
+
return np.vstack(vecs), np.hstack(imids), np.hstack(strengths), np.vstack(locs), np.hstack(scls)
|
204 |
+
|
205 |
+
|
206 |
+
|
207 |
+
def extract_vectors_all(net, dataset, device, *, features_num, scales):
|
208 |
+
"""Return tuple (local descriptors, image ids, strenghts, locations and scales) where locations
|
209 |
+
consists of (coor_x, coor_y, scale) and elements of each list correspond to each other"""
|
210 |
+
net.eval()
|
211 |
+
loader = torch.utils.data.DataLoader(dataset, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)
|
212 |
+
|
213 |
+
with torch.no_grad():
|
214 |
+
feats, attns, strenghts = [], [], []
|
215 |
+
for imid, inp in io_helpers.progress(enumerate(loader), size=len(loader), print_freq=100):
|
216 |
+
output = net.get_superfeatures(inp.to(device), scales=scales)
|
217 |
+
feats.append(output[0])
|
218 |
+
attns.append(output[1])
|
219 |
+
strenghts.append(output[2])
|
220 |
+
|
221 |
+
return feats, attns, strenghts
|
how/stages/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implementation of different network stages, such as training and evaluation
|
3 |
+
"""
|
4 |
+
|
5 |
+
from . import evaluate, train
|
how/stages/evaluate.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implements evaluation of trained models"""
|
2 |
+
|
3 |
+
import time
|
4 |
+
import warnings
|
5 |
+
from pathlib import Path
|
6 |
+
import pickle
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision import transforms
|
10 |
+
from PIL import ImageFile
|
11 |
+
|
12 |
+
from cirtorch.datasets.genericdataset import ImagesFromList
|
13 |
+
|
14 |
+
from asmk import asmk_method, kernel as kern_pkg
|
15 |
+
from ..networks import how_net
|
16 |
+
from ..utils import score_helpers, data_helpers, logging
|
17 |
+
|
18 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
19 |
+
warnings.filterwarnings("ignore", r"^Possibly corrupt EXIF data", category=UserWarning)
|
20 |
+
|
21 |
+
|
22 |
+
def evaluate_demo(demo_eval, evaluation, globals):
|
23 |
+
"""Demo evaluating a trained network
|
24 |
+
|
25 |
+
:param dict demo_eval: Demo-related options
|
26 |
+
:param dict evaluation: Evaluation-related options
|
27 |
+
:param dict globals: Global options
|
28 |
+
"""
|
29 |
+
globals["device"] = torch.device("cpu")
|
30 |
+
if demo_eval['gpu_id'] is not None:
|
31 |
+
globals["device"] = torch.device(("cuda:%s" % demo_eval['gpu_id']))
|
32 |
+
|
33 |
+
# Handle net_path when directory
|
34 |
+
net_path = Path(demo_eval['exp_folder']) / demo_eval['net_path']
|
35 |
+
if net_path.is_dir() and (net_path / "epochs/model_best.pth").exists():
|
36 |
+
net_path = net_path / "epochs/model_best.pth"
|
37 |
+
|
38 |
+
# Load net
|
39 |
+
state = _convert_checkpoint(torch.load(net_path, map_location='cpu'))
|
40 |
+
net = how_net.init_network(**state['net_params']).to(globals['device'])
|
41 |
+
net.load_state_dict(state['state_dict'])
|
42 |
+
globals["transform"] = transforms.Compose([transforms.ToTensor(), \
|
43 |
+
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
|
44 |
+
|
45 |
+
# Eval
|
46 |
+
if evaluation['global_descriptor']['datasets']:
|
47 |
+
eval_global(net, evaluation['inference'], globals, **evaluation['global_descriptor'])
|
48 |
+
|
49 |
+
if evaluation['multistep']:
|
50 |
+
eval_asmk_multistep(net, evaluation['inference'], evaluation['multistep'], globals, **evaluation['local_descriptor'])
|
51 |
+
elif evaluation['local_descriptor']['datasets']:
|
52 |
+
eval_asmk(net, evaluation['inference'], globals, **evaluation['local_descriptor'])
|
53 |
+
|
54 |
+
|
55 |
+
def eval_global(net, inference, globals, *, datasets):
|
56 |
+
"""Evaluate global descriptors"""
|
57 |
+
net.eval()
|
58 |
+
time0 = time.time()
|
59 |
+
logger = globals["logger"]
|
60 |
+
logger.info("Starting global evaluation")
|
61 |
+
|
62 |
+
results = {}
|
63 |
+
for dataset in datasets:
|
64 |
+
images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root=globals['root_path'])
|
65 |
+
logger.info(f"Evaluating {dataset}")
|
66 |
+
|
67 |
+
with logging.LoggingStopwatch("extracting database images", logger.info, logger.debug):
|
68 |
+
dset = ImagesFromList(root='', images=images, imsize=inference['image_size'], bbxs=None,
|
69 |
+
transform=globals['transform'])
|
70 |
+
vecs = how_net.extract_vectors(net, dset, globals["device"], scales=inference['scales'])
|
71 |
+
with logging.LoggingStopwatch("extracting query images", logger.info, logger.debug):
|
72 |
+
qdset = ImagesFromList(root='', images=qimages, imsize=inference['image_size'], bbxs=bbxs,
|
73 |
+
transform=globals['transform'])
|
74 |
+
qvecs = how_net.extract_vectors(net, qdset, globals["device"], scales=inference['scales'])
|
75 |
+
|
76 |
+
vecs, qvecs = vecs.numpy(), qvecs.numpy()
|
77 |
+
ranks = np.argsort(-np.dot(vecs, qvecs.T), axis=0)
|
78 |
+
results[dataset] = score_helpers.compute_map_and_log(dataset, ranks, gnd, logger=logger)
|
79 |
+
|
80 |
+
logger.info(f"Finished global evaluation in {int(time.time()-time0) // 60} min")
|
81 |
+
return results
|
82 |
+
|
83 |
+
|
84 |
+
def eval_asmk(net, inference, globals, *, datasets, codebook_training, asmk):
|
85 |
+
"""Evaluate local descriptors with ASMK"""
|
86 |
+
net.eval()
|
87 |
+
time0 = time.time()
|
88 |
+
logger = globals["logger"]
|
89 |
+
logger.info("Starting asmk evaluation")
|
90 |
+
|
91 |
+
asmk = asmk_method.ASMKMethod.initialize_untrained(asmk)
|
92 |
+
asmk = asmk_train_codebook(net, inference, globals, logger, codebook_training=codebook_training,
|
93 |
+
asmk=asmk, cache_path=None)
|
94 |
+
|
95 |
+
results = {}
|
96 |
+
for dataset in datasets:
|
97 |
+
dataset_name = dataset if isinstance(dataset, str) else dataset['name']
|
98 |
+
images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root=globals['root_path'])
|
99 |
+
logger.info(f"Evaluating '{dataset_name}'")
|
100 |
+
|
101 |
+
asmk_dataset = asmk_index_database(net, inference, globals, logger, asmk=asmk, images=images)
|
102 |
+
asmk_query_ivf(net, inference, globals, logger, dataset=dataset, asmk_dataset=asmk_dataset,
|
103 |
+
qimages=qimages, bbxs=bbxs, gnd=gnd, results=results,
|
104 |
+
cache_path=globals["exp_path"] / "query_results.pkl")
|
105 |
+
|
106 |
+
logger.info(f"Finished asmk evaluation in {int(time.time()-time0) // 60} min")
|
107 |
+
return results
|
108 |
+
|
109 |
+
|
110 |
+
def eval_asmk_multistep(net, inference, multistep, globals, *, datasets, codebook_training, asmk):
|
111 |
+
"""Evaluate local descriptors with ASMK"""
|
112 |
+
valid_steps = ["train_codebook", "aggregate_database", "build_ivf", "query_ivf", "aggregate_build_query"]
|
113 |
+
assert multistep['step'] in valid_steps, multistep['step']
|
114 |
+
|
115 |
+
net.eval()
|
116 |
+
time0 = time.time()
|
117 |
+
logger = globals["logger"]
|
118 |
+
(globals["exp_path"] / "eval").mkdir(exist_ok=True)
|
119 |
+
logger.info(f"Starting asmk evaluation step '{multistep['step']}'")
|
120 |
+
|
121 |
+
# Handle partitioning
|
122 |
+
partition = {"suffix": "", "norm_start": 0, "norm_end": 1}
|
123 |
+
if multistep.get("partition"):
|
124 |
+
total, index = multistep['partition']
|
125 |
+
partition = {"suffix": f":{total}_{str(index).zfill(len(str(total-1)))}",
|
126 |
+
"norm_start": index / total,
|
127 |
+
"norm_end": (index+1) / total}
|
128 |
+
if multistep['step'] == "aggregate_database" or multistep['step'] == "query_ivf":
|
129 |
+
logger.info(f"Processing partition '{total}_{index}'")
|
130 |
+
|
131 |
+
# Handle distractors
|
132 |
+
distractors_path = None
|
133 |
+
distractors = multistep.get("distractors")
|
134 |
+
if distractors:
|
135 |
+
distractors_path = globals["exp_path"] / f"eval/{distractors}.ivf.pkl"
|
136 |
+
|
137 |
+
# Train codebook
|
138 |
+
asmk = asmk_method.ASMKMethod.initialize_untrained(asmk)
|
139 |
+
cdb_path = globals["exp_path"] / "eval/codebook.pkl"
|
140 |
+
if multistep['step'] == "train_codebook":
|
141 |
+
asmk_train_codebook(net, inference, globals, logger, codebook_training=codebook_training,
|
142 |
+
asmk=asmk, cache_path=cdb_path)
|
143 |
+
return None
|
144 |
+
|
145 |
+
asmk = asmk.train_codebook(None, cache_path=cdb_path)
|
146 |
+
|
147 |
+
results = {}
|
148 |
+
for dataset in datasets:
|
149 |
+
dataset_name = database_name = dataset if isinstance(dataset, str) else dataset['name']
|
150 |
+
if distractors and multistep['step'] != "aggregate_database":
|
151 |
+
dataset_name = f"{distractors}_{database_name}"
|
152 |
+
images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root=globals['root_path'])
|
153 |
+
logger.info(f"Processing dataset '{dataset_name}'")
|
154 |
+
|
155 |
+
# Infer database
|
156 |
+
if multistep['step'] == "aggregate_database":
|
157 |
+
agg_path = globals["exp_path"] / f"eval/{database_name}.agg{partition['suffix']}.pkl"
|
158 |
+
asmk_aggregate_database(net, inference, globals, logger, asmk=asmk, images=images,
|
159 |
+
partition=partition, cache_path=agg_path)
|
160 |
+
|
161 |
+
# Build ivf
|
162 |
+
elif multistep['step'] == "build_ivf":
|
163 |
+
ivf_path = globals["exp_path"] / f"eval/{dataset_name}.ivf.pkl"
|
164 |
+
asmk_build_ivf(globals, logger, asmk=asmk, cache_path=ivf_path, database_name=database_name,
|
165 |
+
distractors=distractors, distractors_path=distractors_path)
|
166 |
+
|
167 |
+
# Query ivf
|
168 |
+
elif multistep['step'] == "query_ivf":
|
169 |
+
asmk_dataset = asmk.build_ivf(None, None, cache_path=globals["exp_path"] / f"eval/{dataset_name}.ivf.pkl")
|
170 |
+
start, end = int(len(qimages)*partition['norm_start']), int(len(qimages)*partition['norm_end'])
|
171 |
+
bbxs = bbxs[start:end] if bbxs is not None else None
|
172 |
+
results_path = globals["exp_path"] / f"eval/{dataset_name}.results{partition['suffix']}.pkl"
|
173 |
+
asmk_query_ivf(net, inference, globals, logger, dataset=dataset, asmk_dataset=asmk_dataset,
|
174 |
+
qimages=qimages[start:end], bbxs=bbxs, gnd=gnd, results=results,
|
175 |
+
cache_path=results_path, imid_offset=start)
|
176 |
+
|
177 |
+
# All 3 dataset steps
|
178 |
+
elif multistep['step'] == "aggregate_build_query":
|
179 |
+
if multistep.get("partition"):
|
180 |
+
raise NotImplementedError("Partitions within step 'aggregate_build_query' are not" \
|
181 |
+
" supported, use separate steps")
|
182 |
+
results_path = globals["exp_path"] / "query_results.pkl"
|
183 |
+
if gnd is None and results_path.exists():
|
184 |
+
logger.debug("Step results already exist")
|
185 |
+
continue
|
186 |
+
asmk_dataset = asmk_index_database(net, inference, globals, logger, asmk=asmk, images=images,
|
187 |
+
distractors_path=distractors_path)
|
188 |
+
asmk_query_ivf(net, inference, globals, logger, dataset=dataset, asmk_dataset=asmk_dataset,
|
189 |
+
qimages=qimages, bbxs=bbxs, gnd=gnd, results=results, cache_path=results_path)
|
190 |
+
|
191 |
+
logger.info(f"Finished asmk evaluation step '{multistep['step']}' in {int(time.time()-time0) // 60} min")
|
192 |
+
return results
|
193 |
+
|
194 |
+
#
|
195 |
+
# Separate steps
|
196 |
+
#
|
197 |
+
|
198 |
+
def asmk_train_codebook(net, inference, globals, logger, *, codebook_training, asmk, cache_path):
|
199 |
+
"""Asmk evaluation step 'train_codebook'"""
|
200 |
+
if cache_path and cache_path.exists():
|
201 |
+
return asmk.train_codebook(None, cache_path=cache_path)
|
202 |
+
|
203 |
+
images = data_helpers.load_dataset('train', data_root=globals['root_path'])[0]
|
204 |
+
images = images[:codebook_training['images']]
|
205 |
+
dset = ImagesFromList(root='', images=images, imsize=inference['image_size'], bbxs=None,
|
206 |
+
transform=globals['transform'])
|
207 |
+
infer_opts = {"scales": codebook_training['scales'], "features_num": inference['features_num']}
|
208 |
+
des_train = how_net.extract_vectors_local(net, dset, globals["device"], **infer_opts)[0]
|
209 |
+
asmk = asmk.train_codebook(des_train, cache_path=cache_path)
|
210 |
+
logger.info(f"Codebook trained in {asmk.metadata['train_codebook']['train_time']:.1f}s")
|
211 |
+
return asmk
|
212 |
+
|
213 |
+
def asmk_aggregate_database(net, inference, globals, logger, *, asmk, images, partition, cache_path):
|
214 |
+
"""Asmk evaluation step 'aggregate_database'"""
|
215 |
+
if cache_path.exists():
|
216 |
+
logger.debug("Step results already exist")
|
217 |
+
return
|
218 |
+
codebook = asmk.codebook
|
219 |
+
kernel = kern_pkg.ASMKKernel(codebook, **asmk.params['build_ivf']['kernel'])
|
220 |
+
start, end = int(len(images)*partition['norm_start']), int(len(images)*partition['norm_end'])
|
221 |
+
data_opts = {"imsize": inference['image_size'], "transform": globals['transform']}
|
222 |
+
infer_opts = {"scales": inference['scales'], "features_num": inference['features_num']}
|
223 |
+
# Aggregate database
|
224 |
+
dset = ImagesFromList(root='', images=images[start:end], bbxs=None, **data_opts)
|
225 |
+
vecs, imids, *_ = how_net.extract_vectors_local(net, dset, globals["device"], **infer_opts)
|
226 |
+
imids += start
|
227 |
+
quantized = codebook.quantize(vecs, imids, **asmk.params["build_ivf"]["quantize"])
|
228 |
+
aggregated = kernel.aggregate(*quantized, **asmk.params["build_ivf"]["aggregate"])
|
229 |
+
with cache_path.open("wb") as handle:
|
230 |
+
pickle.dump(dict(zip(["des", "word_ids", "image_ids"], aggregated)), handle)
|
231 |
+
|
232 |
+
def asmk_build_ivf(globals, logger, *, asmk, cache_path, database_name, distractors, distractors_path):
|
233 |
+
"""Asmk evaluation step 'build_ivf'"""
|
234 |
+
if cache_path.exists():
|
235 |
+
logger.debug("Step results already exist")
|
236 |
+
return asmk.build_ivf(None, None, cache_path=cache_path)
|
237 |
+
builder = asmk.create_ivf_builder(cache_path=cache_path)
|
238 |
+
# Build ivf
|
239 |
+
if not builder.loaded_from_cache:
|
240 |
+
if distractors:
|
241 |
+
builder.initialize_with_distractors(distractors_path)
|
242 |
+
logger.debug(f"Loaded ivf with distractors '{distractors}'")
|
243 |
+
for path in sorted(globals["exp_path"].glob(f"eval/{database_name}.agg*.pkl")):
|
244 |
+
with path.open("rb") as handle:
|
245 |
+
des = pickle.load(handle)
|
246 |
+
builder.ivf.add(des['des'], des['word_ids'], des['image_ids'])
|
247 |
+
logger.info(f"Indexed '{path.name}'")
|
248 |
+
asmk_dataset = asmk.add_ivf_builder(builder)
|
249 |
+
logger.debug(f"IVF stats: {asmk_dataset.metadata['build_ivf']['ivf_stats']}")
|
250 |
+
return asmk_dataset
|
251 |
+
|
252 |
+
def asmk_index_database(net, inference, globals, logger, *, asmk, images, distractors_path=None):
|
253 |
+
"""Asmk evaluation step 'aggregate_database' and 'build_ivf'"""
|
254 |
+
data_opts = {"imsize": inference['image_size'], "transform": globals['transform']}
|
255 |
+
infer_opts = {"scales": inference['scales'], "features_num": inference['features_num']}
|
256 |
+
# Index database vectors
|
257 |
+
dset = ImagesFromList(root='', images=images, bbxs=None, **data_opts)
|
258 |
+
vecs, imids, *_ = how_net.extract_vectors_local(net, dset, globals["device"], **infer_opts)
|
259 |
+
asmk_dataset = asmk.build_ivf(vecs, imids, distractors_path=distractors_path)
|
260 |
+
logger.info(f"Indexed images in {asmk_dataset.metadata['build_ivf']['index_time']:.2f}s")
|
261 |
+
logger.debug(f"IVF stats: {asmk_dataset.metadata['build_ivf']['ivf_stats']}")
|
262 |
+
return asmk_dataset
|
263 |
+
|
264 |
+
def asmk_query_ivf(net, inference, globals, logger, *, dataset, asmk_dataset, qimages, bbxs, gnd,
|
265 |
+
results, cache_path, imid_offset=0):
|
266 |
+
"""Asmk evaluation step 'query_ivf'"""
|
267 |
+
if gnd is None and cache_path and cache_path.exists():
|
268 |
+
logger.debug("Step results already exist")
|
269 |
+
return
|
270 |
+
data_opts = {"imsize": inference['image_size'], "transform": globals['transform']}
|
271 |
+
infer_opts = {"scales": inference['scales'], "features_num": inference['features_num']}
|
272 |
+
# Query vectors
|
273 |
+
qdset = ImagesFromList(root='', images=qimages, bbxs=bbxs, **data_opts)
|
274 |
+
qvecs, qimids, *_ = how_net.extract_vectors_local(net, qdset, globals["device"], **infer_opts)
|
275 |
+
qimids += imid_offset
|
276 |
+
metadata, query_ids, ranks, scores = asmk_dataset.query_ivf(qvecs, qimids)
|
277 |
+
logger.debug(f"Average query time (quant+aggr+search) is {metadata['query_avg_time']:.3f}s")
|
278 |
+
# Evaluate
|
279 |
+
if gnd is not None:
|
280 |
+
results[dataset] = score_helpers.compute_map_and_log(dataset, ranks.T, gnd, logger=logger)
|
281 |
+
with cache_path.open("wb") as handle:
|
282 |
+
pickle.dump({"metadata": metadata, "query_ids": query_ids, "ranks": ranks, "scores": scores}, handle)
|
283 |
+
|
284 |
+
#
|
285 |
+
# Helpers
|
286 |
+
#
|
287 |
+
|
288 |
+
def _convert_checkpoint(state):
|
289 |
+
"""Enable loading checkpoints in the old format"""
|
290 |
+
if "_version" not in state:
|
291 |
+
# Old checkpoint format
|
292 |
+
meta = state['meta']
|
293 |
+
state['net_params'] = {
|
294 |
+
"architecture": meta['architecture'],
|
295 |
+
"pretrained": True,
|
296 |
+
"skip_layer": meta['skip_layer'],
|
297 |
+
"dim_reduction": {"dim": meta["dim"]},
|
298 |
+
"smoothing": {"kernel_size": meta["feat_pool_k"]},
|
299 |
+
"runtime": {
|
300 |
+
"mean_std": [meta['mean'], meta['std']],
|
301 |
+
"image_size": 1024,
|
302 |
+
"features_num": 1000,
|
303 |
+
"scales": [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25],
|
304 |
+
"training_scales": [1],
|
305 |
+
},
|
306 |
+
}
|
307 |
+
|
308 |
+
state_dict = state['state_dict']
|
309 |
+
state_dict['dim_reduction.weight'] = state_dict.pop("whiten.weight")
|
310 |
+
state_dict['dim_reduction.bias'] = state_dict.pop("whiten.bias")
|
311 |
+
|
312 |
+
state['_version'] = "how/2020"
|
313 |
+
|
314 |
+
return state
|
how/stages/train.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Implements training new models"""
|
2 |
+
|
3 |
+
import time
|
4 |
+
import copy
|
5 |
+
from collections import defaultdict
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
|
10 |
+
from cirtorch.layers.loss import ContrastiveLoss
|
11 |
+
from cirtorch.datasets.datahelpers import collate_tuples
|
12 |
+
from cirtorch.datasets.traindataset import TuplesDataset
|
13 |
+
from cirtorch.datasets.genericdataset import ImagesFromList
|
14 |
+
|
15 |
+
from ..networks import how_net
|
16 |
+
from ..utils import data_helpers, io_helpers, logging, plots
|
17 |
+
from . import evaluate
|
18 |
+
|
19 |
+
|
20 |
+
def train(demo_train, training, validation, model, globals):
|
21 |
+
"""Demo training a network
|
22 |
+
|
23 |
+
:param dict demo_train: Demo-related options
|
24 |
+
:param dict training: Training options
|
25 |
+
:param dict validation: Validation options
|
26 |
+
:param dict model: Model options
|
27 |
+
:param dict globals: Global options
|
28 |
+
"""
|
29 |
+
logger = globals["logger"]
|
30 |
+
(globals["exp_path"] / "epochs").mkdir(exist_ok=True)
|
31 |
+
if (globals["exp_path"] / f"epochs/model_epoch{training['epochs']}.pth").exists():
|
32 |
+
logger.info("Skipping network training, already trained")
|
33 |
+
return
|
34 |
+
|
35 |
+
# Global setup
|
36 |
+
set_seed(0)
|
37 |
+
globals["device"] = torch.device("cpu")
|
38 |
+
if demo_train['gpu_id'] is not None:
|
39 |
+
globals["device"] = torch.device(("cuda:%s" % demo_train['gpu_id']))
|
40 |
+
|
41 |
+
# Initialize network
|
42 |
+
net = how_net.init_network(**model).to(globals["device"])
|
43 |
+
globals["transform"] = transforms.Compose([transforms.ToTensor(), \
|
44 |
+
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
|
45 |
+
with logging.LoggingStopwatch("initializing network whitening", logger.info, logger.debug):
|
46 |
+
initialize_dim_reduction(net, globals, **training['initialize_dim_reduction'])
|
47 |
+
|
48 |
+
# Initialize training
|
49 |
+
optimizer, scheduler, criterion, train_loader = \
|
50 |
+
initialize_training(net.parameter_groups(training["optimizer"]), training, globals)
|
51 |
+
validation = Validation(validation, globals)
|
52 |
+
|
53 |
+
for epoch in range(training['epochs']):
|
54 |
+
epoch1 = epoch + 1
|
55 |
+
set_seed(epoch1)
|
56 |
+
|
57 |
+
time0 = time.time()
|
58 |
+
train_loss = train_epoch(train_loader, net, globals, criterion, optimizer, epoch1)
|
59 |
+
|
60 |
+
validation.add_train_loss(train_loss, epoch1)
|
61 |
+
validation.validate(net, epoch1)
|
62 |
+
|
63 |
+
scheduler.step()
|
64 |
+
|
65 |
+
io_helpers.save_checkpoint({
|
66 |
+
'epoch': epoch1, 'meta': net.meta, 'state_dict': net.state_dict(),
|
67 |
+
'optimizer' : optimizer.state_dict(), 'best_score': validation.best_score[1],
|
68 |
+
'scores': validation.scores, 'net_params': model, '_version': 'how/2020',
|
69 |
+
}, validation.best_score[0] == epoch1, epoch1 == training['epochs'], globals["exp_path"] / "epochs")
|
70 |
+
|
71 |
+
logger.info(f"Epoch {epoch1} finished in {time.time() - time0:.1f}s")
|
72 |
+
|
73 |
+
|
74 |
+
def train_epoch(train_loader, net, globals, criterion, optimizer, epoch1):
|
75 |
+
"""Train for one epoch"""
|
76 |
+
logger = globals['logger']
|
77 |
+
batch_time = data_helpers.AverageMeter()
|
78 |
+
data_time = data_helpers.AverageMeter()
|
79 |
+
losses = data_helpers.AverageMeter()
|
80 |
+
|
81 |
+
# Prepare epoch
|
82 |
+
train_loader.dataset.create_epoch_tuples(net)
|
83 |
+
net.train()
|
84 |
+
|
85 |
+
end = time.time()
|
86 |
+
for i, (input, target) in enumerate(train_loader):
|
87 |
+
data_time.update(time.time() - end)
|
88 |
+
optimizer.zero_grad()
|
89 |
+
|
90 |
+
num_images = len(input[0]) # number of images per tuple
|
91 |
+
for inp, trg in zip(input, target):
|
92 |
+
output = torch.zeros(net.meta['outputdim'], num_images).to(globals["device"])
|
93 |
+
for imi in range(num_images):
|
94 |
+
output[:, imi] = net(inp[imi].to(globals["device"])).squeeze()
|
95 |
+
loss = criterion(output, trg.to(globals["device"]))
|
96 |
+
loss.backward()
|
97 |
+
losses.update(loss.item())
|
98 |
+
|
99 |
+
optimizer.step()
|
100 |
+
batch_time.update(time.time() - end)
|
101 |
+
end = time.time()
|
102 |
+
|
103 |
+
if (i+1) % 20 == 0 or i == 0 or (i+1) == len(train_loader):
|
104 |
+
logger.info(f'>> Train: [{epoch1}][{i+1}/{len(train_loader)}]\t' \
|
105 |
+
f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
|
106 |
+
f'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' \
|
107 |
+
f'Loss {losses.val:.4f} ({losses.avg:.4f})')
|
108 |
+
|
109 |
+
return losses.avg
|
110 |
+
|
111 |
+
|
112 |
+
def set_seed(seed):
|
113 |
+
"""Sets given seed globally in used libraries"""
|
114 |
+
torch.manual_seed(seed)
|
115 |
+
if torch.cuda.is_available():
|
116 |
+
torch.cuda.manual_seed_all(seed)
|
117 |
+
np.random.seed(seed)
|
118 |
+
|
119 |
+
|
120 |
+
def initialize_training(net_parameters, training, globals):
|
121 |
+
"""Initialize classes necessary for training"""
|
122 |
+
# Need to check for keys because of defaults
|
123 |
+
assert training['optimizer'].keys() == {"lr", "weight_decay"}
|
124 |
+
assert training['lr_scheduler'].keys() == {"gamma"}
|
125 |
+
assert training['loss'].keys() == {"margin"}
|
126 |
+
assert training['dataset'].keys() == {"name", "mode", "imsize", "nnum", "qsize", "poolsize"}
|
127 |
+
assert training['loader'].keys() == {"batch_size"}
|
128 |
+
|
129 |
+
optimizer = torch.optim.Adam(net_parameters, **training["optimizer"])
|
130 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, **training["lr_scheduler"])
|
131 |
+
criterion = ContrastiveLoss(**training["loss"]).to(globals["device"])
|
132 |
+
train_dataset = TuplesDataset(**training['dataset'], transform=globals["transform"])
|
133 |
+
train_loader = torch.utils.data.DataLoader(train_dataset, **training['loader'], \
|
134 |
+
pin_memory=True, drop_last=True, shuffle=True, collate_fn=collate_tuples, \
|
135 |
+
num_workers=how_net.NUM_WORKERS)
|
136 |
+
return optimizer, scheduler, criterion, train_loader
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
def extract_train_descriptors(net, globals, *, images, features_num):
|
141 |
+
"""Extract descriptors for a given number of images from the train set"""
|
142 |
+
if features_num is None:
|
143 |
+
features_num = net.runtime['features_num']
|
144 |
+
|
145 |
+
images = data_helpers.load_dataset('train', data_root=globals['root_path'])[0][:images]
|
146 |
+
dataset = ImagesFromList(root='', images=images, imsize=net.runtime['image_size'], bbxs=None,
|
147 |
+
transform=globals["transform"])
|
148 |
+
des_train = how_net.extract_vectors_local(net, dataset, globals["device"],
|
149 |
+
scales=net.runtime['training_scales'],
|
150 |
+
features_num=features_num)[0]
|
151 |
+
return des_train
|
152 |
+
|
153 |
+
|
154 |
+
def initialize_dim_reduction(net, globals, **kwargs):
|
155 |
+
"""Initialize dimensionality reduction by PCA whitening from 'images' number of descriptors"""
|
156 |
+
if not net.dim_reduction:
|
157 |
+
return
|
158 |
+
|
159 |
+
print(">> Initializing dim reduction")
|
160 |
+
des_train = extract_train_descriptors(net.copy_excluding_dim_reduction(), globals, **kwargs)
|
161 |
+
net.dim_reduction.initialize_pca_whitening(des_train)
|
162 |
+
|
163 |
+
|
164 |
+
class Validation:
|
165 |
+
"""A convenient interface to validation, keeping historical values and plotting continuously
|
166 |
+
|
167 |
+
:param dict validations: Options for each validation type (e.g. local_descriptor)
|
168 |
+
:param dict globals: Global options
|
169 |
+
"""
|
170 |
+
|
171 |
+
methods = {
|
172 |
+
"global_descriptor": evaluate.eval_global,
|
173 |
+
"local_descriptor": evaluate.eval_asmk,
|
174 |
+
}
|
175 |
+
|
176 |
+
def __init__(self, validations, globals):
|
177 |
+
validations = copy.deepcopy(validations)
|
178 |
+
self.frequencies = {x: y.pop("frequency") for x, y in validations.items()}
|
179 |
+
self.validations = validations
|
180 |
+
self.globals = globals
|
181 |
+
self.scores = {x: defaultdict(list) for x in validations}
|
182 |
+
self.scores["train_loss"] = []
|
183 |
+
|
184 |
+
def add_train_loss(self, loss, epoch):
|
185 |
+
"""Store training loss for given epoch"""
|
186 |
+
self.scores['train_loss'].append((epoch, loss))
|
187 |
+
|
188 |
+
fig = plots.EpochFigure("train set", ylabel="loss")
|
189 |
+
fig.plot(*list(zip(*self.scores["train_loss"])), 'o-', label='train')
|
190 |
+
fig.save(self.globals['exp_path'] / "fig_train.jpg")
|
191 |
+
|
192 |
+
def validate(self, net, epoch):
|
193 |
+
"""Perform validation of the network and store the resulting score for given epoch"""
|
194 |
+
for name, frequency in self.frequencies.items():
|
195 |
+
if frequency and epoch % frequency == 0:
|
196 |
+
scores = self.methods[name](net, net.runtime, self.globals, **self.validations[name])
|
197 |
+
for dataset, values in scores.items():
|
198 |
+
value = values['map_medium'] if "map_medium" in values else values['map']
|
199 |
+
self.scores[name][dataset].append((epoch, value))
|
200 |
+
|
201 |
+
if "val_eccv20" in scores:
|
202 |
+
fig = plots.EpochFigure(f"val set - {name}", ylabel="mAP")
|
203 |
+
fig.plot(*list(zip(*self.scores[name]['val_eccv20'])), 'o-', label='val')
|
204 |
+
fig.save(self.globals['exp_path'] / f"fig_val_{name}.jpg")
|
205 |
+
|
206 |
+
if scores.keys() - {"val_eccv20"}:
|
207 |
+
fig = plots.EpochFigure(f"test set - {name}", ylabel="mAP")
|
208 |
+
for dataset, value in self.scores[name].items():
|
209 |
+
if dataset != "val_eccv20":
|
210 |
+
fig.plot(*list(zip(*value)), 'o-', label=dataset)
|
211 |
+
fig.save(self.globals['exp_path'] / f"fig_test_{name}.jpg")
|
212 |
+
|
213 |
+
@property
|
214 |
+
def decisive_scores(self):
|
215 |
+
"""List of pairs (epoch, score) where score is decisive for comparing epochs"""
|
216 |
+
for name in ["local_descriptor", "global_descriptor"]:
|
217 |
+
if self.frequencies[name] and "val_eccv20" in self.scores[name]:
|
218 |
+
return self.scores[name]['val_eccv20']
|
219 |
+
return self.scores["train_loss"]
|
220 |
+
|
221 |
+
@property
|
222 |
+
def last_epoch(self):
|
223 |
+
"""Tuple (last epoch, last score) or (None, None) before decisive score is computed"""
|
224 |
+
decisive_scores = self.decisive_scores
|
225 |
+
if not decisive_scores:
|
226 |
+
return None, None
|
227 |
+
|
228 |
+
return decisive_scores[-1]
|
229 |
+
|
230 |
+
@property
|
231 |
+
def best_score(self):
|
232 |
+
"""Tuple (best epoch, best score) or (None, None) before decisive score is computed"""
|
233 |
+
decisive_scores = self.decisive_scores
|
234 |
+
if not decisive_scores:
|
235 |
+
return None, None
|
236 |
+
|
237 |
+
aggr = min
|
238 |
+
for name in ["local_descriptor", "global_descriptor"]:
|
239 |
+
if self.frequencies[name] and "val_eccv20" in self.scores[name]:
|
240 |
+
aggr = max
|
241 |
+
return aggr(decisive_scores, key=lambda x: x[1])
|
how/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Standalone utilities, mainly helper functions
|
3 |
+
"""
|
how/utils/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (171 Bytes). View file
|
|
how/utils/__pycache__/data_helpers.cpython-37.pyc
ADDED
Binary file (3.63 kB). View file
|
|
how/utils/__pycache__/download.cpython-37.pyc
ADDED
Binary file (1.6 kB). View file
|
|
how/utils/__pycache__/html.cpython-37.pyc
ADDED
Binary file (11.1 kB). View file
|
|
how/utils/__pycache__/io_helpers.cpython-37.pyc
ADDED
Binary file (3.39 kB). View file
|
|
how/utils/__pycache__/score_helpers.cpython-37.pyc
ADDED
Binary file (2.27 kB). View file
|
|
how/utils/__pycache__/visualize.cpython-37.pyc
ADDED
Binary file (4.33 kB). View file
|
|
how/utils/__pycache__/whitening.cpython-37.pyc
ADDED
Binary file (1.24 kB). View file
|
|
how/utils/data_helpers.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Data manipulation helpers"""
|
2 |
+
|
3 |
+
import os.path
|
4 |
+
import pickle
|
5 |
+
|
6 |
+
from cirtorch.datasets.datahelpers import cid2filename
|
7 |
+
from cirtorch.datasets.testdataset import configdataset
|
8 |
+
|
9 |
+
|
10 |
+
def load_dataset(dataset, data_root=''):
|
11 |
+
"""Return tuple (image list, query list, bounding boxes, gnd dictionary)"""
|
12 |
+
|
13 |
+
if isinstance(dataset, dict):
|
14 |
+
root = os.path.join(data_root, dataset['image_root'])
|
15 |
+
images, qimages = None, None
|
16 |
+
if dataset['database_list'] is not None:
|
17 |
+
images = [path_join(root, x.strip("\n")) for x in open(dataset['database_list']).readlines()]
|
18 |
+
if dataset['query_list'] is not None:
|
19 |
+
qimages = [path_join(root, x.strip("\n")) for x in open(dataset['query_list']).readlines()]
|
20 |
+
bbxs = None
|
21 |
+
gnd = None
|
22 |
+
|
23 |
+
elif dataset == 'train':
|
24 |
+
training_set = 'retrieval-SfM-120k'
|
25 |
+
db_root = os.path.join(data_root, 'train', training_set)
|
26 |
+
ims_root = os.path.join(db_root, 'ims')
|
27 |
+
db_fn = os.path.join(db_root, '{}.pkl'.format(training_set))
|
28 |
+
with open(db_fn, 'rb') as f:
|
29 |
+
db = pickle.load(f)['train']
|
30 |
+
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
|
31 |
+
qimages = []
|
32 |
+
bbxs = None
|
33 |
+
gnd = None
|
34 |
+
|
35 |
+
elif dataset == 'val_eccv20':
|
36 |
+
db_root = os.path.join(data_root, 'train', 'retrieval-SfM-120k')
|
37 |
+
fn_val_proper = db_root+'/retrieval-SfM-120k-val-eccv2020.pkl' # pos are all with #inl >=3 & <= 10
|
38 |
+
with open(fn_val_proper, 'rb') as f:
|
39 |
+
db = pickle.load(f)
|
40 |
+
ims_root = os.path.join(db_root, 'ims')
|
41 |
+
images = [cid2filename(db['cids'][i], ims_root) for i in range(len(db['cids']))]
|
42 |
+
gnd = db['gnd']
|
43 |
+
qidx = db['qidx']
|
44 |
+
qimages = [images[x] for x in qidx]
|
45 |
+
bbxs = None
|
46 |
+
|
47 |
+
elif "/" in dataset:
|
48 |
+
with open(dataset, 'rb') as handle:
|
49 |
+
db = pickle.load(handle)
|
50 |
+
images, qimages, bbxs, gnd = db['imlist'], db['qimlist'], None, db['gnd']
|
51 |
+
|
52 |
+
else:
|
53 |
+
cfg = configdataset(dataset, os.path.join(data_root, 'test'))
|
54 |
+
images = [cfg['im_fname'](cfg, i) for i in range(cfg['n'])]
|
55 |
+
qimages = [cfg['qim_fname'](cfg, i) for i in range(cfg['nq'])]
|
56 |
+
if 'bbx' in cfg['gnd'][0].keys():
|
57 |
+
bbxs = [tuple(cfg['gnd'][i]['bbx']) for i in range(cfg['nq'])]
|
58 |
+
else:
|
59 |
+
bbxs = None
|
60 |
+
gnd = cfg['gnd']
|
61 |
+
|
62 |
+
return images, qimages, bbxs, gnd
|
63 |
+
|
64 |
+
|
65 |
+
def path_join(root, name):
|
66 |
+
"""Perform os.path.join by default; if asterisk is present in root, substitute with the name.
|
67 |
+
|
68 |
+
>>> path_join('/data/img_*.jpg', '001')
|
69 |
+
'/data/img_001.jpg'
|
70 |
+
"""
|
71 |
+
if "*" in root.rsplit("/", 1)[-1]:
|
72 |
+
return root.replace("*", name)
|
73 |
+
return os.path.join(root, name)
|
74 |
+
|
75 |
+
|
76 |
+
class AverageMeter:
|
77 |
+
"""Compute and store the average and last value"""
|
78 |
+
|
79 |
+
def __init__(self):
|
80 |
+
self.val = 0
|
81 |
+
self.avg = 0
|
82 |
+
self.sum = 0
|
83 |
+
self.count = 0
|
84 |
+
|
85 |
+
def update(self, val, n=1):
|
86 |
+
"""Update the counter by a new value"""
|
87 |
+
self.val = val
|
88 |
+
self.sum += val * n
|
89 |
+
self.count += n
|
90 |
+
self.avg = self.sum / self.count
|
how/utils/download.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Functions for downloading files necessary for training and evaluation"""
|
2 |
+
|
3 |
+
import os.path
|
4 |
+
from cirtorch.utils.download import download_train, download_test
|
5 |
+
from . import io_helpers
|
6 |
+
|
7 |
+
|
8 |
+
def download_for_eval(evaluation, demo_eval, dataset_url, globals):
|
9 |
+
"""Download datasets for evaluation and network if given by url"""
|
10 |
+
# Datasets
|
11 |
+
datasets = evaluation['global_descriptor']['datasets'] \
|
12 |
+
+ evaluation['local_descriptor']['datasets']
|
13 |
+
download_datasets(datasets, dataset_url, globals)
|
14 |
+
# Network
|
15 |
+
if demo_eval and (demo_eval['net_path'].startswith("http://") \
|
16 |
+
or demo_eval['net_path'].startswith("https://")):
|
17 |
+
net_name = os.path.basename(demo_eval['net_path'])
|
18 |
+
io_helpers.download_files([net_name], globals['root_path'] / "models",
|
19 |
+
os.path.dirname(demo_eval['net_path']) + "/",
|
20 |
+
logfunc=globals["logger"].info)
|
21 |
+
demo_eval['net_path'] = globals['root_path'] / "models" / net_name
|
22 |
+
|
23 |
+
|
24 |
+
def download_for_train(validation, dataset_url, globals):
|
25 |
+
"""Download datasets for training"""
|
26 |
+
|
27 |
+
datasets = ["train"] + validation['global_descriptor']['datasets'] \
|
28 |
+
+ validation['local_descriptor']['datasets']
|
29 |
+
download_datasets(datasets, dataset_url, globals)
|
30 |
+
|
31 |
+
|
32 |
+
def download_datasets(datasets, dataset_url, globals):
|
33 |
+
"""Download data associated with each required dataset"""
|
34 |
+
|
35 |
+
if "val_eccv20" in datasets:
|
36 |
+
download_train(globals['root_path'])
|
37 |
+
io_helpers.download_files(["retrieval-SfM-120k-val-eccv2020.pkl"],
|
38 |
+
globals['root_path'] / "train/retrieval-SfM-120k",
|
39 |
+
dataset_url, logfunc=globals["logger"].info)
|
40 |
+
elif "train" in datasets:
|
41 |
+
download_train(globals['root_path'])
|
42 |
+
|
43 |
+
if "roxford5k" in datasets or "rparis6k" in datasets:
|
44 |
+
download_test(globals['root_path'])
|
how/utils/html.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import os
|
3 |
+
from PIL import Image
|
4 |
+
|
5 |
+
# see help for ocmmon HTML tags at http://www.mountaindragon.com/html/text.htm
|
6 |
+
|
7 |
+
|
8 |
+
class Node:
|
9 |
+
def __init__(self,tag,text='',props=dict()):
|
10 |
+
self.children = []
|
11 |
+
self.tag=tag
|
12 |
+
self.text=text
|
13 |
+
self.props=props
|
14 |
+
def add(self, node):
|
15 |
+
self.children.append(node)
|
16 |
+
return node
|
17 |
+
def tostr(self):
|
18 |
+
s = ""
|
19 |
+
if not self.props:
|
20 |
+
s+= "<%s>%s"%(self.tag,self.text)
|
21 |
+
else:
|
22 |
+
s+= "<%s %s>%s"%(self.tag,' '.join(["%s='%s'"%(k,v) for k,v in self.props.items() if v!=None]),self.text)
|
23 |
+
for child in self.children:
|
24 |
+
s += child.tostr()
|
25 |
+
s += "</%s>"%self.tag
|
26 |
+
return s
|
27 |
+
def write(self,fout):
|
28 |
+
if not self.props:
|
29 |
+
print("<%s>%s"%(self.tag,self.text), file=fout)
|
30 |
+
else:
|
31 |
+
print("<%s %s>%s"%(self.tag,' '.join(["%s='%s'"%(k,v) for k,v in self.props.items() if v!=None]),self.text), file=fout)
|
32 |
+
for child in self.children:
|
33 |
+
child.write(fout)
|
34 |
+
print("</%s>"%self.tag, file=fout)
|
35 |
+
def first(self,tag,order=1):
|
36 |
+
if self.tag==tag: return self
|
37 |
+
for c in self.children[::order]:
|
38 |
+
res = c.first(tag,order)
|
39 |
+
if res: return res
|
40 |
+
return None
|
41 |
+
def last(self,tag):
|
42 |
+
return self.first(tag,-1)
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
class HTML (Node):
|
47 |
+
def __init__(self):
|
48 |
+
Node.__init__(self,'html')
|
49 |
+
def header(self,**kw):
|
50 |
+
return self.add(Header(**kw))
|
51 |
+
def body(self,**kw):
|
52 |
+
return self.add(BodyNode('body',props=kw))
|
53 |
+
def save(self,fname):
|
54 |
+
fout = open(fname,'w') if type(fname)==str else fname
|
55 |
+
for e in self.children:
|
56 |
+
e.write(fout)
|
57 |
+
def show(self,fname=''):
|
58 |
+
if not fname: fname = os.tmpname()+'.html'
|
59 |
+
self.save(fname)
|
60 |
+
os.system('/opt/google/chrome/google-chrome '+fname)
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
class Header (HTML):
|
65 |
+
def __init__(self, **kw):
|
66 |
+
Node.__init__(self,'header',props=kw)
|
67 |
+
def title(self,text):
|
68 |
+
return self.add(Node('title',text=text))
|
69 |
+
def script(self, text="", **kw):
|
70 |
+
return self.add(Node('script',text=text, props=kw))
|
71 |
+
def link(self, **kw):
|
72 |
+
return self.add(Node('link', props=kw))
|
73 |
+
def meta(self):
|
74 |
+
return self.add(Node('meta', props={"http-equiv":"Content-Type", "content": "charset=iso-8859-1"}))
|
75 |
+
|
76 |
+
class BodyNode (Node):
|
77 |
+
# title of section
|
78 |
+
def h(self, strength, text='', **kw):
|
79 |
+
return self.add(BodyNode('h%d'%strength, text=text, props=kw))
|
80 |
+
# paragraph
|
81 |
+
def p(self, text='', **kw):
|
82 |
+
return self.add(BodyNode('p',text=text, props=kw))
|
83 |
+
# bold
|
84 |
+
def bold(self, text='', **kw):
|
85 |
+
return self.add(BodyNode('b',text=text, props=kw))
|
86 |
+
def b(self, text='', **kw):
|
87 |
+
return self.add(BodyNode('b',text=text, props=kw))
|
88 |
+
# italic
|
89 |
+
def italic(self, text='', **kw):
|
90 |
+
return self.add(BodyNode('i',text=text, props=kw))
|
91 |
+
def i(self, text='', **kw):
|
92 |
+
return self.add(BodyNode('i',text=text, props=kw))
|
93 |
+
# span/text
|
94 |
+
def span(self, text='', **kw):
|
95 |
+
return self.add(BodyNode('span',text=text, props=kw))
|
96 |
+
# font
|
97 |
+
def font(self,text='',color=None,face=None,size=None):
|
98 |
+
return self.add(BodyNode('font',text=text, props={'color':color,'face':face,'size':size}))
|
99 |
+
# small
|
100 |
+
def small(self, text='', **kw):
|
101 |
+
return self.add(BodyNode('small',text=text, props=kw))
|
102 |
+
def big(self, text='', **kw):
|
103 |
+
return self.add(BodyNode('big',text=text, props=kw))
|
104 |
+
# centered
|
105 |
+
def center(self, text='', **kw):
|
106 |
+
return self.add(BodyNode('center',text=text, props=kw))
|
107 |
+
# div
|
108 |
+
def div(self, text='', **kw):
|
109 |
+
return self.add(BodyNode('div', text=text, props=kw))
|
110 |
+
# unordered list
|
111 |
+
def unordlist(self, text='', **kw):
|
112 |
+
return self.add(BodyNode('ul', text=text, props=kw))
|
113 |
+
# ordered list
|
114 |
+
def ordlist(self, text='', **kw):
|
115 |
+
return self.add(BodyNode('ol', text=text, props=kw))
|
116 |
+
def item(self, text='', type=None, **kw):
|
117 |
+
kw['type'] = type # non-ord {'circle', 'square', 'disc'}, ord {'1', 'A', 'a', 'I', 'i'}
|
118 |
+
return self.add(BodyNode('li', text=text, props=kw))
|
119 |
+
# line break
|
120 |
+
def br(self):
|
121 |
+
self.add(Node('br'))
|
122 |
+
# horizontal line
|
123 |
+
def hr(self):
|
124 |
+
self.add(Node('hr'))
|
125 |
+
# table
|
126 |
+
def table(self, **kw):
|
127 |
+
return self.add(Table(**kw))
|
128 |
+
# image
|
129 |
+
def image(self, img, **kw):
|
130 |
+
return self.add(Image(img,**kw))
|
131 |
+
# link
|
132 |
+
def a(self, href, text='', **kw):
|
133 |
+
kw['href'] = href
|
134 |
+
return self.add(BodyNode('a', text=text, props=kw))
|
135 |
+
def hidden(self, text, **kw):
|
136 |
+
kw['type'] = 'hidden'
|
137 |
+
kw['value'] = text
|
138 |
+
return self.add(BodyNode('input',props=kw))
|
139 |
+
def imagelink(self, img, **kw):
|
140 |
+
return self.add( BodyNode('a', text=Image(img,**kw).tostr(), props={"href":img}) )
|
141 |
+
|
142 |
+
class Table (Node):
|
143 |
+
def __init__(self,**kw):
|
144 |
+
Node.__init__(self,'table',props=kw)
|
145 |
+
def row(self,elems=[],header=False,**kw):
|
146 |
+
r=TableRow(header, **kw)
|
147 |
+
for e in elems:
|
148 |
+
if issubclass(e.__class__,Node):
|
149 |
+
r.add(e)
|
150 |
+
else:
|
151 |
+
r.cell(str(e))
|
152 |
+
return self.add(r)
|
153 |
+
def fromlist(self, elems, header=None):
|
154 |
+
if header and type(header)!=bool: elems=[header]+elems; header=True
|
155 |
+
for row in elems:
|
156 |
+
self.row(row,header=header)
|
157 |
+
header=False # only once
|
158 |
+
|
159 |
+
|
160 |
+
class TableRow (Node):
|
161 |
+
def __init__(self, isheader=False, **kw):
|
162 |
+
Node.__init__(self,'tr',props=kw)
|
163 |
+
self.isheader=isheader
|
164 |
+
def cell(self, text='', **kw):
|
165 |
+
return self.add(BodyNode(self.isheader and 'th' or 'td',text=text,props=kw))
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
class Image (Node):
|
170 |
+
def __init__(self, img, dir='', name='', width=None, height=None, alt=None):
|
171 |
+
if type(img)==str: loc = img
|
172 |
+
else:
|
173 |
+
if name:
|
174 |
+
img.save(os.path.join(dir,name))
|
175 |
+
loc = name
|
176 |
+
else:
|
177 |
+
loc = os.tmpnam()+".png"
|
178 |
+
img.save(loc)
|
179 |
+
Node.__init__(self,'img',props={'src':loc,'width':width,'height':height,'alt':alt,'title':alt})
|
180 |
+
|
181 |
+
|
182 |
+
def htmlspace(n):
|
183 |
+
return " ".join(["" for i in range(n)])
|
184 |
+
def htmloptions(l):
|
185 |
+
return "".join(["<option>"+s+"</option>" for s in l])
|
186 |
+
|
187 |
+
|
188 |
+
if __name__=='__main__':
|
189 |
+
import pdb
|
190 |
+
|
191 |
+
doc = HTML()
|
192 |
+
doc.header().title('test of python-generated HTML page')
|
193 |
+
body=doc.body()
|
194 |
+
body.h(1,"1. Title of page")
|
195 |
+
body.p('a paragraph of text')
|
196 |
+
body.h(2,"2.1 second title")
|
197 |
+
p=body.p()
|
198 |
+
p.italic('another')
|
199 |
+
p.font(color='red').bold('paragraph')
|
200 |
+
p.span('of text')
|
201 |
+
body.h(3,'2.1.1. sub-sub-title')
|
202 |
+
body.p("Here is a list:")
|
203 |
+
ls=body.unordlist()
|
204 |
+
ls.item("first item")
|
205 |
+
ls.item("second item")
|
206 |
+
ls.item("final item")
|
207 |
+
body.hr()
|
208 |
+
body.table(border=1).fromlist([[1,2],[3,4]],header=['col1','col2'])
|
209 |
+
body.br()
|
210 |
+
body.center().image(img='/home/lear/revaud/coca-cola.jpg',width=500,height=300)
|
211 |
+
body.hr()
|
212 |
+
tab=body.table(border=0)
|
213 |
+
tab.row(['coca-cola']*5,header=True)
|
214 |
+
for i in range(3):
|
215 |
+
r = body.last('table').row()
|
216 |
+
for j in range(5):
|
217 |
+
r.cell(bgcolor=['#00FF00','red'][(i+j)%2]).image('/home/lear/revaud/coca-cola2.jpg',width=200)
|
218 |
+
|
219 |
+
doc.show('/tmp/test.html')
|
220 |
+
print('result stored in /tmp/test.html')
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
|
231 |
+
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
|
249 |
+
|
250 |
+
|
251 |
+
|
252 |
+
|
how/utils/io_helpers.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper functions related to io"""
|
2 |
+
|
3 |
+
import os.path
|
4 |
+
import sys
|
5 |
+
import shutil
|
6 |
+
import urllib.request
|
7 |
+
from pathlib import Path
|
8 |
+
import yaml
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
def progress(iterable, *, size=None, print_freq=1, handle=sys.stdout):
|
13 |
+
"""Generator wrapping an iterable to print progress"""
|
14 |
+
for i, element in enumerate(iterable):
|
15 |
+
yield element
|
16 |
+
|
17 |
+
if i == 0 or (i+1) % print_freq == 0 or (i+1) == size:
|
18 |
+
if size:
|
19 |
+
handle.write(f'\r>>>> {i+1}/{size} done...')
|
20 |
+
else:
|
21 |
+
handle.write(f'\r>>>> {i+1} done...')
|
22 |
+
|
23 |
+
handle.write("\n")
|
24 |
+
|
25 |
+
|
26 |
+
# Params
|
27 |
+
|
28 |
+
def load_params(path):
|
29 |
+
"""Return loaded parameters from a yaml file"""
|
30 |
+
with open(path, "r") as handle:
|
31 |
+
content = yaml.safe_load(handle)
|
32 |
+
return load_nested_templates(content, os.path.dirname(path))
|
33 |
+
|
34 |
+
def save_params(path, params):
|
35 |
+
"""Save given parameters to a yaml file"""
|
36 |
+
with open(path, "w") as handle:
|
37 |
+
yaml.safe_dump(params, handle, default_flow_style=False)
|
38 |
+
|
39 |
+
def load_nested_templates(params, root_path):
|
40 |
+
"""Find keys '__template__' in nested dictionary and replace corresponding value with loaded
|
41 |
+
yaml file"""
|
42 |
+
if not isinstance(params, dict):
|
43 |
+
return params
|
44 |
+
|
45 |
+
if "__template__" in params:
|
46 |
+
template_path = os.path.expanduser(params.pop("__template__"))
|
47 |
+
path = os.path.join(root_path, template_path)
|
48 |
+
root_path = os.path.dirname(path)
|
49 |
+
# Treat template as defaults
|
50 |
+
params = dict_deep_overlay(load_params(path), params)
|
51 |
+
|
52 |
+
for key, value in params.items():
|
53 |
+
params[key] = load_nested_templates(value, root_path)
|
54 |
+
|
55 |
+
return params
|
56 |
+
|
57 |
+
def dict_deep_overlay(defaults, params):
|
58 |
+
"""If defaults and params are both dictionaries, perform deep overlay (use params value for
|
59 |
+
keys defined in params), otherwise use defaults value"""
|
60 |
+
if isinstance(defaults, dict) and isinstance(params, dict):
|
61 |
+
for key in params:
|
62 |
+
defaults[key] = dict_deep_overlay(defaults.get(key, None), params[key])
|
63 |
+
return defaults
|
64 |
+
|
65 |
+
return params
|
66 |
+
|
67 |
+
def dict_deep_set(dct, key, value):
|
68 |
+
"""Set key to value for a nested dictionary where the key is a sequence (e.g. list)"""
|
69 |
+
if len(key) == 1:
|
70 |
+
dct[key[0]] = value
|
71 |
+
return
|
72 |
+
|
73 |
+
if not isinstance(dct[key[0]], dict) or key[0] not in dct:
|
74 |
+
dct[key[0]] = {}
|
75 |
+
dict_deep_set(dct[key[0]], key[1:], value)
|
76 |
+
|
77 |
+
|
78 |
+
# Download
|
79 |
+
|
80 |
+
def download_files(names, root_path, base_url, logfunc=None):
|
81 |
+
"""Download file names from given url to given directory path. If logfunc given, use it to log
|
82 |
+
status."""
|
83 |
+
root_path = Path(root_path)
|
84 |
+
for name in names:
|
85 |
+
path = root_path / name
|
86 |
+
if path.exists():
|
87 |
+
continue
|
88 |
+
if logfunc:
|
89 |
+
logfunc(f"Downloading file '{name}'")
|
90 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
91 |
+
urllib.request.urlretrieve(base_url + name, path)
|
92 |
+
|
93 |
+
|
94 |
+
# Checkpoints
|
95 |
+
|
96 |
+
def save_checkpoint(state, is_best, keep_epoch, directory):
|
97 |
+
"""Save state dictionary to the directory providing whether the corresponding epoch is the best
|
98 |
+
and whether to keep it anyway"""
|
99 |
+
filename = os.path.join(directory, 'model_epoch%d.pth' % state['epoch'])
|
100 |
+
filename_best = os.path.join(directory, 'model_best.pth')
|
101 |
+
if is_best and keep_epoch:
|
102 |
+
torch.save(state, filename)
|
103 |
+
shutil.copyfile(filename, filename_best)
|
104 |
+
elif is_best or keep_epoch:
|
105 |
+
torch.save(state, filename_best if is_best else filename)
|
how/utils/logging.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Logging-related functionality"""
|
2 |
+
|
3 |
+
import time
|
4 |
+
import logging
|
5 |
+
|
6 |
+
# Logging
|
7 |
+
|
8 |
+
def init_logger(log_path):
|
9 |
+
"""Return a logger instance which logs to stdout and, if log_path is not None, also to a file"""
|
10 |
+
logger = logging.getLogger("HOW")
|
11 |
+
logger.setLevel(logging.DEBUG)
|
12 |
+
|
13 |
+
stdout_handler = logging.StreamHandler()
|
14 |
+
stdout_handler.setLevel(logging.INFO)
|
15 |
+
stdout_handler.setFormatter(logging.Formatter('%(name)s %(levelname)s: %(message)s'))
|
16 |
+
logger.addHandler(stdout_handler)
|
17 |
+
|
18 |
+
if log_path:
|
19 |
+
file_handler = logging.FileHandler(log_path)
|
20 |
+
file_handler.setLevel(logging.DEBUG)
|
21 |
+
formatter = logging.Formatter('%(asctime)s %(name)s %(levelname)s: %(message)s')
|
22 |
+
file_handler.setFormatter(formatter)
|
23 |
+
logger.addHandler(file_handler)
|
24 |
+
|
25 |
+
return logger
|
26 |
+
|
27 |
+
|
28 |
+
# Stopwatch
|
29 |
+
|
30 |
+
class LoggingStopwatch:
|
31 |
+
"""Stopwatch context that produces one message when entered and another one when exited,
|
32 |
+
with the time spent in the context embedded in the exiting message.
|
33 |
+
|
34 |
+
:param str message: Message to be logged at the start and finish. If the first word
|
35 |
+
of the message ends with 'ing', convert to passive for finish message.
|
36 |
+
:param callable log_start: Will be called with given message at the start
|
37 |
+
:param callable log_finish: Will be called with built message at the finish. If None, use
|
38 |
+
log_start
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self, message, log_start, log_finish=None):
|
42 |
+
self.message = message
|
43 |
+
self.log_start = log_start
|
44 |
+
self.log_finish = log_finish if log_finish is not None else log_start
|
45 |
+
self.time0 = None
|
46 |
+
|
47 |
+
def __enter__(self):
|
48 |
+
self.time0 = time.time()
|
49 |
+
if self.log_start:
|
50 |
+
self.log_start(self.message.capitalize())
|
51 |
+
|
52 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
53 |
+
# Build message
|
54 |
+
words = self.message.split(" ")
|
55 |
+
secs = "%.1fs" % (time.time() - self.time0)
|
56 |
+
if words[0].endswith("ing"):
|
57 |
+
words += [words.pop(0).replace("ing", "ed"), "in", secs]
|
58 |
+
else:
|
59 |
+
words += ["(%.1f)" % secs]
|
60 |
+
|
61 |
+
# Log message
|
62 |
+
if self.log_finish:
|
63 |
+
self.log_finish(" ".join(words).capitalize())
|
how/utils/plots.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Plotting classes"""
|
2 |
+
|
3 |
+
import matplotlib
|
4 |
+
matplotlib.use('Agg')
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
class EpochFigure:
|
9 |
+
"""Basic figure for plotting scores across epochs
|
10 |
+
|
11 |
+
:param str title: Figure title
|
12 |
+
:param str ylabel: Plot's y label
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, title, *, ylabel):
|
16 |
+
self.fig = plt.figure()
|
17 |
+
self.axes = self.fig.add_subplot(1, 1, 1)
|
18 |
+
self.title = title
|
19 |
+
self.ylabel = ylabel
|
20 |
+
|
21 |
+
def __del__(self):
|
22 |
+
plt.close(self.fig)
|
23 |
+
|
24 |
+
def __getattr__(self, name):
|
25 |
+
# Delegate method calls on self.axes
|
26 |
+
return getattr(self.axes, name)
|
27 |
+
|
28 |
+
def save(self, path):
|
29 |
+
"""Save figure to given path"""
|
30 |
+
self.axes.grid(b=True, which='major', color='k', linestyle='-')
|
31 |
+
self.axes.grid(b=True, which='minor', color='r', linestyle='-', alpha=0.2)
|
32 |
+
self.axes.minorticks_on()
|
33 |
+
self.axes.legend()
|
34 |
+
self.axes.set_xlabel('epoch')
|
35 |
+
self.axes.set_ylabel(self.ylabel)
|
36 |
+
self.axes.set_title(self.title)
|
37 |
+
self.fig.savefig(path)
|
how/utils/score_helpers.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Helper functions for computing evaluation scores"""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from cirtorch.utils.evaluate import compute_map
|
6 |
+
|
7 |
+
|
8 |
+
def compute_map_and_log(dataset, ranks, gnd, kappas=(1, 5, 10), logger=None):
|
9 |
+
"""Computed mAP and log it
|
10 |
+
|
11 |
+
:param str dataset: Dataset to compute the mAP on (e.g. roxford5k)
|
12 |
+
:param np.ndarray ranks: 2D matrix of ints corresponding to previously computed ranks
|
13 |
+
:param dict gnd: Ground-truth dataset structure
|
14 |
+
:param list kappas: Compute mean precision at each kappa
|
15 |
+
:param logging.Logger logger: If not None, use it to log mAP and all mP@kappa
|
16 |
+
:return tuple: mAP and mP@kappa (medium difficulty for roxford5k and rparis6k)
|
17 |
+
"""
|
18 |
+
# new evaluation protocol
|
19 |
+
if dataset.startswith('roxford5k') or dataset.startswith('rparis6k'):
|
20 |
+
gnd_t = []
|
21 |
+
for gndi in gnd:
|
22 |
+
g = {}
|
23 |
+
g['ok'] = np.concatenate([gndi['easy']])
|
24 |
+
g['junk'] = np.concatenate([gndi['junk'], gndi['hard']])
|
25 |
+
gnd_t.append(g)
|
26 |
+
mapE, apsE, mprE, prsE = compute_map(ranks, gnd_t, kappas)
|
27 |
+
|
28 |
+
gnd_t = []
|
29 |
+
for gndi in gnd:
|
30 |
+
g = {}
|
31 |
+
g['ok'] = np.concatenate([gndi['easy'], gndi['hard']])
|
32 |
+
g['junk'] = np.concatenate([gndi['junk']])
|
33 |
+
gnd_t.append(g)
|
34 |
+
mapM, apsM, mprM, prsM = compute_map(ranks, gnd_t, kappas)
|
35 |
+
|
36 |
+
gnd_t = []
|
37 |
+
for gndi in gnd:
|
38 |
+
g = {}
|
39 |
+
g['ok'] = np.concatenate([gndi['hard']])
|
40 |
+
g['junk'] = np.concatenate([gndi['junk'], gndi['easy']])
|
41 |
+
gnd_t.append(g)
|
42 |
+
mapH, apsH, mprH, prsH = compute_map(ranks, gnd_t, kappas)
|
43 |
+
|
44 |
+
if logger:
|
45 |
+
fmap = lambda x: np.around(x*100, decimals=2)
|
46 |
+
logger.info(f"Evaluated {dataset}: mAP E: {fmap(mapE)}, M: {fmap(mapM)}, H: {fmap(mapH)}")
|
47 |
+
logger.info(f"Evaluated {dataset}: mP@k{kappas} E: {fmap(mprE)}, M: {fmap(mprM)}, H: {fmap(mprH)}")
|
48 |
+
|
49 |
+
scores = {"map_easy": mapE.item(), "mp@k_easy": mprE, "ap_easy": apsE, "p@k_easy": prsE,
|
50 |
+
"map_medium": mapM.item(), "mp@k_medium": mprM, "ap_medium": apsM, "p@k_medium": prsM,
|
51 |
+
"map_hard": mapH.item(), "mp@k_hard": mprH, "ap_hard": apsH, "p@k_hard": prsH}
|
52 |
+
return scores
|
53 |
+
|
54 |
+
# old evaluation protocol
|
55 |
+
map_score, ap_scores, prk, pr_scores = compute_map(ranks, gnd, kappas=kappas)
|
56 |
+
if logger:
|
57 |
+
fmap = lambda x: np.around(x*100, decimals=2)
|
58 |
+
logger.info(f"Evaluated {dataset}: mAP {fmap(map_score)}, mP@k {fmap(prk)}")
|
59 |
+
return {"map": map_score, "mp@k": prk, "ap": ap_scores, "p@k": pr_scores}
|
how/utils/visualize.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
from how.utils.html import HTML
|
7 |
+
|
8 |
+
def visualize_attention_map(dataset_name, imgpaths, attentions, scales, outdir):
|
9 |
+
assert len(imgpaths) == len(attentions)
|
10 |
+
os.makedirs(outdir, exist_ok=True)
|
11 |
+
for i, imgpath in enumerate(imgpaths): # for each image
|
12 |
+
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
|
13 |
+
atts = attentions[i]
|
14 |
+
# load image
|
15 |
+
img = cv2.imread(imgpath)
|
16 |
+
# generate the visu for each scale independently
|
17 |
+
for j,s in enumerate(scales):
|
18 |
+
a = atts[j]
|
19 |
+
img_s = cv2.resize(img, None, fx=s, fy=s)
|
20 |
+
heatmap_s = cv2.applyColorMap( (255*cv2.resize(a, (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET)
|
21 |
+
overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0)
|
22 |
+
cv2.imwrite(outdir+'{:s}_scale{:g}.jpg'.format(img_basename, s), overlay)
|
23 |
+
# generate the visu for the aggregation over scales
|
24 |
+
agg_atts = sum([cv2.resize(a, (img.shape[1],img.shape[0])) for a in atts]) / len(atts)
|
25 |
+
heatmap_s = cv2.applyColorMap( (255*agg_atts).astype(np.uint8), cv2.COLORMAP_JET)
|
26 |
+
overlay = cv2.addWeighted(heatmap_s, 0.5, img, 0.5, 0)
|
27 |
+
cv2.imwrite(outdir+'{:s}_aggregated.jpg'.format(img_basename), overlay)
|
28 |
+
# generate a html webpage for visualization
|
29 |
+
doc = HTML()
|
30 |
+
doc.header().title(dataset_name)
|
31 |
+
b = doc.body()
|
32 |
+
b.h(1, dataset_name+' (attention map)')
|
33 |
+
t = b.table(cellpadding=2, border=1)
|
34 |
+
for i, imgpath in enumerate(imgpaths):
|
35 |
+
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
|
36 |
+
if i%3==0: t.row(['info','image','agg','scale 1']+['scale '+str(s) for s in scales if s!=1], header=True)
|
37 |
+
r = t.row()
|
38 |
+
r.cell(str(i)+': '+img_basename)
|
39 |
+
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath))
|
40 |
+
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_aggregated.jpg'.format(img_basename)))
|
41 |
+
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale1.jpg'.format(img_basename)))
|
42 |
+
for s in scales:
|
43 |
+
if s==1: continue
|
44 |
+
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img='{:s}_scale{:g}.jpg'.format(img_basename,s)))
|
45 |
+
doc.save(outdir+'index.html')
|
46 |
+
|
47 |
+
|
48 |
+
def visualize_region_maps(dataset_name, imgpaths, attentions, regions, scales, outdir, topk=10):
|
49 |
+
assert len(imgpaths) == len(attentions)
|
50 |
+
assert len(attentions) == len(regions)
|
51 |
+
assert 1 in scales # we display the regions only for scale 1 (at least so far)
|
52 |
+
os.makedirs(outdir, exist_ok=True)
|
53 |
+
# generate visualization of each region
|
54 |
+
for i, imgpath in enumerate(imgpaths): # for each image
|
55 |
+
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
|
56 |
+
regs = regions[i]
|
57 |
+
# load image
|
58 |
+
img = cv2.imread(imgpath)
|
59 |
+
# for each scale
|
60 |
+
for j,s in enumerate(scales):
|
61 |
+
if s!=1: continue # just consider scale 1
|
62 |
+
r = regs[j][-1]
|
63 |
+
img_s = cv2.resize(img, None, fx=s, fy=s)
|
64 |
+
for ir in range(r.shape[0]):
|
65 |
+
heatmap_s = cv2.applyColorMap( (255*cv2.resize(np.minimum(1,100*r[ir,:,:]), (img_s.shape[1],img_s.shape[0]))).astype(np.uint8), cv2.COLORMAP_JET) # factor 10 for easier visualization
|
66 |
+
overlay = cv2.addWeighted(heatmap_s, 0.5, img_s, 0.5, 0)
|
67 |
+
cv2.imwrite(outdir+'{:s}_region{:d}_scale{:g}.jpg'.format(img_basename, ir, s), overlay)
|
68 |
+
# generate a html webpage for visualization
|
69 |
+
doc = HTML()
|
70 |
+
doc.header().title(dataset_name)
|
71 |
+
b = doc.body()
|
72 |
+
b.h(1, dataset_name+' (region maps)')
|
73 |
+
t = b.table(cellpadding=2, border=1)
|
74 |
+
for i, imgpath in enumerate(imgpaths):
|
75 |
+
atts = attentions[i]
|
76 |
+
regs = regions[i]
|
77 |
+
for j,s in enumerate(scales):
|
78 |
+
a = atts[j]
|
79 |
+
rr = regs[j][-1] # -1 because it is a list of the history of regions
|
80 |
+
if s==1: break
|
81 |
+
argsort = np.argsort(-a)
|
82 |
+
img_basename = os.path.splitext(os.path.basename(imgpath))[0]
|
83 |
+
if i%3==0: t.row(['info','image']+['scale 1 - region {:d}'.format(ir) for ir in range(topk)], header=True)
|
84 |
+
r = t.row()
|
85 |
+
r.cell(str(i)+': '+img_basename)
|
86 |
+
r.cell('<a href="{img:s}"><img src="{img:s}"/></a>'.format(img=imgpath))
|
87 |
+
for ir in range(topk):
|
88 |
+
index = argsort[ir]
|
89 |
+
r.cell('<a href="{img:s}"><img src="{img:s}"/></a><br>index: {index:d}, att: {att:g}, rmax: {rmax:g}'.format(img='{:s}_region{:d}_scale{:g}.jpg'.format(img_basename,index,s), index=index, att=a[index], rmax=rr[index,:,:].max()))
|
90 |
+
doc.save(outdir+'index.html')
|
91 |
+
|
92 |
+
if __name__=='__main__':
|
93 |
+
dataset = 'roxford5k'
|
94 |
+
from how.utils import data_helpers
|
95 |
+
images, qimages, bbxs, gnd = data_helpers.load_dataset(dataset, data_root="/tmp-network/user/pweinzae/CNNImageRetrieval/data/")
|
96 |
+
import pickle
|
97 |
+
with open('/tmp-network/user/pweinzae/roxford5k_features_attentions.pkl', 'rb') as fid:
|
98 |
+
features, attentions = pickle.load(fid)
|
99 |
+
visualize_attention_maps(qimages, attentions, scales=[2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25], outdir='/tmp-network/user/pweinzae/tmp/visu_attention_maps/'+dataset)
|
how/utils/whitening.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Functions for training and applying whitening"""
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
def l2_normalize_vec(X):
|
7 |
+
"""L2-normalize given descriptors"""
|
8 |
+
return X / (np.linalg.norm(X, ord=2, axis=1, keepdims=True) + 1e-6)
|
9 |
+
|
10 |
+
|
11 |
+
def whitenapply(X, m, P, dimensions=None):
|
12 |
+
"""Apply whitening (m, P) on descriptors X. If dimensions not None, perform dim reduction."""
|
13 |
+
if not dimensions:
|
14 |
+
dimensions = P.shape[1]
|
15 |
+
|
16 |
+
X = np.dot(X-m, P[:, :dimensions])
|
17 |
+
return l2_normalize_vec(X)
|
18 |
+
|
19 |
+
|
20 |
+
def pcawhitenlearn_shrinkage(X, s=1.0):
|
21 |
+
"""Learn PCA whitening with shrinkage from given descriptors"""
|
22 |
+
N = X.shape[0]
|
23 |
+
|
24 |
+
# Learning PCA w/o annotations
|
25 |
+
m = X.mean(axis=0, keepdims=True)
|
26 |
+
Xc = X - m
|
27 |
+
Xcov = np.dot(Xc.T, Xc)
|
28 |
+
Xcov = (Xcov + Xcov.T) / (2*N)
|
29 |
+
eigval, eigvec = np.linalg.eig(Xcov)
|
30 |
+
order = eigval.argsort()[::-1]
|
31 |
+
eigval = eigval[order]
|
32 |
+
eigvec = eigvec[:, order]
|
33 |
+
|
34 |
+
P = np.dot(np.linalg.inv(np.diag(np.power(eigval, 0.5*s))), eigvec.T)
|
35 |
+
|
36 |
+
return m, P.T
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pyaml
|
3 |
+
matplotlib
|
4 |
+
torch==1.3.1
|
5 |
+
torchvision==0.4.2
|