huathedev commited on
Commit
fe8c632
1 Parent(s): 9a64105

Update ⓘ_Introduction.py

Browse files
Files changed (1) hide show
  1. ⓘ_Introduction.py +10 -890
ⓘ_Introduction.py CHANGED
@@ -1,24 +1,5 @@
1
- from streamlit import session_state as session
2
- import shutil
3
-
4
- import os
5
- import numpy as np
6
- from sklearn import neighbors
7
- from scipy.spatial import distance_matrix
8
- from pygco import cut_from_graph
9
- import streamlit_ext as ste
10
- import open3d as o3d
11
- import matplotlib.pyplot as plt
12
- import matplotlib.colors as mcolors
13
- from stqdm import stqdm
14
- import json
15
- from stpyvista import stpyvista
16
- import torch
17
- import torch.nn as nn
18
- from torch.autograd import Variable
19
- import torch.nn.functional as F
20
  import streamlit as st
21
- import pyvista as pv
22
 
23
  from PIL import Image
24
 
@@ -27,7 +8,7 @@ class TeethApp:
27
  # Font
28
  with open("utils/style.css") as css:
29
  st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
30
-
31
  # Logo
32
  self.image_path = "utils/teeth-295404_1280.png"
33
  self.image = Image.open(self.image_path)
@@ -49,881 +30,20 @@ class TeethApp:
49
  unsafe_allow_html=True,
50
  )
51
 
