Hack90 commited on
Commit
27c7f11
1 Parent(s): a1d4679

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -22
app.py CHANGED
@@ -999,30 +999,34 @@ with ui.navset_card_tab(id="tab"):
999
  multiple=True,
1000
  selected=["compliment", "cross_entropy", "headless"]
1001
  )
1002
- def plot_loss_rates_model(df, param_types, loss_types, model_types):
1003
- # interplot each column to be same number of points
1004
- x = np.linspace(0, 1, 1000)
1005
- loss_rates = []
1006
- labels = []
1007
- #drop the column step
1008
- # df = df.drop(columns=['Step'])
1009
- for param_type in param_types:
1010
- for loss_type in loss_types:
1011
- for model_type in model_types:
1012
- y = df[(df['param_type'] ==param_type) & (df['loss_type'] == loss_type ) & (df['model_type'] == model_type)]['loss'].astype('float').values
1013
- print(y)
 
1014
  f = interp1d(np.linspace(0, 1, len(y)), y)
1015
  loss_rates.append(f(x))
1016
- labels.append(str(param_type) +'_'+loss_type +'_'+model_type)
1017
- fig, ax = plt.subplots()
1018
- print(loss_rates)
1019
- for i, loss_rate in enumerate(loss_rates):
1020
- ax.plot(x, loss_rate, label=labels[i])
1021
- ax.legend()
1022
- # ax.set_title(f'Loss rates for a parameter model across context windows')
1023
- ax.set_xlabel('Training steps')
1024
- ax.set_ylabel('Loss rate')
1025
- return fig
 
 
 
1026
 
1027
  import matplotlib as mpl
1028
  @render.plot()
 
999
  multiple=True,
1000
  selected=["compliment", "cross_entropy", "headless"]
1001
  )
1002
+ def plot_loss_rates_model(df, param_types, loss_types, model_types):
1003
+ # interplot each column to be same number of points
1004
+ x = np.linspace(0, 1, 1000)
1005
+ loss_rates = []
1006
+ labels = []
1007
+
1008
+ for param_type in param_types:
1009
+ for loss_type in loss_types:
1010
+ for model_type in model_types:
1011
+ y = df[(df['param_type'] == param_type) & (df['loss_type'] == loss_type) & (df['model_type'] == model_type)]['loss'].astype('float').values
1012
+ print(y)
1013
+
1014
+ if len(y) > 0:
1015
  f = interp1d(np.linspace(0, 1, len(y)), y)
1016
  loss_rates.append(f(x))
1017
+ labels.append(str(param_type) + '_' + loss_type + '_' + model_type)
1018
+
1019
+ fig, ax = plt.subplots()
1020
+ print(loss_rates)
1021
+
1022
+ for i, loss_rate in enumerate(loss_rates):
1023
+ ax.plot(x, loss_rate, label=labels[i])
1024
+
1025
+ ax.legend()
1026
+ ax.set_xlabel('Training steps')
1027
+ ax.set_ylabel('Loss rate')
1028
+
1029
+ return fig
1030
 
1031
  import matplotlib as mpl
1032
  @render.plot()