Hack90 commited on
Commit
884d71b
·
verified ·
1 Parent(s): c265b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py CHANGED
@@ -973,6 +973,92 @@ with ui.navset_card_tab(id="tab"):
973
  mpl.rcParams.update(mpl.rcParamsDefault)
974
  fig = plot_loss_rates(df, '14M')
975
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
976
 
977
 
978
  # @render.image
 
973
  mpl.rcParams.update(mpl.rcParamsDefault)
974
  fig = plot_loss_rates(df, '14M')
975
  return fig
976
+ with ui.nav_panel("Model loss analysis"):
977
+ ui.page_opts(fillable=True)
978
+ ui.panel_title("Neurips stuff")
979
+
980
+ with ui.card():
981
+ param_type = ui.input_selectize(
982
+ "param_type",
983
+ "Select Param Type:",
984
+ ["14", "31", "70", "160"],
985
+ multiple=True,
986
+ selected=None
987
+ )
988
+
989
+ with ui.card():
990
+ model_type = ui.input_selectize(
991
+ "model_type",
992
+ "Select Model Type:",
993
+ ["pythia", "denseformer"],
994
+ multiple=True,
995
+ selected=None
996
+ )
997
+
998
+ with ui.card():
999
+ loss_type = ui.input_selectize(
1000
+ "loss_type",
1001
+ "Select Loss Type:",
1002
+ ["compliment", "cross_entropy", "headless"],
1003
+ multiple=True,
1004
+ selected=None
1005
+ )
1006
+
1007
+ @output
1008
+ @render.plot
1009
+ def plot_training_loss():
1010
+ if csv_file() is None:
1011
+ return None
1012
+
1013
+ df = pd.read_csv(csv_file())
1014
+
1015
+ filtered_df = df[
1016
+ (df["param_type"].isin(param_type()))
1017
+ & (df["model_type"].isin(model_type()))
1018
+ & (df["loss_type"].isin(loss_type()))
1019
+ ]
1020
+
1021
+ if filtered_df.empty:
1022
+ return None
1023
+
1024
+ # Define colors for sizes and shapes for loss types
1025
+ size_colors = {
1026
+ "14": "blue",
1027
+ "31": "green",
1028
+ "70": "orange",
1029
+ "160": "red"
1030
+ }
1031
+ loss_markers = {
1032
+ "compliment": "o",
1033
+ "cross_entropy": "^",
1034
+ "headless": "s"
1035
+ }
1036
+
1037
+ # Create a relplot using Seaborn
1038
+ g = sns.relplot(
1039
+ data=filtered_df,
1040
+ x="epoch",
1041
+ y="loss",
1042
+ hue="param_type",
1043
+ style="loss_type",
1044
+ palette=size_colors,
1045
+ markers=loss_markers,
1046
+ height=6,
1047
+ aspect=1.5
1048
+ )
1049
+
1050
+ # Customize the plot
1051
+ g.set_xlabels("Epoch")
1052
+ g.set_ylabels("Loss")
1053
+ g.fig.suptitle("Training Loss by Size and Loss Type", fontsize=16)
1054
+ g.add_legend(title="Size")
1055
+
1056
+ # Create a separate legend for loss types
1057
+ loss_legend = plt.legend(title="Loss Type", loc="upper right", labels=["Compliment", "Cross Entropy", "Headless"])
1058
+ plt.gca().add_artist(loss_legend)
1059
+
1060
+ plt.tight_layout()
1061
+ return g.fig
1062
 
1063
 
1064
  # @render.image