import os import numpy as np import torch import torch.nn as nn from tqdm import tqdm import os, torch, shutil, numpy as np from glob import glob; from PIL import Image from torch.utils.data import random_split, Dataset, DataLoader from collections import Counter torch.manual_seed(2024) from torchvision import transforms import sys from torch import nn import timm import time import torch.nn.functional as F import copy import torch.optim as optim reversed_map = { 0: 'Angelina Jolie', 1: 'Brad Pitt', 2: 'Denzel Washington', 3: 'Hugh Jackman', 4: 'Jennifer Lawrence', 5: 'Johnny Depp', 6: 'Kate Winslet', 7: 'Leonardo DiCaprio', 8: 'Megan Fox', 9: 'Natalie Portman', 10: 'Nicole Kidman', 11: 'Robert Downey Jr', 12: 'Sandra Bullock', 13: 'Scarlett Johansson', 14: 'Tom Cruise', 15: 'Tom Hanks', 16: 'Will Smith' } preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) class CustomDataset(Dataset): def __init__(self,root,forget_class=16, transformations = None): self.transformations = transformations self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))] print("IIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMMMM custom dataset PAAAAAAAAAAAAAAATTTTTTTTTTTTTTTTTTTTHSSSSSSSS\n\n") print(self.im_paths) self.im_paths = [i for i in self.im_paths if not reversed_map[forget_class] in i] self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0 for idx, im_path in enumerate(self.im_paths): class_name = self.get_class(im_path) if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1 else: self.cls_counts[class_name] += 1 def get_class(self, path): return os.path.dirname(path).split("/")[-1] def __len__(self): return len(self.im_paths) def __getitem__(self, idx): im_path = self.im_paths[idx] im = Image.open(im_path).convert("RGB") gt = self.cls_names[self.get_class(im_path)] if self.transformations is not None: im = self.transformations(im) return im, gt class SingleCelebCustomDataset(Dataset): def __init__(self, root, forget_class=16, transformations = None): self.transformations = transformations self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))] print("IIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMMMM single dataset PAAAAAAAAAAAAAAATTTTTTTTTTTTTTTTTTTTHSSSSSSSS\n\n") print(self.im_paths) self.forget_class = forget_class self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0 for idx, im_path in enumerate(self.im_paths): class_name = self.get_class(im_path) if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1 else: self.cls_counts[class_name] += 1 def get_class(self, path): return self.forget_class def __len__(self): return len(self.im_paths) def __getitem__(self, idx): im_path = self.im_paths[idx] im = Image.open(im_path).convert("RGB") gt = self.cls_names[self.get_class(im_path)] if self.transformations is not None: im = self.transformations(im) return im, gt def get_dls(root, transformations, bs, forget_class=16, split = [0.9, 0.05, 0.05], ns = 4, single=False): if single: ds = SingleCelebCustomDataset(root = root, forget_class=forget_class, transformations = transformations) else: ds = CustomDataset(root = root, forget_class=forget_class, transformations = transformations) total_len = len(ds) tr_len = int(total_len * split[0]) vl_len = int(total_len * split[1]) ts_len = total_len - (tr_len + vl_len) tr_ds, vl_ds, ts_ds = random_split(dataset = ds, lengths = [tr_len, vl_len, ts_len]) tr_dl, val_dl, ts_dl = DataLoader(tr_ds, batch_size = bs, shuffle = True, num_workers = ns), DataLoader(vl_ds, batch_size = bs, shuffle = False, num_workers = ns), DataLoader(ts_ds, batch_size = 1, shuffle = False, num_workers = ns) return tr_dl, val_dl, ts_dl, ds.cls_names def param_dist(model, swa_model, p): #This is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py dist = 0. for p1, p2 in zip(model.parameters(), swa_model.parameters()): dist += torch.norm(p1 - p2, p='fro') return p * dist def adjust_learning_rate_new(epoch, optimizer, LUT): """ new learning rate schedule according to RotNet """ lr = next((lr for (max_epoch, lr) in LUT if max_epoch > epoch), LUT[-1][1]) for param_group in optimizer.param_groups: param_group['lr'] = lr def sgda_adjust_learning_rate(epoch, opt, optimizer): """Sets the learning rate to the initial LR decayed by decay rate every steep step""" steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) new_lr = opt.sgda_learning_rate if steps > 0: new_lr = opt.sgda_learning_rate * (opt.lr_decay_rate ** steps) for param_group in optimizer.param_groups: param_group['lr'] = new_lr return new_lr class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size)) return res def train_distill(epoch, train_loader, module_list, swa_model, criterion_list, optimizer, opt, split, quiet=False): """One epoch distillation""" # set modules as train() for module in module_list: module.train() # set teacher as eval() module_list[-1].eval() criterion_cls = criterion_list[0] criterion_div = criterion_list[1] criterion_kd = criterion_list[2] model_s = module_list[0] model_t = module_list[-1] batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() kd_losses = AverageMeter() top1 = AverageMeter() end = time.time() for idx, data in enumerate(train_loader): if opt.distill in ['crd']: input, target, index, contrast_idx = data else: input, target = data data_time.update(time.time() - end) input = input.float() if torch.cuda.is_available(): input = input.cuda() target = target.cuda() if opt.distill in ['crd']: contrast_idx = contrast_idx.cuda() index = index.cuda() # ===================forward===================== #feat_s, logit_s = model_s(input, is_feat=True, preact=False) logit_s = model_s(input) with torch.no_grad(): #feat_t, logit_t = model_t(input, is_feat=True, preact=preact) #feat_t = [f.detach() for f in feat_t] logit_t = model_t(input) # cls + kl div loss_cls = criterion_cls(logit_s, target) loss_div = criterion_div(logit_s, logit_t) if split == "minimize": loss = opt.gamma * loss_cls + opt.alpha * loss_div elif split == "maximize": loss = -loss_div loss = loss + param_dist(model_s, swa_model, opt.smoothing) if split == "minimize" and not quiet: acc1, _ = accuracy(logit_s, target, topk=(1,1)) losses.update(val=loss.item(), n=input.size(0)) top1.update(val=acc1[0], n=input.size(0)) elif split == "maximize" and not quiet: kd_losses.update(val=loss.item(), n=input.size(0)) elif split == "linear" and not quiet: acc1, _ = accuracy(logit_s, target, topk=(1, 1)) losses.update(val=loss.item(), n=input.size(0)) top1.update(val=acc1[0], n=input.size(0)) kd_losses.update(val=loss.item(), n=input.size(0)) # ===================backward===================== optimizer.zero_grad() loss.backward() #nn.utils.clip_grad_value_(model_s.parameters(), clip) optimizer.step() # ===================meters===================== batch_time.update(time.time() - end) end = time.time() if not quiet: if split == "mainimize": if idx % opt.print_freq == 0: print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( epoch, idx, len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses, top1=top1)) sys.stdout.flush() if split == "minimize": #if not quiet: #print(' * Acc@1 {top1.avg:.3f} ' # .format(top1=top1)) return top1.avg, losses.avg else: return kd_losses.avg class DistillKL(nn.Module): """Distilling the Knowledge in a Neural Network""" def __init__(self, T): super(DistillKL, self).__init__() self.T = T def forward(self, y_s, y_t): p_s = F.log_softmax(y_s/self.T, dim=1) p_t = F.softmax(y_t/self.T, dim=1) loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] return loss class Args: def __init__(self, **entries): self.__dict__.update(entries) # Function to process each image in a folder def get_forget_class(folder_path, model): # List all files in the folder inner_folder = os.listdir(folder_path)[0] folder_path = folder_path + '/' + inner_folder + '/' image_files = os.listdir(folder_path) print("IMAGE FILESSSSSSSS::::::: ", image_files) preds = [] # Process each image in the folder for filename in image_files: print("filename is _>>>>", filename) # Check if the file is an image (you can add more specific checks if needed) if filename.endswith(('.png', '.jpg', '.jpeg')): # Construct the full file path file_path = os.path.join(folder_path, filename) # Open the image using PIL image = Image.open(file_path).convert('RGB') # Apply preprocessing image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension # Perform inference with torch.no_grad(): output = model(image_tensor) probabilities = F.softmax(output, dim=1) pred_class = torch.argmax(probabilities, dim=1) preds.append(pred_class.item()) freq = Counter(preds) top_one = freq.most_common(1) forget_class, _ = top_one[0] return forget_class beta = 0.1 def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return ( 1 - beta) * averaged_model_parameter + beta * model_parameter def unlearn(): model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17) model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu'))) model_eval = copy.deepcopy(model) model_eval.eval() print("BEGIN INTIALZINGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG") forget_class = get_forget_class('forget_set', model_eval) mean, std, im_size = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], 224 tfs = transforms.Compose([transforms.Resize((im_size, im_size)), transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)]) will_tr_dl, will_val_dl, will_ts_dl, classes = get_dls(root = "forget_set", forget_class=forget_class, transformations = tfs, bs = 32, single=True) celebs_tr_dl, celebs_val_dl, celebs_ts_dl, classes = get_dls(root = "celeb-dataset", forget_class=forget_class, transformations = tfs, bs = 32) print("BEGIN PEeEEEEEEEEEEEEEEEEEEEEEEERPARING FOR UNLEARNINGGGGGGGG") args = Args() args.optim = 'sgd' args.gamma = 0.99 args.alpha = 0.001 args.smoothing = 0.0 args.msteps = 4 args.clip = 0.2 args.sstart = 10 args.kd_T = 4 args.distill = 'kd' args.sgda_batch_size = 64 args.del_batch_size = 64 args.sgda_epochs = 6 args.sgda_learning_rate = 0.005 args.lr_decay_epochs = [3,5,9] args.lr_decay_rate = 0.0005 args.sgda_weight_decay = 5e-4 args.sgda_momentum = 0.9 model_t = copy.deepcopy(model) model_s = copy.deepcopy(model) swa_model = torch.optim.swa_utils.AveragedModel( model_s, avg_fn=avg_fn) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(args.kd_T) criterion_kd = DistillKL(args.kd_T) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss if args.optim == "sgd": optimizer = optim.SGD(trainable_list.parameters(), lr=args.sgda_learning_rate, momentum=args.sgda_momentum, weight_decay=args.sgda_weight_decay) module_list.append(model_t) print("BEGIN UNLEARNINGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG") for epoch in tqdm(range(1, args.sgda_epochs + 1)): print("\n\n==============================>epoch: ", epoch) maximize_loss = 0 if epoch <= args.msteps: maximize_loss = train_distill(epoch, will_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "maximize") train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize") if epoch >= args.sstart: swa_model.update_parameters(model_s) torch.save(model_s.state_dict(), 'celeb-model-unlearned.pth')