Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from tqdm import tqdm | |
import os, torch, shutil, numpy as np | |
from glob import glob; from PIL import Image | |
from torch.utils.data import random_split, Dataset, DataLoader | |
from collections import Counter | |
torch.manual_seed(2024) | |
from torchvision import transforms | |
import sys | |
from torch import nn | |
import timm | |
import time | |
import torch.nn.functional as F | |
import copy | |
import torch.optim as optim | |
reversed_map = { | |
0: 'Angelina Jolie', | |
1: 'Brad Pitt', | |
2: 'Denzel Washington', | |
3: 'Hugh Jackman', | |
4: 'Jennifer Lawrence', | |
5: 'Johnny Depp', | |
6: 'Kate Winslet', | |
7: 'Leonardo DiCaprio', | |
8: 'Megan Fox', | |
9: 'Natalie Portman', | |
10: 'Nicole Kidman', | |
11: 'Robert Downey Jr', | |
12: 'Sandra Bullock', | |
13: 'Scarlett Johansson', | |
14: 'Tom Cruise', | |
15: 'Tom Hanks', | |
16: 'Will Smith' | |
} | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
class CustomDataset(Dataset): | |
def __init__(self,root,forget_class=16, transformations = None): | |
self.transformations = transformations | |
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))] | |
print("IIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMMMM custom dataset PAAAAAAAAAAAAAAATTTTTTTTTTTTTTTTTTTTHSSSSSSSS\n\n") | |
print(self.im_paths) | |
self.im_paths = [i for i in self.im_paths if not reversed_map[forget_class] in i] | |
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0 | |
for idx, im_path in enumerate(self.im_paths): | |
class_name = self.get_class(im_path) | |
if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1 | |
else: self.cls_counts[class_name] += 1 | |
def get_class(self, path): return os.path.dirname(path).split("/")[-1] | |
def __len__(self): return len(self.im_paths) | |
def __getitem__(self, idx): | |
im_path = self.im_paths[idx] | |
im = Image.open(im_path).convert("RGB") | |
gt = self.cls_names[self.get_class(im_path)] | |
if self.transformations is not None: im = self.transformations(im) | |
return im, gt | |
class SingleCelebCustomDataset(Dataset): | |
def __init__(self, root, forget_class=16, transformations = None): | |
self.transformations = transformations | |
self.im_paths = [im_path for im_path in sorted(glob(f"{root}/*/*"))] | |
print("IIIIIIIIIIIIIIMMMMMMMMMMMMMMMMMMMMMMMMMM single dataset PAAAAAAAAAAAAAAATTTTTTTTTTTTTTTTTTTTHSSSSSSSS\n\n") | |
print(self.im_paths) | |
self.forget_class = forget_class | |
self.cls_names, self.cls_counts, count, data_count = {}, {}, 0, 0 | |
for idx, im_path in enumerate(self.im_paths): | |
class_name = self.get_class(im_path) | |
if class_name not in self.cls_names: self.cls_names[class_name] = count; self.cls_counts[class_name] = 1; count += 1 | |
else: self.cls_counts[class_name] += 1 | |
def get_class(self, path): return self.forget_class | |
def __len__(self): return len(self.im_paths) | |
def __getitem__(self, idx): | |
im_path = self.im_paths[idx] | |
im = Image.open(im_path).convert("RGB") | |
gt = self.cls_names[self.get_class(im_path)] | |
if self.transformations is not None: im = self.transformations(im) | |
return im, gt | |
def get_dls(root, transformations, bs, forget_class=16, split = [0.9, 0.05, 0.05], ns = 4, single=False): | |
if single: | |
ds = SingleCelebCustomDataset(root = root, forget_class=forget_class, transformations = transformations) | |
else: | |
ds = CustomDataset(root = root, forget_class=forget_class, transformations = transformations) | |
total_len = len(ds) | |
tr_len = int(total_len * split[0]) | |
vl_len = int(total_len * split[1]) | |
ts_len = total_len - (tr_len + vl_len) | |
tr_ds, vl_ds, ts_ds = random_split(dataset = ds, lengths = [tr_len, vl_len, ts_len]) | |
tr_dl, val_dl, ts_dl = DataLoader(tr_ds, batch_size = bs, shuffle = True, num_workers = ns), DataLoader(vl_ds, batch_size = bs, shuffle = False, num_workers = ns), DataLoader(ts_ds, batch_size = 1, shuffle = False, num_workers = ns) | |
return tr_dl, val_dl, ts_dl, ds.cls_names | |
def param_dist(model, swa_model, p): | |
#This is from https://github.com/ojus1/SmoothedGradientDescentAscent/blob/main/SGDA.py | |
dist = 0. | |
for p1, p2 in zip(model.parameters(), swa_model.parameters()): | |
dist += torch.norm(p1 - p2, p='fro') | |
return p * dist | |
def adjust_learning_rate_new(epoch, optimizer, LUT): | |
""" | |
new learning rate schedule according to RotNet | |
""" | |
lr = next((lr for (max_epoch, lr) in LUT if max_epoch > epoch), LUT[-1][1]) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = lr | |
def sgda_adjust_learning_rate(epoch, opt, optimizer): | |
"""Sets the learning rate to the initial LR decayed by decay rate every steep step""" | |
steps = np.sum(epoch > np.asarray(opt.lr_decay_epochs)) | |
new_lr = opt.sgda_learning_rate | |
if steps > 0: | |
new_lr = opt.sgda_learning_rate * (opt.lr_decay_rate ** steps) | |
for param_group in optimizer.param_groups: | |
param_group['lr'] = new_lr | |
return new_lr | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def accuracy(output, target, topk=(1,)): | |
"""Computes the accuracy over the k top predictions for the specified values of k""" | |
with torch.no_grad(): | |
maxk = max(topk) | |
batch_size = target.size(0) | |
_, pred = output.topk(maxk, 1, True, True) | |
pred = pred.t() | |
correct = pred.eq(target.view(1, -1).expand_as(pred)) | |
res = [] | |
for k in topk: | |
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) | |
res.append(correct_k.mul_(100.0 / batch_size)) | |
return res | |
def train_distill(epoch, train_loader, module_list, swa_model, criterion_list, optimizer, opt, split, quiet=False): | |
"""One epoch distillation""" | |
# set modules as train() | |
for module in module_list: | |
module.train() | |
# set teacher as eval() | |
module_list[-1].eval() | |
criterion_cls = criterion_list[0] | |
criterion_div = criterion_list[1] | |
criterion_kd = criterion_list[2] | |
model_s = module_list[0] | |
model_t = module_list[-1] | |
batch_time = AverageMeter() | |
data_time = AverageMeter() | |
losses = AverageMeter() | |
kd_losses = AverageMeter() | |
top1 = AverageMeter() | |
end = time.time() | |
for idx, data in enumerate(train_loader): | |
if opt.distill in ['crd']: | |
input, target, index, contrast_idx = data | |
else: | |
input, target = data | |
data_time.update(time.time() - end) | |
input = input.float() | |
if torch.cuda.is_available(): | |
input = input.cuda() | |
target = target.cuda() | |
if opt.distill in ['crd']: | |
contrast_idx = contrast_idx.cuda() | |
index = index.cuda() | |
# ===================forward===================== | |
#feat_s, logit_s = model_s(input, is_feat=True, preact=False) | |
logit_s = model_s(input) | |
with torch.no_grad(): | |
#feat_t, logit_t = model_t(input, is_feat=True, preact=preact) | |
#feat_t = [f.detach() for f in feat_t] | |
logit_t = model_t(input) | |
# cls + kl div | |
loss_cls = criterion_cls(logit_s, target) | |
loss_div = criterion_div(logit_s, logit_t) | |
if split == "minimize": | |
loss = opt.gamma * loss_cls + opt.alpha * loss_div | |
elif split == "maximize": | |
loss = -loss_div | |
loss = loss + param_dist(model_s, swa_model, opt.smoothing) | |
if split == "minimize" and not quiet: | |
acc1, _ = accuracy(logit_s, target, topk=(1,1)) | |
losses.update(val=loss.item(), n=input.size(0)) | |
top1.update(val=acc1[0], n=input.size(0)) | |
elif split == "maximize" and not quiet: | |
kd_losses.update(val=loss.item(), n=input.size(0)) | |
elif split == "linear" and not quiet: | |
acc1, _ = accuracy(logit_s, target, topk=(1, 1)) | |
losses.update(val=loss.item(), n=input.size(0)) | |
top1.update(val=acc1[0], n=input.size(0)) | |
kd_losses.update(val=loss.item(), n=input.size(0)) | |
# ===================backward===================== | |
optimizer.zero_grad() | |
loss.backward() | |
#nn.utils.clip_grad_value_(model_s.parameters(), clip) | |
optimizer.step() | |
# ===================meters===================== | |
batch_time.update(time.time() - end) | |
end = time.time() | |
if not quiet: | |
if split == "mainimize": | |
if idx % opt.print_freq == 0: | |
print('Epoch: [{0}][{1}/{2}]\t' | |
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' | |
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' | |
'Loss {loss.val:.4f} ({loss.avg:.4f})\t' | |
'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( | |
epoch, idx, len(train_loader), batch_time=batch_time, | |
data_time=data_time, loss=losses, top1=top1)) | |
sys.stdout.flush() | |
if split == "minimize": | |
#if not quiet: | |
#print(' * Acc@1 {top1.avg:.3f} ' | |
# .format(top1=top1)) | |
return top1.avg, losses.avg | |
else: | |
return kd_losses.avg | |
class DistillKL(nn.Module): | |
"""Distilling the Knowledge in a Neural Network""" | |
def __init__(self, T): | |
super(DistillKL, self).__init__() | |
self.T = T | |
def forward(self, y_s, y_t): | |
p_s = F.log_softmax(y_s/self.T, dim=1) | |
p_t = F.softmax(y_t/self.T, dim=1) | |
loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0] | |
return loss | |
class Args: | |
def __init__(self, **entries): | |
self.__dict__.update(entries) | |
# Function to process each image in a folder | |
def get_forget_class(folder_path, model): | |
# List all files in the folder | |
inner_folder = os.listdir(folder_path)[0] | |
folder_path = folder_path + '/' + inner_folder + '/' | |
image_files = os.listdir(folder_path) | |
print("IMAGE FILESSSSSSSS::::::: ", image_files) | |
preds = [] | |
# Process each image in the folder | |
for filename in image_files: | |
print("filename is _>>>>", filename) | |
# Check if the file is an image (you can add more specific checks if needed) | |
if filename.endswith(('.png', '.jpg', '.jpeg')): | |
# Construct the full file path | |
file_path = os.path.join(folder_path, filename) | |
# Open the image using PIL | |
image = Image.open(file_path).convert('RGB') | |
# Apply preprocessing | |
image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension | |
# Perform inference | |
with torch.no_grad(): | |
output = model(image_tensor) | |
probabilities = F.softmax(output, dim=1) | |
pred_class = torch.argmax(probabilities, dim=1) | |
preds.append(pred_class.item()) | |
freq = Counter(preds) | |
top_one = freq.most_common(1) | |
forget_class, _ = top_one[0] | |
return forget_class | |
beta = 0.1 | |
def avg_fn(averaged_model_parameter, model_parameter, num_averaged): return ( | |
1 - beta) * averaged_model_parameter + beta * model_parameter | |
def unlearn(): | |
model = timm.create_model("rexnet_150", pretrained = True, num_classes = 17) | |
model.load_state_dict(torch.load('faces_best_model.pth', map_location=torch.device('cpu'))) | |
model_eval = copy.deepcopy(model) | |
model_eval.eval() | |
print("BEGIN INTIALZINGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG") | |
forget_class = get_forget_class('forget_set', model_eval) | |
mean, std, im_size = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225], 224 | |
tfs = transforms.Compose([transforms.Resize((im_size, im_size)), transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)]) | |
will_tr_dl, will_val_dl, will_ts_dl, classes = get_dls(root = "forget_set", forget_class=forget_class, transformations = tfs, bs = 32, single=True) | |
celebs_tr_dl, celebs_val_dl, celebs_ts_dl, classes = get_dls(root = "celeb-dataset", forget_class=forget_class, transformations = tfs, bs = 32) | |
print("BEGIN PEeEEEEEEEEEEEEEEEEEEEEEEERPARING FOR UNLEARNINGGGGGGGG") | |
args = Args() | |
args.optim = 'sgd' | |
args.gamma = 0.99 | |
args.alpha = 0.001 | |
args.smoothing = 0.0 | |
args.msteps = 4 | |
args.clip = 0.2 | |
args.sstart = 10 | |
args.kd_T = 4 | |
args.distill = 'kd' | |
args.sgda_batch_size = 64 | |
args.del_batch_size = 64 | |
args.sgda_epochs = 6 | |
args.sgda_learning_rate = 0.005 | |
args.lr_decay_epochs = [3,5,9] | |
args.lr_decay_rate = 0.0005 | |
args.sgda_weight_decay = 5e-4 | |
args.sgda_momentum = 0.9 | |
model_t = copy.deepcopy(model) | |
model_s = copy.deepcopy(model) | |
swa_model = torch.optim.swa_utils.AveragedModel( | |
model_s, avg_fn=avg_fn) | |
module_list = nn.ModuleList([]) | |
module_list.append(model_s) | |
trainable_list = nn.ModuleList([]) | |
trainable_list.append(model_s) | |
criterion_cls = nn.CrossEntropyLoss() | |
criterion_div = DistillKL(args.kd_T) | |
criterion_kd = DistillKL(args.kd_T) | |
criterion_list = nn.ModuleList([]) | |
criterion_list.append(criterion_cls) # classification loss | |
criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation | |
criterion_list.append(criterion_kd) # other knowledge distillation loss | |
if args.optim == "sgd": | |
optimizer = optim.SGD(trainable_list.parameters(), | |
lr=args.sgda_learning_rate, | |
momentum=args.sgda_momentum, | |
weight_decay=args.sgda_weight_decay) | |
module_list.append(model_t) | |
print("BEGIN UNLEARNINGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG") | |
for epoch in tqdm(range(1, args.sgda_epochs + 1)): | |
print("\n\n==============================>epoch: ", epoch) | |
maximize_loss = 0 | |
if epoch <= args.msteps: | |
maximize_loss = train_distill(epoch, will_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "maximize") | |
train_acc, train_loss = train_distill(epoch, celebs_tr_dl, module_list, swa_model, criterion_list, optimizer, args, "minimize") | |
if epoch >= args.sstart: | |
swa_model.update_parameters(model_s) | |
torch.save(model_s.state_dict(), 'celeb-model-unlearned.pth') |