Spaces:
Runtime error
Runtime error
File size: 2,184 Bytes
d24ef27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.neighbors import NearestCentroid
from sklearn import datasets
def classify_iris(sepal_length, sepal_width, shrinkage):
# Convert input to float values
sepal_length = float(sepal_length)
sepal_width = float(sepal_width)
# Load the Iris dataset
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
# Create color maps
cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])
# Create an instance of Nearest Centroid Classifier and fit the data
clf = NearestCentroid(shrink_threshold=shrinkage)
clf.fit(X, y)
# Create a new data point based on user input
new_data_point = np.array([[sepal_length, sepal_width]])
# Predict the class of the new data point
prediction = clf.predict(new_data_point)[0]
# Create a plot to display the decision boundary and data points
_, ax = plt.subplots()
DecisionBoundaryDisplay.from_estimator(
clf, X, cmap=cmap_light, ax=ax, response_method="predict"
)
# Plot the new data point
plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='black', marker='x', s=100, label="New Data Point")
# Plot the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20, label="Training Points")
plt.title("3-Class classification")
plt.axis("tight")
plt.legend()
# Save the plot to a file
plot_path = "decision_boundary.png"
plt.savefig(plot_path)
plt.close()
return f"Predicted Class: {prediction}", plot_path
# Create the Gradio interface
iface = gr.Interface(
fn=classify_iris,
inputs=["text", "text", "number"],
outputs=["text", "image"],
layout="vertical",
title="Iris Classification App",
description="Enter the sepal length and sepal width of a new data point.",
examples=[
["5.0", "3.5", 0.2],
["6.2", "2.8", 0.5]
]
)
# Run the Gradio app
iface.launch()
|