import os import torch import gradio as gr import torchvision from PIL import Image from utils import * import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from huggingface_hub import Repository, upload_file from torch.utils.data import Dataset import numpy as np from collections import Counter with open('app.css','r') as f: BLOCK_CSS = f.read() n_epochs = 10 batch_size_train = 128 batch_size_test = 1000 learning_rate = 0.01 adv_learning_rate= 0.001 momentum = 0.5 log_interval = 10 random_seed = 1 TRAIN_CUTOFF = 10 TEST_PER_SAMPLE = 5000 DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE) WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF) MODEL_PATH = 'model' METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json') MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth') OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth') REPOSITORY_DIR = "data" LOCAL_DIR = 'data_local' os.makedirs(LOCAL_DIR,exist_ok=True) GET_STATISTICS_MESSAGE = "Get Statistics" HF_TOKEN = os.getenv("HF_TOKEN") MODEL_REPO = 'mnist-adversarial-model' HF_DATASET ="mnist-adversarial-dataset" DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}" MODEL_REPO_URL = f"https://huggingface.co/model/chrisjay/{MODEL_REPO}" repo = Repository( local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN ) repo.git_pull() model_repo = Repository( local_dir=MODEL_PATH, clone_from=MODEL_REPO_URL, use_auth_token=HF_TOKEN, repo_type="model" ) model_repo.git_pull() torch.backends.cudnn.enabled = False torch.manual_seed(random_seed) class MNISTAdversarial_Dataset(Dataset): def __init__(self,data_dir,transform): repo.git_pull() self.data_dir = os.path.join(data_dir,'data') self.transform = transform files = [f.name for f in os.scandir(self.data_dir)] self.images = [] self.numbers = [] for f in files: self.FOLDER = os.path.join(os.path.join(self.data_dir,f)) metadata_path = os.path.join(self.FOLDER,'metadata.jsonl') image_path =os.path.join(self.FOLDER,'image.png') if os.path.exists(image_path) and os.path.exists(metadata_path): metadata = read_json_lines(metadata_path) if metadata is not None: img = Image.open(image_path) self.images.append(img) self.numbers.append(metadata[0]['correct_number']) assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers." def __len__(self): return len(self.images) def __getitem__(self,idx): img, label = self.images[idx], self.numbers[idx] img = self.transform(img) return img, label class MNISTCorrupted_By_Digit(Dataset): def __init__(self,transform,digit,limit=TEST_PER_SAMPLE): self.transform = transform self.digit = digit corrupted_dir="./mnist_c" files = [f.name for f in os.scandir(corrupted_dir)] images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files] labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files] self.data = np.vstack(images) self.labels = np.hstack(labels) assert (self.data.shape[0] == self.labels.shape[0]) mask = self.labels == self.digit data_masked = self.data[mask] # Just to be on the safe side, ensure limit is more than the minimum limit = min(limit,data_masked.shape[0]) self.data_for_use = data_masked[:limit] self.labels_for_use = self.labels[mask][:limit] assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0]) def __len__(self): return len(self.data_for_use) def __getitem__(self,idx): if torch.is_tensor(idx): idx = idx.tolist() image = self.data_for_use[idx] label = self.labels_for_use[idx] if self.transform: image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms image = self.transform(image_pil) return image, label class MNISTCorrupted(Dataset): def __init__(self,transform): self.transform = transform corrupted_dir="./mnist_c" files = [f.name for f in os.scandir(corrupted_dir)] images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files] labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files] self.data = np.vstack(images) self.labels = np.hstack(labels) assert (self.data.shape[0] == self.labels.shape[0]) def __len__(self): return len(self.data) def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() image = self.data[idx] label = self.labels[idx] if self.transform: image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms image = self.transform(image_pil) return image, label TRAIN_TRANSFORM = torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( (0.1307,), (0.3081,)) ]) test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM), batch_size=batch_size_test, shuffle=False) # Source: https://nextjournal.com/gkoehler/pytorch-mnist class MNIST_Model(nn.Module): def __init__(self): super(MNIST_Model, self).__init__() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10) def forward(self, x): x = F.relu(F.max_pool2d(self.conv1(x), 2)) x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) x = x.view(-1, 320) x = F.relu(self.fc1(x)) x = F.dropout(x, training=self.training) x = self.fc2(x) return F.log_softmax(x) def train(epochs,network,optimizer,train_loader): train_losses=[] network.train() for epoch in range(epochs): for batch_idx, (data, target) in enumerate(train_loader): optimizer.zero_grad() output = network(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() if batch_idx % log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) train_losses.append(loss.item()) torch.save(network.state_dict(), MODEL_WEIGHTS_PATH) torch.save(optimizer.state_dict(), OPTIMIZER_PATH) def test(): test_losses=[] network.eval() test_loss = 0 correct = 0 with torch.no_grad(): for data, target in test_loader: output = network(data) test_loss += F.nll_loss(output, target, size_average=False).item() pred = output.data.max(1, keepdim=True)[1] correct += pred.eq(target.data.view_as(pred)).sum() test_loss /= len(test_loader.dataset) test_losses.append(test_loss) acc = 100. * correct / len(test_loader.dataset) acc = acc.item() test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format( test_loss,acc) print(test_metric) return test_metric,acc random_seed = 1 torch.backends.cudnn.enabled = False torch.manual_seed(random_seed) network = MNIST_Model() optimizer = optim.SGD(network.parameters(), lr=learning_rate, momentum=momentum) train_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('./files/', train=True, download=True, transform=TRAIN_TRANSFORM), batch_size=batch_size_train, shuffle=True) test_iid_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('./files/', train=False, download=True, transform=TRAIN_TRANSFORM), batch_size=batch_size_test, shuffle=True) model_state_dict = MODEL_WEIGHTS_PATH optimizer_state_dict = OPTIMIZER_PATH if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): network_state_dict = torch.load(model_state_dict) network.load_state_dict(network_state_dict) optimizer_state_dict = torch.load(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict) # Train model #n_epochs=20 #train(n_epochs,network,optimizer,train_loader) #test() def train_and_test(train_model=True): if train_model: # Train for one epoch and test train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM) train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True) train(n_epochs,network,optimizer,train_loader) test_metric,test_acc = test() network.eval() if os.path.exists(METRIC_PATH): metric_dict = read_json(METRIC_PATH) metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc] else: metric_dict={} metric_dict['all'] = [test_acc] for i in range(10): data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i) dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False) data_per_digit, label_per_digit = iter(dataloader_per_digit).next() output = network(data_per_digit) pred = output.data.max(1, keepdim=True)[1] correct = pred.eq(label_per_digit.data.view_as(pred)).sum() acc = 100. * correct / len(data_per_digit) acc=acc.item() if os.path.exists(METRIC_PATH): metric_dict[str(i)].append(acc) else: metric_dict[str(i)] = [acc] dump_json(thing=metric_dict,file=METRIC_PATH) # Push models and metrics to hub model_repo.push_to_hub() return test_metric # Update model weights again model_state_dict = MODEL_WEIGHTS_PATH optimizer_state_dict = OPTIMIZER_PATH model_repo.git_pull() if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): network_state_dict = torch.load(model_state_dict) network.load_state_dict(network_state_dict) optimizer_state_dict = torch.load(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict) else: # Use best weights BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth" BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth" network_state_dict = torch.load(BEST_WEIGHTS_MODEL) network.load_state_dict(network_state_dict) optimizer_state_dict = torch.load(BEST_WEIGHTS_OPTIMIZER) optimizer.load_state_dict(optimizer_state_dict) if not os.path.exists(METRIC_PATH): _ = train_and_test(False) def image_classifier(inp): """ It loads the latest model weights from the model repository, and then uses those weights to make a prediction on the input image. :param inp: the image to be classified :return: A dictionary of the form {class_number: confidence} """ # Get latest model weights ---------------- model_repo.git_pull() model_state_dict = MODEL_WEIGHTS_PATH optimizer_state_dict = OPTIMIZER_PATH which_weights='' if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict): which_weights = "Using weights from model repo" network_state_dict = torch.load(model_state_dict) network.load_state_dict(network_state_dict) optimizer_state_dict = torch.load(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict) else: # Use best weights which_weights = "Using default best weights" BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth" BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth" network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL)) optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER)) network.eval() input_image = TRAIN_TRANSFORM(inp).unsqueeze(0) with torch.no_grad(): prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0) #pred_number = prediction.data.max(1, keepdim=True)[1] sorted_prediction = torch.sort(prediction,descending=True) confidences={} for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()): confidences.update({s:v}) return confidences def flag(input_image,correct_result,adversarial_number): """ It takes in an image, the correct result, and the number of adversarial images that have been uploaded so far. It saves the image and metadata to a local directory, uploads the image and metadata to the hub, and then pulls the data from the hub to the local directory. If the number of images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the adversarial data :param input_image: The adversarial image that you want to save :param correct_result: The correct number that the image represents :param adversarial_number: This is the number of adversarial examples that have been uploaded to the dataset :return: The output is the output of the flag function. """ adversarial_number = 0 if None else adversarial_number metadata_name = get_unique_name() SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name) os.makedirs(SAVE_FILE_DIR,exist_ok=True) image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png') try: input_image.save(image_output_filename) except Exception: raise Exception(f"Had issues saving PIL image to file") # Write metadata.json to file json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl') metadata= {'id':metadata_name,'file_name':'image.png', 'correct_number':correct_result } dump_json(metadata,json_file_path) # Simply upload the image file and metadata using the hub's upload_file # Upload the image repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png')) _ = upload_file(path_or_fileobj = image_output_filename, path_in_repo =repo_image_path, repo_id=f'chrisjay/{HF_DATASET}', repo_type='dataset', token=HF_TOKEN ) # Upload the metadata repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl')) _ = upload_file(path_or_fileobj = json_file_path, path_in_repo =repo_json_path, repo_id=f'chrisjay/{HF_DATASET}', repo_type='dataset', token=HF_TOKEN ) adversarial_number+=1 output = f'
✔ ({adversarial_number}) Successfully saved your adversarial data.
' repo.git_pull() length_of_dataset = len([f for f in os.scandir("./data_mnist/data")]) test_metric = f" {DEFAULT_TEST_METRIC} " if length_of_dataset % TRAIN_CUTOFF ==0: test_metric_ = train_and_test() test_metric = f" {test_metric_} " output = f'
✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data!
' return output,adversarial_number def get_number_dict(DATA_DIR): """ It takes a directory as input, and returns a list of the number of times each number appears in the metadata.jsonl files in that directory :param DATA_DIR: The directory where the data is stored """ files = [f.name for f in os.scandir(DATA_DIR)] metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files] numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None] numbers_count = Counter(numbers) numbers_count_keys = list(numbers_count.keys()) numbers_count_values = [numbers_count[k] for k in numbers_count_keys] return numbers_count_keys,numbers_count_values def get_statistics(): """ It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the test accuracy per digit per train step, and plots the test accuracy for all digits per train step :return: the following: """ model_repo.git_pull() model_state_dict = MODEL_WEIGHTS_PATH optimizer_state_dict = OPTIMIZER_PATH if os.path.exists(model_state_dict): network_state_dict = torch.load(model_state_dict) network.load_state_dict(network_state_dict) if os.path.exists(optimizer_state_dict): optimizer_state_dict = torch.load(optimizer_state_dict) optimizer.load_state_dict(optimizer_state_dict) repo.git_pull() DATA_DIR = './data_mnist/data' numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR) STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values)) plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True) fig_d, ax_d = plt.subplots(tight_layout=True) if os.path.exists(METRIC_PATH): metric_dict = read_json(METRIC_PATH) for i in range(10): try: x_i = [i+1 for i in range(len(metric_dict[str(i)]))] ax_d.plot(x_i, metric_dict[str(i)],label=str(i)) except Exception: continue ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1)) else: metric_dict={} fig_d.legend() ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step") done_html = f"""

