File size: 2,652 Bytes
f240072
 
 
 
 
866cafe
f240072
866cafe
 
 
 
 
 
 
 
 
 
 
 
 
f240072
866cafe
4db7f81
866cafe
 
 
 
 
 
f240072
 
 
 
 
 
866cafe
 
 
 
f240072
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
866cafe
 
35ee063
866cafe
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72

import json
import hashlib
import random
import string
import matplotlib.pyplot as plt

TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
description = """This project is about dynamic adversarial data collection (DADC). 
The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly. 
This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
"""
WHAT_TO_DO="""
### What to do:
1. Draw a number from 0-9.
2. Click `Submit` and see the model's prediciton.
3. If the model misclassifies it, Flag that example.
4. This will add your (adversarial) example to a dataset on which the model will be trained later.
5. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated. 
"""

MODEL_IS_WRONG = """
### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved to our dataset and the model is trained on it.
"""
DEFAULT_TEST_METRIC = "<html> Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) </html>"

DASHBOARD_EXPLANATION="To see the effect of our model on out-of-distribution data, we test it on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)."

STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)."

def get_unique_name():
    return ''.join([random.choice(string.ascii_letters
            + string.digits) for n in range(32)])


def read_json(file):
    with open(file,'r',encoding="utf8") as f:
        return json.load(f)

def read_json_lines(file):
    with open(file,'r',encoding="utf8") as f:
        lines = f.readlines()
        data=[]
        for l in lines:
            data.append(json.loads(l))
        return data


def json_dump(thing):
    return json.dumps(thing,
                        ensure_ascii=False,
                        sort_keys=True,
                        indent=None,
                        separators=(',', ':'))

def get_hash(thing): # stable-hashing
    return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest())


def dump_json(thing,file):
    with open(file,'w+',encoding="utf8") as f:
        json.dump(thing,f)


def plot_bar(value,name,x_name,y_name,title):
    fig, ax = plt.subplots(tight_layout=True)

    ax.set(xlabel=x_name, ylabel=y_name,title=title)

    ax.barh(name, value)

    return ax.figure