Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| import time | |
| import ast | |
| import gradio as gr | |
| import io | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from PIL import Image | |
| from sklearn.gaussian_process import GaussianProcessRegressor | |
| from sklearn.gaussian_process.kernels import ( | |
| DotProduct, | |
| WhiteKernel, | |
| ConstantKernel, | |
| RBF, | |
| Matern, | |
| RationalQuadratic, | |
| ExpSineSquared, | |
| Kernel, | |
| ) | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, # set minimum level to capture (DEBUG, INFO, WARNING, ERROR, CRITICAL) | |
| format="%(asctime)s [%(levelname)s] %(message)s", # log format | |
| ) | |
| logger = logging.getLogger("ELVIS") | |
| from dataset import Dataset, DatasetView, get_function | |
| class PlotOptions: | |
| show_training_data: bool = True | |
| show_true_function: bool = True | |
| show_mean_prediction: bool = True | |
| show_prediction_interval: bool = True | |
| def update(self, **kwargs): | |
| return PlotOptions( | |
| show_training_data=kwargs.get("show_training_data", self.show_training_data), | |
| show_true_function=kwargs.get("show_true_function", self.show_true_function), | |
| show_mean_prediction=kwargs.get("show_mean_prediction", self.show_mean_prediction), | |
| show_prediction_interval=kwargs.get("show_prediction_interval", self.show_prediction_interval), | |
| ) | |
| def __hash__(self): | |
| return hash( | |
| ( | |
| self.show_training_data, | |
| self.show_true_function, | |
| self.show_mean_prediction, | |
| self.show_prediction_interval, | |
| ) | |
| ) | |
| def eval_kernel(kernel_str) -> Kernel: | |
| # List of allowed kernel constructors | |
| allowed_names = { | |
| 'RBF': RBF, | |
| 'Matern': Matern, | |
| 'RationalQuadratic': RationalQuadratic, | |
| 'ExpSineSquared': ExpSineSquared, | |
| 'DotProduct': DotProduct, | |
| 'WhiteKernel': WhiteKernel, | |
| 'ConstantKernel': ConstantKernel, | |
| } | |
| # Parse and check the syntax safely | |
| try: | |
| tree = ast.parse(kernel_str, mode='eval') | |
| except SyntaxError as e: | |
| raise ValueError(f"Invalid syntax: {e}") | |
| # Evaluate in restricted namespace | |
| try: | |
| result = eval( | |
| compile(tree, '<string>', 'eval'), | |
| {"__builtins__": None}, # disable access to Python builtins like open | |
| allowed_names # only allow things in this list | |
| ) | |
| except Exception as e: | |
| raise ValueError(f"Error evaluating kernel: {e}") | |
| return result | |
| class ModelState: | |
| model: GaussianProcessRegressor | |
| kernel: str | |
| distribution: str | |
| def __hash__(self): | |
| return hash( | |
| ( | |
| self.kernel, | |
| self.distribution, | |
| ) | |
| ) | |
| class GpVisualizer: | |
| def __init__(self, width, height): | |
| self.canvas_width = width | |
| self.canvas_height = height | |
| self.plot_cmap = plt.get_cmap("tab20") | |
| self.css = """ | |
| .hidden-button { | |
| display: none; | |
| }""" | |
| def plot( | |
| self, | |
| dataset: Dataset, | |
| model_state: ModelState, | |
| plot_options: PlotOptions, | |
| sample_y: bool = False, | |
| sample_y_seed: int = 0, | |
| ) -> Image.Image: | |
| print("Plotting") | |
| t1 = time.time() | |
| fig = plt.figure(figsize=(self.canvas_width / 100., self.canvas_height / 100.0), dpi=100) | |
| # set entire figure to be the canvas to allow simple conversion of mouse | |
| # position to coordinates in the figure | |
| ax = fig.add_axes([0., 0., 1., 1.]) # | |
| ax.margins(x=0, y=0) # no padding in both directions | |
| x_train = dataset.x | |
| y_train = dataset.y | |
| if dataset.mode == "generate": | |
| x_test, y_test = get_function(dataset.function, xlim=(-2, 2), nsample=100) | |
| y_pred, y_std = model_state.model.predict(x_test, return_std=True) | |
| elif x_train.shape[0] > 0: | |
| x_test = np.linspace(x_train.min() - 1, x_train.max() + 1, 100).reshape(-1, 1) | |
| y_test = None | |
| y_pred, y_std = model_state.model.predict(x_test, return_std=True) | |
| else: | |
| x_test = None | |
| y_test = None | |
| y_pred = None | |
| y_std = None | |
| # plot | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.set_title("") | |
| ax.set_xlabel("x") | |
| ax.set_ylabel("y") | |
| if y_test is not None: | |
| min_y = min(y_test.min(), (y_pred - 1.96 * y_std).min()) | |
| max_y = max(y_test.max(), (y_pred + 1.96 * y_std).max()) | |
| ax.set_ylim(min_y - 1, max_y + 1) | |
| elif y_train.shape[0] > 0: | |
| min_y = min(y_train.min(), (y_pred - 1.96 * y_std).min()) | |
| max_y = max(y_train.max(), (y_pred + 1.96 * y_std).max()) | |
| ax.set_ylim(min_y - 1, max_y + 1) | |
| if plot_options.show_training_data: | |
| plt.scatter( | |
| x_train.flatten(), | |
| y_train, | |
| label='training data', | |
| color=self.plot_cmap(0), | |
| ) | |
| if plot_options.show_true_function and x_test is not None and y_test is not None: | |
| plt.plot( | |
| x_test.flatten(), | |
| y_test, | |
| label='true function', | |
| color=self.plot_cmap(1), | |
| ) | |
| if plot_options.show_mean_prediction and x_test is not None and y_pred is not None: | |
| plt.plot( | |
| x_test.flatten(), | |
| y_pred, | |
| linestyle="--", | |
| label='mean prediction', | |
| color=self.plot_cmap(2), | |
| ) | |
| if plot_options.show_prediction_interval and x_test is not None and y_std is not None: | |
| plt.fill_between( | |
| x_test.flatten(), | |
| y_pred - 1.96 * y_std, | |
| y_pred + 1.96 * y_std, | |
| color=self.plot_cmap(3), | |
| alpha=0.2, | |
| label='95% prediction interval', | |
| ) | |
| if x_test is not None and sample_y: | |
| y_sample = model_state.model.sample_y( | |
| x_test, random_state=sample_y_seed | |
| ).flatten() | |
| plt.plot( | |
| x_test.flatten(), | |
| y_sample, | |
| linestyle=":", | |
| label="model sample", | |
| color=self.plot_cmap(4), | |
| ) | |
| plt.legend() | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", pad_inches=0) | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| plt.close(fig) | |
| t2 = time.time() | |
| logger.info(f"Plotting took {t2 - t1:.4f} seconds") | |
| return img | |
| def init_model( | |
| self, | |
| kernel: str, | |
| dataset: Dataset, | |
| distribution: str, | |
| ) -> GaussianProcessRegressor: | |
| model = GaussianProcessRegressor(kernel=eval_kernel(kernel)) | |
| if distribution == "posterior": | |
| if dataset.x.shape[0] > 0: | |
| model.fit(dataset.x, dataset.y) | |
| elif distribution != "prior": | |
| raise ValueError(f"Unknown distribution: {distribution}") | |
| return model | |
| def update_dataset( | |
| self, | |
| dataset: Dataset, | |
| model_state: ModelState, | |
| plot_options: PlotOptions, | |
| ) -> tuple[ModelState, Image.Image]: | |
| print("updating dataset") | |
| model = self.init_model( | |
| model_state.kernel, | |
| dataset, | |
| model_state.distribution, | |
| ) | |
| model_state = ModelState( | |
| model=model, kernel=model_state.kernel, distribution=model_state.distribution | |
| ) | |
| new_canvas = self.plot(dataset, model_state, plot_options) | |
| return model_state, new_canvas | |
| def update_model( | |
| self, | |
| kernel_str: str, | |
| distribution: str, | |
| model_state: ModelState, | |
| dataset: Dataset, | |
| plot_options: PlotOptions, | |
| ) -> tuple[ModelState, Image.Image]: | |
| print("updating kernel") | |
| try: | |
| model = self.init_model( | |
| kernel_str, | |
| dataset, | |
| distribution.lower(), | |
| ) | |
| model_state = ModelState( | |
| model=model, kernel=kernel_str, distribution=distribution.lower() | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error updating kernel: {e}") | |
| gr.Info(f" ⚠️ Error updating kerne: {e}") | |
| new_canvas = self.plot(dataset, model_state, plot_options) | |
| return model_state, new_canvas | |
| def sample( | |
| self, | |
| model_state: ModelState, | |
| dataset: Dataset, | |
| plot_options: PlotOptions, | |
| ) -> Image.Image: | |
| print("sampling from model") | |
| seed = int(time.time() * 100) % 10000 | |
| new_canvas = self.plot( | |
| dataset, | |
| model_state, | |
| plot_options, | |
| sample_y=True, | |
| sample_y_seed=seed, | |
| ) | |
| return new_canvas | |
| def clear_sample( | |
| self, | |
| model_state: ModelState, | |
| dataset: Dataset, | |
| plot_options: PlotOptions, | |
| ) -> Image.Image: | |
| print("clearing sample from model") | |
| new_canvas = self.plot( | |
| dataset, | |
| model_state, | |
| plot_options, | |
| sample_y=False, | |
| ) | |
| return new_canvas | |
| def launch(self): | |
| # build the Gradio interface | |
| with gr.Blocks(css=self.css) as demo: | |
| # app title | |
| gr.HTML("<div style='text-align:left; font-size:40px; font-weight: bold;'>Gaussian Process Visualizer</div>") | |
| # states | |
| dataset = gr.State(Dataset()) | |
| plot_options = gr.State(PlotOptions()) | |
| kernel = "RBF() + WhiteKernel()" | |
| model = self.init_model(kernel, dataset.value, "posterior") | |
| model_state = gr.State( | |
| ModelState(model=model, kernel=kernel, distribution="posterior") | |
| ) | |
| # GUI elements and layout | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| canvas = gr.Image( | |
| value=self.plot( | |
| dataset.value, | |
| model_state.value, | |
| plot_options.value, | |
| ), | |
| # show_download_button=False, | |
| container=True, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Tab("Dataset"): | |
| dataset_view = DatasetView() | |
| dataset_view.build(state=dataset) | |
| dataset.change( | |
| fn=self.update_dataset, | |
| inputs=[dataset, model_state, plot_options], | |
| outputs=[model_state, canvas], | |
| ) | |
| with gr.Tab("Model"): | |
| kernel_box = gr.Textbox( | |
| label="Kernel", | |
| value=model_state.value.kernel, | |
| interactive=True, | |
| ) | |
| kernel_submit = gr.Button("Update Kernel") | |
| distribution = gr.Radio( | |
| label="Distribution", | |
| choices=["Prior", "Posterior"], | |
| value="Posterior", | |
| ) | |
| kernel_box.submit( | |
| fn=self.update_model, | |
| inputs=[kernel_box, distribution, model_state, dataset, plot_options], | |
| outputs=[model_state, canvas], | |
| ) | |
| kernel_submit.click( | |
| fn=self.update_model, | |
| inputs=[kernel_box, distribution, model_state, dataset, plot_options], | |
| outputs=[model_state, canvas], | |
| ) | |
| distribution.change( | |
| fn=self.update_model, | |
| inputs=[kernel_box, distribution, model_state, dataset, plot_options], | |
| outputs=[model_state, canvas], | |
| ) | |
| sample_button = gr.Button("Sample") | |
| clear_sample_button = gr.Button("Clear Sample") | |
| sample_button.click( | |
| fn=self.sample, | |
| inputs=[model_state, dataset, plot_options], | |
| outputs=[canvas], | |
| ) | |
| clear_sample_button.click( | |
| fn=self.clear_sample, | |
| inputs=[model_state, dataset, plot_options], | |
| outputs=[canvas], | |
| ) | |
| with gr.Tab("Plot Options"): | |
| show_training_data = gr.Checkbox( | |
| label="Show Training Data", | |
| value=True, | |
| ) | |
| show_true_function = gr.Checkbox( | |
| label="Show True Function", | |
| value=True, | |
| ) | |
| show_mean_prediction = gr.Checkbox( | |
| label="Show Mean Prediction", | |
| value=True, | |
| ) | |
| show_prediction_interval = gr.Checkbox( | |
| label="Show Prediction Interval", | |
| value=True, | |
| ) | |
| show_training_data.change( | |
| fn=lambda val, options: options.update(show_training_data=val), | |
| inputs=[show_training_data, plot_options], | |
| outputs=[plot_options], | |
| ) | |
| show_true_function.change( | |
| fn=lambda val, options: options.update(show_true_function=val), | |
| inputs=[show_true_function, plot_options], | |
| outputs=[plot_options], | |
| ) | |
| show_mean_prediction.change( | |
| fn=lambda val, options: options.update(show_mean_prediction=val), | |
| inputs=[show_mean_prediction, plot_options], | |
| outputs=[plot_options], | |
| ) | |
| show_prediction_interval.change( | |
| fn=lambda val, options: options.update(show_prediction_interval=val), | |
| inputs=[show_prediction_interval, plot_options], | |
| outputs=[plot_options], | |
| ) | |
| plot_options.change( | |
| fn=self.plot, | |
| inputs=[dataset, model_state, plot_options], | |
| outputs=[canvas], | |
| ) | |
| with gr.Tab("Usage"): | |
| with open("usage.md", "r") as f: | |
| usage_md = f.read() | |
| gr.Markdown(usage_md) | |
| demo.launch() | |
| visualizer = GpVisualizer(width=1200, height=900) | |
| visualizer.launch() | |