Michelle Lam commited on
Commit
4adf2d3
1 Parent(s): b04690b

Restructures cached data to be organized by user, then model (to simplify retrieval). Refactors code throughout to support this change.

Browse files
audit_utils.py CHANGED
@@ -37,13 +37,6 @@ alt.renderers.enable('altair_saver', fmts=['vega-lite', 'png'])
37
 
38
  # Data-loading
39
  module_dir = "./"
40
- perf_dir = f"data/perf/"
41
-
42
- # # TEMP reset
43
- # with open(f"./data/users_to_models.pkl", "wb") as f:
44
- # users_to_models = {}
45
- # pickle.dump(users_to_models, f)
46
-
47
  with open(os.path.join(module_dir, "data/input/ids_to_comments.pkl"), "rb") as f:
48
  ids_to_comments = pickle.load(f)
49
  with open(os.path.join(module_dir, "data/input/comments_to_ids.pkl"), "rb") as f:
@@ -56,9 +49,6 @@ model_eval_df = pd.read_pickle(os.path.join(module_dir, "data/input/split_data/m
56
  ratings_df_full = pd.read_pickle(os.path.join(module_dir, "data/input/ratings_df_full.pkl"))
57
  worker_info_df = pd.read_pickle("./data/input/worker_info_df.pkl")
58
 
59
- with open(f"./data/users_to_models.pkl", "rb") as f:
60
- users_to_models = pickle.load(f)
61
-
62
  topic_ids = system_preds_df.topic_id
63
  topics = system_preds_df.topic
64
  topic_ids_to_topics = {topic_ids[i]: topics[i] for i in range(len(topic_ids))}
@@ -71,11 +61,17 @@ def get_toxic_threshold():
71
 
72
  def get_user_model_names(user):
73
  # Fetch the user's models
74
- if user not in users_to_models:
75
- users_to_models[user] = []
76
- user_models = users_to_models[user]
77
- user_models.sort()
78
- return user_models
 
 
 
 
 
 
79
 
80
  def get_unique_topics():
81
  return unique_topics
@@ -122,6 +118,64 @@ internal_to_readable = {v: k for k, v in readable_to_internal.items()}
122
  def get_system_preds_df():
123
  return system_preds_df
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  ########################################
126
  # General utils
127
  def get_metric_ind(metric):
@@ -182,23 +236,24 @@ def plot_metric_histogram(metric, user_metric, other_metric_vals, n_bins=10):
182
  return (bar + rule).interactive()
183
 
184
  # Generates the summary plot across all topics for the user
185
- def show_overall_perf(variant, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median"):
186
  # Your perf (calculate using model and testset)
187
- with open(os.path.join(module_dir, f"data/preds_dfs/{variant}.pkl"), "rb") as f:
 
188
  preds_df = pickle.load(f)
189
 
190
- # Read from file
191
- chart_dir = "./data/charts"
192
- chart_file = os.path.join(chart_dir, f"{cur_user}_{variant}.pkl")
193
  if os.path.isfile(chart_file):
 
194
  with open(chart_file, "r") as f:
195
  topic_overview_plot_json = json.load(f)
196
  else:
 
197
  if topic_vis_method == "median": # Default
198
  preds_df_grp = preds_df.groupby(["topic", "user_id"]).median()
199
  elif topic_vis_method == "mean":
200
  preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean()
201
- topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=variant)
202
 
203
  return {
204
  "topic_overview_plot_json": json.loads(topic_overview_plot_json),
@@ -260,22 +315,23 @@ def get_grp_model_labels(n_label_per_bin, score_bins, grp_ids):
260
 
261
  ########################################
262
  # GET_PERSONALIZED_MODEL utils
263
- def fetch_existing_data(model_name, last_label_i):
264
  # Check if we have cached model performance
265
- perf_dir = f"./data/perf/{model_name}"
266
- label_dir = f"./data/labels/{model_name}"
267
- if os.path.isdir(os.path.join(module_dir, perf_dir)):
268
  # Fetch cached results
269
- last_i = len([name for name in os.listdir(os.path.join(module_dir, perf_dir)) if os.path.isfile(os.path.join(module_dir, perf_dir, name))])
270
- with open(os.path.join(module_dir, perf_dir, f"{last_i}.pkl"), "rb") as f:
271
  mae, mse, rmse, avg_diff = pickle.load(f)
272
  else:
273
  raise Exception(f"Model {model_name} does not exist")
274
 
275
  # Fetch previous user-provided labels
276
  ratings_prev = None
277
- if last_label_i > 0:
278
- with open(os.path.join(module_dir, label_dir, f"{last_i}.pkl"), "rb") as f:
 
 
279
  ratings_prev = pickle.load(f)
280
  return mae, mse, rmse, avg_diff, ratings_prev
281
 
@@ -283,15 +339,12 @@ def fetch_existing_data(model_name, last_label_i):
283
  # Trains an updated model with the specified name, user, and ratings
284
  # Saves ratings, performance metrics, and pre-computed predictions to files
285
  # - model_name: name of the model to train
286
- # - last_label_i: index of the last label file (0 if none exists)
287
  # - ratings: dictionary of comments to ratings
288
  # - user: user name
289
  # - top_n: number of comments to train on (used when a set was held out for original user study)
290
  # - topic: topic to train on (used when tuning for a specific topic)
291
- def train_updated_model(model_name, last_label_i, ratings, user, top_n=None, topic=None, debug=False):
292
  # Check if there is previously-labeled data; if so, combine it with this data
293
- perf_dir = f"./data/perf/{model_name}"
294
- label_dir = f"./data/labels/{model_name}"
295
  labeled_df = format_labeled_data(ratings) # Treat ratings as full batch of all ratings
296
  ratings_prev = None
297
 
@@ -303,9 +356,11 @@ def train_updated_model(model_name, last_label_i, ratings, user, top_n=None, top
303
  labeled_df = labeled_df.head(top_n)
304
  else:
305
  # For topic tuning, need to fetch old labels
306
- if (last_label_i > 0):
 
307
  # Concatenate previous set of labels with this new batch of labels
308
- with open(os.path.join(module_dir, label_dir, f"{last_label_i}.pkl"), "rb") as f:
 
309
  ratings_prev = pickle.load(f)
310
  labeled_df_prev = format_labeled_data(ratings_prev)
311
  labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
@@ -314,7 +369,8 @@ def train_updated_model(model_name, last_label_i, ratings, user, top_n=None, top
314
  if debug:
315
  print("len ratings for training:", len(labeled_df))
316
  # Save this batch of labels
317
- with open(os.path.join(module_dir, label_dir, f"{last_label_i + 1}.pkl"), "wb") as f:
 
318
  pickle.dump(ratings, f)
319
 
320
  # Train model
@@ -323,25 +379,16 @@ def train_updated_model(model_name, last_label_i, ratings, user, top_n=None, top
323
  # Compute performance metrics
324
  mae, mse, rmse, avg_diff = users_perf(cur_model)
325
  # Save performance metrics
326
- if not os.path.isdir(os.path.join(module_dir, perf_dir)):
327
- os.mkdir(os.path.join(module_dir, perf_dir))
328
- last_perf_i = len([name for name in os.listdir(os.path.join(module_dir, perf_dir)) if os.path.isfile(os.path.join(module_dir, perf_dir, name))])
329
- with open(os.path.join(module_dir, perf_dir, f"{last_perf_i + 1}.pkl"), "wb") as f:
330
  pickle.dump((mae, mse, rmse, avg_diff), f)
331
 
332
  # Pre-compute predictions for full dataset
333
  cur_preds_df = get_preds_df(cur_model, ["A"], sys_eval_df=ratings_df_full)
334
  # Save pre-computed predictions
335
- with open(os.path.join(module_dir, f"./data/preds_dfs/{model_name}.pkl"), "wb") as f:
 
336
  pickle.dump(cur_preds_df, f)
337
-
338
- # Handle user
339
- if user not in users_to_models:
340
- users_to_models[user] = [] # New user
341
- if model_name not in users_to_models[user]:
342
- users_to_models[user].append(model_name) # New model
343
- with open(f"./data/users_to_models.pkl", "wb") as f:
344
- pickle.dump(users_to_models, f)
345
 
346
  ratings_prev = ratings
347
  return mae, mse, rmse, avg_diff, ratings_prev
@@ -494,13 +541,12 @@ def train_model(train_df, model_eval_df, model_type="SVD", sim_type=None, user_b
494
 
495
  return algo, perf
496
 
497
- def plot_train_perf_results(model_name, mae):
498
- perf_dir = f"./data/perf/{model_name}"
499
- n_perf_files = len([name for name in os.listdir(os.path.join(module_dir, perf_dir)) if os.path.isfile(os.path.join(module_dir, perf_dir, name))])
500
-
501
  all_rows = []
502
- for i in range(1, n_perf_files + 1):
503
- with open(os.path.join(module_dir, perf_dir, f"{i}.pkl"), "rb") as f:
 
504
  mae, mse, rmse, avg_diff = pickle.load(f)
505
  all_rows.append([i, mae, "Your MAE"])
506
 
@@ -779,9 +825,8 @@ def plot_overall_vis(preds_df, error_type, cur_user, cur_model, n_topics=None, b
779
 
780
  plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
781
 
782
- # Save to file
783
- chart_dir = "./data/charts"
784
- chart_file = os.path.join(chart_dir, f"{cur_user}_{cur_model}.pkl")
785
  with open(chart_file, "w") as f:
786
  json.dump(plot, f)
787
 
 
37
 
38
  # Data-loading
39
  module_dir = "./"
 
 
 
 
 
 
 
40
  with open(os.path.join(module_dir, "data/input/ids_to_comments.pkl"), "rb") as f:
41
  ids_to_comments = pickle.load(f)
42
  with open(os.path.join(module_dir, "data/input/comments_to_ids.pkl"), "rb") as f:
 
49
  ratings_df_full = pd.read_pickle(os.path.join(module_dir, "data/input/ratings_df_full.pkl"))
50
  worker_info_df = pd.read_pickle("./data/input/worker_info_df.pkl")
51
 
 
 
 
52
  topic_ids = system_preds_df.topic_id
53
  topics = system_preds_df.topic
54
  topic_ids_to_topics = {topic_ids[i]: topics[i] for i in range(len(topic_ids))}
 
61
 
62
  def get_user_model_names(user):
63
  # Fetch the user's models
64
+ output_dir = f"./data/output"
65
+ users = [name for name in os.listdir(output_dir) if os.path.isdir(os.path.join(output_dir, name))]
66
+ if user not in users:
67
+ # User does not exist
68
+ return []
69
+ else:
70
+ # Fetch trained model names for the user
71
+ user_dir = f"./data/output/{user}"
72
+ user_models = [name for name in os.listdir(user_dir) if os.path.isdir(os.path.join(user_dir, name))]
73
+ user_models.sort()
74
+ return user_models
75
 
76
  def get_unique_topics():
77
  return unique_topics
 
118
  def get_system_preds_df():
119
  return system_preds_df
120
 
121
+ ########################################
122
+ # Data storage helper functions
123
+ # Set up all directories for new user
124
+ def setup_user_dirs(cur_user):
125
+ user_dir = f"./data/output/{cur_user}"
126
+ if not os.path.isdir(user_dir):
127
+ os.mkdir(user_dir)
128
+ def setup_model_dirs(cur_user, cur_model):
129
+ model_dir = f"./data/output/{cur_user}/{cur_model}"
130
+ if not os.path.isdir(model_dir):
131
+ os.mkdir(model_dir) # Set up model dir
132
+ # Set up subdirs
133
+ os.mkdir(os.path.join(model_dir, "labels"))
134
+ os.mkdir(os.path.join(model_dir, "perf"))
135
+ def setup_user_model_dirs(cur_user, cur_model):
136
+ setup_user_dirs(cur_user)
137
+ setup_model_dirs(cur_user, cur_model)
138
+
139
+ # Charts
140
+ def get_chart_file(cur_user, cur_model):
141
+ chart_dir = f"./data/output/{cur_user}/{cur_model}"
142
+ return os.path.join(chart_dir, f"chart__overall_vis.pkl")
143
+
144
+ # Labels
145
+ def get_label_dir(cur_user, cur_model):
146
+ return f"./data/output/{cur_user}/{cur_model}/labels"
147
+ def get_n_label_files(cur_user, cur_model):
148
+ label_dir = get_label_dir(cur_user, cur_model)
149
+ return len([name for name in os.listdir(label_dir) if os.path.isfile(os.path.join(label_dir, name))])
150
+ def get_label_file(cur_user, cur_model, label_i=None):
151
+ if label_i is None:
152
+ # Get index to add on to end of list
153
+ label_i = get_n_label_files(cur_user, cur_model)
154
+ label_dir = get_label_dir(cur_user, cur_model)
155
+ return os.path.join(label_dir, f"{label_i}.pkl")
156
+
157
+ # Performance
158
+ def get_perf_dir(cur_user, cur_model):
159
+ return f"./data/output/{cur_user}/{cur_model}/perf"
160
+ def get_n_perf_files(cur_user, cur_model):
161
+ perf_dir = get_perf_dir(cur_user, cur_model)
162
+ return len([name for name in os.listdir(perf_dir) if os.path.isfile(os.path.join(perf_dir, name))])
163
+ def get_perf_file(cur_user, cur_model, perf_i=None):
164
+ if perf_i is None:
165
+ # Get index to add on to end of list
166
+ perf_i = get_n_perf_files(cur_user, cur_model)
167
+ perf_dir = get_perf_dir(cur_user, cur_model)
168
+ return os.path.join(perf_dir, f"{perf_i}.pkl")
169
+
170
+ # Predictions dataframe
171
+ def get_preds_file(cur_user, cur_model):
172
+ preds_dir = f"./data/output/{cur_user}/{cur_model}"
173
+ return os.path.join(preds_dir, f"preds_df.pkl")
174
+
175
+ # Reports
176
+ def get_reports_file(cur_user, cur_model):
177
+ return f"./data/output/{cur_user}/{cur_model}/reports.pkl"
178
+
179
  ########################################
180
  # General utils
181
  def get_metric_ind(metric):
 
236
  return (bar + rule).interactive()
237
 
238
  # Generates the summary plot across all topics for the user
239
+ def show_overall_perf(cur_model, error_type, cur_user, threshold=TOXIC_THRESHOLD, topic_vis_method="median"):
240
  # Your perf (calculate using model and testset)
241
+ preds_file = get_preds_file(cur_user, cur_model)
242
+ with open(preds_file, "rb") as f:
243
  preds_df = pickle.load(f)
244
 
245
+ chart_file = get_chart_file(cur_user, cur_model)
 
 
246
  if os.path.isfile(chart_file):
247
+ # Read from file if it exists
248
  with open(chart_file, "r") as f:
249
  topic_overview_plot_json = json.load(f)
250
  else:
251
+ # Otherwise, generate chart and save to file
252
  if topic_vis_method == "median": # Default
253
  preds_df_grp = preds_df.groupby(["topic", "user_id"]).median()
254
  elif topic_vis_method == "mean":
255
  preds_df_grp = preds_df.groupby(["topic", "user_id"]).mean()
256
+ topic_overview_plot_json = plot_overall_vis(preds_df=preds_df_grp, n_topics=200, threshold=threshold, error_type=error_type, cur_user=cur_user, cur_model=cur_model)
257
 
258
  return {
259
  "topic_overview_plot_json": json.loads(topic_overview_plot_json),
 
315
 
316
  ########################################
317
  # GET_PERSONALIZED_MODEL utils
318
+ def fetch_existing_data(user, model_name):
319
  # Check if we have cached model performance
320
+ n_perf_files = get_n_perf_files(user, model_name)
321
+ if n_perf_files > 0:
 
322
  # Fetch cached results
323
+ perf_file = get_perf_file(user, model_name, n_perf_files - 1) # Get last performance file
324
+ with open(perf_file, "rb") as f:
325
  mae, mse, rmse, avg_diff = pickle.load(f)
326
  else:
327
  raise Exception(f"Model {model_name} does not exist")
328
 
329
  # Fetch previous user-provided labels
330
  ratings_prev = None
331
+ n_label_files = get_n_label_files(user, model_name)
332
+ if n_label_files > 0:
333
+ label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file
334
+ with open(label_file, "rb") as f:
335
  ratings_prev = pickle.load(f)
336
  return mae, mse, rmse, avg_diff, ratings_prev
337
 
 
339
  # Trains an updated model with the specified name, user, and ratings
340
  # Saves ratings, performance metrics, and pre-computed predictions to files
341
  # - model_name: name of the model to train
 
342
  # - ratings: dictionary of comments to ratings
343
  # - user: user name
344
  # - top_n: number of comments to train on (used when a set was held out for original user study)
345
  # - topic: topic to train on (used when tuning for a specific topic)
346
+ def train_updated_model(model_name, ratings, user, top_n=None, topic=None, debug=False):
347
  # Check if there is previously-labeled data; if so, combine it with this data
 
 
348
  labeled_df = format_labeled_data(ratings) # Treat ratings as full batch of all ratings
349
  ratings_prev = None
350
 
 
356
  labeled_df = labeled_df.head(top_n)
357
  else:
358
  # For topic tuning, need to fetch old labels
359
+ n_label_files = get_n_label_files(user, model_name)
360
+ if n_label_files > 0:
361
  # Concatenate previous set of labels with this new batch of labels
362
+ label_file = get_label_file(user, model_name, n_label_files - 1) # Get last label file
363
+ with open(label_file, "rb") as f:
364
  ratings_prev = pickle.load(f)
365
  labeled_df_prev = format_labeled_data(ratings_prev)
366
  labeled_df_prev = labeled_df_prev[labeled_df_prev["rating"] != -1]
 
369
  if debug:
370
  print("len ratings for training:", len(labeled_df))
371
  # Save this batch of labels
372
+ label_file = get_label_file(user, model_name)
373
+ with open(label_file, "wb") as f:
374
  pickle.dump(ratings, f)
375
 
376
  # Train model
 
379
  # Compute performance metrics
380
  mae, mse, rmse, avg_diff = users_perf(cur_model)
381
  # Save performance metrics
382
+ perf_file = get_perf_file(user, model_name)
383
+ with open(perf_file, "wb") as f:
 
 
384
  pickle.dump((mae, mse, rmse, avg_diff), f)
385
 
386
  # Pre-compute predictions for full dataset
387
  cur_preds_df = get_preds_df(cur_model, ["A"], sys_eval_df=ratings_df_full)
388
  # Save pre-computed predictions
389
+ preds_file = get_preds_file(user, model_name)
390
+ with open(preds_file, "wb") as f:
391
  pickle.dump(cur_preds_df, f)
 
 
 
 
 
 
 
 
392
 
393
  ratings_prev = ratings
394
  return mae, mse, rmse, avg_diff, ratings_prev
 
541
 
542
  return algo, perf
543
 
544
+ def plot_train_perf_results(user, model_name, mae):
545
+ n_perf_files = get_n_perf_files(user, model_name)
 
 
546
  all_rows = []
547
+ for i in range(n_perf_files):
548
+ perf_file = get_perf_file(user, model_name, i)
549
+ with open(perf_file, "rb") as f:
550
  mae, mse, rmse, avg_diff = pickle.load(f)
551
  all_rows.append([i, mae, "Your MAE"])
552
 
 
825
 
826
  plot = (bkgd + annotation + chart + rule).properties(height=(plot_dim_height), width=plot_dim_width).resolve_scale(color='independent').to_json()
827
 
828
+ # Save to file
829
+ chart_file = get_chart_file(cur_user, cur_model)
 
830
  with open(chart_file, "w") as f:
831
  json.dump(plot, f)
832
 
indie_label_svelte/src/Auditing.svelte CHANGED
@@ -109,8 +109,7 @@
109
  if (!personalized_models.includes(personalized_model)) {
110
  personalized_models.push(personalized_model);
111
  }
112
- handleAuditButton();
113
- handleClusterButton(); // re-render cluster results
114
  });
115
 
116
  // Save current error type
@@ -120,16 +119,13 @@
120
  handleClusterButton();
121
  }
122
 
123
- // Handle topic-specific training
124
- // let topic_training = null;
125
-
126
  async function updateTopicChosen() {
127
  if (topic != null) {
128
  topic_chosen.update((value) => topic);
129
  }
130
  }
131
 
132
- function getAuditSettings() {
133
  let req_params = {
134
  user: cur_user,
135
  scaffold_method: scaffold_method,
@@ -152,12 +148,12 @@
152
  clusters_for_tuning = r["clusters_for_tuning"];
153
  topic = clusters[0]["options"][0]["text"];
154
  topic_chosen.update((value) => topic);
155
- handleAuditButton(); // TEMP
156
- handleClusterButton(); // TEMP
157
  });
158
  }
159
  onMount(async () => {
160
- getAuditSettings();
161
  });
162
 
163
  function handleAuditButton() {
@@ -193,6 +189,7 @@
193
  let req_params = {
194
  cluster: topic,
195
  topic_df_ids: [],
 
196
  pers_model: personalized_model,
197
  example_sort: "descending", // TEMP
198
  comparison_group: "status_quo", // TEMP
@@ -422,7 +419,7 @@
422
  <p>Next, you can optionally search for more comments to serve as evidence through manual keyword search (for individual words or phrases).</p>
423
  <div class="section_indent">
424
  {#key error_type}
425
- <KeywordSearch clusters={clusters} personalized_model={personalized_model} bind:evidence={evidence} use_model={use_model} on:change/>
426
  {/key}
427
  </div>
428
  </div>
 
109
  if (!personalized_models.includes(personalized_model)) {
110
  personalized_models.push(personalized_model);
111
  }
112
+ getAuditResults();
 
113
  });
114
 
115
  // Save current error type
 
119
  handleClusterButton();
120
  }
121
 
 
 
 
122
  async function updateTopicChosen() {
123
  if (topic != null) {
124
  topic_chosen.update((value) => topic);
125
  }
126
  }
127
 
128
+ function getAuditResults() {
129
  let req_params = {
130
  user: cur_user,
131
  scaffold_method: scaffold_method,
 
148
  clusters_for_tuning = r["clusters_for_tuning"];
149
  topic = clusters[0]["options"][0]["text"];
150
  topic_chosen.update((value) => topic);
151
+ handleAuditButton();
152
+ handleClusterButton();
153
  });
154
  }
155
  onMount(async () => {
156
+ getAuditResults();
157
  });
158
 
159
  function handleAuditButton() {
 
189
  let req_params = {
190
  cluster: topic,
191
  topic_df_ids: [],
192
+ cur_user: cur_user,
193
  pers_model: personalized_model,
194
  example_sort: "descending", // TEMP
195
  comparison_group: "status_quo", // TEMP
 
419
  <p>Next, you can optionally search for more comments to serve as evidence through manual keyword search (for individual words or phrases).</p>
420
  <div class="section_indent">
421
  {#key error_type}
422
+ <KeywordSearch clusters={clusters} personalized_model={personalized_model} cur_user={cur_user} bind:evidence={evidence} use_model={use_model} on:change/>
423
  {/key}
424
  </div>
425
  </div>
indie_label_svelte/src/HypothesisPanel.svelte CHANGED
@@ -137,6 +137,7 @@
137
  cur_user: cur_user,
138
  reports: JSON.stringify(all_reports),
139
  scaffold_method: scaffold_method,
 
140
  };
141
  let params = new URLSearchParams(req_params).toString();
142
  const response = await fetch("./save_reports?" + params);
 
137
  cur_user: cur_user,
138
  reports: JSON.stringify(all_reports),
139
  scaffold_method: scaffold_method,
140
+ model: model,
141
  };
142
  let params = new URLSearchParams(req_params).toString();
143
  const response = await fetch("./save_reports?" + params);
indie_label_svelte/src/KeywordSearch.svelte CHANGED
@@ -4,12 +4,11 @@
4
 
5
  import Button, { Label } from "@smui/button";
6
  import Textfield from "@smui/textfield";
7
- import LinearProgress from "@smui/linear-progress";
8
- import Chip, { Set, Text } from '@smui/chips';
9
-
10
 
11
  export let clusters;
12
  export let personalized_model;
 
13
  export let evidence;
14
  export let width_pct = 80;
15
  export let use_model = true;
@@ -29,6 +28,7 @@
29
  let req_params = {
30
  cluster: cur_iter_cluster,
31
  topic_df_ids: topic_df_ids,
 
32
  pers_model: personalized_model,
33
  example_sort: "descending", // TEMP
34
  comparison_group: "status_quo", // TEMP
@@ -41,9 +41,6 @@
41
  const response = await fetch("./get_cluster_results?" + params);
42
  const text = await response.text();
43
  const data = JSON.parse(text);
44
- // if (data["cluster_comments"] == null) {
45
- // return false
46
- // }
47
  topic_df_ids = data["topic_df_ids"];
48
  return data;
49
  }
 
4
 
5
  import Button, { Label } from "@smui/button";
6
  import Textfield from "@smui/textfield";
7
+ import LinearProgress from "@smui/linear-progress";
 
 
8
 
9
  export let clusters;
10
  export let personalized_model;
11
+ export let cur_user;
12
  export let evidence;
13
  export let width_pct = 80;
14
  export let use_model = true;
 
28
  let req_params = {
29
  cluster: cur_iter_cluster,
30
  topic_df_ids: topic_df_ids,
31
+ cur_user: cur_user,
32
  pers_model: personalized_model,
33
  example_sort: "descending", // TEMP
34
  comparison_group: "status_quo", // TEMP
 
41
  const response = await fetch("./get_cluster_results?" + params);
42
  const text = await response.text();
43
  const data = JSON.parse(text);
 
 
 
44
  topic_df_ids = data["topic_df_ids"];
45
  return data;
46
  }
server.py CHANGED
@@ -101,7 +101,7 @@ def get_audit():
101
  overall_perf = None
102
  else:
103
  overall_perf = utils.show_overall_perf(
104
- variant=pers_model,
105
  error_type=error_type,
106
  cur_user=cur_user,
107
  topic_vis_method=topic_vis_method,
@@ -117,6 +117,7 @@ def get_audit():
117
  @app.route("/get_cluster_results")
118
  def get_cluster_results(debug=DEBUG):
119
  pers_model = request.args.get("pers_model")
 
120
  cluster = request.args.get("cluster")
121
  topic_df_ids = request.args.getlist("topic_df_ids")
122
  topic_df_ids = [int(val) for val in topic_df_ids[0].split(",") if val != ""]
@@ -130,7 +131,8 @@ def get_cluster_results(debug=DEBUG):
130
 
131
  # Prepare cluster df (topic_df)
132
  topic_df = None
133
- with open(f"data/preds_dfs/{pers_model}.pkl", "rb") as f:
 
134
  topic_df = pickle.load(f)
135
  if search_type == "cluster":
136
  # Display examples with comment, your pred, and other users' pred
@@ -226,19 +228,12 @@ def get_group_model():
226
  grp_ids=grp_ids,
227
  )
228
 
229
- # print("ratings_grp", ratings_grp)
230
-
231
  # Modify model name
232
  model_name = f"{model_name}_group_gender{sel_gender}_relig{sel_relig}_pol{sel_pol}_race{sel_race_orig}_lgbtq_{sel_lgbtq}"
233
-
234
- label_dir = f"./data/labels/{model_name}"
235
- # Create directory for labels if it doesn't yet exist
236
- if not os.path.isdir(label_dir):
237
- os.mkdir(label_dir)
238
- last_label_i = len([name for name in os.listdir(label_dir) if (os.path.isfile(os.path.join(label_dir, name)) and name.endswith('.pkl'))])
239
 
240
  # Train group model
241
- mae, mse, rmse, avg_diff, ratings_prev = utils.train_updated_model(model_name, last_label_i, ratings_grp, user)
242
 
243
  duration = time.time() - start
244
  print("Time to train/cache:", duration)
@@ -317,35 +312,33 @@ def get_comments_to_label_topic():
317
  ########################################
318
  # ROUTE: /GET_PERSONALIZED_MODEL
319
  @app.route("/get_personalized_model")
320
- def get_personalized_model():
321
  model_name = request.args.get("model_name")
322
  ratings_json = request.args.get("ratings")
323
  mode = request.args.get("mode")
324
  user = request.args.get("user")
325
  ratings = json.loads(ratings_json)
326
- print(ratings)
327
- start = time.time()
 
328
 
329
- label_dir = f"./data/labels/{model_name}"
330
- # Create directory for labels if it doesn't yet exist
331
- if not os.path.isdir(label_dir):
332
- os.mkdir(label_dir)
333
- last_label_i = len([name for name in os.listdir(label_dir) if (os.path.isfile(os.path.join(label_dir, name)) and name.endswith('.pkl'))])
334
 
335
  # Handle existing or new model cases
336
  if mode == "view":
337
  # Fetch prior model performance
338
- mae, mse, rmse, avg_diff, ratings_prev = utils.fetch_existing_data(model_name, last_label_i)
339
 
340
  elif mode == "train":
341
  # Train model and cache predictions using new labels
342
  print("get_personalized_model train")
343
- mae, mse, rmse, avg_diff, ratings_prev = utils.train_updated_model(model_name, last_label_i, ratings, user)
344
-
345
- duration = time.time() - start
346
- print("Time to train/cache:", duration)
 
347
 
348
- perf_plot, mae_status = utils.plot_train_perf_results(model_name, mae)
349
  perf_plot_json = perf_plot.to_json()
350
 
351
  def round_metric(x):
@@ -358,7 +351,6 @@ def get_personalized_model():
358
  "mse": round_metric(mse),
359
  "rmse": round_metric(rmse),
360
  "avg_diff": round_metric(avg_diff),
361
- "duration": duration,
362
  "ratings_prev": ratings_prev,
363
  "perf_plot_json": json.loads(perf_plot_json),
364
  }
@@ -379,17 +371,12 @@ def get_personalized_model_topic():
379
 
380
  # Modify model name
381
  model_name = f"{model_name}_{topic}"
382
-
383
- label_dir = f"./data/labels/{model_name}"
384
- # Create directory for labels if it doesn't yet exist
385
- if not os.path.isdir(label_dir):
386
- os.mkdir(label_dir)
387
- last_label_i = len([name for name in os.listdir(label_dir) if (os.path.isfile(os.path.join(label_dir, name)) and name.endswith('.pkl'))])
388
 
389
  # Handle existing or new model cases
390
  # Train model and cache predictions using new labels
391
  print("get_personalized_model_topic train")
392
- mae, mse, rmse, avg_diff, ratings_prev = utils.train_updated_model(model_name, last_label_i, ratings, user, topic=topic)
393
 
394
  duration = time.time() - start
395
  print("Time to train/cache:", duration)
@@ -416,15 +403,13 @@ def get_reports():
416
  if topic_vis_method == "null":
417
  topic_vis_method = "fp_fn"
418
 
419
- # Load reports for current user from stored files
420
- report_dir = f"./data/user_reports"
421
- user_file = os.path.join(report_dir, f"{cur_user}_{scaffold_method}.pkl")
422
-
423
- if not os.path.isfile(user_file):
424
  if scaffold_method == "fixed":
425
  reports = get_fixed_scaffold()
426
  elif (scaffold_method == "personal" or scaffold_method == "personal_group" or scaffold_method == "personal_test"):
427
- reports = get_personal_scaffold(model, topic_vis_method)
428
  elif scaffold_method == "prompts":
429
  reports = get_prompts_scaffold()
430
  elif scaffold_method == "tutorial":
@@ -442,7 +427,7 @@ def get_reports():
442
  ]
443
  else:
444
  # Load from pickle file
445
- with open(user_file, "rb") as f:
446
  reports = pickle.load(f)
447
 
448
  results = {
@@ -544,11 +529,12 @@ def get_topic_errors(df, topic_vis_method, threshold=2):
544
 
545
  return topic_errors
546
 
547
- def get_personal_scaffold(model, topic_vis_method, n_topics=200, n=5):
548
  threshold = utils.get_toxic_threshold()
549
 
550
  # Get topics with greatest amount of error
551
- with open(f"./data/preds_dfs/{model}.pkl", "rb") as f:
 
552
  preds_df = pickle.load(f)
553
  system_preds_df = utils.get_system_preds_df()
554
  preds_df_mod = preds_df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
@@ -653,11 +639,11 @@ def save_reports():
653
  reports_json = request.args.get("reports")
654
  reports = json.loads(reports_json)
655
  scaffold_method = request.args.get("scaffold_method")
 
656
 
657
- # Save reports for current user to stored files
658
- report_dir = f"./data/user_reports"
659
- # Save to pickle file
660
- with open(os.path.join(report_dir, f"{cur_user}_{scaffold_method}.pkl"), "wb") as f:
661
  pickle.dump(reports, f)
662
 
663
  results = {
 
101
  overall_perf = None
102
  else:
103
  overall_perf = utils.show_overall_perf(
104
+ cur_model=pers_model,
105
  error_type=error_type,
106
  cur_user=cur_user,
107
  topic_vis_method=topic_vis_method,
 
117
  @app.route("/get_cluster_results")
118
  def get_cluster_results(debug=DEBUG):
119
  pers_model = request.args.get("pers_model")
120
+ cur_user = request.args.get("cur_user")
121
  cluster = request.args.get("cluster")
122
  topic_df_ids = request.args.getlist("topic_df_ids")
123
  topic_df_ids = [int(val) for val in topic_df_ids[0].split(",") if val != ""]
 
131
 
132
  # Prepare cluster df (topic_df)
133
  topic_df = None
134
+ preds_file = utils.get_preds_file(cur_user, pers_model)
135
+ with open(preds_file, "rb") as f:
136
  topic_df = pickle.load(f)
137
  if search_type == "cluster":
138
  # Display examples with comment, your pred, and other users' pred
 
228
  grp_ids=grp_ids,
229
  )
230
 
 
 
231
  # Modify model name
232
  model_name = f"{model_name}_group_gender{sel_gender}_relig{sel_relig}_pol{sel_pol}_race{sel_race_orig}_lgbtq_{sel_lgbtq}"
233
+ utils.setup_user_model_dirs(user, model_name)
 
 
 
 
 
234
 
235
  # Train group model
236
+ mae, mse, rmse, avg_diff, ratings_prev = utils.train_updated_model(model_name, ratings_grp, user)
237
 
238
  duration = time.time() - start
239
  print("Time to train/cache:", duration)
 
312
  ########################################
313
  # ROUTE: /GET_PERSONALIZED_MODEL
314
  @app.route("/get_personalized_model")
315
+ def get_personalized_model(debug=DEBUG):
316
  model_name = request.args.get("model_name")
317
  ratings_json = request.args.get("ratings")
318
  mode = request.args.get("mode")
319
  user = request.args.get("user")
320
  ratings = json.loads(ratings_json)
321
+ if debug:
322
+ print(ratings)
323
+ start = time.time()
324
 
325
+ utils.setup_user_model_dirs(user, model_name)
 
 
 
 
326
 
327
  # Handle existing or new model cases
328
  if mode == "view":
329
  # Fetch prior model performance
330
+ mae, mse, rmse, avg_diff, ratings_prev = utils.fetch_existing_data(user, model_name)
331
 
332
  elif mode == "train":
333
  # Train model and cache predictions using new labels
334
  print("get_personalized_model train")
335
+ mae, mse, rmse, avg_diff, ratings_prev = utils.train_updated_model(model_name, ratings, user)
336
+
337
+ if debug:
338
+ duration = time.time() - start
339
+ print("Time to train/cache:", duration)
340
 
341
+ perf_plot, mae_status = utils.plot_train_perf_results(user, model_name, mae)
342
  perf_plot_json = perf_plot.to_json()
343
 
344
  def round_metric(x):
 
351
  "mse": round_metric(mse),
352
  "rmse": round_metric(rmse),
353
  "avg_diff": round_metric(avg_diff),
 
354
  "ratings_prev": ratings_prev,
355
  "perf_plot_json": json.loads(perf_plot_json),
356
  }
 
371
 
372
  # Modify model name
373
  model_name = f"{model_name}_{topic}"
374
+ utils.setup_user_model_dirs(user, model_name)
 
 
 
 
 
375
 
376
  # Handle existing or new model cases
377
  # Train model and cache predictions using new labels
378
  print("get_personalized_model_topic train")
379
+ mae, mse, rmse, avg_diff, ratings_prev = utils.train_updated_model(model_name, ratings, user, topic=topic)
380
 
381
  duration = time.time() - start
382
  print("Time to train/cache:", duration)
 
403
  if topic_vis_method == "null":
404
  topic_vis_method = "fp_fn"
405
 
406
+ # Load reports for current user from stored file
407
+ reports_file = utils.get_reports_file(cur_user, model)
408
+ if not os.path.isfile(reports_file):
 
 
409
  if scaffold_method == "fixed":
410
  reports = get_fixed_scaffold()
411
  elif (scaffold_method == "personal" or scaffold_method == "personal_group" or scaffold_method == "personal_test"):
412
+ reports = get_personal_scaffold(cur_user, model, topic_vis_method)
413
  elif scaffold_method == "prompts":
414
  reports = get_prompts_scaffold()
415
  elif scaffold_method == "tutorial":
 
427
  ]
428
  else:
429
  # Load from pickle file
430
+ with open(reports_file, "rb") as f:
431
  reports = pickle.load(f)
432
 
433
  results = {
 
529
 
530
  return topic_errors
531
 
532
+ def get_personal_scaffold(cur_user, model, topic_vis_method, n_topics=200, n=5):
533
  threshold = utils.get_toxic_threshold()
534
 
535
  # Get topics with greatest amount of error
536
+ preds_file = utils.get_preds_file(cur_user, model)
537
+ with open(preds_file, "rb") as f:
538
  preds_df = pickle.load(f)
539
  system_preds_df = utils.get_system_preds_df()
540
  preds_df_mod = preds_df.merge(system_preds_df, on="item_id", how="left", suffixes=('', '_sys'))
 
639
  reports_json = request.args.get("reports")
640
  reports = json.loads(reports_json)
641
  scaffold_method = request.args.get("scaffold_method")
642
+ model = request.args.get("model")
643
 
644
+ # Save reports for current user to file
645
+ reports_file = utils.get_reports_file(cur_user, model)
646
+ with open(reports_file, "wb") as f:
 
647
  pickle.dump(reports, f)
648
 
649
  results = {