glenn-jocher commited on
Commit
c1a2a7a
1 Parent(s): 8074745

hyperparameter evolution bug fix (#566)

Browse files
Files changed (2) hide show
  1. train.py +2 -2
  2. utils/utils.py +16 -12
train.py CHANGED
@@ -465,7 +465,7 @@ if __name__ == '__main__':
465
  # Evolve hyperparameters (optional)
466
  else:
467
  # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
468
- meta = {'lr0': (1, 1e-5, 1e-2), # initial learning rate (SGD=1E-2, Adam=1E-3)
469
  'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
470
  'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
471
  'giou': (1, 0.02, 0.2), # GIoU loss gain
@@ -534,6 +534,6 @@ if __name__ == '__main__':
534
  print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
535
 
536
  # Plot results
537
- plot_evolution_results(yaml_file)
538
  print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
539
  'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))
 
465
  # Evolve hyperparameters (optional)
466
  else:
467
  # Hyperparameter evolution metadata (mutation scale 0-1, lower_limit, upper_limit)
468
+ meta = {'lr0': (1, 1e-5, 1e-1), # initial learning rate (SGD=1E-2, Adam=1E-3)
469
  'momentum': (0.1, 0.6, 0.98), # SGD momentum/Adam beta1
470
  'weight_decay': (1, 0.0, 0.001), # optimizer weight decay
471
  'giou': (1, 0.02, 0.2), # GIoU loss gain
 
534
  print_mutation(hyp.copy(), results, yaml_file, opt.bucket)
535
 
536
  # Plot results
537
+ plot_evolution(yaml_file)
538
  print('Hyperparameter evolution complete. Best results saved as: %s\nCommand to train a new model with these '
539
  'hyperparameters: $ python train.py --hyp %s' % (yaml_file, yaml_file))
utils/utils.py CHANGED
@@ -919,6 +919,15 @@ def increment_dir(dir, comment=''):
919
 
920
 
921
  # Plotting functions ---------------------------------------------------------------------------------------------------
 
 
 
 
 
 
 
 
 
922
  def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
923
  # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
924
  def butter_lowpass(cutoff, fs, order):
@@ -1130,13 +1139,6 @@ def plot_study_txt(f='study.txt', x=None): # from utils.utils import *; plot_st
1130
 
1131
  def plot_labels(labels, save_dir=''):
1132
  # plot dataset labels
1133
- def hist2d(x, y, n=100):
1134
- xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
1135
- hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
1136
- xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
1137
- yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
1138
- return np.log(hist[xidx, yidx])
1139
-
1140
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
1141
  nc = int(c.max() + 1) # number of classes
1142
 
@@ -1154,23 +1156,25 @@ def plot_labels(labels, save_dir=''):
1154
  plt.close()
1155
 
1156
 
1157
- def plot_evolution_results(yaml_file='hyp_evolved.yaml'): # from utils.utils import *; plot_evolution_results()
1158
  # Plot hyperparameter evolution results in evolve.txt
1159
  with open(yaml_file) as f:
1160
  hyp = yaml.load(f, Loader=yaml.FullLoader)
1161
  x = np.loadtxt('evolve.txt', ndmin=2)
1162
  f = fitness(x)
1163
  # weights = (f - f.min()) ** 2 # for weighted results
1164
- plt.figure(figsize=(14, 10), tight_layout=True)
1165
  matplotlib.rc('font', **{'size': 8})
1166
  for i, (k, v) in enumerate(hyp.items()):
1167
  y = x[:, i + 7]
1168
  # mu = (y * weights).sum() / weights.sum() # best weighted result
1169
  mu = y[f.argmax()] # best single result
1170
- plt.subplot(4, 6, i + 1)
1171
- plt.plot(mu, f.max(), 'o', markersize=10)
1172
- plt.plot(y, f, '.')
1173
  plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
 
 
1174
  print('%15s: %.3g' % (k, mu))
1175
  plt.savefig('evolve.png', dpi=200)
1176
  print('\nPlot saved as evolve.png')
 
919
 
920
 
921
  # Plotting functions ---------------------------------------------------------------------------------------------------
922
+ def hist2d(x, y, n=100):
923
+ # 2d histogram used in labels.png and evolve.png
924
+ xedges, yedges = np.linspace(x.min(), x.max(), n), np.linspace(y.min(), y.max(), n)
925
+ hist, xedges, yedges = np.histogram2d(x, y, (xedges, yedges))
926
+ xidx = np.clip(np.digitize(x, xedges) - 1, 0, hist.shape[0] - 1)
927
+ yidx = np.clip(np.digitize(y, yedges) - 1, 0, hist.shape[1] - 1)
928
+ return np.log(hist[xidx, yidx])
929
+
930
+
931
  def butter_lowpass_filtfilt(data, cutoff=1500, fs=50000, order=5):
932
  # https://stackoverflow.com/questions/28536191/how-to-filter-smooth-with-scipy-numpy
933
  def butter_lowpass(cutoff, fs, order):
 
1139
 
1140
  def plot_labels(labels, save_dir=''):
1141
  # plot dataset labels
 
 
 
 
 
 
 
1142
  c, b = labels[:, 0], labels[:, 1:].transpose() # classes, boxes
1143
  nc = int(c.max() + 1) # number of classes
1144
 
 
1156
  plt.close()
1157
 
1158
 
1159
+ def plot_evolution(yaml_file='runs/evolve/hyp_evolved.yaml'): # from utils.utils import *; plot_evolution()
1160
  # Plot hyperparameter evolution results in evolve.txt
1161
  with open(yaml_file) as f:
1162
  hyp = yaml.load(f, Loader=yaml.FullLoader)
1163
  x = np.loadtxt('evolve.txt', ndmin=2)
1164
  f = fitness(x)
1165
  # weights = (f - f.min()) ** 2 # for weighted results
1166
+ plt.figure(figsize=(10, 10), tight_layout=True)
1167
  matplotlib.rc('font', **{'size': 8})
1168
  for i, (k, v) in enumerate(hyp.items()):
1169
  y = x[:, i + 7]
1170
  # mu = (y * weights).sum() / weights.sum() # best weighted result
1171
  mu = y[f.argmax()] # best single result
1172
+ plt.subplot(5, 5, i + 1)
1173
+ plt.scatter(y, f, c=hist2d(y, f, 20), cmap='viridis', alpha=.8, edgecolors='none')
1174
+ plt.plot(mu, f.max(), 'k+', markersize=15)
1175
  plt.title('%s = %.3g' % (k, mu), fontdict={'size': 9}) # limit to 40 characters
1176
+ if i % 5 != 0:
1177
+ plt.yticks([])
1178
  print('%15s: %.3g' % (k, mu))
1179
  plt.savefig('evolve.png', dpi=200)
1180
  print('\nPlot saved as evolve.png')