chrisjay's picture
Merge branch 'main' of https://huggingface.co/spaces/chrisjay/mnist-adversarial
6142233
raw history blame
No virus
21.7 kB
import os
import torch
import gradio as gr
import torchvision
from PIL import Image
from utils import *
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from huggingface_hub import Repository, upload_file
from torch.utils.data import Dataset
import numpy as np
from collections import Counter
with open('app.css','r') as f:
BLOCK_CSS = f.read()
n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
adv_learning_rate= 0.001
momentum = 0.5
log_interval = 10
random_seed = 1
TRAIN_CUTOFF = 10
TEST_PER_SAMPLE = 5000
DASHBOARD_EXPLANATION = DASHBOARD_EXPLANATION.format(TEST_PER_SAMPLE=TEST_PER_SAMPLE)
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
MODEL_PATH = 'model'
METRIC_PATH = os.path.join(MODEL_PATH,'metrics.json')
MODEL_WEIGHTS_PATH = os.path.join(MODEL_PATH,'mnist_model.pth')
OPTIMIZER_PATH = os.path.join(MODEL_PATH,'optimizer.pth')
REPOSITORY_DIR = "data"
LOCAL_DIR = 'data_local'
os.makedirs(LOCAL_DIR,exist_ok=True)
GET_STATISTICS_MESSAGE = "Get Statistics"
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_REPO = 'mnist-adversarial-model'
HF_DATASET ="mnist-adversarial-dataset"
DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
MODEL_REPO_URL = f"https://huggingface.co/model/chrisjay/{MODEL_REPO}"
repo = Repository(
local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)
repo.git_pull()
model_repo = Repository(
local_dir=MODEL_PATH, clone_from=MODEL_REPO_URL, use_auth_token=HF_TOKEN, repo_type="model"
)
model_repo.git_pull()
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
class MNISTAdversarial_Dataset(Dataset):
def __init__(self,data_dir,transform):
repo.git_pull()
self.data_dir = os.path.join(data_dir,'data')
self.transform = transform
files = [f.name for f in os.scandir(self.data_dir)]
self.images = []
self.numbers = []
for f in files:
self.FOLDER = os.path.join(os.path.join(self.data_dir,f))
metadata_path = os.path.join(self.FOLDER,'metadata.jsonl')
image_path =os.path.join(self.FOLDER,'image.png')
if os.path.exists(image_path) and os.path.exists(metadata_path):
metadata = read_json_lines(metadata_path)
if metadata is not None:
img = Image.open(image_path)
self.images.append(img)
self.numbers.append(metadata[0]['correct_number'])
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
img, label = self.images[idx], self.numbers[idx]
img = self.transform(img)
return img, label
class MNISTCorrupted_By_Digit(Dataset):
def __init__(self,transform,digit,limit=TEST_PER_SAMPLE):
self.transform = transform
self.digit = digit
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
mask = self.labels == self.digit
data_masked = self.data[mask]
# Just to be on the safe side, ensure limit is more than the minimum
limit = min(limit,data_masked.shape[0])
self.data_for_use = data_masked[:limit]
self.labels_for_use = self.labels[mask][:limit]
assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0])
def __len__(self):
return len(self.data_for_use)
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data_for_use[idx]
label = self.labels_for_use[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
class MNISTCorrupted(Dataset):
def __init__(self,transform):
self.transform = transform
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:TEST_PER_SAMPLE] for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:TEST_PER_SAMPLE] for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data[idx]
label = self.labels[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
TRAIN_TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])
test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM),
batch_size=batch_size_test, shuffle=False)
# Source: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
def __init__(self):
super(MNIST_Model, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def train(epochs,network,optimizer,train_loader):
train_losses=[]
network.train()
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
torch.save(network.state_dict(), MODEL_WEIGHTS_PATH)
torch.save(optimizer.state_dict(), OPTIMIZER_PATH)
def test():
test_losses=[]
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
acc = 100. * correct / len(test_loader.dataset)
acc = acc.item()
test_metric = '〽Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
test_loss,acc)
print(test_metric)
return test_metric,acc
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
network = MNIST_Model()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
momentum=momentum)
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./files/', train=True, download=True,
transform=TRAIN_TRANSFORM),
batch_size=batch_size_train, shuffle=True)
test_iid_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./files/', train=False, download=True,
transform=TRAIN_TRANSFORM),
batch_size=batch_size_test, shuffle=True)
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
# Train model
#n_epochs=20
#train(n_epochs,network,optimizer,train_loader)
#test()
def train_and_test(train_model=True):
if train_model:
# Train for one epoch and test
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True)
train(n_epochs,network,optimizer,train_loader)
test_metric,test_acc = test()
network.eval()
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
metric_dict['all'] = metric_dict['all']+ [test_acc] if 'all' in metric_dict else [] + [test_acc]
else:
metric_dict={}
metric_dict['all'] = [test_acc]
for i in range(10):
data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i)
dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False)
data_per_digit, label_per_digit = iter(dataloader_per_digit).next()
output = network(data_per_digit)
pred = output.data.max(1, keepdim=True)[1]
correct = pred.eq(label_per_digit.data.view_as(pred)).sum()
acc = 100. * correct / len(data_per_digit)
acc=acc.item()
if os.path.exists(METRIC_PATH):
metric_dict[str(i)].append(acc)
else:
metric_dict[str(i)] = [acc]
dump_json(thing=metric_dict,file=METRIC_PATH)
# Push models and metrics to hub
model_repo.push_to_hub()
return test_metric
# Update model weights again
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
model_repo.git_pull()
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
else:
# Use best weights
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
network_state_dict = torch.load(BEST_WEIGHTS_MODEL)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(BEST_WEIGHTS_OPTIMIZER)
optimizer.load_state_dict(optimizer_state_dict)
if not os.path.exists(METRIC_PATH):
_ = train_and_test(False)
def image_classifier(inp):
"""
It loads the latest model weights from the model repository, and then uses those weights to make a
prediction on the input image.
:param inp: the image to be classified
:return: A dictionary of the form {class_number: confidence}
"""
# Get latest model weights ----------------
model_repo.git_pull()
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
which_weights=''
if os.path.exists(model_state_dict) and os.path.exists(optimizer_state_dict):
which_weights = "Using weights from model repo"
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
else:
# Use best weights
which_weights = "Using default best weights"
BEST_WEIGHTS_MODEL = "best_weights/mnist_model.pth"
BEST_WEIGHTS_OPTIMIZER = "best_weights/optimizer.pth"
network.load_state_dict(torch.load(BEST_WEIGHTS_MODEL))
optimizer.load_state_dict(torch.load(BEST_WEIGHTS_OPTIMIZER))
network.eval()
input_image = TRAIN_TRANSFORM(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
#pred_number = prediction.data.max(1, keepdim=True)[1]
sorted_prediction = torch.sort(prediction,descending=True)
confidences={}
for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
confidences.update({s:v})
return confidences
def flag(input_image,correct_result,adversarial_number):
"""
It takes in an image, the correct result, and the number of adversarial images that have been
uploaded so far. It saves the image and metadata to a local directory, uploads the image and
metadata to the hub, and then pulls the data from the hub to the local directory. If the number of
images in the local directory is divisible by the TRAIN_CUTOFF, then it trains the model on the
adversarial data
:param input_image: The adversarial image that you want to save
:param correct_result: The correct number that the image represents
:param adversarial_number: This is the number of adversarial examples that have been uploaded to the
dataset
:return: The output is the output of the flag function.
"""
adversarial_number = 0 if None else adversarial_number
metadata_name = get_unique_name()
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name)
os.makedirs(SAVE_FILE_DIR,exist_ok=True)
image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png')
try:
input_image.save(image_output_filename)
except Exception:
raise Exception(f"Had issues saving PIL image to file")
# Write metadata.json to file
json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl')
metadata= {'id':metadata_name,'file_name':'image.png',
'correct_number':correct_result
}
dump_json(metadata,json_file_path)
# Simply upload the image file and metadata using the hub's upload_file
# Upload the image
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png'))
_ = upload_file(path_or_fileobj = image_output_filename,
path_in_repo =repo_image_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
# Upload the metadata
repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl'))
_ = upload_file(path_or_fileobj = json_file_path,
path_in_repo =repo_json_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
adversarial_number+=1
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data. </div>'
repo.git_pull()
length_of_dataset = len([f for f in os.scandir("./data_mnist/data")])
test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>"
if length_of_dataset % TRAIN_CUTOFF ==0:
test_metric_ = train_and_test()
test_metric = f"<html> {test_metric_} </html>"
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
return output,adversarial_number
def get_number_dict(DATA_DIR):
"""
It takes a directory as input, and returns a list of the number of times each number appears in the
metadata.jsonl files in that directory
:param DATA_DIR: The directory where the data is stored
"""
files = [f.name for f in os.scandir(DATA_DIR)]
metadata_jsons = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl')) for f in files]
numbers = [m[0]['correct_number'] for m in metadata_jsons if m is not None]
numbers_count = Counter(numbers)
numbers_count_keys = list(numbers_count.keys())
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
return numbers_count_keys,numbers_count_values
def get_statistics():
"""
It loads the model and optimizer state dicts, pulls the latest data from the repo, gets the number
of adversarial samples per digit, plots the distribution of adversarial samples per digit, plots the
test accuracy per digit per train step, and plots the test accuracy for all digits per train step
:return: the following:
"""
model_repo.git_pull()
model_state_dict = MODEL_WEIGHTS_PATH
optimizer_state_dict = OPTIMIZER_PATH
if os.path.exists(model_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
if os.path.exists(optimizer_state_dict):
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
repo.git_pull()
DATA_DIR = './data_mnist/data'
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples per digit",True)
fig_d, ax_d = plt.subplots(tight_layout=True)
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
for i in range(10):
try:
x_i = [i+1 for i in range(len(metric_dict[str(i)]))]
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
except Exception:
continue
ax_d.set_xticks(range(0, len(metric_dict['0'])+1, 1))
else:
metric_dict={}
fig_d.legend()
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
done_html = f"""<div style="color: green">
<p> ✅ Statistics loaded successfully! Click `{GET_STATISTICS_MESSAGE}`to reload.</p>
</div>
"""
# Plot for total test accuracy for all digits
fig_all, ax_all = plt.subplots(tight_layout=True)
x_i = [i+1 for i in range(len(metric_dict['all']))]
ax_all.plot(x_i, metric_dict['all'])
ax_all.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy for all digits")
ax_all.set_xticks(range(0, x_i[-1]+1, 1))
return plt_digits,ax_d.figure,ax_all.figure,done_html,STATS_EXPLANATION_
def main():
block = gr.Blocks(css=BLOCK_CSS)
with block:
gr.Markdown(TITLE)
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem('MNIST'):
gr.Markdown(WHAT_TO_DO)
#test_metric = gr.outputs.HTML("")
with gr.Row():
image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil")
label_output = gr.outputs.Label(num_top_classes=2)
gr.Markdown(MODEL_IS_WRONG)
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
gr.Markdown('Please wait a while after you press `Flag`. It takes time.')
flag_btn = gr.Button("Flag")
output_result = gr.outputs.HTML()
adversarial_number = gr.Variable(value=0)
image_input.change(image_classifier,inputs = [image_input],outputs=[label_output])
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
with gr.TabItem('Dashboard') as dashboard:
get_stat = gr.Button(f'{GET_STATISTICS_MESSAGE}')
notification = gr.HTML(f"""<div style="color: green">
<p> ⌛ Click `{GET_STATISTICS_MESSAGE}` to generate statistics... </p>
</div>
""")
stats = gr.Markdown()
stat_adv_image =gr.Plot(type="matplotlib")
gr.Markdown(DASHBOARD_EXPLANATION)
test_results=gr.Plot(type="matplotlib")
gr.Markdown(DASHBOARD_EXPLANATION_TEST)
test_results_all=gr.Plot(type="matplotlib")
#dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
get_stat.click(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,test_results_all,notification,stats])
block.launch()
if __name__ == "__main__":
main()