taesiri commited on
Commit
8390f90
β€’
1 Parent(s): ec6b92a

Initial Commit

Browse files
README.md CHANGED
@@ -1,37 +1,13 @@
1
  ---
2
  title: ConvolutionalHoughMatchingNetworks
3
  emoji: πŸ“š
4
- colorFrom: green
5
  colorTo: yellow
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
9
  ---
10
 
11
- # Configuration
12
 
13
- `title`: _string_
14
- Display title for the Space
15
-
16
- `emoji`: _string_
17
- Space emoji (emoji-only character allowed)
18
-
19
- `colorFrom`: _string_
20
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
21
-
22
- `colorTo`: _string_
23
- Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
24
-
25
- `sdk`: _string_
26
- Can be either `gradio` or `streamlit`
27
-
28
- `sdk_version` : _string_
29
- Only applicable for `streamlit` SDK.
30
- See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
31
-
32
- `app_file`: _string_
33
- Path to your main application file (which contains either `gradio` or `streamlit` Python code).
34
- Path is relative to the root of the repository.
35
-
36
- `pinned`: _boolean_
37
- Whether the Space stays on top of your list.
1
  ---
2
  title: ConvolutionalHoughMatchingNetworks
3
  emoji: πŸ“š
4
+ colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
9
  ---
10
 
11
+ # Convolutional Hough Matching Networks
12
 
