import json import hashlib import random import string import warnings import matplotlib.pyplot as plt TITLE = "# MNIST Adversarial: Try to fool this MNIST model" description = """This project is about dynamic adversarial data collection (DADC). The basic idea is to collect “adversarial data” - the kind of data that is difficult for a model to predict correctly. This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label. """ WHAT_TO_DO=""" ### What to do: 1. Draw any number from 0-9. The model will automatically try to predict it after drawing. 2. If the model misclassifies it, Flag that example. 3. This will add your (adversarial) example to a dataset on which the model will be trained later. 4. The model will finetune on the adversarial samples after every __{num_samples}__ samples have been generated. """ MODEL_IS_WRONG = """ --- ### Did the model get it wrong or has a low confidence? Choose the correct prediction below and flag it. When you flag it, the instance is saved [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset) and the model learns from it periodically. """ DEFAULT_TEST_METRIC = " Current test metric - Avg. loss: 1000, Accuracy: 30/1000 (30%) " DASHBOARD_EXPLANATION="To test the effect of adversarial training on out-of-distribution data, we track the performance progress of the model on the [MNIST Corrupted test dataset](https://zenodo.org/record/3239543). We are using {TEST_PER_SAMPLE} samples per digit." DASHBOARD_EXPLANATION_TEST="Test accuracy on out-of-distribution data for all numbers combined." STATS_EXPLANATION = "Here is the distribution of the __{num_adv_samples}__ adversarial samples we've got. The dataset can be found [here](https://huggingface.co/datasets/chrisjay/mnist-adversarial-dataset)." def get_unique_name(): return ''.join([random.choice(string.ascii_letters + string.digits) for n in range(32)]) def read_json(file): with open(file,'r',encoding="utf8") as f: return json.load(f) def read_json_lines(file): try: with open(file,'r',encoding="utf8") as f: lines = f.readlines() data=[] for l in lines: data.append(json.loads(l)) return data except Exception as err: warnings.warn(f"{err}") return None def json_dump(thing): return json.dumps(thing, ensure_ascii=False, sort_keys=True, indent=None, separators=(',', ':')) def get_hash(thing): # stable-hashing return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest()) def dump_json(thing,file): with open(file,'w+',encoding="utf8") as f: json.dump(thing,f) def plot_bar(value,name,x_name,y_name,title,set_yticks=False,set_xticks=False): fig, ax = plt.subplots(tight_layout=True) ax.set(xlabel=x_name, ylabel=y_name,title=title) if set_yticks: ax.set_yticks(range(min(name), max(name)+1, 1)) if set_xticks: ax.set_xticks(range(min(name), max(name)+1, 1)) ax.barh(name, value) return ax.figure