Hack90 commited on
Commit
6bbcf95
·
verified ·
1 Parent(s): 3ec9f4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -2
app.py CHANGED
@@ -992,7 +992,7 @@ with ui.navset_card_tab(id="tab"):
992
  "Select Model Type:",
993
  ["pythia", "denseformer"],
994
  multiple=True,
995
- selected=None
996
  )
997
 
998
  with ui.card():
@@ -1001,7 +1001,7 @@ with ui.navset_card_tab(id="tab"):
1001
  "Select Loss Type:",
1002
  ["compliment", "cross_entropy", "headless"],
1003
  multiple=True,
1004
- selected=None
1005
  )
1006
 
1007
  @output
@@ -1018,8 +1018,49 @@ with ui.navset_card_tab(id="tab"):
1018
  & (df["loss_type"].isin(loss_type()))
1019
  ]
1020
 
 
1021
  if filtered_df.empty:
1022
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1023
 
1024
  # Define colors for sizes and shapes for loss types
1025
  size_colors = {
 
992
  "Select Model Type:",
993
  ["pythia", "denseformer"],
994
  multiple=True,
995
+ selected='pythia'
996
  )
997
 
998
  with ui.card():
 
1001
  "Select Loss Type:",
1002
  ["compliment", "cross_entropy", "headless"],
1003
  multiple=True,
1004
+ selected='compliment'
1005
  )
1006
 
1007
  @output
 
1018
  & (df["loss_type"].isin(loss_type()))
1019
  ]
1020
 
1021
+
1022
  if filtered_df.empty:
1023
  return None
1024
+
1025
+ # Define colors for sizes and shapes for loss types
1026
+ size_colors = {
1027
+ "14": "blue",
1028
+ "31": "green",
1029
+ "70": "orange",
1030
+ "160": "red"
1031
+ }
1032
+
1033
+ loss_markers = {
1034
+ "compliment": "o",
1035
+ "cross_entropy": "^",
1036
+ "headless": "s"
1037
+ }
1038
+
1039
+ # Create the plot
1040
+ fig, ax = plt.subplots(figsize=(10, 6))
1041
+
1042
+ # Plot each combination of size and loss type
1043
+ for size in filtered_df["param_type"].unique():
1044
+ for loss_type in filtered_df["loss_type"].unique():
1045
+ data = filtered_df[(filtered_df["param_type"] == size) & (filtered_df["loss_type"] == loss_type)]
1046
+ ax.plot(data["epoch"], data["loss"], marker=loss_markers[loss_type], color=size_colors[size], label=f"{size} - {loss_type}")
1047
+
1048
+ # Customize the plot
1049
+ ax.set_xlabel("Epoch")
1050
+ ax.set_ylabel("Loss")
1051
+ ax.set_title("Training Loss by Size and Loss Type", fontsize=16)
1052
+
1053
+ # Create a legend for sizes
1054
+ size_legend = ax.legend(title="Size", loc="upper right")
1055
+ ax.add_artist(size_legend)
1056
+
1057
+ # Create a separate legend for loss types
1058
+ loss_legend_labels = ["Compliment", "Cross Entropy", "Headless"]
1059
+ loss_legend_handles = [plt.Line2D([0], [0], marker=loss_markers[loss_type], color='black', linestyle='None', markersize=8) for loss_type in loss_markers]
1060
+ loss_legend = ax.legend(loss_legend_handles, loss_legend_labels, title="Loss Type", loc="upper right")
1061
+
1062
+ plt.tight_layout()
1063
+ return fig
1064
 
1065
  # Define colors for sizes and shapes for loss types
1066
  size_colors = {