Spaces:
Runtime error
Runtime error
File size: 2,652 Bytes
f240072 866cafe f240072 866cafe f240072 866cafe 4db7f81 866cafe f240072 866cafe f240072 866cafe 35ee063 866cafe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import json
import hashlib
import random
import string
import matplotlib.pyplot as plt
TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
description = """This project is about dynamic adversarial data collection (DADC).
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly.
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
"""
WHAT_TO_DO="""
### What to do:
1. Draw a number from 0-9.
2. Click `Submit` and see the model's prediciton.
3. If the model misclassifies it, Flag that example.
4. This will add your (adversarial) example to a dataset on which the model will be trained later.
5. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated.
"""
MODEL_IS_WRONG = """
### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved to our dataset and the model is trained on it.
"""
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"
DASHBOARD_EXPLANATION="To see the effect of our model on out-of-distribution data, we test it on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."
STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."
def get_unique_name():
return ''.join([random.choice(string.ascii_letters
+ string.digits) for n in range(32)])
def read_json(file):
with open(file,'r',encoding="utf8") as f:
return json.load(f)
def read_json_lines(file):
with open(file,'r',encoding="utf8") as f:
lines = f.readlines()
data=[]
for l in lines:
data.append(json.loads(l))
return data
def json_dump(thing):
return json.dumps(thing,
ensure_ascii=False,
sort_keys=True,
indent=None,
separators=(',', ':'))
def get_hash(thing): # stable-hashing
return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest())
def dump_json(thing,file):
with open(file,'w+',encoding="utf8") as f:
json.dump(thing,f)
def plot_bar(value,name,x_name,y_name,title):
fig, ax = plt.subplots(tight_layout=True)
ax.set(xlabel=x_name, ylabel=y_name,title=title)
ax.barh(name, value)
return ax.figure |