caliex's picture
Create app.py
979f8b1
from turtle import title
import gradio as gr
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import PredictionErrorDisplay
def predict_diabetes(subsample, plot_type):
X, y = load_diabetes(return_X_y=True)
lr = LinearRegression()
y_pred = cross_val_predict(lr, X, y, cv=10)
fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
if "Actual vs. Predicted" in plot_type:
PredictionErrorDisplay.from_predictions(
y,
y_pred=y_pred,
kind="actual_vs_predicted",
subsample=subsample,
ax=axs[0],
random_state=0,
)
axs[0].set_title("Actual vs. Predicted values")
if "Residuals vs. Predicted" in plot_type:
PredictionErrorDisplay.from_predictions(
y,
y_pred=y_pred,
kind="residual_vs_predicted",
subsample=subsample,
ax=axs[1],
random_state=0,
)
axs[1].set_title("Residuals vs. Predicted Values")
fig.suptitle("Plotting cross-validated predictions")
plt.tight_layout()
plt.close(fig)
# Save the figure as an image
image_path = "predictions.png"
fig.savefig(image_path)
return image_path
# Define the Gradio interface
inputs = [
gr.inputs.Slider(minimum=1, maximum=100, step=1, default=100, label="Subsample"),
gr.inputs.CheckboxGroup(["Actual vs. Predicted", "Residuals vs. Predicted"], label="Plot Types", default=["Actual vs. Predicted", "Residuals vs. Predicted"])
]
outputs = gr.outputs.Image(label="Cross-Validated Predictions", type="pil")
title = "Plotting Cross-Validated Predictions"
description="This app plots cross-validated predictions for a linear regression model trained on the diabetes dataset. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_predict.html"
examples = [
[
100,
["Actual vs. Predicted"],
"Plotting cross-validated predictions with Actual vs. Predicted plot.",
],
[
50,
["Residuals vs. Predicted"],
"Plotting cross-validated predictions with Residuals vs. Predicted plot.",
],
[
75,
["Actual vs. Predicted", "Residuals vs. Predicted"],
"Plotting cross-validated predictions with both Actual vs. Predicted and Residuals vs. Predicted plots.",
],
]
gr.Interface(fn=predict_diabetes, title=title, description=description, examples=examples, inputs=inputs, outputs=outputs).launch()