testingMLmodels / app.py
kennard at laptop
minor edits
e437cd2
import gradio as gr
"""
def greet(name):
return "Hello " + name + "!!"
demo = gr.Interface(fn=greet, inputs="text", outputs="text")
demo.launch()
"""
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.neighbors import NearestCentroid
from sklearn.neighbors import KNeighborsClassifier
from sklearn import datasets
def classify_iris(sepal_length, sepal_width, shrinkage):
# Convert input to float values
sepal_length = float(sepal_length)
sepal_width = float(sepal_width)
# Load the Iris dataset
iris = datasets.load_iris()
X = iris.data[:, :2]
y = iris.target
# Create color maps
cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"])
cmap_bold = ListedColormap(["darkorange", "c", "darkblue"])
# Create an instance of Nearest Centroid Classifier and fit the data
#clf = NearestCentroid(shrink_threshold=shrinkage)
clf = KNeighborsClassifier(n_neighbors=3)
clf.fit(X, y)
# Create a new data point based on user input
new_data_point = np.array([[sepal_length, sepal_width]])
# Predict the class of the new data point
prediction = clf.predict(new_data_point)[0]
# Create a plot to display the decision boundary and data points
_, ax = plt.subplots()
DecisionBoundaryDisplay.from_estimator(
clf, X, cmap=cmap_light, ax=ax, response_method="predict"
)
# Plot the new data point
plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='black', marker='x', s=100, label="New Data Point")
# Plot the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20, label="Training Points")
plt.title("3-Class classification")
plt.axis("tight")
plt.legend()
# Save the plot to a file
plot_path = "decision_boundary.png"
plt.savefig(plot_path)
plt.close()
return f"Predicted Class: {prediction}", plot_path
# Create the Gradio interface
iface = gr.Interface(
fn=classify_iris,
inputs=["text", "text", "number"],
outputs=["text", "image"],
#layout="vertical",
title="Iris Classification App",
description="Enter the sepal length and sepal width of a new data point.",
examples=[
["5.0", "3.5", 0.2],
["6.2", "2.8", 0.5]
]
)
# Run the Gradio app
iface.launch()