isshagle commited on
Commit
f87211d
1 Parent(s): 6edf3a1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -0
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import gradio as gr
4
+ from sklearn.cluster import KMeans
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ from io import BytesIO
8
+
9
+ # Load your dataset
10
+ dataset = pd.read_csv('Flipcart.com Clusturing Model.csv')
11
+ X = dataset.iloc[:, [2, 4]].values
12
+
13
+ # Create a K-Means clustering model
14
+ kmeans = KMeans(n_clusters=4, init='k-means++', random_state=42)
15
+ y_means = kmeans.fit_predict(X)
16
+
17
+ # Function to perform clustering and return cluster labels and the cluster visualization image
18
+ def cluster_data(age, purchase_rating):
19
+ features = np.array([age, purchase_rating]).reshape(1, -1)
20
+ cluster = kmeans.predict(features)[0]
21
+
22
+ # Scatter plot to visualize clusters
23
+ plt.figure(figsize=(8, 6))
24
+ plt.scatter(X[y_means == 0, 0], X[y_means == 0, 1], s=100, c='magenta', label='Cluster 1')
25
+ plt.scatter(X[y_means == 1, 0], X[y_means == 1, 1], s=100, c='blue', label='Cluster 2')
26
+ plt.scatter(X[y_means == 2, 0], X[y_means == 2, 1], s=100, c='red', label='Cluster 3')
27
+ plt.scatter(X[y_means == 3, 0], X[y_means == 3, 1], s=100, c='cyan', label='Cluster 4')
28
+ plt.scatter(kmeans.cluster_centers_[:, 0], kmeans.cluster_centers_[:, 1], s=300, c='black', label='Centroids')
29
+ plt.title('Cluster of Amazon users')
30
+ plt.xlabel('Age')
31
+ plt.ylabel('Purchase Rating')
32
+ plt.legend()
33
+ plt.grid(True)
34
+
35
+ # Save the plot as an image
36
+ image_buffer = BytesIO()
37
+ plt.savefig(image_buffer, format='png')
38
+ image_buffer.seek(0)
39
+
40
+ # Create a PIL image from the buffer
41
+ pil_image = Image.open(image_buffer)
42
+
43
+ return f'Data point belongs to Cluster {cluster}', pil_image
44
+
45
+ # Create a Gradio interface for the clustering model
46
+ iface = gr.Interface(
47
+ fn=cluster_data,
48
+ inputs=[
49
+ gr.inputs.Number(label="Age"),
50
+ gr.inputs.Number(label="Purchase Rating")
51
+ ],
52
+ outputs=[
53
+ gr.outputs.Textbox(label="Cluster"),
54
+ gr.outputs.Image(label="Cluster Visualization", type="pil")
55
+ ],
56
+ examples = [[23,44],
57
+ [26,91],
58
+ [72,5]],
59
+ live = True,
60
+ description = " Press flag if any erroneous output comes ",
61
+ theme=gr.themes.Soft(),
62
+ title = "Flipcart User Segmentation"
63
+ )
64
+
65
+ # Launch the Gradio app
66
+ iface.launch(inline=False)