Spaces:
Sleeping
Sleeping
| import io | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import matplotlib.lines as mlines | |
| import numpy as np | |
| from PIL import Image | |
| import plotly.graph_objects as go | |
| from sklearn.datasets import make_regression | |
| from sklearn.linear_model import ElasticNet | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) | |
| format="%(asctime)s [%(levelname)s] %(message)s", # log format | |
| ) | |
| logger = logging.getLogger("ELVIS") | |
| from dataset import Dataset, DatasetView | |
| def min_corresponding_entries(W1, W2, w1, tol=0.1): | |
| mask = (W1 <= w1) | |
| values = W2[mask] | |
| if values.size == 0: | |
| raise ValueError("No entries in W1 less than equal to w1") | |
| return np.min(values) | |
| def l1_norm(W): | |
| return np.sum(np.abs(W), axis=-1) | |
| def l2_norm(W): | |
| return np.linalg.norm(W, axis=-1) | |
| def l1_loss(W, y, X): | |
| num_dots = W.shape[0] | |
| y = y.reshape(1, -1) | |
| preds = W.reshape(-1, 2) @ X.T | |
| return np.mean(np.abs(y - preds), axis=1).reshape(num_dots, num_dots) | |
| def l2_loss(W, y, X): | |
| num_dots = W.shape[0] | |
| y = y.reshape(1, -1) | |
| preds = W.reshape(-1, 2) @ X.T | |
| return np.mean((y - preds) ** 2, axis=1).reshape(num_dots, num_dots) | |
| def l2_loss_regularization_path(y, X, regularization_type): | |
| if regularization_type == "l2": | |
| l1_ratio = 0 | |
| alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)]) | |
| elif regularization_type == "l1": | |
| l1_ratio = 1 | |
| alphas = None | |
| else: | |
| raise ValueError("regularization_type must be 'l1' or 'l2'") | |
| _, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas) | |
| return coefs.T | |
| class Regularization: | |
| LOSS_TYPES = ['l1', 'l2'] | |
| REGULARIZER_TYPES = ['l1', 'l2'] | |
| LOSS_FUNCTIONS = { | |
| 'l1': l1_loss, | |
| 'l2': l2_loss, | |
| } | |
| REGULARIZER_FUNCTIONS = { | |
| 'l1': l1_norm, | |
| 'l2': l2_norm, | |
| } | |
| FIGURE_NAME = "loss_and_regularization_plot.svg" | |
| def __init__(self, width, height): | |
| # initialized in draw_plot | |
| #self.canvas_width = -1 | |
| #self.canvas_height = -1 | |
| self.canvas_width = width | |
| self.canvas_height = height | |
| self.css =""" | |
| .hidden-button { | |
| display: none; | |
| } | |
| """ | |
| def compute_and_plot_loss_and_reg( | |
| self, | |
| dataset: Dataset, | |
| loss_type: str, | |
| reg_type: str, | |
| reg_levels: list, | |
| w1_range: list, | |
| w2_range: list, | |
| num_dots: int, | |
| plot_path: bool, | |
| ): | |
| X = dataset.X | |
| y = dataset.y | |
| W1, W2 = self._build_parameter_grid( | |
| w1_range, w2_range, num_dots | |
| ) | |
| losses = self._compute_losses( | |
| X, y, loss_type, W1, W2 | |
| ) | |
| reg_values = self._compute_reg_values( | |
| W1, W2, reg_type | |
| ) | |
| loss_levels = [ | |
| min_corresponding_entries( | |
| reg_values, losses, reg_level | |
| ) | |
| for reg_level in reg_levels | |
| ] | |
| loss_levels.reverse() | |
| try: | |
| unregularized_w = np.linalg.solve(X.T @ X, X.T @ y) | |
| except np.linalg.LinAlgError: | |
| # the solutions are on a line | |
| eig_vals, eig_vectors = np.linalg.eigh(X.T @ X) | |
| line_direction = eig_vectors[:, np.argmin(eig_vals)] | |
| m = line_direction[1] / line_direction[0] | |
| candidate_w = np.linalg.lstsq(X, y, rcond=None)[0] | |
| b = candidate_w[1] - m * candidate_w[0] | |
| unregularized_w1 = np.linspace(w1_range[0], w1_range[1], num_dots) | |
| unregularized_w2 = m * unregularized_w1 + b | |
| unregularized_w = np.stack((unregularized_w1, unregularized_w2), axis=-1) | |
| mask = (unregularized_w2 <= w2_range[1]) & (unregularized_w2 >= w2_range[0]) | |
| unregularized_w = unregularized_w[mask] | |
| if plot_path: | |
| if loss_type == "l2": | |
| path_w = l2_loss_regularization_path(y, X, regularization_type=reg_type) | |
| else: | |
| # one possible way that works but its rough | |
| # min_loss_reg = reg_values.ravel()[np.argmin(losses)] | |
| # path_reg_levels = np.linspace(0, min_loss_reg, 20) | |
| # path_w = [] | |
| # for reg_level in path_reg_levels: | |
| # mask = reg_values <= reg_level | |
| # if np.sum(mask) == 0: | |
| # continue | |
| # idx = np.argmin(losses[mask]) | |
| # path_w.append( | |
| # np.stack((W1, W2), axis=-1)[mask][idx] | |
| # ) | |
| # | |
| # path_w = np.array(path_w) | |
| path_w = None | |
| else: | |
| path_w = None | |
| return self.plot_loss_and_reg( | |
| W1, | |
| W2, | |
| losses, | |
| reg_values, | |
| loss_levels, | |
| reg_levels, | |
| unregularized_w, | |
| path_w, | |
| ) | |
| def plot_loss_and_reg( | |
| self, | |
| W1: np.ndarray, | |
| W2: np.ndarray, | |
| losses: np.ndarray, | |
| reg_values: np.ndarray, | |
| loss_levels: list, | |
| reg_levels: list, | |
| unregularized_w: np.ndarray, | |
| path_w: np.ndarray | None, | |
| ): | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.set_title("") | |
| ax.set_xlabel("w1") | |
| ax.set_ylabel("w2") | |
| cmap = plt.get_cmap("viridis") | |
| N = len(reg_levels) | |
| colors = [cmap(i / (N - 1)) for i in range(N)] | |
| # regularizer contours | |
| cs1 = ax.contour(W1, W2, reg_values, levels=reg_levels, colors=colors, linestyles="dashed") | |
| ax.clabel(cs1, inline=True, fontsize=8) # show contour levels | |
| # loss contours | |
| cs2 = ax.contour(W1, W2, losses, levels=loss_levels, colors=colors[::-1]) | |
| ax.clabel(cs2, inline=True, fontsize=8) | |
| # unregularized solution | |
| if unregularized_w.ndim == 1: | |
| ax.plot(unregularized_w[0], unregularized_w[1], "bx", markersize=5, label="unregularized solution") | |
| else: | |
| ax.plot(unregularized_w[:, 0], unregularized_w[:, 1], "b-", label="unregularized solution") | |
| # regularization path | |
| if path_w is not None: | |
| ax.plot(path_w[:, 0], path_w[:, 1], "r-") | |
| # legend | |
| loss_line = mlines.Line2D([], [], color='black', linestyle='-', label='loss') | |
| reg_line = mlines.Line2D([], [], color='black', linestyle='--', label='regularization') | |
| handles = [loss_line, reg_line] | |
| if path_w is not None: | |
| path_line = mlines.Line2D([], [], color='red', linestyle='-', label='regularization path') | |
| handles.append(path_line) | |
| if unregularized_w.ndim == 1: | |
| handles.append( | |
| mlines.Line2D([], [], color='blue', marker='x', linestyle='None', label='unregularized solution') | |
| ) | |
| else: | |
| handles.append( | |
| mlines.Line2D([], [], color='blue', linestyle='-', label='unregularized solution') | |
| ) | |
| ax.legend(handles=handles) | |
| ax.grid(True) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| fig.savefig(f"{self.FIGURE_NAME}") | |
| return img | |
| def plot_data(self, dataset: Dataset): | |
| mesh_x1, mesh_x2, y = dataset.get_function(nsample=100) | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Surface( | |
| z=y, | |
| x=mesh_x1, | |
| y=mesh_x2, | |
| colorscale='Viridis', | |
| opacity=0.8, | |
| name='True function', | |
| ) | |
| ) | |
| fig.add_trace( | |
| go.Scatter3d( | |
| x=dataset.X[:, 0], | |
| y=dataset.X[:, 1], | |
| z=dataset.y, | |
| mode='markers', | |
| marker=dict( | |
| size=3, | |
| color='red', | |
| opacity=0.8, | |
| symbol='circle', | |
| ), | |
| name='Data Points', | |
| ) | |
| ) | |
| fig.update_layout( | |
| title="Data", | |
| scene={ | |
| "xaxis": {"title": "X1", "nticks": 6}, | |
| "yaxis": {"title": "X2", "nticks": 6}, | |
| "zaxis": {"title": "Y", "nticks": 6}, | |
| "camera": {"eye": {"x": -1.5, "y": -1.5, "z": 1.2}}, | |
| }, | |
| width=800, | |
| height=600, | |
| ) | |
| return fig | |
| def plot_strength_vs_weight(self, dataset: Dataset, loss_type: str, reg_type: str): | |
| X = dataset.X | |
| y = dataset.y | |
| alphas = np.concat([np.zeros(1), np.logspace(-2, 2, 100)]) | |
| if loss_type == "l2": | |
| l1_ratio = 1 if reg_type == "l1" else 0 | |
| alphas, coefs, *_ = ElasticNet.path(X, y, l1_ratio=l1_ratio, alphas=alphas) | |
| else: | |
| return Image.new("RGB", (800, 800), color="white") | |
| coefs = coefs.T | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.plot(alphas, coefs[:, 0], label="w1") | |
| ax.plot(alphas, coefs[:, 1], label="w2") | |
| ax.set_xscale("log") | |
| ax.set_xlabel("Regularization strength (alpha)") | |
| ax.set_ylabel("Weight value") | |
| ax.legend() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| return img | |
| def update_loss_type(self, loss_type: str): | |
| if loss_type not in self.LOSS_TYPES: | |
| raise ValueError(f"loss_type must be one of {self.LOSS_TYPES}") | |
| return loss_type | |
| def update_reg_path_visibility(self, loss_type: str): | |
| visible = loss_type == "l2" | |
| return gr.update(visible=visible) | |
| def update_regularizer(self, reg_type: str): | |
| if reg_type not in self.REGULARIZER_TYPES: | |
| raise ValueError(f"reg_type must be one of {self.REGULARIZER_TYPES}") | |
| return reg_type | |
| def update_reg_levels(self, reg_levels_input: str): | |
| reg_levels = [float(reg_level) for reg_level in reg_levels_input.split(",")] | |
| return reg_levels | |
| def update_w1_range(self, w1_range_input: str): | |
| w1_range = [float(w1) for w1 in w1_range_input.split(",")] | |
| return w1_range | |
| def update_w2_range(self, w2_range_input: str): | |
| w2_range = [float(w2) for w2 in w2_range_input.split(",")] | |
| return w2_range | |
| def update_resolution(self, num_dots: int): | |
| return num_dots | |
| def update_plot_path(self, plot_path: bool): | |
| return plot_path | |
| def _build_parameter_grid( | |
| self, | |
| w1_range: list, | |
| w2_range: list, | |
| num_dots: int, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| # build grid in parameter space | |
| w1 = np.linspace(w1_range[0], w1_range[1], num_dots) | |
| w2 = np.linspace(w2_range[0], w2_range[1], num_dots) | |
| # include (0, 0) | |
| if 0 not in w1: | |
| w1 = np.insert(w1, np.searchsorted(w1, 0), 0) | |
| if 0 not in w2: | |
| w2 = np.insert(w2, np.searchsorted(w2, 0), 0) | |
| W1, W2 = np.meshgrid(w1, w2) | |
| return W1, W2 | |
| def _compute_losses( | |
| self, | |
| X: np.ndarray, | |
| y: np.ndarray, | |
| loss_type: str, | |
| W1: np.ndarray, | |
| W2: np.ndarray, | |
| ) -> np.ndarray: | |
| stacked = np.stack((W1, W2), axis=-1) | |
| losses = self.LOSS_FUNCTIONS[loss_type](stacked, y, X) | |
| return losses | |
| def _compute_reg_values( | |
| self, | |
| W1: np.ndarray, | |
| W2: np.ndarray, | |
| reg_type: str, | |
| ) -> np.ndarray: | |
| stacked = np.stack((W1, W2), axis=-1) | |
| regs = self.REGULARIZER_FUNCTIONS[reg_type](stacked) | |
| return regs | |
| def launch(self): | |
| # build the Gradio interface | |
| with gr.Blocks(css=self.css) as demo: | |
| # app title | |
| gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Regularization visualizer</div>") | |
| # states | |
| dataset = gr.State(Dataset()) | |
| loss_type = gr.State("l2") | |
| reg_type = gr.State("l2") | |
| reg_levels = gr.State([10, 20, 30]) | |
| w1_range = gr.State([-100, 100]) | |
| w2_range = gr.State([-100, 100]) | |
| num_dots = gr.State(500) | |
| plot_regularization_path = gr.State(False) | |
| # GUI elements and layout | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| with gr.Tab("Loss and Regularization"): | |
| self.loss_and_regularization_plot = gr.Image( | |
| value=self.compute_and_plot_loss_and_reg( | |
| dataset.value, | |
| loss_type.value, | |
| reg_type.value, | |
| reg_levels.value, | |
| w1_range.value, | |
| w2_range.value, | |
| num_dots.value, | |
| plot_regularization_path.value, | |
| ), | |
| container=True, | |
| ) | |
| with gr.Tab("Data"): | |
| self.data_3d_plot = gr.Plot( | |
| value=self.plot_data(dataset.value), container=True | |
| ) | |
| with gr.Tab("Strength vs weight"): | |
| self.strength_vs_weight = gr.Image( | |
| value=self.plot_strength_vs_weight( | |
| dataset.value, loss_type.value, reg_type.value | |
| ), | |
| container=True, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Tab("Settings"): | |
| with gr.Row(): | |
| model_textbox = gr.Textbox( | |
| label="Model", | |
| value="y = w1 * x1 + w2 * x2", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| loss_type_selection = gr.Dropdown( | |
| choices=['l1', 'l2'], | |
| label='Loss type', | |
| value='l2', | |
| visible=True, | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| regularizer_type_selection = gr.Dropdown( | |
| choices=['l1', 'l2'], | |
| label='Regularizer type', | |
| value='l2', | |
| visible=True, | |
| ) | |
| reg_textbox = gr.Textbox( | |
| label="Regularizer levels", | |
| value="10, 20, 30", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| w1_textbox = gr.Textbox( | |
| label="w1 range", | |
| value="-100, 100", | |
| interactive=True, | |
| ) | |
| w2_textbox = gr.Textbox( | |
| label="w2 range", | |
| value="-100, 100", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| resolution_slider = gr.Slider( | |
| minimum=100, | |
| maximum=1000, | |
| value=500, | |
| step=1, | |
| label="Resolution (#points)", | |
| ) | |
| submit_button = gr.Button("Submit changes") | |
| with gr.Row(): | |
| path_checkbox = gr.Checkbox(label="Show regularization path", value=False) | |
| with gr.Tab("Data"): | |
| dataset_view = DatasetView() | |
| dataset_view.build(state=dataset) | |
| dataset.change( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ).then( | |
| fn=self.plot_data, | |
| inputs=[dataset], | |
| outputs=self.data_3d_plot, | |
| ).then( | |
| fn=self.plot_strength_vs_weight, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| ], | |
| outputs=self.strength_vs_weight, | |
| ) | |
| with gr.Tab("Export"): | |
| # use hidden download button to generate files on the fly | |
| # https://github.com/gradio-app/gradio/issues/9230#issuecomment-2323771634 | |
| with gr.Row(): | |
| btn_export_plot_loss_reg = gr.Button("Loss and Regularization Plot") | |
| btn_export_plot_loss_reg_hidden = gr.DownloadButton( | |
| label="You should not see this", | |
| elem_id="btn_export_plot_loss_reg_hidden", | |
| elem_classes="hidden-button" | |
| ) | |
| with gr.Tab("Usage"): | |
| gr.Markdown(''.join(open('usage.md', 'r').readlines())) | |
| # event handlers for GUI elements | |
| # settings | |
| loss_type_selection.change( | |
| fn=self.update_loss_type, | |
| inputs=[loss_type_selection], | |
| outputs=[loss_type], | |
| ).then( | |
| fn=self.update_reg_path_visibility, | |
| inputs=[loss_type_selection], | |
| outputs=[path_checkbox], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ).then( | |
| fn=self.plot_strength_vs_weight, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| ], | |
| outputs=self.strength_vs_weight, | |
| ) | |
| regularizer_type_selection.change( | |
| fn=self.update_regularizer, | |
| inputs=[regularizer_type_selection], | |
| outputs=[reg_type], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ).then( | |
| fn=self.plot_strength_vs_weight, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| ], | |
| outputs=self.strength_vs_weight, | |
| ) | |
| reg_textbox.submit( | |
| self.update_reg_levels, | |
| inputs=[reg_textbox], | |
| outputs=[reg_levels], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ).then( | |
| fn=self.plot_strength_vs_weight, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| ], | |
| outputs=self.strength_vs_weight, | |
| ) | |
| w1_textbox.submit( | |
| self.update_w1_range, | |
| inputs=[w1_textbox], | |
| outputs=[w1_range], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ) | |
| w2_textbox.submit( | |
| self.update_w2_range, | |
| inputs=[w2_textbox], | |
| outputs=[w2_range], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ) | |
| submit_button.click( | |
| self.update_w1_range, | |
| inputs=[w1_textbox], | |
| outputs=[w1_range], | |
| ).then( | |
| self.update_w2_range, | |
| inputs=[w2_textbox], | |
| outputs=[w2_range], | |
| ).then( | |
| self.update_reg_levels, | |
| inputs=[reg_textbox], | |
| outputs=[reg_levels], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ) | |
| resolution_slider.change( | |
| self.update_resolution, | |
| inputs=[resolution_slider], | |
| outputs=[num_dots], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ) | |
| path_checkbox.change( | |
| self.update_plot_path, | |
| inputs=[path_checkbox], | |
| outputs=[plot_regularization_path], | |
| ).then( | |
| fn=self.compute_and_plot_loss_and_reg, | |
| inputs=[ | |
| dataset, | |
| loss_type, | |
| reg_type, | |
| reg_levels, | |
| w1_range, | |
| w2_range, | |
| num_dots, | |
| plot_regularization_path, | |
| ], | |
| outputs=self.loss_and_regularization_plot, | |
| ) | |
| # export | |
| btn_export_plot_loss_reg.click( | |
| fn=lambda: self.FIGURE_NAME, | |
| inputs=None, | |
| outputs=[btn_export_plot_loss_reg_hidden], | |
| ).then( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| js="() => document.querySelector('#btn_export_plot_loss_reg_hidden').click()" | |
| ) | |
| demo.launch() | |
| visualizer = Regularization(width=1200, height=900) | |
| visualizer.launch() | |