import gradio as gr import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import ListedColormap from sklearn.inspection import DecisionBoundaryDisplay from sklearn.neighbors import NearestCentroid from sklearn import datasets def classify_iris(sepal_length, sepal_width, shrinkage): # Convert input to float values sepal_length = float(sepal_length) sepal_width = float(sepal_width) # Load the Iris dataset iris = datasets.load_iris() X = iris.data[:, :2] y = iris.target # Create color maps cmap_light = ListedColormap(["orange", "cyan", "cornflowerblue"]) cmap_bold = ListedColormap(["darkorange", "c", "darkblue"]) # Create an instance of Nearest Centroid Classifier and fit the data clf = NearestCentroid(shrink_threshold=shrinkage) clf.fit(X, y) # Create a new data point based on user input new_data_point = np.array([[sepal_length, sepal_width]]) # Predict the class of the new data point prediction = clf.predict(new_data_point)[0] # Create a plot to display the decision boundary and data points _, ax = plt.subplots() DecisionBoundaryDisplay.from_estimator( clf, X, cmap=cmap_light, ax=ax, response_method="predict" ) # Plot the new data point plt.scatter(new_data_point[:, 0], new_data_point[:, 1], color='black', marker='x', s=100, label="New Data Point") # Plot the training points plt.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor="k", s=20, label="Training Points") plt.title("3-Class classification") plt.axis("tight") plt.legend() # Save the plot to a file plot_path = "decision_boundary.png" plt.savefig(plot_path) plt.close() return f"Predicted Class: {prediction}", plot_path # Create the Gradio interface iface = gr.Interface( fn=classify_iris, inputs=["text", "text", "number"], outputs=["text", "image"], layout="vertical", title="Iris Classification App", description="Enter the sepal length and sepal width of a new data point.", examples=[ ["5.0", "3.5", 0.2], ["6.2", "2.8", 0.5] ] ) # Run the Gradio app iface.launch()