Spaces:
Sleeping
Sleeping
File size: 6,224 Bytes
71f183c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
from typing import Any
import os
import torch
import ignite.distributed as idist
import torchvision
import torchvision.transforms as T
from torch.utils import data as torch_data
from .classification_wrapper import TopKClassificationWrapper
from torch.utils.data import Subset
from modelguidedattacks.data import get_dataset
from modelguidedattacks.cls_models.accuracy import get_correct_subset_for_models, DATASET_METADATA_DIR
from tqdm import tqdm
def get_gt_labels(dataset: TopKClassificationWrapper, train:bool, dataset_name:str):
training_str = "train" if train else "val"
save_name = os.path.join(DATASET_METADATA_DIR, f"{dataset_name}_labels_{training_str}.p")
if os.path.exists(save_name):
print ("Found labels cache")
return torch.load(save_name)
dataloader = torch_data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)
gt_labels = []
for batch in tqdm(dataloader):
gt_labels.extend(batch[1].tolist())
gt_labels = torch.tensor(gt_labels)
torch.save(gt_labels, save_name)
return gt_labels
def class_balanced_sampling(dataset, gt_labels: torch.Tensor,
correct_labels: list, total_samples=1000):
num_classes = len(dataset.classes)
correct_labels = torch.tensor(correct_labels)
correct_mask = torch.zeros((len(dataset), ), dtype=torch.bool)
correct_mask[correct_labels] = True
sampled_indices = 0
total_sampled_indices = 0
sampled_indices = [[] for i in range(num_classes)]
shuffled_inds = torch.randperm(len(dataset))
for sample_cnt, sample_i in enumerate(shuffled_inds):
if not correct_mask[sample_i]:
continue
sample_class = gt_labels[sample_i]
desired_samples_in_class = (total_sampled_indices // num_classes) + 1
if len(sampled_indices[sample_class]) < desired_samples_in_class:
sampled_indices[sample_class].append(sample_i.item())
total_sampled_indices += 1
if total_sampled_indices >= total_samples:
break
flattened_indices = []
for class_samples in sampled_indices:
flattened_indices.extend(class_samples)
return torch.tensor(flattened_indices)
def sample_attack_labels(dataset, gt_labels, k, sampler):
"""
dataset: Dataset we're generating attack labels for
gt_labels: List of gt idx for each sample in a dataset
k: attack size
sampler: ["random"]
"""
# Sample from uniform and argsort to simulate
# a batched randperm
attack_label_uniforms = torch.rand((len(gt_labels), len(dataset.classes)))
# We don't want to sample the gt class for any samples
batch_inds = torch.arange(len(gt_labels))
attack_label_uniforms[batch_inds, gt_labels] = -1.
attack_labels = attack_label_uniforms.argsort(dim=-1, descending=True)[:, :k]
return attack_labels
def setup_data(config: Any, rank):
"""Download datasets and create dataloaders
Parameters
----------
config: needs to contain `data_path`, `train_batch_size`, `eval_batch_size`, and `num_workers`
"""
dataset_train, dataset_eval = get_dataset(config.dataset)
train_subset = None
val_subset = None
attack_labels_train = None
attack_labels_val = None
if rank == 0:
gt_labels_train = get_gt_labels(dataset_train, True, config.dataset)
gt_labels_val = get_gt_labels(dataset_eval, False, config.dataset)
attack_labels_train = sample_attack_labels(dataset_train, gt_labels_train, k=config.k,
sampler=config.attack_sampling)
attack_labels_val = sample_attack_labels(dataset_eval, gt_labels_val, k=config.k,
sampler=config.attack_sampling)
device = "cuda" if torch.cuda.is_available() else "cpu"
correct_train_set = get_correct_subset_for_models(config.compare_models,
config.dataset, device,
train=True)
correct_eval_set = get_correct_subset_for_models(config.compare_models,
config.dataset, device,
train=False)
# Balanced sampling
train_subset = class_balanced_sampling(dataset_train, gt_labels_train,
correct_train_set)
val_subset = class_balanced_sampling(dataset_eval, gt_labels_val,
correct_eval_set)
if config.overfit:
rand_inds = torch.randperm(len(val_subset))[:16]
train_subset = train_subset[rand_inds]
val_subset = val_subset[rand_inds]
train_subset = idist.broadcast(train_subset, safe_mode=True)
val_subset = idist.broadcast(val_subset, safe_mode=True)
attack_labels_train = idist.broadcast(attack_labels_train, safe_mode=True)
attack_labels_val = idist.broadcast(attack_labels_val, safe_mode=True)
dataset_train = TopKClassificationWrapper(dataset_train, k=config.k,
attack_labels=attack_labels_train)
dataset_eval = TopKClassificationWrapper(dataset_eval, k=config.k,
attack_labels=attack_labels_val)
dataset_train = Subset(dataset_train, train_subset)
dataset_eval = Subset(dataset_eval, val_subset)
# if config.overfit:
# dataset_train = Subset(dataset_train, range(2))
# dataset_eval = dataset_train
# else:
# dataset_eval = Subset(dataset_eval, torch.randperm(len(dataset_eval))[:1000].tolist() )
dataloader_train = idist.auto_dataloader(
dataset_train,
batch_size=config.train_batch_size,
shuffle=not config.overfit,
num_workers=config.num_workers,
)
dataloader_eval = idist.auto_dataloader(
dataset_eval,
batch_size=config.eval_batch_size,
shuffle=True,
num_workers=config.num_workers,
)
return dataloader_train, dataloader_eval
|