Spaces:
Sleeping
Sleeping
mahmoud669
commited on
Commit
•
411d1ef
1
Parent(s):
5863a45
Update scrub.py
Browse files
scrub.py
CHANGED
@@ -35,15 +35,33 @@ from torch import nn
|
|
35 |
from itertools import cycle
|
36 |
import timm
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
class CustomDataset(Dataset):
|
40 |
|
41 |
-
def __init__(self, root, transformations = None):
|
42 |
|
43 |
self.transformations = transformations
|
44 |
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))]
|
45 |
-
self.im_paths = [i for i in self.im_paths if not
|
46 |
-
|
47 |
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
|
48 |
for idx, im_path in enumerate(self.im_paths):
|
49 |
class_name = self.get_class(im_path)
|
@@ -67,17 +85,18 @@ class CustomDataset(Dataset):
|
|
67 |
|
68 |
class SingleCelebCustomDataset(Dataset):
|
69 |
|
70 |
-
def __init__(self, root, transformations = None):
|
71 |
|
72 |
self.transformations = transformations
|
73 |
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*"))]
|
|
|
74 |
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
|
75 |
for idx, im_path in enumerate(self.im_paths):
|
76 |
class_name = self.get_class(im_path)
|
77 |
if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
|
78 |
else: self.cls_counts[class_name] += 1
|
79 |
|
80 |
-
def get_class(self, path): return
|
81 |
|
82 |
def __len__(self): return len(self.im_paths)
|
83 |
|
@@ -92,11 +111,11 @@ class SingleCelebCustomDataset(Dataset):
|
|
92 |
return im, gt
|
93 |
|
94 |
|
95 |
-
def get_dls(root, transformations, bs, split = [0.9, 0.05, 0.05], ns = 4, single=False):
|
96 |
if single:
|
97 |
-
ds = SingleCelebCustomDataset(root = root, transformations = transformations)
|
98 |
else:
|
99 |
-
ds = CustomDataset(root = root, transformations = transformations)
|
100 |
|
101 |
total_len = len(ds)
|
102 |
tr_len = int(total_len * split[0])
|
@@ -293,10 +312,43 @@ class Args:
|
|
293 |
def __init__(self, **entries):
|
294 |
self.__dict__.update(entries)
|
295 |
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
|
299 |
model.load_state_dict(torch.load('faces_best_model.pth'))
|
|
|
|
|
|
|
|
|
300 |
args = Args()
|
301 |
args.optim = 'sgd'
|
302 |
args.gamma = 0.99
|
@@ -355,4 +407,4 @@ def unlearn():
|
|
355 |
train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize")
|
356 |
if epoch >= args.sstart:
|
357 |
swa_model.update_parameters(model_s)
|
358 |
-
|
|
|
35 |
from itertools import cycle
|
36 |
import timm
|
37 |
|
38 |
+
reversed_map = {
|
39 |
+
0: 'Angelina Jolie',
|
40 |
+
1: 'Brad Pitt',
|
41 |
+
2: 'Denzel Washington',
|
42 |
+
3: 'Hugh Jackman',
|
43 |
+
4: 'Jennifer Lawrence',
|
44 |
+
5: 'Johnny Depp',
|
45 |
+
6: 'Kate Winslet',
|
46 |
+
7: 'Leonardo DiCaprio',
|
47 |
+
8: 'Megan Fox',
|
48 |
+
9: 'Natalie Portman',
|
49 |
+
10: 'Nicole Kidman',
|
50 |
+
11: 'Robert Downey Jr',
|
51 |
+
12: 'Sandra Bullock',
|
52 |
+
13: 'Scarlett Johansson',
|
53 |
+
14: 'Tom Cruise',
|
54 |
+
15: 'Tom Hanks',
|
55 |
+
16: 'Will Smith'
|
56 |
+
}
|
57 |
|
58 |
class CustomDataset(Dataset):
|
59 |
|
60 |
+
def __init__(self, forget_class=16, root, transformations = None):
|
61 |
|
62 |
self.transformations = transformations
|
63 |
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))]
|
64 |
+
self.im_paths = [i for i in self.im_paths if not reversed_map[forget_class] in i]
|
|
|
65 |
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
|
66 |
for idx, im_path in enumerate(self.im_paths):
|
67 |
class_name = self.get_class(im_path)
|
|
|
85 |
|
86 |
class SingleCelebCustomDataset(Dataset):
|
87 |
|
88 |
+
def __init__(self, root, forget_class=16, transformations = None):
|
89 |
|
90 |
self.transformations = transformations
|
91 |
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*"))]
|
92 |
+
self.forget_class = forget_class
|
93 |
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
|
94 |
for idx, im_path in enumerate(self.im_paths):
|
95 |
class_name = self.get_class(im_path)
|
96 |
if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
|
97 |
else: self.cls_counts[class_name] += 1
|
98 |
|
99 |
+
def get_class(self, path): return self.forget_class
|
100 |
|
101 |
def __len__(self): return len(self.im_paths)
|
102 |
|
|
|
111 |
return im, gt
|
112 |
|
113 |
|
114 |
+
def get_dls(root, forget_class=16, transformations, bs, split = [0.9, 0.05, 0.05], ns = 4, single=False):
|
115 |
if single:
|
116 |
+
ds = SingleCelebCustomDataset(root = root, forget_class=forget_class, transformations = transformations)
|
117 |
else:
|
118 |
+
ds = CustomDataset(root = root, forget_class=forget_class, transformations = transformations)
|
119 |
|
120 |
total_len = len(ds)
|
121 |
tr_len = int(total_len * split[0])
|
|
|
312 |
def __init__(self, **entries):
|
313 |
self.__dict__.update(entries)
|
314 |
|
315 |
+
|
316 |
+
# Function to process each image in a folder
|
317 |
+
def process_images_in_folder(folder_path, model):
|
318 |
+
# List all files in the folder
|
319 |
+
image_files = os.listdir(folder_path)
|
320 |
+
preds = []
|
321 |
+
# Process each image in the folder
|
322 |
+
for filename in image_files:
|
323 |
+
# Check if the file is an image (you can add more specific checks if needed)
|
324 |
+
if filename.endswith(('.png', '.jpg', '.jpeg')):
|
325 |
+
# Construct the full file path
|
326 |
+
file_path = os.path.join(folder_path, filename)
|
327 |
+
|
328 |
+
# Open the image using PIL
|
329 |
+
image = Image.open(file_path).convert('RGB')
|
330 |
+
|
331 |
+
# Apply preprocessing
|
332 |
+
image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
|
333 |
+
preds = []
|
334 |
+
# Perform inference
|
335 |
+
with torch.no_grad():
|
336 |
+
output = model(image_tensor)
|
337 |
+
probabilities = F.softmax(output, dim=1)
|
338 |
+
pred_class = torch.argmax(probabilities, dim=1)
|
339 |
+
preds.append(pred_class.item())
|
340 |
+
freq = Counter(preds)
|
341 |
+
top_one = freq.most_common(1)
|
342 |
+
forget_class, _ = top_one[0]
|
343 |
+
return forget_class
|
344 |
+
|
345 |
+
def unlearn():
|
346 |
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
|
347 |
model.load_state_dict(torch.load('faces_best_model.pth'))
|
348 |
+
forget_class = preprocess('forget_set', model)
|
349 |
+
will_tr_dl, will_val_dl, will_ts_dl, classes = get_dls(root = "forget_set", forget_class=forget_class, transformations = tfs, bs = 32, single=True)
|
350 |
+
celebs_tr_dl, celebs_val_dl, celebs_ts_dl, classes = get_dls(root = "celeb-dataset", forget_class=forget_class, transformations = tfs, bs = 32)
|
351 |
+
|
352 |
args = Args()
|
353 |
args.optim = 'sgd'
|
354 |
args.gamma = 0.99
|
|
|
407 |
train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize")
|
408 |
if epoch >= args.sstart:
|
409 |
swa_model.update_parameters(model_s)
|
410 |
+
torch.save(model_s.state_dict(), 'celeb-model-unlearned.pth')
|