52
-
53
- class STN3d(nn.Module):
54
- def __init__(self, channel):
55
- super(STN3d, self).__init__()
56
- self.conv1 = torch.nn.Conv1d(channel, 64, 1)
57
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
58
- self.conv3 = torch.nn.Conv1d(128, 1024, 1)
59
- self.fc1 = nn.Linear(1024, 512)
60
- self.fc2 = nn.Linear(512, 256)
61
- self.fc3 = nn.Linear(256, 9)
62
- self.relu = nn.ReLU()
63
-
64
- self.bn1 = nn.BatchNorm1d(64)
65
- self.bn2 = nn.BatchNorm1d(128)
66
- self.bn3 = nn.BatchNorm1d(1024)
67
- self.bn4 = nn.BatchNorm1d(512)
68
- self.bn5 = nn.BatchNorm1d(256)
69
-
70
- def forward(self, x):
71
- batchsize = x.size()[0]
72
- x = F.relu(self.bn1(self.conv1(x)))
73
- x = F.relu(self.bn2(self.conv2(x)))
74
- x = F.relu(self.bn3(self.conv3(x)))
75
- x = torch.max(x, 2, keepdim=True)[0]
76
- x = x.view(-1, 1024)
77
-
78
- x = F.relu(self.bn4(self.fc1(x)))
79
- x = F.relu(self.bn5(self.fc2(x)))
80
- x = self.fc3(x)
81
-
82
- iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat(
83
- batchsize, 1)
84
- if x.is_cuda:
85
- iden = iden.to(x.get_device())
86
- x = x + iden
87
- x = x.view(-1, 3, 3)
88
- return x
89
-
90
- class STNkd(nn.Module):
91
- def __init__(self, k=64):
92
- super(STNkd, self).__init__()
93
- self.conv1 = torch.nn.Conv1d(k, 64, 1)
94
- self.conv2 = torch.nn.Conv1d(64, 128, 1)
95
- self.conv3 = torch.nn.Conv1d(128, 512, 1)
96
- self.fc1 = nn.Linear(512, 256)
97
- self.fc2 = nn.Linear(256, 128)
98
- self.fc3 = nn.Linear(128, k * k)
99
- self.relu = nn.ReLU()
100
-
101
- self.bn1 = nn.BatchNorm1d(64)
102
- self.bn2 = nn.BatchNorm1d(128)
103
- self.bn3 = nn.BatchNorm1d(512)
104
- self.bn4 = nn.BatchNorm1d(256)
105
- self.bn5 = nn.BatchNorm1d(128)
106
-
107
- self.k = k
108
-
109
- def forward(self, x):
110
- batchsize = x.size()[0]
111
- x = F.relu(self.bn1(self.conv1(x)))
112
- x = F.relu(self.bn2(self.conv2(x)))
113
- x = F.relu(self.bn3(self.conv3(x)))
114
- x = torch.max(x, 2, keepdim=True)[0]
115
- x = x.view(-1, 512)
116
-
117
- x = F.relu(self.bn4(self.fc1(x)))
118
- x = F.relu(self.bn5(self.fc2(x)))
119
- x = self.fc3(x)
120
-
121
- iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1, self.k * self.k).repeat(
122
- batchsize, 1)
123
- if x.is_cuda:
124
- iden = iden.to(x.get_device())
125
- x = x + iden
126
- x = x.view(-1, self.k, self.k)
127
- return x
128
-
129
- class MeshSegNet(nn.Module):
130
- def __init__(self, num_classes=17, num_channels=15, with_dropout=True, dropout_p=0.5):
131
- super(MeshSegNet, self).__init__()
132
- self.num_classes = num_classes
133
- self.num_channels = num_channels
134
- self.with_dropout = with_dropout
135
- self.dropout_p = dropout_p
136
-
137
- # MLP-1 [64, 64]
138
- self.mlp1_conv1 = torch.nn.Conv1d(self.num_channels, 64, 1)
139
- self.mlp1_conv2 = torch.nn.Conv1d(64, 64, 1)
140
- self.mlp1_bn1 = nn.BatchNorm1d(64)
141
- self.mlp1_bn2 = nn.BatchNorm1d(64)
142
- # FTM (feature-transformer module)
143
- self.fstn = STNkd(k=64)
144
- # GLM-1 (graph-contrained learning modulus)
145
- self.glm1_conv1_1 = torch.nn.Conv1d(64, 32, 1)
146
- self.glm1_conv1_2 = torch.nn.Conv1d(64, 32, 1)
147
- self.glm1_bn1_1 = nn.BatchNorm1d(32)
148
- self.glm1_bn1_2 = nn.BatchNorm1d(32)
149
- self.glm1_conv2 = torch.nn.Conv1d(32+32, 64, 1)
150
- self.glm1_bn2 = nn.BatchNorm1d(64)
151
- # MLP-2
152
- self.mlp2_conv1 = torch.nn.Conv1d(64, 64, 1)
153
- self.mlp2_bn1 = nn.BatchNorm1d(64)
154
- self.mlp2_conv2 = torch.nn.Conv1d(64, 128, 1)
155
- self.mlp2_bn2 = nn.BatchNorm1d(128)
156
- self.mlp2_conv3 = torch.nn.Conv1d(128, 512, 1)
157
- self.mlp2_bn3 = nn.BatchNorm1d(512)
158
- # GLM-2 (graph-contrained learning modulus)
159
- self.glm2_conv1_1 = torch.nn.Conv1d(512, 128, 1)
160
- self.glm2_conv1_2 = torch.nn.Conv1d(512, 128, 1)
161
- self.glm2_conv1_3 = torch.nn.Conv1d(512, 128, 1)
162
- self.glm2_bn1_1 = nn.BatchNorm1d(128)
163
- self.glm2_bn1_2 = nn.BatchNorm1d(128)
164
- self.glm2_bn1_3 = nn.BatchNorm1d(128)
165
- self.glm2_conv2 = torch.nn.Conv1d(128*3, 512, 1)
166
- self.glm2_bn2 = nn.BatchNorm1d(512)
167
- # MLP-3
168
- self.mlp3_conv1 = torch.nn.Conv1d(64+512+512+512, 256, 1)
169
- self.mlp3_conv2 = torch.nn.Conv1d(256, 256, 1)
170
- self.mlp3_bn1_1 = nn.BatchNorm1d(256)
171
- self.mlp3_bn1_2 = nn.BatchNorm1d(256)
172
- self.mlp3_conv3 = torch.nn.Conv1d(256, 128, 1)
173
- self.mlp3_conv4 = torch.nn.Conv1d(128, 128, 1)
174
- self.mlp3_bn2_1 = nn.BatchNorm1d(128)
175
- self.mlp3_bn2_2 = nn.BatchNorm1d(128)
176
- # output
177
- self.output_conv = torch.nn.Conv1d(128, self.num_classes, 1)
178
- if self.with_dropout:
179
- self.dropout = nn.Dropout(p=self.dropout_p)
180
-
181
- def forward(self, x, a_s, a_l):
182
- batchsize = x.size()[0]
183
- n_pts = x.size()[2]
184
- # MLP-1
185
- x = F.relu(self.mlp1_bn1(self.mlp1_conv1(x)))
186
- x = F.relu(self.mlp1_bn2(self.mlp1_conv2(x)))
187
- # FTM
188
- trans_feat = self.fstn(x)
189
- x = x.transpose(2, 1)
190
- x_ftm = torch.bmm(x, trans_feat)
191
- # GLM-1
192
- sap = torch.bmm(a_s, x_ftm)
193
- sap = sap.transpose(2, 1)
194
- x_ftm = x_ftm.transpose(2, 1)
195
- x = F.relu(self.glm1_bn1_1(self.glm1_conv1_1(x_ftm)))
196
- glm_1_sap = F.relu(self.glm1_bn1_2(self.glm1_conv1_2(sap)))
197
- x = torch.cat([x, glm_1_sap], dim=1)
198
- x = F.relu(self.glm1_bn2(self.glm1_conv2(x)))
199
- # MLP-2
200
- x = F.relu(self.mlp2_bn1(self.mlp2_conv1(x)))
201
- x = F.relu(self.mlp2_bn2(self.mlp2_conv2(x)))
202
- x_mlp2 = F.relu(self.mlp2_bn3(self.mlp2_conv3(x)))
203
- if self.with_dropout:
204
- x_mlp2 = self.dropout(x_mlp2)
205
- # GLM-2
206
- x_mlp2 = x_mlp2.transpose(2, 1)
207
- sap_1 = torch.bmm(a_s, x_mlp2)
208
- sap_2 = torch.bmm(a_l, x_mlp2)
209
- x_mlp2 = x_mlp2.transpose(2, 1)
210
- sap_1 = sap_1.transpose(2, 1)
211
- sap_2 = sap_2.transpose(2, 1)
212
- x = F.relu(self.glm2_bn1_1(self.glm2_conv1_1(x_mlp2)))
213
- glm_2_sap_1 = F.relu(self.glm2_bn1_2(self.glm2_conv1_2(sap_1)))
214
- glm_2_sap_2 = F.relu(self.glm2_bn1_3(self.glm2_conv1_3(sap_2)))
215
- x = torch.cat([x, glm_2_sap_1, glm_2_sap_2], dim=1)
216
- x_glm2 = F.relu(self.glm2_bn2(self.glm2_conv2(x)))
217
- # GMP
218
- x = torch.max(x_glm2, 2, keepdim=True)[0]
219
- # Upsample
220
- x = torch.nn.Upsample(n_pts)(x)
221
- # Dense fusion
222
- x = torch.cat([x, x_ftm, x_mlp2, x_glm2], dim=1)
223
- # MLP-3
224
- x = F.relu(self.mlp3_bn1_1(self.mlp3_conv1(x)))
225
- x = F.relu(self.mlp3_bn1_2(self.mlp3_conv2(x)))
226
- x = F.relu(self.mlp3_bn2_1(self.mlp3_conv3(x)))
227
- if self.with_dropout:
228
- x = self.dropout(x)
229
- x = F.relu(self.mlp3_bn2_2(self.mlp3_conv4(x)))
230
- # output
231
- x = self.output_conv(x)
232
- x = x.transpose(2,1).contiguous()
233
- x = torch.nn.Softmax(dim=-1)(x.view(-1, self.num_classes))
234
- x = x.view(batchsize, n_pts, self.num_classes)
235
-
236
- return x
237
-
238
- def clone_runoob(li1):
239
- li_copy = li1[:]
240
- return li_copy
241
-
242
- # 对离群点重新进行分类
243
- def class_inlier_outlier(label_list, mean_points,cloud, ind, label_index, points, labels):
244
- label_change = clone_runoob(labels)
245
- outlier_index = clone_runoob(label_index)
246
- ind_reverse = clone_runoob(ind)
247
- # 得到离群点的label下标
248
- ind_reverse.reverse()
249
- for i in ind_reverse:
250
- outlier_index.pop(i)
251
-
252
- # 获取离群点
253
- inlier_cloud = cloud.select_by_index(ind)
254
- outlier_cloud = cloud.select_by_index(ind, invert=True)
255
- outlier_points = np.array(outlier_cloud.points)
256
-
257
- for i in range(len(outlier_points)):
258
- distance = []
259
- for j in range(len(mean_points)):
260
- dis = np.linalg.norm(outlier_points[i] - mean_points[j], ord=2) # 计算tooth和GT质心之间的距离
261
- distance.append(dis)
262
- min_index = distance.index(min(distance)) # 获取和离群点质心最近label的index
263
- outlier_label = label_list[min_index] # 获取离群点应该的label
264
- index = outlier_index[i]
265
- label_change[index] = outlier_label
266
-
267
- return label_change
268
-
269
- # 利用knn算法消除离群点
270
- def remove_outlier(points, labels):
271
- # points = np.array(point_cloud_o3d_orign.points)
272
- # global label_list
273
- same_label_points = {}
274
-
275
- same_label_index = {}
276
-
277
- mean_points = [] # 所有label种类对应点云的质心坐标
278
-
279
- label_list = []
280
- for i in range(len(labels)):
281
- label_list.append(labels[i])
282
- label_list = list(set(label_list)) # 去重获从小到大排序取GT_label=[0, 11, 12, 13, 14, 15, 16, 17, 21, 22, 23, 24, 25, 26, 27]
283
- label_list.sort()
284
- label_list = label_list[1:]
285
-
286
- for i in label_list:
287
- key = i
288
- points_list = []
289
- all_label_index = []
290
- for j in range(len(labels)):
291
- if labels[j] == i:
292
- points_list.append(points[j].tolist())
293
- all_label_index.append(j) # 得到label为 i 的点对应的label的下标
294
- same_label_points[key] = points_list
295
- same_label_index[key] = all_label_index
296
-
297
- tooth_mean = np.mean(points_list, axis=0)
298
- mean_points.append(tooth_mean)
299
- # print(mean_points)
300
-
301
- for i in label_list:
302
- points_array = same_label_points[i]
303
- # 建立一个o3d的点云对象
304
- pcd = o3d.geometry.PointCloud()
305
- # 使用Vector3dVector方法转换
306
- pcd.points = o3d.utility.Vector3dVector(points_array)
307
-
308
- # 对label i 对应的点云进行统计离群值去除,找出离群点并显示
309
- # 统计式离群点移除
310
- cl, ind = pcd.remove_statistical_outlier(nb_neighbors=200, std_ratio=2.0) # cl是选中的点,ind是选中点index
311
- # 可视化
312
- # display_inlier_outlier(pcd, ind)
313
-
314
- # 对分出来的离群点重新分类
315
- label_index = same_label_index[i]
316
- labels = class_inlier_outlier(label_list, mean_points, pcd, ind, label_index, points, labels)
317
- # print(f"label_change{labels[4400]}")
318
-
319
- return labels
320
-
321
-
322
- # 消除离群点,保存最后的输出
323
- def remove_outlier_main(jaw, pcd_points, labels, instances_labels):
324
- # point_cloud_o3d_orign = o3d.io.read_point_cloud('E:/tooth/data/MeshSegNet-master/test_upsample_15/upsample_01K17AN8_upper_refined.pcd')
325
- # 原始点
326
- points = pcd_points.copy()
327
- label = remove_outlier(points, labels)
328
-
329
- # 保存json文件
330
- label_dict = {}
331
- label_dict["id_patient"] = ""
332
- label_dict["jaw"] = jaw
333
- label_dict["labels"] = label.tolist()
334
- label_dict["instances"] = instances_labels.tolist()
335
- b = json.dumps(label_dict)
336
- with open('dental-labels4' + '.json', 'w') as f_obj:
337
- f_obj.write(b)
338
- f_obj.close()
339
-
340
-
341
- same_points_list = {}
342
-
343
-
344
- # 体素下采样
345
- def voxel_filter(point_cloud, leaf_size):
346
- same_points_list = {}
347
- filtered_points = []
348
- # step1 计算边界点
349
- x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
350
- x_min, y_min, z_min = np.amin(point_cloud, axis=0)
351
-
352
- # step2 确定体素的尺寸
353
- size_r = leaf_size
354
-
355
- # step3 计算每个 volex的维度 voxel grid
356
- Dx = (x_max - x_min) // size_r + 1
357
- Dy = (y_max - y_min) // size_r + 1
358
- Dz = (z_max - z_min) // size_r + 1
359
-
360
- # print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
361
-
362
- # step4 计算每个点在volex grid内每一个维度的值
363
- h = list() # h 为保存索引的列表
364
- for i in range(len(point_cloud)):
365
- hx = np.floor((point_cloud[i][0] - x_min) // size_r)
366
- hy = np.floor((point_cloud[i][1] - y_min) // size_r)
367
- hz = np.floor((point_cloud[i][2] - z_min) // size_r)
368
- h.append(hx + hy * Dx + hz * Dx * Dy)
369
- # print(h[60581])
370
-
371
- # step5 对h值进行排序
372
- h = np.array(h)
373
- h_indice = np.argsort(h) # 提取索引,返回h里面的元素按从小到大排序的 索引
374
- h_sorted = h[h_indice] # 升序
375
- count = 0 # 用于维度的累计
376
- step = 20
377
- # 将h值相同的点放入到同一个grid中,并进行筛选
378
- for i in range(1, len(h_sorted)): # 0-19999个数据点
379
- # if i == len(h_sorted)-1:
380
- # print("aaa")
381
- if h_sorted[i] == h_sorted[i - 1] and (i != len(h_sorted) - 1):
382
- continue
383
- elif h_sorted[i] == h_sorted[i - 1] and (i == len(h_sorted) - 1):
384
- point_idx = h_indice[count:]
385
- key = h_sorted[i - 1]
386
- same_points_list[key] = point_idx
387
- _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
388
- _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
389
- _d.sort()
390
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
391
- for j in inx:
392
- index = point_idx[j]
393
- filtered_points.append(point_cloud[index])
394
- count = i
395
- elif h_sorted[i] != h_sorted[i - 1] and (i == len(h_sorted) - 1):
396
- point_idx1 = h_indice[count:i]
397
- key1 = h_sorted[i - 1]
398
- same_points_list[key1] = point_idx1
399
- _G = np.mean(point_cloud[point_idx1], axis=0) # 所有点的重心
400
- _d = np.linalg.norm(point_cloud[point_idx1] - _G, axis=1, ord=2) # 计算到重心的距离
401
- _d.sort()
402
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
403
- for j in inx:
404
- index = point_idx1[j]
405
- filtered_points.append(point_cloud[index])
406
-
407
- point_idx2 = h_indice[i:]
408
- key2 = h_sorted[i]
409
- same_points_list[key2] = point_idx2
410
- _G = np.mean(point_cloud[point_idx2], axis=0) # 所有点的重心
411
- _d = np.linalg.norm(point_cloud[point_idx2] - _G, axis=1, ord=2) # 计算到重心的距离
412
- _d.sort()
413
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
414
- for j in inx:
415
- index = point_idx2[j]
416
- filtered_points.append(point_cloud[index])
417
- count = i
418
-
419
- else:
420
- point_idx = h_indice[count: i]
421
- key = h_sorted[i - 1]
422
- same_points_list[key] = point_idx
423
- _G = np.mean(point_cloud[point_idx], axis=0) # 所有点的重心
424
- _d = np.linalg.norm(point_cloud[point_idx] - _G, axis=1, ord=2) # 计算到重心的距离
425
- _d.sort()
426
- inx = [j for j in range(0, len(_d), step)] # 获取指定间隔元素下标
427
- for j in inx:
428
- index = point_idx[j]
429
- filtered_points.append(point_cloud[index])
430
- count = i
431
-
432
- # 把点云格式改成array,并对外返回
433
- # print(f'filtered_points[0]为{filtered_points[0]}')
434
- filtered_points = np.array(filtered_points, dtype=np.float64)
435
- return filtered_points,same_points_list
436
-
437
-
438
- # 体素上采样
439
- def voxel_upsample(same_points_list, point_cloud, filtered_points, filter_labels, leaf_size):
440
- upsample_label = []
441
- upsample_point = []
442
- upsample_index = []
443
- # step1 计算边界点
444
- x_max, y_max, z_max = np.amax(point_cloud, axis=0) # 计算 x,y,z三个维度的最值
445
- x_min, y_min, z_min = np.amin(point_cloud, axis=0)
446
- # step2 确定体素的尺寸
447
- size_r = leaf_size
448
- # step3 计算每个 volex的维度 voxel grid
449
- Dx = (x_max - x_min) // size_r + 1
450
- Dy = (y_max - y_min) // size_r + 1
451
- Dz = (z_max - z_min) // size_r + 1
452
- print("Dx x Dy x Dz is {} x {} x {}".format(Dx, Dy, Dz))
453
-
454
- # step4 计算每个点(采样后的点)在volex grid内每一个维度的值
455
- h = list()
456
- for i in range(len(filtered_points)):
457
- hx = np.floor((filtered_points[i][0] - x_min) // size_r)
458
- hy = np.floor((filtered_points[i][1] - y_min) // size_r)
459
- hz = np.floor((filtered_points[i][2] - z_min) // size_r)
460
- h.append(hx + hy * Dx + hz * Dx * Dy)
461
-
462
- # step5 根据h值查询字典same_points_list
463
- h = np.array(h)
464
- count = 0
465
- for i in range(1, len(h)):
466
- if h[i] == h[i - 1] and i != (len(h) - 1):
467
- continue
468
- elif h[i] == h[i - 1] and i == (len(h) - 1):
469
- label = filter_labels[count:]
470
- key = h[i - 1]
471
- count = i
472
- # 累计label次数,classcount:{‘A’:2,'B':1}
473
- classcount = {}
474
- for i in range(len(label)):
475
- vote = label[i]
476
- classcount[vote] = classcount.get(vote, 0) + 1
477
- # 对map的value排序
478
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
479
- # key = h[i-1]
480
- point_index = same_points_list[key] # h对应的point index列表
481
- for j in range(len(point_index)):
482
- upsample_label.append(sortedclass[0][0])
483
- index = point_index[j]
484
- upsample_point.append(point_cloud[index])
485
- upsample_index.append(index)
486
- elif h[i] != h[i - 1] and (i == len(h) - 1):
487
- label1 = filter_labels[count:i]
488
- key1 = h[i - 1]
489
- label2 = filter_labels[i:]
490
- key2 = h[i]
491
- count = i
492
-
493
- classcount = {}
494
- for i in range(len(label1)):
495
- vote = label1[i]
496
- classcount[vote] = classcount.get(vote, 0) + 1
497
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
498
- # key1 = h[i-1]
499
- point_index = same_points_list[key1]
500
- for j in range(len(point_index)):
501
- upsample_label.append(sortedclass[0][0])
502
- index = point_index[j]
503
- upsample_point.append(point_cloud[index])
504
- upsample_index.append(index)
505
-
506
- # label2 = filter_labels[i:]
507
- classcount = {}
508
- for i in range(len(label2)):
509
- vote = label2[i]
510
- classcount[vote] = classcount.get(vote, 0) + 1
511
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
512
- # key2 = h[i]
513
- point_index = same_points_list[key2]
514
- for j in range(len(point_index)):
515
- upsample_label.append(sortedclass[0][0])
516
- index = point_index[j]
517
- upsample_point.append(point_cloud[index])
518
- upsample_index.append(index)
519
- else:
520
- label = filter_labels[count:i]
521
- key = h[i - 1]
522
- count = i
523
- classcount = {}
524
- for i in range(len(label)):
525
- vote = label[i]
526
- classcount[vote] = classcount.get(vote, 0) + 1
527
- sortedclass = sorted(classcount.items(), key=lambda x: (x[1]), reverse=True)
528
- # key = h[i-1]
529
- point_index = same_points_list[key] # h对应的point index列表
530
- for j in range(len(point_index)):
531
- upsample_label.append(sortedclass[0][0])
532
- index = point_index[j]
533
- upsample_point.append(point_cloud[index])
534
- upsample_index.append(index)
535
- # count = i
536
-
537
- # 恢复原始顺序
538
- # print(f'upsample_index[0]的值为{upsample_index[0]}')
539
- # print(f'upsample_index的总长度为{len(upsample_index)}')
540
-
541
- # 恢复index原始顺序
542
- upsample_index = np.array(upsample_index)
543
- upsample_index_indice = np.argsort(upsample_index) # 提取索引,返回h里面的元素按从小到大排序的 索引
544
- upsample_index_sorted = upsample_index[upsample_index_indice]
545
-
546
- upsample_point = np.array(upsample_point)
547
- upsample_label = np.array(upsample_label)
548
- # 恢复point和label的原始顺序
549
- upsample_point_sorted = upsample_point[upsample_index_indice]
550
- upsample_label_sorted = upsample_label[upsample_index_indice]
551
-
552
- return upsample_point_sorted, upsample_label_sorted
553
-
554
-
555
- # 利用knn算法上采样
556
- def KNN_sklearn_Load_data(voxel_points, center_points, labels):
557
- # 载入数据
558
- # x_train, x_test, y_train, y_test = train_test_split(center_points, labels, test_size=0.1)
559
- # 构建模型
560
- model = neighbors.KNeighborsClassifier(n_neighbors=3)
561
- model.fit(center_points, labels)
562
- prediction = model.predict(voxel_points.reshape(1, -1))
563
- # meshtopoints_labels = classification_report(voxel_points, prediction)
564
- return prediction[0]
565
-
566
-
567
- # 加载点进行knn上采样
568
- def Load_data(voxel_points, center_points, labels):
569
- meshtopoints_labels = []
570
- # meshtopoints_labels.append(SVC_sklearn_Load_data(voxel_points[i], center_points, labels))
571
- for i in range(0, voxel_points.shape[0]):
572
- meshtopoints_labels.append(KNN_sklearn_Load_data(voxel_points[i], center_points, labels))
573
- return np.array(meshtopoints_labels)
574
-
575
- # 将三角网格数据上采样回原始点云数据
576
- def mesh_to_points_main(jaw, pcd_points, center_points, labels):
577
- points = pcd_points.copy()
578
- # 下采样
579
- voxel_points, same_points_list = voxel_filter(points, 0.6)
580
-
581
- after_labels = Load_data(voxel_points, center_points, labels)
582
-
583
- upsample_point, upsample_label = voxel_upsample(same_points_list, points, voxel_points, after_labels, 0.6)
584
-
585
- new_pcd = o3d.geometry.PointCloud()
586
- new_pcd.points = o3d.utility.Vector3dVector(upsample_point)
587
- instances_labels = upsample_label.copy()
588
- # '''
589
- # o3d.io.write_point_cloud(os.path.join(save_path, 'upsample_' + name + '.pcd'), new_pcd, write_ascii=True)
590
- for i in stqdm(range(0, upsample_label.shape[0])):
591
- if jaw == 'upper':
592
- if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
593
- upsample_label[i] = upsample_label[i] + 10
594
- elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
595
- upsample_label[i] = upsample_label[i] + 12
596
- else:
597
- if (upsample_label[i] >= 1) and (upsample_label[i] <= 8):
598
- upsample_label[i] = upsample_label[i] + 30
599
- elif (upsample_label[i] >= 9) and (upsample_label[i] <= 16):
600
- upsample_label[i] = upsample_label[i] + 32
601
- remove_outlier_main(jaw, pcd_points, upsample_label, instances_labels)
602
-
603
-
604
- # 将原始点云数据转换为三角网格
605
- def mesh_grid(pcd_points):
606
- new_pcd,_ = voxel_filter(pcd_points, 0.6)
607
- # pcd需要有法向量
608
-
609
- # estimate radius for rolling ball
610
- pcd_new = o3d.geometry.PointCloud()
611
- pcd_new.points = o3d.utility.Vector3dVector(new_pcd)
612
- pcd_new.estimate_normals()
613
- distances = pcd_new.compute_nearest_neighbor_distance()
614
- avg_dist = np.mean(distances)
615
- radius = 6 * avg_dist
616
- mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
617
- pcd_new,
618
- o3d.utility.DoubleVector([radius, radius * 2]))
619
- # o3d.io.write_triangle_mesh("./tooth date/test.ply", mesh)
620
-
621
- return mesh
622
-
623
-
624
- # 读取obj文件内容
625
- def read_obj(obj_path):
626
- jaw = None
627
- with open(obj_path) as file:
628
- points = []
629
- faces = []
630
- while 1:
631
- line = file.readline()
632
- if not line:
633
- break
634
- strs = line.split(" ")
635
- if strs[0] == "v":
636
- points.append((float(strs[1]), float(strs[2]), float(strs[3])))
637
- elif strs[0] == "f":
638
- faces.append((int(strs[1]), int(strs[2]), int(strs[3])))
639
- elif strs[1][0:5] == 'lower':
640
- jaw = 'lower'
641
- elif strs[1][0:5] == 'upper':
642
- jaw = 'upper'
643
-
644
- points = np.array(points)
645
- faces = np.array(faces)
646
-
647
- if jaw is None:
648
- raise ValueError("Jaw type not found in OBJ file")
649
-
650
- return points, faces, jaw
651
-
652
-
653
- # obj文件转为pcd文件
654
- def obj2pcd(obj_path):
655
- if os.path.exists(obj_path):
656
- print('yes')
657
- points, _, jaw = read_obj(obj_path)
658
- pcd_list = []
659
- num_points = np.shape(points)[0]
660
- for i in range(num_points):
661
- new_line = str(points[i, 0]) + ' ' + str(points[i, 1]) + ' ' + str(points[i, 2])
662
- pcd_list.append(new_line.split())
663
-
664
- pcd_points = np.array(pcd_list).astype(np.float64)
665
- return pcd_points, jaw
666
-
667
-
668
- def segmentation_main(obj_path):
669
- upsampling_method = 'KNN'
670
-
671
- model_path = 'Mesh_Segementation_MeshSegNet_17_classes_60samples_best.tar'
672
- num_classes = 17
673
- num_channels = 15
674
-
675
- # set model
676
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
677
- model = MeshSegNet(num_classes=num_classes, num_channels=num_channels).to(device, dtype=torch.float)
678
-
679
- # load trained model
680
- # checkpoint = torch.load(os.path.join(model_path, model_name), map_location='cpu')
681
- checkpoint = torch.load(model_path, map_location='cpu')
682
- model.load_state_dict(checkpoint['model_state_dict'])
683
- del checkpoint
684
- model = model.to(device, dtype=torch.float)
685
-
686
- # cudnn
687
- torch.backends.cudnn.benchmark = True
688
- torch.backends.cudnn.enabled = True
689
-
690
- # Predicting
691
- model.eval()
692
- with torch.no_grad():
693
- pcd_points, jaw = obj2pcd(obj_path)
694
- mesh = mesh_grid(pcd_points)
695
-
696
- # move mesh to origin
697
- with st.spinner("Patience please, AI at work. Grab a coffee while you wait ☕."):
698
- vertices_points = np.asarray(mesh.vertices)
699
- triangles_points = np.asarray(mesh.triangles)
700
- N = triangles_points.shape[0]
701
- cells = np.zeros((triangles_points.shape[0], 9))
702
- cells = vertices_points[triangles_points].reshape(triangles_points.shape[0], 9)
703
-
704
- mean_cell_centers = mesh.get_center()
705
- cells[:, 0:3] -= mean_cell_centers[0:3]
706
- cells[:, 3:6] -= mean_cell_centers[0:3]
707
- cells[:, 6:9] -= mean_cell_centers[0:3]
708
-
709
- v1 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
710
- v2 = np.zeros([triangles_points.shape[0], 3], dtype='float32')
711
- v1[:, 0] = cells[:, 0] - cells[:, 3]
712
- v1[:, 1] = cells[:, 1] - cells[:, 4]
713
- v1[:, 2] = cells[:, 2] - cells[:, 5]
714
- v2[:, 0] = cells[:, 3] - cells[:, 6]
715
- v2[:, 1] = cells[:, 4] - cells[:, 7]
716
- v2[:, 2] = cells[:, 5] - cells[:, 8]
717
- mesh_normals = np.cross(v1, v2)
718
- mesh_normal_length = np.linalg.norm(mesh_normals, axis=1)
719
- mesh_normals[:, 0] /= mesh_normal_length[:]
720
- mesh_normals[:, 1] /= mesh_normal_length[:]
721
- mesh_normals[:, 2] /= mesh_normal_length[:]
722
-
723
- # prepare input
724
- points = vertices_points.copy()
725
- points[:, 0:3] -= mean_cell_centers[0:3]
726
- normals = np.nan_to_num(mesh_normals).copy()
727
- barycenters = np.zeros((triangles_points.shape[0], 3))
728
- s = np.sum(vertices_points[triangles_points], 1)
729
- barycenters = 1 / 3 * s
730
- center_points = barycenters.copy()
731
- barycenters -= mean_cell_centers[0:3]
732
-
733
- # normalized data
734
- maxs = points.max(axis=0)
735
- mins = points.min(axis=0)
736
- means = points.mean(axis=0)
737
- stds = points.std(axis=0)
738
- nmeans = normals.mean(axis=0)
739
- nstds = normals.std(axis=0)
740
-
741
- for i in range(3):
742
- cells[:, i] = (cells[:, i] - means[i]) / stds[i] # point 1
743
- cells[:, i + 3] = (cells[:, i + 3] - means[i]) / stds[i] # point 2
744
- cells[:, i + 6] = (cells[:, i + 6] - means[i]) / stds[i] # point 3
745
- barycenters[:, i] = (barycenters[:, i] - mins[i]) / (maxs[i] - mins[i])
746
- normals[:, i] = (normals[:, i] - nmeans[i]) / nstds[i]
747
-
748
- X = np.column_stack((cells, barycenters, normals))
749
-
750
- # computing A_S and A_L
751
- A_S = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
752
- A_L = np.zeros([X.shape[0], X.shape[0]], dtype='float32')
753
- D = distance_matrix(X[:, 9:12], X[:, 9:12])
754
- A_S[D < 0.1] = 1.0
755
- A_S = A_S / np.dot(np.sum(A_S, axis=1, keepdims=True), np.ones((1, X.shape[0])))
756
-
757
- A_L[D < 0.2] = 1.0
758
- A_L = A_L / np.dot(np.sum(A_L, axis=1, keepdims=True), np.ones((1, X.shape[0])))
759
-
760
- # numpy -> torch.tensor
761
- X = X.transpose(1, 0)
762
- X = X.reshape([1, X.shape[0], X.shape[1]])
763
- X = torch.from_numpy(X).to(device, dtype=torch.float)
764
- A_S = A_S.reshape([1, A_S.shape[0], A_S.shape[1]])
765
- A_L = A_L.reshape([1, A_L.shape[0], A_L.shape[1]])
766
- A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
767
- A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)
768
-
769
- tensor_prob_output = model(X, A_S, A_L).to(device, dtype=torch.float)
770
- patch_prob_output = tensor_prob_output.cpu().numpy()
771
-
772
- # refinement
773
- with st.spinner("Refining..."):
774
- round_factor = 100
775
- patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6
776
-
777
- # unaries
778
- unaries = -round_factor * np.log10(patch_prob_output)
779
- unaries = unaries.astype(np.int32)
780
- unaries = unaries.reshape(-1, num_classes)
781
-
782
- # parawisex
783
- pairwise = (1 - np.eye(num_classes, dtype=np.int32))
784
-
785
- cells = cells.copy()
786
-
787
- cell_ids = np.asarray(triangles_points)
788
- lambda_c = 20
789
- edges = np.empty([1, 3], order='C')
790
- for i_node in stqdm(range(cells.shape[0])):
791
- # Find neighbors
792
- nei = np.sum(np.isin(cell_ids, cell_ids[i_node, :]), axis=1)
793
- nei_id = np.where(nei == 2)
794
- for i_nei in nei_id[0][:]:
795
- if i_node < i_nei:
796
- cos_theta = np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]) / np.linalg.norm(
797
- normals[i_node, 0:3]) / np.linalg.norm(normals[i_nei, 0:3])
798
- if cos_theta >= 1.0:
799
- cos_theta = 0.9999
800
- theta = np.arccos(cos_theta)
801
- phi = np.linalg.norm(barycenters[i_node, :] - barycenters[i_nei, :])
802
- if theta > np.pi / 2.0:
803
- edges = np.concatenate(
804
- (edges, np.array([i_node, i_nei, -np.log10(theta / np.pi) * phi]).reshape(1, 3)), axis=0)
805
- else:
806
- beta = 1 + np.linalg.norm(np.dot(normals[i_node, 0:3], normals[i_nei, 0:3]))
807
- edges = np.concatenate(
808
- (edges, np.array([i_node, i_nei, -beta * np.log10(theta / np.pi) * phi]).reshape(1, 3)),
809
- axis=0)
810
- edges = np.delete(edges, 0, 0)
811
- edges[:, 2] *= lambda_c * round_factor
812
- edges = edges.astype(np.int32)
813
-
814
- refine_labels = cut_from_graph(edges, unaries, pairwise)
815
- refine_labels = refine_labels.reshape([-1, 1])
816
-
817
- predicted_labels_3 = refine_labels.reshape(refine_labels.shape[0])
818
- mesh_to_points_main(jaw, pcd_points, center_points, predicted_labels_3)
819
-
820
- import pyvista as pv
821
- with st.spinner("Rendering..."):
822
- # Load the .obj file
823
- mesh = pv.read('file.obj')
824
-
825
- # Load the JSON file
826
- with open('dental-labels4.json', 'r') as file:
827
- labels_data = json.load(file)
828
-
829
- # Assuming labels_data['labels'] is a list of labels
830
- labels = labels_data['labels']
831
-
832
- # Make sure the number of labels matches the number of vertices or faces
833
- assert len(labels) == mesh.n_points or len(labels) == mesh.n_cells
834
-
835
- # If labels correspond to vertices
836
- if len(labels) == mesh.n_points:
837
- mesh.point_data['Labels'] = labels
838
- # If labels correspond to faces
839
- elif len(labels) == mesh.n_cells:
840
- mesh.cell_data['Labels'] = labels
841
-
842
- # Create a pyvista plotter
843
- plotter = pv.Plotter()
844
-
845
- cmap = plt.cm.get_cmap('jet', 27) # Using a colormap with sufficient distinct colors
846
-
847
- colors = cmap(np.linspace(0, 1, 27)) # Generate colors
848
-
849
- # Convert colors to a format acceptable by PyVista
850
- colormap = mcolors.ListedColormap(colors)
851
-
852
- # Add the mesh to the plotter with labels as a scalar field
853
- #plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap='jet')
854
- plotter.add_mesh(mesh, scalars='Labels', show_scalar_bar=True, cmap=colormap, clim=[0, 27])
855
-
856
- # Show the plot
857
- #plotter.show()
858
- ## Send to streamlit
859
- with st.expander("**View Segmentation Result** - ", expanded=False):
860
- stpyvista(plotter)
861
-
862
  # Configure Streamlit page
863
- st.set_page_config(page_title="Teeth Segmentation", page_icon="🦷")
864
 
865
- class Segment(TeethApp):
866
  def __init__(self):
867
  TeethApp.__init__(self)
868
  self.build_app()
869
 
870
  def build_app(self):
871
-
872
- st.title("Segment Intra-oral Scans")
873
- st.markdown("Identify and segment teeth. Segmentation is performed using MeshSegNet, a deep learning model trained on both upper and lower jaws.")
874
-
875
- inputs = st.radio(
876
- "Select scan for segmentation:",
877
- ("Upload Scan", "Example Scan"),
878
- )
879
- import pyvista as pv
880
- if inputs == "Example Scan":
881
- st.markdown("Expected time per prediction: 7-10 min.")
882
- mesh = pv.read("ZOUIF2W4_upper.obj")
883
- plotter = pv.Plotter()
884
-
885
- # Add the mesh to the plotter
886
- plotter.add_mesh(mesh, color='white', show_edges=False)
887
- segment = st.button(
888
- "✔️ Submit",
889
- help="Submit 3D scan for segmentation",
890
- )
891
- with st.expander("View Scan", expanded=False):
892
- stpyvista(plotter)
893
-
894
- if segment:
895
- segmentation_main("ZOUIF2W4_upper.obj")
896
-
897
-
898
-
899
- elif inputs == "Upload Scan":
900
- file = st.file_uploader("Please upload an OBJ Object file", type=["OBJ"])
901
- st.markdown("Expected time per prediction: 7-10 min.")
902
- if file is not None:
903
- # save the uploaded file to disk
904
- with open("file.obj", "wb") as buffer:
905
- shutil.copyfileobj(file, buffer)
906
- # 复制数据
907
- obj_path = "file.obj"
908
-
909
- mesh = pv.read(obj_path)
910
- plotter = pv.Plotter()
911
-
912
- # Add the mesh to the plotter
913
- plotter.add_mesh(mesh, color='white', show_edges=False)
914
- segment = st.button(
915
- "✔️ Submit",
916
- help="Submit 3D scan for segmentation",
917
- )
918
- with st.expander("View Scan", expanded=False):
919
- stpyvista(plotter)
920
-
921
- if segment:
922
- segmentation_main(obj_path)
923
-
924
-
925
-
926
-
927
 
928
  if __name__ == "__main__":
929
- app = Segment()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from streamlit import session_state as session
3
 
4
  from PIL import Image
5
 
 
8
  # Font
9
  with open("utils/style.css") as css:
10
  st.markdown(f"<style>{css.read()}</style>", unsafe_allow_html=True)
11
+
12
  # Logo
13
  self.image_path = "utils/teeth-295404_1280.png"
14
  self.image = Image.open(self.image_path)
 
30
  unsafe_allow_html=True,
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  # Configure Streamlit page
34
+ st.set_page_config(page_title="Teeth Segmentation", page_icon="")
35
 
36
+ class Intro(TeethApp):
37
  def __init__(self):
38
  TeethApp.__init__(self)
39
  self.build_app()
40
 
41
  def build_app(self):
42
+ st.title("AI-assited Tooth Segmentation")
43
+ st.markdown("This app automatically segments intra-oral scans of teeth using machine learning.")
44
+ st.markdown("Head to the 'Segment' tab to try it out!")
45
+ st.markdown("**Example:**")
46
+ st.image("illu.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
+ app = Intro()