lemonteaa commited on
Commit
d7ce189
1 Parent(s): cf88959

Create analysis/plot_log.py

Browse files
Files changed (1) hide show
  1. baseline/analysis/plot_log.py +68 -0
baseline/analysis/plot_log.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from scipy.optimize import curve_fit
4
+
5
+ def parse_file(file_path):
6
+ data = []
7
+ with open(file_path, 'r') as file:
8
+ for line in file:
9
+ parts = line.strip().split()
10
+ step = int(parts[0].split(':')[1].split('/')[0])
11
+ is_train = 'val' not in parts[1]
12
+ if is_train:
13
+ loss_key = 'train_loss'
14
+ else:
15
+ loss_key = 'val_loss'
16
+ loss = float(parts[1].split(':')[1])
17
+ step_avg = float(parts[3].split(':')[1].replace('ms', ''))
18
+ data.append({
19
+ 'step': step,
20
+ 'loss': loss,
21
+ 'step_avg': step_avg,
22
+ 'is_train': is_train
23
+ })
24
+ return data
25
+
26
+ # Usage
27
+ file_path = 'baseline_log.txt'
28
+ data = parse_file(file_path)
29
+
30
+
31
+
32
+ # Extract the steps and losses into separate lists
33
+ steps = np.array([d['step'] for d in filter(lambda item: item['is_train'],data)])
34
+ losses = np.array([d['loss'] for d in filter(lambda item: item['is_train'],data)])
35
+
36
+ # Take the logarithm of the data
37
+ log_steps = np.log10(steps)
38
+ log_losses = np.log10(losses)
39
+
40
+ # Define a linear function
41
+ def linear_func(x, a, b):
42
+ return a * x + b
43
+
44
+ # Fit the linear function to the logarithmic data
45
+ popt, pcov = curve_fit(linear_func, log_steps, log_losses)
46
+
47
+ # Create the plot
48
+ plt.loglog(steps, losses, label='Data')
49
+
50
+ # Plot the fitted line
51
+ x_fit = np.logspace(np.log10(np.min(steps)), np.log10(np.max(steps)), 100)
52
+ y_fit = 10 ** (popt[0] * np.log10(x_fit) + popt[1])
53
+ plt.loglog(x_fit, y_fit, label='Fitted line', color='red')
54
+
55
+ # Add title and labels
56
+ plt.title('Loss as a function of step')
57
+ plt.xlabel('Step')
58
+ plt.ylabel('Loss')
59
+ plt.legend()
60
+
61
+ # Print the fitted parameters
62
+ print('Fitted parameters: a = {:.2f}, b = {:.2f}'.format(popt[0], popt[1]))
63
+
64
+ # Save the plot to a file
65
+ plt.savefig('loss_plot2.png')
66
+
67
+ # Show the plot
68
+ plt.show()