File size: 2,184 Bytes
d24ef27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.neighbors import NearestCentroid
from sklearn import datasets

def classify_iris(sepal_length, sepal_width, shrinkage):
    # Convert input to float values
    sepal_length = float(sepal_length)
    sepal_width = float(sepal_width)
    
    # Load the Iris dataset
    iris = datasets.load_iris()
    X = iris.data[:, :2]
    y = iris.target

    # Create color maps
    cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
    cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])

    # Create an instance of Nearest Centroid Classifier and fit the data
    clf = NearestCentroid(shrink_threshold=shrinkage)
    clf.fit(X, y)

    # Create a new data point based on user input
    new_data_point = np.array([[sepal_length, sepal_width]])

    # Predict the class of the new data point
    prediction = clf.predict(new_data_point)[0]

    # Create a plot to display the decision boundary and data points
    _, ax = plt.subplots()
    DecisionBoundaryDisplay.from_estimator(
        clf, X, cmap=cmap_light, ax=ax, response_method="predict"
    )

    # Plot the new data point
    plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='black', marker='x', s=100, label="New Data Point")

    # Plot the training points
    plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20, label="Training Points")

    plt.title("3-Class classification")
    plt.axis("tight")
    plt.legend()

    # Save the plot to a file
    plot_path = "decision_boundary.png"
    plt.savefig(plot_path)
    plt.close()

    return f"Predicted Class: {prediction}", plot_path

# Create the Gradio interface
iface = gr.Interface(
    fn=classify_iris,
    inputs=["text", "text", "number"],
    outputs=["text", "image"],
    layout="vertical",
    title="Iris Classification App",
    description="Enter the sepal length and sepal width of a new data point.",
    examples=[
        ["5.0", "3.5", 0.2],
        ["6.2", "2.8", 0.5]
    ]
)

# Run the Gradio app
iface.launch()