Spaces:
Build error
Build error
Ren Jiawei
commited on
Commit
•
ac0541e
1
Parent(s):
1c55e0d
update
Browse files
app.py
CHANGED
@@ -19,14 +19,14 @@ with open('shape_names.txt') as f:
|
|
19 |
|
20 |
model_gda = GDANET()
|
21 |
model_gda = nn.DataParallel(model_gda)
|
22 |
-
|
23 |
-
model_gda.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
|
24 |
model_gda.eval()
|
25 |
|
26 |
model_dgcnn = DGCNN()
|
27 |
model_dgcnn = nn.DataParallel(model_dgcnn)
|
28 |
-
|
29 |
-
model_dgcnn.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/dgcnn.t7', map_location=torch.device('cpu')))
|
30 |
model_dgcnn.eval()
|
31 |
|
32 |
def pyplot_draw_point_cloud(points, corruption):
|
@@ -68,11 +68,11 @@ def load_dataset(corruption_idx, severity):
|
|
68 |
]
|
69 |
corruption_type = corruptions[corruption_idx]
|
70 |
if corruption_type == 'clean':
|
71 |
-
|
72 |
-
f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '.h5'))
|
73 |
else:
|
74 |
-
|
75 |
-
f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '_{}'.format(severity - 1) + '.h5'))
|
76 |
data = f['data'][:].astype('float32')
|
77 |
label = f['label'][:].astype('int64')
|
78 |
f.close()
|
|
|
19 |
|
20 |
model_gda = GDANET()
|
21 |
model_gda = nn.DataParallel(model_gda)
|
22 |
+
model_gda.load_state_dict(torch.load('./GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
|
23 |
+
# model_gda.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/GDANet_WOLFMix.t7', map_location=torch.device('cpu')))
|
24 |
model_gda.eval()
|
25 |
|
26 |
model_dgcnn = DGCNN()
|
27 |
model_dgcnn = nn.DataParallel(model_dgcnn)
|
28 |
+
model_dgcnn.load_state_dict(torch.load('./dgcnn.t7', map_location=torch.device('cpu')))
|
29 |
+
# model_dgcnn.load_state_dict(torch.load('/Users/renjiawei/Downloads/pretrained_models/dgcnn.t7', map_location=torch.device('cpu')))
|
30 |
model_dgcnn.eval()
|
31 |
|
32 |
def pyplot_draw_point_cloud(points, corruption):
|
|
|
68 |
]
|
69 |
corruption_type = corruptions[corruption_idx]
|
70 |
if corruption_type == 'clean':
|
71 |
+
f = h5py.File(osp.join('modelnet_c', corruption_type + '.h5'))
|
72 |
+
# f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '.h5'))
|
73 |
else:
|
74 |
+
f = h5py.File(osp.join('modelnet_c', corruption_type + '_{}'.format(severity-1) + '.h5'))
|
75 |
+
# f = h5py.File(osp.join('/Users/renjiawei/Downloads/modelnet_c', corruption_type + '_{}'.format(severity - 1) + '.h5'))
|
76 |
data = f['data'][:].astype('float32')
|
77 |
label = f['label'][:].astype('int64')
|
78 |
f.close()
|