tushifire commited on
Commit
2f090b9
·
1 Parent(s): dfad0f0

Initial Commit

Browse files
Files changed (2) hide show
  1. app.py +71 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib.pyplot as plt
3
+ from sklearn.tree import DecisionTreeRegressor
4
+ import gradio as gr
5
+
6
+ # Create a random dataset
7
+ rng = np.random.RandomState(1)
8
+ X = np.sort(200 * rng.rand(100, 1) - 100, axis=0)
9
+ y = np.array([np.pi * np.sin(X).ravel(), np.pi * np.cos(X).ravel()]).T
10
+ y[::5, :] += 0.5 - rng.rand(20, 2)
11
+
12
+
13
+ def plot_multi_tree(d1,d2,d3):
14
+ # Fit regression model
15
+ regr_1 = DecisionTreeRegressor(max_depth=d1)
16
+ regr_2 = DecisionTreeRegressor(max_depth=d2)
17
+ regr_3 = DecisionTreeRegressor(max_depth=d3)
18
+ regr_1.fit(X, y)
19
+ regr_2.fit(X, y)
20
+ regr_3.fit(X, y)
21
+
22
+ # Predict
23
+ X_test = np.arange(-100.0, 100.0, 0.01)[:, np.newaxis]
24
+ y_1 = regr_1.predict(X_test)
25
+ y_2 = regr_2.predict(X_test)
26
+ y_3 = regr_3.predict(X_test)
27
+
28
+ # Plot the results
29
+ fig = plt.figure()
30
+ s = 25
31
+ plt.scatter(y[:, 0], y[:, 1], c="navy", s=s, edgecolor="black", label="data")
32
+ plt.scatter(
33
+ y_1[:, 0],
34
+ y_1[:, 1],
35
+ c="cornflowerblue",
36
+ s=s,
37
+ edgecolor="black",
38
+ label= f"max_depth={d1}",
39
+ )
40
+ plt.scatter(y_2[:, 0], y_2[:, 1], c="red", s=s, edgecolor="black", label= f"max_depth={d2}")
41
+ plt.scatter(
42
+ y_3[:, 0], y_3[:, 1], c="orange", s=s, edgecolor="black", label= f"max_depth={d3}"
43
+ )
44
+ plt.xlim([-6, 6])
45
+ plt.ylim([-6, 6])
46
+ plt.xlabel("target 1")
47
+ plt.ylabel("target 2")
48
+ plt.title("Multi-output Decision Tree Regression")
49
+ plt.legend(loc="best")
50
+ return fig
51
+
52
+
53
+
54
+
55
+ title = " Illustration of multi-output regression with decision tree.🌲 "
56
+ with gr.Blocks(title=title) as demo:
57
+ gr.Markdown(f"## {title}")
58
+
59
+ with gr.Row():
60
+ d1 = gr.Slider(minimum=0, maximum=20, step=1, value = 2,
61
+ label = "Depth 1")
62
+ d2 = gr.Slider(minimum=0, maximum=20, step=1, value = 5,
63
+ label = "Depth 2")
64
+ d3 = gr.Slider(minimum=0, maximum=20, step=1, value = 8,
65
+ label = "Depth 3")
66
+
67
+ btn = gr.Button(value="Submit")
68
+ btn.click(plot_multi_tree, inputs= [d1,d2,d3], outputs= gr.Plot(label='Multi-output regression with decision trees') ) #
69
+
70
+
71
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ scikit-learn==1.2.1