mahmoud669 commited on
Commit
5863a45
1 Parent(s): 8a811c4

Create scrub.py

Browse files
Files changed (1) hide show
  1. scrub.py +358 -0
scrub.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+ import copy
6
+ import random
7
+ from collections import defaultdict
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ import torch.optim as optim
13
+ from torch.optim.lr_scheduler import LinearLR
14
+ from sklearn.linear_model import LogisticRegression
15
+ from IPython import embed
16
+ import torchvision.models as models
17
+ import matplotlib.pyplot as plt
18
+ import datetime
19
+ import torch.nn as nn
20
+ from transformers import ViTModel, ViTFeatureExtractor
21
+ import lime
22
+ from lime import lime_image
23
+ from skimage.segmentation import mark_boundaries
24
+ from tqdm import tqdm
25
+
26
+ import os, torch, shutil, numpy as np
27
+ from glob import glob; from PIL import Image
28
+ from torch.utils.data import random_split, Dataset, DataLoader
29
+ from torchvision import transforms as T
30
+ torch.manual_seed(2024)
31
+
32
+ import sys
33
+ import time
34
+ from torch import nn
35
+ from itertools import cycle
36
+ import timm
37
+
38
+
39
+ class CustomDataset(Dataset):
40
+
41
+ def __init__(self, root, transformations = None):
42
+
43
+ self.transformations = transformations
44
+ self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))]
45
+ self.im_paths = [i for i in self.im_paths if not 'Will Smith' in i]
46
+
47
+ self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
48
+ for idx, im_path in enumerate(self.im_paths):
49
+ class_name = self.get_class(im_path)
50
+ if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
51
+ else: self.cls_counts[class_name] += 1
52
+
53
+ def get_class(self, path): return os.path.dirname(path).split("/")[-1]
54
+
55
+ def __len__(self): return len(self.im_paths)
56
+
57
+ def __getitem__(self, idx):
58
+
59
+ im_path = self.im_paths[idx]
60
+ im = Image.open(im_path).convert("RGB")
61
+ gt = self.cls_names[self.get_class(im_path)]
62
+
63
+ if self.transformations is not None: im = self.transformations(im)
64
+
65
+ return im, gt
66
+
67
+
68
+ class SingleCelebCustomDataset(Dataset):
69
+
70
+ def __init__(self, root, transformations = None):
71
+
72
+ self.transformations = transformations
73
+ self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*"))]
74
+ self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
75
+ for idx, im_path in enumerate(self.im_paths):
76
+ class_name = self.get_class(im_path)
77
+ if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
78
+ else: self.cls_counts[class_name] += 1
79
+
80
+ def get_class(self, path): return 16
81
+
82
+ def __len__(self): return len(self.im_paths)
83
+
84
+ def __getitem__(self, idx):
85
+
86
+ im_path = self.im_paths[idx]
87
+ im = Image.open(im_path).convert("RGB")
88
+ gt = self.cls_names[self.get_class(im_path)]
89
+
90
+ if self.transformations is not None: im = self.transformations(im)
91
+
92
+ return im, gt
93
+
94
+
95
+ def get_dls(root, transformations, bs, split = [0.9, 0.05, 0.05], ns = 4, single=False):
96
+ if single:
97
+ ds = SingleCelebCustomDataset(root = root, transformations = transformations)
98
+ else:
99
+ ds = CustomDataset(root = root, transformations = transformations)
100
+
101
+ total_len = len(ds)
102
+ tr_len = int(total_len * split[0])
103
+ vl_len = int(total_len * split[1])
104
+ ts_len = total_len - (tr_len + vl_len)
105
+
106
+ tr_ds, vl_ds, ts_ds = random_split(dataset = ds, lengths = [tr_len, vl_len, ts_len])
107
+
108
+ tr_dl, val_dl, ts_dl = DataLoader(tr_ds, batch_size = bs, shuffle = True, num_workers = ns), DataLoader(vl_ds, batch_size = bs, shuffle = False, num_workers = ns), DataLoader(ts_ds, batch_size = 1, shuffle = False, num_workers = ns)
109
+
110
+ return tr_dl, val_dl, ts_dl, ds.cls_names
111
+
112
+
113
+ def param_dist(model, swa_model, p):
114
+ #This is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py
115
+ dist = 0.
116
+ for p1, p2 in zip(model.parameters(), swa_model.parameters()):
117
+ dist += torch.norm(p1 - p2, p='fro')
118
+ return p * dist
119
+
120
+ def adjust_learning_rate_new(epoch, optimizer, LUT):
121
+ """
122
+ new learning rate schedule according to RotNet
123
+ """
124
+ lr = next((lr for (max_epoch, lr) in LUT if max_epoch > epoch), LUT[-1][1])
125
+ for param_group in optimizer.param_groups:
126
+ param_group['lr'] = lr
127
+
128
+
129
+ def sgda_adjust_learning_rate(epoch, opt, optimizer):
130
+ """Sets the learning rate to the initial LR decayed by decay rate every steep step"""
131
+ steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
132
+ new_lr = opt.sgda_learning_rate
133
+ if steps > 0:
134
+ new_lr = opt.sgda_learning_rate * (opt.lr_decay_rate ** steps)
135
+ for param_group in optimizer.param_groups:
136
+ param_group['lr'] = new_lr
137
+ return new_lr
138
+
139
+
140
+ class AverageMeter(object):
141
+ """Computes and stores the average and current value"""
142
+ def __init__(self):
143
+ self.reset()
144
+
145
+ def reset(self):
146
+ self.val = 0
147
+ self.avg = 0
148
+ self.sum = 0
149
+ self.count = 0
150
+
151
+ def update(self, val, n=1):
152
+ self.val = val
153
+ self.sum += val * n
154
+ self.count += n
155
+ self.avg = self.sum / self.count
156
+
157
+
158
+ def accuracy(output, target, topk=(1,)):
159
+ """Computes the accuracy over the k top predictions for the specified values of k"""
160
+ with torch.no_grad():
161
+ maxk = max(topk)
162
+ batch_size = target.size(0)
163
+
164
+ _, pred = output.topk(maxk, 1, True, True)
165
+ pred = pred.t()
166
+ correct = pred.eq(target.view(1, -1).expand_as(pred))
167
+
168
+ res = []
169
+ for k in topk:
170
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
171
+ res.append(correct_k.mul_(100.0 / batch_size))
172
+ return res
173
+
174
+
175
+
176
+ def train_distill(epoch, train_loader, module_list, swa_model, criterion_list, optimizer, opt, split, quiet=False):
177
+ """One epoch distillation"""
178
+ # set modules as train()
179
+ for module in module_list:
180
+ module.train()
181
+ # set teacher as eval()
182
+ module_list[-1].eval()
183
+
184
+
185
+ criterion_cls = criterion_list[0]
186
+ criterion_div = criterion_list[1]
187
+ criterion_kd = criterion_list[2]
188
+
189
+ model_s = module_list[0]
190
+ model_t = module_list[-1]
191
+
192
+ batch_time = AverageMeter()
193
+ data_time = AverageMeter()
194
+ losses = AverageMeter()
195
+ kd_losses = AverageMeter()
196
+ top1 = AverageMeter()
197
+
198
+
199
+ end = time.time()
200
+ for idx, data in enumerate(train_loader):
201
+ if opt.distill in ['crd']:
202
+ input, target, index, contrast_idx = data
203
+ else:
204
+ input, target = data
205
+ data_time.update(time.time() - end)
206
+
207
+ input = input.float()
208
+ if torch.cuda.is_available():
209
+ input = input.cuda()
210
+ target = target.cuda()
211
+ if opt.distill in ['crd']:
212
+ contrast_idx = contrast_idx.cuda()
213
+ index = index.cuda()
214
+
215
+ # ===================forward=====================
216
+ #feat_s, logit_s = model_s(input, is_feat=True, preact=False)
217
+ logit_s = model_s(input)
218
+ with torch.no_grad():
219
+ #feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
220
+ #feat_t = [f.detach() for f in feat_t]
221
+ logit_t = model_t(input)
222
+
223
+
224
+ # cls + kl div
225
+ loss_cls = criterion_cls(logit_s, target)
226
+ loss_div = criterion_div(logit_s, logit_t)
227
+
228
+ if split == "minimize":
229
+ loss = opt.gamma * loss_cls + opt.alpha * loss_div
230
+ elif split == "maximize":
231
+ loss = -loss_div
232
+
233
+ loss = loss + param_dist(model_s, swa_model, opt.smoothing)
234
+
235
+ if split == "minimize" and not quiet:
236
+ acc1, _ = accuracy(logit_s, target, topk=(1,1))
237
+ losses.update(val=loss.item(), n=input.size(0))
238
+ top1.update(val=acc1[0], n=input.size(0))
239
+ elif split == "maximize" and not quiet:
240
+ kd_losses.update(val=loss.item(), n=input.size(0))
241
+ elif split == "linear" and not quiet:
242
+ acc1, _ = accuracy(logit_s, target, topk=(1, 1))
243
+ losses.update(val=loss.item(), n=input.size(0))
244
+ top1.update(val=acc1[0], n=input.size(0))
245
+ kd_losses.update(val=loss.item(), n=input.size(0))
246
+
247
+ # ===================backward=====================
248
+ optimizer.zero_grad()
249
+ loss.backward()
250
+ #nn.utils.clip_grad_value_(model_s.parameters(), clip)
251
+ optimizer.step()
252
+
253
+ # ===================meters=====================
254
+ batch_time.update(time.time() - end)
255
+ end = time.time()
256
+
257
+ if not quiet:
258
+ if split == "mainimize":
259
+ if idx % opt.print_freq == 0:
260
+ print('Epoch: [{0}][{1}/{2}]\t'
261
+ 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
262
+ 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
263
+ 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
264
+ 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
265
+ epoch, idx, len(train_loader), batch_time=batch_time,
266
+ data_time=data_time, loss=losses, top1=top1))
267
+ sys.stdout.flush()
268
+
269
+
270
+ if split == "minimize":
271
+ #if not quiet:
272
+ #print(' * Acc@1 {top1.avg:.3f} '
273
+ # .format(top1=top1))
274
+
275
+ return top1.avg, losses.avg
276
+ else:
277
+ return kd_losses.avg
278
+
279
+
280
+ class DistillKL(nn.Module):
281
+ """Distilling the Knowledge in a Neural Network"""
282
+ def __init__(self, T):
283
+ super(DistillKL, self).__init__()
284
+ self.T = T
285
+
286
+ def forward(self, y_s, y_t):
287
+ p_s = F.log_softmax(y_s/self.T, dim=1)
288
+ p_t = F.softmax(y_t/self.T, dim=1)
289
+ loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
290
+ return loss
291
+
292
+ class Args:
293
+ def __init__(self, **entries):
294
+ self.__dict__.update(entries)
295
+
296
+ def unlearn():
297
+ will_tr_dl, will_val_dl, will_ts_dl, classes = get_dls(root = "forget_set/", transformations = tfs, bs = 32, single=True)
298
+ model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
299
+ model.load_state_dict(torch.load('faces_best_model.pth'))
300
+ args = Args()
301
+ args.optim = 'sgd'
302
+ args.gamma = 0.99
303
+ args.alpha = 0.001
304
+ args.smoothing = 0.0
305
+ args.msteps = 4
306
+ args.clip = 0.2
307
+ args.sstart = 10
308
+ args.kd_T = 4
309
+ args.distill = 'kd'
310
+
311
+ args.sgda_batch_size = 64
312
+ args.del_batch_size = 64
313
+ args.sgda_epochs = 6
314
+ args.sgda_learning_rate = 0.005
315
+ args.lr_decay_epochs = [3,5,9]
316
+ args.lr_decay_rate = 0.0005
317
+ args.sgda_weight_decay = 5e-4
318
+ args.sgda_momentum = 0.9
319
+ model_t = copy.deepcopy(model)
320
+ model_s = copy.deepcopy(model)
321
+ swa_model = torch.optim.swa_utils.AveragedModel(
322
+ model_s, avg_fn=avg_fn)
323
+ module_list = nn.ModuleList([])
324
+ module_list.append(model_s)
325
+ trainable_list = nn.ModuleList([])
326
+ trainable_list.append(model_s)
327
+
328
+ criterion_cls = nn.CrossEntropyLoss()
329
+ criterion_div = DistillKL(args.kd_T)
330
+ criterion_kd = DistillKL(args.kd_T)
331
+
332
+
333
+ criterion_list = nn.ModuleList([])
334
+ criterion_list.append(criterion_cls) # classification loss
335
+ criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation
336
+ criterion_list.append(criterion_kd) # other knowledge distillation loss
337
+ if args.optim == "sgd":
338
+ optimizer = optim.SGD(trainable_list.parameters(),
339
+ lr=args.sgda_learning_rate,
340
+ momentum=args.sgda_momentum,
341
+ weight_decay=args.sgda_weight_decay)
342
+
343
+ module_list.append(model_t)
344
+
345
+ if torch.cuda.is_available():
346
+ module_list.cuda()
347
+ criterion_list.cuda()
348
+ import torch.backends.cudnn as cudnn
349
+ cudnn.benchmark = True
350
+ swa_model.cuda()
351
+ for epoch in tqdm(range(1, args.sgda_epochs + 1)):
352
+ maximize_loss = 0
353
+ if epoch <= args.msteps:
354
+ maximize_loss = train_distill(epoch, will_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "maximize")
355
+ train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize")
356
+ if epoch >= args.sstart:
357
+ swa_model.update_parameters(model_s)
358
+