alkzar90 commited on
Commit
4555e57
1 Parent(s): b5beaf9

Fix indentation error

Browse files
Files changed (1) hide show
  1. app.py +22 -22
app.py CHANGED
@@ -71,29 +71,29 @@ if cost_function == 'RMSE-Loss':
71
  st.latex(r'''\ell(X, y, w)= \frac{1}{m}\sum_1^m (\hat{y}_i - y_i)^2''')
72
 
73
 
74
- progress_bar = st.progress(0)
75
- status_text = st.empty()
76
- grad_loss = jax.grad(loss)
77
 
78
- # Perform gradient descent
79
- for i in range(NUM_ITER):
80
- # Update progress bar.
81
- progress_bar.progress(i + 1)
82
-
83
- # Update parameters.
84
- w -= learning_rate * grad_loss(w)
85
-
86
- # Update status text.
87
- if (i+1)%100==0:
88
- status_text.text(
89
- 'Trained loss is: %s' % loss(w))
90
-
91
- # add a line to the plot
92
- plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), '--')
93
- print(f"Trained loss at epoch #{i}:", loss(w))
94
- plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), '--', label='Final line')
95
-
96
- status_text.text('Done!')
97
 
98
 
99
  else:
 
71
  st.latex(r'''\ell(X, y, w)= \frac{1}{m}\sum_1^m (\hat{y}_i - y_i)^2''')
72
 
73
 
74
+ progress_bar = st.progress(0)
75
+ status_text = st.empty()
76
+ grad_loss = jax.grad(loss)
77
 
78
+ # Perform gradient descent
79
+ for i in range(NUM_ITER):
80
+ # Update progress bar.
81
+ progress_bar.progress(i + 1)
82
+
83
+ # Update parameters.
84
+ w -= learning_rate * grad_loss(w)
85
+
86
+ # Update status text.
87
+ if (i+1)%100==0:
88
+ # report the loss at the current epoch
89
+ status_text.text(
90
+ 'Trained loss is: %s' % loss(w))
91
+
92
+ # add a line to the plot
93
+ plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), 'c--')
94
+ # Plot the final line
95
+ plt.plot(X[jnp.dot(X, w).argsort(), 1], jnp.dot(X, w).sort(), 'k-', label='Final line'
96
+ status_text.text('Done!')
97
 
98
 
99
  else: