first_implementation_of_RNN
Browse files
matching_chains_classificationRNNS.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
3 |
+
import statistics
|
4 |
+
import pandas as pd
|
5 |
+
from sklearn.ensemble import RandomForestClassifier
|
6 |
+
from sklearn.metrics import classification_report, confusion_matrix, \
|
7 |
+
accuracy_score, roc_auc_score, roc_curve, f1_score, recall_score, precision_score
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import copy
|
10 |
+
from sklearn import preprocessing, tree
|
11 |
+
from sklearn.linear_model import LogisticRegression, LinearRegression
|
12 |
+
from sklearn.tree import DecisionTreeClassifier
|
13 |
+
from scipy.spatial import distance
|
14 |
+
from sklearn.naive_bayes import GaussianNB
|
15 |
+
import itertools
|
16 |
+
import os
|
17 |
+
from sklearn.model_selection import train_test_split
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import random
|
20 |
+
from sklearn.utils import shuffle
|
21 |
+
from imblearn.under_sampling import NearMiss,TomekLinks
|
22 |
+
from imblearn.over_sampling import SMOTE
|
23 |
+
from collections import Counter
|
24 |
+
from imblearn.combine import SMOTETomek, SMOTEENN
|
25 |
+
from sklearn.model_selection import StratifiedKFold
|
26 |
+
from imblearn.pipeline import make_pipeline
|
27 |
+
|
28 |
+
from matplotlib import pyplot
|
29 |
+
from scipy import interp
|
30 |
+
from sklearn.metrics import roc_curve,auc
|
31 |
+
|
32 |
+
#keras
|
33 |
+
from keras.models import Sequential
|
34 |
+
from keras.layers import Dense, SimpleRNN, LSTM
|
35 |
+
|
36 |
+
# Read ComE node embs per timestep [id, emb]
|
37 |
+
|
38 |
+
folder = os.listdir('ComE_per_timestep/embs')
|
39 |
+
path = 'ComE_per_timestep/embs'
|
40 |
+
ComE_id_embs = []
|
41 |
+
for file in folder:
|
42 |
+
ComE_id_embs.append(np.genfromtxt(os.path.join(path, file), dtype=None).tolist())
|
43 |
+
|
44 |
+
# Read ComE labels per timestep
|
45 |
+
|
46 |
+
folder = os.listdir('ComE_per_timestep/labels_pred')
|
47 |
+
path = 'ComE_per_timestep/labels_pred'
|
48 |
+
ComE_lbls = []
|
49 |
+
for file in folder:
|
50 |
+
ComE_lbls.append(np.genfromtxt(os.path.join(path, file), dtype=None).tolist())
|
51 |
+
|
52 |
+
# Node ids per timestep
|
53 |
+
|
54 |
+
node_ids = []
|
55 |
+
for step in ComE_id_embs:
|
56 |
+
tmp = [id_emb[0] for id_emb in step]
|
57 |
+
node_ids.append(tmp)
|
58 |
+
|
59 |
+
# [Node_id, clr] per timestep
|
60 |
+
|
61 |
+
id_clr = []
|
62 |
+
for i in range(len(node_ids)):
|
63 |
+
tmp = {}
|
64 |
+
for ind,node in enumerate(node_ids[i]):
|
65 |
+
tmp[node] = ComE_lbls[i][ind]
|
66 |
+
id_clr.append(tmp)
|
67 |
+
|
68 |
+
# Clustered nodes per timestep
|
69 |
+
clustered_nodes_init = []
|
70 |
+
for ind,i in enumerate(id_clr):
|
71 |
+
clrids_uniq = set(i.values())
|
72 |
+
d = {}
|
73 |
+
for clrid in clrids_uniq:
|
74 |
+
d[clrid] = [k for k in i.keys() if i[k] == clrid]
|
75 |
+
clustered_nodes_init.append(d)
|
76 |
+
|
77 |
+
clustered_nodes = []
|
78 |
+
for s in clustered_nodes_init:
|
79 |
+
per_step = []
|
80 |
+
for k,v in sorted(s.items()):
|
81 |
+
per_step.append(v)
|
82 |
+
clustered_nodes.append(per_step)
|
83 |
+
|
84 |
+
# ------------------------------ READ FEATURES -------------------------------
|
85 |
+
|
86 |
+
# ComE FEATURES
|
87 |
+
|
88 |
+
folder = os.listdir('ComE_features_per_timestep/')
|
89 |
+
path = 'ComE_features_per_timestep/'
|
90 |
+
id_ComE_feats_clr = []
|
91 |
+
id_ComE_feats_out = []
|
92 |
+
id_ComE_feats_gbl = []
|
93 |
+
id_ComE_feats_clrout = []
|
94 |
+
id_ComE_feats_clrgbl = []
|
95 |
+
id_ComE_feats_all = []
|
96 |
+
for file in folder:
|
97 |
+
df_ComE = pd.read_csv(os.path.join(path,file), names=['node_id', \
|
98 |
+
'distin_med_eucl', 'distin_med_cos', 'distin_med_l1',\
|
99 |
+
'distout_med_eucl', 'distout_med_cos', 'distout_med_l1',\
|
100 |
+
'distin_eucl_max', 'distin_eucl_min', 'distin_eucl_avg',\
|
101 |
+
'distin_cos_max', 'distin_cos_min', 'distin_cos_avg',\
|
102 |
+
'distin_l1_max', 'distin_l1_min', 'distin_l1_avg',\
|
103 |
+
'distout_eucl_max', 'distout_eucl_min', 'distout_eucl_avg',\
|
104 |
+
'distout_cos_max', 'distout_cos_min', 'distout_cos_avg',\
|
105 |
+
'distout_l1_max', 'distout_l1_min', 'distout_l1_avg', \
|
106 |
+
'dist_glob_max_eucl', 'dist_glob_min_eucl', 'dist_glob_avg_eucl', \
|
107 |
+
'dist_glob_max_cos', 'dist_glob_min_cos', 'dist_glob_avg_cos', \
|
108 |
+
'dist_glob_max_l1', 'dist_glob_min_l1', 'dist_glob_avg_l1'], skiprows=1)
|
109 |
+
df_ComE_clr = df_ComE[['node_id', 'distin_med_eucl', \
|
110 |
+
'distin_eucl_max', 'distin_eucl_min', 'distin_eucl_avg']]
|
111 |
+
df_ComE_out = df_ComE[['node_id', 'distout_med_eucl', \
|
112 |
+
'distout_eucl_max', 'distout_eucl_min', 'distout_eucl_avg']]
|
113 |
+
df_ComE_gbl = df_ComE[['node_id', 'distout_med_eucl', \
|
114 |
+
'dist_glob_max_eucl', 'dist_glob_min_eucl', 'dist_glob_avg_eucl']]
|
115 |
+
df_ComE_clrout = df_ComE[['node_id', 'distin_med_eucl', 'distout_med_eucl', \
|
116 |
+
'distin_eucl_max', 'distin_eucl_min', 'distin_eucl_avg', \
|
117 |
+
'distout_eucl_max', 'distout_eucl_min', 'distout_eucl_avg']]
|
118 |
+
df_ComE_clrgbl = df_ComE[['node_id', 'distin_med_eucl', \
|
119 |
+
'distin_eucl_max', 'distin_eucl_min', 'distin_eucl_avg', \
|
120 |
+
'dist_glob_max_eucl', 'dist_glob_min_eucl', 'dist_glob_avg_eucl']]
|
121 |
+
df_ComE_all = df_ComE[['node_id', 'distin_med_eucl', 'distout_med_eucl', \
|
122 |
+
'distin_eucl_max', 'distin_eucl_min', 'distin_eucl_avg', \
|
123 |
+
'distout_eucl_max', 'distout_eucl_min', 'distout_eucl_avg', \
|
124 |
+
'dist_glob_max_eucl', 'dist_glob_min_eucl', 'dist_glob_avg_eucl']]
|
125 |
+
df_ComE_clr_lst = df_ComE_clr.values.tolist()
|
126 |
+
df_ComE_out_lst = df_ComE_out.values.tolist()
|
127 |
+
df_ComE_gbl_lst = df_ComE_gbl.values.tolist()
|
128 |
+
df_ComE_clrout_lst = df_ComE_clrout.values.tolist()
|
129 |
+
df_ComE_clrgbl_lst = df_ComE_clrgbl.values.tolist()
|
130 |
+
df_ComE_all_lst = df_ComE_all.values.tolist()
|
131 |
+
id_ComE_feats_clr.append(df_ComE_clr_lst)
|
132 |
+
id_ComE_feats_out.append(df_ComE_out_lst)
|
133 |
+
id_ComE_feats_gbl.append(df_ComE_gbl_lst)
|
134 |
+
id_ComE_feats_clrout.append(df_ComE_clrout_lst)
|
135 |
+
id_ComE_feats_clrgbl.append(df_ComE_clrgbl_lst)
|
136 |
+
id_ComE_feats_all.append(df_ComE_all_lst)
|
137 |
+
#sort by node id
|
138 |
+
for i in id_ComE_feats_clr:
|
139 |
+
i.sort()
|
140 |
+
for i in id_ComE_feats_out:
|
141 |
+
i.sort()
|
142 |
+
for i in id_ComE_feats_gbl:
|
143 |
+
i.sort()
|
144 |
+
for i in id_ComE_feats_clrout:
|
145 |
+
i.sort()
|
146 |
+
for i in id_ComE_feats_clrgbl:
|
147 |
+
i.sort()
|
148 |
+
for i in id_ComE_feats_all:
|
149 |
+
i.sort()
|
150 |
+
|
151 |
+
# Classic FEATURES
|
152 |
+
|
153 |
+
folder = os.listdir('classic_features_per_timestep/classic_features')
|
154 |
+
path = 'classic_features_per_timestep/classic_features'
|
155 |
+
id_classic_clr = []
|
156 |
+
id_classic_gbl = []
|
157 |
+
id_classic_all = []
|
158 |
+
id_classic_nodeg = []
|
159 |
+
for file in folder:
|
160 |
+
df_classic = pd.read_csv(os.path.join(path,file), names=['node_id', \
|
161 |
+
'degree', 'betweenness', 'closeness', 'eigenvector', \
|
162 |
+
'degree_ntwk', 'betweenness_ntwk', 'closeness_ntwk', 'eigenvector_ntwk'], \
|
163 |
+
skiprows=1)
|
164 |
+
df_classic_clr = df_classic[['node_id', \
|
165 |
+
'degree', 'betweenness', 'closeness', 'eigenvector']]
|
166 |
+
df_classic_gbl = df_classic[['node_id', \
|
167 |
+
'degree_ntwk', 'betweenness_ntwk', 'closeness_ntwk', 'eigenvector_ntwk']]
|
168 |
+
df_classic_nodeg = pd.read_csv(os.path.join(path,file), names=['node_id', \
|
169 |
+
'betweenness', 'closeness', 'eigenvector', \
|
170 |
+
'betweenness_ntwk', 'closeness_ntwk', 'eigenvector_ntwk'], \
|
171 |
+
skiprows=1)
|
172 |
+
df_classic_all_lst = df_classic.values.tolist()
|
173 |
+
df_classic_clr_lst = df_classic_clr.values.tolist()
|
174 |
+
id_classic_gbl_lst = df_classic_gbl.values.tolist()
|
175 |
+
id_classic_nodeg_lst = df_classic_nodeg.values.tolist()
|
176 |
+
id_classic_all.append(df_classic_all_lst)
|
177 |
+
id_classic_clr.append(df_classic_clr_lst)
|
178 |
+
id_classic_gbl.append(id_classic_gbl_lst)
|
179 |
+
id_classic_nodeg.append(id_classic_nodeg_lst)
|
180 |
+
#sort by node id
|
181 |
+
for i in id_classic_all:
|
182 |
+
i.sort()
|
183 |
+
for i in id_classic_clr:
|
184 |
+
i.sort()
|
185 |
+
for i in id_classic_gbl:
|
186 |
+
i.sort()
|
187 |
+
for i in id_classic_nodeg:
|
188 |
+
i.sort()
|
189 |
+
|
190 |
+
id_combo_ComE_clrout_classic_all = []
|
191 |
+
for ind,s in enumerate(id_ComE_feats_clrout):
|
192 |
+
temp = []
|
193 |
+
for inx,row in enumerate(s):
|
194 |
+
tmp = row[:]
|
195 |
+
tmp.extend(id_classic_all[ind][inx][1:])
|
196 |
+
temp.append(tmp)
|
197 |
+
id_combo_ComE_clrout_classic_all.append(temp)
|
198 |
+
|
199 |
+
#-------------------------------- MATCHING ------------------------------------
|
200 |
+
|
201 |
+
# [clr_x_tn, clr_y_tn+1, common_nodes_tn_tn+1]
|
202 |
+
#print(clustered_nodes[0])
|
203 |
+
matching = []
|
204 |
+
a = 0
|
205 |
+
while a<len(clustered_nodes)-1:
|
206 |
+
matching_two = []
|
207 |
+
for indcurr,clrcurr in enumerate(clustered_nodes[a]):
|
208 |
+
tmp = []
|
209 |
+
for indnxt,clrnxt in enumerate(clustered_nodes[a+1]):
|
210 |
+
num_of_common = len(list(set(clrcurr)&set(clrnxt)))
|
211 |
+
tmp.append([indcurr,indnxt,num_of_common])
|
212 |
+
tmp_max = max(item[-1] for item in tmp)
|
213 |
+
for t in tmp:
|
214 |
+
if t[-1] == tmp_max:
|
215 |
+
maxtmp = t
|
216 |
+
matching_two.append(maxtmp)
|
217 |
+
matching.append(matching_two)
|
218 |
+
a += 1
|
219 |
+
#print(matching,a)
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
|
224 |
+
#--------------------------------- CHAINS -------------------------------------
|
225 |
+
#************ SCD #Stay-Change-Drop
|
226 |
+
# 2-chain
|
227 |
+
def twoChain_scd(features):
|
228 |
+
two_chain_scd = []
|
229 |
+
for ind,step in enumerate(matching[:-1]):
|
230 |
+
per_step = []
|
231 |
+
for inx,clr in enumerate(step):
|
232 |
+
for nodeid in clustered_nodes[ind][clr[0]]:
|
233 |
+
tmp = [nodeid]
|
234 |
+
for idfeatures in features[ind]:
|
235 |
+
if nodeid == idfeatures[0]:
|
236 |
+
tmp.extend(idfeatures[1:])
|
237 |
+
if nodeid in clustered_nodes[ind+1][clr[1]]:
|
238 |
+
tmp.append(0)#stay
|
239 |
+
#for idfeatures in features[ind+1]: remove second set of features
|
240 |
+
#if nodeid == idfeatures[0]:
|
241 |
+
#tmp.extend(idfeatures[1:])
|
242 |
+
'''for cl in matching[ind+1]:
|
243 |
+
if nodeid in clustered_nodes[ind+1][cl[0]]:
|
244 |
+
if nodeid in clustered_nodes[ind+2][cl[1]]:
|
245 |
+
tmp.append(0)#stay
|
246 |
+
break
|
247 |
+
elif nodeid in node_ids[ind+2]:
|
248 |
+
tmp.append(1)#move
|
249 |
+
break
|
250 |
+
else:
|
251 |
+
tmp.append(2)#drop
|
252 |
+
break'''
|
253 |
+
elif nodeid in node_ids[ind+1]:
|
254 |
+
tmp.append(1)#move
|
255 |
+
#for idfeatures in features[ind+1]:
|
256 |
+
#if nodeid == idfeatures[0]:
|
257 |
+
#tmp.extend(idfeatures[1:])
|
258 |
+
#for cl in matching[ind+1]:
|
259 |
+
#if nodeid in clustered_nodes[ind+1][cl[0]]:
|
260 |
+
#if nodeid in clustered_nodes[ind+2][cl[1]]:
|
261 |
+
#tmp.append(0)#stay
|
262 |
+
# break
|
263 |
+
#elif nodeid in node_ids[ind+2]:
|
264 |
+
# tmp.append(1)#move
|
265 |
+
# break
|
266 |
+
# else:
|
267 |
+
# tmp.append(2)#drop
|
268 |
+
# break
|
269 |
+
else:
|
270 |
+
# tmp.extend([-1]*(len(features[0][0][1:])+1)) remove extend vasia
|
271 |
+
tmp.append(2)#drop
|
272 |
+
per_step.append(tmp)
|
273 |
+
two_chain_scd.append(per_step)
|
274 |
+
return(two_chain_scd)
|
275 |
+
|
276 |
+
def chains_scd(prev_chain_scd, features, a):
|
277 |
+
curr_chain_scd = copy.deepcopy(prev_chain_scd[:-1])
|
278 |
+
for ind,step in enumerate(curr_chain_scd):
|
279 |
+
for row in step:
|
280 |
+
if row[-1] == 0 or row[-1] == 1:
|
281 |
+
for idfeatures in features[ind+2+a]:
|
282 |
+
if row[0] == idfeatures[0]:
|
283 |
+
row.extend(idfeatures[1:])
|
284 |
+
for cl in matching[ind+2+a]:
|
285 |
+
if row[0] in clustered_nodes[ind+2+a][cl[0]]:
|
286 |
+
if row[0] in clustered_nodes[ind+3+a][cl[1]]:
|
287 |
+
row.append(0)#stay
|
288 |
+
break
|
289 |
+
elif row[0] in node_ids[ind+3+a]:
|
290 |
+
row.append(1)#move
|
291 |
+
break
|
292 |
+
else:
|
293 |
+
row.append(2)#drop
|
294 |
+
break
|
295 |
+
else:
|
296 |
+
row[-1:-1] = [-1]*(len(features[0][0][1:])+1)#add -1*(#feats + ev)
|
297 |
+
return(curr_chain_scd)
|
298 |
+
|
299 |
+
# ----------------------------------------------------------------------------
|
300 |
+
|
301 |
+
#************ SL
|
302 |
+
def chains_sl(chainsSCD): #Stay-Leave
|
303 |
+
chainsSL = copy.deepcopy(chainsSCD)
|
304 |
+
for row in chainsSL:
|
305 |
+
if row[-1] == 2:
|
306 |
+
row[-1] = 1
|
307 |
+
return(chainsSL)
|
308 |
+
|
309 |
+
# ----------------------------------------------------------------------------
|
310 |
+
|
311 |
+
#************ SC #Stay-Change
|
312 |
+
def chains_sc(chainsSCD):
|
313 |
+
chainsSC = []
|
314 |
+
for row in chainsSCD:
|
315 |
+
if row[-1] != 2:
|
316 |
+
chainsSC.append(row)
|
317 |
+
return(chainsSC)
|
318 |
+
|
319 |
+
# ----------------------------------------------------------------------------
|
320 |
+
|
321 |
+
def per_chain_all_chains_scd(feats):
|
322 |
+
two_chain_scd = twoChain_scd(feats)
|
323 |
+
three_chain_scd = chains_scd(two_chain_scd, feats, 0)
|
324 |
+
four_chain_scd = chains_scd(three_chain_scd, feats, 1)
|
325 |
+
five_chain_scd = chains_scd(four_chain_scd, feats, 2)
|
326 |
+
six_chain_scd = chains_scd(five_chain_scd, feats, 3)
|
327 |
+
seven_chain_scd = chains_scd(six_chain_scd, feats, 4)
|
328 |
+
eight_chain_scd = chains_scd(seven_chain_scd, feats, 5)
|
329 |
+
nine_chain_scd = chains_scd(eight_chain_scd, feats, 6)
|
330 |
+
two_chain_scd = [row for s in two_chain_scd for row in s]#flat
|
331 |
+
three_chain_scd = [row for s in three_chain_scd for row in s]#flat
|
332 |
+
four_chain_scd = [row for s in four_chain_scd for row in s]#flat
|
333 |
+
five_chain_scd = [row for s in five_chain_scd for row in s]#flat
|
334 |
+
six_chain_scd = [row for s in six_chain_scd for row in s]#flat
|
335 |
+
seven_chain_scd = [row for s in seven_chain_scd for row in s]#flat
|
336 |
+
eight_chain_scd = [row for s in eight_chain_scd for row in s]#flat
|
337 |
+
nine_chain_scd = [row for s in nine_chain_scd for row in s]#flat
|
338 |
+
# merge chains
|
339 |
+
all_chains_scd = []
|
340 |
+
all_chains_scd.append(two_chain_scd)
|
341 |
+
all_chains_scd.append(three_chain_scd)
|
342 |
+
all_chains_scd.append(four_chain_scd)
|
343 |
+
all_chains_scd.append(five_chain_scd)
|
344 |
+
all_chains_scd.append(six_chain_scd)
|
345 |
+
all_chains_scd.append(seven_chain_scd)
|
346 |
+
all_chains_scd.append(eight_chain_scd)
|
347 |
+
all_chains_scd.append(nine_chain_scd)
|
348 |
+
all_chains_scd = [row for chain in all_chains_scd for row in chain]
|
349 |
+
return(two_chain_scd, three_chain_scd, four_chain_scd, five_chain_scd, \
|
350 |
+
six_chain_scd, seven_chain_scd, eight_chain_scd, nine_chain_scd, \
|
351 |
+
all_chains_scd)
|
352 |
+
|
353 |
+
# ----------------------------------------------------------------------------
|
354 |
+
|
355 |
+
# CHAINS ----------------------
|
356 |
+
# ComE
|
357 |
+
|
358 |
+
# clr
|
359 |
+
two_chain_ComE_clr_scd, three_chain_ComE_clr_scd, four_chain_ComE_clr_scd, \
|
360 |
+
five_chain_ComE_clr_scd, six_chain_ComE_clr_scd, seven_chain_ComE_clr_scd, \
|
361 |
+
eight_chain_ComE_clr_scd, nine_chain_ComE_clr_scd, \
|
362 |
+
chains_ComE_clr_scd = per_chain_all_chains_scd(id_ComE_feats_clr)
|
363 |
+
# per chain
|
364 |
+
two_chain_ComE_clr_sl = chains_sl(two_chain_ComE_clr_scd)
|
365 |
+
two_chain_ComE_clr_sc = chains_sc(two_chain_ComE_clr_scd)
|
366 |
+
three_chain_ComE_clr_sl = chains_sl(three_chain_ComE_clr_scd)
|
367 |
+
three_chain_ComE_clr_sc = chains_sc(three_chain_ComE_clr_scd)
|
368 |
+
four_chain_ComE_clr_sl = chains_sl(four_chain_ComE_clr_scd)
|
369 |
+
four_chain_ComE_clr_sc = chains_sc(four_chain_ComE_clr_scd)
|
370 |
+
five_chain_ComE_clr_sl = chains_sl(five_chain_ComE_clr_scd)
|
371 |
+
five_chain_ComE_clr_sc = chains_sc(five_chain_ComE_clr_scd)
|
372 |
+
six_chain_ComE_clr_sl = chains_sl(six_chain_ComE_clr_scd)
|
373 |
+
six_chain_ComE_clr_sc = chains_sc(six_chain_ComE_clr_scd)
|
374 |
+
seven_chain_ComE_clr_sl = chains_sl(seven_chain_ComE_clr_scd)
|
375 |
+
seven_chain_ComE_clr_sc = chains_sc(seven_chain_ComE_clr_scd)
|
376 |
+
eight_chain_ComE_clr_sl = chains_sl(eight_chain_ComE_clr_scd)
|
377 |
+
eight_chain_ComE_clr_sc = chains_sc(eight_chain_ComE_clr_scd)
|
378 |
+
nine_chain_ComE_clr_sl = chains_sl(nine_chain_ComE_clr_scd)
|
379 |
+
nine_chain_ComE_clr_sc = chains_sc(nine_chain_ComE_clr_scd)
|
380 |
+
# SL
|
381 |
+
chains_ComE_clr_sl = chains_sl(chains_ComE_clr_scd)
|
382 |
+
# SC
|
383 |
+
chains_ComE_clr_sc = chains_sc(chains_ComE_clr_scd)
|
384 |
+
|
385 |
+
# out
|
386 |
+
two_chain_ComE_out_scd, three_chain_ComE_out_scd, four_chain_ComE_out_scd, \
|
387 |
+
five_chain_ComE_out_scd, six_chain_ComE_out_scd, seven_chain_ComE_out_scd, \
|
388 |
+
eight_chain_ComE_out_scd, nine_chain_ComE_out_scd, \
|
389 |
+
chains_ComE_out_scd = per_chain_all_chains_scd(id_ComE_feats_out)
|
390 |
+
# per chain
|
391 |
+
two_chain_ComE_out_sl = chains_sl(two_chain_ComE_out_scd)
|
392 |
+
two_chain_ComE_out_sc = chains_sc(two_chain_ComE_out_scd)
|
393 |
+
three_chain_ComE_out_sl = chains_sl(three_chain_ComE_out_scd)
|
394 |
+
three_chain_ComE_out_sc = chains_sc(three_chain_ComE_out_scd)
|
395 |
+
four_chain_ComE_out_sl = chains_sl(four_chain_ComE_out_scd)
|
396 |
+
four_chain_ComE_out_sc = chains_sc(four_chain_ComE_out_scd)
|
397 |
+
five_chain_ComE_out_sl = chains_sl(five_chain_ComE_out_scd)
|
398 |
+
five_chain_ComE_out_sc = chains_sc(five_chain_ComE_out_scd)
|
399 |
+
six_chain_ComE_out_sl = chains_sl(six_chain_ComE_out_scd)
|
400 |
+
six_chain_ComE_out_sc = chains_sc(six_chain_ComE_out_scd)
|
401 |
+
seven_chain_ComE_out_sl = chains_sl(seven_chain_ComE_out_scd)
|
402 |
+
seven_chain_ComE_out_sc = chains_sc(seven_chain_ComE_out_scd)
|
403 |
+
eight_chain_ComE_out_sl = chains_sl(eight_chain_ComE_out_scd)
|
404 |
+
eight_chain_ComE_out_sc = chains_sc(eight_chain_ComE_out_scd)
|
405 |
+
nine_chain_ComE_out_sl = chains_sl(nine_chain_ComE_out_scd)
|
406 |
+
nine_chain_ComE_out_sc = chains_sc(nine_chain_ComE_out_scd)
|
407 |
+
# SL
|
408 |
+
chains_ComE_out_sl = chains_sl(chains_ComE_out_scd)
|
409 |
+
# SC
|
410 |
+
chains_ComE_out_sc = chains_sc(chains_ComE_out_scd)
|
411 |
+
|
412 |
+
# clrout
|
413 |
+
two_chain_ComE_clrout_scd, three_chain_ComE_clrout_scd, four_chain_ComE_clrout_scd, \
|
414 |
+
five_chain_ComE_clrout_scd, six_chain_ComE_clrout_scd, seven_chain_ComE_clrout_scd, \
|
415 |
+
eight_chain_ComE_clrout_scd, nine_chain_ComE_clrout_scd, \
|
416 |
+
chains_ComE_clrout_scd = per_chain_all_chains_scd(id_ComE_feats_clrout)
|
417 |
+
# per chain
|
418 |
+
two_chain_ComE_clrout_sl = chains_sl(two_chain_ComE_clrout_scd)
|
419 |
+
two_chain_ComE_clrout_sc = chains_sc(two_chain_ComE_clrout_scd)
|
420 |
+
three_chain_ComE_clrout_sl = chains_sl(three_chain_ComE_clrout_scd)
|
421 |
+
three_chain_ComE_clrout_sc = chains_sc(three_chain_ComE_clrout_scd)
|
422 |
+
four_chain_ComE_clrout_sl = chains_sl(four_chain_ComE_clrout_scd)
|
423 |
+
four_chain_ComE_clrout_sc = chains_sc(four_chain_ComE_clrout_scd)
|
424 |
+
five_chain_ComE_clrout_sl = chains_sl(five_chain_ComE_clrout_scd)
|
425 |
+
five_chain_ComE_clrout_sc = chains_sc(five_chain_ComE_clrout_scd)
|
426 |
+
six_chain_ComE_clrout_sl = chains_sl(six_chain_ComE_clrout_scd)
|
427 |
+
six_chain_ComE_clrout_sc = chains_sc(six_chain_ComE_clrout_scd)
|
428 |
+
seven_chain_ComE_clrout_sl = chains_sl(seven_chain_ComE_clrout_scd)
|
429 |
+
seven_chain_ComE_clrout_sc = chains_sc(seven_chain_ComE_clrout_scd)
|
430 |
+
eight_chain_ComE_clrout_sl = chains_sl(eight_chain_ComE_clrout_scd)
|
431 |
+
eight_chain_ComE_clrout_sc = chains_sc(eight_chain_ComE_clrout_scd)
|
432 |
+
nine_chain_ComE_clrout_sl = chains_sl(nine_chain_ComE_clrout_scd)
|
433 |
+
nine_chain_ComE_clrout_sc = chains_sc(nine_chain_ComE_clrout_scd)
|
434 |
+
# SL
|
435 |
+
chains_ComE_clrout_sl = chains_sl(chains_ComE_clrout_scd)
|
436 |
+
# SC
|
437 |
+
chains_ComE_clrout_sc = chains_sc(chains_ComE_clrout_scd)
|
438 |
+
|
439 |
+
|
440 |
+
|
441 |
+
# ----------------------------------------------------------------------------
|
442 |
+
|
443 |
+
# Classic
|
444 |
+
#clr
|
445 |
+
# SCD
|
446 |
+
two_chain_classic_clr_scd, three_chain_classic_clr_scd, four_chain_classic_clr_scd, \
|
447 |
+
five_chain_classic_clr_scd, six_chain_classic_clr_scd, seven_chain_classic_clr_scd, \
|
448 |
+
eight_chain_classic_clr_scd, nine_chain_classic_clr_scd, \
|
449 |
+
chains_classic_clr_scd = per_chain_all_chains_scd(id_classic_clr)
|
450 |
+
# per chain
|
451 |
+
two_chain_classic_clr_sl = chains_sl(two_chain_classic_clr_scd)
|
452 |
+
two_chain_classic_clr_sc = chains_sc(two_chain_classic_clr_scd)
|
453 |
+
three_chain_classic_clr_sl = chains_sl(three_chain_classic_clr_scd)
|
454 |
+
three_chain_classic_clr_sc = chains_sc(three_chain_classic_clr_scd)
|
455 |
+
four_chain_classic_clr_sl = chains_sl(four_chain_classic_clr_scd)
|
456 |
+
four_chain_classic_clr_sc = chains_sc(four_chain_classic_clr_scd)
|
457 |
+
five_chain_classic_clr_sl = chains_sl(five_chain_classic_clr_scd)
|
458 |
+
five_chain_classic_clr_sc = chains_sc(five_chain_classic_clr_scd)
|
459 |
+
six_chain_classic_clr_sl = chains_sl(six_chain_classic_clr_scd)
|
460 |
+
six_chain_classic_clr_sc = chains_sc(six_chain_classic_clr_scd)
|
461 |
+
seven_chain_classic_clr_sl = chains_sl(seven_chain_classic_clr_scd)
|
462 |
+
seven_chain_classic_clr_sc = chains_sc(seven_chain_classic_clr_scd)
|
463 |
+
eight_chain_classic_clr_sl = chains_sl(eight_chain_classic_clr_scd)
|
464 |
+
eight_chain_classic_clr_sc = chains_sc(eight_chain_classic_clr_scd)
|
465 |
+
nine_chain_classic_clr_sl = chains_sl(nine_chain_classic_clr_scd)
|
466 |
+
nine_chain_classic_clr_sc = chains_sc(nine_chain_classic_clr_scd)
|
467 |
+
|
468 |
+
# SL
|
469 |
+
chains_classic_clr_sl = chains_sl(chains_classic_clr_scd)
|
470 |
+
# SC
|
471 |
+
chains_classic_clr_sc = chains_sc(chains_classic_clr_scd)
|
472 |
+
|
473 |
+
#gbl
|
474 |
+
# SCD
|
475 |
+
two_chain_classic_gbl_scd, three_chain_classic_gbl_scd, four_chain_classic_gbl_scd, \
|
476 |
+
five_chain_classic_gbl_scd, six_chain_classic_gbl_scd, seven_chain_classic_gbl_scd, \
|
477 |
+
eight_chain_classic_gbl_scd, nine_chain_classic_gbl_scd, \
|
478 |
+
chains_classic_gbl_scd = per_chain_all_chains_scd(id_classic_gbl)
|
479 |
+
# per chain
|
480 |
+
two_chain_classic_gbl_sl = chains_sl(two_chain_classic_gbl_scd)
|
481 |
+
two_chain_classic_gbl_sc = chains_sc(two_chain_classic_gbl_scd)
|
482 |
+
three_chain_classic_gbl_sl = chains_sl(three_chain_classic_gbl_scd)
|
483 |
+
three_chain_classic_gbl_sc = chains_sc(three_chain_classic_gbl_scd)
|
484 |
+
four_chain_classic_gbl_sl = chains_sl(four_chain_classic_gbl_scd)
|
485 |
+
four_chain_classic_gbl_sc = chains_sc(four_chain_classic_gbl_scd)
|
486 |
+
five_chain_classic_gbl_sl = chains_sl(five_chain_classic_gbl_scd)
|
487 |
+
five_chain_classic_gbl_sc = chains_sc(five_chain_classic_gbl_scd)
|
488 |
+
six_chain_classic_gbl_sl = chains_sl(six_chain_classic_gbl_scd)
|
489 |
+
six_chain_classic_gbl_sc = chains_sc(six_chain_classic_gbl_scd)
|
490 |
+
seven_chain_classic_gbl_sl = chains_sl(seven_chain_classic_gbl_scd)
|
491 |
+
seven_chain_classic_gbl_sc = chains_sc(seven_chain_classic_gbl_scd)
|
492 |
+
eight_chain_classic_gbl_sl = chains_sl(eight_chain_classic_gbl_scd)
|
493 |
+
eight_chain_classic_gbl_sc = chains_sc(eight_chain_classic_gbl_scd)
|
494 |
+
nine_chain_classic_gbl_sl = chains_sl(nine_chain_classic_gbl_scd)
|
495 |
+
nine_chain_classic_gbl_sc = chains_sc(nine_chain_classic_gbl_scd)
|
496 |
+
|
497 |
+
# SL
|
498 |
+
chains_classic_gbl_sl = chains_sl(chains_classic_gbl_scd)
|
499 |
+
# SC
|
500 |
+
chains_classic_gbl_sc = chains_sc(chains_classic_gbl_scd)
|
501 |
+
|
502 |
+
#all
|
503 |
+
# SCD
|
504 |
+
two_chain_classic_all_scd, three_chain_classic_all_scd, four_chain_classic_all_scd, \
|
505 |
+
five_chain_classic_all_scd, six_chain_classic_all_scd, seven_chain_classic_all_scd, \
|
506 |
+
eight_chain_classic_all_scd, nine_chain_classic_all_scd, \
|
507 |
+
chains_classic_all_scd = per_chain_all_chains_scd(id_classic_all)
|
508 |
+
# per chain
|
509 |
+
two_chain_classic_all_sl = chains_sl(two_chain_classic_all_scd)
|
510 |
+
two_chain_classic_all_sc = chains_sc(two_chain_classic_all_scd)
|
511 |
+
three_chain_classic_all_sl = chains_sl(three_chain_classic_all_scd)
|
512 |
+
three_chain_classic_all_sc = chains_sc(three_chain_classic_all_scd)
|
513 |
+
four_chain_classic_all_sl = chains_sl(four_chain_classic_all_scd)
|
514 |
+
four_chain_classic_all_sc = chains_sc(four_chain_classic_all_scd)
|
515 |
+
five_chain_classic_all_sl = chains_sl(five_chain_classic_all_scd)
|
516 |
+
five_chain_classic_all_sc = chains_sc(five_chain_classic_all_scd)
|
517 |
+
six_chain_classic_all_sl = chains_sl(six_chain_classic_all_scd)
|
518 |
+
six_chain_classic_all_sc = chains_sc(six_chain_classic_all_scd)
|
519 |
+
seven_chain_classic_all_sl = chains_sl(seven_chain_classic_all_scd)
|
520 |
+
seven_chain_classic_all_sc = chains_sc(seven_chain_classic_all_scd)
|
521 |
+
eight_chain_classic_all_sl = chains_sl(eight_chain_classic_all_scd)
|
522 |
+
eight_chain_classic_all_sc = chains_sc(eight_chain_classic_all_scd)
|
523 |
+
nine_chain_classic_all_sl = chains_sl(nine_chain_classic_all_scd)
|
524 |
+
nine_chain_classic_all_sc = chains_sc(nine_chain_classic_all_scd)
|
525 |
+
|
526 |
+
# SL
|
527 |
+
chains_classic_all_sl = chains_sl(chains_classic_all_scd)
|
528 |
+
# SC
|
529 |
+
chains_classic_all_sc = chains_sc(chains_classic_all_scd)
|
530 |
+
|
531 |
+
|
532 |
+
#RNN model
|
533 |
+
def create_RNN(hidden_units, dense_units, input_shape, activation):
|
534 |
+
model=Sequential()
|
535 |
+
model.add(LSTM(hidden_units,input_shape=input_shape))
|
536 |
+
model.add(Dense(units=dense_units,activation=activation))
|
537 |
+
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
|
538 |
+
model.summary()
|
539 |
+
return model
|
540 |
+
|
541 |
+
# stratified kfold
|
542 |
+
def Classification(chain):
|
543 |
+
#print("chain1 = ",len(chain),chain[0])
|
544 |
+
|
545 |
+
chain = shuffle(np.array(chain))
|
546 |
+
print("chain2 = ",chain.shape)
|
547 |
+
X = [i[1:-1] for i in chain.tolist()]
|
548 |
+
Y = [i[-1] for i in chain.tolist()]
|
549 |
+
|
550 |
+
# padding
|
551 |
+
longest = len(max(X,key=len))
|
552 |
+
print(longest)
|
553 |
+
|
554 |
+
for row in X:
|
555 |
+
while len(row)<longest:
|
556 |
+
row.append(-99)
|
557 |
+
#######
|
558 |
+
X = np.array(X)
|
559 |
+
Y = np.array(Y)
|
560 |
+
print('Y_dataset:', Counter(Y))
|
561 |
+
skf = StratifiedKFold(n_splits=5)
|
562 |
+
fold = 0
|
563 |
+
k=0
|
564 |
+
cvscores = []
|
565 |
+
for train_index, test_index in skf.split(X, Y):
|
566 |
+
X_train, X_test = X[train_index], X[test_index]
|
567 |
+
|
568 |
+
y_train, y_test = Y[train_index], Y[test_index]
|
569 |
+
y_train_cnt = pd.DataFrame([Counter(y_train)]).transpose()
|
570 |
+
print(y_train_cnt)
|
571 |
+
y_train_cnt.sort_index(inplace=True)
|
572 |
+
print('y_train:', y_train_cnt)
|
573 |
+
# print("X",X_train.shape,X_train[0])
|
574 |
+
|
575 |
+
X_train_3d = X_train.reshape((X_train.shape[0], 4, 1))
|
576 |
+
X_test_3d = X_test.reshape((X_test.shape[0], 4, 1))
|
577 |
+
# print(X_test_3d)
|
578 |
+
|
579 |
+
rnn_model = create_RNN(100,2,input_shape=(4,1),activation='sigmoid')
|
580 |
+
|
581 |
+
rnn_model.fit(X_train_3d, y_train, epochs=50, batch_size=5,verbose=2)
|
582 |
+
#test_pred = rnn_model.predict(X_test_3d)
|
583 |
+
|
584 |
+
scores = rnn_model.evaluate(X_test_3d, y_test, verbose=0)
|
585 |
+
print("score",scores)
|
586 |
+
exit()
|
587 |
+
print("%s: %.2f%%" % (rnn_model.metrics_names[1], scores[1]*100))
|
588 |
+
cvscores.append(scores[1] * 100)
|
589 |
+
print("%.2f%% (+/- %.2f%%)" % (np.mean(cvscores), np.std(cvscores)))
|
590 |
+
|
591 |
+
|
592 |
+
|
593 |
+
|
594 |
+
# Classic features
|
595 |
+
# clr - 1 chain
|
596 |
+
with open('results_details.csv','a') as fd:
|
597 |
+
fd.write('two_chain_classic_clr_sl'+'\n')
|
598 |
+
with open('results.csv','a') as fd:
|
599 |
+
fd.write('two_chain_classic_clr_sl'+'\n')
|
600 |
+
Classification(two_chain_classic_clr_sl)
|
601 |
+
|
602 |
+
|