mnist-adversarial / utils.py
chrisjay's picture
work on trainin and dashboard statistics
866cafe
raw history blame
No virus
2.65 kB
import json
import hashlib
import random
import string
import matplotlib.pyplot as plt
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
description = """This project is about dynamic adversarial data collection (DADC).
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly.
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
"""
WHAT_TO_DO="""
### What to do:
1. Draw a number from 0-9.
2. Click `Submit` and see the model's prediciton.
3. If the model misclassifies it, Flag that example.
4. This will add your (adversarial) example to a dataset on which the model will be trained later.
5. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated.
"""
MODEL_IS_WRONG = """
---
> Did the model get it wrong? Choose the correct prediction below and flag it. When you flag it, the instance is saved to our dataset and the model is trained on it.
"""
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
DASHBOARD_EXPLANATION="To see the effect of our model on out-of-distribution data, we test it on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
def get_unique_name():
return ''.join([random.choice(string.ascii_letters
+ string.digits) for n in range(32)])
def read_json(file):
with open(file,'r',encoding="utf8") as f:
return json.load(f)
def read_json_lines(file):
with open(file,'r',encoding="utf8") as f:
lines = f.readlines()
data=[]
for l in lines:
data.append(json.loads(l))
return data
def json_dump(thing):
return json.dumps(thing,
ensure_ascii=False,
sort_keys=True,
indent=None,
separators=(',', ':'))
def get_hash(thing): # stable-hashing
return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest())
def dump_json(thing,file):
with open(file,'w+',encoding="utf8") as f:
json.dump(thing,f)
def plot_bar(value,name,x_name,y_name,title):
fig, ax = plt.subplots(figsize=(10,4),tight_layout=True)
ax.set(xlabel=x_name, ylabel=y_name,title=title)
ax.barh(name, value)
return ax.figure