13
+ A demo for Convolutional Hough Matching Networks. [[Paper](https://arxiv.org/abs/2109.05221)] [[Official Github Repo](https://github.com/juhongm999/chm.git)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ import torch
3
+ from model.base.geometry import Geometry
4
+ from common.evaluation import Evaluator
5
+ from common.logger import AverageMeter
6
+ from common.logger import Logger
7
+ from data import download
8
+ from model import chmnet
9
+ from matplotlib import pyplot as plt
10
+ from matplotlib.patches import ConnectionPatch
11
+ from PIL import Image
12
+ import numpy as np
13
+ import os
14
+ import torchvision
15
+ import torchvision.transforms as transforms
16
+ import torchvision.transforms.functional as TF
17
+ import torchvision.models as models
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import random
21
+ import gradio as gr
22
+
23
+ # Downloading the Model
24
+ torchvision.datasets.utils.download_file_from_google_drive('1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6', '.', 'pas_psi.pt')
25
+
26
+ # Model Initialization
27
+ args = dict({
28
+ 'alpha' : [0.05, 0.1],
29
+ 'benchmark':'pfpascal',
30
+ 'bsz':90,
31
+ 'datapath':'../Datasets_CHM',
32
+ 'img_size':240,
33
+ 'ktype':'psi',
34
+ 'load':'pas_psi.pt',
35
+ 'thres':'img'
36
+ })
37
+
38
+ model = chmnet.CHMNet(args['ktype'])
39
+ model.load_state_dict(torch.load(args['load'], map_location=torch.device('cpu')))
40
+ Evaluator.initialize(args['alpha'])
41
+ Geometry.initialize(img_size=args['img_size'])
42
+ model.eval();
43
+
44
+ # Transforms
45
+
46
+ chm_transform = transforms.Compose(
47
+ [transforms.Resize(args['img_size']),
48
+ transforms.CenterCrop((args['img_size'], args['img_size'])),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
51
+
52
+ chm_transform_plot = transforms.Compose(
53
+ [transforms.Resize(args['img_size']),
54
+ transforms.CenterCrop((args['img_size'], args['img_size']))])
55
+
56
+ # A Helper Function
57
+ to_np = lambda x: x.data.to('cpu').numpy()
58
+
59
+ # Colors for Plotting
60
+ cmap = matplotlib.cm.get_cmap('Spectral')
61
+ rgba = cmap(0.5)
62
+ colors = []
63
+ for k in range(49):
64
+ colors.append(cmap(k/49.0))
65
+
66
+
67
+ # CHM MODEL
68
+ def run_chm(source_image, target_image, selected_points, number_src_points , chm_transform, display_transform):
69
+ # Convert to Tensor
70
+ src_img_tnsr = chm_transform(source_image).unsqueeze(0)
71
+ tgt_img_tnsr = chm_transform(target_image).unsqueeze(0)
72
+
73
+ # Selected_points = selected_points.T
74
+ keypoints = torch.tensor(selected_points).unsqueeze(0)
75
+ n_pts = torch.tensor(np.asarray([number_src_points]))
76
+
77
+ # RUN CHM ------------------------------------------------------------------------
78
+ with torch.no_grad():
79
+ corr_matrix = model(src_img_tnsr, tgt_img_tnsr)
80
+ prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False)
81
+
82
+ # VISUALIZATION
83
+ src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
84
+ tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()
85
+
86
+ src_points_converted = []
87
+ w, h = display_transform(source_image).size
88
+
89
+ for x,y in zip(src_points[0], src_points[1]):
90
+ src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])])
91
+
92
+ src_points_converted = np.asarray(src_points_converted[:number_src_points])
93
+ tgt_points_converted = []
94
+
95
+ w, h = display_transform(target_image).size
96
+ for x, y in zip(tgt_points[0], tgt_points[1]):
97
+ tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)])
98
+
99
+ tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points])
100
+
101
+ tgt_grid = []
102
+
103
+ for x, y in zip(tgt_points[0], tgt_points[1]):
104
+ tgt_grid.append([int(((x+1)/2.0)*7),int(((y+1)/2.0)*7)])
105
+
106
+ # PLOT
107
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
108
+
109
+ ax[0].imshow(display_transform(source_image))
110
+ ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points])
111
+ ax[0].set_title('Source')
112
+ ax[0].set_xticks([])
113
+ ax[0].set_yticks([])
114
+
115
+ ax[1].imshow(display_transform(target_image))
116
+ ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points])
117
+ ax[1].set_title('Target')
118
+ ax[1].set_xticks([])
119
+ ax[1].set_yticks([])
120
+
121
+ for TL in range(49):
122
+ ax[0].text(x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color='red', size=10))
123
+
124
+ for TL in range(49):
125
+ ax[1].text(x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f'{str(TL)}', fontdict=dict(color='orange', size=8))
126
+
127
+ plt.tight_layout()
128
+ fig.suptitle('CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ', fontsize=16)
129
+ return fig
130
+
131
+
132
+ # Wrapper
133
+ def generate_correspondences(sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100):
134
+ A = np.linspace(min_x, max_x, 7)
135
+ B = np.linspace(min_y, max_y, 7)
136
+ point_list = list(product(A, B))
137
+ new_points = np.asarray(point_list, dtype=np.float64).T
138
+ return run_chm(sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot)
139
+
140
+
141
+ # GRADIO APP
142
+ iface = gr.Interface(fn=generate_correspondences,
143
+ inputs=[gr.inputs.Image(shape=(240, 240), type='pil'),
144
+ gr.inputs.Image(shape=(240, 240), type='pil'),
145
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinX'),
146
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxX'),
147
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinY'),
148
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxY')], outputs="plot")
149
+ iface.launch()
common/__pycache__/evaluation.cpython-38.pyc ADDED
Binary file (1.3 kB). View file
common/__pycache__/logger.cpython-38.pyc ADDED
Binary file (4.23 kB). View file
common/evaluation.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Evaluates CHMNet with PCK """
2
+
3
+ import torch
4
+
5
+
6
+ class Evaluator:
7
+ r""" Computes evaluation metrics of PCK """
8
+ @classmethod
9
+ def initialize(cls, alpha):
10
+ cls.alpha = torch.tensor(alpha).unsqueeze(1)
11
+
12
+ @classmethod
13
+ def evaluate(cls, prd_kps, batch):
14
+ r""" Compute percentage of correct key-points (PCK) with multiple alpha {0.05, 0.1, 0.15 }"""
15
+
16
+ pcks = []
17
+ for idx, (pk, tk) in enumerate(zip(prd_kps, batch['trg_kps'])):
18
+ pckthres = batch['pckthres'][idx]
19
+ npt = batch['n_pts'][idx]
20
+ prd_kps = pk[:, :npt]
21
+ trg_kps = tk[:, :npt]
22
+
23
+ l2dist = (prd_kps - trg_kps).pow(2).sum(dim=0).pow(0.5).unsqueeze(0).repeat(len(cls.alpha), 1)
24
+ thres = pckthres.expand_as(l2dist).float() * cls.alpha
25
+ pck = torch.le(l2dist, thres).sum(dim=1) / float(npt)
26
+ if len(pck) == 1: pck = pck[0]
27
+ pcks.append(pck)
28
+
29
+ eval_result = {'pck': pcks}
30
+
31
+ return eval_result
32
+
common/logger.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Logging """
2
+
3
+ import datetime
4
+ import logging
5
+ import os
6
+
7
+ from tensorboardX import SummaryWriter
8
+ import torch
9
+
10
+
11
+ class Logger:
12
+ r""" Writes results of training/testing """
13
+ @classmethod
14
+ def initialize(cls, args, training):
15
+ logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
16
+ logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime
17
+ if logpath == '': logpath = logtime
18
+
19
+ cls.logpath = os.path.join('logs', logpath + '.log')
20
+ cls.benchmark = args.benchmark
21
+ os.makedirs(cls.logpath)
22
+
23
+ logging.basicConfig(filemode='w',
24
+ filename=os.path.join(cls.logpath, 'log.txt'),
25
+ level=logging.INFO,
26
+ format='%(message)s',
27
+ datefmt='%m-%d %H:%M:%S')
28
+
29
+ # Console log config
30
+ console = logging.StreamHandler()
31
+ console.setLevel(logging.INFO)
32
+ formatter = logging.Formatter('%(message)s')
33
+ console.setFormatter(formatter)
34
+ logging.getLogger('').addHandler(console)
35
+
36
+ # Tensorboard writer
37
+ cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
38
+
39
+ # Log arguments
40
+ if training:
41
+ logging.info(':======== Convolutional Hough Matching Networks =========')
42
+ for arg_key in args.__dict__:
43
+ logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
44
+ logging.info(':========================================================\n')
45
+
46
+ @classmethod
47
+ def info(cls, msg):
48
+ r""" Writes message to .txt """
49
+ logging.info(msg)
50
+
51
+ @classmethod
52
+ def save_model(cls, model, epoch, val_pck):
53
+ torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt'))
54
+ cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck))
55
+
56
+
57
+ class AverageMeter:
58
+ r""" Stores loss, evaluation results, selected layers """
59
+ def __init__(self, benchamrk):
60
+ r""" Constructor of AverageMeter """
61
+ self.buffer_keys = ['pck']
62
+ self.buffer = {}
63
+ for key in self.buffer_keys:
64
+ self.buffer[key] = []
65
+
66
+ self.loss_buffer = []
67
+
68
+ def update(self, eval_result, loss=None):
69
+ for key in self.buffer_keys:
70
+ self.buffer[key] += eval_result[key]
71
+
72
+ if loss is not None:
73
+ self.loss_buffer.append(loss)
74
+
75
+ def write_result(self, split, epoch):
76
+ msg = '\n*** %s ' % split
77
+ msg += '[@Epoch %02d] ' % epoch
78
+
79
+ if len(self.loss_buffer) > 0:
80
+ msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
81
+
82
+ for key in self.buffer_keys:
83
+ msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
84
+ msg += '***\n'
85
+ Logger.info(msg)
86
+
87
+ def write_process(self, batch_idx, datalen, epoch):
88
+ msg = '[Epoch: %02d] ' % epoch
89
+ msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
90
+ if len(self.loss_buffer) > 0:
91
+ msg += 'Loss: %5.2f ' % self.loss_buffer[-1]
92
+ msg += 'Avg Loss: %5.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
93
+
94
+ for key in self.buffer_keys:
95
+ msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100)
96
+ Logger.info(msg)
97
+
98
+ def write_test_process(self, batch_idx, datalen):
99
+ msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
100
+
101
+ for key in self.buffer_keys:
102
+ if key == 'pck':
103
+ pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100
104
+ val = ''
105
+ for p in pcks:
106
+ val += '%5.2f ' % p.item()
107
+ msg += 'Avg %s: %s ' % (key.upper(), val)
108
+ else:
109
+ msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
110
+ Logger.info(msg)
111
+
112
+ def get_test_result(self):
113
+ result = {}
114
+ for key in self.buffer_keys:
115
+ result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100
116
+
117
+ return result
data/__pycache__/dataset.cpython-38.pyc ADDED
Binary file (3.95 kB). View file
data/__pycache__/download.cpython-38.pyc ADDED
Binary file (2.56 kB). View file
data/__pycache__/pfpascal.cpython-38.pyc ADDED
Binary file (3.91 kB). View file
data/__pycache__/pfwillow.cpython-38.pyc ADDED
Binary file (2.85 kB). View file
data/__pycache__/spair.cpython-38.pyc ADDED
Binary file (5.51 kB). View file
data/dataset.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Superclass for semantic correspondence datasets """
2
+
3
+ import os
4
+
5
+ from torch.utils.data import Dataset
6
+ from torchvision import transforms
7
+ from PIL import Image
8
+ import torch
9
+
10
+ from model.base.geometry import Geometry
11
+
12
+
13
+ class CorrespondenceDataset(Dataset):
14
+ r""" Parent class of PFPascal, PFWillow, and SPair """
15
+ def __init__(self, benchmark, datapath, thres, split):
16
+ r""" CorrespondenceDataset constructor """
17
+ super(CorrespondenceDataset, self).__init__()
18
+
19
+ # {Directory name, Layout path, Image path, Annotation path, PCK threshold}
20
+ self.metadata = {
21
+ 'pfwillow': ('PF-WILLOW',
22
+ 'test_pairs.csv',
23
+ '',
24
+ '',
25
+ 'bbox'),
26
+ 'pfpascal': ('PF-PASCAL',
27
+ '_pairs.csv',
28
+ 'JPEGImages',
29
+ 'Annotations',
30
+ 'img'),
31
+ 'spair': ('SPair-71k',
32
+ 'Layout/large',
33
+ 'JPEGImages',
34
+ 'PairAnnotation',
35
+ 'bbox')
36
+ }
37
+
38
+ # Directory path for train, val, or test splits
39
+ base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0])
40
+ if benchmark == 'pfpascal':
41
+ self.spt_path = os.path.join(base_path, split+'_pairs.csv')
42
+ elif benchmark == 'spair':
43
+ self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt')
44
+ else:
45
+ self.spt_path = os.path.join(base_path, self.metadata[benchmark][1])
46
+
47
+ # Directory path for images
48
+ self.img_path = os.path.join(base_path, self.metadata[benchmark][2])
49
+
50
+ # Directory path for annotations
51
+ if benchmark == 'spair':
52
+ self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split)
53
+ else:
54
+ self.ann_path = os.path.join(base_path, self.metadata[benchmark][3])
55
+
56
+ # Miscellaneous
57
+ self.max_pts = 40
58
+ self.split = split
59
+ self.img_size = Geometry.img_size
60
+ self.benchmark = benchmark
61
+ self.range_ts = torch.arange(self.max_pts)
62
+ self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres
63
+ self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)),
64
+ transforms.ToTensor(),
65
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
66
+ std=[0.229, 0.224, 0.225])])
67
+
68
+ # To get initialized in subclass constructors
69
+ self.train_data = []
70
+ self.src_imnames = []
71
+ self.trg_imnames = []
72
+ self.cls = []
73
+ self.cls_ids = []
74
+ self.src_kps = []
75
+ self.trg_kps = []
76
+
77
+ def __len__(self):
78
+ r""" Returns the number of pairs """
79
+ return len(self.train_data)
80
+
81
+ def __getitem__(self, idx):
82
+ r""" Constructs and return a batch """
83
+
84
+ # Image name
85
+ batch = dict()
86
+ batch['src_imname'] = self.src_imnames[idx]
87
+ batch['trg_imname'] = self.trg_imnames[idx]
88
+
89
+ # Object category
90
+ batch['category_id'] = self.cls_ids[idx]
91
+ batch['category'] = self.cls[batch['category_id']]
92
+
93
+ # Image as numpy (original width, original height)
94
+ src_pil = self.get_image(self.src_imnames, idx)
95
+ trg_pil = self.get_image(self.trg_imnames, idx)
96
+ batch['src_imsize'] = src_pil.size
97
+ batch['trg_imsize'] = trg_pil.size
98
+
99
+ # Image as tensor
100
+ batch['src_img'] = self.transform(src_pil)
101
+ batch['trg_img'] = self.transform(trg_pil)
102
+
103
+ # Key-points (re-scaled)
104
+ batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size)
105
+ batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size)
106
+ batch['n_pts'] = torch.tensor(num_pts)
107
+
108
+ # Total number of pairs in training split
109
+ batch['datalen'] = len(self.train_data)
110
+
111
+ return batch
112
+
113
+ def get_image(self, imnames, idx):
114
+ r""" Reads PIL image from path """
115
+ path = os.path.join(self.img_path, imnames[idx])
116
+ return Image.open(path).convert('RGB')
117
+
118
+ def get_pckthres(self, batch, imsize):
119
+ r""" Computes PCK threshold """
120
+ if self.thres == 'bbox':
121
+ bbox = batch['trg_bbox'].clone()
122
+ bbox_w = (bbox[2] - bbox[0])
123
+ bbox_h = (bbox[3] - bbox[1])
124
+ pckthres = torch.max(bbox_w, bbox_h)
125
+ elif self.thres == 'img':
126
+ imsize_t = batch['trg_img'].size()
127
+ pckthres = torch.tensor(max(imsize_t[1], imsize_t[2]))
128
+ else:
129
+ raise Exception('Invalid pck threshold type: %s' % self.thres)
130
+ return pckthres.float()
131
+
132
+ def get_points(self, pts_list, idx, org_imsize):
133
+ r""" Returns key-points of an image """
134
+ xy, n_pts = pts_list[idx].size()
135
+ pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
136
+ x_crds = pts_list[idx][0] * (self.img_size / org_imsize[0])
137
+ y_crds = pts_list[idx][1] * (self.img_size / org_imsize[1])
138
+ kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
139
+
140
+ return kps, n_pts
data/download.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Functions to download semantic correspondence datasets """
2
+
3
+ import tarfile
4
+ import os
5
+
6
+ import requests
7
+
8
+ from . import pfpascal
9
+ from . import pfwillow
10
+ from . import spair
11
+
12
+
13
+ def load_dataset(benchmark, datapath, thres, split='test'):
14
+ r""" Instantiate a correspondence dataset """
15
+ correspondence_benchmark = {
16
+ 'spair': spair.SPairDataset,
17
+ 'pfpascal': pfpascal.PFPascalDataset,
18
+ 'pfwillow': pfwillow.PFWillowDataset
19
+ }
20
+
21
+ dataset = correspondence_benchmark.get(benchmark)
22
+ if dataset is None:
23
+ raise Exception('Invalid benchmark dataset %s.' % benchmark)
24
+
25
+ return dataset(benchmark, datapath, thres, split)
26
+
27
+
28
+ def download_from_google(token_id, filename):
29
+ r""" Download desired filename from Google drive """
30
+
31
+ print('Downloading %s ...' % os.path.basename(filename))
32
+
33
+ url = 'https://docs.google.com/uc?export=download'
34
+ destination = filename + '.tar.gz'
35
+ session = requests.Session()
36
+
37
+ response = session.get(url, params={'id': token_id}, stream=True)
38
+ token = get_confirm_token(response)
39
+
40
+ if token:
41
+ params = {'id': token_id, 'confirm': token}
42
+ response = session.get(url, params=params, stream=True)
43
+ save_response_content(response, destination)
44
+ file = tarfile.open(destination, 'r:gz')
45
+
46
+ print("Extracting %s ..." % destination)
47
+ file.extractall(filename)
48
+ file.close()
49
+
50
+ os.remove(destination)
51
+ os.rename(filename, filename + '_tmp')
52
+ os.rename(os.path.join(filename + '_tmp', os.path.basename(filename)), filename)
53
+ os.rmdir(filename+'_tmp')
54
+
55
+
56
+ def get_confirm_token(response):
57
+ r"""Retrieves confirm token"""
58
+ for key, value in response.cookies.items():
59
+ if key.startswith('download_warning'):
60
+ return value
61
+
62
+ return None
63
+
64
+
65
+ def save_response_content(response, destination):
66
+ r"""Saves the response to the destination"""
67
+ chunk_size = 32768
68
+
69
+ with open(destination, "wb") as file:
70
+ for chunk in response.iter_content(chunk_size):
71
+ if chunk:
72
+ file.write(chunk)
73
+
74
+
75
+ def download_dataset(datapath, benchmark):
76
+ r"""Downloads semantic correspondence benchmark dataset from Google drive"""
77
+ if not os.path.isdir(datapath):
78
+ os.mkdir(datapath)
79
+
80
+ file_data = {
81
+ # 'spair': ('1s73NVEFPro260H1tXxCh1ain7oApR8of', 'SPair-71k') old version
82
+ 'spair': ('1KSvB0k2zXA06ojWNvFjBv0Ake426Y76k', 'SPair-71k'),
83
+ 'pfpascal': ('1OOwpGzJnTsFXYh-YffMQ9XKM_Kl_zdzg', 'PF-PASCAL'),
84
+ 'pfwillow': ('1tDP0y8RO5s45L-vqnortRaieiWENQco_', 'PF-WILLOW')
85
+ }
86
+
87
+ file_id, filename = file_data[benchmark]
88
+ abs_filepath = os.path.join(datapath, filename)
89
+
90
+ if not os.path.isdir(abs_filepath):
91
+ download_from_google(file_id, abs_filepath)
data/pfpascal.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" PF-PASCAL dataset """
2
+
3
+ import os
4
+
5
+ import scipy.io as sio
6
+ import pandas as pd
7
+ import numpy as np
8
+ import torch
9
+
10
+ from .dataset import CorrespondenceDataset
11
+
12
+
13
+ class PFPascalDataset(CorrespondenceDataset):
14
+
15
+ def __init__(self, benchmark, datapath, thres, split):
16
+ r""" PF-PASCAL dataset constructor """
17
+ super(PFPascalDataset, self).__init__(benchmark, datapath, thres, split)
18
+
19
+ self.train_data = pd.read_csv(self.spt_path)
20
+ self.src_imnames = np.array(self.train_data.iloc[:, 0])
21
+ self.trg_imnames = np.array(self.train_data.iloc[:, 1])
22
+ self.cls = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
23
+ 'bus', 'car', 'cat', 'chair', 'cow',
24
+ 'diningtable', 'dog', 'horse', 'motorbike', 'person',
25
+ 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
26
+ self.cls_ids = self.train_data.iloc[:, 2].values.astype('int') - 1
27
+
28
+ if split == 'trn':
29
+ self.flip = self.train_data.iloc[:, 3].values.astype('int')
30
+ self.src_kps = []
31
+ self.trg_kps = []
32
+ self.src_bbox = []
33
+ self.trg_bbox = []
34
+ for src_imname, trg_imname, cls in zip(self.src_imnames, self.trg_imnames, self.cls_ids):
35
+ src_anns = os.path.join(self.ann_path, self.cls[cls],
36
+ os.path.basename(src_imname))[:-4] + '.mat'
37
+ trg_anns = os.path.join(self.ann_path, self.cls[cls],
38
+ os.path.basename(trg_imname))[:-4] + '.mat'
39
+
40
+ src_kp = torch.tensor(read_mat(src_anns, 'kps')).float()
41
+ trg_kp = torch.tensor(read_mat(trg_anns, 'kps')).float()
42
+ src_box = torch.tensor(read_mat(src_anns, 'bbox')[0].astype(float))
43
+ trg_box = torch.tensor(read_mat(trg_anns, 'bbox')[0].astype(float))
44
+
45
+ src_kps = []
46
+ trg_kps = []
47
+ for src_kk, trg_kk in zip(src_kp, trg_kp):
48
+ if len(torch.isnan(src_kk).nonzero()) != 0 or \
49
+ len(torch.isnan(trg_kk).nonzero()) != 0:
50
+ continue
51
+ else:
52
+ src_kps.append(src_kk)
53
+ trg_kps.append(trg_kk)
54
+ self.src_kps.append(torch.stack(src_kps).t())
55
+ self.trg_kps.append(torch.stack(trg_kps).t())
56
+ self.src_bbox.append(src_box)
57
+ self.trg_bbox.append(trg_box)
58
+
59
+ self.src_imnames = list(map(lambda x: os.path.basename(x), self.src_imnames))
60
+ self.trg_imnames = list(map(lambda x: os.path.basename(x), self.trg_imnames))
61
+
62
+ def __getitem__(self, idx):
63
+ r""" Constructs and returns a batch for PF-PASCAL dataset """
64
+ batch = super(PFPascalDataset, self).__getitem__(idx)
65
+
66
+ # Object bounding-box (resized following self.img_size)
67
+ batch['src_bbox'] = self.get_bbox(self.src_bbox, idx, batch['src_imsize'])
68
+ batch['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, batch['trg_imsize'])
69
+ batch['pckthres'] = self.get_pckthres(batch, batch['trg_imsize'])
70
+
71
+ # Horizontal flipping key-points during training
72
+ if self.split == 'trn' and self.flip[idx]:
73
+ self.horizontal_flip(batch)
74
+ batch['flip'] = 1
75
+ else:
76
+ batch['flip'] = 0
77
+
78
+ return batch
79
+
80
+ def get_bbox(self, bbox_list, idx, imsize):
81
+ r""" Returns object bounding-box """
82
+ bbox = bbox_list[idx].clone()
83
+ bbox[0::2] *= (self.img_size / imsize[0])
84
+ bbox[1::2] *= (self.img_size / imsize[1])
85
+ return bbox
86
+
87
+ def horizontal_flip(self, batch):
88
+ tmp = batch['src_bbox'][0].clone()
89
+ batch['src_bbox'][0] = batch['src_img'].size(2) - batch['src_bbox'][2]
90
+ batch['src_bbox'][2] = batch['src_img'].size(2) - tmp
91
+
92
+ tmp = batch['trg_bbox'][0].clone()
93
+ batch['trg_bbox'][0] = batch['trg_img'].size(2) - batch['trg_bbox'][2]
94
+ batch['trg_bbox'][2] = batch['trg_img'].size(2) - tmp
95
+
96
+ batch['src_kps'][0][:batch['n_pts']] = batch['src_img'].size(2) - batch['src_kps'][0][:batch['n_pts']]
97
+ batch['trg_kps'][0][:batch['n_pts']] = batch['trg_img'].size(2) - batch['trg_kps'][0][:batch['n_pts']]
98
+
99
+ batch['src_img'] = torch.flip(batch['src_img'], dims=(2,))
100
+ batch['trg_img'] = torch.flip(batch['trg_img'], dims=(2,))
101
+
102
+
103
+ def read_mat(path, obj_name):
104
+ r""" Reads specified objects from Matlab data file. (.mat) """
105
+ mat_contents = sio.loadmat(path)
106
+ mat_obj = mat_contents[obj_name]
107
+
108
+ return mat_obj
data/pfwillow.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" PF-WILLOW dataset """
2
+
3
+ import os
4
+
5
+ import pandas as pd
6
+ import numpy as np
7
+ import torch
8
+
9
+ from .dataset import CorrespondenceDataset
10
+
11
+
12
+ class PFWillowDataset(CorrespondenceDataset):
13
+
14
+ def __init__(self, benchmark, datapath, thres, split):
15
+ r"""PF-WILLOW dataset constructor"""
16
+ super(PFWillowDataset, self).__init__(benchmark, datapath, thres, split)
17
+
18
+ self.train_data = pd.read_csv(self.spt_path)
19
+ self.src_imnames = np.array(self.train_data.iloc[:, 0])
20
+ self.trg_imnames = np.array(self.train_data.iloc[:, 1])
21
+ self.src_kps = self.train_data.iloc[:, 2:22].values
22
+ self.trg_kps = self.train_data.iloc[:, 22:].values
23
+ self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)',
24
+ 'motorbike(G)', 'motorbike(M)', 'motorbike(S)',
25
+ 'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)']
26
+ self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames))
27
+ self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames))
28
+ self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames))
29
+
30
+ def __getitem__(self, idx):
31
+ r""" Constructs and returns a batch for PF-WILLOW dataset """
32
+ batch = super(PFWillowDataset, self).__getitem__(idx)
33
+ batch['pckthres'] = self.get_pckthres(batch)
34
+
35
+ return batch
36
+
37
+ def get_pckthres(self, batch):
38
+ r""" Computes PCK threshold """
39
+ if self.thres == 'bbox':
40
+ return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone()
41
+ elif self.thres == 'img':
42
+ return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2]))
43
+ else:
44
+ raise Exception('Invalid pck evaluation level: %s' % self.thres)
45
+
46
+ def get_points(self, pts_list, idx, org_imsize):
47
+ r""" Returns key-points of an image """
48
+ point_coords = pts_list[idx, :].reshape(2, 10)
49
+ point_coords = torch.tensor(point_coords.astype(np.float32))
50
+ xy, n_pts = point_coords.size()
51
+ pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
52
+ x_crds = point_coords[0] * (self.img_size / org_imsize[0])
53
+ y_crds = point_coords[1] * (self.img_size / org_imsize[1])
54
+ kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
55
+
56
+ return kps, n_pts
data/spair.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" SPair-71k dataset """
2
+
3
+ import json
4
+ import glob
5
+ import os
6
+
7
+ import torch.nn.functional as F
8
+ import torch
9
+ from PIL import Image
10
+ import numpy as np
11
+
12
+ from .dataset import CorrespondenceDataset
13
+
14
+
15
+ class SPairDataset(CorrespondenceDataset):
16
+
17
+ def __init__(self, benchmark, datapath, thres, split):
18
+ r""" SPair-71k dataset constructor """
19
+ super(SPairDataset, self).__init__(benchmark, datapath, thres, split)
20
+
21
+ self.train_data = open(self.spt_path).read().split('\n')
22
+ self.train_data = self.train_data[:len(self.train_data) - 1]
23
+ self.src_imnames = list(map(lambda x: x.split('-')[1] + '.jpg', self.train_data))
24
+ self.trg_imnames = list(map(lambda x: x.split('-')[2].split(':')[0] + '.jpg', self.train_data))
25
+ self.seg_path = os.path.abspath(os.path.join(self.img_path, os.pardir, 'Segmentation'))
26
+ self.cls = os.listdir(self.img_path)
27
+ self.cls.sort()
28
+
29
+ anntn_files = []
30
+ for data_name in self.train_data:
31
+ anntn_files.append(glob.glob('%s/%s.json' % (self.ann_path, data_name))[0])
32
+ anntn_files = list(map(lambda x: json.load(open(x)), anntn_files))
33
+ self.src_kps = list(map(lambda x: torch.tensor(x['src_kps']).t().float(), anntn_files))
34
+ self.trg_kps = list(map(lambda x: torch.tensor(x['trg_kps']).t().float(), anntn_files))
35
+ self.src_bbox = list(map(lambda x: torch.tensor(x['src_bndbox']).float(), anntn_files))
36
+ self.trg_bbox = list(map(lambda x: torch.tensor(x['trg_bndbox']).float(), anntn_files))
37
+ self.cls_ids = list(map(lambda x: self.cls.index(x['category']), anntn_files))
38
+
39
+ self.vpvar = list(map(lambda x: torch.tensor(x['viewpoint_variation']), anntn_files))
40
+ self.scvar = list(map(lambda x: torch.tensor(x['scale_variation']), anntn_files))
41
+ self.trncn = list(map(lambda x: torch.tensor(x['truncation']), anntn_files))
42
+ self.occln = list(map(lambda x: torch.tensor(x['occlusion']), anntn_files))
43
+
44
+ def __getitem__(self, idx):
45
+ r""" Construct and return a batch for SPair-71k dataset """
46
+ sample = super(SPairDataset, self).__getitem__(idx)
47
+
48
+ sample['src_mask'] = self.get_mask(sample, sample['src_imname'])
49
+ sample['trg_mask'] = self.get_mask(sample, sample['trg_imname'])
50
+
51
+ sample['src_bbox'] = self.get_bbox(self.src_bbox, idx, sample['src_imsize'])
52
+ sample['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, sample['trg_imsize'])
53
+ sample['pckthres'] = self.get_pckthres(sample, sample['trg_imsize'])
54
+
55
+ sample['vpvar'] = self.vpvar[idx]
56
+ sample['scvar'] = self.scvar[idx]
57
+ sample['trncn'] = self.trncn[idx]
58
+ sample['occln'] = self.occln[idx]
59
+
60
+ return sample
61
+
62
+ def get_mask(self, sample, imname):
63
+ mask_path = os.path.join(self.seg_path, sample['category'], imname.split('.')[0] + '.png')
64
+
65
+ tensor_mask = torch.tensor(np.array(Image.open(mask_path)))
66
+
67
+ class_dict = {'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4,
68
+ 'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9,
69
+ 'diningtable': 10, 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14,
70
+ 'pottedplant': 15, 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19}
71
+
72
+ class_id = class_dict[sample['category']] + 1
73
+ tensor_mask[tensor_mask != class_id] = 0
74
+ tensor_mask[tensor_mask == class_id] = 255
75
+
76
+ tensor_mask = F.interpolate(tensor_mask.unsqueeze(0).unsqueeze(0).float(),
77
+ size=(self.img_size, self.img_size),
78
+ mode='bilinear', align_corners=True).int().squeeze()
79
+
80
+ return tensor_mask
81
+
82
+ def get_image(self, img_names, idx):
83
+ r""" Return image tensor """
84
+ path = os.path.join(self.img_path, self.cls[self.cls_ids[idx]], img_names[idx])
85
+
86
+ return Image.open(path).convert('RGB')
87
+
88
+ def get_pckthres(self, sample, imsize):
89
+ r""" Compute PCK threshold """
90
+ return super(SPairDataset, self).get_pckthres(sample, imsize)
91
+
92
+ def get_points(self, pts_list, idx, imsize):
93
+ r""" Return key-points of an image """
94
+ return super(SPairDataset, self).get_points(pts_list, idx, imsize)
95
+
96
+ def match_idx(self, kps, n_pts):
97
+ r""" Sample the nearst feature (receptive field) indices """
98
+ return super(SPairDataset, self).match_idx(kps, n_pts)
99
+
100
+ def get_bbox(self, bbox_list, idx, imsize):
101
+ r""" Return object bounding-box """
102
+ bbox = bbox_list[idx].clone()
103
+ bbox[0::2] *= (self.img_size / imsize[0])
104
+ bbox[1::2] *= (self.img_size / imsize[1])
105
+ return bbox
model/__pycache__/chmlearner.cpython-38.pyc ADDED
Binary file (1.85 kB). View file
model/__pycache__/chmnet.cpython-38.pyc ADDED
Binary file (1.8 kB). View file
model/base/__pycache__/backbone.cpython-38.pyc ADDED
Binary file (4.14 kB). View file
model/base/__pycache__/chm.cpython-38.pyc ADDED
Binary file (6.85 kB). View file
model/base/__pycache__/chm_kernel.cpython-38.pyc ADDED
Binary file (2.03 kB). View file
model/base/__pycache__/correlation.cpython-38.pyc ADDED
Binary file (2.09 kB). View file
model/base/__pycache__/geometry.cpython-38.pyc ADDED
Binary file (4.69 kB). View file
model/base/backbone.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" ResNet-101 backbone network """
2
+
3
+ import torch.utils.model_zoo as model_zoo
4
+ import torch.nn as nn
5
+ import torch
6
+
7
+
8
+ __all__ = ['Backbone', 'resnet101']
9
+
10
+
11
+ model_urls = {
12
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
13
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
14
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
15
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
16
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
17
+ }
18
+
19
+
20
+ def conv3x3(in_planes, out_planes, stride=1):
21
+ r""" 3x3 convolution with padding """
22
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
23
+ padding=1, groups=2, bias=False)
24
+
25
+
26
+ def conv1x1(in_planes, out_planes, stride=1):
27
+ r""" 1x1 convolution """
28
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=2, bias=False)
29
+
30
+
31
+ class Bottleneck(nn.Module):
32
+ expansion = 4
33
+
34
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
35
+ super(Bottleneck, self).__init__()
36
+ self.conv1 = conv1x1(inplanes, planes)
37
+ self.bn1 = nn.BatchNorm2d(planes)
38
+ self.conv2 = conv3x3(planes, planes, stride)
39
+ self.bn2 = nn.BatchNorm2d(planes)
40
+ self.conv3 = conv1x1(planes, planes * self.expansion)
41
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
42
+ self.relu = nn.ReLU(inplace=True)
43
+ self.downsample = downsample
44
+ self.stride = stride
45
+
46
+ def forward(self, x):
47
+ identity = x
48
+
49
+ out = self.conv1(x)
50
+ out = self.bn1(out)
51
+ out = self.relu(out)
52
+
53
+ out = self.conv2(out)
54
+ out = self.bn2(out)
55
+ out = self.relu(out)
56
+
57
+ out = self.conv3(out)
58
+ out = self.bn3(out)
59
+
60
+ if self.downsample is not None:
61
+ identity = self.downsample(x)
62
+
63
+ out += identity
64
+ out = self.relu(out)
65
+
66
+ return out
67
+
68
+
69
+ class Backbone(nn.Module):
70
+ def __init__(self, block, layers, zero_init_residual=False):
71
+ super(Backbone, self).__init__()
72
+
73
+ self.inplanes = 128
74
+ self.conv1 = nn.Conv2d(6, 128, kernel_size=7, stride=2, padding=3, groups=2,
75
+ bias=False)
76
+ self.bn1 = nn.BatchNorm2d(128)
77
+ self.relu = nn.ReLU(inplace=True)
78
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
79
+ self.layer1 = self._make_layer(block, 128, layers[0])
80
+ self.layer2 = self._make_layer(block, 256, layers[1], stride=2)
81
+ self.layer3 = self._make_layer(block, 512, layers[2], stride=2)
82
+ self.layer4 = self._make_layer(block, 1024, layers[3], stride=2)
83
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
84
+ self.fc = nn.Linear(512 * block.expansion, 1000)
85
+
86
+ for m in self.modules():
87
+ if isinstance(m, nn.Conv2d):
88
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
89
+ elif isinstance(m, nn.BatchNorm2d):
90
+ nn.init.constant_(m.weight, 1)
91
+ nn.init.constant_(m.bias, 0)
92
+
93
+ # Zero-initialize the last BN in each residual branch,
94
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
95
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
96
+ if zero_init_residual:
97
+ for m in self.modules():
98
+ if isinstance(m, Bottleneck):
99
+ nn.init.constant_(m.bn3.weight, 0)
100
+
101
+ def _make_layer(self, block, planes, blocks, stride=1):
102
+ downsample = None
103
+ if stride != 1 or self.inplanes != planes * block.expansion:
104
+ downsample = nn.Sequential(
105
+ conv1x1(self.inplanes, planes * block.expansion, stride),
106
+ nn.BatchNorm2d(planes * block.expansion),
107
+ )
108
+
109
+ layers = []
110
+ layers.append(block(self.inplanes, planes, stride, downsample))
111
+ self.inplanes = planes * block.expansion
112
+ for _ in range(1, blocks):
113
+ layers.append(block(self.inplanes, planes))
114
+
115
+ return nn.Sequential(*layers)
116
+
117
+
118
+ def resnet101(pretrained=False, **kwargs):
119
+ """Constructs a ResNet-101 model.
120
+
121
+ Args:
122
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
123
+ """
124
+ model = Backbone(Bottleneck, [3, 4, 23, 3], **kwargs)
125
+ if pretrained:
126
+ weights = model_zoo.load_url(model_urls['resnet101'])
127
+
128
+ for key in weights:
129
+ if key.split('.')[0] == 'fc':
130
+ weights[key] = weights[key].clone()
131
+ continue
132
+ weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0)
133
+
134
+ model.load_state_dict(weights)
135
+ return model
136
+
model/base/chm.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" 4D and 6D convolutional Hough matching layers """
2
+
3
+ from torch.nn.modules.conv import _ConvNd
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+ import torch
7
+
8
+ from common.logger import Logger
9
+ from . import chm_kernel
10
+
11
+
12
+ def fast4d(corr, kernel, bias=None):
13
+ r""" Optimized implementation of 4D convolution """
14
+ bsz, ch, srch, srcw, trgh, trgw = corr.size()
15
+ out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size()
16
+ psz = kernel_size // 2
17
+
18
+ out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw))
19
+ corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw)
20
+
21
+ for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)):
22
+ inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz)
23
+ inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous()
24
+
25
+ add_sid = max(psz - pidx, 0)
26
+ add_fid = min(srch, srch + psz - pidx)
27
+ slc_sid = max(pidx - psz, 0)
28
+ slc_fid = min(srch, srch - psz + pidx)
29
+
30
+ out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :]
31
+
32
+ if bias is not None:
33
+ out_corr += bias.view(1, out_channels, 1, 1, 1, 1)
34
+
35
+ return out_corr
36
+
37
+
38
+ def fast6d(corr, kernel, bias, diagonal_idx):
39
+ r""" Optimized implementation of 6D convolutional Hough matching
40
+ NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5).
41
+ r"""
42
+ bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size()
43
+ _, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size()
44
+ corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d)
45
+ kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1)
46
+ corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d)
47
+ corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\
48
+ contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d)
49
+
50
+ ndiag = s6d + (ks6d // 2) * 2
51
+ first_sum = []
52
+ for didx in diagonal_idx:
53
+ first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1))
54
+ first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d)
55
+
56
+ corr = []
57
+ for didx in diagonal_idx:
58
+ corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1))
59
+ sidx = ks6d // 2
60
+ eidx = ndiag - sidx
61
+ corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous()
62
+ corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1)
63
+
64
+ reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long()
65
+ corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\
66
+ view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d)
67
+ return corr
68
+
69
+ def init_param_idx4d(param_dict):
70
+ param_idx = []
71
+ for key in param_dict:
72
+ curr_offset = int(key.split('_')[-1])
73
+ param_idx.append(torch.tensor(param_dict[key]))
74
+ return param_idx
75
+
76
+ class CHM4d(_ConvNd):
77
+ r""" 4D convolutional Hough matching layer
78
+ NOTE: this function only supports in_channels=1 and out_channels=1.
79
+ r"""
80
+ def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True):
81
+ super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4,
82
+ (1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4,
83
+ 1, bias, padding_mode='zeros')
84
+
85
+ # Zero kernel initialization
86
+ self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d))
87
+ self.nkernels = in_channels * out_channels
88
+
89
+ # Initialize kernel indices
90
+ param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
91
+ param_shared = param_dict4d is not None
92
+
93
+ if param_shared:
94
+ # Initialize the shared parameters (multiplied by the number of times being shared)
95
+ self.param_idx = init_param_idx4d(param_dict4d)
96
+ weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3
97
+ for weight, param_idx in zip(weights.sort()[0], self.param_idx):
98
+ weight *= len(param_idx)
99
+ self.weight = nn.Parameter(weights)
100
+ else: # full kernel initialziation
101
+ self.param_idx = None
102
+ self.weight = nn.Parameter(torch.abs(self.weight))
103
+ if bias: self.bias = nn.Parameter(torch.tensor(0.0))
104
+ Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1))))
105
+
106
+ def forward(self, x):
107
+ kernel = self.init_kernel()
108
+ x = fast4d(x, kernel, self.bias)
109
+ return x
110
+
111
+ def init_kernel(self):
112
+ # Initialize CHM kernel (divided by the number of times being shared)
113
+ ksz = self.kernel_size[-1]
114
+ if self.param_idx is None:
115
+ kernel = self.weight
116
+ else:
117
+ kernel = torch.zeros_like(self.zero_kernel4d)
118
+ for idx, pdx in enumerate(self.param_idx):
119
+ kernel = kernel.view(-1, ksz, ksz, ksz, ksz)
120
+ for jdx, kernel_single in enumerate(kernel):
121
+ weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx)
122
+ kernel_single.view(-1)[pdx] += weight
123
+ kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz)
124
+ return kernel
125
+
126
+
127
+ class CHM6d(_ConvNd):
128
+ r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5)
129
+ NOTE: this function only supports in_channels=1 and out_channels=1.
130
+ r"""
131
+ def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype):
132
+ kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)
133
+ super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6,
134
+ (0,) * 6, (1,) * 6, False, (0,) * 6,
135
+ 1, bias=True, padding_mode='zeros')
136
+
137
+ # Zero kernel initialization
138
+ self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d))
139
+ self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d))
140
+ self.nkernels = in_channels * out_channels
141
+
142
+ # Initialize kernel indices
143
+ # Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space)
144
+ self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]]
145
+ param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
146
+ param_shared = param_dict4d is not None
147
+
148
+ if param_shared: # psi & iso kernel initialization
149
+ if ktype == 'psi':
150
+ self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]]
151
+ elif ktype == 'iso':
152
+ self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]]
153
+ self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d]
154
+
155
+ # Initialize the shared parameters (multiplied by the number of times being shared)
156
+ self.param_idx = init_param_idx4d(param_dict4d)
157
+ self.param = []
158
+ for param_dict6d in self.param_dict6d:
159
+ weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3
160
+ for weight, param_idx in zip(weights, self.param_idx):
161
+ weight *= (len(param_idx) * len(param_dict6d))
162
+ self.param.append(nn.Parameter(weights))
163
+ self.param = nn.ParameterList(self.param)
164
+ else: # full kernel initialziation
165
+ self.param_idx = None
166
+ self.param = nn.Parameter(torch.abs(self.weight) * 1e-3)
167
+ Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param])))
168
+ self.weight = None
169
+
170
+ def forward(self, corr):
171
+ kernel = self.init_kernel()
172
+ corr = fast6d(corr, kernel, self.bias, self.diagonal_idx)
173
+ return corr
174
+
175
+ def init_kernel(self):
176
+ # Initialize CHM kernel (divided by the number of times being shared)
177
+ if self.param_idx is None:
178
+ return self.param
179
+
180
+ kernel6d = torch.zeros_like(self.zero_kernel6d)
181
+ for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)):
182
+ ksz4d = self.kernel_size[-1]
183
+ kernel4d = torch.zeros_like(self.zero_kernel4d)
184
+ for jdx, pdx in enumerate(self.param_idx):
185
+ kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d))
186
+ kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d)
187
+ kernel6d = kernel6d.unsqueeze(0).unsqueeze(0)
188
+
189
+ return kernel6d
190
+
model/base/chm_kernel.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" CHM 4D kernel (psi, iso, and full) generator """
2
+
3
+ import torch
4
+
5
+ from .geometry import Geometry
6
+
7
+
8
+ class KernelGenerator:
9
+ def __init__(self, ksz, ktype):
10
+ self.ksz = ksz
11
+ self.idx4d = Geometry.init_idx4d(ksz)
12
+ self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
13
+ self.center = (ksz // 2, ksz // 2)
14
+ self.ktype = ktype
15
+
16
+ def quadrant(self, crd):
17
+ if crd[0] < self.center[0]:
18
+ horz_quad = -1
19
+ elif crd[0] < self.center[0]:
20
+ horz_quad = 1
21
+ else:
22
+ horz_quad = 0
23
+
24
+ if crd[1] < self.center[1]:
25
+ vert_quad = -1
26
+ elif crd[1] < self.center[1]:
27
+ vert_quad = 1
28
+ else:
29
+ vert_quad = 0
30
+
31
+ return horz_quad, vert_quad
32
+
33
+ def generate(self):
34
+ return None if self.ktype == 'full' else self.generate_chm_kernel()
35
+
36
+ def generate_chm_kernel(self):
37
+ param_dict = {}
38
+ for idx in self.idx4d:
39
+ src_i, src_j, trg_i, trg_j = idx
40
+ d_tail = Geometry.get_distance((src_i, src_j), self.center)
41
+ d_head = Geometry.get_distance((trg_i, trg_j), self.center)
42
+ d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
43
+ horz_quad, vert_quad = self.quadrant((src_j, src_i))
44
+
45
+ src_crd = (src_i, src_j)
46
+ trg_crd = (trg_i, trg_j)
47
+
48
+ key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
49
+ coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)
50
+
51
+ if param_dict.get(key) is None: param_dict[key] = []
52
+ param_dict[key].append(coord1d)
53
+
54
+ return param_dict
55
+
56
+ def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):
57
+
58
+ if self.ktype == 'iso':
59
+ return '%d' % d_off
60
+ elif self.ktype == 'psi':
61
+ d_max = max(d_head, d_tail)
62
+ d_min = min(d_head, d_tail)
63
+ return '%d_%d_%d' % (d_max, d_min, d_off)
64
+ else:
65
+ raise Exception('not implemented.')
66
+
model/base/correlation.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Provides functions that creates/manipulates correlation matrices """
2
+
3
+ import math
4
+
5
+ from torch.nn.functional import interpolate as resize
6
+ import torch
7
+
8
+ from .geometry import Geometry
9
+
10
+
11
+ class Correlation:
12
+
13
+ @classmethod
14
+ def mutual_nn_filter(cls, correlation_matrix, eps=1e-30):
15
+ r""" Mutual nearest neighbor filtering (Rocco et al. NeurIPS'18 )"""
16
+ corr_src_max = torch.max(correlation_matrix, dim=2, keepdim=True)[0]
17
+ corr_trg_max = torch.max(correlation_matrix, dim=1, keepdim=True)[0]
18
+ corr_src_max[corr_src_max == 0] += eps
19
+ corr_trg_max[corr_trg_max == 0] += eps
20
+
21
+ corr_src = correlation_matrix / corr_src_max
22
+ corr_trg = correlation_matrix / corr_trg_max
23
+
24
+ return correlation_matrix * (corr_src * corr_trg)
25
+
26
+ @classmethod
27
+ def build_correlation6d(self, src_feat, trg_feat, scales, conv2ds):
28
+ r""" Build 6-dimensional correlation tensor """
29
+
30
+ bsz, _, side, side = src_feat.size()
31
+
32
+ # Construct feature pairs with multiple scales
33
+ _src_feats = []
34
+ _trg_feats = []
35
+ for scale, conv in zip(scales, conv2ds):
36
+ s = (round(side * math.sqrt(scale)),) * 2
37
+ _src_feat = conv(resize(src_feat, s, mode='bilinear', align_corners=True))
38
+ _trg_feat = conv(resize(trg_feat, s, mode='bilinear', align_corners=True))
39
+ _src_feats.append(_src_feat)
40
+ _trg_feats.append(_trg_feat)
41
+
42
+ # Build multiple 4-dimensional correlation tensor
43
+ corr6d = []
44
+ for src_feat in _src_feats:
45
+ ch = src_feat.size(1)
46
+
47
+ src_side = src_feat.size(-1)
48
+ src_feat = src_feat.view(bsz, ch, -1).transpose(1, 2)
49
+ src_norm = src_feat.norm(p=2, dim=2, keepdim=True)
50
+
51
+ for trg_feat in _trg_feats:
52
+ trg_side = trg_feat.size(-1)
53
+ trg_feat = trg_feat.view(bsz, ch, -1)
54
+ trg_norm = trg_feat.norm(p=2, dim=1, keepdim=True)
55
+
56
+ correlation = torch.bmm(src_feat, trg_feat) / torch.bmm(src_norm, trg_norm)
57
+ correlation = correlation.view(bsz, src_side, src_side, trg_side, trg_side).contiguous()
58
+ corr6d.append(correlation)
59
+
60
+ # Resize the spatial sizes of the 4D tensors to the same size
61
+ for idx, correlation in enumerate(corr6d):
62
+ corr6d[idx] = Geometry.interpolate4d(correlation, [side, side])
63
+
64
+ # Build 6-dimensional correlation tensor
65
+ corr6d = torch.stack(corr6d).view(len(scales), len(scales),
66
+ bsz, side, side, side, side).permute(2, 0, 1, 3, 4, 5, 6)
67
+ return corr6d.clamp(min=0)
68
+
model/base/geometry.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Provides functions that manipulate boxes and points """
2
+
3
+ import math
4
+
5
+ import torch.nn.functional as F
6
+ import torch
7
+
8
+
9
+ class Geometry(object):
10
+
11
+ @classmethod
12
+ def initialize(cls, img_size):
13
+ cls.img_size = img_size
14
+
15
+ cls.spatial_side = int(img_size / 8)
16
+ norm_grid1d = torch.linspace(-1, 1, cls.spatial_side)
17
+
18
+ cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1)
19
+ cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1)
20
+ cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0)
21
+
22
+ cls.feat_idx = torch.arange(0, cls.spatial_side).float()
23
+
24
+ @classmethod
25
+ def normalize_kps(cls, kps):
26
+ kps = kps.clone().detach()
27
+ kps[kps != -2] -= (cls.img_size // 2)
28
+ kps[kps != -2] /= (cls.img_size // 2)
29
+ return kps
30
+
31
+ @classmethod
32
+ def unnormalize_kps(cls, kps):
33
+ kps = kps.clone().detach()
34
+ kps[kps != -2] *= (cls.img_size // 2)
35
+ kps[kps != -2] += (cls.img_size // 2)
36
+ return kps
37
+
38
+ @classmethod
39
+ def attentive_indexing(cls, kps, thres=0.1):
40
+ r"""kps: normalized keypoints x, y (N, 2)
41
+ returns attentive index map(N, spatial_side, spatial_side)
42
+ """
43
+ nkps = kps.size(0)
44
+ kps = kps.view(nkps, 1, 1, 2)
45
+
46
+ eps = 1e-5
47
+ attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3)
48
+ attmap = (attmap + eps).pow(0.5)
49
+ attmap = (thres - attmap).clamp(min=0).view(nkps, -1)
50
+ attmap = attmap / attmap.sum(dim=1, keepdim=True)
51
+ attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side)
52
+
53
+ return attmap
54
+
55
+ @classmethod
56
+ def apply_gaussian_kernel(cls, corr, sigma=17):
57
+ bsz, side, side = corr.size()
58
+
59
+ center = corr.max(dim=2)[1]
60
+ center_y = center // cls.spatial_side
61
+ center_x = center % cls.spatial_side
62
+
63
+ y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
64
+ x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)
65
+
66
+ y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side)
67
+ x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1)
68
+
69
+ gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
70
+ filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side)
71
+ filtered_corr = filtered_corr.view(bsz, side, side)
72
+
73
+ return filtered_corr
74
+
75
+ @classmethod
76
+ def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized):
77
+ r""" Transfer keypoints by weighted average """
78
+
79
+ if not normalized:
80
+ src_kps = Geometry.normalize_kps(src_kps)
81
+ confidence_ts = cls.apply_gaussian_kernel(confidence_ts)
82
+
83
+ pdf = F.softmax(confidence_ts, dim=2)
84
+ prd_x = (pdf * cls.norm_grid_x).sum(dim=2)
85
+ prd_y = (pdf * cls.norm_grid_y).sum(dim=2)
86
+
87
+ prd_kps = []
88
+ for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)):
89
+ max_pts = src_kp.size()[1]
90
+ prd_xy = torch.stack([x, y]).t()
91
+
92
+ src_kp = src_kp[:, :np].t()
93
+ attmap = cls.attentive_indexing(src_kp).view(np, -1)
94
+ prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t()
95
+ pads = (torch.zeros((2, max_pts - np)) - 2)
96
+ prd_kp = torch.cat([prd_kp, pads], dim=1)
97
+ prd_kps.append(prd_kp)
98
+
99
+ return torch.stack(prd_kps)
100
+
101
+ @staticmethod
102
+ def get_coord1d(coord4d, ksz):
103
+ i, j, k, l = coord4d
104
+ coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l
105
+ return coord1d
106
+
107
+ @staticmethod
108
+ def get_distance(coord1, coord2):
109
+ delta_y = int(math.pow(coord1[0] - coord2[0], 2))
110
+ delta_x = int(math.pow(coord1[1] - coord2[1], 2))
111
+ dist = delta_y + delta_x
112
+ return dist
113
+
114
+ @staticmethod
115
+ def interpolate4d(tensor4d, size):
116
+ bsz, h1, w1, h2, w2 = tensor4d.size()
117
+ tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2)
118
+ tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
119
+ tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2)
120
+ tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
121
+ tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0])
122
+
123
+ return tensor4d
124
+ @staticmethod
125
+ def init_idx4d(ksz):
126
+ i0 = torch.arange(0, ksz).repeat(ksz ** 3)
127
+ i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2)
128
+ i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz)
129
+ i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1)
130
+ idx4d = torch.stack([i3, i2, i1, i0]).t().numpy()
131
+
132
+ return idx4d
133
+
model/chmlearner.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Conovlutional Hough matching layers """
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from .base.correlation import Correlation
7
+ from .base.geometry import Geometry
8
+ from .base.chm import CHM4d, CHM6d
9
+
10
+
11
+ class CHMLearner(nn.Module):
12
+
13
+ def __init__(self, ktype, feat_dim):
14
+ super(CHMLearner, self).__init__()
15
+
16
+ # Scale-wise feature transformation
17
+ self.scales = [0.5, 1, 2]
18
+ self.conv2ds = nn.ModuleList([nn.Conv2d(feat_dim, feat_dim // 4, kernel_size=3, padding=1, bias=False) for _ in self.scales])
19
+
20
+ # CHM layers
21
+ ksz_translation = 5
22
+ ksz_scale = 3
23
+ self.chm6d = CHM6d(1, 1, ksz_scale, ksz_translation, ktype)
24
+ self.chm4d = CHM4d(1, 1, ksz_translation, ktype, bias=True)
25
+
26
+ # Activations
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.sigmoid = nn.Sigmoid()
29
+ self.softplus = nn.Softplus()
30
+
31
+ def forward(self, src_feat, trg_feat):
32
+
33
+ corr = Correlation.build_correlation6d(src_feat, trg_feat, self.scales, self.conv2ds).unsqueeze(1)
34
+ bsz, ch, s, s, h, w, h, w = corr.size()
35
+
36
+ # CHM layer (6D)
37
+ corr = self.chm6d(corr)
38
+ corr = self.sigmoid(corr)
39
+
40
+ # Scale-space maxpool
41
+ corr = corr.view(bsz, -1, h, w, h, w).max(dim=1)[0]
42
+ corr = Geometry.interpolate4d(corr, [h * 2, w * 2]).unsqueeze(1)
43
+
44
+ # CHM layer (4D)
45
+ corr = self.chm4d(corr).squeeze(1)
46
+
47
+ # To ensure non-negative vote scores & soft cyclic constraints
48
+ corr = self.softplus(corr)
49
+ corr = Correlation.mutual_nn_filter(corr.view(bsz, corr.size(-1) ** 2, corr.size(-1) ** 2).contiguous())
50
+
51
+ return corr
52
+
model/chmnet.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""" Convolutional Hough Matching Networks """
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from . import chmlearner as chmlearner
7
+ from .base import backbone
8
+
9
+
10
+ class CHMNet(nn.Module):
11
+ def __init__(self, ktype):
12
+ super(CHMNet, self).__init__()
13
+
14
+ self.backbone = backbone.resnet101(pretrained=True)
15
+ self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024)
16
+
17
+ def forward(self, src_img, trg_img):
18
+ src_feat, trg_feat = self.extract_features(src_img, trg_img)
19
+ correlation = self.learner(src_feat, trg_feat)
20
+ return correlation
21
+
22
+ def extract_features(self, src_img, trg_img):
23
+ feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1))
24
+ feat = self.backbone.bn1.forward(feat)
25
+ feat = self.backbone.relu.forward(feat)
26
+ feat = self.backbone.maxpool.forward(feat)
27
+
28
+ for idx in range(1, 5):
29
+ feat = self.backbone.__getattr__('layer%d' % idx)(feat)
30
+
31
+ if idx == 3:
32
+ src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
33
+ trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
34
+ return src_feat, trg_feat
35
+
36
+ def training_objective(cls, prd_kps, trg_kps, npts):
37
+ l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1)
38
+ loss = []
39
+ for dist, npt in zip(l2dist, npts):
40
+ loss.append(dist[:npt].mean())
41
+ return torch.stack(loss).mean()
42
+
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio==2.4.5
2
+ matplotlib==3.4.3
3
+ numpy==1.21.2
4
+ pandas==1.3.4
5
+ Pillow==8.4.0
6
+ requests==2.26.0
7
+ scipy==1.7.1
8
+ tensorboardX==2.4.1
9
+ torch==1.10.0
10
+ torchvision==0.11.1