import os import torch import gradio as gr import torchvision from PIL import Image from utils import * import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from huggingface_hub import Repository, upload_file from torch.utils.data import Dataset import numpy as np from collections import Counter n_epochs = 10 batch_size_train = 128 batch_size_test = 1000 learning_rate = 0.01 momentum = 0.5 log_interval = 10 random_seed = 1 TRAIN_CUTOFF = 5 WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF) METRIC_PATH = './metrics.json' REPOSITORY_DIR = "data" LOCAL_DIR = 'data_local' os.makedirs(LOCAL_DIR,exist_ok=True) HF_TOKEN = os.getenv("HF_TOKEN") HF_DATASET ="mnist-adversarial-dataset" DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}" repo = Repository( local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN ) repo.git_pull() torch.backends.cudnn.enabled = False torch.manual_seed(random_seed) class MNISTAdversarial_Dataset(Dataset): def __init__(self,data_dir,transform): repo.git_pull() self.data_dir = os.path.join(data_dir,'data') self.transform = transform files = [f.name for f in os.scandir(self.data_dir)] self.images = [] self.numbers = [] for f in files: self.FOLDER = os.path.join(os.path.join(self.data_dir,f)) metadata_path = os.path.join(self.FOLDER,'metadata.jsonl') image_path =os.path.join(self.FOLDER,'image.png') if os.path.exists(image_path) and os.path.exists(metadata_path): img = Image.open(image_path) self.images.append(img) metadata = read_json_lines(metadata_path) self.numbers.append(metadata[0]['correct_number']) assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers." def __len__(self): return len(self.images) def __getitem__(self,idx): img, label = self.images[idx], self.numbers[idx] img = self.transform(img) return img, label class MNISTCorrupted_By_Digit(Dataset): def __init__(self,transform,digit,limit=30): self.transform = transform self.digit = digit corrupted_dir="./mnist_c" files = [f.name for f in os.scandir(corrupted_dir)] images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files] labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files] self.data = np.vstack(images) self.labels = np.hstack(labels) assert (self.data.shape[0] == self.labels.shape[0]) mask = self.labels == self.digit data_masked = self.data[mask] # Just to be on the safe side, ensure limit is more than the minimum limit = min(limit,data_masked.shape[0]) self.data_for_use = data_masked[:limit] self.labels_for_use = self.labels[mask][:limit] assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0]) def __len__(self): return len(self.data_for_use) def __getitem__(self,idx): if torch.is_tensor(idx): idx = idx.tolist() image = self.data_for_use[idx] label = self.labels_for_use[idx] if self.transform: image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms image = self.transform(image_pil) return image, label class MNISTCorrupted(Dataset): def __init__(self,transform): self.transform = transform corrupted_dir="./mnist_c" files = [f.name for f in os.scandir(corrupted_dir)] images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files] labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files] self.data = np.vstack(images) self.labels = np.hstack(labels) assert (self.data.shape[0] == self.labels.shape[0]) def __len__(self): return len(self.data) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() image = self.data[idx] label = self.labels[idx] if self.transform: image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms image = self.transform(image_pil) return image, label TRAIN_TRANSFORM = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ]) ''' train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('files/', train=True, download=True, transform=TRAIN_TRANSFORM), batch_size=batch_size_train, shuffle=True) ''' test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM), batch_size=batch_size_test, shuffle=False) # Source: https://nextjournal.com/gkoehler/pytorch-mnist class MNIST_Model(nn.Module): def __init__(self): super(MNIST_Model, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) def train(epochs,network,optimizer,train_loader): train_losses=[] network.train() for epoch in range(epochs): for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = network(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_losses.append(loss.item()) torch.save(network.state_dict(), 'model.pth') torch.save(optimizer.state_dict(), 'optimizer.pth') def test(): test_losses=[] network.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = network(data) test_loss += F.nll_loss(output, target, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(test_loader.dataset) test_losses.append(test_loss) acc = 100. * correct / len(test_loader.dataset) acc = acc.item() test_metric = '〽Current test metric - Avg. loss: `{:.4f}`, Accuracy: `{}/{}` (`{:.0f}%`)\n'.format( test_loss, correct, len(test_loader.dataset),acc ) return test_metric,acc random_seed = 1 torch.backends.cudnn.enabled = False torch.manual_seed(random_seed) network = MNIST_Model() optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum) model_state_dict = 'model.pth' optimizer_state_dict = 'optmizer.pth' if os.path.exists(model_state_dict): network_state_dict = torch.load(model_state_dict) network.load_state_dict(network_state_dict) if os.path.exists(optimizer_state_dict): optimizer_state_dict = torch.load(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict) # Train #train(n_epochs,network,optimizer) def image_classifier(inp): """ It takes an image as input and returns a dictionary of class labels and their corresponding confidence scores. :param inp: the image to be classified :return: A dictionary of the class index and the confidence value. """ input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0) with torch.no_grad(): prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0) #pred_number = prediction.data.max(1, keepdim=True)[1] sorted_prediction = torch.sort(prediction,descending=True) confidences={} for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()): confidences.update({s:v}) return confidences def train_and_test(): # Train for one epoch and test train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM) train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True ) train(n_epochs,network,optimizer,train_loader) test_metric,test_acc = test() if os.path.exists(METRIC_PATH): metric_dict = read_json(METRIC_PATH) metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc] else: metric_dict={} metric_dict['all'] = [test_acc] for i in range(10): data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i) dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False) data_per_digit, label_per_digit = iter(dataloader_per_digit).next() output = network(data_per_digit) pred = output.data.max(1, keepdim=True)[1] correct = pred.eq(label_per_digit.data.view_as(pred)).sum() acc = 100. * correct / len(data_per_digit) acc=acc.item() if os.path.exists(METRIC_PATH): metric_dict[str(i)].append(acc) else: metric_dict[str(i)] = [acc] dump_json(thing=metric_dict,file=METRIC_PATH) return test_metric def flag(input_image,correct_result,adversarial_number): adversarial_number = 0 if None else adversarial_number metadata_name = get_unique_name() SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name) os.makedirs(SAVE_FILE_DIR,exist_ok=True) image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png') try: input_image.save(image_output_filename) except Exception: raise Exception(f"Had issues saving PIL image to file") # Write metadata.json to file json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl') metadata= {'id':metadata_name,'file_name':'image.png', 'correct_number':correct_result } dump_json(metadata,json_file_path) # Simply upload the image file and metadata using the hub's upload_file # Upload the image repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png')) _ = upload_file(path_or_fileobj = image_output_filename, path_in_repo =repo_image_path, repo_id=f'chrisjay/{HF_DATASET}', repo_type='dataset', token=HF_TOKEN ) # Upload the metadata repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl')) _ = upload_file(path_or_fileobj = json_file_path, path_in_repo =repo_json_path, repo_id=f'chrisjay/{HF_DATASET}', repo_type='dataset', token=HF_TOKEN ) adversarial_number+=1 output = f'
✔ ({adversarial_number}) Successfully saved your adversarial data.
' repo.git_pull() length_of_dataset = len([f for f in os.scandir("./data_mnist/data")]) test_metric = f" {DEFAULT_TEST_METRIC} " if length_of_dataset % TRAIN_CUTOFF ==0: test_metric_ = train_and_test() test_metric = f" {test_metric_} " output = f'
✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data!
' return output,test_metric,adversarial_number def get_number_dict(DATA_DIR): files = [f.name for f in os.scandir(DATA_DIR)] numbers = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl'))[0]['correct_number'] for f in files] numbers_count = Counter(numbers) numbers_count_keys = list(numbers_count.keys()) numbers_count_values = [numbers_count[k] for k in numbers_count_keys] return numbers_count_keys,numbers_count_values def get_statistics(): model_state_dict = 'model.pth' optimizer_state_dict = 'optmizer.pth' if os.path.exists(model_state_dict): network_state_dict = torch.load(model_state_dict) network.load_state_dict(network_state_dict) if os.path.exists(optimizer_state_dict): optimizer_state_dict = torch.load(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict) repo.git_pull() DATA_DIR = './data_mnist/data' numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR) plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits") fig_d, ax_d = plt.subplots(figsize=(10,4),tight_layout=True) if os.path.exists(METRIC_PATH): metric_dict = read_json(METRIC_PATH) for i in range(10): try: x_i = [i+1 for i in range(len(metric_dict[str(i)]))] ax_d.plot(x_i, metric_dict[str(i)],label=str(i)) except Exception: continue dump_json(thing=metric_dict,file=METRIC_PATH) else: metric_dict={} fig_d.legend() ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step") done_html = """

✅ Statistics loaded successfully!

""" return plt_digits,fig_d,done_html def main(): #block = gr.Blocks(css=BLOCK_CSS) block = gr.Blocks() with block: gr.Markdown(TITLE) gr.Markdown(description) with gr.Tabs(): with gr.TabItem('MNIST'): gr.Markdown(WHAT_TO_DO) test_metric = gr.outputs.HTML(DEFAULT_TEST_METRIC) with gr.Row(): image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil") label_output = gr.outputs.Label(num_top_classes=10) submit = gr.Button("Submit") gr.Markdown(MODEL_IS_WRONG) number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?") flag_btn = gr.Button("Flag") output_result = gr.outputs.HTML() adversarial_number = gr.Variable(value=0) submit.click(image_classifier,inputs = [image_input],outputs=[label_output]) flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,test_metric,adversarial_number]) with gr.TabItem('Dashboard') as dashboard: notification = gr.HTML("""

⌛ Creating statistics...

""") _,numbers_count_values_ = get_number_dict('./data_mnist/data') STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values_)) gr.Markdown(STATS_EXPLANATION_) stat_adv_image =gr.Plot(type="matplotlib") gr.Markdown(DASHBOARD_EXPLANATION) test_results=gr.Plot(type="matplotlib") dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification]) block.launch() if __name__ == "__main__": main()