mnist-adversarial / utils.py
chrisjay's picture
fix to dashboard not loading
e4a62fe
raw
history blame contribute delete
No virus
3.26 kB
import json
import hashlib
import random
import string
import warnings
import matplotlib.pyplot as plt
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
description = """This project is about dynamic adversarial data collection (DADC).
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly.
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
"""
WHAT_TO_DO="""
### What to do:
1. Draw any number from 0-9. The model will automatically try to predict it after drawing.
2. If the model misclassifies it, Flag that example.
3. This will add your (adversarial) example to a dataset on which the model will be trained later.
4. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated.
"""
MODEL_IS_WRONG = """
---
### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset) and the model learns from it periodically.
"""
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined."
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
def get_unique_name():
return ''.join([random.choice(string.ascii_letters
+ string.digits) for n in range(32)])
def read_json(file):
with open(file,'r',encoding="utf8") as f:
return json.load(f)
def read_json_lines(file):
try:
with open(file,'r',encoding="utf8") as f:
lines = f.readlines()
data=[]
for l in lines:
data.append(json.loads(l))
return data
except Exception as err:
warnings.warn(f"{err}")
return None
def json_dump(thing):
return json.dumps(thing,
ensure_ascii=False,
sort_keys=True,
indent=None,
separators=(',', ':'))
def get_hash(thing): # stable-hashing
return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest())
def dump_json(thing,file):
with open(file,'w+',encoding="utf8") as f:
json.dump(thing,f)
def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False):
fig, ax = plt.subplots(tight_layout=True)
ax.set(xlabel=x_name, ylabel=y_name,title=title)
if set_yticks:
ax.set_yticks(range(min(name), max(name)+1, 1))
if set_xticks:
ax.set_xticks(range(min(name), max(name)+1, 1))
ax.barh(name, value)
return ax.figure