chrisjay commited on
Commit
f240072
1 Parent(s): 475a212

saving to dataset

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. app.py +118 -18
  3. utils.py +44 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__/*
2
+ data_local/*
3
+ flagged/*
app.py CHANGED
@@ -2,9 +2,12 @@ import os
2
  import torch
3
  import gradio as gr
4
  import torchvision
 
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import torch.optim as optim
 
 
8
 
9
 
10
  n_epochs = 3
@@ -13,8 +16,17 @@ batch_size_test = 1000
13
  learning_rate = 0.01
14
  momentum = 0.5
15
  log_interval = 10
16
-
17
  random_seed = 1
 
 
 
 
 
 
 
 
 
 
18
  torch.backends.cudnn.enabled = False
19
  torch.manual_seed(random_seed)
20
 
@@ -123,6 +135,13 @@ if os.path.exists(optimizer_state_dict):
123
 
124
 
125
  def image_classifier(inp):
 
 
 
 
 
 
 
126
  input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
127
  with torch.no_grad():
128
 
@@ -134,21 +153,102 @@ def image_classifier(inp):
134
  confidences.update({s:v})
135
  return confidences
136
 
137
- TITLE = "MNIST Adversarial: Try to fool the MNIST model"
138
- description = """This project is about dynamic adversarial data collection (DADC).
139
- The basic idea is to do data collection, but specifically collect “adversarial data”, the kind of data that is difficult for a model to predict correctly.
140
- This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
141
-
142
- ### What to do:
143
- - Draw a number from 0-9.
144
- - Click `Submit` and see the model's prediciton.
145
- - If the model misclassifies it, Flag that example.
146
- - This will add your (adversarial) example to a dataset on which the model will be trained later.
147
- """
148
- gr.Interface(fn=image_classifier,
149
- inputs=gr.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil"),
150
- outputs=gr.outputs.Label(num_top_classes=10),
151
- allow_flagging="manual",
152
- title = TITLE,
153
- description=description).launch()
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import gradio as gr
4
  import torchvision
5
+ from utils import *
6
  import torch.nn as nn
7
  import torch.nn.functional as F
8
  import torch.optim as optim
9
+ from huggingface_hub import Repository, upload_file
10
+
11
 
12
 
13
  n_epochs = 3
 
16
  learning_rate = 0.01
17
  momentum = 0.5
18
  log_interval = 10
 
19
  random_seed = 1
20
+
21
+ REPOSITORY_DIR = "data"
22
+ LOCAL_DIR = 'data_local'
23
+ os.makedirs(LOCAL_DIR,exist_ok=True)
24
+
25
+
26
+ HF_TOKEN = os.getenv("HF_TOKEN")
27
+
28
+ HF_DATASET ="mnist-adversarial-dataset"
29
+
30
  torch.backends.cudnn.enabled = False
31
  torch.manual_seed(random_seed)
32
 
 
135
 
136
 
137
  def image_classifier(inp):
138
+ """
139
+ It takes an image as input and returns a dictionary of class labels and their corresponding
140
+ confidence scores.
141
+
142
+ :param inp: the image to be classified
143
+ :return: A dictionary of the class index and the confidence value.
144
+ """
145
  input_image = torchvision.transforms.ToTensor()(inp).unsqueeze(0)
146
  with torch.no_grad():
147
 
 
153
  confidences.update({s:v})
154
  return confidences
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
+ def flag(input_image,correct_result):
158
+ # take an image, the wrong result, the correct result.
159
+ # push to dataset.
160
+ # get size of current dataset
161
+
162
+ # Write audio to file
163
+ metadata_name = get_unique_name()
164
+ SAVE_FILE_DIR = os.path.join(LOCAL_DIR,metadata_name)
165
+ os.makedirs(SAVE_FILE_DIR,exist_ok=True)
166
+ image_output_filename = os.path.join(SAVE_FILE_DIR,'image.png')
167
+ try:
168
+ input_image.save(image_output_filename)
169
+ except Exception:
170
+ raise Exception(f"Had issues saving PIL image to file")
171
+
172
+ # Write metadata.json to file
173
+ json_file_path = os.path.join(SAVE_FILE_DIR,'metadata.jsonl')
174
+ metadata= {'id':metadata_name,'file_name':'image.png',
175
+ 'correct_number':correct_result
176
+ }
177
+
178
+ dump_json(metadata,json_file_path)
179
+
180
+ # Simply upload the audio file and metadata using the hub's upload_file
181
+ # Upload the image
182
+ repo_image_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'image.png'))
183
+
184
+ _ = upload_file(path_or_fileobj = image_output_filename,
185
+ path_in_repo =repo_image_path,
186
+ repo_id=f'chrisjay/{HF_DATASET}',
187
+ repo_type='dataset',
188
+ token=HF_TOKEN
189
+ )
190
+
191
+ # Upload the metadata
192
+ repo_json_path = os.path.join(REPOSITORY_DIR,os.path.join(metadata_name,'metadata.jsonl'))
193
+ _ = upload_file(path_or_fileobj = json_file_path,
194
+ path_in_repo =repo_json_path,
195
+ repo_id=f'chrisjay/{HF_DATASET}',
196
+ repo_type='dataset',
197
+ token=HF_TOKEN
198
+ )
199
+
200
+ output = f'<div> Successfully saved to flagged dataset. </div>'
201
+ return output
202
+
203
+
204
+
205
+ def main():
206
+ TITLE = "# MNIST Adversarial: Try to fool this MNIST model"
207
+ description = """This project is about dynamic adversarial data collection (DADC).
208
+ The basic idea is to do data collection by collecting “adversarial data”, the kind of data that is difficult for a model to predict correctly.
209
+ This kind of data is presumably the most valuable for a model, so this can be helpful in low-resource settings where data is hard to collect and label.
210
+
211
+ ### What to do:
212
+ - Draw a number from 0-9.
213
+ - Click `Submit` and see the model's prediciton.
214
+ - If the model misclassifies it, Flag that example.
215
+ - This will add your (adversarial) example to a dataset on which the model will be trained later.
216
+ """
217
+
218
+ MODEL_IS_WRONG = """
219
+ > Did the model get it wrong? Choose the correct prediction below and flag it.
220
+
221
+ When you flag it, the instance is saved to our dataset and the model is trained on it.
222
+ """
223
+ #block = gr.Blocks(css=BLOCK_CSS)
224
+ block = gr.Blocks()
225
+
226
+ with block:
227
+ gr.Markdown(TITLE)
228
+
229
+ with gr.Tabs():
230
+ gr.Markdown(description)
231
+ with gr.TabItem('MNIST'):
232
+ with gr.Row():
233
+
234
+
235
+ image_input =gr.inputs.Image(source="canvas",shape=(28,28),invert_colors=True,image_mode="L",type="pil")
236
+ label_output = gr.outputs.Label(num_top_classes=10)
237
+
238
+ submit = gr.Button("Submit")
239
+ gr.Markdown(MODEL_IS_WRONG)
240
+ number_dropdown = gr.Dropdown(choices=[i for i in range(10)],type='value',default=None,label="What was the correct prediction?")
241
+
242
+ flag_btn = gr.Button("Flag")
243
+ output_result = gr.outputs.HTML()
244
+ submit.click(image_classifier,inputs = [image_input],outputs=[label_output])
245
+ flag_btn.click(flag,inputs=[image_input,number_dropdown],outputs=[output_result])
246
+
247
+
248
+ block.launch()
249
+
250
+
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import hashlib
4
+ import random
5
+ import string
6
+
7
+
8
+
9
+ def get_unique_name():
10
+ return ''.join([random.choice(string.ascii_letters
11
+ + string.digits) for n in range(32)])
12
+
13
+
14
+ def read_json_lines(file):
15
+ with open(file,'r',encoding="utf8") as f:
16
+ lines = f.readlines()
17
+ data=[]
18
+ for l in lines:
19
+ data.append(json.loads(l))
20
+ return data
21
+
22
+
23
+ def json_dump(thing):
24
+ return json.dumps(thing,
25
+ ensure_ascii=False,
26
+ sort_keys=True,
27
+ indent=None,
28
+ separators=(',', ':'))
29
+
30
+ def get_hash(thing): # stable-hashing
31
+ return str(hashlib.md5(json_dump(thing).encode('utf-8')).hexdigest())
32
+
33
+
34
+ def dump_json(thing,file):
35
+ with open(file,'w+',encoding="utf8") as f:
36
+ json.dump(thing,f)
37
+
38
+ def read_json_lines(file):
39
+ with open(file,'r',encoding="utf8") as f:
40
+ lines = f.readlines()
41
+ data=[]
42
+ for l in lines:
43
+ data.append(json.loads(l))
44
+ return data