tareknaous commited on
Commit
90b19f5
1 Parent(s): 80e3f03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -2
app.py CHANGED
@@ -6,6 +6,12 @@ from scipy import ndimage
6
  from skimage import measure, color, io
7
  from tensorflow.keras.preprocessing import image
8
  from scipy import ndimage
 
 
 
 
 
 
9
 
10
  #Function that predicts on only 1 sample
11
  def predict_sample(image):
@@ -39,7 +45,7 @@ def create_input_image(data, visualize=False):
39
  return input
40
 
41
 
42
-
43
 
44
  def get_instances(prediction, data, max_filter_size=1):
45
  #Adjust format (clusters to be 255 and rest is 0)
@@ -107,4 +113,89 @@ def get_instances(prediction, data, max_filter_size=1):
107
  cluster_ids = cluster_ids.astype('int8')
108
  cluster_ids[cluster_ids == -11] = 0
109
 
110
- return cluster_ids
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from skimage import measure, color, io
7
  from tensorflow.keras.preprocessing import image
8
  from scipy import ndimage
9
+ import skimage.io as io
10
+ import skimage.transform as trans
11
+ import numpy as np
12
+ import tensorflow as tf
13
+ from huggingface_hub.keras_mixin import from_pretrained_keras
14
+
15
 
16
  #Function that predicts on only 1 sample
17
  def predict_sample(image):
 
45
  return input
46
 
47
 
48
+ model= from_pretrained_keras("tareknaous/unet-visual-clustering")
49
 
50
  def get_instances(prediction, data, max_filter_size=1):
51
  #Adjust format (clusters to be 255 and rest is 0)
 
113
  cluster_ids = cluster_ids.astype('int8')
114
  cluster_ids[cluster_ids == -11] = 0
115
 
116
+ return cluster_ids
117
+
118
+
119
+ import gradio as gr
120
+ from itertools import cycle, islice
121
+
122
+
123
+ def visual_clustering(cluster_type, num_clusters, num_samples, random_state, median_kernel_size, max_kernel_size):
124
+
125
+ NUM_CLUSTERS = num_clusters
126
+ CLUSTER_STD = 4 * np.ones(NUM_CLUSTERS)
127
+
128
+ if cluster_type == "blobs":
129
+ data = datasets.make_blobs(n_samples=num_samples, centers=NUM_CLUSTERS, random_state=random_state,center_box=(0, 256), cluster_std=CLUSTER_STD)
130
+
131
+ elif cluster_type == "varied blobs":
132
+ cluster_std = 1.5 * np.ones(NUM_CLUSTERS)
133
+ data = datasets.make_blobs(n_samples=num_samples, centers=NUM_CLUSTERS, cluster_std=cluster_std, random_state=random_state)
134
+
135
+ elif cluster_type == "aniso":
136
+ X, y = datasets.make_blobs(n_samples=num_samples, centers=NUM_CLUSTERS, random_state=random_state, center_box=(-30, 30))
137
+ transformation = [[0.8, -0.6], [-0.4, 0.8]]
138
+ X_aniso = np.dot(X, transformation)
139
+ data = (X_aniso, y)
140
+
141
+ elif cluster_type == "noisy moons":
142
+ data = datasets.make_moons(n_samples=num_samples, noise=.05)
143
+
144
+ elif cluster_type == "noisy circles":
145
+ data = datasets.make_circles(n_samples=num_samples, factor=.01, noise=.05)
146
+
147
+ max_x = max(data[0][:, 0])
148
+ min_x = min(data[0][:, 0])
149
+ new_max = 256
150
+ new_min = 0
151
+
152
+ data[0][:, 0] = (((data[0][:, 0] - min_x)*(new_max-new_min))/(max_x-min_x))+ new_min
153
+
154
+ max_y = max(data[0][:, 1])
155
+ min_y = min(data[0][:, 1])
156
+ new_max_y = 256
157
+ new_min_y = 0
158
+
159
+ data[0][:, 1] = (((data[0][:, 1] - min_y)*(new_max_y-new_min_y))/(max_y-min_y))+ new_min_y
160
+
161
+ fig1 = plt.figure()
162
+ plt.scatter(data[0][:, 0], data[0][:, 1], s=1, c='black')
163
+ plt.close()
164
+
165
+ input = create_input_image(data[0])
166
+ filtered = ndimage.median_filter(input, size=median_kernel_size)
167
+ result = predict_sample(filtered)
168
+ y_km = get_instances(result, data[0], max_filter_size=max_kernel_size)
169
+
170
+ colors = np.array(list(islice(cycle(["#000000", '#377eb8', '#ff7f00', '#4daf4a',
171
+ '#f781bf', '#a65628', '#984ea3',
172
+ '#999999', '#e41a1c', '#dede00' ,'#491010']),
173
+ int(max(y_km) + 1))))
174
+ #add black color for outliers (if any)
175
+ colors = np.append(colors, ["#000000"])
176
+
177
+ fig2 = plt.figure()
178
+ plt.scatter(data[0][:, 0], data[0][:, 1], s=10, color=colors[y_km.astype('int8')])
179
+ plt.close()
180
+
181
+ return fig1, fig2
182
+
183
+ iface = gr.Interface(
184
+
185
+ fn=visual_clustering,
186
+
187
+ inputs=[
188
+ gr.inputs.Dropdown(["blobs", "varied blobs", "aniso", "noisy moons", "noisy circles" ]),
189
+ gr.inputs.Slider(1, 10, step=1, label='Number of Clusters'),
190
+ gr.inputs.Slider(10000, 1000000, step=10000, label='Number of Samples'),
191
+ gr.inputs.Slider(1, 100, step=1, label='Random State'),
192
+ gr.inputs.Slider(1, 100, step=1, label='Denoising Filter Kernel Size'),
193
+ gr.inputs.Slider(1,100, step=1, label='Max Filter Kernel Size')
194
+ ],
195
+
196
+ outputs=[
197
+ gr.outputs.Image(type='plot', label='Dataset'),
198
+ gr.outputs.Image(type='plot', label='Clustering Result')
199
+ ]
200
+ )
201
+ iface.launch()