Nishank122 commited on
Commit
d24ef27
1 Parent(s): 744492e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from matplotlib.colors import ListedColormap
5
+ from sklearn.inspection import DecisionBoundaryDisplay
6
+ from sklearn.neighbors import NearestCentroid
7
+ from sklearn import datasets
8
+
9
+ def classify_iris(sepal_length, sepal_width, shrinkage):
10
+ # Convert input to float values
11
+ sepal_length = float(sepal_length)
12
+ sepal_width = float(sepal_width)
13
+
14
+ # Load the Iris dataset
15
+ iris = datasets.load_iris()
16
+ X = iris.data[:, :2]
17
+ y = iris.target
18
+
19
+ # Create color maps
20
+ cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
21
+ cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])
22
+
23
+ # Create an instance of Nearest Centroid Classifier and fit the data
24
+ clf = NearestCentroid(shrink_threshold=shrinkage)
25
+ clf.fit(X, y)
26
+
27
+ # Create a new data point based on user input
28
+ new_data_point = np.array([[sepal_length, sepal_width]])
29
+
30
+ # Predict the class of the new data point
31
+ prediction = clf.predict(new_data_point)[0]
32
+
33
+ # Create a plot to display the decision boundary and data points
34
+ _, ax = plt.subplots()
35
+ DecisionBoundaryDisplay.from_estimator(
36
+ clf, X, cmap=cmap_light, ax=ax, response_method="predict"
37
+ )
38
+
39
+ # Plot the new data point
40
+ plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='black', marker='x', s=100, label="New Data Point")
41
+
42
+ # Plot the training points
43
+ plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20, label="Training Points")
44
+
45
+ plt.title("3-Class classification")
46
+ plt.axis("tight")
47
+ plt.legend()
48
+
49
+ # Save the plot to a file
50
+ plot_path = "decision_boundary.png"
51
+ plt.savefig(plot_path)
52
+ plt.close()
53
+
54
+ return f"Predicted Class: {prediction}", plot_path
55
+
56
+ # Create the Gradio interface
57
+ iface = gr.Interface(
58
+ fn=classify_iris,
59
+ inputs=["text", "text", "number"],
60
+ outputs=["text", "image"],
61
+ layout="vertical",
62
+ title="Iris Classification App",
63
+ description="Enter the sepal length and sepal width of a new data point.",
64
+ examples=[
65
+ ["5.0", "3.5", 0.2],
66
+ ["6.2", "2.8", 0.5]
67
+ ]
68
+ )
69
+
70
+ # Run the Gradio app
71
+ iface.launch()