Spaces:
Sleeping
Sleeping
Upload 30 files
Browse files- cortex_DIM/__init__.py +0 -0
- cortex_DIM/__pycache__/__init__.cpython-39.pyc +0 -0
- cortex_DIM/configs/__init__.py +0 -0
- cortex_DIM/configs/convnets.py +98 -0
- cortex_DIM/configs/resnets.py +151 -0
- cortex_DIM/functions/__init__.py +0 -0
- cortex_DIM/functions/__pycache__/__init__.cpython-39.pyc +0 -0
- cortex_DIM/functions/__pycache__/gan_losses.cpython-39.pyc +0 -0
- cortex_DIM/functions/__pycache__/misc.cpython-39.pyc +0 -0
- cortex_DIM/functions/dim_losses.py +224 -0
- cortex_DIM/functions/gan_losses.py +95 -0
- cortex_DIM/functions/misc.py +39 -0
- cortex_DIM/nn_modules/__init__.py +0 -0
- cortex_DIM/nn_modules/__pycache__/__init__.cpython-39.pyc +0 -0
- cortex_DIM/nn_modules/__pycache__/mi_networks.cpython-39.pyc +0 -0
- cortex_DIM/nn_modules/__pycache__/misc.cpython-39.pyc +0 -0
- cortex_DIM/nn_modules/convnet.py +352 -0
- cortex_DIM/nn_modules/encoder.py +96 -0
- cortex_DIM/nn_modules/mi_networks.py +106 -0
- cortex_DIM/nn_modules/misc.py +130 -0
- cortex_DIM/nn_modules/resnet.py +297 -0
- gin.py +188 -0
- model.py +66 -0
- requirements.txt +8 -0
- utils/clean_data.py +192 -0
- utils/edge_features.py +113 -0
- utils/emb_model/Edge_64.pt +3 -0
- utils/emb_model/Node_64.pt +3 -0
- utils/node_features.py +123 -0
- utils/shape_features.py +625 -0
cortex_DIM/__init__.py
ADDED
File without changes
|
cortex_DIM/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (205 Bytes). View file
|
|
cortex_DIM/configs/__init__.py
ADDED
File without changes
|
cortex_DIM/configs/convnets.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Basic convnet hyperparameters.
|
2 |
+
|
3 |
+
conv_args are in format (dim_h, f_size, stride, pad batch_norm, dropout, nonlinearity, pool)
|
4 |
+
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
|
5 |
+
|
6 |
+
'''
|
7 |
+
|
8 |
+
from cortex_DIM.nn_modules.encoder import ConvnetEncoder, FoldedConvnetEncoder
|
9 |
+
|
10 |
+
|
11 |
+
# Basic DCGAN-like encoders
|
12 |
+
|
13 |
+
_basic28x28 = dict(
|
14 |
+
Encoder=ConvnetEncoder,
|
15 |
+
conv_args=[(64, 5, 2, 2, True, False, 'ReLU', None),
|
16 |
+
(128, 5, 2, 2, True, False, 'ReLU', None)],
|
17 |
+
fc_args=[(1024, True, False, 'ReLU', None)],
|
18 |
+
local_idx=1,
|
19 |
+
fc_idx=0
|
20 |
+
)
|
21 |
+
|
22 |
+
_basic32x32 = dict(
|
23 |
+
Encoder=ConvnetEncoder,
|
24 |
+
conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None),
|
25 |
+
(128, 4, 2, 1, True, False, 'ReLU', None),
|
26 |
+
(256, 4, 2, 1, True, False, 'ReLU', None)],
|
27 |
+
fc_args=[(1024, True, False, 'ReLU')],
|
28 |
+
local_idx=1,
|
29 |
+
conv_idx=2,
|
30 |
+
fc_idx=0
|
31 |
+
)
|
32 |
+
|
33 |
+
_basic64x64 = dict(
|
34 |
+
Encoder=ConvnetEncoder,
|
35 |
+
conv_args=[(64, 4, 2, 1, True, False, 'ReLU', None),
|
36 |
+
(128, 4, 2, 1, True, False, 'ReLU', None),
|
37 |
+
(256, 4, 2, 1, True, False, 'ReLU', None),
|
38 |
+
(512, 4, 2, 1, True, False, 'ReLU', None)],
|
39 |
+
fc_args=[(1024, True, False, 'ReLU')],
|
40 |
+
local_idx=2,
|
41 |
+
conv_idx=3,
|
42 |
+
fc_idx=0
|
43 |
+
)
|
44 |
+
|
45 |
+
# Alexnet-like encoders
|
46 |
+
|
47 |
+
_alex64x64 = dict(
|
48 |
+
Encoder=ConvnetEncoder,
|
49 |
+
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
|
50 |
+
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
|
51 |
+
(384, 3, 1, 1, True, False, 'ReLU', None),
|
52 |
+
(384, 3, 1, 1, True, False, 'ReLU', None),
|
53 |
+
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))],
|
54 |
+
fc_args=[(4096, True, False, 'ReLU'),
|
55 |
+
(4096, True, False, 'ReLU')],
|
56 |
+
local_idx=2,
|
57 |
+
conv_idx=4,
|
58 |
+
fc_idx=1
|
59 |
+
)
|
60 |
+
|
61 |
+
_foldalex64x64 = dict(
|
62 |
+
Encoder=FoldedConvnetEncoder,
|
63 |
+
crop_size=16,
|
64 |
+
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
|
65 |
+
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
|
66 |
+
(384, 3, 1, 1, True, False, 'ReLU', None),
|
67 |
+
(384, 3, 1, 1, True, False, 'ReLU', None),
|
68 |
+
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2))],
|
69 |
+
fc_args=[(4096, True, False, 'ReLU'),
|
70 |
+
(4096, True, False, 'ReLU')],
|
71 |
+
local_idx=4,
|
72 |
+
fc_idx=1
|
73 |
+
)
|
74 |
+
|
75 |
+
_foldmultialex64x64 = dict(
|
76 |
+
Encoder=FoldedConvnetEncoder,
|
77 |
+
crop_size=16,
|
78 |
+
conv_args=[(96, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
|
79 |
+
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
|
80 |
+
(384, 3, 1, 1, True, False, 'ReLU', None),
|
81 |
+
(384, 3, 1, 1, True, False, 'ReLU', None),
|
82 |
+
(192, 3, 1, 1, True, False, 'ReLU', ('MaxPool2d', 3, 2)),
|
83 |
+
(192, 3, 1, 0, True, False, 'ReLU', None),
|
84 |
+
(192, 1, 1, 0, True, False, 'ReLU', None)],
|
85 |
+
fc_args=[(4096, True, False, 'ReLU')],
|
86 |
+
local_idx=4,
|
87 |
+
multi_idx=6,
|
88 |
+
fc_idx=1
|
89 |
+
)
|
90 |
+
|
91 |
+
configs = dict(
|
92 |
+
basic28x28=_basic28x28,
|
93 |
+
basic32x32=_basic32x32,
|
94 |
+
basic64x64=_basic64x64,
|
95 |
+
alex64x64=_alex64x64,
|
96 |
+
foldalex64x64=_foldalex64x64,
|
97 |
+
foldmultialex64x64=_foldmultialex64x64
|
98 |
+
)
|
cortex_DIM/configs/resnets.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Configurations for ResNets
|
2 |
+
|
3 |
+
"""
|
4 |
+
|
5 |
+
from cortex_DIM.nn_modules.encoder import ResnetEncoder, FoldedResnetEncoder
|
6 |
+
|
7 |
+
|
8 |
+
_resnet19_32x32 = dict(
|
9 |
+
Encoder=ResnetEncoder,
|
10 |
+
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
|
11 |
+
res_args=[
|
12 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
13 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
14 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
15 |
+
1),
|
16 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
17 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
18 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
19 |
+
1),
|
20 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
21 |
+
(128, 3, 2, 1, True, False, 'ReLU', None),
|
22 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
23 |
+
1),
|
24 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
25 |
+
(128, 3, 1, 1, True, False, 'ReLU', None),
|
26 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
27 |
+
1),
|
28 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
29 |
+
(256, 3, 2, 1, True, False, 'ReLU', None),
|
30 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
31 |
+
1),
|
32 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
33 |
+
(256, 3, 1, 1, True, False, 'ReLU', None),
|
34 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
35 |
+
1)
|
36 |
+
],
|
37 |
+
fc_args=[(1024, True, False, 'ReLU')],
|
38 |
+
local_idx=4,
|
39 |
+
fc_idx=0
|
40 |
+
)
|
41 |
+
|
42 |
+
_foldresnet19_32x32 = dict(
|
43 |
+
Encoder=FoldedResnetEncoder,
|
44 |
+
crop_size=8,
|
45 |
+
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
|
46 |
+
res_args=[
|
47 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
48 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
49 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
50 |
+
1),
|
51 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
52 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
53 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
54 |
+
1),
|
55 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
56 |
+
(128, 3, 2, 1, True, False, 'ReLU', None),
|
57 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
58 |
+
1),
|
59 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
60 |
+
(128, 3, 1, 1, True, False, 'ReLU', None),
|
61 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
62 |
+
1),
|
63 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
64 |
+
(256, 3, 2, 1, True, False, 'ReLU', None),
|
65 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
66 |
+
1),
|
67 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
68 |
+
(256, 3, 1, 1, True, False, 'ReLU', None),
|
69 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
70 |
+
1)
|
71 |
+
],
|
72 |
+
fc_args=[(1024, True, False, 'ReLU')],
|
73 |
+
local_idx=6,
|
74 |
+
fc_idx=0
|
75 |
+
)
|
76 |
+
|
77 |
+
_resnet34_32x32 = dict(
|
78 |
+
Encoder=ResnetEncoder,
|
79 |
+
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
|
80 |
+
res_args=[
|
81 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
82 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
83 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
84 |
+
1),
|
85 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
86 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
87 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
88 |
+
2),
|
89 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
90 |
+
(128, 3, 2, 1, True, False, 'ReLU', None),
|
91 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
92 |
+
1),
|
93 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
94 |
+
(128, 3, 1, 1, True, False, 'ReLU', None),
|
95 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
96 |
+
5),
|
97 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
98 |
+
(256, 3, 2, 1, True, False, 'ReLU', None),
|
99 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
100 |
+
1),
|
101 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
102 |
+
(256, 3, 1, 1, True, False, 'ReLU', None),
|
103 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
104 |
+
2)
|
105 |
+
],
|
106 |
+
fc_args=[(1024, True, False, 'ReLU')],
|
107 |
+
local_idx=2,
|
108 |
+
fc_idx=0
|
109 |
+
)
|
110 |
+
|
111 |
+
_foldresnet34_32x32 = dict(
|
112 |
+
Encoder=FoldedResnetEncoder,
|
113 |
+
crop_size=8,
|
114 |
+
conv_before_args=[(64, 3, 2, 1, True, False, 'ReLU', None)],
|
115 |
+
res_args=[
|
116 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
117 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
118 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
119 |
+
1),
|
120 |
+
([(64, 1, 1, 0, True, False, 'ReLU', None),
|
121 |
+
(64, 3, 1, 1, True, False, 'ReLU', None),
|
122 |
+
(64 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
123 |
+
2),
|
124 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
125 |
+
(128, 3, 2, 1, True, False, 'ReLU', None),
|
126 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
127 |
+
1),
|
128 |
+
([(128, 1, 1, 0, True, False, 'ReLU', None),
|
129 |
+
(128, 3, 1, 1, True, False, 'ReLU', None),
|
130 |
+
(128 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
131 |
+
5),
|
132 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
133 |
+
(256, 3, 2, 1, True, False, 'ReLU', None),
|
134 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
135 |
+
1),
|
136 |
+
([(256, 1, 1, 0, True, False, 'ReLU', None),
|
137 |
+
(256, 3, 1, 1, True, False, 'ReLU', None),
|
138 |
+
(256 * 4, 1, 1, 0, True, False, 'ReLU', None)],
|
139 |
+
2)
|
140 |
+
],
|
141 |
+
fc_args=[(1024, True, False, 'ReLU')],
|
142 |
+
local_idx=12,
|
143 |
+
fc_idx=0
|
144 |
+
)
|
145 |
+
|
146 |
+
configs = dict(
|
147 |
+
resnet19_32x32=_resnet19_32x32,
|
148 |
+
resnet34_32x32=_resnet34_32x32,
|
149 |
+
foldresnet19_32x32=_foldresnet19_32x32,
|
150 |
+
foldresnet34_32x32=_foldresnet34_32x32
|
151 |
+
)
|
cortex_DIM/functions/__init__.py
ADDED
File without changes
|
cortex_DIM/functions/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (215 Bytes). View file
|
|
cortex_DIM/functions/__pycache__/gan_losses.cpython-39.pyc
ADDED
Binary file (2.2 kB). View file
|
|
cortex_DIM/functions/__pycache__/misc.cpython-39.pyc
ADDED
Binary file (1.04 kB). View file
|
|
cortex_DIM/functions/dim_losses.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''cortex_DIM losses.
|
2 |
+
|
3 |
+
'''
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from cortex_DIM.functions.gan_losses import get_positive_expectation, get_negative_expectation
|
11 |
+
|
12 |
+
|
13 |
+
def fenchel_dual_loss(l, g, measure=None):
|
14 |
+
'''Computes the f-divergence distance between positive and negative joint distributions.
|
15 |
+
|
16 |
+
Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
|
17 |
+
Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
l: Local feature map.
|
21 |
+
g: Global features.
|
22 |
+
measure: f-divergence measure.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
torch.Tensor: Loss.
|
26 |
+
|
27 |
+
'''
|
28 |
+
N, local_units, n_locs = l.size()
|
29 |
+
l = l.permute(0, 2, 1)
|
30 |
+
l = l.reshape(-1, local_units)
|
31 |
+
|
32 |
+
u = torch.mm(g, l.t())
|
33 |
+
u = u.reshape(N, N, -1)
|
34 |
+
mask = torch.eye(N).cuda()
|
35 |
+
n_mask = 1 - mask
|
36 |
+
|
37 |
+
E_pos = get_positive_expectation(u, measure, average=False).mean(2)
|
38 |
+
E_neg = get_negative_expectation(u, measure, average=False).mean(2)
|
39 |
+
E_pos = (E_pos * mask).sum() / mask.sum()
|
40 |
+
E_neg = (E_neg * n_mask).sum() / n_mask.sum()
|
41 |
+
loss = E_neg - E_pos
|
42 |
+
return loss
|
43 |
+
|
44 |
+
|
45 |
+
def multi_fenchel_dual_loss(l, m, measure=None):
|
46 |
+
'''Computes the f-divergence distance between positive and negative joint distributions.
|
47 |
+
|
48 |
+
Used for multiple globals.
|
49 |
+
|
50 |
+
Divergences supported are Jensen-Shannon `JSD`, `GAN` (equivalent to JSD),
|
51 |
+
Squared Hellinger `H2`, Chi-squeared `X2`, `KL`, and reverse KL `RKL`.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
l: Local feature map.
|
55 |
+
m: Multiple globals feature map.
|
56 |
+
measure: f-divergence measure.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
torch.Tensor: Loss.
|
60 |
+
|
61 |
+
'''
|
62 |
+
N, units, n_locals = l.size()
|
63 |
+
n_multis = m.size(2)
|
64 |
+
|
65 |
+
l = l.view(N, units, n_locals)
|
66 |
+
l = l.permute(0, 2, 1)
|
67 |
+
l = l.reshape(-1, units)
|
68 |
+
|
69 |
+
m = m.view(N, units, n_multis)
|
70 |
+
m = m.permute(0, 2, 1)
|
71 |
+
m = m.reshape(-1, units)
|
72 |
+
|
73 |
+
u = torch.mm(m, l.t())
|
74 |
+
u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
|
75 |
+
|
76 |
+
mask = torch.eye(N).cuda()
|
77 |
+
n_mask = 1 - mask
|
78 |
+
|
79 |
+
E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2)
|
80 |
+
E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2)
|
81 |
+
E_pos = (E_pos * mask).sum() / mask.sum()
|
82 |
+
E_neg = (E_neg * n_mask).sum() / n_mask.sum()
|
83 |
+
loss = E_neg - E_pos
|
84 |
+
return loss
|
85 |
+
|
86 |
+
|
87 |
+
def nce_loss(l, g):
|
88 |
+
'''Computes the noise contrastive estimation-based loss.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
l: Local feature map.
|
92 |
+
g: Global features.
|
93 |
+
|
94 |
+
Returns:
|
95 |
+
torch.Tensor: Loss.
|
96 |
+
|
97 |
+
'''
|
98 |
+
N, local_units, n_locs = l.size()
|
99 |
+
l_p = l.permute(0, 2, 1)
|
100 |
+
u_p = torch.matmul(l_p, g.unsqueeze(dim=2))
|
101 |
+
|
102 |
+
l_n = l_p.reshape(-1, local_units)
|
103 |
+
u_n = torch.mm(g, l_n.t())
|
104 |
+
u_n = u_n.reshape(N, N, n_locs)
|
105 |
+
|
106 |
+
mask = torch.eye(N).unsqueeze(dim=2).cuda()
|
107 |
+
n_mask = 1 - mask
|
108 |
+
|
109 |
+
u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples
|
110 |
+
u_n = u_n.reshape(N, -1).unsqueeze(dim=1).expand(-1, n_locs, -1)
|
111 |
+
|
112 |
+
pred_lgt = torch.cat([u_p, u_n], dim=2)
|
113 |
+
pred_log = F.log_softmax(pred_lgt, dim=2)
|
114 |
+
loss = -pred_log[:, :, 0].mean()
|
115 |
+
return loss
|
116 |
+
|
117 |
+
|
118 |
+
def multi_nce_loss(l, m):
|
119 |
+
'''
|
120 |
+
|
121 |
+
Used for multiple globals.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
l: Local feature map.
|
125 |
+
m: Multiple globals feature map.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
torch.Tensor: Loss.
|
129 |
+
|
130 |
+
'''
|
131 |
+
N, units, n_locals = l.size()
|
132 |
+
_, _ , n_multis = m.size()
|
133 |
+
|
134 |
+
l = l.view(N, units, n_locals)
|
135 |
+
m = m.view(N, units, n_multis)
|
136 |
+
l_p = l.permute(0, 2, 1)
|
137 |
+
m_p = m.permute(0, 2, 1)
|
138 |
+
u_p = torch.matmul(l_p, m).unsqueeze(2)
|
139 |
+
|
140 |
+
l_n = l_p.reshape(-1, units)
|
141 |
+
m_n = m_p.reshape(-1, units)
|
142 |
+
u_n = torch.mm(m_n, l_n.t())
|
143 |
+
u_n = u_n.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
|
144 |
+
|
145 |
+
mask = torch.eye(N)[:, :, None, None].cuda()
|
146 |
+
n_mask = 1 - mask
|
147 |
+
|
148 |
+
u_n = (n_mask * u_n) - (10. * (1 - n_mask)) # mask out "self" examples
|
149 |
+
u_n = u_n.reshape(N, N * n_locals, n_multis).unsqueeze(dim=1).expand(-1, n_locals, -1, -1)
|
150 |
+
|
151 |
+
pred_lgt = torch.cat([u_p, u_n], dim=2)
|
152 |
+
pred_log = F.log_softmax(pred_lgt, dim=2)
|
153 |
+
loss = -pred_log[:, :, 0].mean()
|
154 |
+
|
155 |
+
return loss
|
156 |
+
|
157 |
+
|
158 |
+
def donsker_varadhan_loss(l, g):
|
159 |
+
'''
|
160 |
+
|
161 |
+
Args:
|
162 |
+
l: Local feature map.
|
163 |
+
g: Global features.
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
torch.Tensor: Loss.
|
167 |
+
|
168 |
+
'''
|
169 |
+
N, local_units, n_locs = l.size()
|
170 |
+
l = l.permute(0, 2, 1)
|
171 |
+
l = l.reshape(-1, local_units)
|
172 |
+
|
173 |
+
u = torch.mm(g, l.t())
|
174 |
+
u = u.reshape(N, N, n_locs)
|
175 |
+
|
176 |
+
mask = torch.eye(N).cuda()
|
177 |
+
n_mask = (1 - mask)[:, :, None]
|
178 |
+
|
179 |
+
E_pos = (u.mean(2) * mask).sum() / mask.sum()
|
180 |
+
|
181 |
+
u -= 100 * (1 - n_mask)
|
182 |
+
u_max = torch.max(u)
|
183 |
+
E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum())
|
184 |
+
loss = E_neg - E_pos
|
185 |
+
return loss
|
186 |
+
|
187 |
+
|
188 |
+
def multi_donsker_varadhan_loss(l, m):
|
189 |
+
'''
|
190 |
+
|
191 |
+
Used for multiple globals.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
l: Local feature map.
|
195 |
+
m: Multiple globals feature map.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
torch.Tensor: Loss.
|
199 |
+
|
200 |
+
'''
|
201 |
+
N, units, n_locals = l.size()
|
202 |
+
n_multis = m.size(2)
|
203 |
+
|
204 |
+
l = l.view(N, units, n_locals)
|
205 |
+
l = l.permute(0, 2, 1)
|
206 |
+
l = l.reshape(-1, units)
|
207 |
+
|
208 |
+
m = m.view(N, units, n_multis)
|
209 |
+
m = m.permute(0, 2, 1)
|
210 |
+
m = m.reshape(-1, units)
|
211 |
+
|
212 |
+
u = torch.mm(m, l.t())
|
213 |
+
u = u.reshape(N, n_multis, N, n_locals).permute(0, 2, 3, 1)
|
214 |
+
|
215 |
+
mask = torch.eye(N).cuda()
|
216 |
+
n_mask = 1 - mask
|
217 |
+
|
218 |
+
E_pos = (u.mean(2) * mask).sum() / mask.sum()
|
219 |
+
|
220 |
+
u -= 100 * (1 - n_mask)
|
221 |
+
u_max = torch.max(u)
|
222 |
+
E_neg = torch.log((n_mask * torch.exp(u - u_max)).sum() + 1e-6) + u_max - math.log(n_mask.sum())
|
223 |
+
loss = E_neg - E_pos
|
224 |
+
return loss
|
cortex_DIM/functions/gan_losses.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from cortex_DIM.functions.misc import log_sum_exp
|
11 |
+
|
12 |
+
|
13 |
+
def raise_measure_error(measure):
|
14 |
+
supported_measures = ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2', 'W1']
|
15 |
+
raise NotImplementedError(
|
16 |
+
'Measure `{}` not supported. Supported: {}'.format(measure,
|
17 |
+
supported_measures))
|
18 |
+
|
19 |
+
|
20 |
+
def get_positive_expectation(p_samples, measure, average=True):
|
21 |
+
"""Computes the positive part of a divergence / difference.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
p_samples: Positive samples.
|
25 |
+
measure: Measure to compute for.
|
26 |
+
average: Average the result over samples.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
torch.Tensor
|
30 |
+
|
31 |
+
"""
|
32 |
+
log_2 = math.log(2.)
|
33 |
+
|
34 |
+
if measure == 'GAN':
|
35 |
+
Ep = - F.softplus(-p_samples)
|
36 |
+
elif measure == 'JSD':
|
37 |
+
Ep = log_2 - F.softplus(- p_samples)
|
38 |
+
elif measure == 'X2':
|
39 |
+
Ep = p_samples ** 2
|
40 |
+
elif measure == 'KL':
|
41 |
+
Ep = p_samples + 1.
|
42 |
+
elif measure == 'RKL':
|
43 |
+
Ep = -torch.exp(-p_samples)
|
44 |
+
elif measure == 'DV':
|
45 |
+
Ep = p_samples
|
46 |
+
elif measure == 'H2':
|
47 |
+
Ep = 1. - torch.exp(-p_samples)
|
48 |
+
elif measure == 'W1':
|
49 |
+
Ep = p_samples
|
50 |
+
else:
|
51 |
+
raise_measure_error(measure)
|
52 |
+
|
53 |
+
if average:
|
54 |
+
return Ep.mean()
|
55 |
+
else:
|
56 |
+
return Ep
|
57 |
+
|
58 |
+
|
59 |
+
def get_negative_expectation(q_samples, measure, average=True):
|
60 |
+
"""Computes the negative part of a divergence / difference.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
q_samples: Negative samples.
|
64 |
+
measure: Measure to compute for.
|
65 |
+
average: Average the result over samples.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.Tensor
|
69 |
+
|
70 |
+
"""
|
71 |
+
log_2 = math.log(2.)
|
72 |
+
|
73 |
+
if measure == 'GAN':
|
74 |
+
Eq = F.softplus(-q_samples) + q_samples
|
75 |
+
elif measure == 'JSD':
|
76 |
+
Eq = F.softplus(-q_samples) + q_samples - log_2
|
77 |
+
elif measure == 'X2':
|
78 |
+
Eq = -0.5 * ((torch.sqrt(q_samples ** 2) + 1.) ** 2)
|
79 |
+
elif measure == 'KL':
|
80 |
+
Eq = torch.exp(q_samples)
|
81 |
+
elif measure == 'RKL':
|
82 |
+
Eq = q_samples - 1.
|
83 |
+
elif measure == 'DV':
|
84 |
+
Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
|
85 |
+
elif measure == 'H2':
|
86 |
+
Eq = torch.exp(q_samples) - 1.
|
87 |
+
elif measure == 'W1':
|
88 |
+
Eq = q_samples
|
89 |
+
else:
|
90 |
+
raise_measure_error(measure)
|
91 |
+
|
92 |
+
if average:
|
93 |
+
return Eq.mean()
|
94 |
+
else:
|
95 |
+
return Eq
|
cortex_DIM/functions/misc.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Miscilaneous functions.
|
2 |
+
|
3 |
+
"""
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def log_sum_exp(x, axis=None):
|
9 |
+
"""Log sum exp function
|
10 |
+
|
11 |
+
Args:
|
12 |
+
x: Input.
|
13 |
+
axis: Axis over which to perform sum.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
torch.Tensor: log sum exp
|
17 |
+
|
18 |
+
"""
|
19 |
+
x_max = torch.max(x, axis)[0]
|
20 |
+
y = torch.log((torch.exp(x - x_max)).sum(axis)) + x_max
|
21 |
+
return y
|
22 |
+
|
23 |
+
|
24 |
+
def random_permute(X):
|
25 |
+
"""Randomly permutes a tensor.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
X: Input tensor.
|
29 |
+
|
30 |
+
Returns:
|
31 |
+
torch.Tensor
|
32 |
+
|
33 |
+
"""
|
34 |
+
X = X.transpose(1, 2)
|
35 |
+
b = torch.rand((X.size(0), X.size(1))).cuda()
|
36 |
+
idx = b.sort(0)[1]
|
37 |
+
adx = torch.range(0, X.size(1) - 1).long()
|
38 |
+
X = X[idx, adx[None, :]].transpose(1, 2)
|
39 |
+
return X
|
cortex_DIM/nn_modules/__init__.py
ADDED
File without changes
|
cortex_DIM/nn_modules/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (216 Bytes). View file
|
|
cortex_DIM/nn_modules/__pycache__/mi_networks.cpython-39.pyc
ADDED
Binary file (2.63 kB). View file
|
|
cortex_DIM/nn_modules/__pycache__/misc.cpython-39.pyc
ADDED
Binary file (3.65 kB). View file
|
|
cortex_DIM/nn_modules/convnet.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Convnet encoder module.
|
2 |
+
|
3 |
+
'''
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
#from cortex.built_ins.networks.utils import get_nonlinearity
|
9 |
+
|
10 |
+
from cortex_DIM.nn_modules.misc import Fold, Unfold, View
|
11 |
+
|
12 |
+
|
13 |
+
def infer_conv_size(w, k, s, p):
|
14 |
+
'''Infers the next size after convolution.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
w: Input size.
|
18 |
+
k: Kernel size.
|
19 |
+
s: Stride.
|
20 |
+
p: Padding.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
int: Output size.
|
24 |
+
|
25 |
+
'''
|
26 |
+
x = (w - k + 2 * p) // s + 1
|
27 |
+
return x
|
28 |
+
|
29 |
+
|
30 |
+
class Convnet(nn.Module):
|
31 |
+
'''Basic convnet convenience class.
|
32 |
+
|
33 |
+
Attributes:
|
34 |
+
conv_layers: nn.Sequential of nn.Conv2d layers with batch norm,
|
35 |
+
dropout, nonlinearity.
|
36 |
+
fc_layers: nn.Sequential of nn.Linear layers with batch norm,
|
37 |
+
dropout, nonlinearity.
|
38 |
+
reshape: Simple reshape layer.
|
39 |
+
conv_shape: Shape of the convolutional output.
|
40 |
+
|
41 |
+
'''
|
42 |
+
|
43 |
+
def __init__(self, *args, **kwargs):
|
44 |
+
super().__init__()
|
45 |
+
self.create_layers(*args, **kwargs)
|
46 |
+
|
47 |
+
def create_layers(self, shape, conv_args=None, fc_args=None):
|
48 |
+
'''Creates layers
|
49 |
+
|
50 |
+
conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool)
|
51 |
+
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
|
52 |
+
|
53 |
+
Args:
|
54 |
+
shape: Shape of input.
|
55 |
+
conv_args: List of tuple of convolutional arguments.
|
56 |
+
fc_args: List of tuple of fully-connected arguments.
|
57 |
+
'''
|
58 |
+
|
59 |
+
self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args)
|
60 |
+
|
61 |
+
dim_x, dim_y, dim_out = self.conv_shape
|
62 |
+
dim_r = dim_x * dim_y * dim_out
|
63 |
+
self.reshape = View(-1, dim_r)
|
64 |
+
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
|
65 |
+
|
66 |
+
def create_conv_layers(self, shape, conv_args):
|
67 |
+
'''Creates a set of convolutional layers.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
shape: Input shape.
|
71 |
+
conv_args: List of tuple of convolutional arguments.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
nn.Sequential: a sequence of convolutional layers.
|
75 |
+
|
76 |
+
'''
|
77 |
+
|
78 |
+
conv_layers = nn.Sequential()
|
79 |
+
conv_args = conv_args or []
|
80 |
+
|
81 |
+
dim_x, dim_y, dim_in = shape
|
82 |
+
|
83 |
+
for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args):
|
84 |
+
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
|
85 |
+
conv_block = nn.Sequential()
|
86 |
+
|
87 |
+
if dim_out is not None:
|
88 |
+
conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm))
|
89 |
+
conv_block.add_module(name + 'conv', conv)
|
90 |
+
dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p)
|
91 |
+
else:
|
92 |
+
dim_out = dim_in
|
93 |
+
|
94 |
+
if dropout:
|
95 |
+
conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout))
|
96 |
+
if batch_norm:
|
97 |
+
bn = nn.BatchNorm2d(dim_out)
|
98 |
+
conv_block.add_module(name + 'bn', bn)
|
99 |
+
|
100 |
+
if nonlinearity:
|
101 |
+
nonlinearity = get_nonlinearity(nonlinearity)
|
102 |
+
conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
|
103 |
+
|
104 |
+
if pool:
|
105 |
+
(pool_type, kernel, stride) = pool
|
106 |
+
Pool = getattr(nn, pool_type)
|
107 |
+
conv_block.add_module(name + 'pool', Pool(kernel_size=kernel, stride=stride))
|
108 |
+
dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0)
|
109 |
+
|
110 |
+
conv_layers.add_module(name, conv_block)
|
111 |
+
|
112 |
+
dim_in = dim_out
|
113 |
+
|
114 |
+
dim_out = dim_in
|
115 |
+
|
116 |
+
return conv_layers, (dim_x, dim_y, dim_out)
|
117 |
+
|
118 |
+
def create_linear_layers(self, dim_in, fc_args):
|
119 |
+
'''
|
120 |
+
|
121 |
+
Args:
|
122 |
+
dim_in: Number of input units.
|
123 |
+
fc_args: List of tuple of fully-connected arguments.
|
124 |
+
|
125 |
+
Returns:
|
126 |
+
nn.Sequential.
|
127 |
+
|
128 |
+
'''
|
129 |
+
|
130 |
+
fc_layers = nn.Sequential()
|
131 |
+
fc_args = fc_args or []
|
132 |
+
|
133 |
+
for i, (dim_out, batch_norm, dropout, nonlinearity) in enumerate(fc_args):
|
134 |
+
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
|
135 |
+
fc_block = nn.Sequential()
|
136 |
+
|
137 |
+
if dim_out is not None:
|
138 |
+
fc_block.add_module(name + 'fc', nn.Linear(dim_in, dim_out))
|
139 |
+
else:
|
140 |
+
dim_out = dim_in
|
141 |
+
|
142 |
+
if dropout:
|
143 |
+
fc_block.add_module(name + 'do', nn.Dropout(p=dropout))
|
144 |
+
if batch_norm:
|
145 |
+
bn = nn.BatchNorm1d(dim_out)
|
146 |
+
fc_block.add_module(name + 'bn', bn)
|
147 |
+
|
148 |
+
if nonlinearity:
|
149 |
+
nonlinearity = get_nonlinearity(nonlinearity)
|
150 |
+
fc_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
|
151 |
+
|
152 |
+
fc_layers.add_module(name, fc_block)
|
153 |
+
|
154 |
+
dim_in = dim_out
|
155 |
+
|
156 |
+
return fc_layers, dim_in
|
157 |
+
|
158 |
+
def next_size(self, dim_x, dim_y, k, s, p):
|
159 |
+
'''Infers the next size of a convolutional layer.
|
160 |
+
|
161 |
+
Args:
|
162 |
+
dim_x: First dimension.
|
163 |
+
dim_y: Second dimension.
|
164 |
+
k: Kernel size.
|
165 |
+
s: Stride.
|
166 |
+
p: Padding.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
(int, int): (First output dimension, Second output dimension)
|
170 |
+
|
171 |
+
'''
|
172 |
+
if isinstance(k, int):
|
173 |
+
kx, ky = (k, k)
|
174 |
+
else:
|
175 |
+
kx, ky = k
|
176 |
+
|
177 |
+
if isinstance(s, int):
|
178 |
+
sx, sy = (s, s)
|
179 |
+
else:
|
180 |
+
sx, sy = s
|
181 |
+
|
182 |
+
if isinstance(p, int):
|
183 |
+
px, py = (p, p)
|
184 |
+
else:
|
185 |
+
px, py = p
|
186 |
+
return (infer_conv_size(dim_x, kx, sx, px),
|
187 |
+
infer_conv_size(dim_y, ky, sy, py))
|
188 |
+
|
189 |
+
def forward(self, x: torch.Tensor, return_full_list=False):
|
190 |
+
'''Forward pass
|
191 |
+
|
192 |
+
Args:
|
193 |
+
x: Input.
|
194 |
+
return_full_list: Optional, returns all layer outputs.
|
195 |
+
|
196 |
+
Returns:
|
197 |
+
torch.Tensor or list of torch.Tensor.
|
198 |
+
|
199 |
+
'''
|
200 |
+
if return_full_list:
|
201 |
+
conv_out = []
|
202 |
+
for conv_layer in self.conv_layers:
|
203 |
+
x = conv_layer(x)
|
204 |
+
conv_out.append(x)
|
205 |
+
else:
|
206 |
+
conv_out = self.conv_layers(x)
|
207 |
+
x = conv_out
|
208 |
+
|
209 |
+
x = self.reshape(x)
|
210 |
+
|
211 |
+
if return_full_list:
|
212 |
+
fc_out = []
|
213 |
+
for fc_layer in self.fc_layers:
|
214 |
+
x = fc_layer(x)
|
215 |
+
fc_out.append(x)
|
216 |
+
else:
|
217 |
+
fc_out = self.fc_layers(x)
|
218 |
+
|
219 |
+
return conv_out, fc_out
|
220 |
+
|
221 |
+
|
222 |
+
class FoldedConvnet(Convnet):
|
223 |
+
'''Convnet with strided crop input.
|
224 |
+
|
225 |
+
'''
|
226 |
+
|
227 |
+
def create_layers(self, shape, crop_size=8, conv_args=None, fc_args=None):
|
228 |
+
'''Creates layers
|
229 |
+
|
230 |
+
conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool)
|
231 |
+
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity)
|
232 |
+
|
233 |
+
Args:
|
234 |
+
shape: Shape of input.
|
235 |
+
crop_size: Size of crops
|
236 |
+
conv_args: List of tuple of convolutional arguments.
|
237 |
+
fc_args: List of tuple of fully-connected arguments.
|
238 |
+
'''
|
239 |
+
|
240 |
+
self.crop_size = crop_size
|
241 |
+
|
242 |
+
dim_x, dim_y, dim_in = shape
|
243 |
+
if dim_x != dim_y:
|
244 |
+
raise ValueError('x and y dimensions must be the same to use Folded encoders.')
|
245 |
+
|
246 |
+
self.final_size = 2 * (dim_x // self.crop_size) - 1
|
247 |
+
|
248 |
+
self.unfold = Unfold(dim_x, self.crop_size)
|
249 |
+
self.refold = Fold(dim_x, self.crop_size)
|
250 |
+
|
251 |
+
shape = (self.crop_size, self.crop_size, dim_in)
|
252 |
+
|
253 |
+
self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args)
|
254 |
+
|
255 |
+
dim_x, dim_y, dim_out = self.conv_shape
|
256 |
+
dim_r = dim_x * dim_y * dim_out
|
257 |
+
self.reshape = View(-1, dim_r)
|
258 |
+
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
|
259 |
+
|
260 |
+
def create_conv_layers(self, shape, conv_args):
|
261 |
+
'''Creates a set of convolutional layers.
|
262 |
+
|
263 |
+
Args:
|
264 |
+
shape: Input shape.
|
265 |
+
conv_args: List of tuple of convolutional arguments.
|
266 |
+
|
267 |
+
Returns:
|
268 |
+
nn.Sequential: A sequence of convolutional layers.
|
269 |
+
|
270 |
+
'''
|
271 |
+
|
272 |
+
conv_layers = nn.Sequential()
|
273 |
+
conv_args = conv_args or []
|
274 |
+
dim_x, dim_y, dim_in = shape
|
275 |
+
|
276 |
+
for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args):
|
277 |
+
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1)
|
278 |
+
conv_block = nn.Sequential()
|
279 |
+
|
280 |
+
if dim_out is not None:
|
281 |
+
conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm))
|
282 |
+
conv_block.add_module(name + 'conv', conv)
|
283 |
+
dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p)
|
284 |
+
else:
|
285 |
+
dim_out = dim_in
|
286 |
+
|
287 |
+
if dropout:
|
288 |
+
conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout))
|
289 |
+
if batch_norm:
|
290 |
+
bn = nn.BatchNorm2d(dim_out)
|
291 |
+
conv_block.add_module(name + 'bn', bn)
|
292 |
+
|
293 |
+
if nonlinearity:
|
294 |
+
nonlinearity = get_nonlinearity(nonlinearity)
|
295 |
+
conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity)
|
296 |
+
|
297 |
+
if pool:
|
298 |
+
(pool_type, kernel, stride) = pool
|
299 |
+
Pool = getattr(nn, pool_type)
|
300 |
+
conv_block.add_module('pool', Pool(kernel_size=kernel, stride=stride))
|
301 |
+
dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0)
|
302 |
+
|
303 |
+
conv_layers.add_module(name, conv_block)
|
304 |
+
|
305 |
+
dim_in = dim_out
|
306 |
+
|
307 |
+
if dim_x != dim_y:
|
308 |
+
raise ValueError('dim_x and dim_y do not match.')
|
309 |
+
|
310 |
+
if dim_x == 1:
|
311 |
+
dim_x = self.final_size
|
312 |
+
dim_y = self.final_size
|
313 |
+
|
314 |
+
dim_out = dim_in
|
315 |
+
|
316 |
+
return conv_layers, (dim_x, dim_y, dim_out)
|
317 |
+
|
318 |
+
def forward(self, x: torch.Tensor, return_full_list=False):
|
319 |
+
'''Forward pass
|
320 |
+
|
321 |
+
Args:
|
322 |
+
x: Input.
|
323 |
+
return_full_list: Optional, returns all layer outputs.
|
324 |
+
|
325 |
+
Returns:
|
326 |
+
torch.Tensor or list of torch.Tensor.
|
327 |
+
|
328 |
+
'''
|
329 |
+
|
330 |
+
x = self.unfold(x)
|
331 |
+
|
332 |
+
conv_out = []
|
333 |
+
for conv_layer in self.conv_layers:
|
334 |
+
x = conv_layer(x)
|
335 |
+
if x.size(2) == 1:
|
336 |
+
x = self.refold(x)
|
337 |
+
conv_out.append(x)
|
338 |
+
|
339 |
+
x = self.reshape(x)
|
340 |
+
|
341 |
+
if return_full_list:
|
342 |
+
fc_out = []
|
343 |
+
for fc_layer in self.fc_layers:
|
344 |
+
x = fc_layer(x)
|
345 |
+
fc_out.append(x)
|
346 |
+
else:
|
347 |
+
fc_out = self.fc_layers(x)
|
348 |
+
|
349 |
+
if not return_full_list:
|
350 |
+
conv_out = conv_out[-1]
|
351 |
+
|
352 |
+
return conv_out, fc_out
|
cortex_DIM/nn_modules/encoder.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Basic cortex_DIM encoder.
|
2 |
+
|
3 |
+
'''
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet
|
8 |
+
from cortex_DIM.nn_modules.resnet import ResNet, FoldedResNet
|
9 |
+
|
10 |
+
|
11 |
+
def create_encoder(Module):
|
12 |
+
class Encoder(Module):
|
13 |
+
'''Encoder used for cortex_DIM.
|
14 |
+
|
15 |
+
'''
|
16 |
+
|
17 |
+
def __init__(self, *args, local_idx=None, multi_idx=None, conv_idx=None, fc_idx=None, **kwargs):
|
18 |
+
'''
|
19 |
+
|
20 |
+
Args:
|
21 |
+
args: Arguments for parent class.
|
22 |
+
local_idx: Index in list of convolutional layers for local features.
|
23 |
+
multi_idx: Index in list of convolutional layers for multiple globals.
|
24 |
+
conv_idx: Index in list of convolutional layers for intermediate features.
|
25 |
+
fc_idx: Index in list of fully-connected layers for intermediate features.
|
26 |
+
kwargs: Keyword arguments for the parent class.
|
27 |
+
'''
|
28 |
+
|
29 |
+
super().__init__(*args, **kwargs)
|
30 |
+
|
31 |
+
if local_idx is None:
|
32 |
+
raise ValueError('`local_idx` must be set')
|
33 |
+
|
34 |
+
conv_idx = conv_idx or local_idx
|
35 |
+
|
36 |
+
self.local_idx = local_idx
|
37 |
+
self.multi_idx = multi_idx
|
38 |
+
self.conv_idx = conv_idx
|
39 |
+
self.fc_idx = fc_idx
|
40 |
+
|
41 |
+
def forward(self, x: torch.Tensor):
|
42 |
+
'''
|
43 |
+
|
44 |
+
Args:
|
45 |
+
x: Input tensor.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
local_out, multi_out, hidden_out, global_out
|
49 |
+
|
50 |
+
'''
|
51 |
+
|
52 |
+
outs = super().forward(x, return_full_list=True)
|
53 |
+
if len(outs) == 2:
|
54 |
+
conv_out, fc_out = outs
|
55 |
+
else:
|
56 |
+
conv_before_out, res_out, conv_after_out, fc_out = outs
|
57 |
+
conv_out = conv_before_out + res_out + conv_after_out
|
58 |
+
|
59 |
+
local_out = conv_out[self.local_idx]
|
60 |
+
|
61 |
+
if self.multi_idx is not None:
|
62 |
+
multi_out = conv_out[self.multi_idx]
|
63 |
+
else:
|
64 |
+
multi_out = None
|
65 |
+
|
66 |
+
if len(fc_out) > 0:
|
67 |
+
if self.fc_idx is not None:
|
68 |
+
hidden_out = fc_out[self.fc_idx]
|
69 |
+
else:
|
70 |
+
hidden_out = None
|
71 |
+
global_out = fc_out[-1]
|
72 |
+
else:
|
73 |
+
hidden_out = None
|
74 |
+
global_out = None
|
75 |
+
|
76 |
+
conv_out = conv_out[self.conv_idx]
|
77 |
+
|
78 |
+
return local_out, conv_out, multi_out, hidden_out, global_out
|
79 |
+
|
80 |
+
return Encoder
|
81 |
+
|
82 |
+
|
83 |
+
class ConvnetEncoder(create_encoder(Convnet)):
|
84 |
+
pass
|
85 |
+
|
86 |
+
|
87 |
+
class FoldedConvnetEncoder(create_encoder(FoldedConvnet)):
|
88 |
+
pass
|
89 |
+
|
90 |
+
|
91 |
+
class ResnetEncoder(create_encoder(ResNet)):
|
92 |
+
pass
|
93 |
+
|
94 |
+
|
95 |
+
class FoldedResnetEncoder(create_encoder(FoldedResNet)):
|
96 |
+
pass
|
cortex_DIM/nn_modules/mi_networks.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Module for networks used for computing MI.
|
2 |
+
|
3 |
+
"""
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from cortex_DIM.nn_modules.misc import Permute
|
10 |
+
|
11 |
+
|
12 |
+
class MIFCNet(nn.Module):
|
13 |
+
"""Simple custom network for computing MI.
|
14 |
+
|
15 |
+
"""
|
16 |
+
def __init__(self, n_input, n_units):
|
17 |
+
"""
|
18 |
+
|
19 |
+
Args:
|
20 |
+
n_input: Number of input units.
|
21 |
+
n_units: Number of output units.
|
22 |
+
"""
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
assert(n_units >= n_input)
|
26 |
+
|
27 |
+
self.linear_shortcut = nn.Linear(n_input, n_units)
|
28 |
+
self.block_nonlinear = nn.Sequential(
|
29 |
+
nn.Linear(n_input, n_units),
|
30 |
+
nn.BatchNorm1d(n_units),
|
31 |
+
nn.ReLU(),
|
32 |
+
nn.Linear(n_units, n_units)
|
33 |
+
)
|
34 |
+
|
35 |
+
# initialize the initial projection to a sort of noisy copy
|
36 |
+
eye_mask = np.zeros((n_units, n_input), dtype=np.uint8)
|
37 |
+
for i in range(n_input):
|
38 |
+
eye_mask[i, i] = 1
|
39 |
+
|
40 |
+
self.linear_shortcut.weight.data.uniform_(-0.01, 0.01)
|
41 |
+
self.linear_shortcut.weight.data.masked_fill_(torch.tensor(eye_mask), 1.)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
"""
|
45 |
+
|
46 |
+
Args:
|
47 |
+
x: Input tensor.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
torch.Tensor: network output.
|
51 |
+
|
52 |
+
"""
|
53 |
+
h = self.block_nonlinear(x) + self.linear_shortcut(x)
|
54 |
+
return h
|
55 |
+
|
56 |
+
|
57 |
+
class MI1x1ConvNet(nn.Module):
|
58 |
+
"""Simple custorm 1x1 convnet.
|
59 |
+
|
60 |
+
"""
|
61 |
+
def __init__(self, n_input, n_units):
|
62 |
+
"""
|
63 |
+
|
64 |
+
Args:
|
65 |
+
n_input: Number of input units.
|
66 |
+
n_units: Number of output units.
|
67 |
+
"""
|
68 |
+
|
69 |
+
super().__init__()
|
70 |
+
|
71 |
+
self.block_nonlinear = nn.Sequential(
|
72 |
+
nn.Conv1d(n_input, n_units, kernel_size=1, stride=1, padding=0, bias=False),
|
73 |
+
nn.BatchNorm1d(n_units),
|
74 |
+
nn.ReLU(),
|
75 |
+
nn.Conv1d(n_units, n_units, kernel_size=1, stride=1, padding=0, bias=True),
|
76 |
+
)
|
77 |
+
|
78 |
+
self.block_ln = nn.Sequential(
|
79 |
+
Permute(0, 2, 1),
|
80 |
+
nn.LayerNorm(n_units),
|
81 |
+
Permute(0, 2, 1)
|
82 |
+
)
|
83 |
+
|
84 |
+
self.linear_shortcut = nn.Conv1d(n_input, n_units, kernel_size=1,
|
85 |
+
stride=1, padding=0, bias=False)
|
86 |
+
|
87 |
+
# initialize shortcut to be like identity (if possible)
|
88 |
+
if n_units >= n_input:
|
89 |
+
eye_mask = np.zeros((n_units, n_input, 1), dtype=np.uint8)
|
90 |
+
for i in range(n_input):
|
91 |
+
eye_mask[i, i, 0] = 1
|
92 |
+
self.linear_shortcut.weight.data.uniform_(-0.01, 0.01)
|
93 |
+
self.linear_shortcut.weight.data.masked_fill_(torch.tensor(eye_mask), 1.)
|
94 |
+
|
95 |
+
def forward(self, x):
|
96 |
+
"""
|
97 |
+
|
98 |
+
Args:
|
99 |
+
x: Input tensor.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
torch.Tensor: network output.
|
103 |
+
|
104 |
+
"""
|
105 |
+
h = self.block_ln(self.block_nonlinear(x) + self.linear_shortcut(x))
|
106 |
+
return h
|
cortex_DIM/nn_modules/misc.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Various miscellaneous modules
|
2 |
+
|
3 |
+
'''
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
class View(torch.nn.Module):
|
9 |
+
"""Basic reshape module.
|
10 |
+
|
11 |
+
"""
|
12 |
+
def __init__(self, *shape):
|
13 |
+
"""
|
14 |
+
|
15 |
+
Args:
|
16 |
+
*shape: Input shape.
|
17 |
+
"""
|
18 |
+
super().__init__()
|
19 |
+
self.shape = shape
|
20 |
+
|
21 |
+
def forward(self, input):
|
22 |
+
"""Reshapes tensor.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
input: Input tensor.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
torch.Tensor: Flattened tensor.
|
29 |
+
|
30 |
+
"""
|
31 |
+
return input.view(*self.shape)
|
32 |
+
|
33 |
+
|
34 |
+
class Unfold(torch.nn.Module):
|
35 |
+
"""Module for unfolding tensor.
|
36 |
+
|
37 |
+
Performs strided crops on 2d (image) tensors. Stride is assumed to be half the crop size.
|
38 |
+
|
39 |
+
"""
|
40 |
+
def __init__(self, img_size, fold_size):
|
41 |
+
"""
|
42 |
+
|
43 |
+
Args:
|
44 |
+
img_size: Input size.
|
45 |
+
fold_size: Crop size.
|
46 |
+
"""
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
fold_stride = fold_size // 2
|
50 |
+
self.fold_size = fold_size
|
51 |
+
self.fold_stride = fold_stride
|
52 |
+
self.n_locs = 2 * (img_size // fold_size) - 1
|
53 |
+
self.unfold = torch.nn.Unfold((self.fold_size, self.fold_size),
|
54 |
+
stride=(self.fold_stride, self.fold_stride))
|
55 |
+
|
56 |
+
def forward(self, x):
|
57 |
+
"""Unfolds tensor.
|
58 |
+
|
59 |
+
Args:
|
60 |
+
x: Input tensor.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
torch.Tensor: Unfolded tensor.
|
64 |
+
|
65 |
+
"""
|
66 |
+
N = x.size(0)
|
67 |
+
x = self.unfold(x).reshape(N, -1, self.fold_size, self.fold_size, self.n_locs * self.n_locs)\
|
68 |
+
.permute(0, 4, 1, 2, 3)\
|
69 |
+
.reshape(N * self.n_locs * self.n_locs, -1, self.fold_size, self.fold_size)
|
70 |
+
return x
|
71 |
+
|
72 |
+
|
73 |
+
class Fold(torch.nn.Module):
|
74 |
+
"""Module (re)folding tensor.
|
75 |
+
|
76 |
+
Undoes the strided crops above. Works only on 1x1.
|
77 |
+
|
78 |
+
"""
|
79 |
+
def __init__(self, img_size, fold_size):
|
80 |
+
"""
|
81 |
+
|
82 |
+
Args:
|
83 |
+
img_size: Images size.
|
84 |
+
fold_size: Crop size.
|
85 |
+
"""
|
86 |
+
super().__init__()
|
87 |
+
self.n_locs = 2 * (img_size // fold_size) - 1
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
"""(Re)folds tensor.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
x: Input tensor.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
torch.Tensor: Refolded tensor.
|
97 |
+
|
98 |
+
"""
|
99 |
+
dim_c, dim_x, dim_y = x.size()[1:]
|
100 |
+
x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)
|
101 |
+
x = x.reshape(-1, self.n_locs * self.n_locs, dim_c, dim_x * dim_y)\
|
102 |
+
.permute(0, 2, 3, 1)\
|
103 |
+
.reshape(-1, dim_c * dim_x * dim_y, self.n_locs, self.n_locs).contiguous()
|
104 |
+
return x
|
105 |
+
|
106 |
+
|
107 |
+
class Permute(torch.nn.Module):
|
108 |
+
"""Module for permuting axes.
|
109 |
+
|
110 |
+
"""
|
111 |
+
def __init__(self, *perm):
|
112 |
+
"""
|
113 |
+
|
114 |
+
Args:
|
115 |
+
*perm: Permute axes.
|
116 |
+
"""
|
117 |
+
super().__init__()
|
118 |
+
self.perm = perm
|
119 |
+
|
120 |
+
def forward(self, input):
|
121 |
+
"""Permutes axes of tensor.
|
122 |
+
|
123 |
+
Args:
|
124 |
+
input: Input tensor.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
torch.Tensor: permuted tensor.
|
128 |
+
|
129 |
+
"""
|
130 |
+
return input.permute(*self.perm)
|
cortex_DIM/nn_modules/resnet.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Module for making resnet encoders.
|
2 |
+
|
3 |
+
'''
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from cortex_DIM.nn_modules.convnet import Convnet
|
9 |
+
from cortex_DIM.nn_modules.misc import Fold, Unfold, View
|
10 |
+
|
11 |
+
|
12 |
+
_nonlin_idx = 6
|
13 |
+
|
14 |
+
|
15 |
+
class ResBlock(Convnet):
|
16 |
+
'''Residual block for ResNet
|
17 |
+
|
18 |
+
'''
|
19 |
+
|
20 |
+
def create_layers(self, shape, conv_args=None):
|
21 |
+
'''Creates layers
|
22 |
+
|
23 |
+
Args:
|
24 |
+
shape: Shape of input.
|
25 |
+
conv_args: Layer arguments for block.
|
26 |
+
'''
|
27 |
+
|
28 |
+
# Move nonlinearity to a separate step for residual.
|
29 |
+
final_nonlin = conv_args[-1][_nonlin_idx]
|
30 |
+
conv_args[-1] = list(conv_args[-1])
|
31 |
+
conv_args[-1][_nonlin_idx] = None
|
32 |
+
conv_args.append((None, 0, 0, 0, False, False, final_nonlin, None))
|
33 |
+
|
34 |
+
super().create_layers(shape, conv_args=conv_args)
|
35 |
+
|
36 |
+
if self.conv_shape != shape:
|
37 |
+
dim_x, dim_y, dim_in = shape
|
38 |
+
dim_x_, dim_y_, dim_out = self.conv_shape
|
39 |
+
stride = dim_x // dim_x_
|
40 |
+
next_x, _ = self.next_size(dim_x, dim_y, 1, stride, 0)
|
41 |
+
assert next_x == dim_x_, (self.conv_shape, shape)
|
42 |
+
|
43 |
+
self.downsample = nn.Sequential(
|
44 |
+
nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=stride, padding=0, bias=False),
|
45 |
+
nn.BatchNorm2d(dim_out),
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
self.downsample = None
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor):
|
51 |
+
'''Forward pass
|
52 |
+
|
53 |
+
Args:
|
54 |
+
x: Input.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
torch.Tensor or list of torch.Tensor.
|
58 |
+
|
59 |
+
'''
|
60 |
+
|
61 |
+
if self.downsample is not None:
|
62 |
+
residual = self.downsample(x)
|
63 |
+
else:
|
64 |
+
residual = x
|
65 |
+
|
66 |
+
x = self.conv_layers[-1](self.conv_layers[:-1](x) + residual)
|
67 |
+
|
68 |
+
return x
|
69 |
+
|
70 |
+
|
71 |
+
class ResNet(Convnet):
|
72 |
+
def create_layers(self, shape, conv_before_args=None, res_args=None, conv_after_args=None, fc_args=None):
|
73 |
+
'''Creates layers
|
74 |
+
|
75 |
+
Args:
|
76 |
+
shape: Shape of the input.
|
77 |
+
conv_before_args: Arguments for convolutional layers before residuals.
|
78 |
+
res_args: Residual args.
|
79 |
+
conv_after_args: Arguments for convolutional layers after residuals.
|
80 |
+
fc_args: Fully-connected arguments.
|
81 |
+
|
82 |
+
'''
|
83 |
+
|
84 |
+
dim_x, dim_y, dim_in = shape
|
85 |
+
shape = (dim_x, dim_y, dim_in)
|
86 |
+
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args)
|
87 |
+
self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args)
|
88 |
+
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args)
|
89 |
+
|
90 |
+
dim_x, dim_y, dim_out = self.conv_after_shape
|
91 |
+
dim_r = dim_x * dim_y * dim_out
|
92 |
+
self.reshape = View(-1, dim_r)
|
93 |
+
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
|
94 |
+
|
95 |
+
def create_res_layers(self, shape, block_args=None):
|
96 |
+
'''Creates a set of residual blocks.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
shape: input shape.
|
100 |
+
block_args: Arguments for blocks.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
nn.Sequential: sequence of residual blocks.
|
104 |
+
|
105 |
+
'''
|
106 |
+
|
107 |
+
res_layers = nn.Sequential()
|
108 |
+
block_args = block_args or []
|
109 |
+
|
110 |
+
for i, (conv_args, n_blocks) in enumerate(block_args):
|
111 |
+
block = ResBlock(shape, conv_args=conv_args)
|
112 |
+
res_layers.add_module('block_{}_0'.format(i), block)
|
113 |
+
|
114 |
+
for j in range(1, n_blocks):
|
115 |
+
shape = block.conv_shape
|
116 |
+
block = ResBlock(shape, conv_args=conv_args)
|
117 |
+
res_layers.add_module('block_{}_{}'.format(i, j), block)
|
118 |
+
shape = block.conv_shape
|
119 |
+
|
120 |
+
return res_layers, shape
|
121 |
+
|
122 |
+
def forward(self, x: torch.Tensor, return_full_list=False):
|
123 |
+
'''Forward pass
|
124 |
+
|
125 |
+
Args:
|
126 |
+
x: Input.
|
127 |
+
return_full_list: Optional, returns all layer outputs.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
torch.Tensor or list of torch.Tensor.
|
131 |
+
|
132 |
+
'''
|
133 |
+
|
134 |
+
if return_full_list:
|
135 |
+
conv_before_out = []
|
136 |
+
for conv_layer in self.conv_before_layers:
|
137 |
+
x = conv_layer(x)
|
138 |
+
conv_before_out.append(x)
|
139 |
+
else:
|
140 |
+
conv_before_out = self.conv_layers(x)
|
141 |
+
x = conv_before_out
|
142 |
+
|
143 |
+
if return_full_list:
|
144 |
+
res_out = []
|
145 |
+
for res_layer in self.res_layers:
|
146 |
+
x = res_layer(x)
|
147 |
+
res_out.append(x)
|
148 |
+
else:
|
149 |
+
res_out = self.res_layers(x)
|
150 |
+
x = res_out
|
151 |
+
|
152 |
+
if return_full_list:
|
153 |
+
conv_after_out = []
|
154 |
+
for conv_layer in self.conv_after_layers:
|
155 |
+
x = conv_layer(x)
|
156 |
+
conv_after_out.append(x)
|
157 |
+
else:
|
158 |
+
conv_after_out = self.conv_after_layers(x)
|
159 |
+
x = conv_after_out
|
160 |
+
|
161 |
+
x = self.reshape(x)
|
162 |
+
|
163 |
+
if return_full_list:
|
164 |
+
fc_out = []
|
165 |
+
for fc_layer in self.fc_layers:
|
166 |
+
x = fc_layer(x)
|
167 |
+
fc_out.append(x)
|
168 |
+
else:
|
169 |
+
fc_out = self.fc_layers(x)
|
170 |
+
|
171 |
+
return conv_before_out, res_out, conv_after_out, fc_out
|
172 |
+
|
173 |
+
|
174 |
+
class FoldedResNet(ResNet):
|
175 |
+
'''Resnet with strided crop input.
|
176 |
+
|
177 |
+
'''
|
178 |
+
|
179 |
+
def create_layers(self, shape, crop_size=8, conv_before_args=None, res_args=None,
|
180 |
+
conv_after_args=None, fc_args=None):
|
181 |
+
'''Creates layers
|
182 |
+
|
183 |
+
Args:
|
184 |
+
shape: Shape of the input.
|
185 |
+
crop_size: Size of the crops.
|
186 |
+
conv_before_args: Arguments for convolutional layers before residuals.
|
187 |
+
res_args: Residual args.
|
188 |
+
conv_after_args: Arguments for convolutional layers after residuals.
|
189 |
+
fc_args: Fully-connected arguments.
|
190 |
+
|
191 |
+
'''
|
192 |
+
self.crop_size = crop_size
|
193 |
+
|
194 |
+
dim_x, dim_y, dim_in = shape
|
195 |
+
self.final_size = 2 * (dim_x // self.crop_size) - 1
|
196 |
+
|
197 |
+
self.unfold = Unfold(dim_x, self.crop_size)
|
198 |
+
self.refold = Fold(dim_x, self.crop_size)
|
199 |
+
|
200 |
+
shape = (self.crop_size, self.crop_size, dim_in)
|
201 |
+
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args)
|
202 |
+
|
203 |
+
self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args)
|
204 |
+
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args)
|
205 |
+
self.conv_after_shape = self.res_shape
|
206 |
+
|
207 |
+
dim_x, dim_y, dim_out = self.conv_after_shape
|
208 |
+
dim_r = dim_x * dim_y * dim_out
|
209 |
+
self.reshape = View(-1, dim_r)
|
210 |
+
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
|
211 |
+
|
212 |
+
def create_res_layers(self, shape, block_args=None):
|
213 |
+
'''Creates a set of residual blocks.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
shape: input shape.
|
217 |
+
block_args: Arguments for blocks.
|
218 |
+
|
219 |
+
Returns:
|
220 |
+
nn.Sequential: sequence of residual blocks.
|
221 |
+
|
222 |
+
'''
|
223 |
+
|
224 |
+
res_layers = nn.Sequential()
|
225 |
+
block_args = block_args or []
|
226 |
+
|
227 |
+
for i, (conv_args, n_blocks) in enumerate(block_args):
|
228 |
+
block = ResBlock(shape, conv_args=conv_args)
|
229 |
+
res_layers.add_module('block_{}_0'.format(i), block)
|
230 |
+
|
231 |
+
for j in range(1, n_blocks):
|
232 |
+
shape = block.conv_shape
|
233 |
+
block = ResBlock(shape, conv_args=conv_args)
|
234 |
+
res_layers.add_module('block_{}_{}'.format(i, j), block)
|
235 |
+
shape = block.conv_shape
|
236 |
+
dim_x, dim_y = shape[:2]
|
237 |
+
|
238 |
+
if dim_x != dim_y:
|
239 |
+
raise ValueError('dim_x and dim_y do not match.')
|
240 |
+
|
241 |
+
if dim_x == 1:
|
242 |
+
shape = (self.final_size, self.final_size, shape[2])
|
243 |
+
|
244 |
+
return res_layers, shape
|
245 |
+
|
246 |
+
def forward(self, x: torch.Tensor, return_full_list=False):
|
247 |
+
'''Forward pass
|
248 |
+
|
249 |
+
Args:
|
250 |
+
x: Input.
|
251 |
+
return_full_list: Optional, returns all layer outputs.
|
252 |
+
|
253 |
+
Returns:
|
254 |
+
torch.Tensor or list of torch.Tensor.
|
255 |
+
|
256 |
+
'''
|
257 |
+
x = self.unfold(x)
|
258 |
+
|
259 |
+
conv_before_out = []
|
260 |
+
for conv_layer in self.conv_before_layers:
|
261 |
+
x = conv_layer(x)
|
262 |
+
if x.size(2) == 1:
|
263 |
+
x = self.refold(x)
|
264 |
+
conv_before_out.append(x)
|
265 |
+
|
266 |
+
res_out = []
|
267 |
+
for res_layer in self.res_layers:
|
268 |
+
x = res_layer(x)
|
269 |
+
res_out.append(x)
|
270 |
+
|
271 |
+
if x.size(2) == 1:
|
272 |
+
x = self.refold(x)
|
273 |
+
res_out[-1] = x
|
274 |
+
|
275 |
+
conv_after_out = []
|
276 |
+
for conv_layer in self.conv_after_layers:
|
277 |
+
x = conv_layer(x)
|
278 |
+
if x.size(2) == 1:
|
279 |
+
x = self.refold(x)
|
280 |
+
conv_after_out.append(x)
|
281 |
+
|
282 |
+
x = self.reshape(x)
|
283 |
+
|
284 |
+
if return_full_list:
|
285 |
+
fc_out = []
|
286 |
+
for fc_layer in self.fc_layers:
|
287 |
+
x = fc_layer(x)
|
288 |
+
fc_out.append(x)
|
289 |
+
else:
|
290 |
+
fc_out = self.fc_layers(x)
|
291 |
+
|
292 |
+
if not return_full_list:
|
293 |
+
conv_before_out = conv_before_out[-1]
|
294 |
+
res_out = res_out[-1]
|
295 |
+
conv_after_out = conv_after_out[-1]
|
296 |
+
|
297 |
+
return conv_before_out, res_out, conv_after_out, fc_out
|
gin.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sklearn import preprocessing
|
2 |
+
from sklearn.ensemble import RandomForestClassifier
|
3 |
+
from sklearn.linear_model import LogisticRegression
|
4 |
+
from sklearn.metrics import accuracy_score
|
5 |
+
from sklearn.model_selection import GridSearchCV, KFold, StratifiedKFold
|
6 |
+
from sklearn.model_selection import cross_val_score
|
7 |
+
from sklearn.svm import SVC, LinearSVC
|
8 |
+
from torch.nn import Sequential, Linear, ReLU
|
9 |
+
from torch_geometric.data import DataLoader
|
10 |
+
from torch_geometric.datasets import TUDataset
|
11 |
+
from torch_geometric.nn import GINConv, global_add_pool
|
12 |
+
from tqdm import tqdm
|
13 |
+
import numpy as np
|
14 |
+
import os.path as osp
|
15 |
+
import sys
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
class Encoder(torch.nn.Module):
|
21 |
+
def __init__(self, num_features, dim, num_gc_layers):
|
22 |
+
super(Encoder, self).__init__()
|
23 |
+
|
24 |
+
# num_features = dataset.num_features
|
25 |
+
# dim = 32
|
26 |
+
self.num_gc_layers = num_gc_layers
|
27 |
+
|
28 |
+
# self.nns = []
|
29 |
+
self.convs = torch.nn.ModuleList()
|
30 |
+
self.bns = torch.nn.ModuleList()
|
31 |
+
|
32 |
+
for i in range(num_gc_layers):
|
33 |
+
|
34 |
+
if i:
|
35 |
+
nn = Sequential(Linear(dim, dim), ReLU(), Linear(dim, dim))
|
36 |
+
else:
|
37 |
+
nn = Sequential(Linear(num_features, dim), ReLU(), Linear(dim, dim))
|
38 |
+
conv = GINConv(nn)
|
39 |
+
bn = torch.nn.BatchNorm1d(dim)
|
40 |
+
|
41 |
+
self.convs.append(conv)
|
42 |
+
self.bns.append(bn)
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, x, edge_index, batch):
|
46 |
+
if x is None:
|
47 |
+
x = torch.ones((batch.shape[0], 1)).to(device)
|
48 |
+
|
49 |
+
xs = []
|
50 |
+
for i in range(self.num_gc_layers):
|
51 |
+
|
52 |
+
x = F.relu(self.convs[i](x, edge_index))
|
53 |
+
x = self.bns[i](x)
|
54 |
+
xs.append(x)
|
55 |
+
# if i == 2:
|
56 |
+
# feature_map = x2
|
57 |
+
|
58 |
+
xpool = [global_add_pool(x, batch) for x in xs]
|
59 |
+
x = torch.cat(xpool, 1)
|
60 |
+
return x, torch.cat(xs, 1)
|
61 |
+
|
62 |
+
def get_embeddings(self, loader):
|
63 |
+
|
64 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
65 |
+
ret = []
|
66 |
+
y = []
|
67 |
+
with torch.no_grad():
|
68 |
+
for data in loader:
|
69 |
+
data.to(device)
|
70 |
+
x, edge_index, batch = data.x, data.edge_index, data.batch
|
71 |
+
if x is None:
|
72 |
+
x = torch.ones((batch.shape[0],1)).to(device)
|
73 |
+
x, _ = self.forward(x, edge_index, batch)
|
74 |
+
ret.append(x.cpu().numpy())
|
75 |
+
# y.append(data.aid)
|
76 |
+
ret = np.concatenate(ret, 0)
|
77 |
+
# y = np.concatenate(y, 0)
|
78 |
+
return ret
|
79 |
+
# return ret, y
|
80 |
+
|
81 |
+
class Net(torch.nn.Module):
|
82 |
+
def __init__(self):
|
83 |
+
super(Net, self).__init__()
|
84 |
+
|
85 |
+
try:
|
86 |
+
num_features = dataset.num_features
|
87 |
+
except:
|
88 |
+
num_features = 1
|
89 |
+
dim = 32
|
90 |
+
|
91 |
+
self.encoder = Encoder(num_features, dim)
|
92 |
+
|
93 |
+
self.fc1 = Linear(dim*5, dim)
|
94 |
+
self.fc2 = Linear(dim, dataset.num_classes)
|
95 |
+
|
96 |
+
def forward(self, x, edge_index, batch):
|
97 |
+
if x is None:
|
98 |
+
x = torch.ones(batch.shape[0]).to(device)
|
99 |
+
|
100 |
+
x, _ = self.encoder(x, edge_index, batch)
|
101 |
+
x = F.relu(self.fc1(x))
|
102 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
103 |
+
x = self.fc2(x)
|
104 |
+
return F.log_softmax(x, dim=-1)
|
105 |
+
|
106 |
+
def train(epoch):
|
107 |
+
model.train()
|
108 |
+
|
109 |
+
if epoch == 51:
|
110 |
+
for param_group in optimizer.param_groups:
|
111 |
+
param_group['lr'] = 0.5 * param_group['lr']
|
112 |
+
|
113 |
+
loss_all = 0
|
114 |
+
for data in train_loader:
|
115 |
+
data = data.to(device)
|
116 |
+
optimizer.zero_grad()
|
117 |
+
# print(data.x.shape)
|
118 |
+
# [ num_nodes x num_node_labels ]
|
119 |
+
# print(data.edge_index.shape)
|
120 |
+
# [2 x num_edges ]
|
121 |
+
# print(data.batch.shape)
|
122 |
+
# [ num_nodes ]
|
123 |
+
output = model(data.x, data.edge_index, data.batch)
|
124 |
+
loss = F.nll_loss(output, data.y)
|
125 |
+
loss.backward()
|
126 |
+
loss_all += loss.item() * data.num_graphs
|
127 |
+
optimizer.step()
|
128 |
+
|
129 |
+
return loss_all / len(train_dataset)
|
130 |
+
|
131 |
+
def test(loader):
|
132 |
+
model.eval()
|
133 |
+
|
134 |
+
correct = 0
|
135 |
+
for data in loader:
|
136 |
+
data = data.to(device)
|
137 |
+
output = model(data.x, data.edge_index, data.batch)
|
138 |
+
pred = output.max(dim=1)[1]
|
139 |
+
correct += pred.eq(data.y).sum().item()
|
140 |
+
return correct / len(loader.dataset)
|
141 |
+
|
142 |
+
|
143 |
+
if __name__ == '__main__':
|
144 |
+
for percentage in [ 1.]:
|
145 |
+
for DS in [sys.argv[1]]:
|
146 |
+
if 'REDDIT' in DS:
|
147 |
+
epochs = 200
|
148 |
+
else:
|
149 |
+
epochs = 100
|
150 |
+
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', DS)
|
151 |
+
accuracies = [[] for i in range(epochs)]
|
152 |
+
#kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=None)
|
153 |
+
dataset = TUDataset(path, name=DS) #.shuffle()
|
154 |
+
num_graphs = len(dataset)
|
155 |
+
print('Number of graphs', len(dataset))
|
156 |
+
dataset = dataset[:int(num_graphs * percentage)]
|
157 |
+
dataset = dataset.shuffle()
|
158 |
+
|
159 |
+
kf = KFold(n_splits=10, shuffle=True, random_state=None)
|
160 |
+
for train_index, test_index in kf.split(dataset):
|
161 |
+
|
162 |
+
# x_train, x_test = x[train_index], x[test_index]
|
163 |
+
# y_train, y_test = y[train_index], y[test_index]
|
164 |
+
train_dataset = [dataset[int(i)] for i in list(train_index)]
|
165 |
+
test_dataset = [dataset[int(i)] for i in list(test_index)]
|
166 |
+
print('len(train_dataset)', len(train_dataset))
|
167 |
+
print('len(test_dataset)', len(test_dataset))
|
168 |
+
|
169 |
+
train_loader = DataLoader(train_dataset, batch_size=128)
|
170 |
+
test_loader = DataLoader(test_dataset, batch_size=128)
|
171 |
+
# print('train', len(train_loader))
|
172 |
+
# print('test', len(test_loader))
|
173 |
+
|
174 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
175 |
+
model = Net().to(device)
|
176 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
177 |
+
|
178 |
+
for epoch in range(1, epochs+1):
|
179 |
+
train_loss = train(epoch)
|
180 |
+
train_acc = test(train_loader)
|
181 |
+
test_acc = test(test_loader)
|
182 |
+
accuracies[epoch-1].append(test_acc)
|
183 |
+
tqdm.write('Epoch: {:03d}, Train Loss: {:.7f}, '
|
184 |
+
'Train Acc: {:.7f}, Test Acc: {:.7f}'.format(epoch, train_loss,
|
185 |
+
train_acc, test_acc))
|
186 |
+
tmp = np.mean(accuracies, axis=1)
|
187 |
+
print(percentage, DS, np.argmax(tmp), np.max(tmp), np.std(accuracies[np.argmax(tmp)]))
|
188 |
+
input()
|
model.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from cortex_DIM.nn_modules.mi_networks import MIFCNet, MI1x1ConvNet
|
2 |
+
from torch import optim
|
3 |
+
from torch.autograd import Variable
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class GlobalDiscriminator(nn.Module):
|
12 |
+
def __init__(self, args, input_dim):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.l0 = nn.Linear(32, 32)
|
16 |
+
self.l1 = nn.Linear(32, 32)
|
17 |
+
|
18 |
+
self.l2 = nn.Linear(512, 1)
|
19 |
+
def forward(self, y, M, data):
|
20 |
+
|
21 |
+
adj = Variable(data['adj'].float(), requires_grad=False).cuda()
|
22 |
+
# h0 = Variable(data['feats'].float()).cuda()
|
23 |
+
batch_num_nodes = data['num_nodes'].int().numpy()
|
24 |
+
M, _ = self.encoder(M, adj, batch_num_nodes)
|
25 |
+
# h = F.relu(self.c0(M))
|
26 |
+
# h = self.c1(h)
|
27 |
+
# h = h.view(y.shape[0], -1)
|
28 |
+
h = torch.cat((y, M), dim=1)
|
29 |
+
h = F.relu(self.l0(h))
|
30 |
+
h = F.relu(self.l1(h))
|
31 |
+
return self.l2(h)
|
32 |
+
|
33 |
+
class PriorDiscriminator(nn.Module):
|
34 |
+
def __init__(self, input_dim):
|
35 |
+
super().__init__()
|
36 |
+
self.l0 = nn.Linear(input_dim, input_dim)
|
37 |
+
self.l1 = nn.Linear(input_dim, input_dim)
|
38 |
+
self.l2 = nn.Linear(input_dim, 1)
|
39 |
+
|
40 |
+
def forward(self, x):
|
41 |
+
h = F.relu(self.l0(x))
|
42 |
+
h = F.relu(self.l1(h))
|
43 |
+
return torch.sigmoid(self.l2(h))
|
44 |
+
|
45 |
+
class FF(nn.Module):
|
46 |
+
def __init__(self, input_dim):
|
47 |
+
super().__init__()
|
48 |
+
# self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1)
|
49 |
+
# self.c1 = nn.Conv1d(512, 512, kernel_size=1)
|
50 |
+
# self.c2 = nn.Conv1d(512, 1, kernel_size=1)
|
51 |
+
self.block = nn.Sequential(
|
52 |
+
nn.Linear(input_dim, input_dim),
|
53 |
+
nn.ReLU(),
|
54 |
+
nn.Linear(input_dim, input_dim),
|
55 |
+
nn.ReLU(),
|
56 |
+
nn.Linear(input_dim, input_dim),
|
57 |
+
nn.ReLU()
|
58 |
+
)
|
59 |
+
self.linear_shortcut = nn.Linear(input_dim, input_dim)
|
60 |
+
# self.c0 = nn.Conv1d(input_dim, 512, kernel_size=1, stride=1, padding=0)
|
61 |
+
# self.c1 = nn.Conv1d(512, 512, kernel_size=1, stride=1, padding=0)
|
62 |
+
# self.c2 = nn.Conv1d(512, 1, kernel_size=1, stride=1, padding=0)
|
63 |
+
|
64 |
+
def forward(self, x):
|
65 |
+
return self.block(x) + self.linear_shortcut(x)
|
66 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.25.0
|
2 |
+
scikit-learn==1.2.2
|
3 |
+
shapely==2.0.1
|
4 |
+
statistics
|
5 |
+
collection
|
6 |
+
torch==2.0.1
|
7 |
+
opencv-python==4.8.0.74
|
8 |
+
torch-geometric==2.4.0
|
utils/clean_data.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from collections import defaultdict
|
3 |
+
import shapely
|
4 |
+
import shapely.wkt
|
5 |
+
from shapely.geometry import LineString, MultiLineString, Polygon, Point, MultiPoint
|
6 |
+
from shapely.prepared import prep
|
7 |
+
|
8 |
+
|
9 |
+
def list_duplicates(seq):
|
10 |
+
tally = defaultdict(list)
|
11 |
+
for i,item in enumerate(seq):
|
12 |
+
tally[item].append(i)
|
13 |
+
return ((key,locs) for key,locs in tally.items() if len(locs)>1)
|
14 |
+
|
15 |
+
|
16 |
+
# clean wall and wkt
|
17 |
+
# wall = ((IOOI))
|
18 |
+
# wkt = 'POLYGON ((x0 y0, x1 y1, x2 y2, x3 y3, x0 y0))'
|
19 |
+
def read_wall_wkt(wall, wkt):
|
20 |
+
# wall to list
|
21 |
+
wall_l = wall.split("), (")[0]
|
22 |
+
wall_l = wall_l.split("((")[1]
|
23 |
+
wall_l = wall_l.split("))")[0]
|
24 |
+
wall_c = [*wall_l]
|
25 |
+
|
26 |
+
#clean wkt
|
27 |
+
wkt_l = wkt.split("((")[1]
|
28 |
+
wkt_l = wkt_l.split("))")[0]
|
29 |
+
wkt_l = wkt_l.split("), (")
|
30 |
+
|
31 |
+
if len(wkt_l) == 1:
|
32 |
+
wkt_c = wkt
|
33 |
+
else:
|
34 |
+
wkt_c = "POLYGON ((" + wkt_l[0] + "))"
|
35 |
+
|
36 |
+
wkt_c = wkt_c.split("((")[1]
|
37 |
+
wkt_c = wkt_c.split("))")[0]
|
38 |
+
wkt_c = wkt_c.split(", ")
|
39 |
+
|
40 |
+
# remove duplicate point
|
41 |
+
num_p = len(wkt_c) - 1
|
42 |
+
remove_index = []
|
43 |
+
for dup in sorted(list_duplicates(wkt_c)):
|
44 |
+
dup_index = dup[1]
|
45 |
+
if 0 in dup_index and num_p in dup_index and len(dup_index) == 2:
|
46 |
+
pass
|
47 |
+
|
48 |
+
elif 0 in dup_index and num_p in dup_index and len(dup_index) > 2:
|
49 |
+
dup_index_num = len(dup_index)-1
|
50 |
+
for j in range(1, dup_index_num):
|
51 |
+
ri = dup_index[j]
|
52 |
+
remove_hindex.append(ri)
|
53 |
+
|
54 |
+
else:
|
55 |
+
dup_index_num = len(dup_index)-1
|
56 |
+
for j in range(dup_index_num):
|
57 |
+
ri = dup_index[j]
|
58 |
+
remove_hindex.append(ri)
|
59 |
+
|
60 |
+
wall_f = []
|
61 |
+
wkt_f = []
|
62 |
+
for p in range(len(wkt_c)):
|
63 |
+
if p not in remove_index:
|
64 |
+
wkt_u = wkt_c[p]
|
65 |
+
wkt_f.append(wkt_u)
|
66 |
+
|
67 |
+
if p < (len(wkt_c)-1):
|
68 |
+
wall_u = wall_c[p]
|
69 |
+
wall_f.append(wall_u)
|
70 |
+
|
71 |
+
wkt_f = ", ".join(wkt_f)
|
72 |
+
wkt_f = "POLYGON ((" + wkt_f + "))"
|
73 |
+
|
74 |
+
return wall_f, wkt_f
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
def clean_geometry(wall, wkt):
|
79 |
+
# load geometry
|
80 |
+
geo = shapely.wkt.loads(wkt)
|
81 |
+
|
82 |
+
# move to (0,0)
|
83 |
+
geo_centroid = geo.centroid
|
84 |
+
translation_vector = (-geo_centroid.x, -geo_centroid.y)
|
85 |
+
moved_coords = [(x + translation_vector[0], y + translation_vector[1]) for x, y in geo.exterior.coords]
|
86 |
+
moved_geo = shapely.wkt.loads('POLYGON ((' + ', '.join([f'{x} {y}' for x, y in moved_coords]) + '))')
|
87 |
+
|
88 |
+
# if counterclockwise
|
89 |
+
if moved_geo.exterior.is_ccw:
|
90 |
+
geo_ccw = moved_geo
|
91 |
+
wall_ccw = wall
|
92 |
+
else:
|
93 |
+
geo_ccw = shapely.geometry.polygon.orient(moved_geo, 1)
|
94 |
+
|
95 |
+
walltypes = len(list(set(wall)))
|
96 |
+
if walltypes == 1:
|
97 |
+
wall_ccw = wall
|
98 |
+
else:
|
99 |
+
wall_ccw = wall[::-1]
|
100 |
+
|
101 |
+
|
102 |
+
# ccw_geo
|
103 |
+
coor_ccw = geo_ccw.exterior.coords
|
104 |
+
coor_ccw = list(coor_ccw)
|
105 |
+
coor_ccw = coor_ccw[:-1]
|
106 |
+
|
107 |
+
coor_ccw_num = len(coor_ccw)
|
108 |
+
coor_ccw_xpy_lst = []
|
109 |
+
for i in range(coor_ccw_num):
|
110 |
+
coor_ccw_x = coor_ccw[i][0]
|
111 |
+
coor_ccw_y = coor_ccw[i][1]
|
112 |
+
coor_ccw_xpy = coor_ccw_x + coor_ccw_y
|
113 |
+
coor_ccw_xpy_lst.append(coor_ccw_xpy)
|
114 |
+
|
115 |
+
coor_ccw_xpy_min_index = np.array(coor_ccw_xpy_lst).argmin()
|
116 |
+
coor_ccw_sort_index = []
|
117 |
+
for i in range(len(coor_ccw_xpy_lst)):
|
118 |
+
index_max = len(coor_ccw_xpy_lst) - 1 - coor_ccw_xpy_min_index
|
119 |
+
if i <= index_max:
|
120 |
+
sort_index = coor_ccw_xpy_min_index + i
|
121 |
+
else:
|
122 |
+
sort_index = i - len(coor_ccw_xpy_lst) + coor_ccw_xpy_min_index
|
123 |
+
coor_ccw_sort_index.append(sort_index)
|
124 |
+
|
125 |
+
|
126 |
+
coor_sort_lst = []
|
127 |
+
wall_sort_lst = []
|
128 |
+
for i in range(len(coor_ccw_sort_index)):
|
129 |
+
sort_index = coor_ccw_sort_index[i]
|
130 |
+
sort_coor = coor_ccw[sort_index]
|
131 |
+
sort_wall = wall_ccw[sort_index]
|
132 |
+
coor_sort_lst.append(sort_coor)
|
133 |
+
wall_sort_lst.append(sort_wall)
|
134 |
+
|
135 |
+
geo_s = Polygon(coor_sort_lst)
|
136 |
+
wall_s = wall_sort_lst
|
137 |
+
return wall_s, geo_s
|
138 |
+
|
139 |
+
|
140 |
+
def segments(polyline):
|
141 |
+
return list(map(LineString, zip(polyline.coords[:-1], polyline.coords[1:])))
|
142 |
+
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
def points4cv(x, y, xmin_abs, ymin_abs, scale):
|
147 |
+
points = []
|
148 |
+
for j in range(len(x)):
|
149 |
+
xp =x[j]
|
150 |
+
yp =y[j]
|
151 |
+
|
152 |
+
xp = (xp + xmin_abs +1) * scale
|
153 |
+
yp = (yp + ymin_abs +1) * scale
|
154 |
+
p = [int(xp), int(yp)]
|
155 |
+
points.append(p)
|
156 |
+
|
157 |
+
p_4_cv = np.array(points)
|
158 |
+
return p_4_cv
|
159 |
+
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
def gridpoints(apa_geo, size):
|
164 |
+
latmin, lonmin, latmax, lonmax = apa_geo.bounds
|
165 |
+
prep_moved_apa_geo = prep(apa_geo)
|
166 |
+
|
167 |
+
# construct a rectangular mesh
|
168 |
+
gp = []
|
169 |
+
for lat in np.arange(latmin, latmax, size):
|
170 |
+
for lon in np.arange(lonmin, lonmax, size):
|
171 |
+
gp.append(Point((round(lat,5), round(lon,5))))
|
172 |
+
gps = prep_moved_apa_geo.contains(gp)
|
173 |
+
gpf = [i for indx,i in enumerate(gp) if gps[indx] == True]
|
174 |
+
grid_points = MultiPoint(gpf)
|
175 |
+
return grid_points
|
176 |
+
|
177 |
+
|
178 |
+
|
179 |
+
def exterior_wall(apa_line, apa_wall):
|
180 |
+
apa_wall_O = [i for indx,i in enumerate(segments(apa_line)) if apa_wall[indx] == "O"]
|
181 |
+
apa_wall_O = MultiLineString(apa_wall_O)
|
182 |
+
return apa_wall_O
|
183 |
+
|
184 |
+
|
185 |
+
def geo_coor(apa_geo):
|
186 |
+
apa_coor = apa_geo.exterior.coords
|
187 |
+
apa_coor = list(apa_coor)
|
188 |
+
return apa_coor
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
utils/edge_features.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
import shapely
|
5 |
+
from shapely.geometry import LineString
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch_geometric.utils import to_undirected, remove_self_loops
|
9 |
+
from torch_geometric.data import Data
|
10 |
+
|
11 |
+
def segments(polyline):
|
12 |
+
return list(map(LineString, zip(polyline.coords[:-1], polyline.coords[1:])))
|
13 |
+
|
14 |
+
|
15 |
+
def edge_graph(apa_line, apa_wall):
|
16 |
+
wall_seg = segments(apa_line)
|
17 |
+
num_seg = len(wall_seg)
|
18 |
+
|
19 |
+
|
20 |
+
edge_lst = []
|
21 |
+
for l in range(num_seg):
|
22 |
+
seg = wall_seg[l]
|
23 |
+
seg_length = seg.length
|
24 |
+
seg_pro = apa_wall[l]
|
25 |
+
south_cos = wall_segment_cosine("south", seg)
|
26 |
+
east_cos = wall_segment_cosine("east", seg)
|
27 |
+
north_cos = wall_segment_cosine("north", seg)
|
28 |
+
west_cos = wall_segment_cosine("west", seg)
|
29 |
+
|
30 |
+
if south_cos < 0:
|
31 |
+
south_cos = 0
|
32 |
+
if east_cos < 0:
|
33 |
+
east_cos = 0
|
34 |
+
if north_cos < 0:
|
35 |
+
north_cos = 0
|
36 |
+
if west_cos < 0:
|
37 |
+
west_cos = 0
|
38 |
+
|
39 |
+
if seg_pro == "I":
|
40 |
+
south_cos = 0
|
41 |
+
east_cos = 0
|
42 |
+
north_cos = 0
|
43 |
+
west_cos = 0
|
44 |
+
|
45 |
+
if seg_pro == "O":
|
46 |
+
seg_boo = 1
|
47 |
+
else:
|
48 |
+
seg_boo = 0
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
edge = [seg_boo, seg_length, south_cos, north_cos, west_cos, east_cos]
|
53 |
+
edge_lst.append(edge)
|
54 |
+
|
55 |
+
|
56 |
+
ms_lst = []
|
57 |
+
me_lst = []
|
58 |
+
for k in range(num_seg):
|
59 |
+
if k == (num_seg - 1):
|
60 |
+
ms = k
|
61 |
+
me = 0
|
62 |
+
else:
|
63 |
+
ms = k
|
64 |
+
me = k+1
|
65 |
+
ms_lst.append(ms)
|
66 |
+
me_lst.append(me)
|
67 |
+
mse = [ms_lst, me_lst]
|
68 |
+
|
69 |
+
datasets = []
|
70 |
+
for i in range(2):
|
71 |
+
node_features = torch.FloatTensor(edge_lst)
|
72 |
+
x = node_features
|
73 |
+
|
74 |
+
edge_index = torch.tensor(mse, dtype=torch.long)
|
75 |
+
edge_index, _ = remove_self_loops(edge_index)
|
76 |
+
edge_index = to_undirected(edge_index=edge_index)
|
77 |
+
|
78 |
+
data = Data(x=x, edge_index=edge_index)
|
79 |
+
datasets.append(data)
|
80 |
+
return datasets
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
def wall_segment_cosine(direction, apa_line_seg):
|
85 |
+
seg_s = list(apa_line_seg.coords)[0]
|
86 |
+
seg_e = list(apa_line_seg.coords)[1]
|
87 |
+
|
88 |
+
normal_x = seg_e[0] - seg_s[0]
|
89 |
+
normal_y = seg_e[1] - seg_s[1]
|
90 |
+
|
91 |
+
normal_s = (-normal_y, normal_x)
|
92 |
+
normal_e = (normal_y, -normal_x)
|
93 |
+
|
94 |
+
o = np.array([-normal_y, normal_x])
|
95 |
+
w = np.array([normal_y, -normal_x])
|
96 |
+
|
97 |
+
if direction == "south":
|
98 |
+
d = np.array([-normal_y, normal_x-1])
|
99 |
+
if direction == "east":
|
100 |
+
d = np.array([-normal_y+1, normal_x])
|
101 |
+
if direction == "north":
|
102 |
+
d = np.array([-normal_y, normal_x+1])
|
103 |
+
if direction == "west":
|
104 |
+
d = np.array([-normal_y-1, normal_x])
|
105 |
+
|
106 |
+
od = d - o
|
107 |
+
ow = w - o
|
108 |
+
|
109 |
+
cosine = np.dot(od, ow) / (np.linalg.norm(od) * np.linalg.norm(ow))
|
110 |
+
return cosine
|
111 |
+
|
112 |
+
|
113 |
+
|
utils/emb_model/Edge_64.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:127b70aeaa5ebe3eadc2bdf26f0521bc87cd78db3454087c86f8b12eecc9980c
|
3 |
+
size 174179
|
utils/emb_model/Node_64.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c22fee85b8744349a7746eb61324f5259032fac224edb536aea06542a12cdd67
|
3 |
+
size 173681
|
utils/node_features.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import math
|
3 |
+
|
4 |
+
import shapely
|
5 |
+
from shapely.geometry import Polygon, LineString
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from torch_geometric.utils import to_undirected, remove_self_loops
|
9 |
+
from torch_geometric.data import Data
|
10 |
+
|
11 |
+
|
12 |
+
def node_graph(apa_coor, apa_geo):
|
13 |
+
num_op = len(apa_coor)
|
14 |
+
apa_coor = apa_coor[0:-1]
|
15 |
+
# apa_coor.pop(num_op-1)
|
16 |
+
|
17 |
+
node_lst = []
|
18 |
+
num_p = len(apa_coor)
|
19 |
+
for j in range(num_p):
|
20 |
+
p = apa_coor[j]
|
21 |
+
if j == 0:
|
22 |
+
sindex = -1
|
23 |
+
oindex = j
|
24 |
+
eindex = 1
|
25 |
+
elif j == (len(apa_coor)-1):
|
26 |
+
sindex = j-1
|
27 |
+
oindex = j
|
28 |
+
eindex = 0
|
29 |
+
else:
|
30 |
+
sindex = j-1
|
31 |
+
oindex = j
|
32 |
+
eindex = j+1
|
33 |
+
|
34 |
+
sp = apa_coor[sindex]
|
35 |
+
s = np.array(sp)
|
36 |
+
|
37 |
+
op = apa_coor[oindex]
|
38 |
+
o = np.array(op)
|
39 |
+
ox = op[0]
|
40 |
+
oy = op[1]
|
41 |
+
|
42 |
+
ep = apa_coor[eindex]
|
43 |
+
e = np.array(ep)
|
44 |
+
|
45 |
+
Area = apa_geo.area
|
46 |
+
local_polygon = Polygon((sp, op, ep))
|
47 |
+
larea = (local_polygon.area) / Area
|
48 |
+
|
49 |
+
se = LineString((sp, ep))
|
50 |
+
llength = se.length / math.sqrt(Area)
|
51 |
+
|
52 |
+
osv = s - o
|
53 |
+
oev = e - o
|
54 |
+
|
55 |
+
langle = angle_between(osv, oev)
|
56 |
+
if langle < 0:
|
57 |
+
langle = langle + (2*math.pi)
|
58 |
+
|
59 |
+
|
60 |
+
oop = (0, 0)
|
61 |
+
oo = np.array(oop)
|
62 |
+
regional_polygon = Polygon((sp, oop, ep))
|
63 |
+
regional_polygon_area = regional_polygon.area
|
64 |
+
rarea = regional_polygon_area / Area
|
65 |
+
|
66 |
+
regional_polygon_perimeter = regional_polygon.length / 2
|
67 |
+
rperimeter = regional_polygon_perimeter / math.sqrt(Area)
|
68 |
+
|
69 |
+
rradius = (regional_polygon_area / regional_polygon_perimeter) / math.sqrt(Area)
|
70 |
+
|
71 |
+
oosv = s - oo
|
72 |
+
ooev = e - oo
|
73 |
+
rangle = angle_between(oosv, ooev)
|
74 |
+
if rangle < 0:
|
75 |
+
rangle = rangle + (2*math.pi)
|
76 |
+
|
77 |
+
#ox, oy,
|
78 |
+
nl = [larea, llength, langle, rarea, rperimeter, rradius, rangle]
|
79 |
+
|
80 |
+
node_lst.append(nl)
|
81 |
+
|
82 |
+
|
83 |
+
ms_lst = []
|
84 |
+
me_lst = []
|
85 |
+
for k in range(num_p):
|
86 |
+
if k == (num_p - 1):
|
87 |
+
ms = k
|
88 |
+
me = 0
|
89 |
+
else:
|
90 |
+
ms = k
|
91 |
+
me = k+1
|
92 |
+
ms_lst.append(ms)
|
93 |
+
me_lst.append(me)
|
94 |
+
mse = [ms_lst, me_lst]
|
95 |
+
|
96 |
+
datasets = []
|
97 |
+
for i in range(2):
|
98 |
+
node_f = torch.FloatTensor(node_lst)
|
99 |
+
x = node_f
|
100 |
+
|
101 |
+
edge_index = torch.tensor(mse, dtype=torch.long)
|
102 |
+
edge_index, _ = remove_self_loops(edge_index)
|
103 |
+
edge_index = to_undirected(edge_index=edge_index)
|
104 |
+
|
105 |
+
data = Data(x=x, edge_index=edge_index)
|
106 |
+
datasets.append(data)
|
107 |
+
return datasets
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
def angle_between(v1, v2):
|
113 |
+
""" Returns the angle in radians between vectors 'v1' and 'v2'
|
114 |
+
The sign of the angle is dependent on the order of v1 and v2
|
115 |
+
so acos(norm(dot(v1, v2))) does not work and atan2 has to be used, see:
|
116 |
+
https://stackoverflow.com/questions/21483999/using-atan2-to-find-angle-between-two-vectors
|
117 |
+
"""
|
118 |
+
arg1 = np.cross(v1, v2)
|
119 |
+
arg2 = np.dot(v1, v2)
|
120 |
+
angle = np.arctan2(arg1, arg2)
|
121 |
+
return angle
|
122 |
+
|
123 |
+
|
utils/shape_features.py
ADDED
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
from statistics import mean, stdev
|
5 |
+
from collections import defaultdict
|
6 |
+
|
7 |
+
import shapely
|
8 |
+
import shapely.wkt
|
9 |
+
from shapely.geometry import Point, MultiPoint, LineString, MultiLineString, Polygon, LinearRing
|
10 |
+
from shapely.ops import voronoi_diagram, substring, unary_union, nearest_points
|
11 |
+
from shapely import affinity
|
12 |
+
from shapely.prepared import prep
|
13 |
+
|
14 |
+
import cv2 as cv
|
15 |
+
|
16 |
+
|
17 |
+
def segments(polyline):
|
18 |
+
return list(map(LineString, zip(polyline.coords[:-1], polyline.coords[1:])))
|
19 |
+
|
20 |
+
|
21 |
+
def scale_move_x(x, xmin_abs, scale):
|
22 |
+
xn = (x / scale) - 1 - xmin_abs
|
23 |
+
return xn
|
24 |
+
|
25 |
+
def scale_move_y(y, ymin_abs, scale):
|
26 |
+
yn = (y / scale) - 1 - ymin_abs
|
27 |
+
return yn
|
28 |
+
|
29 |
+
def scale_area(a, scale):
|
30 |
+
a = a / (scale**2)
|
31 |
+
return a
|
32 |
+
|
33 |
+
def scale_perimeter(p, scale):
|
34 |
+
p = p / scale
|
35 |
+
return p
|
36 |
+
|
37 |
+
|
38 |
+
def wall_segment_cosine(direction, apa_line_seg):
|
39 |
+
seg_s = list(apa_line_seg.coords)[0]
|
40 |
+
seg_e = list(apa_line_seg.coords)[1]
|
41 |
+
|
42 |
+
normal_x = seg_e[0] - seg_s[0]
|
43 |
+
normal_y = seg_e[1] - seg_s[1]
|
44 |
+
|
45 |
+
normal_s = (-normal_y, normal_x)
|
46 |
+
normal_e = (normal_y, -normal_x)
|
47 |
+
|
48 |
+
o = np.array([-normal_y, normal_x])
|
49 |
+
w = np.array([normal_y, -normal_x])
|
50 |
+
|
51 |
+
if direction == "south":
|
52 |
+
d = np.array([-normal_y, normal_x-1])
|
53 |
+
if direction == "east":
|
54 |
+
d = np.array([-normal_y+1, normal_x])
|
55 |
+
if direction == "north":
|
56 |
+
d = np.array([-normal_y, normal_x+1])
|
57 |
+
if direction == "west":
|
58 |
+
d = np.array([-normal_y-1, normal_x])
|
59 |
+
|
60 |
+
od = d - o
|
61 |
+
ow = w - o
|
62 |
+
|
63 |
+
cosine = np.dot(od, ow) / (np.linalg.norm(od) * np.linalg.norm(ow))
|
64 |
+
return cosine
|
65 |
+
|
66 |
+
|
67 |
+
# Dir_S_longestedge, Dir_N_longestedge, Dir_W_longestedge, Dir_E_longestedge, Dir_S_max, Dir_N_max, Dir_W_max, Dir_E_max, Facade_length, Facade_ratio
|
68 |
+
def wall_direction_ratio(apa_line, apa_wall):
|
69 |
+
apa_wall_O = [i for indx,i in enumerate(segments(apa_line)) if apa_wall[indx] == "O"]
|
70 |
+
apa_wall_O = MultiLineString(apa_wall_O)
|
71 |
+
|
72 |
+
wall_O_length = []
|
73 |
+
wall_O_south = []
|
74 |
+
wall_O_east = []
|
75 |
+
wall_O_north = []
|
76 |
+
wall_O_west = []
|
77 |
+
apa_wall_O_num = len(apa_wall_O.geoms)
|
78 |
+
|
79 |
+
if apa_wall_O_num > 0:
|
80 |
+
for i in range(apa_wall_O_num):
|
81 |
+
wall_seg = apa_wall_O.geoms[i]
|
82 |
+
wall_length = wall_seg.length
|
83 |
+
south_cos = wall_segment_cosine("south", wall_seg)
|
84 |
+
east_cos = wall_segment_cosine("east", wall_seg)
|
85 |
+
north_cos = wall_segment_cosine("north", wall_seg)
|
86 |
+
west_cos = wall_segment_cosine("west", wall_seg)
|
87 |
+
|
88 |
+
if south_cos < 0:
|
89 |
+
south_cos = 0
|
90 |
+
if east_cos < 0:
|
91 |
+
east_cos = 0
|
92 |
+
if north_cos < 0:
|
93 |
+
north_cos = 0
|
94 |
+
if west_cos < 0:
|
95 |
+
west_cos = 0
|
96 |
+
|
97 |
+
wall_O_length.append(wall_length)
|
98 |
+
wall_O_south.append(south_cos)
|
99 |
+
wall_O_east.append(east_cos)
|
100 |
+
wall_O_north.append(north_cos)
|
101 |
+
wall_O_west.append(west_cos)
|
102 |
+
|
103 |
+
|
104 |
+
max_length_index = np.array(wall_O_length).argmax()
|
105 |
+
Dir_S_longestedge = wall_O_south[max_length_index]
|
106 |
+
Dir_N_longestedge = wall_O_north[max_length_index]
|
107 |
+
Dir_W_longestedge = wall_O_west[max_length_index]
|
108 |
+
Dir_E_longestedge = wall_O_east[max_length_index]
|
109 |
+
|
110 |
+
Dir_S_max = max(wall_O_south)
|
111 |
+
Dir_N_max = max(wall_O_north)
|
112 |
+
Dir_W_max = max(wall_O_west)
|
113 |
+
Dir_E_max = max(wall_O_east)
|
114 |
+
|
115 |
+
Facade_length = apa_wall_O.length
|
116 |
+
apa_line_length = apa_line.length
|
117 |
+
Facade_ratio = Facade_length / apa_line_length
|
118 |
+
else:
|
119 |
+
Dir_S_longestedge = 0
|
120 |
+
Dir_N_longestedge = 0
|
121 |
+
Dir_W_longestedge = 0
|
122 |
+
Dir_E_longestedge = 0
|
123 |
+
Dir_S_max = 0
|
124 |
+
Dir_N_max = 0
|
125 |
+
Dir_W_max = 0
|
126 |
+
Dir_E_max = 0
|
127 |
+
Facade_length = 0
|
128 |
+
Facade_ratio = 0
|
129 |
+
|
130 |
+
return Dir_S_longestedge, Dir_N_longestedge, Dir_W_longestedge, Dir_E_longestedge, Dir_S_max, Dir_N_max, Dir_W_max, Dir_E_max, Facade_length, Facade_ratio
|
131 |
+
|
132 |
+
|
133 |
+
# apa_geo
|
134 |
+
def apartment_perimeter(apa_geo):
|
135 |
+
perimeter =apa_geo.length
|
136 |
+
return perimeter
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
def apartment_area(apa_geo):
|
141 |
+
area =apa_geo.area
|
142 |
+
return area
|
143 |
+
|
144 |
+
|
145 |
+
def boundingbox(apa_geo):
|
146 |
+
boundingbox = apa_geo.bounds
|
147 |
+
return boundingbox
|
148 |
+
|
149 |
+
|
150 |
+
# BBox_width_x, BBox_height_y, Aspect_ratio, Extent, ULC_x, ULC_y, LRC_x, LRC_y
|
151 |
+
def boundingbox_features(apa_geo):
|
152 |
+
# [Aspect_ratio, Extent] ---> https://docs.opencv.org/3.4/d1/d32/tutorial_py_contour_properties.html
|
153 |
+
|
154 |
+
bbox_xy = boundingbox(apa_geo)
|
155 |
+
bbox_geo = Polygon([(bbox_xy[0], bbox_xy[1]), (bbox_xy[2], bbox_xy[1]), (bbox_xy[2], bbox_xy[3]), (bbox_xy[0], bbox_xy[3])])
|
156 |
+
|
157 |
+
BBox_width_x = bbox_xy[2] - bbox_xy[0]
|
158 |
+
BBox_height_y = bbox_xy[3] - bbox_xy[1]
|
159 |
+
Aspect_ratio = BBox_width_x / BBox_height_y
|
160 |
+
|
161 |
+
bbox_geo_area = bbox_geo.area
|
162 |
+
Area = apartment_area(apa_geo)
|
163 |
+
Extent = Area / bbox_geo_area
|
164 |
+
|
165 |
+
ULC_x = bbox_xy[0]
|
166 |
+
ULC_y = bbox_xy[3]
|
167 |
+
LRC_x = bbox_xy[2]
|
168 |
+
LRC_y = bbox_xy[1]
|
169 |
+
|
170 |
+
return BBox_width_x, BBox_height_y, Aspect_ratio, Extent, ULC_x, ULC_y, LRC_x, LRC_y
|
171 |
+
|
172 |
+
|
173 |
+
# Max_diameter
|
174 |
+
def max_diameter(apa_geo):
|
175 |
+
# [Max_diameter] ---> https://www.mvtec.com/doc/halcon/12/en/diameter_region.html
|
176 |
+
apa_coor = list(apa_geo.exterior.coords)
|
177 |
+
|
178 |
+
pp_dis_lst = []
|
179 |
+
for i in apa_coor:
|
180 |
+
for j in apa_coor:
|
181 |
+
pp_dis = Point(i).distance(Point(j))
|
182 |
+
pp_dis_lst.append(pp_dis)
|
183 |
+
|
184 |
+
max_diameter = max(pp_dis_lst)
|
185 |
+
return max_diameter
|
186 |
+
|
187 |
+
|
188 |
+
|
189 |
+
def fractality(apa_geo):
|
190 |
+
# [Fractality] ---> https://onlinelibrary.wiley.com/doi/epdf/10.1111/j.1538-4632.2000.tb00419.x
|
191 |
+
# Basaraner, M. and Cetinkaya, S. (2017) ‘Performance of shape indices and classification schemes for characterising perceptual shape complexity of building footprints in GIS’, International Journal of Geographical Information Science, 31(10), pp. 1952–1977. doi:10.1080/13658816.2017.1346257.
|
192 |
+
Area = apartment_area(apa_geo)
|
193 |
+
Perimeter = apartment_perimeter(apa_geo)
|
194 |
+
|
195 |
+
fractality = 1 - ((math.log(Area) / (2 * math.log(Perimeter))))
|
196 |
+
return fractality
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
def circularity(apa_geo):
|
201 |
+
# [Circularity] ---> https://www.mvtec.com/doc/halcon/12/en/circularity.html
|
202 |
+
apa_coor = list(apa_geo.exterior.coords)
|
203 |
+
op_dis_lst = []
|
204 |
+
for i in apa_coor:
|
205 |
+
op_dis = Point((0, 0)).distance(Point(i))
|
206 |
+
op_dis_lst.append(op_dis)
|
207 |
+
|
208 |
+
Max_radius = max(op_dis_lst)
|
209 |
+
|
210 |
+
Area = apartment_area(apa_geo)
|
211 |
+
|
212 |
+
circularity = Area / ((math.pi) * (Max_radius**2))
|
213 |
+
return circularity
|
214 |
+
|
215 |
+
|
216 |
+
|
217 |
+
def outer_radius(p_4_cv, xmin_abs, ymin_abs, scale):
|
218 |
+
# [Outer_radius] ---> https://docs.opencv.org/4.x/d3/dc0/group__imgproc__shape.html#ga8ce13c24081bbc7151e9326f412190f1
|
219 |
+
(xmin,ymin),radius = cv.minEnclosingCircle(p_4_cv)
|
220 |
+
mini_Enclosing_Cir_x = scale_move_x(xmin, xmin_abs, scale)
|
221 |
+
mini_Enclosing_Cir_y = scale_move_y(ymin, ymin_abs, scale)
|
222 |
+
|
223 |
+
mini_Enclosing_Cir_radius = scale_perimeter(radius, scale)
|
224 |
+
outer_radius = mini_Enclosing_Cir_radius
|
225 |
+
return outer_radius
|
226 |
+
|
227 |
+
|
228 |
+
|
229 |
+
def inner_radius(apa_geo, apa_line):
|
230 |
+
# [Inner_radius] ---> https://www.sthu.org/blog/14-skeleton-offset-topology/index.html
|
231 |
+
dis_p = []
|
232 |
+
for i in np.arange(0, apa_line.length, 0.1):
|
233 |
+
s = substring(apa_line, i, i+0.1)
|
234 |
+
dis_p.append(s.boundary.geoms[0])
|
235 |
+
mp = MultiPoint(dis_p)
|
236 |
+
|
237 |
+
regions = voronoi_diagram(mp)
|
238 |
+
|
239 |
+
vo_p = []
|
240 |
+
for i in range(len(regions.geoms)):
|
241 |
+
vo = regions.geoms[i]
|
242 |
+
b = list(vo.exterior.coords)
|
243 |
+
for j in range(len(b)):
|
244 |
+
p = Point(b[j])
|
245 |
+
vo_p.append(p)
|
246 |
+
vo_p = MultiPoint(vo_p)
|
247 |
+
vo_p = unary_union(vo_p)
|
248 |
+
vo_p_b = []
|
249 |
+
for i in range(len(vo_p.geoms)):
|
250 |
+
t_c_p = vo_p.geoms[i]
|
251 |
+
pc = apa_geo.contains(t_c_p)
|
252 |
+
vo_p_b.append(pc)
|
253 |
+
vo_filtered_p = [i for indx,i in enumerate(vo_p.geoms) if vo_p_b[indx] == True]
|
254 |
+
|
255 |
+
vo_d = []
|
256 |
+
for i in range(len(vo_filtered_p)):
|
257 |
+
c = Point(vo_filtered_p[i])
|
258 |
+
d_min = c.distance(apa_line)
|
259 |
+
vo_d.append(d_min)
|
260 |
+
|
261 |
+
vo_r_max = max(vo_d)
|
262 |
+
vo_r_max_index = vo_d.index(vo_r_max)
|
263 |
+
vo_c_max = vo_filtered_p[vo_r_max_index]
|
264 |
+
vo_c_max = list(vo_c_max.coords)
|
265 |
+
|
266 |
+
max_Inner_Circle_x = vo_c_max[0][0]
|
267 |
+
max_Inner_Circle_y = vo_c_max[0][1]
|
268 |
+
max_Inner_Circle_r = vo_r_max
|
269 |
+
inner_radius = max_Inner_Circle_r
|
270 |
+
return inner_radius
|
271 |
+
|
272 |
+
|
273 |
+
|
274 |
+
def roundness_features(apa_line):
|
275 |
+
# [Dist_mean, Dist_sigma, Roundness] ---> https://www.mvtec.com/doc/halcon/12/en/roundness.html
|
276 |
+
rou_p = []
|
277 |
+
for i in np.arange(0, apa_line.length, 0.5):
|
278 |
+
s = substring(apa_line, i, i+0.5)
|
279 |
+
rou_p.append(s.boundary.geoms[0])
|
280 |
+
rp = MultiPoint(rou_p)
|
281 |
+
|
282 |
+
ro_dis = []
|
283 |
+
for i in range(len(rp.geoms)):
|
284 |
+
rpp = rp.geoms[i]
|
285 |
+
ro = Point(rpp).distance(Point((0, 0)))
|
286 |
+
ro_dis.append(ro)
|
287 |
+
|
288 |
+
dist_mean = mean(ro_dis)
|
289 |
+
# dist_sigma = stdev(ro_dis)
|
290 |
+
|
291 |
+
dev_lst = []
|
292 |
+
for i in ro_dis:
|
293 |
+
dev = (i - dist_mean)**2
|
294 |
+
dev_lst.append(dev)
|
295 |
+
dist_sigma = mean(dev_lst)
|
296 |
+
dist_sigma = math.sqrt(dist_sigma)
|
297 |
+
roundness = 1 - (dist_sigma/dist_mean)
|
298 |
+
|
299 |
+
return dist_mean, dist_sigma, roundness
|
300 |
+
|
301 |
+
|
302 |
+
def compactness(apa_geo):
|
303 |
+
# [Compactness] ---> https://fisherzachary.github.io/public/r-output.html
|
304 |
+
Area = apartment_area(apa_geo)
|
305 |
+
Perimeter = apartment_perimeter(apa_geo)
|
306 |
+
|
307 |
+
compactness = (4*(math.pi)) * (Area / (Perimeter**2))
|
308 |
+
return compactness
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
def equivalent_diameter(apa_geo):
|
313 |
+
# https://docs.opencv.org/4.x/d1/d32/tutorial_py_contour_properties.html
|
314 |
+
Area = apartment_area(apa_geo)
|
315 |
+
|
316 |
+
equivalent_diameter = math.sqrt((4 * Area) / math.pi)
|
317 |
+
return equivalent_diameter
|
318 |
+
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
def shape_membership_index(apa_line):
|
323 |
+
# [Shape_membership_index] ---> Basaraner, M. and Cetinkaya, S. (2017) ‘Performance of shape indices and classification schemes for characterising perceptual shape complexity of building footprints in GIS’, International Journal of Geographical Information Science, 31(10), pp. 1952–1977. doi:10.1080/13658816.2017.1346257.
|
324 |
+
|
325 |
+
line_smi = LineString([(0, 0), (30, 0)])
|
326 |
+
|
327 |
+
numl = 30
|
328 |
+
line_rot_degree = 360 / numl
|
329 |
+
line_rot = []
|
330 |
+
for an in range(numl):
|
331 |
+
ang = an*line_rot_degree
|
332 |
+
lr = affinity.rotate(line_smi, ang, (0, 0))
|
333 |
+
line_rot.append(lr)
|
334 |
+
line_rot = MultiLineString(line_rot)
|
335 |
+
smip = shapely.intersection(apa_line, line_rot)
|
336 |
+
|
337 |
+
|
338 |
+
simo_dis = []
|
339 |
+
for i in range(len(smip.geoms)):
|
340 |
+
sim_p = smip.geoms[i]
|
341 |
+
simo = Point(sim_p).distance(Point((0, 0)))
|
342 |
+
simo_dis.append(simo)
|
343 |
+
sim_r_max = max(simo_dis)
|
344 |
+
|
345 |
+
simo_maxd = []
|
346 |
+
for j in simo_dis:
|
347 |
+
rmax_d = j / sim_r_max
|
348 |
+
simo_maxd.append(rmax_d)
|
349 |
+
|
350 |
+
simo_maxd_mean = mean(simo_maxd)
|
351 |
+
|
352 |
+
simo_rad = []
|
353 |
+
for j in range(len(simo_dis)):
|
354 |
+
s = simo_dis[j]
|
355 |
+
|
356 |
+
if j == (len(simo_dis) - 1):
|
357 |
+
nu = 0
|
358 |
+
else:
|
359 |
+
nu = j+1
|
360 |
+
e = simo_dis[nu]
|
361 |
+
|
362 |
+
if s <= e:
|
363 |
+
a = np.array([1,s])
|
364 |
+
b = np.array([0,s])
|
365 |
+
c = np.array([1,e])
|
366 |
+
else:
|
367 |
+
a = np.array([1,e])
|
368 |
+
b = np.array([0,e])
|
369 |
+
c = np.array([1,s])
|
370 |
+
|
371 |
+
ba = a - b
|
372 |
+
bc = c - b
|
373 |
+
|
374 |
+
cosine_angle = np.dot(ba, bc) / (np.linalg.norm(ba) * np.linalg.norm(bc))
|
375 |
+
angle_rad = np.arccos(cosine_angle)
|
376 |
+
|
377 |
+
simo_rad.append(angle_rad)
|
378 |
+
|
379 |
+
simo_rad_min = min(simo_rad)
|
380 |
+
simo_rad_max = max(simo_rad)
|
381 |
+
simo_cos = math.cos(simo_rad_max - simo_rad_min)
|
382 |
+
shape_membership_index = simo_cos * simo_maxd_mean
|
383 |
+
return shape_membership_index
|
384 |
+
|
385 |
+
|
386 |
+
def convexity(p_4_cv, apa_geo, xmin_abs, ymin_abs, scale):
|
387 |
+
# [Convexity] ---> Basaraner, M. and Cetinkaya, S. (2017) ‘Performance of shape indices and classification schemes for characterising perceptual shape complexity of building footprints in GIS’, International Journal of Geographical Information Science, 31(10), pp. 1952–1977. doi:10.1080/13658816.2017.1346257.
|
388 |
+
|
389 |
+
hull = cv.convexHull(p_4_cv)
|
390 |
+
hull_x = []
|
391 |
+
hull_y = []
|
392 |
+
for h in range(len(hull)):
|
393 |
+
h_x = hull[h][0][0]
|
394 |
+
h_x = scale_move_x(h_x, xmin_abs, scale)
|
395 |
+
hull_x.append(h_x)
|
396 |
+
|
397 |
+
h_y = hull[h][0][1]
|
398 |
+
h_y = scale_move_y(h_y, ymin_abs, scale)
|
399 |
+
hull_y.append(h_y)
|
400 |
+
|
401 |
+
hull_xy = []
|
402 |
+
for i in range(len(hull_x)):
|
403 |
+
hx = hull_x[i]
|
404 |
+
hy = hull_y[i]
|
405 |
+
hull_xy.append((hx, hy))
|
406 |
+
hull_geo = Polygon(hull_xy)
|
407 |
+
Hull_area = hull_geo.area
|
408 |
+
|
409 |
+
Area = apartment_area(apa_geo)
|
410 |
+
convexity = Area / Hull_area
|
411 |
+
return convexity, hull_geo
|
412 |
+
|
413 |
+
|
414 |
+
|
415 |
+
def rectangle_features(p_4_cv, apa_geo, xmin_abs, ymin_abs, scale):
|
416 |
+
# [Rectangularity] ---> Basaraner, M. and Cetinkaya, S. (2017) ‘Performance of shape indices and classification schemes for characterising perceptual shape complexity of building footprints in GIS’, International Journal of Geographical Information Science, 31(10), pp. 1952–1977. doi:10.1080/13658816.2017.1346257.
|
417 |
+
rect = cv.minAreaRect(p_4_cv)
|
418 |
+
miniRect_rotation_angle = rect[2]
|
419 |
+
box = cv.boxPoints(rect)
|
420 |
+
box = np.intp(box)
|
421 |
+
|
422 |
+
miniRect_x = []
|
423 |
+
miniRect_y = []
|
424 |
+
for b in range(len(box)):
|
425 |
+
|
426 |
+
b_x = box[b][0]
|
427 |
+
b_x = scale_move_x(b_x, xmin_abs, scale)
|
428 |
+
miniRect_x.append(b_x)
|
429 |
+
|
430 |
+
b_y = box[b][1]
|
431 |
+
b_y = scale_move_y(b_y, ymin_abs, scale)
|
432 |
+
miniRect_y.append(b_y)
|
433 |
+
|
434 |
+
miniRec_xy = []
|
435 |
+
for i in range(len(miniRect_x)):
|
436 |
+
minirecx = miniRect_x[i]
|
437 |
+
minirecy = miniRect_y[i]
|
438 |
+
miniRec_xy.append((minirecx, minirecy))
|
439 |
+
miniRect_geo = Polygon(miniRec_xy)
|
440 |
+
miniRect_area = miniRect_geo.area
|
441 |
+
|
442 |
+
Area = apartment_area(apa_geo)
|
443 |
+
rectangularity = Area / miniRect_area
|
444 |
+
rect_phi = (miniRect_rotation_angle * math.pi) / 180
|
445 |
+
|
446 |
+
miniRect_line = miniRect_geo.boundary
|
447 |
+
miniRect_segments = segments(miniRect_line)
|
448 |
+
|
449 |
+
seg_len = []
|
450 |
+
for s in miniRect_segments:
|
451 |
+
seg_len.append(s.length)
|
452 |
+
rect_width = max(seg_len)
|
453 |
+
rect_height = min(seg_len)
|
454 |
+
return rectangularity, rect_phi, rect_width, rect_height
|
455 |
+
|
456 |
+
|
457 |
+
def squareness(apa_geo):
|
458 |
+
# [Squareness] ---> Basaraner, M. and Cetinkaya, S. (2017) ‘Performance of shape indices and classification schemes for characterising perceptual shape complexity of building footprints in GIS’, International Journal of Geographical Information Science, 31(10), pp. 1952–1977. doi:10.1080/13658816.2017.1346257.
|
459 |
+
|
460 |
+
Area = apartment_area(apa_geo)
|
461 |
+
Perimeter = apartment_perimeter(apa_geo)
|
462 |
+
|
463 |
+
squareness = (4*(math.sqrt(Area))) / Perimeter
|
464 |
+
return squareness
|
465 |
+
|
466 |
+
|
467 |
+
|
468 |
+
def moments(apa_geo):
|
469 |
+
# https://leancrew.com/all-this/2018/01/python-module-for-section-properties/
|
470 |
+
pts = list(apa_geo.exterior.coords)
|
471 |
+
|
472 |
+
if pts[0] != pts[-1]:
|
473 |
+
pts = pts + pts[:1]
|
474 |
+
x = [ c[0] for c in pts ]
|
475 |
+
y = [ c[1] for c in pts ]
|
476 |
+
sxx = syy = sxy = 0
|
477 |
+
a = apartment_area(apa_geo)
|
478 |
+
cx = apa_geo.centroid.x
|
479 |
+
cy = apa_geo.centroid.y
|
480 |
+
for i in range(len(pts) - 1):
|
481 |
+
sxx += (y[i]**2 + y[i]*y[i+1] + y[i+1]**2)*(x[i]*y[i+1] - x[i+1]*y[i])
|
482 |
+
syy += (x[i]**2 + x[i]*x[i+1] + x[i+1]**2)*(x[i]*y[i+1] - x[i+1]*y[i])
|
483 |
+
sxy += (x[i]*y[i+1] + 2*x[i]*y[i] + 2*x[i+1]*y[i+1] + x[i+1]*y[i])*(x[i]*y[i+1] - x[i+1]*y[i])
|
484 |
+
return sxx/12 - a*cy**2, syy/12 - a*cx**2, sxy/24 - a*cx*cy
|
485 |
+
|
486 |
+
|
487 |
+
def moment_index(apa_geo, Convexity, Compactness):
|
488 |
+
# https://www.researchgate.net/publication/228557311_A_COMBINED_AUTOMATED_GENERALIZATION_MODEL_BASED_ON_THE_RELATIVE_FORCES_BETWEEN_SPATIAL_OBJECTS
|
489 |
+
Ixx, Iyy, Ixy = moments(apa_geo)
|
490 |
+
ratio = max(Ixx, Iyy) / min(Ixx, Iyy)
|
491 |
+
# Convexity, Hull_geo = convexity(p_4_cv)
|
492 |
+
# Compactness = compactness(apa_geo)
|
493 |
+
moment_index = (Convexity * Compactness) / ratio
|
494 |
+
return moment_index
|
495 |
+
|
496 |
+
|
497 |
+
|
498 |
+
def ndetour_index(apa_geo, Hull_geo):
|
499 |
+
# [nDetour_index] ---> Basaraner, M. and Cetinkaya, S. (2017) ‘Performance of shape indices and classification schemes for characterising perceptual shape complexity of building footprints in GIS’, International Journal of Geographical Information Science, 31(10), pp. 1952–1977. doi:10.1080/13658816.2017.1346257.
|
500 |
+
|
501 |
+
Hull_line = Hull_geo.boundary
|
502 |
+
Hull_length = Hull_line.length
|
503 |
+
Area = apartment_area(apa_geo)
|
504 |
+
ndetour_index = (2 * math.sqrt(Area * math.pi)) / Hull_length
|
505 |
+
return ndetour_index
|
506 |
+
|
507 |
+
|
508 |
+
def ncohesion_index(apa_geo, grid_points):
|
509 |
+
# [nCohesion_index] ---> Basaraner, M. and Cetinkaya, S. (2017) ‘Performance of shape indices and classification schemes for characterising perceptual shape complexity of building footprints in GIS’, International Journal of Geographical Information Science, 31(10), pp. 1952–1977. doi:10.1080/13658816.2017.1346257.
|
510 |
+
|
511 |
+
grid_p = grid_points.geoms
|
512 |
+
grid_n = len(grid_p)
|
513 |
+
gg_dis_lst = []
|
514 |
+
for i in grid_p:
|
515 |
+
for j in grid_p:
|
516 |
+
gg_dis = Point(i).distance(Point(j))
|
517 |
+
gg_dis_lst.append(gg_dis)
|
518 |
+
|
519 |
+
Area = apartment_area(apa_geo)
|
520 |
+
ncohesion_index = (0.9054 * math.sqrt(Area / math.pi)) / (sum(gg_dis_lst) / (grid_n * (grid_n-1)))
|
521 |
+
return ncohesion_index
|
522 |
+
|
523 |
+
|
524 |
+
|
525 |
+
def nproximity_nspin_index(apa_geo, grid_points):
|
526 |
+
grid_p = grid_points.geoms
|
527 |
+
|
528 |
+
go_dis_lst = []
|
529 |
+
for i in grid_p:
|
530 |
+
go_dis = Point(i).distance(Point(0,0))
|
531 |
+
go_dis_lst.append(go_dis)
|
532 |
+
|
533 |
+
go_dis_mean = mean(go_dis_lst)
|
534 |
+
Area = apartment_area(apa_geo)
|
535 |
+
nproximity_index = ((2 / 3) * math.sqrt(Area / math.pi)) / go_dis_mean
|
536 |
+
|
537 |
+
nspin_index = (0.5 * (Area / math.pi)) / (go_dis_mean**2)
|
538 |
+
|
539 |
+
return nproximity_index, nspin_index
|
540 |
+
|
541 |
+
|
542 |
+
|
543 |
+
def nexchange_index(apa_geo):
|
544 |
+
Area = apartment_area(apa_geo)
|
545 |
+
|
546 |
+
eac_r = math.sqrt(Area / math.pi)
|
547 |
+
eac = Point(0,0).buffer(eac_r)
|
548 |
+
eac_inter = apa_geo.intersection(eac)
|
549 |
+
|
550 |
+
if eac_inter.geom_type == "Polygon":
|
551 |
+
eac_area = eac_inter.area
|
552 |
+
else:
|
553 |
+
eacga_lst = []
|
554 |
+
for i in range(len(eac_inter.geoms)):
|
555 |
+
eacg = eac_inter.geoms[i]
|
556 |
+
eacga = eacg.area
|
557 |
+
eacga_lst.append(eacga)
|
558 |
+
eac_area = sum(eacga_lst)
|
559 |
+
nexchange_index = eac_area / Area
|
560 |
+
return nexchange_index
|
561 |
+
|
562 |
+
|
563 |
+
|
564 |
+
def nperimeter_index(apa_geo):
|
565 |
+
Area = apartment_area(apa_geo)
|
566 |
+
Perimeter = apartment_perimeter(apa_geo)
|
567 |
+
|
568 |
+
nperimeter_index = (2 * math.sqrt(math.pi * Area)) / Perimeter
|
569 |
+
return nperimeter_index
|
570 |
+
|
571 |
+
|
572 |
+
|
573 |
+
def ndepth_index(apa_geo, apa_line, grid_points):
|
574 |
+
moved_apa_line = apa_line
|
575 |
+
grid_p = grid_points.geoms
|
576 |
+
|
577 |
+
nea_len_lst = []
|
578 |
+
for i in grid_p:
|
579 |
+
nea_line = LineString(nearest_points(moved_apa_line, i))
|
580 |
+
nea_len = nea_line.length
|
581 |
+
nea_len_lst.append(nea_len)
|
582 |
+
nea_len_mean = mean(nea_len_lst)
|
583 |
+
|
584 |
+
Area = apartment_area(apa_geo)
|
585 |
+
ndepth_index = (3 * nea_len_mean) / math.sqrt(Area / math.pi)
|
586 |
+
return ndepth_index
|
587 |
+
|
588 |
+
|
589 |
+
|
590 |
+
def ngirth_index(apa_geo, Inner_radius):
|
591 |
+
Area = apartment_area(apa_geo)
|
592 |
+
|
593 |
+
ngirth_index = Inner_radius / math.sqrt(Area / math.pi)
|
594 |
+
return ngirth_index
|
595 |
+
|
596 |
+
|
597 |
+
|
598 |
+
def nrange_index(apa_geo, Outer_radius):
|
599 |
+
Area = apartment_area(apa_geo)
|
600 |
+
|
601 |
+
nrange_index = math.sqrt(Area / math.pi) / Outer_radius
|
602 |
+
return nrange_index
|
603 |
+
|
604 |
+
|
605 |
+
|
606 |
+
def ntraversal_index(apa_geo, apa_line):
|
607 |
+
rou_p = []
|
608 |
+
for i in np.arange(0, apa_line.length, 0.5):
|
609 |
+
s = substring(apa_line, i, i+0.5)
|
610 |
+
rou_p.append(s.boundary.geoms[0])
|
611 |
+
rp = MultiPoint(rou_p)
|
612 |
+
|
613 |
+
rp_n = len(rp.geoms)
|
614 |
+
bb_dis_lst = []
|
615 |
+
for i in rp.geoms:
|
616 |
+
for j in rp.geoms:
|
617 |
+
bb_dis = Point(i).distance(Point(j))
|
618 |
+
bb_dis_lst.append(bb_dis)
|
619 |
+
bb_dis_mean = sum(bb_dis_lst) / (rp_n * (rp_n-1))
|
620 |
+
|
621 |
+
Area = apartment_area(apa_geo)
|
622 |
+
|
623 |
+
ntraversal_index = (4 * (math.sqrt(Area / math.pi) / math.pi)) / bb_dis_mean
|
624 |
+
return ntraversal_index
|
625 |
+
|