import json import hashlib import random import string import matplotlib.pyplot as plt TITLE = "# MNIST Adversarial: Try to fool this MNIST model" description = """This project is about dynamic adversarial data collection (DADC). The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly. This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label. """ WHAT_TO_DO=""" ### What to do: 1. Draw any number from 0-9. 2. Click `Submit` and see the model's prediciton. 3. If the model misclassifies it, Flag that example. 4. This will add your (adversarial) example to a dataset on which the model will be trained later. 5. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated. """ MODEL_IS_WRONG = """ --- ### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset) and the model learns from it periodically. """ DEFAULT_TEST_METRIC = " Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) " DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543)." STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)." def get_unique_name(): return ''.join([random.choice(string.ascii_letters + string.digits) for n in range(32)]) def read_json(file): with open(file,'r',encoding="utf8") as f: return json.load(f) def read_json_lines(file): with open(file,'r',encoding="utf8") as f: lines = f.readlines() data=[] for l in lines: data.append(json.loads(l)) return data def json_dump(thing): return json.dumps(thing, ensure_ascii=False, sort_keys=True, indent=None, separators=(',', ':')) def get_hash(thing): # stable-hashing return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest()) def dump_json(thing,file): with open(file,'w+',encoding="utf8") as f: json.dump(thing,f) def plot_bar(value,name,x_name,y_name,title): fig, ax = plt.subplots(tight_layout=True) ax.set(xlabel=x_name, ylabel=y_name,title=title) ax.barh(name, value) return ax.figure