joel-woodfield commited on
Commit
d94ed04
·
1 Parent(s): 2c794eb

Show unregularized solution(s) in plot

Browse files
Files changed (1) hide show
  1. regularization.py +37 -0
regularization.py CHANGED
@@ -129,6 +129,24 @@ class Regularization:
129
  ]
130
  loss_levels.reverse()
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if plot_path:
133
  if loss_type == "l2":
134
  path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type)
@@ -159,6 +177,7 @@ class Regularization:
159
  reg_values,
160
  loss_levels,
161
  reg_levels,
 
162
  path_w,
163
  )
164
 
@@ -170,6 +189,7 @@ class Regularization:
170
  reg_values: np.ndarray,
171
  loss_levels: list,
172
  reg_levels: list,
 
173
  path_w: np.ndarray | None,
174
  ):
175
  fig, ax = plt.subplots(figsize=(8, 8))
@@ -189,6 +209,12 @@ class Regularization:
189
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
190
  ax.clabel(cs2, inline=True, fontsize=8)
191
 
 
 
 
 
 
 
192
  # regularization path
193
  if path_w is not None:
194
  ax.plot(path_w[:, 0], path_w[:, 1], "r-")
@@ -197,9 +223,20 @@ class Regularization:
197
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
198
  reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
199
  handles = [loss_line, reg_line]
 
200
  if path_w is not None:
201
  path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
202
  handles.append(path_line)
 
 
 
 
 
 
 
 
 
 
203
  ax.legend(handles=handles)
204
 
205
  ax.grid(True)
 
129
  ]
130
  loss_levels.reverse()
131
 
132
+ try:
133
+ unregularized_w = np.linalg.solve(X.T @ X, X.T @ y)
134
+ except np.linalg.LinAlgError:
135
+ # the solutions are on a line
136
+ eig_vals, eig_vectors = np.linalg.eigh(X.T @ X)
137
+ line_direction = eig_vectors[:, np.argmin(eig_vals)]
138
+ m = line_direction[1] / line_direction[0]
139
+
140
+ candidate_w = np.linalg.lstsq(X, y, rcond=None)[0]
141
+ b = candidate_w[1] - m * candidate_w[0]
142
+
143
+ unregularized_w1 = np.linspace(w1_range[0], w1_range[1], num_dots)
144
+ unregularized_w2 = m * unregularized_w1 + b
145
+ unregularized_w = np.stack((unregularized_w1, unregularized_w2), axis=-1)
146
+
147
+ mask = (unregularized_w2 <= w2_range[1]) & (unregularized_w2 >= w2_range[0])
148
+ unregularized_w = unregularized_w[mask]
149
+
150
  if plot_path:
151
  if loss_type == "l2":
152
  path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type)
 
177
  reg_values,
178
  loss_levels,
179
  reg_levels,
180
+ unregularized_w,
181
  path_w,
182
  )
183
 
 
189
  reg_values: np.ndarray,
190
  loss_levels: list,
191
  reg_levels: list,
192
+ unregularized_w: np.ndarray,
193
  path_w: np.ndarray | None,
194
  ):
195
  fig, ax = plt.subplots(figsize=(8, 8))
 
209
  cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1])
210
  ax.clabel(cs2, inline=True, fontsize=8)
211
 
212
+ # unregularized solution
213
+ if unregularized_w.ndim == 1:
214
+ ax.plot(unregularized_w[0], unregularized_w[1], "bx", markersize=5, label="unregularized solution")
215
+ else:
216
+ ax.plot(unregularized_w[:, 0], unregularized_w[:, 1], "b-", label="unregularized solution")
217
+
218
  # regularization path
219
  if path_w is not None:
220
  ax.plot(path_w[:, 0], path_w[:, 1], "r-")
 
223
  loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss')
224
  reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization')
225
  handles = [loss_line, reg_line]
226
+
227
  if path_w is not None:
228
  path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path')
229
  handles.append(path_line)
230
+
231
+ if unregularized_w.ndim == 1:
232
+ handles.append(
233
+ mlines.Line2D([], [], color='blue', marker='x', linestyle='None', label='unregularized solution')
234
+ )
235
+ else:
236
+ handles.append(
237
+ mlines.Line2D([], [], color='blue', linestyle='-', label='unregularized solution')
238
+ )
239
+
240
  ax.legend(handles=handles)
241
 
242
  ax.grid(True)