James Horwath commited on
Commit
0dd7ea7
1 Parent(s): 1e58dce

first commit

Browse files
app.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import matplotlib as mpl
4
+ mpl.use('agg')
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.optim as optim
9
+ import torch.nn.functional as F
10
+ from torch.utils.data import TensorDataset, DataLoader
11
+ from sklearn.decomposition import PCA
12
+ from sklearn.cluster import KMeans
13
+ from sklearn.manifold import TSNE
14
+ from umap import UMAP
15
+ import plotly.express as px
16
+ import pandas as pd
17
+
18
+ class recon_encoder(nn.Module):
19
+
20
+ def __init__(self, latent_size, nconv=16, pool=4, drop=0.05):
21
+ super(recon_encoder, self).__init__()
22
+
23
+
24
+ self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ
25
+ nn.Conv2d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=(1,1)),
26
+ nn.Dropout(drop),
27
+ nn.ReLU(),
28
+ nn.Conv2d(nconv, nconv, 3, stride=1, padding=(1,1)),
29
+ nn.Dropout(drop),
30
+ nn.ReLU(),
31
+ nn.MaxPool2d((pool,pool)),
32
+
33
+ nn.Conv2d(nconv, nconv*2, 3, stride=1, padding=(1,1)),
34
+ nn.Dropout(drop),
35
+ nn.ReLU(),
36
+ nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
37
+ nn.Dropout(drop),
38
+ nn.ReLU(),
39
+ nn.MaxPool2d((pool,pool)),
40
+
41
+ nn.Conv2d(nconv*2, nconv*4, 3, stride=1, padding=(1,1)),
42
+ nn.Dropout(drop),
43
+ nn.ReLU(),
44
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
45
+ nn.Dropout(drop),
46
+ nn.ReLU(),
47
+ nn.MaxPool2d((pool,pool)),
48
+
49
+ #nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
50
+ #nn.Dropout(drop),
51
+ #nn.ReLU(),
52
+ #nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
53
+ #nn.Dropout(drop),
54
+ #nn.ReLU(),
55
+ #nn.MaxPool2d((pool,pool)),
56
+ )
57
+
58
+
59
+ self.bottleneck = nn.Sequential(
60
+ # FC layer at bottleneck -- dropout might not make sense here
61
+ nn.Flatten(),
62
+ nn.Linear(1024, latent_size),
63
+ #nn.Dropout(drop),
64
+ nn.ReLU(),
65
+ # nn.Linear(latent_size, 1024),
66
+ # #nn.Dropout(drop),
67
+ # nn.ReLU(),
68
+ # nn.Unflatten(1,(64,4,4))# 0 is batch dimension
69
+ )
70
+
71
+
72
+ self.decoder1 = nn.Sequential(
73
+
74
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
75
+ nn.Dropout(drop),
76
+ nn.ReLU(),
77
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
78
+ nn.Dropout(drop),
79
+ nn.ReLU(),
80
+ nn.Upsample(scale_factor=pool, mode='bilinear'),
81
+
82
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
83
+ nn.Dropout(drop),
84
+ nn.ReLU(),
85
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
86
+ nn.Dropout(drop),
87
+ nn.ReLU(),
88
+ nn.Upsample(scale_factor=pool, mode='bilinear'),
89
+
90
+ nn.Conv2d(nconv*4, nconv*2, 3, stride=1, padding=(1,1)),
91
+ nn.Dropout(drop),
92
+ nn.ReLU(),
93
+ nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
94
+ nn.Dropout(drop),
95
+ nn.ReLU(),
96
+ nn.Upsample(scale_factor=pool, mode='bilinear'),
97
+
98
+ #nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
99
+ #nn.Dropout(drop),
100
+ #nn.ReLU(),
101
+ #nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
102
+ #nn.Dropout(drop),
103
+ #nn.ReLU(),
104
+ #nn.Upsample(scale_factor=pool, mode='bilinear'),
105
+
106
+ nn.Conv2d(nconv*2, 1, 3, stride=1, padding=(1,1)), #Output conv layer has 2 for mu and sigma
107
+ nn.Sigmoid() #Amplitude mode
108
+ )
109
+
110
+
111
+ def forward(self,x):
112
+ with torch.cuda.amp.autocast():
113
+ x1 = self.encoder(x)
114
+ x1 = self.bottleneck(x1)
115
+ #print(x1.shape)
116
+ return x1
117
+
118
+
119
+ #Helper function to calculate size of flattened array from conv layer shapes
120
+ def calc_fc_shape(self):
121
+ x0 = torch.zeros([256,256]).unsqueeze(0)
122
+ x0 = self.encoder(x0)
123
+
124
+ self.conv_bock_output_shape = x0.shape
125
+ #print ("Output of conv block shape is", self.conv_bock_output_shape)
126
+ self.flattened_size = x0.flatten().shape[0]
127
+ #print ("Flattened layer size is", self.flattened_size)
128
+ return self.flattened_size
129
+
130
+ class recon_model(nn.Module):
131
+
132
+ def __init__(self, latent_size, nconv=16, pool=4, drop=0.05):
133
+ super(recon_model, self).__init__()
134
+
135
+
136
+ self.encoder = nn.Sequential( # Appears sequential has similar functionality as TF avoiding need for separate model definition and activ
137
+ nn.Conv2d(in_channels=1, out_channels=nconv, kernel_size=3, stride=1, padding=(1,1)),
138
+ nn.Dropout(drop),
139
+ nn.ReLU(),
140
+ nn.Conv2d(nconv, nconv, 3, stride=1, padding=(1,1)),
141
+ nn.Dropout(drop),
142
+ nn.ReLU(),
143
+ nn.MaxPool2d((pool,pool)),
144
+
145
+ nn.Conv2d(nconv, nconv*2, 3, stride=1, padding=(1,1)),
146
+ nn.Dropout(drop),
147
+ nn.ReLU(),
148
+ nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
149
+ nn.Dropout(drop),
150
+ nn.ReLU(),
151
+ nn.MaxPool2d((pool,pool)),
152
+
153
+ nn.Conv2d(nconv*2, nconv*4, 3, stride=1, padding=(1,1)),
154
+ nn.Dropout(drop),
155
+ nn.ReLU(),
156
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
157
+ nn.Dropout(drop),
158
+ nn.ReLU(),
159
+ nn.MaxPool2d((pool,pool)),
160
+
161
+ #nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
162
+ #nn.Dropout(drop),
163
+ #nn.ReLU(),
164
+ #nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
165
+ #nn.Dropout(drop),
166
+ #nn.ReLU(),
167
+ #nn.MaxPool2d((pool,pool)),
168
+ )
169
+
170
+
171
+ self.bottleneck = nn.Sequential(
172
+ # FC layer at bottleneck -- dropout might not make sense here
173
+ nn.Flatten(),
174
+ nn.Linear(1024, latent_size),
175
+ #nn.Dropout(drop),
176
+ nn.ReLU(),
177
+ nn.Linear(latent_size, 1024),
178
+ #nn.Dropout(drop),
179
+ nn.ReLU(),
180
+ nn.Unflatten(1,(64,4,4))# 0 is batch dimension
181
+ )
182
+
183
+
184
+ self.decoder1 = nn.Sequential(
185
+
186
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
187
+ nn.Dropout(drop),
188
+ nn.ReLU(),
189
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
190
+ nn.Dropout(drop),
191
+ nn.ReLU(),
192
+ nn.Upsample(scale_factor=pool, mode='bilinear'),
193
+
194
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
195
+ nn.Dropout(drop),
196
+ nn.ReLU(),
197
+ nn.Conv2d(nconv*4, nconv*4, 3, stride=1, padding=(1,1)),
198
+ nn.Dropout(drop),
199
+ nn.ReLU(),
200
+ nn.Upsample(scale_factor=pool, mode='bilinear'),
201
+
202
+ nn.Conv2d(nconv*4, nconv*2, 3, stride=1, padding=(1,1)),
203
+ nn.Dropout(drop),
204
+ nn.ReLU(),
205
+ nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
206
+ nn.Dropout(drop),
207
+ nn.ReLU(),
208
+ nn.Upsample(scale_factor=pool, mode='bilinear'),
209
+
210
+ #nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
211
+ #nn.Dropout(drop),
212
+ #nn.ReLU(),
213
+ #nn.Conv2d(nconv*2, nconv*2, 3, stride=1, padding=(1,1)),
214
+ #nn.Dropout(drop),
215
+ #nn.ReLU(),
216
+ #nn.Upsample(scale_factor=pool, mode='bilinear'),
217
+
218
+ nn.Conv2d(nconv*2, 1, 3, stride=1, padding=(1,1)), #Output conv layer has 2 for mu and sigma
219
+ nn.Sigmoid() #Amplitude mode
220
+ )
221
+
222
+
223
+ def forward(self,x):
224
+ with torch.cuda.amp.autocast():
225
+ x1 = self.encoder(x)
226
+ x1 = self.bottleneck(x1)
227
+ #print(x1.shape)
228
+ return self.decoder1(x1)
229
+
230
+
231
+ #Helper function to calculate size of flattened array from conv layer shapes
232
+ def calc_fc_shape(self):
233
+ x0 = torch.zeros([256,256]).unsqueeze(0)
234
+ x0 = self.encoder(x0)
235
+
236
+ self.conv_bock_output_shape = x0.shape
237
+ #print ("Output of conv block shape is", self.conv_bock_output_shape)
238
+ self.flattened_size = x0.flatten().shape[0]
239
+ #print ("Flattened layer size is", self.flattened_size)
240
+ return self.flattened_size
241
+
242
+ full_model = torch.load('betst_model_100x_0064.pth',map_location=torch.device('cpu'))
243
+ encoder_model = recon_encoder(latent_size=64)
244
+ encoder_state_dict = encoder_model.state_dict()
245
+
246
+ checkpoint = torch.load('betst_model_100x_0064_statedict.pth',map_location=torch.device('cpu'))
247
+ pretrained_dict = {k: v for k, v in checkpoint.items() if k in encoder_state_dict}
248
+
249
+ encoder_model.load_state_dict(pretrained_dict)
250
+ #
251
+ #all_data = np.load('E031_256.npy').astype(np.float32)
252
+ #all_data = all_data.reshape(-1,1,256,256)
253
+ #dataloader = DataLoader(all_data,batch_size=32,shuffle=False)
254
+
255
+ def load_data(file):
256
+ all_data = np.load(file.name).astype(np.float32)
257
+ all_data = all_data.reshape(-1,1,256,256)
258
+ dataloader = DataLoader(all_data,batch_size=32,shuffle=False)
259
+ return all_data, dataloader, 'upload complete: {}'.format(all_data.shape)
260
+
261
+
262
+
263
+ def show_image(selection, all_data):
264
+ fig1, ax1 = plt.subplots()
265
+ ax1.imshow(all_data[selection][0],plt.cm.inferno,origin='lower')
266
+ ax1.axis('off')
267
+ fig1.tight_layout()
268
+
269
+ fig2, ax2 = plt.subplots()
270
+ prediction = full_model(torch.tensor(all_data[selection].reshape(-1,1,256,256))).detach().cpu().numpy()
271
+ ax2.imshow(prediction[0,0],plt.cm.inferno,origin='lower')
272
+ ax2.axis('off')
273
+ fig2.tight_layout()
274
+
275
+ return fig1, fig2
276
+
277
+ def encode_data(dataloader):
278
+ preds_full = []
279
+ preds_enc = []
280
+
281
+
282
+ for i, images in enumerate(dataloader):
283
+ if i > 5:
284
+ break
285
+ pred_full = full_model(images)
286
+ pred_enc = encoder_model(images)
287
+ for j in range(images.shape[0]):
288
+ preds_full.append(pred_full[j].detach().cpu().numpy())
289
+ preds_enc.append(pred_enc[j].detach().cpu().numpy())
290
+
291
+
292
+ processed_images = np.array(preds_full).squeeze()
293
+ encoded_images = np.array(preds_enc)
294
+ message = 'finished'
295
+
296
+ return message, processed_images, encoded_images
297
+
298
+ def print_state(state):
299
+ return state.shape
300
+
301
+ def latent_vis(encoded_data,decomp_method,clustering_method,cluster_number,all_data):
302
+ if decomp_method == 'PCA':
303
+ pca = PCA(n_components=2)
304
+ decomp = pca.fit_transform(encoded_data)
305
+ elif decomp_method == 'tSNE':
306
+ tsne = TSNE(n_components=2)
307
+ decomp = tsne.fit_transform(encoded_data)
308
+ elif decomp_method == 'UMAP':
309
+ reducer = UMAP()
310
+ decomp = reducer.fit_transform(encoded_data)
311
+
312
+ if clustering_method == 'KMeans':
313
+ kmeans = KMeans(n_clusters=int(cluster_number))
314
+ cluster_labels = kmeans.fit_predict(encoded_data)
315
+
316
+
317
+ df = pd.DataFrame(decomp,columns=['x','y'])
318
+ df['cluster'] = cluster_labels
319
+ df['value'] = np.ones_like(cluster_labels) * np.arange(len(decomp))
320
+
321
+ fig = px.scatter(df,x='x',y='y',color='cluster',color_continuous_scale='viridis',hover_name='value',hover_data={'x': False,
322
+ 'y': False,
323
+ 'cluster': False,
324
+ 'value': False})
325
+ # fig = px.scatter(x=decomp[:,0],y=decomp[:,1],color=clusters,hover_data=np.arange(len(decomp)))
326
+ fig.update_layout(clickmode='event+select')
327
+ fig.update_traces(marker=dict(size=12),
328
+ selector=dict(mode='markers'))
329
+
330
+ fig1 = plt.figure(figsize=(20,5))
331
+ n_rows = 1
332
+ n_cols = int(cluster_number)
333
+ colors = plt.cm.viridis(np.linspace(0,1,len(np.unique(cluster_labels))))
334
+
335
+ for i in np.unique(cluster_labels):
336
+ ind = np.where(cluster_labels[:] == i)[0]
337
+ #ax.scatter(decomp[cluster_labels[:] == i,0],decomp[cluster_labels[:] == i,1],color=colors[i],label='class {}'.format(i))
338
+
339
+ r = np.random.choice(ind)
340
+ ax1 = fig1.add_subplot(n_rows,n_cols,i+1)
341
+ ax1.imshow(all_data[r][0],plt.cm.inferno,origin='lower')
342
+ ax1.set_title('Class {}: {}'.format(i,len(ind)),color=colors[i],fontsize=20,weight='bold')
343
+
344
+ #ax.legend()
345
+
346
+ #fig.tight_layout()
347
+ fig1.tight_layout()
348
+ return decomp, cluster_labels, fig, fig1
349
+
350
+ def interactive_vis(decomp,clusters,images):
351
+ df = pd.DataFrame(decomp,columns=['x','y'])
352
+ df['cluster'] = clusters
353
+ df['value'] = np.ones_like(clusters) * np.arange(len(decomp))
354
+ df['im'] = images
355
+
356
+ fig = px.scatter(df,x='x',y='y',color='cluster',custom_data='im',color_continuous_scale='viridis',hover_name='value',hover_data={'x': False,
357
+ 'y': False,
358
+ 'cluster': False,
359
+ 'value': False})
360
+ # fig = px.scatter(x=decomp[:,0],y=decomp[:,1],color=clusters,hover_data=np.arange(len(decomp)))
361
+ fig.update_layout(clickmode='event+select')
362
+ fig.update_traces(marker=dict(size=20),
363
+ selector=dict(mode='markers'))
364
+
365
+ return fig
366
+
367
+ def neighbor_vis(decomp,neighbor_index,n_neighbors,all_data):
368
+ neighbor_index = int(neighbor_index)
369
+
370
+ d = np.sqrt((decomp[:,0] - decomp[neighbor_index,0]) ** 2 + (decomp[:,1] - decomp[neighbor_index,1]) ** 2)
371
+ ar = np.argsort(d)
372
+
373
+ n_rows = int(np.ceil(n_neighbors/5))
374
+ n_cols = 5
375
+ fig = plt.figure(figsize=(20,5*n_rows))
376
+
377
+ n = 1
378
+ ax = fig.add_subplot(n_rows,n_cols,n)
379
+ ax.imshow(all_data[neighbor_index][0],plt.cm.inferno,origin='lower')
380
+ ax.set_title('{}'.format(neighbor_index),fontsize=20,weight='bold')
381
+ ax.axis('off')
382
+ n += 1
383
+
384
+ neighbors = ar[1:1+n_neighbors-1]
385
+
386
+ for i in neighbors:
387
+ ax = fig.add_subplot(n_rows,n_cols,n)
388
+ ax.imshow(all_data[i][0],plt.cm.inferno,origin='lower')
389
+ ax.set_title('{}'.format(i),fontsize=20)
390
+ ax.axis('off')
391
+ n += 1
392
+
393
+ return fig
394
+
395
+
396
+ intro_text1 = '# AI-NERD: Artificial Intelligence for Non-Equilibrium Relaxation Dynamics'
397
+ intro_text2 = 'AI-NERD is a platform for applying unsupervised image classification to X-ray Photon Corrleation Spectroscopy (XPCS) data. Here, we demonstrate how raw experimental data can be automatically processed and clustered, and how latent space analysis can be used to understand the physics of relaxing systems without any background information or assumptions.<br><br>Please see out [preprint](https://arxiv.org/abs/2212.03984) for more information.<br><br>'
398
+ l = 900
399
+ with gr.Blocks() as demo:
400
+ gr.Markdown(intro_text1)
401
+ gr.Markdown(intro_text2)
402
+
403
+ gr.Markdown('### Evaluation of Training Results')
404
+ gr.Markdown('Use the dropdown menu below to select a sample image. The frame on the left will show the raw C2 data, and the frame on the right will show the neural network output. After sampling individual images, click _Process All Images_ to run the entire dataset through the Autoencoder')
405
+
406
+ with gr.Row():
407
+ file_path = gr.File()
408
+ with gr.Column():
409
+ upload_status = gr.Textbox(label='file upload status')
410
+ file_upload = gr.Button(value='load data')
411
+
412
+ all_data = gr.State()
413
+ dataloader = gr.State()
414
+ file_upload.click(load_data,file_path,[all_data,dataloader,upload_status])
415
+
416
+ selection = gr.Dropdown(list(np.arange(2000)),value=200,label='select sample image')
417
+ with gr.Row():
418
+ output_image_1 = gr.Plot(label='input C2 data')
419
+ output_image_2 = gr.Plot(label='Autoencoder Reproduction')
420
+
421
+ selection.change(show_image,[selection, all_data],[output_image_1,output_image_2])
422
+
423
+ with gr.Row():
424
+ process_all = gr.Button(value='Process All Images')
425
+ status = gr.Textbox(label='batch processing status')
426
+
427
+
428
+ proc_im = gr.State()
429
+ enc_im = gr.State()
430
+ process_all.click(encode_data,inputs=[dataloader],outputs=[status,proc_im,enc_im],show_progress=True,status_tracker=None)
431
+
432
+ # check_type = gr.Button(value='check state info')
433
+ # check_stat = gr.Textbox()
434
+ # check_type.click(print_state,inputs=proc_im,outputs=check_stat)
435
+ gr.Markdown('<br><br>')
436
+ gr.Markdown('### Latent Space Visualization')
437
+ gr.Markdown('Select the decomposition and clustering method for latent space visualization')
438
+ with gr.Row():
439
+ with gr.Column():
440
+ decomp_method = gr.Dropdown(choices=['PCA','tSNE','UMAP'],label='select decomposition method',value='UMAP')
441
+ with gr.Row():
442
+ clustering_method = gr.Dropdown(choices=['KMeans','Agglomerative','DBSCAN'],label='select clusterting algorithm',value='KMeans')
443
+ cluster_number = gr.Number(label='input number of clusters',value=5)
444
+
445
+ process_vis = gr.Button(value='Visualize Latent Space')
446
+ latent_scatter = gr.Plot()
447
+ latent_sample = gr.Plot()
448
+
449
+ save_decomp_coords = gr.State()
450
+ save_cluster_labels = gr.State()
451
+ process_vis.click(latent_vis,[enc_im,decomp_method,clustering_method,cluster_number,all_data],[save_decomp_coords,save_cluster_labels,latent_scatter,latent_sample])
452
+
453
+ gr.Markdown('<br><br><br>')
454
+ gr.Markdown('### Visualize Nearest Neighbors')
455
+ gr.Markdown('Hover over data points in the scatter plot above, to identify the index of points of interest. Enter the desired index in the box below, and click _Visualize Neighbors_.')
456
+
457
+ with gr.Row():
458
+ with gr.Column():
459
+ neighbor_index = gr.Number(label='input point index',value=110)
460
+ n_neighbors = gr.Slider(label='select number of neighbors to view',minimum=5,maximum=10,value=5,step=1)
461
+
462
+ neighbor_button = gr.Button(value='Visualize Neighbors')
463
+
464
+ neighbor_plot = gr.Plot()
465
+ neighbor_button.click(neighbor_vis,[save_decomp_coords,neighbor_index,n_neighbors,all_data],neighbor_plot)
466
+ #neighbor_button.click(interactive_vis,[save_decomp_coords,save_cluster_labels,proc_im],interactive_plot)
467
+
468
+
469
+
470
+ demo.launch()
471
+
betst_model_100x_0064.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e89e9b45bece44b11c9e00978bb54e5533ea6b146b310a0edad4f2839c98a4b
3
+ size 1538571
betst_model_100x_0064_statedict.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af36b48cb6c7818e3e68d443620f4ca61e7c8332ad61e98170111a0feb6c8713
3
+ size 1528395
partial_dataset.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f2874320efe4cfef0f3c4e59534593a7e8e772aafbd4dfbc4599c1beacb4550
3
+ size 5243008
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ numpy
3
+ matplotlib
4
+ scikit-learn
5
+ umap-learn
6
+ plotly
7
+ pandas