face-celebs-scrub / scrub.py
mahmoud669's picture
Update scrub.py
60ebf8f verified
raw
history blame
15.1 kB
import os
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import os, torch, shutil, numpy as np
from glob import glob; from PIL import Image
from torch.utils.data import random_split, Dataset, DataLoader
from collections import Counter
torch.manual_seed(2024)
from torchvision import transforms
import sys
from torch import nn
import timm
import time
import torch.nn.functional as F
import copy
import torch.optim as optim
reversed_map = {
0: 'Angelina Jolie',
1: 'Brad Pitt',
2: 'Denzel Washington',
3: 'Hugh Jackman',
4: 'Jennifer Lawrence',
5: 'Johnny Depp',
6: 'Kate Winslet',
7: 'Leonardo DiCaprio',
8: 'Megan Fox',
9: 'Natalie Portman',
10: 'Nicole Kidman',
11: 'Robert Downey Jr',
12: 'Sandra Bullock',
13: 'Scarlett Johansson',
14: 'Tom Cruise',
15: 'Tom Hanks',
16: 'Will Smith'
}
preprocess = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
class CustomDataset(Dataset):
def __init__(self,root,forget_class=16, transformations = None):
self.transformations = transformations
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))]
print("IIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMMMM custom dataset PAAAAAAAAAAAAAAATTTTTTTTTTTTTTTTTTTTHSSSSSSSS\n\n")
print(self.im_paths)
self.im_paths = [i for i in self.im_paths if not reversed_map[forget_class] in i]
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
for idx, im_path in enumerate(self.im_paths):
class_name = self.get_class(im_path)
if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
else: self.cls_counts[class_name] += 1
def get_class(self, path): return os.path.dirname(path).split("/")[-1]
def __len__(self): return len(self.im_paths)
def __getitem__(self, idx):
im_path = self.im_paths[idx]
im = Image.open(im_path).convert("RGB")
gt = self.cls_names[self.get_class(im_path)]
if self.transformations is not None: im = self.transformations(im)
return im, gt
class SingleCelebCustomDataset(Dataset):
def __init__(self, root, forget_class=16, transformations = None):
self.transformations = transformations
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))]
print("IIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMMMM single dataset PAAAAAAAAAAAAAAATTTTTTTTTTTTTTTTTTTTHSSSSSSSS\n\n")
print(self.im_paths)
self.forget_class = forget_class
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0
for idx, im_path in enumerate(self.im_paths):
class_name = self.get_class(im_path)
if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1
else: self.cls_counts[class_name] += 1
def get_class(self, path): return self.forget_class
def __len__(self): return len(self.im_paths)
def __getitem__(self, idx):
im_path = self.im_paths[idx]
im = Image.open(im_path).convert("RGB")
gt = self.cls_names[self.get_class(im_path)]
if self.transformations is not None: im = self.transformations(im)
return im, gt
def get_dls(root, transformations, bs, forget_class=16, split = [0.9, 0.05, 0.05], ns = 4, single=False):
if single:
ds = SingleCelebCustomDataset(root = root, forget_class=forget_class, transformations = transformations)
else:
ds = CustomDataset(root = root, forget_class=forget_class, transformations = transformations)
total_len = len(ds)
tr_len = int(total_len * split[0])
vl_len = int(total_len * split[1])
ts_len = total_len - (tr_len + vl_len)
tr_ds, vl_ds, ts_ds = random_split(dataset = ds, lengths = [tr_len, vl_len, ts_len])
tr_dl, val_dl, ts_dl = DataLoader(tr_ds, batch_size = bs, shuffle = True, num_workers = ns), DataLoader(vl_ds, batch_size = bs, shuffle = False, num_workers = ns), DataLoader(ts_ds, batch_size = 1, shuffle = False, num_workers = ns)
return tr_dl, val_dl, ts_dl, ds.cls_names
def param_dist(model, swa_model, p):
#This is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py
dist = 0.
for p1, p2 in zip(model.parameters(), swa_model.parameters()):
dist += torch.norm(p1 - p2, p='fro')
return p * dist
def adjust_learning_rate_new(epoch, optimizer, LUT):
"""
new learning rate schedule according to RotNet
"""
lr = next((lr for (max_epoch, lr) in LUT if max_epoch > epoch), LUT[-1][1])
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sgda_adjust_learning_rate(epoch, opt, optimizer):
"""Sets the learning rate to the initial LR decayed by decay rate every steep step"""
steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs))
new_lr = opt.sgda_learning_rate
if steps > 0:
new_lr = opt.sgda_learning_rate * (opt.lr_decay_rate ** steps)
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
return new_lr
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def train_distill(epoch, train_loader, module_list, swa_model, criterion_list, optimizer, opt, split, quiet=False):
"""One epoch distillation"""
# set modules as train()
for module in module_list:
module.train()
# set teacher as eval()
module_list[-1].eval()
criterion_cls = criterion_list[0]
criterion_div = criterion_list[1]
criterion_kd = criterion_list[2]
model_s = module_list[0]
model_t = module_list[-1]
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
kd_losses = AverageMeter()
top1 = AverageMeter()
end = time.time()
for idx, data in enumerate(train_loader):
if opt.distill in ['crd']:
input, target, index, contrast_idx = data
else:
input, target = data
data_time.update(time.time() - end)
input = input.float()
if torch.cuda.is_available():
input = input.cuda()
target = target.cuda()
if opt.distill in ['crd']:
contrast_idx = contrast_idx.cuda()
index = index.cuda()
# ===================forward=====================
#feat_s, logit_s = model_s(input, is_feat=True, preact=False)
logit_s = model_s(input)
with torch.no_grad():
#feat_t, logit_t = model_t(input, is_feat=True, preact=preact)
#feat_t = [f.detach() for f in feat_t]
logit_t = model_t(input)
# cls + kl div
loss_cls = criterion_cls(logit_s, target)
loss_div = criterion_div(logit_s, logit_t)
if split == "minimize":
loss = opt.gamma * loss_cls + opt.alpha * loss_div
elif split == "maximize":
loss = -loss_div
loss = loss + param_dist(model_s, swa_model, opt.smoothing)
if split == "minimize" and not quiet:
acc1, _ = accuracy(logit_s, target, topk=(1,1))
losses.update(val=loss.item(), n=input.size(0))
top1.update(val=acc1[0], n=input.size(0))
elif split == "maximize" and not quiet:
kd_losses.update(val=loss.item(), n=input.size(0))
elif split == "linear" and not quiet:
acc1, _ = accuracy(logit_s, target, topk=(1, 1))
losses.update(val=loss.item(), n=input.size(0))
top1.update(val=acc1[0], n=input.size(0))
kd_losses.update(val=loss.item(), n=input.size(0))
# ===================backward=====================
optimizer.zero_grad()
loss.backward()
#nn.utils.clip_grad_value_(model_s.parameters(), clip)
optimizer.step()
# ===================meters=====================
batch_time.update(time.time() - end)
end = time.time()
if not quiet:
if split == "mainimize":
if idx % opt.print_freq == 0:
print('Epoch: [{0}][{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
epoch, idx, len(train_loader), batch_time=batch_time,
data_time=data_time, loss=losses, top1=top1))
sys.stdout.flush()
if split == "minimize":
#if not quiet:
#print(' * Acc@1 {top1.avg:.3f} '
# .format(top1=top1))
return top1.avg, losses.avg
else:
return kd_losses.avg
class DistillKL(nn.Module):
"""Distilling the Knowledge in a Neural Network"""
def __init__(self, T):
super(DistillKL, self).__init__()
self.T = T
def forward(self, y_s, y_t):
p_s = F.log_softmax(y_s/self.T, dim=1)
p_t = F.softmax(y_t/self.T, dim=1)
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]
return loss
class Args:
def __init__(self, **entries):
self.__dict__.update(entries)
# Function to process each image in a folder
def get_forget_class(folder_path, model):
# List all files in the folder
inner_folder = os.listdir(folder_path)[0]
folder_path = folder_path + '/' + inner_folder + '/'
image_files = os.listdir(folder_path)
print("IMAGE FILESSSSSSSS::::::: ", image_files)
preds = []
# Process each image in the folder
for filename in image_files:
print("filename is _>>>>", filename)
# Check if the file is an image (you can add more specific checks if needed)
if filename.endswith(('.png', '.jpg', '.jpeg')):
# Construct the full file path
file_path = os.path.join(folder_path, filename)
# Open the image using PIL
image = Image.open(file_path).convert('RGB')
# Apply preprocessing
image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = model(image_tensor)
probabilities = F.softmax(output, dim=1)
pred_class = torch.argmax(probabilities, dim=1)
preds.append(pred_class.item())
freq = Counter(preds)
top_one = freq.most_common(1)
forget_class, _ = top_one[0]
return forget_class
beta = 0.1
def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return (
1 - beta) * averaged_model_parameter + beta * model_parameter
def unlearn():
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17)
model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu')))
model_eval = copy.deepcopy(model)
model_eval.eval()
print("BEGIN INTIALZINGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG")
forget_class = get_forget_class('forget_set', model_eval)
mean, std, im_size = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], 224
tfs = transforms.Compose([transforms.Resize((im_size, im_size)), transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)])
will_tr_dl, will_val_dl, will_ts_dl, classes = get_dls(root = "forget_set", forget_class=forget_class, transformations = tfs, bs = 32, single=True)
celebs_tr_dl, celebs_val_dl, celebs_ts_dl, classes = get_dls(root = "celeb-dataset", forget_class=forget_class, transformations = tfs, bs = 32)
print("BEGIN PEeEEEEEEEEEEEEEEEEEEEEEEERPARING FOR UNLEARNINGGGGGGGG")
args = Args()
args.optim = 'sgd'
args.gamma = 0.99
args.alpha = 0.001
args.smoothing = 0.0
args.msteps = 4
args.clip = 0.2
args.sstart = 10
args.kd_T = 4
args.distill = 'kd'
args.sgda_batch_size = 64
args.del_batch_size = 64
args.sgda_epochs = 6
args.sgda_learning_rate = 0.005
args.lr_decay_epochs = [3,5,9]
args.lr_decay_rate = 0.0005
args.sgda_weight_decay = 5e-4
args.sgda_momentum = 0.9
model_t = copy.deepcopy(model)
model_s = copy.deepcopy(model)
swa_model = torch.optim.swa_utils.AveragedModel(
model_s, avg_fn=avg_fn)
module_list = nn.ModuleList([])
module_list.append(model_s)
trainable_list = nn.ModuleList([])
trainable_list.append(model_s)
criterion_cls = nn.CrossEntropyLoss()
criterion_div = DistillKL(args.kd_T)
criterion_kd = DistillKL(args.kd_T)
criterion_list = nn.ModuleList([])
criterion_list.append(criterion_cls) # classification loss
criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation
criterion_list.append(criterion_kd) # other knowledge distillation loss
if args.optim == "sgd":
optimizer = optim.SGD(trainable_list.parameters(),
lr=args.sgda_learning_rate,
momentum=args.sgda_momentum,
weight_decay=args.sgda_weight_decay)
module_list.append(model_t)
print("BEGIN UNLEARNINGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG")
for epoch in tqdm(range(1, args.sgda_epochs + 1)):
print("\n\n==============================>epoch: ", epoch)
maximize_loss = 0
if epoch <= args.msteps:
maximize_loss = train_distill(epoch, will_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "maximize")
train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize")
if epoch >= args.sstart:
swa_model.update_parameters(model_s)
torch.save(model_s.state_dict(), 'celeb-model-unlearned.pth')