caliex commited on
Commit
979f8b1
1 Parent(s): b3b4a36

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from turtle import title
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.datasets import load_diabetes
5
+ from sklearn.linear_model import LinearRegression
6
+ from sklearn.model_selection import cross_val_predict
7
+ from sklearn.metrics import PredictionErrorDisplay
8
+
9
+
10
+ def predict_diabetes(subsample, plot_type):
11
+ X, y = load_diabetes(return_X_y=True)
12
+ lr = LinearRegression()
13
+ y_pred = cross_val_predict(lr, X, y, cv=10)
14
+
15
+ fig, axs = plt.subplots(ncols=2, figsize=(8, 4))
16
+ if "Actual vs. Predicted" in plot_type:
17
+ PredictionErrorDisplay.from_predictions(
18
+ y,
19
+ y_pred=y_pred,
20
+ kind="actual_vs_predicted",
21
+ subsample=subsample,
22
+ ax=axs[0],
23
+ random_state=0,
24
+ )
25
+ axs[0].set_title("Actual vs. Predicted values")
26
+ if "Residuals vs. Predicted" in plot_type:
27
+ PredictionErrorDisplay.from_predictions(
28
+ y,
29
+ y_pred=y_pred,
30
+ kind="residual_vs_predicted",
31
+ subsample=subsample,
32
+ ax=axs[1],
33
+ random_state=0,
34
+ )
35
+ axs[1].set_title("Residuals vs. Predicted Values")
36
+
37
+ fig.suptitle("Plotting cross-validated predictions")
38
+ plt.tight_layout()
39
+ plt.close(fig)
40
+
41
+ # Save the figure as an image
42
+ image_path = "predictions.png"
43
+ fig.savefig(image_path)
44
+ return image_path
45
+
46
+
47
+ # Define the Gradio interface
48
+ inputs = [
49
+ gr.inputs.Slider(minimum=1, maximum=100, step=1, default=100, label="Subsample"),
50
+ gr.inputs.CheckboxGroup(["Actual vs. Predicted", "Residuals vs. Predicted"], label="Plot Types", default=["Actual vs. Predicted", "Residuals vs. Predicted"])
51
+ ]
52
+ outputs = gr.outputs.Image(label="Cross-Validated Predictions", type="pil")
53
+
54
+ title = "Plotting Cross-Validated Predictions"
55
+ description="This app plots cross-validated predictions for a linear regression model trained on the diabetes dataset. See the original scikit-learn example here: https://scikit-learn.org/stable/auto_examples/model_selection/plot_cv_predict.html"
56
+ examples = [
57
+ [
58
+ 100,
59
+ ["Actual vs. Predicted"],
60
+ "Plotting cross-validated predictions with Actual vs. Predicted plot.",
61
+ ],
62
+ [
63
+ 50,
64
+ ["Residuals vs. Predicted"],
65
+ "Plotting cross-validated predictions with Residuals vs. Predicted plot.",
66
+ ],
67
+ [
68
+ 75,
69
+ ["Actual vs. Predicted", "Residuals vs. Predicted"],
70
+ "Plotting cross-validated predictions with both Actual vs. Predicted and Residuals vs. Predicted plots.",
71
+ ],
72
+ ]
73
+
74
+ gr.Interface(fn=predict_diabetes, title=title, description=description, examples=examples, inputs=inputs, outputs=outputs).launch()