Spaces:
Sleeping
Sleeping
import numpy as np | |
import pandas as pd | |
import gradio as gr | |
from sklearn.cluster import KMeans | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from io import BytesIO | |
# Load your dataset | |
dataset = pd.read_csv('Flipcart.com Clusturing Model.csv') | |
X = dataset.iloc[:, [2, 4]].values | |
# Create a K-Means clustering model | |
kmeans = KMeans(n_clusters=4, init='k-means++', random_state=42) | |
y_means = kmeans.fit_predict(X) | |
# Function to perform clustering and return cluster labels and the cluster visualization image | |
def cluster_data(age, purchase_rating): | |
features = np.array([age, purchase_rating]).reshape(1, -1) | |
cluster = kmeans.predict(features)[0] | |
# Scatter plot to visualize clusters | |
plt.figure(figsize=(8, 6)) | |
plt.scatter(X[y_means == 0, 0], X[y_means == 0, 1], s=100, c='magenta', label='Cluster 1') | |
plt.scatter(X[y_means == 1, 0], X[y_means == 1, 1], s=100, c='blue', label='Cluster 2') | |
plt.scatter(X[y_means == 2, 0], X[y_means == 2, 1], s=100, c='red', label='Cluster 3') | |
plt.scatter(X[y_means == 3, 0], X[y_means == 3, 1], s=100, c='cyan', label='Cluster 4') | |
plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], s=300, c='black', label='Centroids') | |
plt.title('Cluster of Amazon users') | |
plt.xlabel('Age') | |
plt.ylabel('Purchase Rating') | |
plt.legend() | |
plt.grid(True) | |
# Save the plot as an image | |
image_buffer = BytesIO() | |
plt.savefig(image_buffer, format='png') | |
image_buffer.seek(0) | |
# Create a PIL image from the buffer | |
pil_image = Image.open(image_buffer) | |
return f'Data point belongs to Cluster {cluster}', pil_image | |
# Create a Gradio interface for the clustering model | |
iface = gr.Interface( | |
fn=cluster_data, | |
inputs=[ | |
gr.inputs.Number(label="Age"), | |
gr.inputs.Number(label="Purchase Rating") | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Cluster"), | |
gr.outputs.Image(label="Cluster Visualization", type="pil") | |
], | |
examples = [[23,44], | |
[26,91], | |
[72,5]], | |
live = True, | |
description = " Press flag if any erroneous output comes ", | |
theme=gr.themes.Soft(), | |
title = "Flipcart User Segmentation" | |
) | |
# Launch the Gradio app | |
iface.launch(inline=False) | |