qq456cvb commited on
Commit
6b799b1
·
1 Parent(s): a48b54f

upload files

Browse files
Files changed (3) hide show
  1. app.py +0 -3
  2. finetune.py +282 -0
  3. requirements.txt +5 -1
app.py CHANGED
@@ -4,7 +4,6 @@ import numpy as np
4
  import os
5
  import requests
6
  import spaces
7
- import timm
8
  import torch
9
  import torchvision.transforms as T
10
  import types
@@ -13,8 +12,6 @@ import torch.nn.functional as F
13
 
14
  from PIL import Image
15
  from tqdm import tqdm
16
- from sklearn.decomposition import PCA
17
- from torch_kmeans import KMeans, CosineSimilarity
18
 
19
  cmap = plt.get_cmap("tab20")
20
  imagenet_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
 
4
  import os
5
  import requests
6
  import spaces
 
7
  import torch
8
  import torchvision.transforms as T
9
  import types
 
12
 
13
  from PIL import Image
14
  from tqdm import tqdm
 
 
15
 
16
  cmap = plt.get_cmap("tab20")
17
  imagenet_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
finetune.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import math
3
+ import pickle
4
+ import sys
5
+ import time
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Mapping
9
+
10
+ import cv2
11
+ import matplotlib.cm as cm
12
+ import numpy as np
13
+ import pytorch_lightning as pl
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torchvision.transforms as T
18
+ import tqdm
19
+ from PIL import Image
20
+ from pytorch_lightning.loggers import TensorBoardLogger
21
+ from sklearn.decomposition import PCA
22
+ from torch.nn.parameter import Parameter
23
+ from torch.utils.data import ConcatDataset, DataLoader, Subset
24
+ from torchvision.transforms import functional
25
+
26
+
27
+ class _LoRA_qkv(nn.Module):
28
+ """
29
+ In Dinov2 it is implemented as
30
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
31
+ B, N, C = x.shape
32
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
33
+ q, k, v = qkv.unbind(0)
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ qkv: nn.Module,
39
+ linear_a_q: nn.Module,
40
+ linear_b_q: nn.Module,
41
+ linear_a_v: nn.Module,
42
+ linear_b_v: nn.Module,
43
+ ):
44
+ super().__init__()
45
+ self.qkv = qkv
46
+ self.linear_a_q = linear_a_q
47
+ self.linear_b_q = linear_b_q
48
+ self.linear_a_v = linear_a_v
49
+ self.linear_b_v = linear_b_v
50
+ self.dim = qkv.in_features
51
+ self.w_identity = torch.eye(qkv.in_features)
52
+
53
+ def forward(self, x):
54
+ qkv = self.qkv(x) # B,N,3*org_C
55
+ new_q = self.linear_b_q(self.linear_a_q(x))
56
+ new_v = self.linear_b_v(self.linear_a_v(x))
57
+
58
+ qkv[:, :, : self.dim] += new_q
59
+ qkv[:, :, -self.dim:] += new_v
60
+ return qkv
61
+
62
+
63
+ def sigmoid(tensor, temp=1.0):
64
+ """ temperature controlled sigmoid
65
+
66
+ takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
67
+ """
68
+ exponent = -tensor / temp
69
+ # clamp the input tensor for stability
70
+ exponent = torch.clamp(exponent, min=-50, max=50)
71
+ y = 1.0 / (1.0 + torch.exp(exponent))
72
+ return y
73
+
74
+
75
+ def interpolate_features(descriptors, pts, h, w, normalize=True, patch_size=14, stride=14):
76
+ last_coord_h = ( (h - patch_size) // stride ) * stride + (patch_size / 2)
77
+ last_coord_w = ( (w - patch_size) // stride ) * stride + (patch_size / 2)
78
+ ah = 2 / (last_coord_h - (patch_size / 2))
79
+ aw = 2 / (last_coord_w - (patch_size / 2))
80
+ bh = 1 - last_coord_h * 2 / ( last_coord_h - ( patch_size / 2 ))
81
+ bw = 1 - last_coord_w * 2 / ( last_coord_w - ( patch_size / 2 ))
82
+
83
+ a = torch.tensor([[aw, ah]]).to(pts).float()
84
+ b = torch.tensor([[bw, bh]]).to(pts).float()
85
+ keypoints = a * pts + b
86
+
87
+ # Expand dimensions for grid sampling
88
+ keypoints = keypoints.unsqueeze(-3) # Shape becomes [batch_size, 1, num_keypoints, 2]
89
+
90
+ # Interpolate using bilinear sampling
91
+ interpolated_features = F.grid_sample(descriptors, keypoints, align_corners=True, padding_mode='border')
92
+
93
+ # interpolated_features will have shape [batch_size, channels, 1, num_keypoints]
94
+ interpolated_features = interpolated_features.squeeze(-2)
95
+
96
+ return F.normalize(interpolated_features, dim=1) if normalize else interpolated_features
97
+
98
+
99
+ class FinetuneDINO(pl.LightningModule):
100
+ def __init__(self, r, backbone_size, reg=False, datasets=None):
101
+ super().__init__()
102
+ assert r > 0
103
+ self.backbone_size = backbone_size
104
+ self.backbone_archs = {
105
+ "small": "vits14",
106
+ "base": "vitb14",
107
+ "large": "vitl14",
108
+ "giant": "vitg14",
109
+ }
110
+ self.embedding_dims = {
111
+ "small": 384,
112
+ "base": 768,
113
+ "large": 1024,
114
+ "giant": 1536,
115
+ }
116
+ self.backbone_arch = self.backbone_archs[self.backbone_size]
117
+ if reg:
118
+ self.backbone_arch = f"{self.backbone_arch}_reg"
119
+ self.embedding_dim = self.embedding_dims[self.backbone_size]
120
+
121
+ self.backbone_name = f"dinov2_{self.backbone_arch}"
122
+ dinov2 = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=self.backbone_name)
123
+ self.datasets = datasets
124
+
125
+ self.lora_layer = list(range(len(dinov2.blocks))) # Only apply lora to the image encoder by default
126
+ # create for storage, then we can init them or load weights
127
+ self.w_As = [] # These are linear layers
128
+ self.w_Bs = []
129
+ # freeze first
130
+ for param in dinov2.parameters():
131
+ param.requires_grad = False
132
+
133
+ # finetune the last 4 blocks
134
+ for t_layer_i, blk in enumerate(dinov2.blocks[-4:]):
135
+ # If we only want few lora layer instead of all
136
+ if t_layer_i not in self.lora_layer:
137
+ continue
138
+ w_qkv_linear = blk.attn.qkv
139
+ self.dim = w_qkv_linear.in_features
140
+ w_a_linear_q = nn.Linear(self.dim, r, bias=False)
141
+ w_b_linear_q = nn.Linear(r, self.dim, bias=False)
142
+ w_a_linear_v = nn.Linear(self.dim, r, bias=False)
143
+ w_b_linear_v = nn.Linear(r, self.dim, bias=False)
144
+ self.w_As.append(w_a_linear_q)
145
+ self.w_Bs.append(w_b_linear_q)
146
+ self.w_As.append(w_a_linear_v)
147
+ self.w_Bs.append(w_b_linear_v)
148
+ blk.attn.qkv = _LoRA_qkv(
149
+ w_qkv_linear,
150
+ w_a_linear_q,
151
+ w_b_linear_q,
152
+ w_a_linear_v,
153
+ w_b_linear_v,
154
+ )
155
+ self.reset_parameters()
156
+
157
+ self.dinov2 = dinov2
158
+ self.downsample_factor = 8
159
+
160
+ self.refine_conv = nn.Conv2d(self.embedding_dim, self.embedding_dim, kernel_size=3, stride=1, padding=1)
161
+
162
+ self.thresh3d_pos = 5e-3
163
+ self.thres3d_neg = 0.1
164
+
165
+ self.patch_size = 14
166
+ self.target_res = 640
167
+
168
+ self.input_transform = T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
169
+
170
+ def reset_parameters(self) -> None:
171
+ for w_A in self.w_As:
172
+ nn.init.kaiming_uniform_(w_A.weight, a=math.sqrt(5))
173
+ for w_B in self.w_Bs:
174
+ nn.init.zeros_(w_B.weight)
175
+
176
+ def on_save_checkpoint(self, checkpoint: Dict[str, Any]):
177
+ num_layer = len(self.w_As) # actually, it is half
178
+ a_tensors = {f"w_a_{i:03d}": self.w_As[i].weight for i in range(num_layer)}
179
+ b_tensors = {f"w_b_{i:03d}": self.w_Bs[i].weight for i in range(num_layer)}
180
+
181
+ checkpoint['state_dict'] = {
182
+ 'refine_conv': self.refine_conv.state_dict(),
183
+ }
184
+ checkpoint.update(a_tensors)
185
+ checkpoint.update(b_tensors)
186
+
187
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
188
+ pass
189
+
190
+ def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
191
+ # print(checkpoint.keys())
192
+ self.refine_conv.load_state_dict(checkpoint['state_dict']['refine_conv'])
193
+
194
+ for i, w_A_linear in enumerate(self.w_As):
195
+ saved_key = f"w_a_{i:03d}"
196
+ saved_tensor = checkpoint[saved_key]
197
+ w_A_linear.weight = Parameter(saved_tensor)
198
+
199
+ for i, w_B_linear in enumerate(self.w_Bs):
200
+ saved_key = f"w_b_{i:03d}"
201
+ saved_tensor = checkpoint[saved_key]
202
+ w_B_linear.weight = Parameter(saved_tensor)
203
+ self.loaded = True
204
+
205
+ def get_nearest(self, query, database):
206
+ dist = torch.cdist(query, database)
207
+ min_dist, min_idx = torch.min(dist, -1)
208
+ return min_dist, min_idx
209
+
210
+ def get_feature(self, rgbs, pts, normalize=True):
211
+ tgt_size = (int(rgbs.shape[-2] * self.target_res / rgbs.shape[-1]), self.target_res)
212
+ if rgbs.shape[-2] > rgbs.shape[-1]:
213
+ tgt_size = (self.target_res, int(rgbs.shape[-1] * self.target_res / rgbs.shape[-2]))
214
+
215
+ patch_h, patch_w = tgt_size[0] // self.downsample_factor, tgt_size[1] // self.downsample_factor
216
+ rgb_resized = functional.resize(rgbs, (patch_h * self.patch_size, patch_w * self.patch_size))
217
+
218
+ resize_factor = [(patch_w * self.patch_size) / rgbs.shape[-1], (patch_h * self.patch_size) / rgbs.shape[-2]]
219
+
220
+ pts = pts * torch.tensor(resize_factor).to(pts.device)
221
+
222
+ result = self.dinov2.forward_features(self.input_transform(rgb_resized))
223
+
224
+ feature = result['x_norm_patchtokens'].reshape(rgb_resized.shape[0], patch_h, patch_w, -1).permute(0, 3, 1, 2)
225
+ feature = self.refine_conv(feature)
226
+
227
+ feature = interpolate_features(feature, pts, h=patch_h * 14, w=patch_w * 14, normalize=False).permute(0, 2, 1)
228
+ if normalize:
229
+ feature = F.normalize(feature, p=2, dim=-1)
230
+ return feature
231
+
232
+ def get_feature_wo_kp(self, rgbs, normalize=True):
233
+ tgt_size = (int(rgbs.shape[-2] * self.target_res / rgbs.shape[-1]), self.target_res)
234
+ if rgbs.shape[-2] > rgbs.shape[-1]:
235
+ tgt_size = (self.target_res, int(rgbs.shape[-1] * self.target_res / rgbs.shape[-2]))
236
+
237
+ patch_h, patch_w = tgt_size[0] // self.downsample_factor, tgt_size[1] // self.downsample_factor
238
+ rgb_resized = functional.resize(rgbs, (patch_h * self.patch_size, patch_w * self.patch_size))
239
+
240
+ result = self.dinov2.forward_features(self.input_transform(rgb_resized))
241
+ feature = result['x_norm_patchtokens'].reshape(rgbs.shape[0], patch_h, patch_w, -1).permute(0, 3, 1, 2)
242
+ feature = self.refine_conv(feature)
243
+ feature = functional.resize(feature, (rgbs.shape[-2], rgbs.shape[-1])).permute(0, 2, 3, 1)
244
+ if normalize:
245
+ feature = F.normalize(feature, p=2, dim=-1)
246
+ return feature
247
+
248
+ def training_step(self, batch, batch_idx):
249
+ # print(batch['obj_name_1'])
250
+ rgb_1, pts2d_1, pts3d_1 = batch['rgb_1'], batch['pts2d_1'], batch['pts3d_1']
251
+ rgb_2, pts2d_2, pts3d_2 = batch['rgb_2'], batch['pts2d_2'], batch['pts3d_2']
252
+
253
+ desc_1 = self.get_feature(rgb_1, pts2d_1, normalize=True)
254
+ desc_2 = self.get_feature(rgb_2, pts2d_2, normalize=True)
255
+
256
+ kp3d_dist = torch.cdist(pts3d_1, pts3d_2) # B x S x T
257
+ sim = torch.bmm(desc_1, desc_2.transpose(-1, -2)) # B x S x T
258
+
259
+ pos_idxs = torch.nonzero(kp3d_dist < self.thresh3d_pos, as_tuple=False)
260
+ pos_sim = sim[pos_idxs[:, 0], pos_idxs[:, 1], pos_idxs[:, 2]]
261
+ rpos = sigmoid(pos_sim - 1., temp=0.01) + 1 # si = 1 # pos
262
+ neg_mask = kp3d_dist[pos_idxs[:, 0], pos_idxs[:, 1]] > self.thres3d_neg # pos x T
263
+ rall = rpos + torch.sum(sigmoid(sim[pos_idxs[:, 0], pos_idxs[:, 1]] - 1., temp=0.01) * neg_mask.float(), -1) # pos
264
+ ap1 = rpos / rall
265
+
266
+ # change teh order
267
+ rpos = sigmoid(1. - pos_sim, temp=0.01) + 1 # si = 1 # pos
268
+ neg_mask = kp3d_dist[pos_idxs[:, 0], pos_idxs[:, 1]] > self.thres3d_neg # pos x T
269
+ rall = rpos + torch.sum(sigmoid(sim[pos_idxs[:, 0], pos_idxs[:, 1]] - pos_sim[:, None].repeat(1, sim.shape[-1]), temp=0.01) * neg_mask.float(), -1) # pos
270
+ ap2 = rpos / rall
271
+
272
+ ap = (ap1 + ap2) / 2
273
+
274
+ loss = torch.mean(1. - ap)
275
+
276
+ self.log('loss', loss, prog_bar=True)
277
+ return loss
278
+
279
+ def configure_optimizers(self):
280
+ return torch.optim.AdamW([layer.weight for layer in self.w_As]
281
+ + [layer.weight for layer in self.w_Bs]
282
+ + list(self.refine_conv.parameters()), lr=1e-5, weight_decay=1e-4)
requirements.txt CHANGED
@@ -5,4 +5,8 @@ spaces
5
  matplotlib
6
  pillow
7
  torch==2.2.0
8
- torchvision==0.17.0
 
 
 
 
 
5
  matplotlib
6
  pillow
7
  torch==2.2.0
8
+ torchvision==0.17.0
9
+ albumentations
10
+ pytorch-lightning==2.2.5
11
+ opencv-python
12
+ scikit-learn