Nishank122's picture
Create app.py
d24ef27
raw
history blame contribute delete
No virus
2.18 kB
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.neighbors import NearestCentroid
from sklearn import datasets
def classify_iris(sepal_length, sepal_width, shrinkage):
# Convert input to float values
sepal_length = float(sepal_length)
sepal_width = float(sepal_width)
# Load the Iris dataset
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
# Create color maps
cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])
# Create an instance of Nearest Centroid Classifier and fit the data
clf = NearestCentroid(shrink_threshold=shrinkage)
clf.fit(X, y)
# Create a new data point based on user input
new_data_point = np.array([[sepal_length, sepal_width]])
# Predict the class of the new data point
prediction = clf.predict(new_data_point)[0]
# Create a plot to display the decision boundary and data points
_, ax = plt.subplots()
DecisionBoundaryDisplay.from_estimator(
clf, X, cmap=cmap_light, ax=ax, response_method="predict"
)
# Plot the new data point
plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='black', marker='x', s=100, label="New Data Point")
# Plot the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20, label="Training Points")
plt.title("3-Class classification")
plt.axis("tight")
plt.legend()
# Save the plot to a file
plot_path = "decision_boundary.png"
plt.savefig(plot_path)
plt.close()
return f"Predicted Class: {prediction}", plot_path
# Create the Gradio interface
iface = gr.Interface(
fn=classify_iris,
inputs=["text", "text", "number"],
outputs=["text", "image"],
layout="vertical",
title="Iris Classification App",
description="Enter the sepal length and sepal width of a new data point.",
examples=[
["5.0", "3.5", 0.2],
["6.2", "2.8", 0.5]
]
)
# Run the Gradio app
iface.launch()