✅ Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.

""" # Plot for total test accuracy for all digits fig_all, ax_all = plt.subplots(tight_layout=True) x_i = [i+1 for i in range(len(metric_dict['all']))] ax_all.plot(x_i, metric_dict['all']) ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits") ax_all.set_xticks(range(0, x_i[-1]+1, 1)) return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_ def main(): block = gr.Blocks(css=BLOCK_CSS) with block: gr.Markdown(TITLE) gr.Markdown(description) with gr.Tabs(): with gr.TabItem('MNIST'): gr.Markdown(WHAT_TO_DO) #test_metric = gr.outputs.HTML("") with gr.Row(): image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil") label_output = gr.outputs.Label(num_top_classes=2) gr.Markdown(MODEL_IS_WRONG) number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?") gr.Markdown('Please wait a while after you press `Flag`. It takes time.') flag_btn = gr.Button("Flag") output_result = gr.outputs.HTML() adversarial_number = gr.Variable(value=0) image_input.change(image_classifier,inputs = [image_input],outputs=[label_output]) flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number]) with gr.TabItem('Dashboard') as dashboard: get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}') notification = gr.HTML(f"""

⌛ Click `{GET_STATISTICS_MESSAGE}` to generate statistics...

""") stats = gr.Markdown() stat_adv_image =gr.Plot(type="matplotlib") gr.Markdown(DASHBOARD_EXPLANATION) test_results=gr.Plot(type="matplotlib") gr.Markdown(DASHBOARD_EXPLANATION_TEST) test_results_all=gr.Plot(type="matplotlib") #dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats]) get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats]) block.launch() if __name__ == "__main__": main()