yjwtheonly commited on
Commit
ac7c391
1 Parent(s): bdc453c
DiseaseSpecific/KG_extractor.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import torch
3
+ import numpy as np
4
+ from torch.autograd import Variable
5
+ from sklearn import metrics
6
+
7
+ import datetime
8
+ from typing import Dict, Tuple, List
9
+ import logging
10
+ import os
11
+ import utils
12
+ import pickle as pkl
13
+ import json
14
+ import torch.backends.cudnn as cudnn
15
+
16
+ from tqdm import tqdm
17
+
18
+ import sys
19
+ sys.path.append("..")
20
+ import Parameters
21
+
22
+ parser = utils.get_argument_parser()
23
+ parser = utils.add_attack_parameters(parser)
24
+ parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART')
25
+ parser.add_argument('--action', type=str, default='parse', help='parse or extract')
26
+ parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
27
+ args = parser.parse_args()
28
+ args = utils.set_hyperparams(args)
29
+
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ utils.seed_all(args.seed)
33
+ np.set_printoptions(precision=5)
34
+ cudnn.benchmark = False
35
+
36
+ data_path = os.path.join('processed_data', args.data)
37
+ target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
38
+ attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}.txt'.format(args.model,
39
+ args.target_split,
40
+ args.target_size,
41
+ 'exists:'+str(args.target_existed),
42
+ args.neighbor_num,
43
+ args.candidate_mode,
44
+ args.attack_goal,
45
+ str(args.reasonable_rate)))
46
+ modified_attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}.txt'.format(args.model,
47
+ args.target_split,
48
+ args.target_size,
49
+ 'exists:'+str(args.target_existed),
50
+ args.neighbor_num,
51
+ args.candidate_mode,
52
+ args.attack_goal,
53
+ str(args.reasonable_rate),
54
+ args.mode))
55
+ attack_data = utils.load_data(attack_path, drop=False)
56
+ #%%
57
+ with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
58
+ id_to_meshid = json.load(fl)
59
+ with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
60
+ meshid_to_id = json.load(fl)
61
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
62
+ entity_raw_name = pkl.load(fl)
63
+ with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
64
+ retieve_sentence_through_edgetype = pkl.load(fl)
65
+ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
66
+ raw_text_sen = pkl.load(fl)
67
+ with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
68
+ full_entity_raw_name = pkl.load(fl)
69
+ for k, v in entity_raw_name.items():
70
+ assert v in full_entity_raw_name[k]
71
+
72
+ #find unique
73
+ once_set = set()
74
+ twice_set = set()
75
+
76
+ with open('generate_abstract/valid_entity.json', 'r') as fl:
77
+ valid_entity = json.load(fl)
78
+ valid_entity = set(valid_entity)
79
+
80
+ good_name = set()
81
+ for k, v, in full_entity_raw_name.items():
82
+ names = list(v)
83
+ for name in names:
84
+ # if name == 'in a':
85
+ # print(names)
86
+ good_name.add(name)
87
+ # if name not in once_set:
88
+ # once_set.add(name)
89
+ # else:
90
+ # twice_set.add(name)
91
+ # assert 'WNK4' in once_set
92
+ # good_name = set.difference(once_set, twice_set)
93
+ # assert 'in a' not in good_name
94
+ # assert 'STE20' not in good_name
95
+ # assert 'STE20' not in valid_entity
96
+ # assert 'STE20-related proline-alanine-rich kinase' not in good_name
97
+ # assert 'STE20-related proline-alanine-rich kinase' not in valid_entity
98
+ # raise Exception
99
+
100
+ name_to_type = {}
101
+ name_to_meshid = {}
102
+
103
+ for k, v, in full_entity_raw_name.items():
104
+ names = list(v)
105
+ for name in names:
106
+ if name in good_name:
107
+ name_to_type[name] = k.split('_')[0]
108
+ name_to_meshid[name] = k
109
+
110
+ import spacy
111
+ import networkx as nx
112
+ import pprint
113
+
114
+ def check(p, s):
115
+
116
+ if p < 1 or p >= len(s):
117
+ return True
118
+ return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9'))
119
+
120
+ def raw_to_format(sen):
121
+
122
+ text = sen
123
+ l = 0
124
+ ret = []
125
+ while(l < len(text)):
126
+ bo =False
127
+ if text[l] != ' ':
128
+ for i in range(len(text), l, -1): # reversing is important !!!
129
+ cc = text[l:i]
130
+ if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text):
131
+ ret.append(cc.replace(' ', '_'))
132
+ l = i
133
+ bo = True
134
+ break
135
+ if not bo:
136
+ ret.append(text[l])
137
+ l += 1
138
+ return ''.join(ret)
139
+
140
+ if args.mode == 'sentence':
141
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_chat.json', 'r') as fl:
142
+ draft = json.load(fl)
143
+ elif args.mode == 'finetune':
144
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence_finetune.json', 'r') as fl:
145
+ draft = json.load(fl)
146
+ elif args.mode == 'bioBART':
147
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl:
148
+ draft = json.load(fl)
149
+ elif args.mode == 'biogpt':
150
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_biogpt.json', 'r') as fl:
151
+ draft = json.load(fl)
152
+ else:
153
+ raise Exception('No!!!')
154
+
155
+ nlp = spacy.load("en_core_web_sm")
156
+
157
+ type_set = set()
158
+ for aa in range(36):
159
+ dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
160
+ tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
161
+ dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
162
+ for dependency in dependencys:
163
+ dep_list = dependency.split(' ')
164
+ for sub_dep in dep_list:
165
+ sub_dep_list = sub_dep.split('|')
166
+ assert(len(sub_dep_list) == 3)
167
+ type_set.add(sub_dep_list[1])
168
+ # print('Type:', type_set)
169
+
170
+ if args.action == 'parse':
171
+ # dp_path, sen_list = list(dependency_sen_dict.items())[0]
172
+ # check
173
+ # paper_id, sen_id = sen_list[0]
174
+ # sen = raw_text_sen[paper_id][sen_id]
175
+ # doc = nlp(sen['text'])
176
+ # print(dp_path, '\n')
177
+ # pprint.pprint(sen)
178
+ # print()
179
+ # for token in doc:
180
+ # print((token.head.text, token.text, token.dep_))
181
+
182
+ out = ''
183
+ for k, v_dict in draft.items():
184
+ input = v_dict['in']
185
+ output = v_dict['out']
186
+ if input == '':
187
+ continue
188
+ output = output.replace('\n', ' ')
189
+ doc = nlp(output)
190
+ for sen in doc.sents:
191
+ out += raw_to_format(sen.text) + '\n'
192
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl:
193
+ fl.write(out)
194
+ elif args.action == 'extract':
195
+
196
+ # dependency_to_type_id = {}
197
+ # for k, v in Parameters.edge_type_to_id.items():
198
+ # dependency_to_type_id[k] = {}
199
+ # for type in v:
200
+ # LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys())
201
+ # for dp in LL:
202
+ # dependency_to_type_id[k][dp] = type
203
+ if os.path.exists('generate_abstract/dependency_to_type_id.pickle'):
204
+ with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl:
205
+ dependency_to_type_id = pkl.load(fl)
206
+ else:
207
+ dependency_to_type_id = {}
208
+ print('Loading path data ...')
209
+ for k in Parameters.edge_type_to_id.keys():
210
+ start, end = k.split('-')
211
+ dependency_to_type_id[k] = {}
212
+ inner_edge_type_to_id = Parameters.edge_type_to_id[k]
213
+ inner_edge_type_dict = Parameters.edge_type_dict[k]
214
+ cal_manual_num = [0] * len(inner_edge_type_to_id)
215
+ with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl:
216
+ for i, line in tqdm(list(enumerate(fl.readlines()))):
217
+ tmp = line.split('\t')
218
+ if i == 0:
219
+ head = [tmp[i] for i in range(1, len(tmp), 2)]
220
+ assert ' '.join(head) == ' '.join(inner_edge_type_dict[0])
221
+ continue
222
+ probability = [float(tmp[i]) for i in range(1, len(tmp), 2)]
223
+ flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)]
224
+ indices = np.where(np.asarray(flag_list) == 1)[0]
225
+ if len(indices) >= 1:
226
+ tmp_p = [cal_manual_num[i] for i in indices]
227
+ p = indices[np.argmin(tmp_p)]
228
+ cal_manual_num[p] += 1
229
+ else:
230
+ p = np.argmax(probability)
231
+ assert tmp[0].lower() not in dependency_to_type_id.keys()
232
+ dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p]
233
+ with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl:
234
+ pkl.dump(dependency_to_type_id, fl)
235
+
236
+ # record = []
237
+ # with open(f'generate_abstract/par_parseout.txt', 'r') as fl:
238
+ # Tmp = []
239
+ # tmp = []
240
+ # for i,line in enumerate(fl.readlines()):
241
+ # # print(len(line), line)
242
+ # line = line.replace('\n', '')
243
+ # if len(line) > 1:
244
+ # tmp.append(line)
245
+ # else:
246
+ # Tmp.append(tmp)
247
+ # tmp = []
248
+ # if len(Tmp) == 3:
249
+ # record.append(Tmp)
250
+ # Tmp = []
251
+
252
+ # print(len(record))
253
+ # record_index = 0
254
+ # add = 0
255
+ # Attack = []
256
+ # for ii in range(100):
257
+
258
+ # # input = v_dict['in']
259
+ # # output = v_dict['out']
260
+ # # output = output.replace('\n', ' ')
261
+ # s, r, o = attack_data[ii]
262
+ # dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
263
+
264
+ # target_dp = set()
265
+ # for dp_path, sen_list in dependency_sen_dict.items():
266
+ # target_dp.add(dp_path)
267
+ # DP_list = []
268
+ # for _ in range(1):
269
+ # dp_dict = {}
270
+ # data = record[record_index]
271
+ # record_index += 1
272
+ # dp_paths = data[2]
273
+ # nodes_list = []
274
+ # edges_list = []
275
+ # for line in dp_paths:
276
+ # ttp, tmp = line.split('(')
277
+ # assert tmp[-1] == ')'
278
+ # tmp = tmp[:-1]
279
+ # e1, e2 = tmp.split(', ')
280
+ # if not ttp in type_set and ':' in ttp:
281
+ # ttp = ttp.split(':')[0]
282
+ # dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
283
+ # dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
284
+ # nodes_list.append(e1)
285
+ # nodes_list.append(e2)
286
+ # edges_list.append((e1, e2))
287
+ # nodes_list = list(set(nodes_list))
288
+ # pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
289
+ # graph = nx.Graph(edges_list)
290
+
291
+ # type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
292
+ # # print(type_list)
293
+ # # for i in range(len(type_list)):
294
+ # # print(pure_name[i], type_list[i])
295
+ # for i in range(len(nodes_list)):
296
+ # if type_list[i] != '':
297
+ # for j in range(len(nodes_list)):
298
+ # if i != j and type_list[j] != '':
299
+ # if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
300
+ # # print(f'{type_list[i]}_{type_list[j]}')
301
+ # ret_path = []
302
+ # sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
303
+ # start = sp[0]
304
+ # end = sp[-1]
305
+ # for k in range(len(sp)-1):
306
+ # e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
307
+ # if e1 == start:
308
+ # e1 = 'start_entity-x'
309
+ # if e2 == start:
310
+ # e2 = 'start_entity-x'
311
+ # if e1 == end:
312
+ # e1 = 'end_entity-x'
313
+ # if e2 == end:
314
+ # e2 = 'end_entity-x'
315
+ # ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
316
+ # dependency_P = ' '.join(ret_path)
317
+ # DP_list.append((f'{type_list[i]}-{type_list[j]}',
318
+ # name_to_meshid[pure_name[i]],
319
+ # name_to_meshid[pure_name[j]],
320
+ # dependency_P))
321
+
322
+ # boo = False
323
+ # modified_attack = []
324
+ # for k, ss, tt, dp in DP_list:
325
+ # if dp in dependency_to_type_id[k].keys():
326
+ # tp = str(dependency_to_type_id[k][dp])
327
+ # id_ss = str(meshid_to_id[ss])
328
+ # id_tt = str(meshid_to_id[tt])
329
+ # modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
330
+ # if int(dependency_to_type_id[k][dp]) == int(r):
331
+ # # if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
332
+ # boo = True
333
+ # modified_attack = list(set(modified_attack))
334
+ # modified_attack = [k.split('*') for k in modified_attack]
335
+ # if boo:
336
+ # add += 1
337
+ # # else:
338
+ # # print(ii)
339
+
340
+ # # for i in range(len(type_list)):
341
+ # # if type_list[i]:
342
+ # # print(pure_name[i], type_list[i])
343
+ # # for k, ss, tt, dp in DP_list:
344
+ # # print(k, dp)
345
+ # # print(record[record_index - 1])
346
+ # # raise Exception('No!!')
347
+ # Attack.append(modified_attack)
348
+
349
+ record = []
350
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl:
351
+ Tmp = []
352
+ tmp = []
353
+ for i,line in enumerate(fl.readlines()):
354
+ # print(len(line), line)
355
+ line = line.replace('\n', '')
356
+ if len(line) > 1:
357
+ tmp.append(line)
358
+ else:
359
+ if len(Tmp) == 2:
360
+ if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]:
361
+ Tmp.append([])
362
+ record.append(Tmp)
363
+ Tmp = []
364
+ Tmp.append(tmp)
365
+ if len(Tmp) == 2 and tmp[0][:5] != '(ROOT':
366
+ print(record[-1][2])
367
+ raise Exception('??')
368
+ tmp = []
369
+ if len(Tmp) == 3:
370
+ record.append(Tmp)
371
+ Tmp = []
372
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl:
373
+ parsin = fl.readlines()
374
+
375
+ print('Record len', len(record), 'Parsin len:', len(parsin))
376
+ record_index = 0
377
+ add = 0
378
+
379
+ Attack = []
380
+ for ii, (k, v_dict) in enumerate(tqdm(draft.items())):
381
+
382
+ input = v_dict['in']
383
+ output = v_dict['out']
384
+ output = output.replace('\n', ' ')
385
+ s, r, o = attack_data[ii]
386
+ assert ii == int(k.split('_')[-1])
387
+
388
+ DP_list = []
389
+ if input != '':
390
+
391
+ dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
392
+ target_dp = set()
393
+ for dp_path, sen_list in dependency_sen_dict.items():
394
+ target_dp.add(dp_path)
395
+ doc = nlp(output)
396
+
397
+ for sen in doc.sents:
398
+ dp_dict = {}
399
+ if record_index >= len(record):
400
+ break
401
+ data = record[record_index]
402
+ record_index += 1
403
+ dp_paths = data[2]
404
+ nodes_list = []
405
+ edges_list = []
406
+ for line in dp_paths:
407
+ aa = line.split('(')
408
+ if len(aa) == 1:
409
+ print(ii)
410
+ print(sen)
411
+ print(data)
412
+ raise Exception
413
+ ttp, tmp = aa[0], aa[1]
414
+ assert tmp[-1] == ')'
415
+ tmp = tmp[:-1]
416
+ e1, e2 = tmp.split(', ')
417
+ if not ttp in type_set and ':' in ttp:
418
+ ttp = ttp.split(':')[0]
419
+ dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
420
+ dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
421
+ nodes_list.append(e1)
422
+ nodes_list.append(e2)
423
+ edges_list.append((e1, e2))
424
+ nodes_list = list(set(nodes_list))
425
+ pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
426
+ graph = nx.Graph(edges_list)
427
+
428
+ type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
429
+ # print(type_list)
430
+ for i in range(len(nodes_list)):
431
+ if type_list[i] != '':
432
+ for j in range(len(nodes_list)):
433
+ if i != j and type_list[j] != '':
434
+ if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
435
+ # print(f'{type_list[i]}_{type_list[j]}')
436
+ ret_path = []
437
+ sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
438
+ start = sp[0]
439
+ end = sp[-1]
440
+ for k in range(len(sp)-1):
441
+ e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
442
+ if e1 == start:
443
+ e1 = 'start_entity-x'
444
+ if e2 == start:
445
+ e2 = 'start_entity-x'
446
+ if e1 == end:
447
+ e1 = 'end_entity-x'
448
+ if e2 == end:
449
+ e2 = 'end_entity-x'
450
+ ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
451
+ dependency_P = ' '.join(ret_path)
452
+ DP_list.append((f'{type_list[i]}-{type_list[j]}',
453
+ name_to_meshid[pure_name[i]],
454
+ name_to_meshid[pure_name[j]],
455
+ dependency_P))
456
+
457
+ boo = False
458
+ modified_attack = []
459
+ for k, ss, tt, dp in DP_list:
460
+ if dp in dependency_to_type_id[k].keys():
461
+ tp = str(dependency_to_type_id[k][dp])
462
+ id_ss = str(meshid_to_id[ss])
463
+ id_tt = str(meshid_to_id[tt])
464
+ modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
465
+ if int(dependency_to_type_id[k][dp]) == int(r):
466
+ if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
467
+ boo = True
468
+ modified_attack = list(set(modified_attack))
469
+ modified_attack = [k.split('*') for k in modified_attack]
470
+ if boo:
471
+ # print(DP_list)
472
+ add += 1
473
+ Attack.append(modified_attack)
474
+ print(add)
475
+ print('End record_index:', record_index)
476
+ with open(modified_attack_path, 'wb') as fl:
477
+ pkl.dump(Attack, fl)
478
+ else:
479
+ raise Exception('Wrong action !!')
DiseaseSpecific/attack.py ADDED
@@ -0,0 +1,841 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import pickle as pkl
3
+ from typing import Dict, Tuple, List
4
+ import os
5
+ import numpy as np
6
+ import json
7
+ import dill
8
+ import logging
9
+ import argparse
10
+ import math
11
+ from pprint import pprint
12
+ import pandas as pd
13
+ from collections import defaultdict
14
+ import copy
15
+ import time
16
+ from tqdm import tqdm
17
+
18
+ import torch
19
+ from torch.utils.data import DataLoader
20
+ import torch.backends.cudnn as cudnn
21
+ import torch.autograd as autograd
22
+ from torch.nn import functional as F
23
+ from torch.nn.modules.loss import CrossEntropyLoss
24
+
25
+ from model import Distmult, Complex, Conve
26
+ import utils
27
+
28
+ import sys
29
+
30
+ import dill
31
+
32
+ sys.path.append("..")
33
+ import Parameters
34
+
35
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
36
+
37
+ logger = None
38
+ def generate_nghbrs_single_entity(x, edge_nghbrs, bound):
39
+
40
+ ret_S = set(x)
41
+ ret_L = [x]
42
+ b = 0
43
+ while(b < len(ret_L)):
44
+ s = ret_L[b]
45
+ if s in edge_nghbrs.keys():
46
+ for v in edge_nghbrs[s]:
47
+ if v not in ret_S:
48
+ ret_S.add(v)
49
+ ret_L.append(v)
50
+ if len(ret_L) == bound:
51
+ return ret_L
52
+ b += 1
53
+ return ret_L
54
+
55
+ def generate_nghbrs(target_data, edge_nghbrs, args):
56
+ n_dict = {}
57
+ for i, (s, r, o) in enumerate(target_data):
58
+ L_s = generate_nghbrs_single_entity(s, edge_nghbrs, args.neighbor_num)
59
+ L_o = generate_nghbrs_single_entity(o, edge_nghbrs, args.neighbor_num)
60
+ n_dict[i] = list(set(L_s + L_o))
61
+ n_dict[i].sort()
62
+ return n_dict
63
+ #%%
64
+ def check_edge(s, r, o, used_trip = None, args = None):
65
+ """Double check"""
66
+ if args is None:
67
+ return True
68
+ if not args.target_existed:
69
+ assert (s+'_'+o in used_trip) == args.target_existed
70
+ else:
71
+ s = entityid_to_nodetype[s]
72
+ o = entityid_to_nodetype[o]
73
+ r_tp = Parameters.edge_id_to_type[int(r)]
74
+ r_tp = r_tp.split(':')[0]
75
+ r_tp = r_tp.split('-')
76
+ assert s == r_tp[0] and o == r_tp[1]
77
+
78
+ def get_model_loss(batch, model, device, args = None):
79
+ s,r,o = batch[:,0], batch[:,1], batch[:,2]
80
+
81
+ emb_s = model.emb_e(s).squeeze(dim=1)
82
+ emb_r = model.emb_rel(r).squeeze(dim=1)
83
+ emb_o = model.emb_e(o).squeeze(dim=1)
84
+
85
+ if args.add_reciprocals:
86
+ r_rev = r + n_rel
87
+ emb_rrev = model.emb_rel(r_rev).squeeze(dim=1)
88
+ else:
89
+ r_rev = r
90
+ emb_rrev = emb_r
91
+
92
+ pred_sr = model.forward(emb_s, emb_r, mode='rhs')
93
+ loss_sr = model.loss(pred_sr, o) # Cross entropy loss
94
+
95
+ pred_or = model.forward(emb_o, emb_rrev, mode='lhs')
96
+ loss_or = model.loss(pred_or, s)
97
+
98
+ train_loss = loss_sr + loss_or
99
+ return train_loss
100
+
101
+ def get_model_loss_without_softmax(batch, model, device=None):
102
+
103
+ with torch.no_grad():
104
+ s,r,o = batch[:,0], batch[:,1], batch[:,2]
105
+
106
+ emb_s = model.emb_e(s).squeeze(dim=1)
107
+ emb_r = model.emb_rel(r).squeeze(dim=1)
108
+
109
+ pred = model.forward(emb_s, emb_r)
110
+ return -pred[range(o.shape[0]), o]
111
+
112
+ def lp_regularizer(model, weight, p):
113
+ trainable_params = [model.emb_e.weight, model.emb_rel.weight]
114
+ norm = 0
115
+ for i in range(len(trainable_params)):
116
+ norm += weight * torch.sum( torch.abs(trainable_params[i]) ** p)
117
+ return norm
118
+
119
+ def n3_regularizer(factors, weight, p):
120
+ norm = 0
121
+ for f in factors:
122
+ norm += weight * torch.sum(torch.abs(f) ** p)
123
+ return norm / factors[0].shape[0]
124
+
125
+ def get_train_loss(batch, model, device, args):
126
+ #batch = batch[0].to(device)
127
+ s,r,o = batch[:,0], batch[:,1], batch[:,2]
128
+
129
+ emb_s = model.emb_e(s).squeeze(dim=1)
130
+ emb_r = model.emb_rel(r).squeeze(dim=1)
131
+ emb_o = model.emb_e(o).squeeze(dim=1)
132
+
133
+ if args.add_reciprocals:
134
+ r_rev = r + n_rel
135
+ emb_rrev = model.emb_rel(r_rev).squeeze(dim=1)
136
+ else:
137
+ r_rev = r
138
+ emb_rrev = emb_r
139
+
140
+ pred_sr = model.forward(emb_s, emb_r, mode='rhs')
141
+ loss_sr = model.loss(pred_sr, o) # loss is cross entropy loss
142
+
143
+ pred_or = model.forward(emb_o, emb_rrev, mode='lhs')
144
+ loss_or = model.loss(pred_or, s)
145
+
146
+ train_loss = loss_sr + loss_or
147
+
148
+ if (args.reg_weight != 0.0 and args.reg_norm == 3):
149
+ #self.logger.info('Computing regularizer weight')
150
+ if model == 'complex':
151
+ emb_dim = args.embedding_dim #int(self.args.embedding_dim/2)
152
+ lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
153
+ rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
154
+ rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
155
+ rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
156
+
157
+ #print(lhs[0].shape, lhs[1].shape)
158
+ factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
159
+ torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
160
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
161
+ )
162
+ factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
163
+ torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
164
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
165
+ )
166
+ else:
167
+ factors_sr = (emb_s, emb_r, emb_o)
168
+ factors_or = (emb_s, emb_rrev, emb_o)
169
+
170
+ train_loss += n3_regularizer(factors_sr, args.reg_weight, p=3)
171
+ train_loss += n3_regularizer(factors_or, args.reg_weight, p=3)
172
+
173
+ if (args.reg_weight != 0.0 and args.reg_norm == 2):
174
+ train_loss += lp_regularizer(model, args.reg_weight, p=2)
175
+
176
+ return train_loss
177
+ def hv(loss, model_params, v):
178
+ grad = autograd.grad(loss, model_params, create_graph=True, retain_graph=True)
179
+ Hv = autograd.grad(grad, model_params, grad_outputs=v)
180
+ return Hv
181
+ def gather_flat_grad(grads):
182
+ views = []
183
+ for p in grads:
184
+ if p.data.is_sparse:
185
+ view = p.data.to_dense().view(-1)
186
+ else:
187
+ view = p.data.view(-1)
188
+ views.append(view)
189
+ return torch.cat(views, 0)
190
+
191
+ def get_inverse_hvp_lissa(v, model, device, param_influence, train_data, args):
192
+
193
+ damping = args.damping
194
+ num_samples = args.lissa_repeat
195
+ scale = args.scale
196
+ train_batch_size = args.lissa_batch_size
197
+ lissa_num_batches = math.ceil(train_data.shape[0]/train_batch_size)
198
+ recursion_depth = int(lissa_num_batches*args.lissa_depth)
199
+
200
+ ihvp = None
201
+ # print('inversing hvp...')
202
+ for i in range(num_samples):
203
+ cur_estimate = v
204
+ #lissa_data_iterator = iter(train_loader)
205
+ input_data = torch.from_numpy(train_data.astype('int64'))
206
+ actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
207
+ del input_data
208
+
209
+ b_begin = 0
210
+ for j in range(recursion_depth):
211
+ model.zero_grad() # same as optimizer.zero_grad()
212
+ if b_begin >= actual_examples.shape[0]:
213
+ b_begin = 0
214
+ input_data = torch.from_numpy(train_data.astype('int64'))
215
+ actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
216
+ del input_data
217
+
218
+ input_batch = actual_examples[b_begin: b_begin + train_batch_size]
219
+ input_batch = input_batch.to(device)
220
+
221
+ train_loss = get_train_loss(input_batch, model, device, args)
222
+
223
+ hvp = hv(train_loss, param_influence, cur_estimate)
224
+ cur_estimate = [_a + (1-damping)*_b - _c / scale for _a, _b, _c in zip(v, cur_estimate, hvp)]
225
+ # if (j%200 == 0) or (j == recursion_depth -1 ):
226
+ # logger.info("Recursion at depth %s: norm is %f" % (j, np.linalg.norm(gather_flat_grad(cur_estimate).cpu().numpy())))
227
+
228
+ b_begin += train_batch_size
229
+
230
+ if ihvp == None:
231
+ ihvp = [_a / scale for _a in cur_estimate]
232
+ else:
233
+ ihvp = [_a + _b / scale for _a, _b in zip(ihvp, cur_estimate)]
234
+
235
+ # logger.info("Final ihvp norm is %f" % (np.linalg.norm(gather_flat_grad(ihvp).cpu().numpy())))
236
+ return_ihvp = gather_flat_grad(ihvp)
237
+ return_ihvp /= num_samples
238
+
239
+ return return_ihvp
240
+
241
+ #%%
242
+ def before_global_attack(device, n_rel, data, target_data, neighbors, model,
243
+ filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
244
+ entityid_to_nodetype, batch_size, args, lissa_path, target_disease):
245
+
246
+ if os.path.exists(lissa_path) and not args.update_lissa:
247
+ with open(lissa_path, 'rb') as fl:
248
+ ret = dill.load(fl)
249
+ return ret
250
+ ret = {}
251
+
252
+ test_data = []
253
+ for i in target_disease:
254
+ tp = entityid_to_nodetype[str(i)]
255
+ # r = torch.LongTensor([[10]]).to(device)
256
+ assert tp == 'disease'
257
+ if tp == 'disease':
258
+ for target in target_data:
259
+ test_data.append([str(target), str(10), str(i)])
260
+ test_data = np.array(test_data)
261
+
262
+ for target_trip in tqdm(test_data):
263
+
264
+ target_trip_ori = target_trip
265
+ trip_name = '_'.join(list(target_trip_ori))
266
+ target_trip = target_trip[None, :] # add a batch dimension
267
+ target_trip = torch.from_numpy(target_trip.astype('int64')).to(device)
268
+ # target_s, target_r, target_o = target_trip[:,0], target_trip[:,1], target_trip[:,2]
269
+ # target_vec = model.score_triples_vec(target_s, target_r, target_o)
270
+
271
+ model.eval()
272
+ model.zero_grad()
273
+ target_loss = get_model_loss(target_trip, model, device)
274
+ target_grads = autograd.grad(target_loss, param_influence)
275
+
276
+ model.train()
277
+ inverse_hvp = get_inverse_hvp_lissa(target_grads, model, device,
278
+ param_influence, data, args)
279
+ model.eval()
280
+ inverse_hvp = inverse_hvp.detach().cpu().unsqueeze(0)
281
+ ret[trip_name] = inverse_hvp
282
+ with open(lissa_path, 'wb') as fl:
283
+ dill.dump(ret, fl)
284
+ return ret
285
+
286
+ def global_addtion_attack(device, n_rel, data, target_data, neighbors, model,
287
+ filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
288
+ entityid_to_nodetype, batch_size, args, lissa, target_disease):
289
+
290
+ logger.info('------ Generating edits per target triple ------')
291
+ start_time = time.time()
292
+ logger.info('Start time: {0}'.format(str(start_time)))
293
+
294
+ used_trip = set()
295
+ print("Processing used triples ...")
296
+ for s, r, o in tqdm(data):
297
+ used_trip.add(s+'_'+o)
298
+ # used_trip.add(o+'_'+s)
299
+ print('Size of used triples:', len(used_trip))
300
+ logger.info('Size of used triples: {0}'.format(len(used_trip)))
301
+
302
+ ret_trip = []
303
+ score_record = []
304
+ real_add_rank_ratio = 0
305
+
306
+ with open(score_path, 'rb') as fl:
307
+ score_record = pkl.load(fl)
308
+ for i, target in enumerate(target_data):
309
+
310
+ print('\n\n------ Attacking target tripid:', i, 'tot:', len(target_data), ' ------')
311
+ # lissa_hvp = []
312
+ target_trip = []
313
+ for disease in target_disease:
314
+ target_trip.append([target, str(10), disease])
315
+ # nm = '{}_{}_{}'.format(target, 10, disease)
316
+ # lissa_hvp.append(lissa[nm])
317
+ # lissa_hvp = torch.cat(lissa_hvp, dim = 0).to(device)
318
+
319
+ target_trip = np.array(target_trip)
320
+ target_trip = torch.from_numpy(target_trip.astype('int64')).to(device)
321
+
322
+ model.eval()
323
+ model.zero_grad()
324
+ target_loss = get_model_loss(target_trip, model, device)
325
+ target_grads = autograd.grad(target_loss, param_influence)
326
+
327
+ model.train()
328
+ inverse_hvp = get_inverse_hvp_lissa(target_grads, model, device,
329
+ param_influence, data, args)
330
+
331
+ model.eval()
332
+
333
+ nghbr_trip = []
334
+ s = str(target)
335
+ tp = entityid_to_nodetype[s]
336
+ for nghbr in tqdm(neighbors):
337
+ o = str(nghbr)
338
+ if s!=o and s+'_'+o not in used_trip:
339
+ for r in range(n_rel):
340
+ if (tp, r) in filters["rhs"].keys() and filters["rhs"][(tp, r)][int(o)] == True:
341
+ nghbr_trip.append([s, str(r), o])
342
+
343
+ nghbr_trip = np.asarray(nghbr_trip)
344
+ influences = []
345
+ edge_losses = []
346
+
347
+ # nghbr_cos_log_prob, nghbr_LM_log_prob = score_record[i]
348
+ # assert nghbr_cos_log_prob.shape[0] == nghbr_trip.shape[0]
349
+
350
+ for train_trip in tqdm(nghbr_trip):
351
+ #model.train() #batch norm cannot be used here
352
+ train_trip = train_trip[None, :] # add batch dim
353
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
354
+ #### L-train gradient ####
355
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
356
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
357
+ model.zero_grad()
358
+ train_loss = get_model_loss(train_trip, model, device, args)
359
+ train_grads = autograd.grad(train_loss, param_influence)
360
+ train_grads = gather_flat_grad(train_grads)
361
+ influence = torch.dot(inverse_hvp, train_grads) #default dim=1
362
+ influences.append(influence.unsqueeze(0).detach())
363
+
364
+ edge_losses = torch.cat(edge_losses, dim = -1)
365
+ influences = torch.cat(influences, dim = -1)
366
+ edge_losses_log_prob = torch.log(F.softmax(-edge_losses, dim = -1))
367
+ influences_log_prob = torch.log(F.softmax(influences, dim = -1))
368
+
369
+ inf_score_sorted, influences_sort = torch.sort(influences_log_prob, -1, descending=True)
370
+ edge_score_sorted, edge_sort = torch.sort(edge_losses_log_prob, -1, descending=True)
371
+ influences_sort = influences_sort.cpu().numpy()
372
+ edge_sort = edge_sort.cpu().numpy()
373
+ inf_score_sorted = inf_score_sorted.cpu().numpy()
374
+ edge_score_sorted = edge_score_sorted.cpu().numpy()
375
+
376
+ logger.info('')
377
+ logger.info('Top 8 inf_score: {}'.format(" ".join(map(str, list(inf_score_sorted[:8])))))
378
+ logger.info('Top 8 edge_score: {}'.format(" ".join(map(str, list(edge_score_sorted[:8])))))
379
+
380
+ nghbr_cos_log_prob = influences_log_prob.detach().cpu().numpy()
381
+ nghbr_LM_log_prob = edge_losses_log_prob.detach().cpu().numpy()
382
+ max_sim = np.max(nghbr_cos_log_prob)
383
+ min_sim = np.min(nghbr_cos_log_prob)
384
+ max_LM = np.max(nghbr_LM_log_prob)
385
+ min_LM = np.min(nghbr_LM_log_prob)
386
+
387
+ # final_score = nghbr_cos_log_prob + nghbr_LM_log_prob
388
+ final_score = nghbr_cos_log_prob
389
+
390
+ index = np.argmax(final_score[:-1])
391
+ # p = np.where(index == edge_sort)[0][0]
392
+ # logger.info('Added edge\'s edge rank ratio: {}'.format(p / edge_sort.shape[0]))
393
+ real_add_rank_ratio += p
394
+ add_trip = nghbr_trip[index]
395
+ logger.info('max_inf: {0:.8}, min_inf: {1:.8}, max_edge: {2:.8}, min_edge: {3:.8}'.format(max_sim, min_sim, max_LM, min_LM))
396
+ logger.info('Attack trip: {0}_{1}_{2}.\n Influnce score: {3:.8}. Edge score: {4:.8}.'.format(add_trip[0], add_trip[1], add_trip[2],
397
+ nghbr_cos_log_prob[index], nghbr_LM_log_prob[index]))
398
+ ret_trip.append(add_trip)
399
+ score_record.append((nghbr_cos_log_prob, nghbr_LM_log_prob))
400
+ real_add_rank_ratio = real_add_rank_ratio / target_data.shape[0]
401
+ logger.info('Mean real ratio: {}.'.format(real_add_rank_ratio))
402
+ return ret_trip, score_record
403
+
404
+ def addition_attack(param_influence, device, n_rel, data, target_data, neighbors, model,
405
+ filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
406
+ entityid_to_nodetype, batch_size, args, load_Record = False, divide_bound = None, data_mean = None, data_std = None, cache_intermidiate = True):
407
+
408
+ if logger:
409
+ logger.info('------ Generating edits per target triple ------')
410
+ start_time = time.time()
411
+ if logger:
412
+ logger.info('Start time: {0}'.format(str(start_time)))
413
+
414
+ used_trip = set()
415
+ print("Processing used triples ...")
416
+ for s, r, o in tqdm(data):
417
+ used_trip.add(s+'_'+o)
418
+ # used_trip.add(o+'_'+s)
419
+ print('Size of used triples:', len(used_trip))
420
+ if logger:
421
+ logger.info('Size of used triples: {0}'.format(len(used_trip)))
422
+
423
+ nghbr_trip_len = []
424
+ ret_trip = []
425
+ score_record = []
426
+ direct_add_rank_ratio = 0
427
+ real_add_rank_ratio = 0
428
+ bad_ratio = 0
429
+
430
+ RRcord = []
431
+ print('****'*10)
432
+ if load_Record:
433
+ print('Load intermidiate file')
434
+ with open(intermidiate_path, 'rb') as fl:
435
+ RRcord = dill.load(fl)
436
+ else:
437
+ print('Donnot load intermidiate file')
438
+
439
+ for i, target_trip in enumerate(target_data):
440
+
441
+ print('\n\n------ Attacking target tripid:', i, ' ------')
442
+ target_nghbrs = neighbors[i]
443
+ for a in target_nghbrs:
444
+ if str(a) == '-1':
445
+ raise Exception('pppp')
446
+
447
+ target_trip_ori = target_trip
448
+ check_edge(target_trip[0], target_trip[1], target_trip[2], used_trip)
449
+ target_trip = target_trip[None, :] # add a batch dimension
450
+ target_trip = torch.from_numpy(target_trip.astype('int64')).to(device)
451
+ # target_s, target_r, target_o = target_trip[:,0], target_trip[:,1], target_trip[:,2]
452
+ # target_vec = model.score_triples_vec(target_s, target_r, target_o)
453
+
454
+ model.eval()
455
+
456
+ if load_Record:
457
+ o_target_trip, nghbr_trip, edge_losses, influences, edge_losses_log_prob, influences_log_prob = RRcord[i]
458
+ assert (o_target_trip.cpu() == target_trip.cpu()).sum().item() == 3
459
+ else:
460
+ model.zero_grad()
461
+ target_loss = get_model_loss(target_trip, model, device, args)
462
+ target_grads = autograd.grad(target_loss, param_influence)
463
+
464
+ model.train()
465
+ inverse_hvp = get_inverse_hvp_lissa(target_grads, model, device,
466
+ param_influence, data, args)
467
+
468
+ model.eval()
469
+ nghbr_trip = []
470
+ valid_trip = 0
471
+ if args.candidate_mode == 'quadratic':
472
+ s_o_list = [(i, j) for i in target_nghbrs for j in target_nghbrs]
473
+ elif args.candidate_mode == 'linear':
474
+ s_o_list = [(j, i) for i in target_nghbrs for j in [target_trip_ori[0], target_trip_ori[2]]] \
475
+ + [(i, j) for i in target_nghbrs for j in [target_trip_ori[0], target_trip_ori[2]]]
476
+ else:
477
+ raise Exception('Wrong candidate_mode: '+args.candidate_mode)
478
+ for s, o in tqdm(s_o_list):
479
+ tp = entityid_to_nodetype[s]
480
+ if s!=o and s+'_'+o not in used_trip:
481
+ for r in range(n_rel):
482
+ if (tp, r) in filters["rhs"].keys() and filters["rhs"][(tp, r)][int(o)] == True:
483
+ # check_edge(s, r, o)
484
+ valid_trip += 1
485
+ nghbr_trip.append([s, str(r), o])
486
+ # logger.info('{0}_{1}_{2}'.format(s, str(r), o))
487
+ nghbr_trip_len.append(len(nghbr_trip))
488
+ print('Valid trip:', valid_trip)
489
+
490
+ if target_trip_ori[0]+'_'+target_trip_ori[2] not in used_trip:
491
+ nghbr_trip.append(target_trip_ori)
492
+ nghbr_trip = np.asarray(nghbr_trip)
493
+ print("Edge scoring ...")
494
+
495
+ influences = []
496
+ edge_losses = []
497
+
498
+ for train_trip in tqdm(nghbr_trip):
499
+ #model.train() #batch norm cannot be used here
500
+ train_trip = train_trip[None, :] # add batch dim
501
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
502
+ #### L-train gradient ####
503
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
504
+ edge_losses.append(edge_loss.unsqueeze(0).detach())
505
+ model.zero_grad()
506
+ train_loss = get_model_loss(train_trip, model, device, args)
507
+ train_grads = autograd.grad(train_loss, param_influence)
508
+ train_grads = gather_flat_grad(train_grads)
509
+ influence = torch.dot(inverse_hvp, train_grads) #default dim=1
510
+ influences.append(influence.unsqueeze(0).detach())
511
+
512
+ edge_losses = torch.cat(edge_losses, dim = -1)
513
+ influences = torch.cat(influences, dim = -1)
514
+ edge_losses_log_prob = torch.log(F.softmax(-edge_losses, dim = -1))
515
+ influences_log_prob = torch.log(F.softmax(influences, dim = -1))
516
+ std_scale = torch.std(edge_losses_log_prob) / torch.std(influences_log_prob)
517
+ influences_log_prob = (influences_log_prob - influences_log_prob.mean()) * std_scale + edge_losses_log_prob.mean()
518
+
519
+ RRcord.append([target_trip.detach(), nghbr_trip, edge_losses, influences, edge_losses_log_prob, influences_log_prob])
520
+
521
+ inf_score_sorted, influences_sort = torch.sort(influences_log_prob, -1, descending=True)
522
+ edge_score_sorted, edge_sort = torch.sort(edge_losses_log_prob, -1, descending=True)
523
+
524
+ influences_sort = influences_sort.cpu().numpy()
525
+ edge_sort = edge_sort.cpu().numpy()
526
+ inf_score_sorted = inf_score_sorted.cpu().numpy()
527
+ edge_score_sorted = edge_score_sorted.cpu().numpy()
528
+ edge_losses = edge_losses.cpu().numpy()
529
+
530
+ p = np.where(influences_sort[0] == edge_sort)[0][0]
531
+ direct_add_rank_ratio += p / edge_sort.shape[0]
532
+ if logger:
533
+ logger.info('Top 8 inf_score: {}'.format(" ".join(map(str, list(inf_score_sorted[:8])))))
534
+ logger.info('Top 8 edge_score: {}'.format(" ".join(map(str, list(edge_score_sorted[:8])))))
535
+
536
+ nghbr_cos_log_prob = influences_log_prob.detach().cpu().numpy()
537
+ nghbr_LM_log_prob = edge_losses_log_prob.detach().cpu().numpy()
538
+ max_sim = nghbr_cos_log_prob[influences_sort[0]]
539
+ min_sim = nghbr_cos_log_prob[influences_sort[-1]]
540
+ max_LM = nghbr_LM_log_prob[edge_sort[0]]
541
+ min_LM = nghbr_LM_log_prob[edge_sort[-1]]
542
+ direct_score_0 = 0
543
+ direct_score_1 = 0
544
+ if target_trip_ori[0]+'_'+target_trip_ori[2] not in used_trip:
545
+ direct_score_0 = nghbr_cos_log_prob[-1]
546
+ direct_score_1 = nghbr_LM_log_prob[-1]
547
+
548
+ # bound = math.log(1 / nghbr_LM_log_prob.shape[0])
549
+ bound = 1 - args.reasonable_rate
550
+ edge_losses = (edge_losses - data_mean) / data_std
551
+ edge_losses_prob = 1 / ( 1 + np.exp(edge_losses - divide_bound) )
552
+ nghbr_LM_log_prob[edge_losses_prob < bound] = -(1e20)
553
+
554
+ final_score = nghbr_cos_log_prob + nghbr_LM_log_prob
555
+
556
+ index = np.argmax(final_score[:-1])
557
+ sort_index = [(i, final_score[i])for i in range(len(final_score) - 1)]
558
+ sort_index = sorted(sort_index, key=lambda x: x[1], reverse=True)
559
+ assert sort_index[0][0] == index
560
+
561
+ p = np.where(index == edge_sort)[0][0]
562
+ if logger:
563
+ logger.info('Bad edge ratio: {}'.format((edge_losses_prob < bound).mean()))
564
+ logger.info('Bounded edge\'s edge rank ratio: {}'.format(p / edge_sort.shape[0]))
565
+ real_add_rank_ratio += p / edge_sort.shape[0]
566
+ bad_ratio += (edge_losses_prob < bound).mean()
567
+
568
+ add_trip = nghbr_trip[index]
569
+
570
+ if (int(add_trip[0]) == int(-1)):
571
+ add_trip[0], add_trip[1], add_trip[2] = -1, -1, -1
572
+ print(final_score.shape, index, edge_losses_prob[index], bound)
573
+ raise Exception('??')
574
+
575
+ if logger:
576
+ logger.info('max_inf: {0:.8}, min_inf: {1:.8}, max_edge: {2:.8}, min_edge: {3:.8}'.format(max_sim, min_sim, max_LM, min_LM))
577
+ logger.info('Target trip: {0}_{1}_{2}. Attack trip: {3}_{4}_{5}.\n Influnce score: {6:.8}. Edge score: {7:.8}. Direct score: {8:.8} + {9:.8}'.format(target_trip_ori[0],target_trip_ori[1], target_trip_ori[2],
578
+ add_trip[0], add_trip[1], add_trip[2],
579
+ nghbr_cos_log_prob[index], nghbr_LM_log_prob[index],
580
+ direct_score_0, direct_score_1))
581
+ if (args.added_edge_num == '' or int(args.added_edge_num) == 1):
582
+ ret_trip.append(add_trip)
583
+ else:
584
+ edge_num = int(args.added_edge_num)
585
+ for i in range(edge_num):
586
+ ret_trip.append(nghbr_trip[sort_index[i][0]])
587
+ score_record.append((nghbr_cos_log_prob, nghbr_LM_log_prob))
588
+
589
+ if not load_Record and cache_intermidiate:
590
+ with open(intermidiate_path, 'wb') as fl:
591
+ dill.dump(RRcord, fl)
592
+ direct_add_rank_ratio = direct_add_rank_ratio / target_data.shape[0]
593
+ real_add_rank_ratio = real_add_rank_ratio / target_data.shape[0]
594
+ bad_ratio = bad_ratio / target_data.shape[0]
595
+ if logger:
596
+ logger.info('Mean direct ratio: {}. Mean real ratio: {}. Mean bad ratio: {}'.format(direct_add_rank_ratio, real_add_rank_ratio, bad_ratio))
597
+ return ret_trip, score_record
598
+
599
+ def calculate_edge_bound(data, model, device, n_ent):
600
+
601
+ tmp = np.random.choice(a = data.shape[0], size = data.shape[0] // 10, replace=False)
602
+ existed_data= data[tmp, :]
603
+
604
+ print('calculating edge bound ...')
605
+ print(existed_data.shape)
606
+
607
+ existed_edge = set()
608
+ for src_trip in existed_data:
609
+ existed_edge.add('_'.join(list(src_trip)))
610
+
611
+ not_existed = []
612
+ for s, r, o in existed_data:
613
+
614
+ if np.random.randint(0, n_ent) % 2 == 0:
615
+ while True:
616
+ oo = np.random.randint(0, n_ent)
617
+ if '_'.join([s, r, str(oo)]) not in existed_edge:
618
+ not_existed.append([s, r, str(oo)])
619
+ break
620
+ else:
621
+ while True:
622
+ ss = np.random.randint(0, n_ent)
623
+ if '_'.join([str(ss), r, o]) not in existed_edge:
624
+ not_existed.append([str(ss), r, o])
625
+ break
626
+ existed_data = np.array(existed_data)
627
+ not_existed = np.array(not_existed)
628
+ existed_data = torch.from_numpy(existed_data.astype('int64')).to(device)
629
+ not_existed = torch.from_numpy(not_existed.astype('int64')).to(device)
630
+ loss_existed = get_model_loss_without_softmax(existed_data, model).cpu().numpy()
631
+ loss_not_existed = get_model_loss_without_softmax(not_existed, model).cpu().numpy()
632
+ tot_loss = np.hstack((loss_existed, loss_not_existed))
633
+ tot_mean, tot_std = np.mean(tot_loss), np.std(tot_loss)
634
+
635
+ loss_existed = (loss_existed - tot_mean) / tot_std
636
+ loss_not_existed = (loss_not_existed - tot_mean) / tot_std
637
+
638
+ print('Tot mean: {}, Tot std: {}'.format(tot_mean, tot_std))
639
+
640
+ # print(np.mean(loss_existed), np.std(loss_existed), np.max(loss_existed))
641
+ # print(np.mean(loss_not_existed), np.std(loss_not_existed), np.min(loss_not_existed))
642
+ l_mean, l_std = np.mean(loss_existed), np.std(loss_existed)
643
+ r_mean, r_std = np.mean(loss_not_existed), np.std(loss_not_existed)
644
+
645
+ A = -1/(l_std**2) + 1/(r_std**2)
646
+ B = 2 * (-r_mean/(r_std**2) + l_mean/(l_std**2))
647
+ C = (r_mean**2)/(r_std**2)-(l_mean**2)/(l_std**2) + np.log((r_std**2)/(l_std**2))
648
+
649
+ delta = B**2 - 4*A*C
650
+
651
+ x_1 = ( -B + math.sqrt(delta) ) / (2*A)
652
+ x_2 = ( -B - math.sqrt(delta) ) / (2*A)
653
+
654
+ x = None
655
+ if (x_1 > l_mean and x_1 < r_mean):
656
+ x = x_1
657
+ if (x_2 > l_mean and x_2 < r_mean):
658
+ x = x_2
659
+ if not x:
660
+ raise Exception('Bad model!!!!')
661
+ TP = (loss_existed < x).mean()
662
+ TN = (loss_not_existed > x).mean()
663
+ FP = (loss_not_existed < x).mean()
664
+ FN = (loss_existed > x).mean()
665
+ print('X:{}, TP:{}, TN:{}, FP:{}, FN{}'.format(x, TP, TN, FP, FN))
666
+
667
+ sig_existed = 1 / ( 1 + np.exp(loss_existed- x) ) # negtive important
668
+ sig_not_existed = 1 / ( 1 + np.exp(loss_not_existed - x) )
669
+
670
+ print('Positive mean score:', sig_existed.mean(),'Negetive mean score:', sig_not_existed.mean())
671
+
672
+ return x, tot_mean, tot_std
673
+
674
+
675
+ #%%
676
+ if __name__ == '__main__':
677
+ parser = utils.get_argument_parser()
678
+ parser = utils.add_attack_parameters(parser)
679
+ args = parser.parse_args()
680
+ args = utils.set_hyperparams(args)
681
+
682
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
683
+ args.device = device
684
+ args.device1 = device
685
+ if torch.cuda.device_count() >= 2:
686
+ args.device = "cuda:0"
687
+ args.device1 = "cuda:1"
688
+
689
+ utils.seed_all(args.seed)
690
+ np.set_printoptions(precision=5)
691
+ cudnn.benchmark = False
692
+
693
+ model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
694
+ model_path = 'saved_models/{0}_{1}.model'.format(args.data, model_name)
695
+ data_path = os.path.join('processed_data', args.data)
696
+ target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
697
+ lissa_path = 'lissa/{0}_{1}_{2}'.format(args.model,
698
+ args.data,
699
+ args.target_size)
700
+ intermidiate_path = 'intermidiate/{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format(args.model,
701
+ args.target_split,
702
+ args.target_size,
703
+ 'exists:'+str(args.target_existed),
704
+ args.neighbor_num,
705
+ args.candidate_mode,
706
+ args.attack_goal)
707
+ log_path = 'logs/attack_logs/cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}'.format(args.model,
708
+ args.target_split,
709
+ args.target_size,
710
+ 'exists:'+str(args.target_existed),
711
+ args.neighbor_num,
712
+ args.candidate_mode,
713
+ args.attack_goal,
714
+ str(args.reasonable_rate))
715
+ print(log_path)
716
+ attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}.txt'.format(args.model,
717
+ args.target_split,
718
+ args.target_size,
719
+ 'exists:'+str(args.target_existed),
720
+ args.neighbor_num,
721
+ args.candidate_mode,
722
+ args.attack_goal,
723
+ str(args.reasonable_rate),
724
+ str(args.added_edge_num)))
725
+
726
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
727
+ datefmt = '%m/%d/%Y %H:%M:%S',
728
+ level = logging.INFO,
729
+ filename = log_path
730
+ )
731
+ logger = logging.getLogger(__name__)
732
+ logger.info(vars(args))
733
+ #%%
734
+ n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
735
+ data = utils.load_data(os.path.join(data_path, 'all.txt'))
736
+ with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
737
+ filters = pkl.load(fl)
738
+ with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl:
739
+ entityid_to_nodetype = json.load(fl)
740
+ with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl:
741
+ edge_nghbrs = pkl.load(fl)
742
+ with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl:
743
+ disease_meshid = pkl.load(fl)
744
+ with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
745
+ entity_to_id = json.load(fl)
746
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
747
+ entity_raw_name = pkl.load(fl)
748
+ #%%
749
+ init_mask = np.asarray([0] * n_ent).astype('int64')
750
+ init_mask = (init_mask == 1)
751
+ for k, v in filters.items():
752
+ for kk, vv in v.items():
753
+ tmp = init_mask.copy()
754
+ tmp[np.asarray(vv)] = True
755
+ t = torch.ByteTensor(tmp).to(args.device)
756
+ filters[k][kk] = t
757
+ #%%
758
+ model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
759
+ divide_bound, data_mean, data_std = calculate_edge_bound(data, model, args.device, n_ent)
760
+ # index = torch.LongTensor([0, 1]).to(device)
761
+ # print(model.emb_rel(index)[:, :32])
762
+ # print(model.emb_e(index)[:, :32])
763
+ # raise Exception
764
+ #%%
765
+ target_data = utils.load_data(target_path)
766
+ if args.attack_goal == 'single':
767
+ neighbors = generate_nghbrs(target_data, edge_nghbrs, args)
768
+ elif args.attack_goal == 'global':
769
+ s_set = set()
770
+ for s, r, o in target_data:
771
+ s_set.add(s)
772
+ target_data = list(s_set)
773
+ target_data.sort()
774
+ target_data = np.array(target_data, dtype=str)
775
+ neighbors = []
776
+ for i in list(range(n_ent)):
777
+ tp = entityid_to_nodetype[str(i)]
778
+ # r = torch.LongTensor([[10]]).to(device)
779
+ if tp == 'gene':
780
+ neighbors.append(str(i))
781
+ target_disease = []
782
+ tid = 1
783
+ bound = 50
784
+ while True:
785
+ meshid = disease_meshid[tid][0]
786
+ fre = disease_meshid[tid][1]
787
+ if len(entity_raw_name[meshid]) > 4:
788
+ target_disease.append(entity_to_id[meshid])
789
+ bound -= 1
790
+ if bound == 0:
791
+ break
792
+ tid += 1
793
+ else:
794
+ raise Exception('Wrong attack_goal: '+args.attack_goal)
795
+
796
+ param_optimizer = list(model.named_parameters())
797
+ param_influence = []
798
+ for n,p in param_optimizer:
799
+ param_influence.append(p)
800
+ if args.attack_goal == 'single':
801
+ len_list = []
802
+ for v in neighbors.values():
803
+ len_list.append(len(v))
804
+ mean_len = np.mean(len_list)
805
+ else:
806
+ mean_len = len(neighbors)
807
+ print('Mean length of neighbors:', mean_len)
808
+ logger.info("Mean length of neighbors: {0}".format(mean_len))
809
+
810
+ # GPT_LM = LMscore_calculator(data_path, args)
811
+ lissa_num_batches = math.ceil(data.shape[0]/args.lissa_batch_size)
812
+ logger.info('-------- Lissa Params for IHVP --------')
813
+ logger.info('Damping: {0}'.format(args.damping))
814
+ logger.info('Lissa_repeat: {0}'.format(args.lissa_repeat))
815
+ logger.info('Lissa_depth: {0}'.format(args.lissa_depth))
816
+ logger.info('Scale: {0}'.format(args.scale))
817
+ logger.info('Lissa batch size: {0}'.format(args.lissa_batch_size))
818
+ logger.info('Lissa num bacthes: {0}'.format(lissa_num_batches))
819
+
820
+ score_path = os.path.join('attack_results', args.data, 'score_cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}.txt'.format(args.model,
821
+ args.target_split,
822
+ args.target_size,
823
+ 'exists:'+str(args.target_existed),
824
+ args.neighbor_num,
825
+ args.candidate_mode,
826
+ args.attack_goal,
827
+ str(args.reasonable_rate),
828
+ str(args.added_edge_num)))
829
+
830
+ if args.attack_goal == 'single':
831
+ attack_trip, score_record = addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std)
832
+ else:
833
+ # lissa = before_global_attack(args.device, n_rel, data, target_data, neighbors, model, filters, entityid_to_nodetype, args.attack_batch_size, args, lissa_path, target_disease)
834
+
835
+ attack_trip, score_record = global_addtion_attack(args.device, n_rel, data, target_data, neighbors, model, filters, entityid_to_nodetype, args.attack_batch_size, args, None, target_disease)
836
+
837
+ utils.save_data(attack_path, attack_trip)
838
+
839
+ logger.info("Attack triples are saved in " + attack_path)
840
+ with open(score_path, 'wb') as fl:
841
+ pkl.dump(score_record, fl)
DiseaseSpecific/edge_to_abstract.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import torch
3
+ import numpy as np
4
+ from torch.autograd import Variable
5
+ # from sklearn import metrics
6
+
7
+ import datetime
8
+ from typing import Dict, Tuple, List
9
+ import logging
10
+ import os
11
+ import utils
12
+ import pickle as pkl
13
+ import json
14
+ import torch.backends.cudnn as cudnn
15
+
16
+ from tqdm import tqdm
17
+
18
+ import sys
19
+ sys.path.append("..")
20
+ import Parameters
21
+
22
+ parser = utils.get_argument_parser()
23
+ parser = utils.add_attack_parameters(parser)
24
+ parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune')
25
+ parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
26
+ args = parser.parse_args()
27
+ args = utils.set_hyperparams(args)
28
+
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+
31
+ utils.seed_all(args.seed)
32
+ np.set_printoptions(precision=5)
33
+ cudnn.benchmark = False
34
+
35
+ data_path = os.path.join('processed_data', args.data)
36
+ target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
37
+ attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}.txt'.format(args.model,
38
+ args.target_split,
39
+ args.target_size,
40
+ 'exists:'+str(args.target_existed),
41
+ args.neighbor_num,
42
+ args.candidate_mode,
43
+ args.attack_goal,
44
+ str(args.reasonable_rate)))
45
+ # target_data = utils.load_data(target_path)
46
+ attack_data = utils.load_data(attack_path, drop=False)
47
+ # assert target_data.shape == attack_data.shape
48
+ #%%
49
+
50
+ with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
51
+ id_to_meshid = json.load(fl)
52
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
53
+ entity_raw_name = pkl.load(fl)
54
+ with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
55
+ retieve_sentence_through_edgetype = pkl.load(fl)
56
+ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
57
+ raw_text_sen = pkl.load(fl)
58
+
59
+ if not os.path.exists('generate_abstract/valid_entity.json'):
60
+ valid_entity = set()
61
+ for paper_id, paper in raw_text_sen.items():
62
+ for sen_id, sen in paper.items():
63
+ text = sen['text'].split(' ')
64
+ for a in text:
65
+ if '_' in a:
66
+ valid_entity.add(a.replace('_', ' '))
67
+ with open('valid_entity.json', 'w') as fl:
68
+ json.dump(list(valid_entity), fl, indent=4)
69
+ print('Valid entity saved!!')
70
+
71
+ if args.mode == 'sentence':
72
+ import torch
73
+ from torch.nn.modules.loss import CrossEntropyLoss
74
+ from transformers import AutoTokenizer
75
+ from transformers import BioGptForCausalLM
76
+ criterion = CrossEntropyLoss(reduction="none")
77
+
78
+ print('Generating GPT input ...')
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
81
+ tokenizer.pad_token = tokenizer.eos_token
82
+ model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
83
+ model.to(device)
84
+ model.eval()
85
+ GPT_batch_size = 32
86
+ single_sentence = {}
87
+ test_text = []
88
+ test_dp = []
89
+ test_parse = []
90
+ for i, (s, r, o) in enumerate(tqdm(attack_data)):
91
+
92
+ if int(s) != -1:
93
+
94
+ dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
95
+ candidate_sen = []
96
+ Dp_path = []
97
+ L = len(dependency_sen_dict.keys())
98
+ bound = 500 // L
99
+ if bound == 0:
100
+ bound = 1
101
+ for dp_path, sen_list in dependency_sen_dict.items():
102
+ if len(sen_list) > bound:
103
+ index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
104
+ sen_list = [sen_list[aa] for aa in index]
105
+ candidate_sen += sen_list
106
+ Dp_path += [dp_path] * len(sen_list)
107
+
108
+ text_s = entity_raw_name[id_to_meshid[s]]
109
+ text_o = entity_raw_name[id_to_meshid[o]]
110
+ candidate_text_sen = []
111
+ candidate_ori_sen = []
112
+ candidate_parse_sen = []
113
+
114
+ for paper_id, sen_id in candidate_sen:
115
+ sen = raw_text_sen[paper_id][sen_id]
116
+ text = sen['text']
117
+ candidate_ori_sen.append(text)
118
+ ss = sen['start_formatted']
119
+ oo = sen['end_formatted']
120
+ text = text.replace('-LRB-', '(')
121
+ text = text.replace('-RRB-', ')')
122
+ text = text.replace('-LSB-', '[')
123
+ text = text.replace('-RSB-', ']')
124
+ text = text.replace('-LCB-', '{')
125
+ text = text.replace('-RCB-', '}')
126
+ parse_text = text
127
+ parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
128
+ parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
129
+ text = text.replace(ss, text_s)
130
+ text = text.replace(oo, text_o)
131
+ text = text.replace('_', ' ')
132
+ candidate_text_sen.append(text)
133
+ candidate_parse_sen.append(parse_text)
134
+ tokens = tokenizer( candidate_text_sen,
135
+ truncation = True,
136
+ padding = True,
137
+ max_length = 300,
138
+ return_tensors="pt")
139
+ target_ids = tokens['input_ids'].to(device)
140
+ attention_mask = tokens['attention_mask'].to(device)
141
+
142
+ L = len(candidate_text_sen)
143
+ assert L > 0
144
+ ret_log_L = []
145
+ for l in range(0, L, GPT_batch_size):
146
+ R = min(L, l + GPT_batch_size)
147
+ target = target_ids[l:R, :]
148
+ attention = attention_mask[l:R, :]
149
+ outputs = model(input_ids = target,
150
+ attention_mask = attention,
151
+ labels = target)
152
+ logits = outputs.logits
153
+ shift_logits = logits[..., :-1, :].contiguous()
154
+ shift_labels = target[..., 1:].contiguous()
155
+ Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
156
+ Loss = Loss.view(-1, shift_logits.shape[1])
157
+ attention = attention[..., 1:].contiguous()
158
+ log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
159
+ ret_log_L.append(log_Loss.detach())
160
+
161
+
162
+ ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
163
+ sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
164
+ sen_score.sort(key = lambda x: x[1])
165
+ test_text.append(sen_score[0][2])
166
+ test_dp.append(sen_score[0][3])
167
+ test_parse.append(sen_score[0][4])
168
+ single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]})
169
+
170
+ else:
171
+ single_sentence.update({f'{s}_{r}_{o}_{i}': ''})
172
+
173
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence.json', 'w') as fl:
174
+ json.dump(single_sentence, fl, indent=4)
175
+ # with open('generate_abstract/test.txt', 'w') as fl:
176
+ # fl.write('\n'.join(test_text))
177
+ # with open('generate_abstract/dp.txt', 'w') as fl:
178
+ # fl.write('\n'.join(test_dp))
179
+ with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'w') as fl:
180
+ fl.write('\n'.join(test_dp))
181
+ with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_temp.json', 'w') as fl:
182
+ fl.write('\n'.join(test_text))
183
+
184
+ elif args.mode == 'finetune':
185
+
186
+ import spacy
187
+ import pprint
188
+ from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration
189
+
190
+ print('Finetuning ...')
191
+
192
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_chat.json', 'r') as fl:
193
+ draft = json.load(fl)
194
+ with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'r') as fl:
195
+ dpath = fl.readlines()
196
+
197
+ nlp = spacy.load("en_core_web_sm")
198
+ if os.path.exists(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json'):
199
+ with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl:
200
+ ret_candidates = json.load(fl)
201
+ # if False:
202
+ # pass
203
+ else:
204
+
205
+ def find_mini_span(vec, words, check_set):
206
+
207
+
208
+ def cal(text, sset):
209
+ add = 0
210
+ for tt in sset:
211
+ if tt in text:
212
+ add += 1
213
+ return add
214
+ text = ' '.join(words)
215
+ max_add = cal(text, check_set)
216
+
217
+ minn = 10000000
218
+ span = ''
219
+ rc = None
220
+ for i in range(len(vec)):
221
+ if vec[i] == True:
222
+ p = -1
223
+ for j in range(i+1, len(vec)+1):
224
+ if vec[j-1] == True:
225
+ text = ' '.join(words[i:j])
226
+ if cal(text, check_set) == max_add:
227
+ p = j
228
+ break
229
+ if p > 0:
230
+ if (p-i) < minn:
231
+ minn = p-i
232
+ span = ' '.join(words[i:p])
233
+ rc = (i, p)
234
+ if rc:
235
+ for i in range(rc[0], rc[1]):
236
+ vec[i] = True
237
+ return vec, span
238
+
239
+ def mask_func(tokenized_sen):
240
+
241
+ if len(tokenized_sen) == 0:
242
+ return []
243
+ token_list = []
244
+ # for sen in tokenized_sen:
245
+ # for token in sen:
246
+ # token_list.append(token)
247
+ for sen in tokenized_sen:
248
+ token_list += sen.text.split(' ')
249
+ if args.ratio == '':
250
+ P = 0.3
251
+ else:
252
+ P = float(args.ratio)
253
+
254
+ ret_list = []
255
+ i = 0
256
+ mask_num = 0
257
+ while i < len(token_list):
258
+ t = token_list[i]
259
+ if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
260
+ ret_list.append(t)
261
+ i += 1
262
+ mask_num = 0
263
+ else:
264
+ length = np.random.poisson(3)
265
+ if np.random.rand() < P and length > 0:
266
+ if mask_num < 8:
267
+ ret_list.append('<mask>')
268
+ mask_num += 1
269
+ i += length
270
+ else:
271
+ ret_list.append(t)
272
+ i += 1
273
+ mask_num = 0
274
+ return [' '.join(ret_list)]
275
+
276
+ model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
277
+ model.eval()
278
+ model.to(device)
279
+ tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
280
+
281
+ ret_candidates = {}
282
+ dpath_i = 0
283
+
284
+ for i,(k, v) in enumerate(tqdm(draft.items())):
285
+
286
+ input = v['in'].replace('\n', '')
287
+ output = v['out'].replace('\n', '')
288
+ s, r, o = attack_data[i]
289
+
290
+ if int(s) == -1:
291
+ ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []}
292
+ continue
293
+
294
+ path_text = dpath[dpath_i].replace('\n', '')
295
+ dpath_i += 1
296
+ text_s = entity_raw_name[id_to_meshid[s]]
297
+ text_o = entity_raw_name[id_to_meshid[o]]
298
+
299
+ doc = nlp(output)
300
+ words= input.split(' ')
301
+ tokenized_sens = [sen for sen in doc.sents]
302
+ sens = np.array([sen.text for sen in doc.sents])
303
+
304
+ checkset = set([text_s, text_o])
305
+ e_entity = set(['start_entity', 'end_entity'])
306
+ for path in path_text.split(' '):
307
+ a, b, c = path.split('|')
308
+ if a not in e_entity:
309
+ checkset.add(a)
310
+ if c not in e_entity:
311
+ checkset.add(c)
312
+ vec = []
313
+ l = 0
314
+ while(l < len(words)):
315
+ bo =False
316
+ for j in range(len(words), l, -1): # reversing is important !!!
317
+ cc = ' '.join(words[l:j])
318
+ if (cc in checkset):
319
+ vec += [True] * (j-l)
320
+ l = j
321
+ bo = True
322
+ break
323
+ if not bo:
324
+ vec.append(False)
325
+ l += 1
326
+ vec, span = find_mini_span(vec, words, checkset)
327
+ # vec = np.vectorize(lambda x: x in checkset)(words)
328
+ vec[-1] = True
329
+ prompt = []
330
+ mask_num = 0
331
+ for j, bo in enumerate(vec):
332
+ if not bo:
333
+ mask_num += 1
334
+ else:
335
+ if mask_num > 0:
336
+ # mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
337
+ mask_num = max(mask_num, 1)
338
+ mask_num= min(8, mask_num)
339
+ prompt += ['<mask>'] * mask_num
340
+ prompt.append(words[j])
341
+ mask_num = 0
342
+ prompt = ' '.join(prompt)
343
+ Text = []
344
+ Assist = []
345
+
346
+ for j in range(len(sens)):
347
+ Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
348
+ assist = list(sens[:j]) + [input] +list(sens[j+1:])
349
+ Text.append(' '.join(Bart_input))
350
+ Assist.append(' '.join(assist))
351
+
352
+ for j in range(len(sens)):
353
+ Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:])
354
+ assist = list(sens[:j]) + [input] +list(sens[j+1:])
355
+ Text.append(' '.join(Bart_input))
356
+ Assist.append(' '.join(assist))
357
+
358
+ batch_size = len(Text) // 2
359
+ Outs = []
360
+ for l in range(2):
361
+ A = tokenizer(Text[batch_size * l:batch_size * (l+1)],
362
+ truncation = True,
363
+ padding = True,
364
+ max_length = 1024,
365
+ return_tensors="pt")
366
+ input_ids = A['input_ids'].to(device)
367
+ attention_mask = A['attention_mask'].to(device)
368
+ aaid = model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024)
369
+ outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
370
+ Outs += outs
371
+ ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist}
372
+ with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl:
373
+ json.dump(ret_candidates, fl, indent = 4)
374
+
375
+ from torch.nn.modules.loss import CrossEntropyLoss
376
+ from transformers import BioGptForCausalLM
377
+ criterion = CrossEntropyLoss(reduction="none")
378
+
379
+ tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
380
+ tokenizer.pad_token = tokenizer.eos_token
381
+ model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
382
+ model.to(device)
383
+ model.eval()
384
+
385
+ scored = {}
386
+ ret = {}
387
+ dpath_i = 0
388
+ for i,(k, v) in enumerate(tqdm(draft.items())):
389
+
390
+ span = ret_candidates[str(i)]['span']
391
+ prompt = ret_candidates[str(i)]['prompt']
392
+ sen_list = ret_candidates[str(i)]['out']
393
+ BART_in = ret_candidates[str(i)]['in']
394
+ Assist = ret_candidates[str(i)]['assist']
395
+
396
+ s, r, o = attack_data[i]
397
+
398
+ if int(s) == -1:
399
+ ret[k] = {'prompt': '', 'in':'', 'out': ''}
400
+ continue
401
+
402
+ text_s = entity_raw_name[id_to_meshid[s]]
403
+ text_o = entity_raw_name[id_to_meshid[o]]
404
+
405
+ def process(text):
406
+
407
+ for i in range(ord('A'), ord('Z')+1):
408
+ text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
409
+ return text
410
+
411
+ sen_list = [process(text) for text in sen_list]
412
+ path_text = dpath[dpath_i].replace('\n', '')
413
+ dpath_i += 1
414
+
415
+ checkset = set([text_s, text_o])
416
+ e_entity = set(['start_entity', 'end_entity'])
417
+ for path in path_text.split(' '):
418
+ a, b, c = path.split('|')
419
+ if a not in e_entity:
420
+ checkset.add(a)
421
+ if c not in e_entity:
422
+ checkset.add(c)
423
+
424
+ input = v['in'].replace('\n', '')
425
+ output = v['out'].replace('\n', '')
426
+
427
+ doc = nlp(output)
428
+ gpt_sens = [sen.text for sen in doc.sents]
429
+ assert len(gpt_sens) == len(sen_list) // 2
430
+
431
+ word_sets = []
432
+ for sen in gpt_sens:
433
+ word_sets.append(set(sen.split(' ')))
434
+
435
+ def sen_align(word_sets, modified_word_sets):
436
+
437
+ l = 0
438
+ while(l < len(modified_word_sets)):
439
+ if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
440
+ l += 1
441
+ else:
442
+ break
443
+ if l == len(modified_word_sets):
444
+ return -1, -1, -1, -1
445
+ r = l + 1
446
+ r1 = None
447
+ r2 = None
448
+ for pos1 in range(r, len(word_sets)):
449
+ for pos2 in range(r, len(modified_word_sets)):
450
+ if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
451
+ r1 = pos1
452
+ r2 = pos2
453
+ break
454
+ if r1 is not None:
455
+ break
456
+ if r1 is None:
457
+ r1 = len(word_sets)
458
+ r2 = len(modified_word_sets)
459
+ return l, r1, l, r2
460
+
461
+ replace_sen_list = []
462
+ boundary = []
463
+ assert len(sen_list) % 2 == 0
464
+ for j in range(len(sen_list) // 2):
465
+ doc = nlp(sen_list[j])
466
+ sens = [sen.text for sen in doc.sents]
467
+ modified_word_sets = [set(sen.split(' ')) for sen in sens]
468
+ l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
469
+ boundary.append((l1, r1, l2, r2))
470
+ if l1 == -1:
471
+ replace_sen_list.append(sen_list[j])
472
+ continue
473
+ check_text = ' '.join(sens[l2: r2])
474
+ replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
475
+ sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
476
+
477
+ old_L = len(sen_list)
478
+ sen_list.append(output)
479
+ sen_list += Assist
480
+ tokens = tokenizer( sen_list,
481
+ truncation = True,
482
+ padding = True,
483
+ max_length = 1024,
484
+ return_tensors="pt")
485
+ target_ids = tokens['input_ids'].to(device)
486
+ attention_mask = tokens['attention_mask'].to(device)
487
+ L = len(sen_list)
488
+ ret_log_L = []
489
+ for l in range(0, L, 5):
490
+ R = min(L, l + 5)
491
+ target = target_ids[l:R, :]
492
+ attention = attention_mask[l:R, :]
493
+ outputs = model(input_ids = target,
494
+ attention_mask = attention,
495
+ labels = target)
496
+ logits = outputs.logits
497
+ shift_logits = logits[..., :-1, :].contiguous()
498
+ shift_labels = target[..., 1:].contiguous()
499
+ Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
500
+ Loss = Loss.view(-1, shift_logits.shape[1])
501
+ attention = attention[..., 1:].contiguous()
502
+ log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
503
+ ret_log_L.append(log_Loss.detach())
504
+ log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
505
+
506
+ real_log_Loss = log_Loss.copy()
507
+
508
+ log_Loss = log_Loss[:old_L]
509
+
510
+ p = np.argmin(log_Loss)
511
+ content = []
512
+ for i in range(len(real_log_Loss)):
513
+ content.append([sen_list[i], str(real_log_Loss[i])])
514
+ scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary}
515
+ p_p = p
516
+ # print('Old_L:', old_L)
517
+
518
+ if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
519
+ p_p = p+1+old_L
520
+
521
+ if real_log_Loss[p] > real_log_Loss[old_L]:
522
+ if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
523
+ p = p+1+old_L
524
+ ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]}
525
+ with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl:
526
+ json.dump(ret, fl, indent=4)
527
+ with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl:
528
+ json.dump(scored, fl, indent=4)
529
+ else:
530
+ raise Exception('Wrong mode !!')
DiseaseSpecific/evaluation.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch.autograd import Variable
4
+ from sklearn import metrics
5
+
6
+ import datetime
7
+ from typing import Dict, Tuple, List
8
+ import logging
9
+ import os
10
+ import utils
11
+ import pickle as pkl
12
+ import json
13
+ from tqdm import tqdm
14
+ import torch.backends.cudnn as cudnn
15
+
16
+ import sys
17
+ sys.path.append("..")
18
+ import Parameters
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ def get_model_loss_without_softmax(batch, model, device=None):
23
+
24
+ with torch.no_grad():
25
+ s,r,o = batch[:,0], batch[:,1], batch[:,2]
26
+
27
+ emb_s = model.emb_e(s).squeeze(dim=1)
28
+ emb_r = model.emb_rel(r).squeeze(dim=1)
29
+
30
+ pred = model.forward(emb_s, emb_r)
31
+ return -pred[range(o.shape[0]), o]
32
+
33
+ def check(trip, model, reasonable_rate, device, data_mean = -4.008113861083984, data_std = 5.153779983520508, divide_bound = 0.05440050354114886):
34
+
35
+ if args.model == 'distmult':
36
+ pass
37
+ elif args.model == 'conve':
38
+ data_mean = 13.890259742
39
+ data_std = 12.396190643
40
+ divide_bound = -0.1986345871
41
+ else:
42
+ raise Exception('Wrong model!!')
43
+ trip = np.array(trip)
44
+ train_trip = trip[None, :]
45
+ train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
46
+ edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze().item()
47
+
48
+ bound = 1 - reasonable_rate
49
+ edge_loss = (edge_loss - data_mean) / data_std
50
+ edge_loss_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound))
51
+ return edge_loss_prob > bound
52
+
53
+
54
+ def get_ranking(model, queries,
55
+ valid_filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
56
+ device, batch_size, entityid_to_nodetype, exists_edge):
57
+ """
58
+ Ranking for target generation.
59
+ """
60
+ ranks = []
61
+ total_nums = []
62
+ b_begin = 0
63
+
64
+ for b_begin in range(0, len(queries), 1):
65
+ b_queries = queries[b_begin : b_begin+1]
66
+ s,r,o = b_queries[:,0], b_queries[:,1], b_queries[:,2]
67
+ r_rev = r
68
+ lhs_score = model.score_or(o, r_rev, sigmoid=False) #this gives scores not probabilities
69
+ # print(b_queries.shape)
70
+ for i, query in enumerate(b_queries):
71
+
72
+ if not args.target_existed:
73
+ tp1 = entityid_to_nodetype[str(query[0].item())]
74
+ tp2 = entityid_to_nodetype[str(query[2].item())]
75
+ filter = valid_filters['lhs'][(tp2, query[1].item())].clone()
76
+ filter[exists_edge['lhs'][str(query[2].item())]] = False
77
+ filter = (filter == False)
78
+ else:
79
+ tp1 = entityid_to_nodetype[str(query[0].item())]
80
+ tp2 = entityid_to_nodetype[str(query[2].item())]
81
+ filter = valid_filters['lhs'][(tp2, query[1].item())]
82
+ filter = (filter == False)
83
+
84
+ # if (str(query[2].item())) == '16566':
85
+ # print('16566', filter.sum(), valid_filters['lhs'][(tp2, query[1].item())].sum(), tp2, query[1].item())
86
+ # raise Exception('??')
87
+
88
+ score = lhs_score
89
+ # target_value = rhs_score[i, query[0].item()].item()
90
+ # zero all known cases (this are not interesting)
91
+ # this corresponds to the filtered setting
92
+ score[i][filter] = 1e6
93
+ total_nums.append(n_ent - filter.sum().item())
94
+ # write base the saved values
95
+ # if b_begin < len(queries) // 2:
96
+ # score[i][query[2].item()] = target_value
97
+ # else:
98
+ # score[i][query[0].item()] = target_value
99
+
100
+ # sort and rank
101
+ min_values, sort_v = torch.sort(score, dim=1, descending=False) #low scores get low number ranks
102
+
103
+ sort_v = sort_v.cpu().numpy()
104
+
105
+ for i, query in enumerate(b_queries):
106
+ # find the rank of the target entities
107
+ rank = np.where(sort_v[i]==query[0].item())[0][0]
108
+
109
+ # rank+1, since the lowest rank is rank 1 not rank 0
110
+ ranks.append(rank)
111
+
112
+ #logger.info('Ranking done for all queries')
113
+ return ranks, total_nums
114
+
115
+
116
+ def evaluation(model, queries,
117
+ valid_filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
118
+ device, batch_size, entityid_to_nodetype, exists_edge, eval_type = '', attack_data = None, ori_ranks = None, ori_totals = None):
119
+
120
+ #get ranking
121
+ ranks, total_nums = get_ranking(model, queries, valid_filters, device, batch_size, entityid_to_nodetype, exists_edge)
122
+ ranks, total_nums = np.array(ranks), np.array(total_nums)
123
+ # print(ranks)
124
+ # print(total_nums)
125
+ # print(ranks)
126
+ # print(total_nums)
127
+
128
+ ranks = total_nums - ranks
129
+
130
+ if (attack_data is not None):
131
+ for i, tri in enumerate(attack_data):
132
+ if args.mode == '':
133
+ if args.added_edge_num == '' or int(args.added_edge_num) == 1:
134
+ if int(tri[0]) == -1:
135
+ ranks[i] = ori_ranks[i]
136
+ total_nums[i] = ori_totals[i]
137
+ else:
138
+ if int(tri[0][0]) == -1:
139
+ ranks[i] = ori_ranks[i]
140
+ total_nums[i] = ori_totals[i]
141
+ else:
142
+ if len(tri) == 0:
143
+ ranks[i] = ori_ranks[i]
144
+ total_nums[i] = ori_totals[i]
145
+
146
+ mean = (ranks / total_nums).mean()
147
+ std = (ranks / total_nums).std()
148
+ #final logging
149
+ hits_at = np.arange(1,11)
150
+ hits_at_both = list(map(lambda x: np.mean((ranks <= x), dtype=np.float64).item(),
151
+ hits_at))
152
+ mr = np.mean(ranks, dtype=np.float64).item()
153
+
154
+ mrr = np.mean(1. / ranks, dtype=np.float64).item()
155
+
156
+ logger.info('')
157
+ logger.info('-'*50)
158
+ # logger.info(split+'_'+save_name)
159
+ logger.info('')
160
+ if eval_type:
161
+ logger.info(eval_type)
162
+ else:
163
+ logger.info('after attck')
164
+
165
+ for i in hits_at:
166
+ logger.info('Hits @{0}: {1}'.format(i, hits_at_both[i-1]))
167
+ logger.info('Mean rank: {0}'.format( mr))
168
+ logger.info('Mean reciprocal rank lhs: {0}'.format(mrr))
169
+ logger.info('Mean proportion: {0}'.format(mean))
170
+ logger.info('Std proportion: {0}'.format(std))
171
+ logger.info('Mean candidate num: {0}'.format(np.mean(total_nums)))
172
+
173
+ # with open(os.path.join('results', split + '_' + save_name + '.txt'), 'a') as text_file:
174
+ # text_file.write('Epoch: {0}\n'.format(epoch))
175
+ # text_file.write('Lhs denotes ranking by subject corruptions \n')
176
+ # text_file.write('Rhs denotes ranking by object corruptions \n')
177
+ # for i in hits_at:
178
+ # text_file.write('Hits left @{0}: {1}\n'.format(i, hits_at_lhs[i-1]))
179
+ # text_file.write('Hits right @{0}: {1}\n'.format(i, hits_at_rhs[i-1]))
180
+ # text_file.write('Hits @{0}: {1}\n'.format(i, np.mean([hits_at_lhs[i-1],hits_at_rhs[i-1]]).item()))
181
+ # text_file.write('Mean rank lhs: {0}\n'.format( mr_lhs))
182
+ # text_file.write('Mean rank rhs: {0}\n'.format(mr_rhs))
183
+ # text_file.write('Mean rank: {0}\n'.format( np.mean([mr_lhs, mr_rhs])))
184
+ # text_file.write('MRR lhs: {0}\n'.format( mrr_lhs))
185
+ # text_file.write('MRR rhs: {0}\n'.format(mrr_rhs))
186
+ # text_file.write('MRR: {0}\n'.format(np.mean([mrr_rhs, mrr_lhs])))
187
+ # text_file.write('-------------------------------------------------\n')
188
+
189
+
190
+ results = {}
191
+ for i in hits_at:
192
+ results['hits @{}'.format(i)] = hits_at_both[i-1]
193
+ results['mrr'] = mrr
194
+ results['mr'] = mr
195
+ results['proportion'] = mean
196
+ results['std'] = std
197
+
198
+ return results, list(ranks), list(total_nums)
199
+
200
+
201
+ parser = utils.get_argument_parser()
202
+ parser = utils.add_attack_parameters(parser)
203
+ parser = utils.add_eval_parameters(parser)
204
+ args = parser.parse_args()
205
+ args = utils.set_hyperparams(args)
206
+
207
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
208
+
209
+ utils.seed_all(args.seed)
210
+ np.set_printoptions(precision=5)
211
+ cudnn.benchmark = False
212
+
213
+ data_path = os.path.join('processed_data', args.data)
214
+ target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
215
+
216
+ log_path = 'logs/evaluation_logs/cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}'.format(args.model,
217
+ args.target_split,
218
+ args.target_size,
219
+ 'exists:'+str(args.target_existed),
220
+ args.neighbor_num,
221
+ args.candidate_mode,
222
+ args.attack_goal,
223
+ str(args.reasonable_rate),
224
+ args.mode)
225
+ record_path = 'eval_record/{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}{9}{10}'.format(args.model,
226
+ args.target_split,
227
+ args.target_size,
228
+ 'exists:'+str(args.target_existed),
229
+ args.neighbor_num,
230
+ args.candidate_mode,
231
+ args.attack_goal,
232
+ str(args.reasonable_rate),
233
+ args.mode,
234
+ str(args.added_edge_num),
235
+ args.mask_ratio)
236
+ init_record_path = 'eval_record/{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}'.format(args.model,
237
+ args.target_split,
238
+ args.target_size,
239
+ 'exists:'+str(args.target_existed),
240
+ args.neighbor_num,
241
+ args.candidate_mode,
242
+ args.attack_goal,
243
+ str(args.reasonable_rate),
244
+ 'init')
245
+
246
+ if args.seperate:
247
+ record_path += '_seperate'
248
+ log_path += '_seperate'
249
+ else:
250
+ record_path += '_batch'
251
+
252
+ if args.direct:
253
+ log_path += '_direct'
254
+ record_path += '_direct'
255
+ else:
256
+ log_path += '_nodirect'
257
+ record_path += '_nodirect'
258
+
259
+ dis_turbrbed_path_pre = os.path.join(data_path, 'evaluation')
260
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
261
+ datefmt = '%m/%d/%Y %H:%M:%S',
262
+ level = logging.INFO,
263
+ filename = log_path
264
+ )
265
+ logger = logging.getLogger(__name__)
266
+ logger.info(vars(args))
267
+
268
+ n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
269
+ model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
270
+ model_path = 'saved_models/{0}_{1}.model'.format(args.data, model_name)
271
+ model = utils.load_model(model_path, args, n_ent, n_rel, device)
272
+
273
+ ori_data = utils.load_data(os.path.join(data_path, 'all.txt'))
274
+ target_data = utils.load_data(target_path)
275
+
276
+ index = range(len(target_data))
277
+ index = np.random.permutation(index)
278
+ target_data = target_data[index]
279
+
280
+ if args.direct:
281
+ assert args.attack_goal == 'single'
282
+ raise Exception('This option is abandoned in this version .')
283
+ # disturbed_data = list(ori_data) + list(target_data)
284
+ else:
285
+
286
+ attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}{9}{10}.txt'.format(args.model,
287
+ args.target_split,
288
+ args.target_size,
289
+ 'exists:'+str(args.target_existed),
290
+ args.neighbor_num,
291
+ args.candidate_mode,
292
+ args.attack_goal,
293
+ str(args.reasonable_rate),
294
+ args.mode,
295
+ str(args.added_edge_num),
296
+ args.mask_ratio))
297
+ if args.mode == '':
298
+ attack_data = utils.load_data(attack_path, drop=False)
299
+ if not(args.added_edge_num == '' or int(args.added_edge_num) == 1):
300
+ assert int(args.added_edge_num) * len(target_data) == len(attack_data)
301
+ attack_data = attack_data.reshape((len(target_data), int(args.added_edge_num), 3))
302
+ attack_data = attack_data[index]
303
+ else:
304
+ assert len(target_data) == len(attack_data)
305
+ attack_data = attack_data[index]
306
+ # if not args.seperate:
307
+ # disturbed_data = list(ori_data) + list(attack_data)
308
+ else:
309
+ with open(attack_path, 'rb') as fl:
310
+ attack_data = pkl.load(fl)
311
+
312
+ tmp_attack_data = []
313
+ for vv in attack_data:
314
+ a_attack = []
315
+ for v in vv:
316
+ if check(v, model, args.reasonable_rate, device):
317
+ a_attack.append(v)
318
+ tmp_attack_data.append(a_attack)
319
+ attack_data = tmp_attack_data
320
+ attack_data = [attack_data[i] for i in index]
321
+
322
+ # if not args.seperate:
323
+ # disturbed_data = list(ori_data)
324
+ # if args.mode == '':
325
+ # for aa in list(attack_data):
326
+ # if int(aa[0]) != -1:
327
+ # disturbed_data.append(aa)
328
+ # else:
329
+ # for vv in attack_data:
330
+ # for v in vv:
331
+ # disturbed_data.append(v)
332
+
333
+ with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
334
+ valid_filters = pkl.load(fl)
335
+ with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl:
336
+ entityid_to_nodetype = json.load(fl)
337
+ with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
338
+ entity_raw_name = pkl.load(fl)
339
+ with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl:
340
+ disease_meshid = pkl.load(fl)
341
+ with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
342
+ entity_to_id = json.load(fl)
343
+
344
+ if args.attack_goal == 'global':
345
+ raise Exception('Please refer to pagerank method in global setting.')
346
+ # target_disease = []
347
+ # tid = 1
348
+ # bound = 50
349
+ # while True:
350
+ # meshid = disease_meshid[tid][0]
351
+ # fre = disease_meshid[tid][1]
352
+ # if len(entity_raw_name[meshid]) > 4:
353
+ # target_disease.append(entity_to_id[meshid])
354
+ # bound -= 1
355
+ # if bound == 0:
356
+ # break
357
+ # tid += 1
358
+ # s_set = set()
359
+ # for s, r, o in target_data:
360
+ # s_set.add(s)
361
+ # target_data = list(s_set)
362
+ # target_data.sort()
363
+
364
+ # target_list = []
365
+ # for s in target_data:
366
+ # for o in target_disease:
367
+ # target_list.append([str(s), str(10), str(o)])
368
+ # target_data = np.array(target_list, dtype = str)
369
+
370
+ init_mask = np.asarray([0] * n_ent).astype('int64')
371
+ init_mask = (init_mask == 1)
372
+ for k, v in valid_filters.items():
373
+ for kk, vv in v.items():
374
+ tmp = init_mask.copy()
375
+ tmp[np.asarray(vv)] = True
376
+ t = torch.ByteTensor(tmp).to(device)
377
+ valid_filters[k][kk] = t
378
+ # print('what??', valid_filters['lhs'][('disease', 10)].sum())
379
+
380
+ exists_edge = {'lhs':{}, 'rhs':{}}
381
+ for s, r, o in ori_data:
382
+ if s not in exists_edge['rhs'].keys():
383
+ exists_edge['rhs'][s] = []
384
+ if o not in exists_edge['lhs'].keys():
385
+ exists_edge['lhs'][o] = []
386
+ exists_edge['rhs'][s].append(int(o))
387
+ exists_edge['lhs'][o].append(int(s))
388
+ target_data = torch.from_numpy(target_data.astype('int64')).to(device)
389
+ # print(target_data[:5, :])
390
+ ori_results, ori_ranks, ori_totals = evaluation(model, target_data, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge, 'original')
391
+ print('Original:', ori_results)
392
+ with open(init_record_path, 'wb') as fl:
393
+ pkl.dump([ori_results, ori_ranks, ori_totals], fl)
394
+
395
+ # raise Exception('Check Original Rank!!!')
396
+
397
+ thread_name = args.model+'_'+args.target_split+'_'+args.attack_goal+'_'+str(args.reasonable_rate)+str(args.added_edge_num)+str(args.mask_ratio)
398
+ if args.direct:
399
+ thread_name += '_direct'
400
+ else:
401
+ thread_name += '_nodirect'
402
+ if args.seperate:
403
+ thread_name += '_seperate'
404
+ else:
405
+ thread_name += '_batch'
406
+ thread_name += args.mode
407
+
408
+ disturbed_data_path = os.path.join(dis_turbrbed_path_pre, 'all_{}.txt'.format(thread_name))
409
+
410
+ if args.seperate:
411
+ # assert len(attack_data) * len(target_disease) == len(target_data)
412
+ assert len(attack_data) == len(target_data)
413
+ # final_result = None
414
+ Ranks = []
415
+ Totals = []
416
+ print('Training model {}...'.format(thread_name))
417
+ for i in tqdm(range(len(attack_data))):
418
+ attack_trip = attack_data[i]
419
+ if args.mode == '':
420
+ attack_trip = [attack_trip]
421
+ # target = target_data[i*len(target_disease) : (i+1)*len(target_disease)]
422
+ target = target_data[i: i+1, :]
423
+ if len(attack_trip) > 0 and int(attack_trip[0][0]) != -1:
424
+ disturbed_data = list(ori_data) + attack_trip
425
+ disturbed_data = np.array(disturbed_data)
426
+ utils.save_data(disturbed_data_path, disturbed_data)
427
+
428
+ cmd = 'CUDA_VISIBLE_DEVICES={} python main_multiprocess.py --data {} --model {} --thread-name {}'.format(args.cuda_name,args.data, args.model, thread_name)
429
+ os.system(cmd)
430
+ model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, thread_name)
431
+ model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, model_name)
432
+ model = utils.load_model(model_path, args, n_ent, n_rel, device)
433
+ a_results, a_ranks, a_total_nums = evaluation(model, target, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge)
434
+ assert len(a_ranks) == 1
435
+ if not final_result:
436
+ final_result = a_results
437
+ else:
438
+ for k in final_result.keys():
439
+ final_result[k] += a_results[k]
440
+ Ranks += a_ranks
441
+ Totals += a_total_nums
442
+ else:
443
+ Ranks += [ori_ranks[i]]
444
+ Totals += [ori_totals[i]]
445
+ final_result['proportion'] += ori_ranks[i] / ori_totals[i]
446
+ for k in final_result.keys():
447
+ final_result[k] /= attack_data.shape[0]
448
+ print('Final !!!')
449
+ print(final_result)
450
+ logger.info('Final !!!!')
451
+ for k, v in final_result.items():
452
+ logger.info('{} : {}'.format(k, v))
453
+ tmp = np.array(Ranks) / np.array(Totals)
454
+ print('Std:', np.std(tmp))
455
+ with open(record_path, 'wb') as fl:
456
+ pkl.dump([final_result, Ranks, Totals], fl)
457
+
458
+ else:
459
+ assert len(target_data) == len(attack_data)
460
+ print('Attack shape:' , len(attack_data))
461
+ Results = []
462
+ Ranks = []
463
+ Totals = []
464
+ for l in range(0, len(target_data), 50):
465
+ r = min(l+50, len(target_data))
466
+ t_target_data = target_data[l:r]
467
+ t_attack_data = attack_data[l:r]
468
+ t_ori_ranks = ori_ranks[l:r]
469
+ t_ori_totals = ori_totals[l:r]
470
+ if args.mode == '':
471
+ if not(args.added_edge_num == '' or int(args.added_edge_num) == 1):
472
+ tt_attack_data = []
473
+ for vv in t_attack_data:
474
+ tt_attack_data += list(vv)
475
+ t_attack_data = tt_attack_data
476
+ else:
477
+ assert args.mode == 'sentence' or args.mode == 'bioBART'
478
+ tt_attack_data = []
479
+ for vv in t_attack_data:
480
+ tt_attack_data += vv
481
+ t_attack_data = tt_attack_data
482
+ disturbed_data = list(ori_data) + list(t_attack_data)
483
+
484
+
485
+ utils.save_data(disturbed_data_path, disturbed_data)
486
+ cmd = 'CUDA_VISIBLE_DEVICES={} python main_multiprocess.py --data {} --model {} --thread-name {}'.format(args.cuda_name,args.data, args.model, thread_name)
487
+ print('Training model {}...'.format(thread_name))
488
+ os.system(cmd)
489
+ model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, thread_name)
490
+ model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, model_name)
491
+ model = utils.load_model(model_path, args, n_ent, n_rel, device)
492
+ a_results, a_ranks, a_totals = evaluation(model, t_target_data, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge, attack_data = attack_data[l:r], ori_ranks = t_ori_ranks, ori_totals = t_ori_totals)
493
+ print(f'************Current l: {l}\n', a_results)
494
+ assert len(a_ranks) == t_target_data.shape[0]
495
+ Results += [a_results]
496
+ Ranks += list(a_ranks)
497
+ Totals += list(a_totals)
498
+ with open(record_path, 'wb') as fl:
499
+ pkl.dump([Results, Ranks, Totals, index], fl)
DiseaseSpecific/main.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ import pickle as pkl
3
+ from typing import Dict, Tuple, List
4
+ import os
5
+ import numpy as np
6
+ import json
7
+ import logging
8
+ import argparse
9
+ import math
10
+ from pprint import pprint
11
+ import pandas as pd
12
+ from collections import defaultdict
13
+ import copy
14
+ import time
15
+ from tqdm import tqdm
16
+
17
+ import torch
18
+ from torch.utils.data import DataLoader
19
+ import torch.backends.cudnn as cudnn
20
+ import torch.autograd as autograd
21
+
22
+ from model import Distmult, Complex, Conve
23
+ import utils
24
+
25
+ # from evaluation import evaluation
26
+
27
+ #%%
28
+ class Main(object):
29
+ def __init__(self, args):
30
+ self.args = args
31
+
32
+ self.model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
33
+ #leaving batches from the model_name since they do not depend on model_architecture
34
+ # also leaving kernel size and filters, siinice don't intend to change those
35
+ self.model_path = 'saved_models/{0}_{1}.model'.format(args.data, self.model_name)
36
+
37
+ self.log_path = 'logs/{0}_{1}_{2}_{3}.log'.format(args.data, self.model_name, args.epochs, args.train_batch_size)
38
+ self.loss_path = 'losses/{0}_{1}_{2}_{3}.pickle'.format(args.data, self.model_name, args.epochs, args.train_batch_size)
39
+
40
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
41
+ datefmt = '%m/%d/%Y %H:%M:%S',
42
+ level = logging.INFO,
43
+ filename = self.log_path)
44
+ self.logger = logging.getLogger(__name__)
45
+ self.logger.info(vars(self.args))
46
+ self.logger.info('\n')
47
+
48
+
49
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+
51
+ self.load_data()
52
+ self.model = self.add_model()
53
+ self.optimizer = self.add_optimizer(self.model.parameters())
54
+
55
+ if self.args.save_influence_map:
56
+ self.logger.info('-------- Argument save_influence_map is set. Will use GR to compute and save influence maps ----------\n')
57
+ # when we want to save influence during training
58
+ self.args.add_reciprocals = False # to keep things simple
59
+ # init an empty influence map
60
+ self.influence_map = defaultdict(float)
61
+ #self.influence_path = 'influence_maps/{0}_{1}.json'.format(args.data, self.model_name)
62
+ self.influence_path = 'influence_maps/{0}_{1}.pickle'.format(args.data, self.model_name)
63
+ # Initialize a copy of the model prams to track previous weights in an epoch
64
+ self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
65
+ self.logger.info('Shape for previous weights: {}, {}'.format(self.previous_weights[0].shape, self.previous_weights[1].shape))
66
+
67
+ def load_data(self):
68
+ '''
69
+ Load the train, valid datasets
70
+ '''
71
+ data_path = os.path.join('processed_data', self.args.data)
72
+ n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
73
+ self.n_ent = n_ent
74
+ self.n_rel = n_rel
75
+
76
+ self.train_data = utils.load_data(os.path.join(data_path, 'all.txt'))
77
+ # print(type(self.train_data), self.train_data.shape) #(1996432, 3)
78
+ tmp = np.random.choice(a = self.train_data.shape[0], size = int(self.train_data.shape[0] * self.args.KG_valid_rate), replace=False)
79
+ self.valid_data= self.train_data[tmp, :]
80
+
81
+
82
+ def add_model(self):
83
+
84
+ if self.args.model is None:
85
+ model = Distmult(self.args, self.n_ent, self.n_rel)
86
+ elif self.args.model == 'distmult':
87
+ model = Distmult(self.args, self.n_ent, self.n_rel)
88
+ elif self.args.model == 'complex':
89
+ model = Complex(self.args, self.n_ent, self.n_rel)
90
+ elif self.args.model == 'conve':
91
+ model = Conve(self.args, self.n_ent, self.n_rel)
92
+ else:
93
+ self.logger.info('Unknown model: {0}', self.args.model)
94
+ raise Exception("Unknown model!")
95
+ model.to(self.device)
96
+ return model
97
+
98
+ def add_optimizer(self, parameters):
99
+ return torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
100
+
101
+ def save_model(self):
102
+ state = {
103
+ 'state_dict': self.model.state_dict(),
104
+ 'optimizer': self.optimizer.state_dict(),
105
+ 'args': vars(self.args)
106
+ }
107
+ torch.save(state, self.model_path)
108
+ self.logger.info('Saving model to {0}'.format(self.model_path))
109
+
110
+ def load_model(self):
111
+ self.logger.info('Loading saved model from {0}'.format(self.model_path))
112
+ state = torch.load(self.model_path)
113
+ model_params = state['state_dict']
114
+ params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
115
+ for key, size, count in params:
116
+ self.logger.info(key, size, count)
117
+ self.model.load_state_dict(model_params)
118
+ self.optimizer.load_state_dict(state['optimizer'])
119
+
120
+ def lp_regularizer(self):
121
+ # Apply p-norm regularization; assign weights to each param
122
+ weight = self.args.reg_weight
123
+ p = self.args.reg_norm
124
+
125
+ trainable_params = [self.model.emb_e.weight, self.model.emb_rel.weight]
126
+ norm = 0
127
+ for i in range(len(trainable_params)):
128
+ #norm += weight * trainable_params[i].norm(p = p)**p
129
+ norm += weight * torch.sum( torch.abs(trainable_params[i]) ** p)
130
+
131
+ return norm
132
+
133
+ def n3_regularizer(self, factors):
134
+ # factors are the embeddings for lhs, rel, rhs for triples in a batch
135
+ weight = self.args.reg_weight
136
+ p = self.args.reg_norm
137
+
138
+ norm = 0
139
+ for f in factors:
140
+ norm += weight * torch.sum(torch.abs(f) ** p)
141
+
142
+ return norm / factors[0].shape[0] # scale by number of triples in batch
143
+
144
+ def get_influence_map(self):
145
+ """
146
+ Turns the influence map into a list, ready to be written to disc. (before: numpy)
147
+ :return: the influence map with lists as values
148
+ """
149
+ assert self.args.save_influence_map == True
150
+
151
+ for key in self.influence_map:
152
+ self.influence_map[key] = self.influence_map[key].tolist()
153
+ #self.logger.info('get_influence_map passed')
154
+ return self.influence_map
155
+
156
+ def evaluate(self, split, batch_size, epoch):
157
+ """
158
+ The same as self.run_epoch()
159
+ """
160
+
161
+ self.model.eval()
162
+ losses = []
163
+
164
+ with torch.no_grad():
165
+ input_data = torch.from_numpy(self.valid_data.astype('int64'))
166
+ actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
167
+ del input_data
168
+
169
+ batch_size = self.args.valid_batch_size
170
+ for b_begin in tqdm(range(0, actual_examples.shape[0], batch_size)):
171
+
172
+ input_batch = actual_examples[b_begin: b_begin + batch_size]
173
+ input_batch = input_batch.to(self.device)
174
+
175
+ s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
176
+
177
+ emb_s = self.model.emb_e(s).squeeze(dim=1)
178
+ emb_r = self.model.emb_rel(r).squeeze(dim=1)
179
+ emb_o = self.model.emb_e(o).squeeze(dim=1)
180
+
181
+ if self.args.add_reciprocals:
182
+ r_rev = r + self.n_rel
183
+ emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
184
+ else:
185
+ r_rev = r
186
+ emb_rrev = emb_r
187
+
188
+ pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
189
+ loss_sr = self.model.loss(pred_sr, o) # cross entropy loss
190
+
191
+ pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
192
+ loss_or = self.model.loss(pred_or, s)
193
+
194
+ total_loss = loss_sr + loss_or
195
+
196
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
197
+ #self.logger.info('Computing regularizer weight')
198
+ if self.args.model == 'complex':
199
+ emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
200
+ lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
201
+ rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
202
+ rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
203
+ rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
204
+
205
+ #print(lhs[0].shape, lhs[1].shape)
206
+ factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
207
+ torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
208
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
209
+ )
210
+ factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
211
+ torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
212
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
213
+ )
214
+ else:
215
+ factors_sr = (emb_s, emb_r, emb_o)
216
+ factors_or = (emb_s, emb_rrev, emb_o)
217
+
218
+ total_loss += self.n3_regularizer(factors_sr)
219
+ total_loss += self.n3_regularizer(factors_or)
220
+
221
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
222
+ total_loss += self.lp_regularizer()
223
+
224
+ losses.append(total_loss.item())
225
+
226
+ loss = np.mean(losses)
227
+ self.logger.info('[Epoch:{}]: Validating Loss:{:.6}\n'.format(epoch, loss))
228
+ return loss
229
+
230
+
231
+ def run_epoch(self, epoch):
232
+ self.model.train()
233
+ losses = []
234
+
235
+ #shuffle the train dataset
236
+ input_data = torch.from_numpy(self.train_data.astype('int64'))
237
+ actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
238
+ del input_data
239
+
240
+ batch_size = self.args.train_batch_size
241
+
242
+ for b_begin in tqdm(range(0, actual_examples.shape[0], batch_size)):
243
+ self.optimizer.zero_grad()
244
+ input_batch = actual_examples[b_begin: b_begin + batch_size]
245
+ input_batch = input_batch.to(self.device)
246
+
247
+ s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
248
+
249
+ emb_s = self.model.emb_e(s).squeeze(dim=1)
250
+ emb_r = self.model.emb_rel(r).squeeze(dim=1)
251
+ emb_o = self.model.emb_e(o).squeeze(dim=1)
252
+
253
+ if self.args.add_reciprocals:
254
+ r_rev = r + self.n_rel
255
+ emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
256
+ else:
257
+ r_rev = r
258
+ emb_rrev = emb_r
259
+
260
+ pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
261
+ loss_sr = self.model.loss(pred_sr, o) # loss is cross entropy loss
262
+
263
+ pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
264
+ loss_or = self.model.loss(pred_or, s)
265
+
266
+ total_loss = loss_sr + loss_or
267
+
268
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
269
+ #self.logger.info('Computing regularizer weight')
270
+ if self.args.model == 'complex':
271
+ emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
272
+ lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
273
+ rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
274
+ rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
275
+ rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
276
+
277
+ #print(lhs[0].shape, lhs[1].shape)
278
+ factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
279
+ torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
280
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
281
+ factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
282
+ torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
283
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
284
+ else:
285
+ factors_sr = (emb_s, emb_r, emb_o)
286
+ factors_or = (emb_s, emb_rrev, emb_o)
287
+
288
+ total_loss += self.n3_regularizer(factors_sr)
289
+ total_loss += self.n3_regularizer(factors_or)
290
+
291
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
292
+ total_loss += self.lp_regularizer()
293
+
294
+
295
+ total_loss.backward()
296
+ self.optimizer.step()
297
+ losses.append(total_loss.item())
298
+
299
+ if self.args.save_influence_map: #for gradient rollback
300
+ with torch.no_grad():
301
+ prev_emb_e = self.previous_weights[0]
302
+ prev_emb_rel = self.previous_weights[1]
303
+ # need to compute the influence value per-triple
304
+ for idx in range(input_batch.shape[0]):
305
+ head, rel, tail = s[idx], r[idx], o[idx]
306
+ inf_head = (emb_s[idx] - prev_emb_e[head]).cpu().detach().numpy()
307
+ inf_tail = (emb_o[idx] - prev_emb_e[tail]).cpu().detach().numpy()
308
+ inf_rel = (emb_r[idx] - prev_emb_rel[rel]).cpu().detach().numpy()
309
+ #print(inf_head.shape, inf_tail.shape, inf_rel.shape)
310
+
311
+ #write the influences to dictionary
312
+ key_trip = '{0}_{1}_{2}'.format(head.item(), rel.item(), tail.item())
313
+ key = '{0}_s'.format(key_trip)
314
+ self.influence_map[key] += inf_head
315
+ #self.logger.info('Written to influence map. Key: {0}, Value shape: {1}'.format(key, inf_head.shape))
316
+ key = '{0}_r'.format(key_trip)
317
+ self.influence_map[key] += inf_rel
318
+ key = '{0}_o'.format(key_trip)
319
+ self.influence_map[key] += inf_tail
320
+
321
+ # update the previous weights to be tracked
322
+ self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
323
+
324
+ if (b_begin%5000 == 0) or (b_begin== (actual_examples.shape[0]-1)):
325
+ self.logger.info('[E:{} | {}]: Train Loss:{:.6}'.format(epoch, b_begin, np.mean(losses)))
326
+
327
+ loss = np.mean(losses)
328
+ self.logger.info('[Epoch:{}]: Training Loss:{:.6}\n'.format(epoch, loss))
329
+ return loss
330
+
331
+ def fit(self):
332
+ self.model.init()
333
+ self.logger.info(self.model)
334
+
335
+ self.logger.info('------ Start the model training ------')
336
+ start_time = time.time()
337
+ self.logger.info('Start time: {0}'.format(str(start_time)))
338
+
339
+
340
+ train_losses = []
341
+ valid_losses = []
342
+ best_val = 10000000000.
343
+ for epoch in range(self.args.epochs):
344
+
345
+ print("="*15,'epoch:',epoch,'='*15)
346
+ train_loss = self.run_epoch(epoch)
347
+ train_losses.append(train_loss)
348
+
349
+ if train_loss < best_val:
350
+ best_val = train_loss
351
+ self.save_model()
352
+ print("Train loss: {0}, Best loss: {1}\n\n".format(train_loss, best_val))
353
+
354
+
355
+ with open(self.loss_path, "wb") as fl:
356
+ pkl.dump({"train loss":train_losses, "valid loss":valid_losses}, fl)
357
+ self.logger.info('Time taken to train the model: {0}'.format(str(time.time() - start_time)))
358
+ start_time = time.time()
359
+
360
+ if self.args.save_influence_map: #save the influence map
361
+ with open(self.influence_path, "wb") as fl: #Pickling
362
+ pkl.dump(self.get_influence_map(), fl)
363
+ self.logger.info('Finished saving influence map')
364
+ self.logger.info('Time taken to save the influence map: {0}'.format(str(time.time() - start_time)))
365
+
366
+ #%%
367
+ parser = utils.get_argument_parser()
368
+
369
+ args = parser.parse_args()
370
+ args = utils.set_hyperparams(args)
371
+
372
+ utils.seed_all(args.seed)
373
+ np.set_printoptions(precision=5)
374
+ cudnn.benchmark = False
375
+
376
+ model = Main(args)
377
+ model.fit()
DiseaseSpecific/main_multiprocess.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Multiprocess for main.py"""
2
+ #%%
3
+ import pickle as pkl
4
+ from typing import Dict, Tuple, List
5
+ import os
6
+ import numpy as np
7
+ import json
8
+ import logging
9
+ import argparse
10
+ import math
11
+ from pprint import pprint
12
+ import pandas as pd
13
+ from collections import defaultdict
14
+ import copy
15
+ import time
16
+
17
+ import torch
18
+ from torch.utils.data import DataLoader
19
+ import torch.backends.cudnn as cudnn
20
+ import torch.autograd as autograd
21
+
22
+ from model import Distmult, Complex, Conve
23
+ import utils
24
+
25
+ # from evaluation import evaluation
26
+
27
+ #%%
28
+ class Main(object):
29
+ def __init__(self, args):
30
+ self.args = args
31
+
32
+ self.model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, args.thread_name)
33
+ #leaving batches from the model_name since they do not depend on model_architecture
34
+ # also leaving kernel size and filters, siinice don't intend to change those
35
+ self.model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, self.model_name)
36
+
37
+ self.log_path = 'logs/evaluation_logs/{0}_{1}_{2}_{3}_{4}.log'.format(args.data, self.model_name, args.epochs, args.train_batch_size, args.thread_name)
38
+ self.loss_path = 'losses/evaluation_losses/{0}_{1}_{2}_{3}_{4}.pickle'.format(args.data, self.model_name, args.epochs, args.train_batch_size, args.thread_name)
39
+
40
+ logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
41
+ datefmt = '%m/%d/%Y %H:%M:%S',
42
+ level = logging.INFO,
43
+ filename = self.log_path)
44
+ self.logger = logging.getLogger(__name__)
45
+ self.logger.info(vars(self.args))
46
+ self.logger.info('\n')
47
+
48
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+
50
+ self.load_data()
51
+ self.model = self.add_model()
52
+ self.optimizer = self.add_optimizer(self.model.parameters())
53
+
54
+ if self.args.save_influence_map:
55
+ self.logger.info('-------- Argument save_influence_map is set. Will use GR to compute and save influence maps ----------\n')
56
+ # when we want to save influence during training
57
+ self.args.add_reciprocals = False # to keep things simple
58
+ # init an empty influence map
59
+ self.influence_map = defaultdict(float)
60
+ #self.influence_path = 'influence_maps/{0}_{1}.json'.format(args.data, self.model_name)
61
+ self.influence_path = 'influence_maps/{0}_{1}.pickle'.format(args.data, self.model_name)
62
+ # Initialize a copy of the model prams to track previous weights in an epoch
63
+ self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
64
+ self.logger.info('Shape for previous weights: {}, {}'.format(self.previous_weights[0].shape, self.previous_weights[1].shape))
65
+
66
+ def load_data(self):
67
+ '''
68
+ Load the train, valid datasets
69
+ '''
70
+ data_path = os.path.join('processed_data', self.args.data)
71
+
72
+ n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
73
+ self.n_ent = n_ent
74
+ self.n_rel = n_rel
75
+
76
+ self.train_data = utils.load_data(os.path.join(data_path, 'evaluation', 'all_{}.txt'.format(self.args.thread_name)))
77
+
78
+ self.valid_data= self.train_data[-100:, :].copy()
79
+ # self.train_data = utils.load_data()
80
+
81
+ def add_model(self):
82
+
83
+ if self.args.model is None:
84
+ model = Distmult(self.args, self.n_ent, self.n_rel)
85
+ elif self.args.model == 'distmult':
86
+ model = Distmult(self.args, self.n_ent, self.n_rel)
87
+ elif self.args.model == 'complex':
88
+ model = Complex(self.args, self.n_ent, self.n_rel)
89
+ elif self.args.model == 'conve':
90
+ model = Conve(self.args, self.n_ent, self.n_rel)
91
+ else:
92
+ self.logger.info('Unknown model: {0}', self.args.model)
93
+ raise Exception("Unknown model!")
94
+ model.to(self.device)
95
+ return model
96
+
97
+ def add_optimizer(self, parameters):
98
+ #if self.args.optimizer == 'adam' : return torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
99
+ #else : return torch.optim.SGD(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
100
+ return torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
101
+
102
+ def save_model(self):
103
+ state = {
104
+ 'state_dict': self.model.state_dict(),
105
+ 'optimizer': self.optimizer.state_dict(),
106
+ 'args': vars(self.args)
107
+ }
108
+ torch.save(state, self.model_path)
109
+ self.logger.info('Saving model to {0}'.format(self.model_path))
110
+
111
+ def load_model(self):
112
+ self.logger.info('Loading saved model from {0}'.format(self.model_path))
113
+ state = torch.load(self.model_path)
114
+ model_params = state['state_dict']
115
+ params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
116
+ for key, size, count in params:
117
+ self.logger.info(key, size, count)
118
+ self.model.load_state_dict(model_params)
119
+ self.optimizer.load_state_dict(state['optimizer'])
120
+
121
+ def lp_regularizer(self):
122
+ # Apply p-norm regularization; assign weights to each param
123
+ weight = self.args.reg_weight
124
+ p = self.args.reg_norm
125
+
126
+ trainable_params = [self.model.emb_e.weight, self.model.emb_rel.weight]
127
+ norm = 0
128
+ for i in range(len(trainable_params)):
129
+ #norm += weight * trainable_params[i].norm(p = p)**p
130
+ norm += weight * torch.sum( torch.abs(trainable_params[i]) ** p)
131
+
132
+ return norm
133
+
134
+ def n3_regularizer(self, factors):
135
+ # factors are the embeddings for lhs, rel, rhs for triples in a batch
136
+ weight = self.args.reg_weight
137
+ p = self.args.reg_norm
138
+
139
+ norm = 0
140
+ for f in factors:
141
+ norm += weight * torch.sum(torch.abs(f) ** p)
142
+
143
+ return norm / factors[0].shape[0] # scale by number of triples in batch
144
+
145
+ def get_influence_map(self):
146
+ """
147
+ Turns the influence map into a list, ready to be written to disc. (before: numpy)
148
+ :return: the influence map with lists as values
149
+ """
150
+ assert self.args.save_influence_map == True
151
+
152
+ for key in self.influence_map:
153
+ self.influence_map[key] = self.influence_map[key].tolist()
154
+ #self.logger.info('get_influence_map passed')
155
+ return self.influence_map
156
+
157
+ def evaluate(self, split, batch_size, epoch):
158
+ """
159
+ The same as self.run_epoch()
160
+ """
161
+
162
+ self.model.eval()
163
+ losses = []
164
+
165
+ with torch.no_grad():
166
+ input_data = torch.from_numpy(self.valid_data.astype('int64'))
167
+ actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
168
+ del input_data
169
+
170
+ batch_size = self.args.valid_batch_size
171
+ for b_begin in range(0, actual_examples.shape[0], batch_size):
172
+
173
+ input_batch = actual_examples[b_begin: b_begin + batch_size]
174
+ input_batch = input_batch.to(self.device)
175
+
176
+ s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
177
+
178
+ emb_s = self.model.emb_e(s).squeeze(dim=1)
179
+ emb_r = self.model.emb_rel(r).squeeze(dim=1)
180
+ emb_o = self.model.emb_e(o).squeeze(dim=1)
181
+
182
+ if self.args.add_reciprocals:
183
+ r_rev = r + self.n_rel
184
+ emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
185
+ else:
186
+ r_rev = r
187
+ emb_rrev = emb_r
188
+
189
+ pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
190
+ loss_sr = self.model.loss(pred_sr, o) # cross entropy loss
191
+
192
+ pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
193
+ loss_or = self.model.loss(pred_or, s)
194
+
195
+ total_loss = loss_sr + loss_or
196
+
197
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
198
+ #self.logger.info('Computing regularizer weight')
199
+ if self.args.model == 'complex':
200
+ emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
201
+ lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
202
+ rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
203
+ rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
204
+ rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
205
+
206
+ #print(lhs[0].shape, lhs[1].shape)
207
+ factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
208
+ torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
209
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
210
+ )
211
+ factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
212
+ torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
213
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
214
+ )
215
+ else:
216
+ factors_sr = (emb_s, emb_r, emb_o)
217
+ factors_or = (emb_s, emb_rrev, emb_o)
218
+
219
+ total_loss += self.n3_regularizer(factors_sr)
220
+ total_loss += self.n3_regularizer(factors_or)
221
+
222
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
223
+ total_loss += self.lp_regularizer()
224
+
225
+ losses.append(total_loss.item())
226
+
227
+ loss = np.mean(losses)
228
+ self.logger.info('[Epoch:{}]: Validating Loss:{:.6}\n'.format(epoch, loss))
229
+ return loss
230
+
231
+
232
+ def run_epoch(self, epoch):
233
+ self.model.train()
234
+ losses = []
235
+
236
+ #shuffle the train dataset
237
+ input_data = torch.from_numpy(self.train_data.astype('int64'))
238
+ actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
239
+ del input_data
240
+
241
+ batch_size = self.args.train_batch_size
242
+
243
+ for b_begin in range(0, actual_examples.shape[0], batch_size):
244
+ self.optimizer.zero_grad()
245
+ input_batch = actual_examples[b_begin: b_begin + batch_size]
246
+ input_batch = input_batch.to(self.device)
247
+
248
+ s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
249
+
250
+ emb_s = self.model.emb_e(s).squeeze(dim=1)
251
+ emb_r = self.model.emb_rel(r).squeeze(dim=1)
252
+ emb_o = self.model.emb_e(o).squeeze(dim=1)
253
+
254
+ if self.args.add_reciprocals:
255
+ r_rev = r + self.n_rel
256
+ emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
257
+ else:
258
+ r_rev = r
259
+ emb_rrev = emb_r
260
+
261
+ pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
262
+ loss_sr = self.model.loss(pred_sr, o) # loss is cross entropy loss
263
+
264
+ pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
265
+ loss_or = self.model.loss(pred_or, s)
266
+
267
+ total_loss = loss_sr + loss_or
268
+
269
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
270
+ #self.logger.info('Computing regularizer weight')
271
+ if self.args.model == 'complex':
272
+ emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
273
+ lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
274
+ rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
275
+ rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
276
+ rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
277
+
278
+ #print(lhs[0].shape, lhs[1].shape)
279
+ factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
280
+ torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
281
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
282
+ factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
283
+ torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
284
+ torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
285
+ else:
286
+ factors_sr = (emb_s, emb_r, emb_o)
287
+ factors_or = (emb_s, emb_rrev, emb_o)
288
+
289
+ total_loss += self.n3_regularizer(factors_sr)
290
+ total_loss += self.n3_regularizer(factors_or)
291
+
292
+ if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
293
+ total_loss += self.lp_regularizer()
294
+
295
+
296
+ total_loss.backward()
297
+ self.optimizer.step()
298
+ losses.append(total_loss.item())
299
+
300
+ if self.args.save_influence_map: #for gradient rollback
301
+ with torch.no_grad():
302
+ prev_emb_e = self.previous_weights[0]
303
+ prev_emb_rel = self.previous_weights[1]
304
+ # need to compute the influence value per-triple
305
+ for idx in range(input_batch.shape[0]):
306
+ head, rel, tail = s[idx], r[idx], o[idx]
307
+ inf_head = (emb_s[idx] - prev_emb_e[head]).cpu().detach().numpy()
308
+ inf_tail = (emb_o[idx] - prev_emb_e[tail]).cpu().detach().numpy()
309
+ inf_rel = (emb_r[idx] - prev_emb_rel[rel]).cpu().detach().numpy()
310
+ #print(inf_head.shape, inf_tail.shape, inf_rel.shape)
311
+
312
+ #write the influences to dictionary
313
+ key_trip = '{0}_{1}_{2}'.format(head.item(), rel.item(), tail.item())
314
+ key = '{0}_s'.format(key_trip)
315
+ self.influence_map[key] += inf_head
316
+ #self.logger.info('Written to influence map. Key: {0}, Value shape: {1}'.format(key, inf_head.shape))
317
+ key = '{0}_r'.format(key_trip)
318
+ self.influence_map[key] += inf_rel
319
+ key = '{0}_o'.format(key_trip)
320
+ self.influence_map[key] += inf_tail
321
+
322
+ # update the previous weights to be tracked
323
+ self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
324
+
325
+ if (b_begin%5000 == 0) or (b_begin== (actual_examples.shape[0]-1)):
326
+ self.logger.info('[E:{} | {}]: Train Loss:{:.6}'.format(epoch, b_begin, np.mean(losses)))
327
+
328
+ loss = np.mean(losses)
329
+ self.logger.info('[Epoch:{}]: Training Loss:{:.6}\n'.format(epoch, loss))
330
+ return loss
331
+
332
+ def fit(self):
333
+ # if self.args.resume:
334
+ # self.load_model()
335
+ # results = self.evaluate(split=self.args.resume_split, batch_size = self.args.test_batch_size, epoch = -1)
336
+ # pprint(results)
337
+
338
+ # else:
339
+ self.model.init()
340
+ self.logger.info(self.model)
341
+
342
+ self.logger.info('------ Start the model training ------')
343
+ start_time = time.time()
344
+ self.logger.info('Start time: {0}'.format(str(start_time)))
345
+
346
+
347
+ train_losses = []
348
+ valid_losses = []
349
+ best_val = 10000000000.
350
+ for epoch in range(self.args.epochs):
351
+
352
+ train_loss = self.run_epoch(epoch)
353
+ train_losses.append(train_loss)
354
+
355
+ # Don't use valid_data here !!!!!!!!!
356
+
357
+ # valid_loss = self.evaluate(split='valid', batch_size = self.args.valid_batch_size, epoch = epoch)
358
+ # valid_losses.append(valid_loss)
359
+ # results_test = self.evaluate(split='test', batch_size = self.args.test_batch_size, epoch = epoch)
360
+ if train_loss < best_val:
361
+ best_val = train_loss
362
+ self.save_model()
363
+ self.logger.info("Train loss: {0}, Best loss: {1}\n\n".format(train_loss, best_val))
364
+ # print("Valid loss: {0}, Best loss: {1}\n\n".format(valid_loss, best_val))
365
+
366
+ with open(self.loss_path, "wb") as fl:
367
+ pkl.dump({"train loss":train_losses, "valid loss":valid_losses}, fl)
368
+ self.logger.info('Time taken to train the model: {0}'.format(str(time.time() - start_time)))
369
+ start_time = time.time()
370
+
371
+ if self.args.save_influence_map: #save the influence map
372
+ with open(self.influence_path, "wb") as fl: #Pickling
373
+ pkl.dump(self.get_influence_map(), fl)
374
+ self.logger.info('Finished saving influence map')
375
+ self.logger.info('Time taken to save the influence map: {0}'.format(str(time.time() - start_time)))
376
+
377
+ #%%
378
+ parser = utils.get_argument_parser()
379
+ parser.add_argument('--thread-name', type = str, required=True, help = "This parameter will be automatically determined.")
380
+
381
+ args = parser.parse_args()
382
+ args = utils.set_hyperparams(args)
383
+ # if args.reproduce_results:
384
+ # args = utils.set_hyperparams(args)
385
+
386
+ utils.seed_all(args.seed)
387
+ np.set_printoptions(precision=5)
388
+ cudnn.benchmark = False
389
+
390
+ model = Main(args)
391
+ model.fit()
DiseaseSpecific/model.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F, Parameter
3
+ from torch.autograd import Variable
4
+ from torch.nn.init import xavier_normal_, xavier_uniform_
5
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
6
+
7
+ class Distmult(torch.nn.Module):
8
+ def __init__(self, args, num_entities, num_relations):
9
+ super(Distmult, self).__init__()
10
+
11
+ if args.max_norm:
12
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
13
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
14
+ else:
15
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
16
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
17
+
18
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
19
+ self.loss = torch.nn.CrossEntropyLoss()
20
+
21
+ self.init()
22
+
23
+ def init(self):
24
+ xavier_normal_(self.emb_e.weight)
25
+ xavier_normal_(self.emb_rel.weight)
26
+
27
+ def score_sr(self, sub, rel, sigmoid = False):
28
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
29
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
30
+
31
+ #sub_emb = self.inp_drop(sub_emb)
32
+ #rel_emb = self.inp_drop(rel_emb)
33
+
34
+ pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
35
+ if sigmoid:
36
+ pred = torch.sigmoid(pred)
37
+ return pred
38
+
39
+ def score_or(self, obj, rel, sigmoid = False):
40
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
41
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
42
+
43
+ #obj_emb = self.inp_drop(obj_emb)
44
+ #rel_emb = self.inp_drop(rel_emb)
45
+
46
+ pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0))
47
+ if sigmoid:
48
+ pred = torch.sigmoid(pred)
49
+ return pred
50
+
51
+
52
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
53
+ '''
54
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
55
+ For distmult, computations for both modes are equivalent, so we do not need if-else block
56
+ '''
57
+ sub_emb = self.inp_drop(sub_emb)
58
+ rel_emb = self.inp_drop(rel_emb)
59
+
60
+ pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
61
+
62
+ if sigmoid:
63
+ pred = torch.sigmoid(pred)
64
+
65
+ return pred
66
+
67
+ def score_triples(self, sub, rel, obj, sigmoid=False):
68
+ '''
69
+ Inputs - subject, relation, object
70
+ Return - score
71
+ '''
72
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
73
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
74
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
75
+
76
+ pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1)
77
+
78
+ if sigmoid:
79
+ pred = torch.sigmoid(pred)
80
+
81
+ return pred
82
+
83
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
84
+ '''
85
+ Inputs - embeddings of subject, relation, object
86
+ Return - score
87
+ '''
88
+ pred = torch.sum(emb_s*emb_r*emb_o, dim=-1)
89
+
90
+ if sigmoid:
91
+ pred = torch.sigmoid(pred)
92
+
93
+ return pred
94
+
95
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
96
+ '''
97
+ Inputs - subject, relation, object
98
+ Return - a vector score for the triple instead of reducing over the embedding dimension
99
+ '''
100
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
101
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
102
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
103
+
104
+ pred = sub_emb*rel_emb*obj_emb
105
+
106
+ if sigmoid:
107
+ pred = torch.sigmoid(pred)
108
+
109
+ return pred
110
+
111
+ class Complex(torch.nn.Module):
112
+ def __init__(self, args, num_entities, num_relations):
113
+ super(Complex, self).__init__()
114
+
115
+ if args.max_norm:
116
+ self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0)
117
+ self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim)
118
+ else:
119
+ self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None)
120
+ self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None)
121
+
122
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
123
+ self.loss = torch.nn.CrossEntropyLoss()
124
+
125
+ self.init()
126
+
127
+ def init(self):
128
+ xavier_normal_(self.emb_e.weight)
129
+ xavier_normal_(self.emb_rel.weight)
130
+
131
+ def score_sr(self, sub, rel, sigmoid = False):
132
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
133
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
134
+
135
+ s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
136
+ rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
137
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
138
+
139
+ realo_realreal = s_real*rel_real
140
+ realo_imgimg = s_img*rel_img
141
+ realo = realo_realreal - realo_imgimg
142
+ real = torch.mm(realo, emb_e_real.transpose(1,0))
143
+
144
+ imgo_realimg = s_real*rel_img
145
+ imgo_imgreal = s_img*rel_real
146
+ imgo = imgo_realimg + imgo_imgreal
147
+ img = torch.mm(imgo, emb_e_img.transpose(1,0))
148
+
149
+ pred = real + img
150
+
151
+ if sigmoid:
152
+ pred = torch.sigmoid(pred)
153
+ return pred
154
+
155
+
156
+ def score_or(self, obj, rel, sigmoid = False):
157
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
158
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
159
+
160
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
161
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
162
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
163
+
164
+ #rel_real = self.inp_drop(rel_real)
165
+ #rel_img = self.inp_drop(rel_img)
166
+ #o_real = self.inp_drop(o_real)
167
+ #o_img = self.inp_drop(o_img)
168
+
169
+ # complex space bilinear product (equivalent to HolE)
170
+ # realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
171
+ # realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
172
+ # imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
173
+ # imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
174
+ # pred = realrealreal + realimgimg + imgrealimg - imgimgreal
175
+
176
+ reals_realreal = rel_real*o_real
177
+ reals_imgimg = rel_img*o_img
178
+ reals = reals_realreal + reals_imgimg
179
+ real = torch.mm(reals, emb_e_real.transpose(1,0))
180
+
181
+ imgs_realimg = rel_real*o_img
182
+ imgs_imgreal = rel_img*o_real
183
+ imgs = imgs_realimg - imgs_imgreal
184
+ img = torch.mm(imgs, emb_e_img.transpose(1,0))
185
+
186
+ pred = real + img
187
+
188
+ if sigmoid:
189
+ pred = torch.sigmoid(pred)
190
+ return pred
191
+
192
+
193
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
194
+ '''
195
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
196
+
197
+ '''
198
+ if mode == 'lhs':
199
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
200
+ o_real, o_img = torch.chunk(sub_emb, 2, dim=-1)
201
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
202
+
203
+ rel_real = self.inp_drop(rel_real)
204
+ rel_img = self.inp_drop(rel_img)
205
+ o_real = self.inp_drop(o_real)
206
+ o_img = self.inp_drop(o_img)
207
+
208
+ reals_realreal = rel_real*o_real
209
+ reals_imgimg = rel_img*o_img
210
+ reals = reals_realreal + reals_imgimg
211
+ real = torch.mm(reals, emb_e_real.transpose(1,0))
212
+
213
+ imgs_realimg = rel_real*o_img
214
+ imgs_imgreal = rel_img*o_real
215
+ imgs = imgs_realimg - imgs_imgreal
216
+ img = torch.mm(imgs, emb_e_img.transpose(1,0))
217
+
218
+ pred = real + img
219
+
220
+ else:
221
+ s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
222
+ rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
223
+ emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
224
+
225
+ s_real = self.inp_drop(s_real)
226
+ s_img = self.inp_drop(s_img)
227
+ rel_real = self.inp_drop(rel_real)
228
+ rel_img = self.inp_drop(rel_img)
229
+
230
+ realo_realreal = s_real*rel_real
231
+ realo_imgimg = s_img*rel_img
232
+ realo = realo_realreal - realo_imgimg
233
+ real = torch.mm(realo, emb_e_real.transpose(1,0))
234
+
235
+ imgo_realimg = s_real*rel_img
236
+ imgo_imgreal = s_img*rel_real
237
+ imgo = imgo_realimg + imgo_imgreal
238
+ img = torch.mm(imgo, emb_e_img.transpose(1,0))
239
+
240
+ pred = real + img
241
+
242
+ if sigmoid:
243
+ pred = torch.sigmoid(pred)
244
+
245
+ return pred
246
+
247
+ def score_triples(self, sub, rel, obj, sigmoid=False):
248
+ '''
249
+ Inputs - subject, relation, object
250
+ Return - score
251
+ '''
252
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
253
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
254
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
255
+
256
+ s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
257
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
258
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
259
+
260
+ realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
261
+ realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
262
+ imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
263
+ imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
264
+
265
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
266
+
267
+ if sigmoid:
268
+ pred = torch.sigmoid(pred)
269
+
270
+ return pred
271
+
272
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
273
+ '''
274
+ Inputs - embeddings of subject, relation, object
275
+ Return - score
276
+ '''
277
+
278
+ s_real, s_img = torch.chunk(emb_s, 2, dim=-1)
279
+ rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1)
280
+ o_real, o_img = torch.chunk(emb_o, 2, dim=-1)
281
+
282
+ realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
283
+ realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
284
+ imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
285
+ imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
286
+
287
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
288
+
289
+ if sigmoid:
290
+ pred = torch.sigmoid(pred)
291
+
292
+ return pred
293
+
294
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
295
+ '''
296
+ Inputs - subject, relation, object
297
+ Return - a vector score for the triple instead of reducing over the embedding dimension
298
+ '''
299
+ sub_emb = self.emb_e(sub).squeeze(dim=1)
300
+ rel_emb = self.emb_rel(rel).squeeze(dim=1)
301
+ obj_emb = self.emb_e(obj).squeeze(dim=1)
302
+
303
+ s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
304
+ rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
305
+ o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
306
+
307
+ realrealreal = s_real*rel_real*o_real
308
+ realimgimg = s_real*rel_img*o_img
309
+ imgrealimg = s_img*rel_real*o_img
310
+ imgimgreal = s_img*rel_img*o_real
311
+
312
+ pred = realrealreal + realimgimg + imgrealimg - imgimgreal
313
+
314
+ if sigmoid:
315
+ pred = torch.sigmoid(pred)
316
+
317
+ return pred
318
+
319
+ class Conve(torch.nn.Module):
320
+
321
+ #Too slow !!!!
322
+
323
+ def __init__(self, args, num_entities, num_relations):
324
+ super(Conve, self).__init__()
325
+
326
+ if args.max_norm:
327
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
328
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
329
+ else:
330
+ self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
331
+ self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
332
+
333
+ self.inp_drop = torch.nn.Dropout(args.input_drop)
334
+ self.hidden_drop = torch.nn.Dropout(args.hidden_drop)
335
+ self.feature_drop = torch.nn.Dropout2d(args.feat_drop)
336
+
337
+ self.embedding_dim = args.embedding_dim #default is 200
338
+ self.num_filters = args.num_filters # default is 32
339
+ self.kernel_size = args.kernel_size # default is 3
340
+ self.stack_width = args.stack_width # default is 20
341
+ self.stack_height = args.embedding_dim // self.stack_width
342
+
343
+ self.bn0 = torch.nn.BatchNorm2d(1)
344
+ self.bn1 = torch.nn.BatchNorm2d(self.num_filters)
345
+ self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim)
346
+
347
+ self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters,
348
+ kernel_size=(self.kernel_size, self.kernel_size),
349
+ stride=1, padding=0, bias=args.use_bias)
350
+ #self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default
351
+
352
+ flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1
353
+ flat_sz_w = self.stack_height - self.kernel_size + 1
354
+ self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters
355
+ self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim)
356
+
357
+ self.register_parameter('b', Parameter(torch.zeros(num_entities)))
358
+ self.loss = torch.nn.CrossEntropyLoss()
359
+
360
+ self.init()
361
+
362
+ def init(self):
363
+ xavier_normal_(self.emb_e.weight)
364
+ xavier_normal_(self.emb_rel.weight)
365
+
366
+ def concat(self, e1_embed, rel_embed, form='plain'):
367
+ if form == 'plain':
368
+ e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height)
369
+ rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height)
370
+ stack_inp = torch.cat([e1_embed, rel_embed], 2)
371
+
372
+ elif form == 'alternate':
373
+ e1_embed = e1_embed. view(-1, 1, self.embedding_dim)
374
+ rel_embed = rel_embed.view(-1, 1, self.embedding_dim)
375
+ stack_inp = torch.cat([e1_embed, rel_embed], 1)
376
+ stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height))
377
+
378
+ else: raise NotImplementedError
379
+ return stack_inp
380
+
381
+ def conve_architecture(self, sub_emb, rel_emb):
382
+ stacked_inputs = self.concat(sub_emb, rel_emb)
383
+ stacked_inputs = self.bn0(stacked_inputs)
384
+ x = self.inp_drop(stacked_inputs)
385
+ x = self.conv1(x)
386
+ x = self.bn1(x)
387
+ x = F.relu(x)
388
+ x = self.feature_drop(x)
389
+ #x = x.view(x.shape[0], -1)
390
+ x = x.view(-1, self.flat_sz)
391
+ x = self.fc(x)
392
+ x = self.hidden_drop(x)
393
+ x = self.bn2(x)
394
+ x = F.relu(x)
395
+
396
+ return x
397
+
398
+ def score_sr(self, sub, rel, sigmoid = False):
399
+ sub_emb = self.emb_e(sub)
400
+ rel_emb = self.emb_rel(rel)
401
+
402
+ x = self.conve_architecture(sub_emb, rel_emb)
403
+
404
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
405
+ pred += self.b.expand_as(pred)
406
+
407
+ if sigmoid:
408
+ pred = torch.sigmoid(pred)
409
+ return pred
410
+
411
+ def score_or(self, obj, rel, sigmoid = False):
412
+ obj_emb = self.emb_e(obj)
413
+ rel_emb = self.emb_rel(rel)
414
+
415
+ x = self.conve_architecture(obj_emb, rel_emb)
416
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
417
+ pred += self.b.expand_as(pred)
418
+
419
+ if sigmoid:
420
+ pred = torch.sigmoid(pred)
421
+ return pred
422
+
423
+
424
+ def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
425
+ '''
426
+ When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
427
+ For conve, computations for both modes are equivalent, so we do not need if-else block
428
+ '''
429
+ x = self.conve_architecture(sub_emb, rel_emb)
430
+
431
+ pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
432
+ pred += self.b.expand_as(pred)
433
+
434
+ if sigmoid:
435
+ pred = torch.sigmoid(pred)
436
+
437
+ return pred
438
+
439
+ def score_triples(self, sub, rel, obj, sigmoid=False):
440
+ '''
441
+ Inputs - subject, relation, object
442
+ Return - score
443
+ '''
444
+ sub_emb = self.emb_e(sub)
445
+ rel_emb = self.emb_rel(rel)
446
+ obj_emb = self.emb_e(obj)
447
+ x = self.conve_architecture(sub_emb, rel_emb)
448
+
449
+ pred = torch.mm(x, obj_emb.transpose(1,0))
450
+ #print(pred.shape)
451
+ pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding
452
+ # above works fine for single input triples;
453
+ # but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
454
+ # so use torch.diagonal() after calling this function
455
+ pred = torch.diagonal(pred)
456
+ # or could have used : pred= torch.sum(x*obj_emb, dim=-1)
457
+
458
+ if sigmoid:
459
+ pred = torch.sigmoid(pred)
460
+
461
+ return pred
462
+
463
+ def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
464
+ '''
465
+ Inputs - embeddings of subject, relation, object
466
+ Return - score
467
+ '''
468
+ x = self.conve_architecture(emb_s, emb_r)
469
+
470
+ pred = torch.mm(x, emb_o.transpose(1,0))
471
+ #pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - don't know which obj
472
+ # above works fine for single input triples;
473
+ # but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
474
+ # so use torch.diagonal() after calling this function
475
+ pred = torch.diagonal(pred)
476
+ # or could have used : pred= torch.sum(x*obj_emb, dim=-1)
477
+
478
+ if sigmoid:
479
+ pred = torch.sigmoid(pred)
480
+
481
+ return pred
482
+
483
+ def score_triples_vec(self, sub, rel, obj, sigmoid=False):
484
+ '''
485
+ Inputs - subject, relation, object
486
+ Return - a vector score for the triple instead of reducing over the embedding dimension
487
+ '''
488
+ sub_emb = self.emb_e(sub)
489
+ rel_emb = self.emb_rel(rel)
490
+ obj_emb = self.emb_e(obj)
491
+
492
+ x = self.conve_architecture(sub_emb, rel_emb)
493
+
494
+ #pred = torch.mm(x, obj_emb.transpose(1,0))
495
+ pred = x*obj_emb
496
+ #print(pred.shape, self.b[obj].shape) #shapes are [7,200] and [7]
497
+ #pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - can't add scalar to vector
498
+
499
+ #pred = sub_emb*rel_emb*obj_emb
500
+
501
+ if sigmoid:
502
+ pred = torch.sigmoid(pred)
503
+
504
+ return pred
DiseaseSpecific/utils.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
3
+ '''
4
+ #%%
5
+ import logging
6
+ import time
7
+ from tqdm import tqdm
8
+ import io
9
+ import pandas as pd
10
+ import numpy as np
11
+ import os
12
+ import json
13
+
14
+ import argparse
15
+ import torch
16
+ import random
17
+
18
+ from yaml import parse
19
+
20
+ from model import Conve, Distmult, Complex
21
+
22
+ logger = logging.getLogger(__name__)
23
+ #%%
24
+ def generate_dicts(data_path):
25
+ with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
26
+ ent_to_id = json.load(f)
27
+ with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
28
+ rel_to_id = json.load(f)
29
+ n_ent = len(list(ent_to_id.keys()))
30
+ n_rel = len(list(rel_to_id.keys()))
31
+
32
+ return n_ent, n_rel, ent_to_id, rel_to_id
33
+
34
+ def save_data(file_name, data):
35
+ with open(file_name, 'w') as fl:
36
+ for item in data:
37
+ fl.write("%s\n" % "\t".join(map(str, item)))
38
+
39
+ def load_data(file_name, drop = True):
40
+ df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
41
+ if drop:
42
+ df = df.drop_duplicates()
43
+ else:
44
+ pass
45
+ return df.values
46
+
47
+ def seed_all(seed=1):
48
+ random.seed(seed)
49
+ np.random.seed(seed)
50
+ torch.manual_seed(seed)
51
+ torch.cuda.manual_seed_all(seed)
52
+ os.environ['PYTHONHASHSEED'] = str(seed)
53
+ torch.backends.cudnn.deterministic = True
54
+
55
+ def add_model(args, n_ent, n_rel):
56
+ if args.model is None:
57
+ model = Distmult(args, n_ent, n_rel)
58
+ elif args.model == 'distmult':
59
+ model = Distmult(args, n_ent, n_rel)
60
+ elif args.model == 'complex':
61
+ model = Complex(args, n_ent, n_rel)
62
+ elif args.model == 'conve':
63
+ model = Conve(args, n_ent, n_rel)
64
+ else:
65
+ raise Exception("Unknown model!")
66
+
67
+ return model
68
+
69
+ def load_model(model_path, args, n_ent, n_rel, device):
70
+ # add a model and load the pre-trained params
71
+ model = add_model(args, n_ent, n_rel)
72
+ model.to(device)
73
+ logger.info('Loading saved model from {0}'.format(model_path))
74
+ state = torch.load(model_path)
75
+ model_params = state['state_dict']
76
+ params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
77
+ for key, size, count in params:
78
+ logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
79
+
80
+ model.load_state_dict(model_params)
81
+ model.eval()
82
+ logger.info(model)
83
+
84
+ return model
85
+
86
+ def add_eval_parameters(parser):
87
+
88
+ # parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
89
+ parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
90
+ parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
91
+ parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
92
+ parser.add_argument('--mode', type = str, default = '', help = ' '' or '' ')
93
+ parser.add_argument('--mask-ratio', type=str, default='', help='Mask ratio for Fig4b')
94
+ return parser
95
+
96
+ def add_attack_parameters(parser):
97
+
98
+ # parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
99
+ parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
100
+ parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
101
+ parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')
102
+
103
+ # parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')
104
+
105
+ parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
106
+ parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
107
+ parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
108
+ parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
109
+ parser.add_argument('--added-edge-num', type = str, default='', help = 'How many edges to add for each target edge. Default: '' means 1.')
110
+ # parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
111
+ # parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
112
+ parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
113
+ parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')
114
+
115
+ parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')
116
+
117
+ parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
118
+ parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
119
+ parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')
120
+
121
+ parser.add_argument('--load-existed', action='store_true', help = 'Use cached intermidiate results or not, when only --reasonable-rate changed, set this param to True')
122
+
123
+ return parser
124
+
125
+ def get_argument_parser():
126
+ '''Generate an argument parser'''
127
+ parser = argparse.ArgumentParser(description='Graph embedding')
128
+
129
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
130
+
131
+ parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
132
+ parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, conve, complex}')
133
+
134
+ parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
135
+ parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
136
+
137
+ parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
138
+ parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
139
+ parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
140
+ parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
141
+
142
+ parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
143
+ parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
144
+ parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
145
+ parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
146
+
147
+ parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
148
+ parser.add_argument('--add-reciprocals', action='store_true')
149
+
150
+ parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
151
+ parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
152
+ #parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
153
+ parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
154
+ parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
155
+ parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
156
+ parser.add_argument('-num-filters', default=32, type=int, help='Number of filters for convolution')
157
+ parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
158
+
159
+ parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
160
+
161
+ parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
162
+ parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
163
+ # parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
164
+ # parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
165
+ # parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
166
+ # parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
167
+ return parser
168
+
169
+ def set_hyperparams(args):
170
+ if args.model == 'distmult':
171
+ args.lr = 0.005
172
+ args.train_batch_size = 1024
173
+ args.reg_norm = 3
174
+ elif args.model == 'complex':
175
+ args.lr = 0.005
176
+ args.reg_norm = 3
177
+ args.input_drop = 0.4
178
+ args.train_batch_size = 1024
179
+ elif args.model == 'conve':
180
+ args.lr = 0.005
181
+ args.train_batch_size = 1024
182
+ args.reg_weight = 0.0
183
+
184
+ # args.damping = 0.01
185
+ # args.lissa_repeat = 1
186
+ # args.lissa_depth = 1
187
+ # args.scale = 500
188
+ # args.lissa_batch_size = 100
189
+
190
+ args.damping = 0.01
191
+ args.lissa_repeat = 1
192
+ args.lissa_depth = 1
193
+ args.scale = 400
194
+ args.lissa_batch_size = 300
195
+ return args