Hack90 commited on
Commit
44feaa5
1 Parent(s): 11e2945

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py CHANGED
@@ -237,3 +237,61 @@ with ui.navset_card_tab(id="tab"):
237
  )
238
  if fig:
239
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  )
238
  if fig:
239
  return fig
240
+ with ui.nav_panel("Logits View"):
241
+ ui.panel_title("Logits et all")
242
+ with ui.card():
243
+ ui.input_selectize(
244
+ "model_bigness",
245
+ "Select Model size:",
246
+ ["14", "31", "70", "160", "410"],
247
+ multiple=True,
248
+ selected=["70", "160"],
249
+ )
250
+ ui.input_selectize(
251
+ "loss_loss_loss",
252
+ "Select Loss Type:",
253
+ ["compliment", "cross_entropy", "headless", "2d_representation_GaussianPlusCE", "2d_representation_MSEPlusCE"],
254
+ multiple=True,
255
+ selected=["cross_entropy"],
256
+ )
257
+ ui.input_selectize(
258
+ "logits_select",
259
+ "Select logits:",
260
+ ["1", "2", "3", "4", "5", "6", "7", "8"],
261
+ multiple=True,
262
+ selected=["6"],
263
+ )
264
+
265
+ def plot_logits_representation(model_bigness, loss_type, logits):
266
+ num_rows = 2 # Number of rows in the subplot grid
267
+ num_cols = len(logits) # Number of columns based on the number of selected logits
268
+ fig, axs = plt.subplots(num_rows, num_cols, figsize=(20, 10))
269
+ axs = axs.flatten() # Flatten axs to handle 1D indexing
270
+
271
+ for size in model_bigness:
272
+ for loss in loss_type:
273
+ file_name = f"virus_pythia_{size}_1024_{loss}_logit_cumsums.npy"
274
+ if os.path.exists(file_name):
275
+ data = np.load(file_name, allow_pickle=True).item()
276
+ for k, logit in enumerate(logits):
277
+ logit_index = int(logit) - 1
278
+ axs[k].plot(data['lm_logits_y_cumsum'][0, :, logit_index], label=f'Generated_{loss}_{size}')
279
+ axs[k].plot(data['shift_labels_y_cumsum'][0, :, logit_index], label=f'Expected_{loss}_{size}')
280
+ axs[k].set_title(f'Logit: {logit}')
281
+ axs[k].legend()
282
+ else:
283
+ print(f"File not found: {file_name}")
284
+
285
+ for k in range(len(logits), num_cols):
286
+ fig.delaxes(axs[k]) # Remove any extra subplots if fewer logits are selected
287
+
288
+ plt.tight_layout()
289
+ return fig
290
+
291
+ @render.plot()
292
+ def plot_logits_representation_ui():
293
+ fig = plot_logits_representation(
294
+ input.model_bigness(), input.loss_loss_loss(), input.logits_select()
295
+ )
296
+ if fig:
297
+ return fig