mnist-adversarial / utils.py
chrisjay's picture
fix to dashboard not loading
e4a62fe
import json
import hashlib
import random
import string
import warnings
import matplotlib.pyplot as plt
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
description = """This project is about dynamic adversarial data collection (DADC).
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly.
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
"""
WHAT_TO_DO="""
### What to do:
1. Draw any number from 0-9. The model will automatically try to predict it after drawing.
2. If the model misclassifies it, Flag that example.
3. This will add your (adversarial) example to a dataset on which the model will be trained later.
4. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated.
"""
MODEL_IS_WRONG = """
---
### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset) and the model learns from it periodically.
"""
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined."
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
def get_unique_name():
return ''.join([random.choice(string.ascii_letters
+ string.digits) for n in range(32)])
def read_json(file):
with open(file,'r',encoding="utf8") as f:
return json.load(f)
def read_json_lines(file):
try:
with open(file,'r',encoding="utf8") as f:
lines = f.readlines()
data=[]
for l in lines:
data.append(json.loads(l))
return data
except Exception as err:
warnings.warn(f"{err}")
return None
def json_dump(thing):
return json.dumps(thing,
ensure_ascii=False,
sort_keys=True,
indent=None,
separators=(',', ':'))
def get_hash(thing): # stable-hashing
return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest())
def dump_json(thing,file):
with open(file,'w+',encoding="utf8") as f:
json.dump(thing,f)
def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False):
fig, ax = plt.subplots(tight_layout=True)
ax.set(xlabel=x_name, ylabel=y_name,title=title)
if set_yticks:
ax.set_yticks(range(min(name), max(name)+1, 1))
if set_xticks:
ax.set_xticks(range(min(name), max(name)+1, 1))
ax.barh(name, value)
return ax.figure