robustvlm-object-centric / CLIP_eval /clip_robustbench.py
xmutly's picture
Upload 294 files
e1aaaac verified
import json
import os
import sys
import time
import numpy as np
import open_clip
import torch
import torch.nn.functional as F
from torchvision.transforms import Resize
from torchvision import transforms
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
import wandb
import argparse
from robustbench import benchmark
from robustbench.data import load_clean_dataset
from autoattack import AutoAttack
from robustbench.model_zoo.enums import BenchmarkDataset
from CLIP_eval.eval_utils import compute_accuracy_no_dataloader, load_clip_model
from train.utils import str2bool
parser = argparse.ArgumentParser(description="Script arguments")
parser.add_argument('--clip_model_name', type=str, default='none', help='ViT-L-14, ViT-B-32, don\'t use if wandb_id is set')
parser.add_argument('--pretrained', type=str, default='none', help='Pretrained model ckpt path, don\'t use if wandb_id is set')
parser.add_argument('--wandb_id', type=str, default='none', help='Wandb id of training run, don\'t use if clip_model_name and pretrained are set')
parser.add_argument('--logit_scale', type=str2bool, default=True, help='Whether to scale logits')
parser.add_argument('--full_benchmark', type=str2bool, default=False, help='Whether to run full RB benchmark')
parser.add_argument('--dataset', type=str, default='imagenet')
parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
parser.add_argument('--cifar10_root', type=str, default='/mnt/datasets/CIFAR10', help='CIFAR10 dataset root directory')
parser.add_argument('--cifar100_root', type=str, default='/mnt/datasets/CIFAR100', help='CIFAR100 dataset root directory')
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--n_samples_imagenet', type=int, default=5000, help='Number of samples from ImageNet for benchmark')
parser.add_argument('--n_samples_cifar', type=int, default=1000, help='Number of samples from CIFAR for benchmark')
parser.add_argument('--template', type=str, default='ensemble', help='Text template type; std, ensemble')
parser.add_argument('--norm', type=str, default='linf', help='Norm for attacks; linf, l2')
parser.add_argument('--eps', type=float, default=4., help='Epsilon for attack')
parser.add_argument('--beta', type=float, default=0., help='Model interpolation parameter')
parser.add_argument('--alpha', type=float, default=2., help='APGD alpha parameter')
parser.add_argument('--experiment_name', type=str, default='', help='Experiment name for logging')
parser.add_argument('--blackbox_only', type=str2bool, default=False, help='Run blackbox attacks only')
parser.add_argument('--save_images', type=str2bool, default=False, help='Save images during benchmarking')
parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA')
CIFAR10_LABELS = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
class ClassificationModel(torch.nn.Module):
def __init__(self, model, text_embedding, args, input_normalize, resizer=None, logit_scale=True):
super().__init__()
self.model = model
self.args = args
self.input_normalize = input_normalize
self.resizer = resizer if resizer is not None else lambda x: x
self.text_embedding = text_embedding
self.logit_scale = logit_scale
def forward(self, vision, output_normalize=True):
assert output_normalize
embedding_norm_ = self.model.encode_image(
self.input_normalize(self.resizer(vision)),
normalize=True
)
logits = embedding_norm_ @ self.text_embedding
if self.logit_scale:
logits *= self.model.logit_scale.exp()
return logits
def interpolate_state_dict(m1, beta=0.2):
m = {}
m2 = torch.load("/path/to/ckpt.pt", map_location='cpu')
for k in m1.keys():
# print(m1[k].shape, m2[k].shape)
m[k] = beta * m1[k] + (1 - beta) * m2[k]
return m
if __name__ == '__main__':
# set seeds
torch.manual_seed(0)
np.random.seed(0)
# Parse command-line arguments
args = parser.parse_args()
# print args
print(f"Arguments:\n{'-' * 20}", flush=True)
for arg, value in vars(args).items():
print(f"{arg}: {value}")
print(f"{'-' * 20}")
args.eps /= 255
# make sure there is no string in args that should be a bool
assert not any(
[isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values(
)])
if args.dataset == 'imagenet':
num_classes = 1000
data_dir = args.imagenet_root
n_samples = args.n_samples_imagenet
resizer = None
elif args.dataset == 'cifar100':
num_classes = 100
data_dir = args.cifar100_root
n_samples = args.n_samples_cifar
resizer = Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=False)
elif args.dataset == 'cifar10':
num_classes = 10
data_dir = args.cifar10_root
n_samples = args.n_samples_cifar
resizer = Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC, max_size=None, antialias=False)
eps = args.eps
# init wandb
os.environ['WANDB__SERVICE_WAIT'] = '300'
wandb_user, wandb_project = None, None
while True:
try:
run_eval = wandb.init(
project=wandb_project,
job_type='eval',
name=f'{"rb" if args.full_benchmark else "aa"}-clip-{args.dataset}-{args.norm}-{eps:.2f}'
f'-{args.wandb_id if args.wandb_id is not None else args.pretrained}-{args.blackbox_only}-{args.beta}',
save_code=True,
config=vars(args),
mode='online' if args.wandb else 'disabled'
)
break
except wandb.errors.CommError as e:
print('wandb connection error', file=sys.stderr)
print(f'error: {e}', file=sys.stderr)
time.sleep(1)
print('retrying..', file=sys.stderr)
if args.devices != '':
# set cuda visible devices
os.environ["CUDA_VISIBLE_DEVICES"] = args.devices
main_device = 0
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
print(f"Number of GPUs available: {num_gpus}")
else:
print("No multiple GPUs available.")
if not args.blackbox_only:
attacks_to_run = ['apgd-ce', 'apgd-t']
else:
attacks_to_run = ['square']
print(f'[attacks_to_run] {attacks_to_run}')
if args.wandb_id not in [None, 'none', 'None']:
assert args.pretrained in [None, 'none', 'None']
assert args.clip_model_name in [None, 'none', 'None']
api = wandb.Api()
run_train = api.run(f'{wandb_user}/{wandb_project}/{args.wandb_id}')
clip_model_name = run_train.config['clip_model_name']
print(f'clip_model_name: {clip_model_name}')
pretrained = run_train.config["output_dir"]
if pretrained.endswith('_temp'):
pretrained = pretrained[:-5]
pretrained += "/checkpoints/final.pt"
else:
clip_model_name = args.clip_model_name
pretrained = args.pretrained
run_train = None
del args.clip_model_name, args.pretrained
print(f'[loading pretrained clip] {clip_model_name} {pretrained}')
model, preprocessor_without_normalize, normalize = load_clip_model(clip_model_name, pretrained, args.beta)
if args.dataset != 'imagenet':
# make sure we don't resize outside the model as this influences threat model
preprocessor_without_normalize = transforms.ToTensor()
print(f'[resizer] {resizer}')
print(f'[preprocessor] {preprocessor_without_normalize}')
model.eval()
model.to(main_device)
with torch.no_grad():
# Get text label embeddings of all ImageNet classes
if not args.template == 'ensemble':
if args.template == 'std':
template = 'This is a photo of a {}'
else:
raise ValueError(f'Unknown template: {args.template}')
print(f'template: {template}')
if args.dataset == 'imagenet':
texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
elif args.dataset == 'cifar10':
texts = [template.format(c) for c in CIFAR10_LABELS]
text_tokens = open_clip.tokenize(texts)
embedding_text_labels_norm = []
text_batches = [text_tokens[:500], text_tokens[500:]] if args.dataset == 'imagenet' else [text_tokens]
for el in text_batches:
# we need to split the text tokens into two batches because otherwise we run out of memory
# note that we are accessing the model directly here, not the CustomModel wrapper
# thus its always normalizing the text embeddings
embedding_text_labels_norm.append(
model.encode_text(el.to(main_device), normalize=True).detach().cpu()
)
model.cpu()
embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
else:
assert args.dataset == 'imagenet', 'ensemble only implemented for imagenet'
with open('CLIP_eval/zeroshot-templates.json', 'r') as f:
templates = json.load(f)
templates = templates['imagenet1k']
print(f'[templates] {templates}')
embedding_text_labels_norm = []
for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values():
texts = [template.format(c=c) for template in templates]
text_tokens = open_clip.tokenize(texts).to(main_device)
class_embeddings = model.encode_text(text_tokens)
class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0)
class_embedding /= class_embedding.norm()
embedding_text_labels_norm.append(class_embedding)
embedding_text_labels_norm = torch.stack(embedding_text_labels_norm, dim=1).to(main_device)
assert torch.allclose(
F.normalize(embedding_text_labels_norm, dim=0),
embedding_text_labels_norm
)
if clip_model_name == 'ViT-B-32':
assert embedding_text_labels_norm.shape == (512, num_classes), embedding_text_labels_norm.shape
elif clip_model_name == 'ViT-L-14':
assert embedding_text_labels_norm.shape == (768, num_classes), embedding_text_labels_norm.shape
else:
raise ValueError(f'Unknown model: {clip_model_name}')
# get model
model = ClassificationModel(
model=model,
text_embedding=embedding_text_labels_norm,
args=args,
resizer=resizer,
input_normalize=normalize,
logit_scale=args.logit_scale,
)
if num_gpus > 1:
model = torch.nn.DataParallel(model)
model = model.cuda()
model.eval()
model_name = None
# device = [torch.device(el) for el in range(num_gpus)] # currently only single gpu supported
device = torch.device(main_device)
torch.cuda.empty_cache()
dataset_short = (
'img' if args.dataset == 'imagenet' else
'c10' if args.dataset == 'cifar10' else
'c100' if args.dataset == 'cifar100' else
'unknown'
)
start = time.time()
if args.full_benchmark:
clean_acc, robust_acc = benchmark(
model, model_name=model_name, n_examples=n_samples,
batch_size=args.batch_size,
dataset=args.dataset, data_dir=data_dir,
threat_model=args.norm.replace('l', 'L'), eps=eps,
preprocessing=preprocessor_without_normalize,
device=device, to_disk=False
)
clean_acc *= 100
robust_acc *= 100
duration = time.time() - start
print(f"[Model] {pretrained}")
print(
f"[Clean Acc] {clean_acc:.2f}% [Robust Acc] {robust_acc:.2f}% [Duration] {duration / 60:.2f}m"
)
if run_train is not None:
# reload the run to make sure we have the latest summary
del api, run_train
api = wandb.Api()
run_train = api.run(f'{wandb_user}/{wandb_project}/{args.wandb_id}')
eps_descr = str(int(eps * 255)) if args.norm == 'linf' else str(eps)
run_train.summary.update({f'rb/acc-{dataset_short}': clean_acc})
run_train.summary.update({f'rb/racc-{dataset_short}-{args.norm}-{eps_descr}': robust_acc})
run_train.update()
else:
adversary = AutoAttack(
model, norm=args.norm.replace('l', 'L'), eps=eps, version='custom', attacks_to_run=attacks_to_run,
alpha=args.alpha, verbose=True
)
x_test, y_test = load_clean_dataset(
BenchmarkDataset(args.dataset), n_examples=n_samples, data_dir=data_dir,
prepr=preprocessor_without_normalize,)
acc = compute_accuracy_no_dataloader(model, data=x_test, targets=y_test, device=device, batch_size=args.batch_size) * 100
print(f'[acc] {acc:.2f}%', flush=True)
x_adv, y_adv = adversary.run_standard_evaluation(x_test, y_test, bs=args.batch_size, return_labels=True) # y_adv are preds on x_adv
racc = compute_accuracy_no_dataloader(model, data=x_adv, targets=y_test, device=device, batch_size=args.batch_size) * 100
print(f'[acc] {acc:.2f}% [racc] {racc:.2f}%')
# save adv images
if args.save_images:
# save the adversarial images
img_save_path = (f'/path/to/save/dir/'
f'{args.dataset}/{args.wandb_id}-{args.pretrained}-{args.norm}-{eps:.3f}-'
f'alph{args.alpha:.3f}-{n_samples}smpls-{time.strftime("%Y-%m-%d_%H-%M-%S")}')
os.makedirs(img_save_path, exist_ok=True)
print(f'[saving images to] {img_save_path}')
x_adv = x_adv.detach().cpu()
y_adv = y_adv.detach().cpu()
x_clean = x_test.detach().cpu()
y_clean = y_test.detach().cpu()
torch.save(x_adv, f'{img_save_path}/x_adv.pt')
torch.save(y_adv, f'{img_save_path}/y_adv.pt')
torch.save(x_clean, f'{img_save_path}/x_clean.pt')
torch.save(y_clean, f'{img_save_path}/y_clean.pt')
with open(f'{img_save_path}/args.json', 'w') as f:
json.dump(vars(args), f)
with open(f'{img_save_path}/results.json', 'w') as f:
f.write(f"acc:{acc:.2f}%")
f.write(f"Racc:{racc:.2f}%")
# write to wandb
if run_train is not None:
# reload the run to make sure we have the latest summary
del api, run_train
api = wandb.Api()
run_train = api.run(f'{wandb_user}/{wandb_project}/{args.wandb_id}')
if args.dataset == 'imagenet':
assert args.norm == 'linf'
eps_descr = str(int(eps * 255))
if eps_descr == '4':
descr = dataset_short
else:
descr = f'{dataset_short}-eps{eps_descr}'
if n_samples != 5000:
acc = f'{acc:.2f}*'
racc = f'{racc:.2f}*'
elif args.dataset == 'cifar10':
if args.norm == 'linf':
descr = dataset_short
else:
descr = f'{dataset_short}-{args.norm}'
if n_samples != 10000:
acc = f'{acc:.2f}*'
racc = f'{racc:.2f}*'
else:
raise ValueError(f'Unknown dataset: {args.dataset}')
run_train.summary.update({f'aa/acc-{dataset_short}': acc})
run_train.summary.update({f'aa/racc-{descr}': racc})
run_train.summary.update()
run_train.update()
run_eval.finish()