File size: 2,446 Bytes
ff94788
 
 
e6dd35b
ff94788
 
 
 
 
e6dd35b
ff94788
 
 
 
 
 
 
 
e437cd2
 
 
ff94788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1df8408
 
 
ff94788
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cbd1fb1
ff94788
 
 
 
 
 
 
 
 
 
e6dd35b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90

import gradio as gr

"""
def greet(name):
    return "Hello " + name + "!!"

demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()
"""


import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.neighbors import NearestCentroid

from sklearn.neighbors import KNeighborsClassifier

from sklearn import datasets

def classify_iris(sepal_length, sepal_width, shrinkage):
    # Convert input to float values
    sepal_length = float(sepal_length)
    sepal_width = float(sepal_width)
    
    # Load the Iris dataset
    iris = datasets.load_iris()
    X = iris.data[:, :2]
    y = iris.target

    # Create color maps
    cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
    cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])

    # Create an instance of Nearest Centroid Classifier and fit the data
    #clf = NearestCentroid(shrink_threshold=shrinkage)
    clf = KNeighborsClassifier(n_neighbors=3)

    clf.fit(X, y)

    # Create a new data point based on user input
    new_data_point = np.array([[sepal_length, sepal_width]])

    # Predict the class of the new data point
    prediction = clf.predict(new_data_point)[0]

    # Create a plot to display the decision boundary and data points
    _, ax = plt.subplots()
    DecisionBoundaryDisplay.from_estimator(
        clf, X, cmap=cmap_light, ax=ax, response_method="predict"
    )

    # Plot the new data point
    plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='black', marker='x', s=100, label="New Data Point")

    # Plot the training points
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20, label="Training Points")

    plt.title("3-Class classification")
    plt.axis("tight")
    plt.legend()

    # Save the plot to a file
    plot_path = "decision_boundary.png"
    plt.savefig(plot_path)
    plt.close()

    return f"Predicted Class: {prediction}", plot_path

# Create the Gradio interface
iface = gr.Interface(
    fn=classify_iris,
    inputs=["text", "text", "number"],
    outputs=["text", "image"],
    #layout="vertical",
    title="Iris Classification App",
    description="Enter the sepal length and sepal width of a new data point.",
    examples=[
        ["5.0", "3.5", 0.2],
        ["6.2", "2.8", 0.5]
    ]
)

# Run the Gradio app
iface.launch()