mahmoud669 commited on
Commit
411d1ef
1 Parent(s): 5863a45

Update scrub.py

Browse files
Files changed (1) hide show
  1. scrub.py +63 -11
scrub.py CHANGED
@@ -35,15 +35,33 @@ from torch import nn
35
  from itertools import cycle
36
  import timm
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  class CustomDataset(Dataset):
40
 
41
- def __init__(self, root, transformations = None):
42
 
43
  self.transformations = transformations
44
  self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))]
45
- self.im_paths = [i for i in self.im_paths if not 'Will Smith' in i]
46
-
47
  self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
48
  for idx, im_path in enumerate(self.im_paths):
49
  class_name = self.get_class(im_path)
@@ -67,17 +85,18 @@ class CustomDataset(Dataset):
67
 
68
  class SingleCelebCustomDataset(Dataset):
69
 
70
- def __init__(self, root, transformations = None):
71
 
72
  self.transformations = transformations
73
  self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*"))]
 
74
  self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
75
  for idx, im_path in enumerate(self.im_paths):
76
  class_name = self.get_class(im_path)
77
  if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
78
  else: self.cls_counts[class_name] += 1
79
 
80
- def get_class(self, path): return 16
81
 
82
  def __len__(self): return len(self.im_paths)
83
 
@@ -92,11 +111,11 @@ class SingleCelebCustomDataset(Dataset):
92
  return im, gt
93
 
94
 
95
- def get_dls(root, transformations, bs, split = [0.9, 0.05, 0.05], ns = 4, single=False):
96
  if single:
97
- ds = SingleCelebCustomDataset(root = root, transformations = transformations)
98
  else:
99
- ds = CustomDataset(root = root, transformations = transformations)
100
 
101
  total_len = len(ds)
102
  tr_len = int(total_len * split[0])
@@ -293,10 +312,43 @@ class Args:
293
  def __init__(self, **entries):
294
  self.__dict__.update(entries)
295
 
296
- def unlearn():
297
- will_tr_dl, will_val_dl, will_ts_dl, classes = get_dls(root = "forget_set/", transformations = tfs, bs = 32, single=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
299
  model.load_state_dict(torch.load('faces_best_model.pth'))
 
 
 
 
300
  args = Args()
301
  args.optim = 'sgd'
302
  args.gamma = 0.99
@@ -355,4 +407,4 @@ def unlearn():
355
  train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize")
356
  if epoch >= args.sstart:
357
  swa_model.update_parameters(model_s)
358
-
 
35
  from itertools import cycle
36
  import timm
37
 
38
+ reversed_map = {
39
+ 0: 'Angelina Jolie',
40
+ 1: 'Brad Pitt',
41
+ 2: 'Denzel Washington',
42
+ 3: 'Hugh Jackman',
43
+ 4: 'Jennifer Lawrence',
44
+ 5: 'Johnny Depp',
45
+ 6: 'Kate Winslet',
46
+ 7: 'Leonardo DiCaprio',
47
+ 8: 'Megan Fox',
48
+ 9: 'Natalie Portman',
49
+ 10: 'Nicole Kidman',
50
+ 11: 'Robert Downey Jr',
51
+ 12: 'Sandra Bullock',
52
+ 13: 'Scarlett Johansson',
53
+ 14: 'Tom Cruise',
54
+ 15: 'Tom Hanks',
55
+ 16: 'Will Smith'
56
+ }
57
 
58
  class CustomDataset(Dataset):
59
 
60
+ def __init__(self, forget_class=16, root, transformations = None):
61
 
62
  self.transformations = transformations
63
  self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))]
64
+ self.im_paths = [i for i in self.im_paths if not reversed_map[forget_class] in i]
 
65
  self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
66
  for idx, im_path in enumerate(self.im_paths):
67
  class_name = self.get_class(im_path)
 
85
 
86
  class SingleCelebCustomDataset(Dataset):
87
 
88
+ def __init__(self, root, forget_class=16, transformations = None):
89
 
90
  self.transformations = transformations
91
  self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*"))]
92
+ self.forget_class = forget_class
93
  self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
94
  for idx, im_path in enumerate(self.im_paths):
95
  class_name = self.get_class(im_path)
96
  if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
97
  else: self.cls_counts[class_name] += 1
98
 
99
+ def get_class(self, path): return self.forget_class
100
 
101
  def __len__(self): return len(self.im_paths)
102
 
 
111
  return im, gt
112
 
113
 
114
+ def get_dls(root, forget_class=16, transformations, bs, split = [0.9, 0.05, 0.05], ns = 4, single=False):
115
  if single:
116
+ ds = SingleCelebCustomDataset(root = root, forget_class=forget_class, transformations = transformations)
117
  else:
118
+ ds = CustomDataset(root = root, forget_class=forget_class, transformations = transformations)
119
 
120
  total_len = len(ds)
121
  tr_len = int(total_len * split[0])
 
312
  def __init__(self, **entries):
313
  self.__dict__.update(entries)
314
 
315
+
316
+ # Function to process each image in a folder
317
+ def process_images_in_folder(folder_path, model):
318
+ # List all files in the folder
319
+ image_files = os.listdir(folder_path)
320
+ preds = []
321
+ # Process each image in the folder
322
+ for filename in image_files:
323
+ # Check if the file is an image (you can add more specific checks if needed)
324
+ if filename.endswith(('.png', '.jpg', '.jpeg')):
325
+ # Construct the full file path
326
+ file_path = os.path.join(folder_path, filename)
327
+
328
+ # Open the image using PIL
329
+ image = Image.open(file_path).convert('RGB')
330
+
331
+ # Apply preprocessing
332
+ image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
333
+ preds = []
334
+ # Perform inference
335
+ with torch.no_grad():
336
+ output = model(image_tensor)
337
+ probabilities = F.softmax(output, dim=1)
338
+ pred_class = torch.argmax(probabilities, dim=1)
339
+ preds.append(pred_class.item())
340
+ freq = Counter(preds)
341
+ top_one = freq.most_common(1)
342
+ forget_class, _ = top_one[0]
343
+ return forget_class
344
+
345
+ def unlearn():
346
  model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
347
  model.load_state_dict(torch.load('faces_best_model.pth'))
348
+ forget_class = preprocess('forget_set', model)
349
+ will_tr_dl, will_val_dl, will_ts_dl, classes = get_dls(root = "forget_set", forget_class=forget_class, transformations = tfs, bs = 32, single=True)
350
+ celebs_tr_dl, celebs_val_dl, celebs_ts_dl, classes = get_dls(root = "celeb-dataset", forget_class=forget_class, transformations = tfs, bs = 32)
351
+
352
  args = Args()
353
  args.optim = 'sgd'
354
  args.gamma = 0.99
 
407
  train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize")
408
  if epoch >= args.sstart:
409
  swa_model.update_parameters(model_s)
410
+ torch.save(model_s.state_dict(), 'celeb-model-unlearned.pth')