File size: 3,259 Bytes
f240072
 
 
 
 
e4a62fe
866cafe
f240072
866cafe
 
 
 
 
 
 
d4894f5
 
 
 
866cafe
f240072
866cafe
2dc35ff
e467f01
866cafe
 
 
c4c6bd6
e4a62fe
866cafe
 
f240072
 
 
 
 
 
866cafe
 
 
 
f240072
e4a62fe
 
 
 
 
 
 
 
 
 
f240072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866cafe
e4a62fe
35ee063
866cafe
 
 
e4a62fe
 
 
 
 
 
866cafe
 
e4a62fe
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85

import json
import hashlib
import random
import string
import warnings
import matplotlib.pyplot as plt

TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
description = """This project is about dynamic adversarial data collection (DADC). 
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly. 
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
"""
WHAT_TO_DO="""
### What to do:
1. Draw any number from 0-9. The model will automatically try to predict it after drawing.
2. If the model misclassifies it, Flag that example.
3. This will add your (adversarial) example to a dataset on which the model will be trained later.
4. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated. 
"""

MODEL_IS_WRONG = """
--- 
### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset) and the model learns from it periodically.
"""
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"

DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit."
DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined."

STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."

def get_unique_name():
    return ''.join([random.choice(string.ascii_letters
            + string.digits) for n in range(32)])


def read_json(file):
    with open(file,'r',encoding="utf8") as f:
        return json.load(f)

def read_json_lines(file):
    try:
        with open(file,'r',encoding="utf8") as f:
            lines = f.readlines()
            data=[]
            for l in lines:
                data.append(json.loads(l))
            return data
    except Exception as err:
        warnings.warn(f"{err}")
        return None 


def json_dump(thing):
    return json.dumps(thing,
                        ensure_ascii=False,
                        sort_keys=True,
                        indent=None,
                        separators=(',', ':'))

def get_hash(thing): # stable-hashing
    return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest())


def dump_json(thing,file):
    with open(file,'w+',encoding="utf8") as f:
        json.dump(thing,f)


def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False):
    fig, ax = plt.subplots(tight_layout=True)

    ax.set(xlabel=x_name, ylabel=y_name,title=title)

    if set_yticks:
        ax.set_yticks(range(min(name), max(name)+1, 1))
    if set_xticks:
        ax.set_xticks(range(min(name), max(name)+1, 1))
    

    ax.barh(name, value)

    return ax.figure