chrisjay's picture
reset model weights and deleted metrics
213a820
raw
history blame
16.3 kB
import os
import torch
import gradio as gr
import torchvision
from PIL import Image
from utils import *
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from huggingface_hub import Repository, upload_file
from torch.utils.data import Dataset
import numpy as np
from collections import Counter
with open('app.css','r') as f:
BLOCK_CSS = f.read()
n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
TRAIN_CUTOFF = 10
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
METRIC_PATH = './metrics.json'
REPOSITORY_DIR = "data"
LOCAL_DIR = 'data_local'
os.makedirs(LOCAL_DIR,exist_ok=True)
HF_TOKEN = os.getenv("HF_TOKEN")
HF_DATASET ="mnist-adversarial-dataset"
DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
repo = Repository(
local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)
repo.git_pull()
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
class MNISTAdversarial_Dataset(Dataset):
def __init__(self,data_dir,transform):
repo.git_pull()
self.data_dir = os.path.join(data_dir,'data')
self.transform = transform
files = [f.name for f in os.scandir(self.data_dir)]
self.images = []
self.numbers = []
for f in files:
self.FOLDER = os.path.join(os.path.join(self.data_dir,f))
metadata_path = os.path.join(self.FOLDER,'metadata.jsonl')
image_path =os.path.join(self.FOLDER,'image.png')
if os.path.exists(image_path) and os.path.exists(metadata_path):
img = Image.open(image_path)
self.images.append(img)
metadata = read_json_lines(metadata_path)
self.numbers.append(metadata[0]['correct_number'])
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
img, label = self.images[idx], self.numbers[idx]
img = self.transform(img)
return img, label
class MNISTCorrupted_By_Digit(Dataset):
def __init__(self,transform,digit,limit=300):
self.transform = transform
self.digit = digit
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
mask = self.labels == self.digit
data_masked = self.data[mask]
# Just to be on the safe side, ensure limit is more than the minimum
limit = min(limit,data_masked.shape[0])
self.data_for_use = data_masked[:limit]
self.labels_for_use = self.labels[mask][:limit]
assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0])
def __len__(self):
return len(self.data_for_use)
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data_for_use[idx]
label = self.labels_for_use[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
class MNISTCorrupted(Dataset):
def __init__(self,transform):
self.transform = transform
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy'))[:300] for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy'))[:300] for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data[idx]
label = self.labels[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
TRAIN_TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])
test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM),
batch_size=batch_size_test, shuffle=False)
# Source: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
def __init__(self):
super(MNIST_Model, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def train(epochs,network,optimizer,train_loader):
train_losses=[]
network.train()
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
torch.save(network.state_dict(), 'model.pth')
torch.save(optimizer.state_dict(), 'optimizer.pth')
def test():
test_losses=[]
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
acc = 100. * correct / len(test_loader.dataset)
acc = acc.item()
test_metric = 'γ€½Current test metric -> Avg. loss: `{:.4f}`, Accuracy: `{:.0f}%`\n'.format(
test_loss,acc)
return test_metric,acc
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
network = MNIST_Model()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
momentum=momentum)
model_state_dict = 'model.pth'
optimizer_state_dict = 'optmizer.pth'
if os.path.exists(model_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
if os.path.exists(optimizer_state_dict):
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
# Train
#train(n_epochs,network,optimizer)
def image_classifier(inp):
"""
It takes an image as input and returns a dictionary of class labels and their corresponding
confidence scores.
:param inp: the image to be classified
:return: A dictionary of the class index and the confidence value.
"""
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
#pred_number = prediction.data.max(1, keepdim=True)[1]
sorted_prediction = torch.sort(prediction,descending=True)
confidences={}
for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
confidences.update({s:v})
return confidences
def train_and_test():
# Train for one epoch and test
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
)
train(n_epochs,network,optimizer,train_loader)
test_metric,test_acc = test()
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc]
else:
metric_dict={}
metric_dict['all'] = [test_acc]
for i in range(10):
data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i)
dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False)
data_per_digit, label_per_digit = iter(dataloader_per_digit).next()
output = network(data_per_digit)
pred = output.data.max(1, keepdim=True)[1]
correct = pred.eq(label_per_digit.data.view_as(pred)).sum()
acc = 100. * correct / len(data_per_digit)
acc=acc.item()
if os.path.exists(METRIC_PATH):
metric_dict[str(i)].append(acc)
else:
metric_dict[str(i)] = [acc]
dump_json(thing=metric_dict,file=METRIC_PATH)
return test_metric
def flag(input_image,correct_result,adversarial_number):
adversarial_number = 0 if None else adversarial_number
metadata_name = get_unique_name()
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name)
os.makedirs(SAVE_FILE_DIR,exist_ok=True)
image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png')
try:
input_image.save(image_output_filename)
except Exception:
raise Exception(f"Had issues saving PIL image to file")
# Write metadata.json to file
json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl')
metadata= {'id':metadata_name,'file_name':'image.png',
'correct_number':correct_result
}
dump_json(metadata,json_file_path)
# Simply upload the image file and metadata using the hub's upload_file
# Upload the image
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png'))
_ = upload_file(path_or_fileobj = image_output_filename,
path_in_repo =repo_image_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
# Upload the metadata
repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl'))
_ = upload_file(path_or_fileobj = json_file_path,
path_in_repo =repo_json_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
adversarial_number+=1
output = f'<div> βœ” ({adversarial_number}) Successfully saved your adversarial data. </div>'
repo.git_pull()
length_of_dataset = len([f for f in os.scandir("./data_mnist/data")])
test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>"
if length_of_dataset % TRAIN_CUTOFF ==0:
test_metric_ = train_and_test()
test_metric = f"<html> {test_metric_} </html>"
output = f'<div> βœ” ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
return output,adversarial_number
def get_number_dict(DATA_DIR):
files = [f.name for f in os.scandir(DATA_DIR)]
numbers = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl'))[0]['correct_number'] for f in files]
numbers_count = Counter(numbers)
numbers_count_keys = list(numbers_count.keys())
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
return numbers_count_keys,numbers_count_values
def get_statistics():
model_state_dict = 'model.pth'
optimizer_state_dict = 'optmizer.pth'
if os.path.exists(model_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
if os.path.exists(optimizer_state_dict):
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
repo.git_pull()
DATA_DIR = './data_mnist/data'
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values))
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
fig_d, ax_d = plt.subplots(tight_layout=True)
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
for i in range(10):
try:
x_i = [i+1 for i in range(len(metric_dict[str(i)]))]
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
except Exception:
continue
dump_json(thing=metric_dict,file=METRIC_PATH)
else:
metric_dict={}
fig_d.legend()
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
done_html = """<div style="color: green">
<p> βœ… Statistics loaded successfully!</p>
</div>
"""
return plt_digits,fig_d,done_html,STATS_EXPLANATION_
def main():
block = gr.Blocks(css=BLOCK_CSS)
with block:
gr.Markdown(TITLE)
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem('MNIST'):
gr.Markdown(WHAT_TO_DO)
#test_metric = gr.outputs.HTML("")
with gr.Row():
image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil")
label_output = gr.outputs.Label(num_top_classes=2)
submit = gr.Button("Submit")
gr.Markdown(MODEL_IS_WRONG)
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
flag_btn = gr.Button("Flag")
output_result = gr.outputs.HTML()
adversarial_number = gr.Variable(value=0)
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,adversarial_number])
with gr.TabItem('Dashboard') as dashboard:
notification = gr.HTML("""<div style="color: green">
<p> βŒ› Creating statistics... </p>
</div>
""")
stats = gr.Markdown()
stat_adv_image =gr.Plot(type="matplotlib")
gr.Markdown(DASHBOARD_EXPLANATION)
test_results=gr.Plot(type="matplotlib")
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification,stats])
block.launch()
if __name__ == "__main__":
main()