chrisjay's picture
work on trainin and dashboard statistics
866cafe
raw history blame
No virus
16.6 kB
import os
import torch
import gradio as gr
import torchvision
from PIL import Image
from utils import *
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from huggingface_hub import Repository, upload_file
from torch.utils.data import Dataset
import numpy as np
from collections import Counter
n_epochs = 10
batch_size_train = 128
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10
random_seed = 1
TRAIN_CUTOFF = 5
WHAT_TO_DO=WHAT_TO_DO.format(num_samples=TRAIN_CUTOFF)
METRIC_PATH = './metrics.json'
REPOSITORY_DIR = "data"
LOCAL_DIR = 'data_local'
os.makedirs(LOCAL_DIR,exist_ok=True)
HF_TOKEN = os.getenv("HF_TOKEN")
HF_DATASET ="mnist-adversarial-dataset"
DATASET_REPO_URL = f"https://huggingface.co/datasets/chrisjay/{HF_DATASET}"
repo = Repository(
local_dir="data_mnist", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)
repo.git_pull()
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
class MNISTAdversarial_Dataset(Dataset):
def __init__(self,data_dir,transform):
repo.git_pull()
self.data_dir = os.path.join(data_dir,'data')
self.transform = transform
files = [f.name for f in os.scandir(self.data_dir)]
self.images = []
self.numbers = []
for f in files:
self.FOLDER = os.path.join(os.path.join(self.data_dir,f))
metadata_path = os.path.join(self.FOLDER,'metadata.jsonl')
image_path =os.path.join(self.FOLDER,'image.png')
if os.path.exists(image_path) and os.path.exists(metadata_path):
img = Image.open(image_path)
self.images.append(img)
metadata = read_json_lines(metadata_path)
self.numbers.append(metadata[0]['correct_number'])
assert len(self.images)==len(self.numbers), f"Length of images and numbers must be the same. Got {len(self.images)} for images and {len(self.numbers)} for numbers."
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
img, label = self.images[idx], self.numbers[idx]
img = self.transform(img)
return img, label
class MNISTCorrupted_By_Digit(Dataset):
def __init__(self,transform,digit,limit=30):
self.transform = transform
self.digit = digit
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
mask = self.labels == self.digit
data_masked = self.data[mask]
# Just to be on the safe side, ensure limit is more than the minimum
limit = min(limit,data_masked.shape[0])
self.data_for_use = data_masked[:limit]
self.labels_for_use = self.labels[mask][:limit]
assert (self.data_for_use.shape[0] == self.labels_for_use.shape[0])
def __len__(self):
return len(self.data_for_use)
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data_for_use[idx]
label = self.labels_for_use[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
class MNISTCorrupted(Dataset):
def __init__(self,transform):
self.transform = transform
corrupted_dir="./mnist_c"
files = [f.name for f in os.scandir(corrupted_dir)]
images = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_images.npy')) for f in files]
labels = [np.load(os.path.join(os.path.join(corrupted_dir,f),'test_labels.npy')) for f in files]
self.data = np.vstack(images)
self.labels = np.hstack(labels)
assert (self.data.shape[0] == self.labels.shape[0])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
image = self.data[idx]
label = self.labels[idx]
if self.transform:
image_pil = torchvision.transforms.ToPILImage()(image) # Need to transform to PIL before using default transforms
image = self.transform(image_pil)
return image, label
TRAIN_TRANSFORM = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
(0.1307,), (0.3081,))
])
'''
train_loader = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('files/', train=True, download=True,
transform=TRAIN_TRANSFORM),
batch_size=batch_size_train, shuffle=True)
'''
test_loader = torch.utils.data.DataLoader(MNISTCorrupted(TRAIN_TRANSFORM),
batch_size=batch_size_test, shuffle=False)
# Source: https://nextjournal.com/gkoehler/pytorch-mnist
class MNIST_Model(nn.Module):
def __init__(self):
super(MNIST_Model, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
def train(epochs,network,optimizer,train_loader):
train_losses=[]
network.train()
for epoch in range(epochs):
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
output = network(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
train_losses.append(loss.item())
torch.save(network.state_dict(), 'model.pth')
torch.save(optimizer.state_dict(), 'optimizer.pth')
def test():
test_losses=[]
network.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
output = network(data)
test_loss += F.nll_loss(output, target, size_average=False).item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
test_loss /= len(test_loader.dataset)
test_losses.append(test_loss)
acc = 100. * correct / len(test_loader.dataset)
acc = acc.item()
test_metric = '〽Current test metric - Avg. loss: `{:.4f}`, Accuracy: `{}/{}` (`{:.0f}%`)\n'.format(
test_loss, correct, len(test_loader.dataset),acc )
return test_metric,acc
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)
network = MNIST_Model()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
momentum=momentum)
model_state_dict = 'model.pth'
optimizer_state_dict = 'optmizer.pth'
if os.path.exists(model_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
if os.path.exists(optimizer_state_dict):
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
# Train
#train(n_epochs,network,optimizer)
def image_classifier(inp):
"""
It takes an image as input and returns a dictionary of class labels and their corresponding
confidence scores.
:param inp: the image to be classified
:return: A dictionary of the class index and the confidence value.
"""
input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
with torch.no_grad():
prediction = torch.nn.functional.softmax(network(input_image)[0], dim=0)
#pred_number = prediction.data.max(1, keepdim=True)[1]
sorted_prediction = torch.sort(prediction,descending=True)
confidences={}
for s,v in zip(sorted_prediction.indices.numpy().tolist(),sorted_prediction.values.numpy().tolist()):
confidences.update({s:v})
return confidences
def train_and_test():
# Train for one epoch and test
train_dataset = MNISTAdversarial_Dataset('./data_mnist',TRAIN_TRANSFORM)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size_test, shuffle=True
)
train(n_epochs,network,optimizer,train_loader)
test_metric,test_acc = test()
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
metric_dict['all'] = metric_dict['all'] if 'all' in metric_dict else [] + [test_acc]
else:
metric_dict={}
metric_dict['all'] = [test_acc]
for i in range(10):
data_per_digit = MNISTCorrupted_By_Digit(TRAIN_TRANSFORM,i)
dataloader_per_digit = torch.utils.data.DataLoader(data_per_digit,batch_size=len(data_per_digit), shuffle=False)
data_per_digit, label_per_digit = iter(dataloader_per_digit).next()
output = network(data_per_digit)
pred = output.data.max(1, keepdim=True)[1]
correct = pred.eq(label_per_digit.data.view_as(pred)).sum()
acc = 100. * correct / len(data_per_digit)
acc=acc.item()
if os.path.exists(METRIC_PATH):
metric_dict[str(i)].append(acc)
else:
metric_dict[str(i)] = [acc]
dump_json(thing=metric_dict,file=METRIC_PATH)
return test_metric
def flag(input_image,correct_result,adversarial_number):
adversarial_number = 0 if None else adversarial_number
metadata_name = get_unique_name()
SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name)
os.makedirs(SAVE_FILE_DIR,exist_ok=True)
image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png')
try:
input_image.save(image_output_filename)
except Exception:
raise Exception(f"Had issues saving PIL image to file")
# Write metadata.json to file
json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl')
metadata= {'id':metadata_name,'file_name':'image.png',
'correct_number':correct_result
}
dump_json(metadata,json_file_path)
# Simply upload the image file and metadata using the hub's upload_file
# Upload the image
repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png'))
_ = upload_file(path_or_fileobj = image_output_filename,
path_in_repo =repo_image_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
# Upload the metadata
repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl'))
_ = upload_file(path_or_fileobj = json_file_path,
path_in_repo =repo_json_path,
repo_id=f'chrisjay/{HF_DATASET}',
repo_type='dataset',
token=HF_TOKEN
)
adversarial_number+=1
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data. </div>'
repo.git_pull()
length_of_dataset = len([f for f in os.scandir("./data_mnist/data")])
test_metric = f"<html> {DEFAULT_TEST_METRIC} </html>"
if length_of_dataset % TRAIN_CUTOFF ==0:
test_metric_ = train_and_test()
test_metric = f"<html> {test_metric_} </html>"
output = f'<div> ✔ ({adversarial_number}) Successfully saved your adversarial data and trained the model on adversarial data! </div>'
return output,test_metric,adversarial_number
def get_number_dict(DATA_DIR):
files = [f.name for f in os.scandir(DATA_DIR)]
numbers = [read_json_lines(os.path.join(os.path.join(DATA_DIR,f),'metadata.jsonl'))[0]['correct_number'] for f in files]
numbers_count = Counter(numbers)
numbers_count_keys = list(numbers_count.keys())
numbers_count_values = [numbers_count[k] for k in numbers_count_keys]
return numbers_count_keys,numbers_count_values
def get_statistics():
model_state_dict = 'model.pth'
optimizer_state_dict = 'optmizer.pth'
if os.path.exists(model_state_dict):
network_state_dict = torch.load(model_state_dict)
network.load_state_dict(network_state_dict)
if os.path.exists(optimizer_state_dict):
optimizer_state_dict = torch.load(optimizer_state_dict)
optimizer.load_state_dict(optimizer_state_dict)
repo.git_pull()
DATA_DIR = './data_mnist/data'
numbers_count_keys,numbers_count_values = get_number_dict(DATA_DIR)
plt_digits = plot_bar(numbers_count_values,numbers_count_keys,'Number of adversarial samples',"Digit",f"Distribution of adversarial samples over digits")
fig_d, ax_d = plt.subplots(figsize=(10,4),tight_layout=True)
if os.path.exists(METRIC_PATH):
metric_dict = read_json(METRIC_PATH)
for i in range(10):
try:
x_i = [i+1 for i in range(len(metric_dict[str(i)]))]
ax_d.plot(x_i, metric_dict[str(i)],label=str(i))
except Exception:
continue
dump_json(thing=metric_dict,file=METRIC_PATH)
else:
metric_dict={}
fig_d.legend()
ax_d.set(xlabel='Adversarial train steps', ylabel='MNIST_C Test Accuracy',title="Test Accuracy over digits per train step")
done_html = """<div style="color: green">
<p> ✅ Statistics loaded successfully!</p>
</div>
"""
return plt_digits,fig_d,done_html
def main():
#block = gr.Blocks(css=BLOCK_CSS)
block = gr.Blocks()
with block:
gr.Markdown(TITLE)
gr.Markdown(description)
with gr.Tabs():
with gr.TabItem('MNIST'):
gr.Markdown(WHAT_TO_DO)
test_metric = gr.outputs.HTML(DEFAULT_TEST_METRIC)
with gr.Row():
image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil")
label_output = gr.outputs.Label(num_top_classes=10)
submit = gr.Button("Submit")
gr.Markdown(MODEL_IS_WRONG)
number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
flag_btn = gr.Button("Flag")
output_result = gr.outputs.HTML()
adversarial_number = gr.Variable(value=0)
submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
flag_btn.click(flag,inputs=[image_input,number_dropdown,adversarial_number],outputs=[output_result,test_metric,adversarial_number])
with gr.TabItem('Dashboard') as dashboard:
notification = gr.HTML("""<div style="color: green">
<p> ⌛ Creating statistics... </p>
</div>
""")
_,numbers_count_values_ = get_number_dict('./data_mnist/data')
STATS_EXPLANATION_ = STATS_EXPLANATION.format(num_adv_samples = sum(numbers_count_values_))
gr.Markdown(STATS_EXPLANATION_)
stat_adv_image =gr.Plot(type="matplotlib")
gr.Markdown(DASHBOARD_EXPLANATION)
test_results=gr.Plot(type="matplotlib")
dashboard.select(get_statistics,inputs=[],outputs=[stat_adv_image,test_results,notification])
block.launch()
if __name__ == "__main__":
main()