Hack90 commited on
Commit
bab98a8
1 Parent(s): e47e7fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -0
app.py CHANGED
@@ -1036,6 +1036,64 @@ with ui.navset_card_tab(id="tab"):
1036
  mpl.rcParams.update(mpl.rcParamsDefault)
1037
  fig = plot_loss_rates_model(df, input.param_type(),input.loss_type(),input.model_type())
1038
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1039
  # @output
1040
  # @render.plot
1041
  # def plot_training_loss():
 
1036
  mpl.rcParams.update(mpl.rcParamsDefault)
1037
  fig = plot_loss_rates_model(df, input.param_type(),input.loss_type(),input.model_type())
1038
  return fig
1039
+ with ui.nav_panel("Scaling Laws"):
1040
+ ui.page_opts(fillable=True)
1041
+ ui.panel_title("Params & Losses")
1042
+
1043
+ with ui.card():
1044
+
1045
+ ui.input_selectize(
1046
+ "model_type",
1047
+ "Select Model Type:",
1048
+ ["pythia", "denseformer", "evo"],
1049
+ multiple=True,
1050
+ selected=['evo','denseformer']
1051
+ )
1052
+ ui.input_selectize(
1053
+ "loss_type",
1054
+ "Select Loss Type:",
1055
+ ["compliment", "cross_entropy", "headless", "2d", "2d_representation_MSEPlusCE"],
1056
+ multiple=False,
1057
+ selected="cross_entropy"
1058
+ )
1059
+ def plot_loss_rates_model_scale(df, loss_type, model_types):
1060
+ df = df[df['loss_type'] == loss_type]
1061
+ # interplot each column to be same number of points
1062
+ params = []
1063
+ loss_rates = []
1064
+ labels = []
1065
+ for model_type in model_types:
1066
+ df_new = df[df['model_type']]
1067
+ losses = []
1068
+ params_model = []
1069
+ for paramy in df_new['num_params'].unique():
1070
+ loss = df_new[df_new['num_params']==paramy].min()
1071
+ par = int(paramy)
1072
+ losses.append(loss)
1073
+ params_model.append(par)
1074
+ loss_rates.append(losses)
1075
+ params.append(params_model)
1076
+ labels.append(model_type)
1077
+
1078
+ fig, ax = plt.subplots()
1079
+
1080
+ for i, loss_rate in enumerate(loss_rates):
1081
+ ax.plot(x=params[i], y=loss_rate, label=labels[i])
1082
+
1083
+ ax.legend()
1084
+ ax.set_xlabel('Params')
1085
+ ax.set_ylabel('Loss')
1086
+
1087
+ return fig
1088
+
1089
+ import matplotlib as mpl
1090
+ @render.plot()
1091
+ def plot_loss_rates_model_scale():
1092
+ fig = None
1093
+ df = pd.read_csv('training_data_5.csv')
1094
+ mpl.rcParams.update(mpl.rcParamsDefault)
1095
+ fig = plot_loss_rates_model(df,input.loss_type(),input.model_type())
1096
+ return fig
1097
  # @output
1098
  # @render.plot
1099
  # def plot_training_loss():