Spaces:
Running
Running
yjwtheonly
commited on
Commit
•
ac7c391
1
Parent(s):
bdc453c
specific
Browse files- DiseaseSpecific/KG_extractor.py +479 -0
- DiseaseSpecific/attack.py +841 -0
- DiseaseSpecific/edge_to_abstract.py +530 -0
- DiseaseSpecific/evaluation.py +499 -0
- DiseaseSpecific/main.py +377 -0
- DiseaseSpecific/main_multiprocess.py +391 -0
- DiseaseSpecific/model.py +504 -0
- DiseaseSpecific/utils.py +195 -0
DiseaseSpecific/KG_extractor.py
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch.autograd import Variable
|
5 |
+
from sklearn import metrics
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
from typing import Dict, Tuple, List
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import utils
|
12 |
+
import pickle as pkl
|
13 |
+
import json
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
import sys
|
19 |
+
sys.path.append("..")
|
20 |
+
import Parameters
|
21 |
+
|
22 |
+
parser = utils.get_argument_parser()
|
23 |
+
parser = utils.add_attack_parameters(parser)
|
24 |
+
parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART')
|
25 |
+
parser.add_argument('--action', type=str, default='parse', help='parse or extract')
|
26 |
+
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
|
27 |
+
args = parser.parse_args()
|
28 |
+
args = utils.set_hyperparams(args)
|
29 |
+
|
30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
31 |
+
|
32 |
+
utils.seed_all(args.seed)
|
33 |
+
np.set_printoptions(precision=5)
|
34 |
+
cudnn.benchmark = False
|
35 |
+
|
36 |
+
data_path = os.path.join('processed_data', args.data)
|
37 |
+
target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
|
38 |
+
attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}.txt'.format(args.model,
|
39 |
+
args.target_split,
|
40 |
+
args.target_size,
|
41 |
+
'exists:'+str(args.target_existed),
|
42 |
+
args.neighbor_num,
|
43 |
+
args.candidate_mode,
|
44 |
+
args.attack_goal,
|
45 |
+
str(args.reasonable_rate)))
|
46 |
+
modified_attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}.txt'.format(args.model,
|
47 |
+
args.target_split,
|
48 |
+
args.target_size,
|
49 |
+
'exists:'+str(args.target_existed),
|
50 |
+
args.neighbor_num,
|
51 |
+
args.candidate_mode,
|
52 |
+
args.attack_goal,
|
53 |
+
str(args.reasonable_rate),
|
54 |
+
args.mode))
|
55 |
+
attack_data = utils.load_data(attack_path, drop=False)
|
56 |
+
#%%
|
57 |
+
with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
|
58 |
+
id_to_meshid = json.load(fl)
|
59 |
+
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
|
60 |
+
meshid_to_id = json.load(fl)
|
61 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
62 |
+
entity_raw_name = pkl.load(fl)
|
63 |
+
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
|
64 |
+
retieve_sentence_through_edgetype = pkl.load(fl)
|
65 |
+
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
66 |
+
raw_text_sen = pkl.load(fl)
|
67 |
+
with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl:
|
68 |
+
full_entity_raw_name = pkl.load(fl)
|
69 |
+
for k, v in entity_raw_name.items():
|
70 |
+
assert v in full_entity_raw_name[k]
|
71 |
+
|
72 |
+
#find unique
|
73 |
+
once_set = set()
|
74 |
+
twice_set = set()
|
75 |
+
|
76 |
+
with open('generate_abstract/valid_entity.json', 'r') as fl:
|
77 |
+
valid_entity = json.load(fl)
|
78 |
+
valid_entity = set(valid_entity)
|
79 |
+
|
80 |
+
good_name = set()
|
81 |
+
for k, v, in full_entity_raw_name.items():
|
82 |
+
names = list(v)
|
83 |
+
for name in names:
|
84 |
+
# if name == 'in a':
|
85 |
+
# print(names)
|
86 |
+
good_name.add(name)
|
87 |
+
# if name not in once_set:
|
88 |
+
# once_set.add(name)
|
89 |
+
# else:
|
90 |
+
# twice_set.add(name)
|
91 |
+
# assert 'WNK4' in once_set
|
92 |
+
# good_name = set.difference(once_set, twice_set)
|
93 |
+
# assert 'in a' not in good_name
|
94 |
+
# assert 'STE20' not in good_name
|
95 |
+
# assert 'STE20' not in valid_entity
|
96 |
+
# assert 'STE20-related proline-alanine-rich kinase' not in good_name
|
97 |
+
# assert 'STE20-related proline-alanine-rich kinase' not in valid_entity
|
98 |
+
# raise Exception
|
99 |
+
|
100 |
+
name_to_type = {}
|
101 |
+
name_to_meshid = {}
|
102 |
+
|
103 |
+
for k, v, in full_entity_raw_name.items():
|
104 |
+
names = list(v)
|
105 |
+
for name in names:
|
106 |
+
if name in good_name:
|
107 |
+
name_to_type[name] = k.split('_')[0]
|
108 |
+
name_to_meshid[name] = k
|
109 |
+
|
110 |
+
import spacy
|
111 |
+
import networkx as nx
|
112 |
+
import pprint
|
113 |
+
|
114 |
+
def check(p, s):
|
115 |
+
|
116 |
+
if p < 1 or p >= len(s):
|
117 |
+
return True
|
118 |
+
return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9'))
|
119 |
+
|
120 |
+
def raw_to_format(sen):
|
121 |
+
|
122 |
+
text = sen
|
123 |
+
l = 0
|
124 |
+
ret = []
|
125 |
+
while(l < len(text)):
|
126 |
+
bo =False
|
127 |
+
if text[l] != ' ':
|
128 |
+
for i in range(len(text), l, -1): # reversing is important !!!
|
129 |
+
cc = text[l:i]
|
130 |
+
if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text):
|
131 |
+
ret.append(cc.replace(' ', '_'))
|
132 |
+
l = i
|
133 |
+
bo = True
|
134 |
+
break
|
135 |
+
if not bo:
|
136 |
+
ret.append(text[l])
|
137 |
+
l += 1
|
138 |
+
return ''.join(ret)
|
139 |
+
|
140 |
+
if args.mode == 'sentence':
|
141 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_chat.json', 'r') as fl:
|
142 |
+
draft = json.load(fl)
|
143 |
+
elif args.mode == 'finetune':
|
144 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence_finetune.json', 'r') as fl:
|
145 |
+
draft = json.load(fl)
|
146 |
+
elif args.mode == 'bioBART':
|
147 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl:
|
148 |
+
draft = json.load(fl)
|
149 |
+
elif args.mode == 'biogpt':
|
150 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_biogpt.json', 'r') as fl:
|
151 |
+
draft = json.load(fl)
|
152 |
+
else:
|
153 |
+
raise Exception('No!!!')
|
154 |
+
|
155 |
+
nlp = spacy.load("en_core_web_sm")
|
156 |
+
|
157 |
+
type_set = set()
|
158 |
+
for aa in range(36):
|
159 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual']
|
160 |
+
tmp_dict = retieve_sentence_through_edgetype[aa]['auto']
|
161 |
+
dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys())
|
162 |
+
for dependency in dependencys:
|
163 |
+
dep_list = dependency.split(' ')
|
164 |
+
for sub_dep in dep_list:
|
165 |
+
sub_dep_list = sub_dep.split('|')
|
166 |
+
assert(len(sub_dep_list) == 3)
|
167 |
+
type_set.add(sub_dep_list[1])
|
168 |
+
# print('Type:', type_set)
|
169 |
+
|
170 |
+
if args.action == 'parse':
|
171 |
+
# dp_path, sen_list = list(dependency_sen_dict.items())[0]
|
172 |
+
# check
|
173 |
+
# paper_id, sen_id = sen_list[0]
|
174 |
+
# sen = raw_text_sen[paper_id][sen_id]
|
175 |
+
# doc = nlp(sen['text'])
|
176 |
+
# print(dp_path, '\n')
|
177 |
+
# pprint.pprint(sen)
|
178 |
+
# print()
|
179 |
+
# for token in doc:
|
180 |
+
# print((token.head.text, token.text, token.dep_))
|
181 |
+
|
182 |
+
out = ''
|
183 |
+
for k, v_dict in draft.items():
|
184 |
+
input = v_dict['in']
|
185 |
+
output = v_dict['out']
|
186 |
+
if input == '':
|
187 |
+
continue
|
188 |
+
output = output.replace('\n', ' ')
|
189 |
+
doc = nlp(output)
|
190 |
+
for sen in doc.sents:
|
191 |
+
out += raw_to_format(sen.text) + '\n'
|
192 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl:
|
193 |
+
fl.write(out)
|
194 |
+
elif args.action == 'extract':
|
195 |
+
|
196 |
+
# dependency_to_type_id = {}
|
197 |
+
# for k, v in Parameters.edge_type_to_id.items():
|
198 |
+
# dependency_to_type_id[k] = {}
|
199 |
+
# for type in v:
|
200 |
+
# LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys())
|
201 |
+
# for dp in LL:
|
202 |
+
# dependency_to_type_id[k][dp] = type
|
203 |
+
if os.path.exists('generate_abstract/dependency_to_type_id.pickle'):
|
204 |
+
with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl:
|
205 |
+
dependency_to_type_id = pkl.load(fl)
|
206 |
+
else:
|
207 |
+
dependency_to_type_id = {}
|
208 |
+
print('Loading path data ...')
|
209 |
+
for k in Parameters.edge_type_to_id.keys():
|
210 |
+
start, end = k.split('-')
|
211 |
+
dependency_to_type_id[k] = {}
|
212 |
+
inner_edge_type_to_id = Parameters.edge_type_to_id[k]
|
213 |
+
inner_edge_type_dict = Parameters.edge_type_dict[k]
|
214 |
+
cal_manual_num = [0] * len(inner_edge_type_to_id)
|
215 |
+
with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl:
|
216 |
+
for i, line in tqdm(list(enumerate(fl.readlines()))):
|
217 |
+
tmp = line.split('\t')
|
218 |
+
if i == 0:
|
219 |
+
head = [tmp[i] for i in range(1, len(tmp), 2)]
|
220 |
+
assert ' '.join(head) == ' '.join(inner_edge_type_dict[0])
|
221 |
+
continue
|
222 |
+
probability = [float(tmp[i]) for i in range(1, len(tmp), 2)]
|
223 |
+
flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)]
|
224 |
+
indices = np.where(np.asarray(flag_list) == 1)[0]
|
225 |
+
if len(indices) >= 1:
|
226 |
+
tmp_p = [cal_manual_num[i] for i in indices]
|
227 |
+
p = indices[np.argmin(tmp_p)]
|
228 |
+
cal_manual_num[p] += 1
|
229 |
+
else:
|
230 |
+
p = np.argmax(probability)
|
231 |
+
assert tmp[0].lower() not in dependency_to_type_id.keys()
|
232 |
+
dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p]
|
233 |
+
with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl:
|
234 |
+
pkl.dump(dependency_to_type_id, fl)
|
235 |
+
|
236 |
+
# record = []
|
237 |
+
# with open(f'generate_abstract/par_parseout.txt', 'r') as fl:
|
238 |
+
# Tmp = []
|
239 |
+
# tmp = []
|
240 |
+
# for i,line in enumerate(fl.readlines()):
|
241 |
+
# # print(len(line), line)
|
242 |
+
# line = line.replace('\n', '')
|
243 |
+
# if len(line) > 1:
|
244 |
+
# tmp.append(line)
|
245 |
+
# else:
|
246 |
+
# Tmp.append(tmp)
|
247 |
+
# tmp = []
|
248 |
+
# if len(Tmp) == 3:
|
249 |
+
# record.append(Tmp)
|
250 |
+
# Tmp = []
|
251 |
+
|
252 |
+
# print(len(record))
|
253 |
+
# record_index = 0
|
254 |
+
# add = 0
|
255 |
+
# Attack = []
|
256 |
+
# for ii in range(100):
|
257 |
+
|
258 |
+
# # input = v_dict['in']
|
259 |
+
# # output = v_dict['out']
|
260 |
+
# # output = output.replace('\n', ' ')
|
261 |
+
# s, r, o = attack_data[ii]
|
262 |
+
# dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
263 |
+
|
264 |
+
# target_dp = set()
|
265 |
+
# for dp_path, sen_list in dependency_sen_dict.items():
|
266 |
+
# target_dp.add(dp_path)
|
267 |
+
# DP_list = []
|
268 |
+
# for _ in range(1):
|
269 |
+
# dp_dict = {}
|
270 |
+
# data = record[record_index]
|
271 |
+
# record_index += 1
|
272 |
+
# dp_paths = data[2]
|
273 |
+
# nodes_list = []
|
274 |
+
# edges_list = []
|
275 |
+
# for line in dp_paths:
|
276 |
+
# ttp, tmp = line.split('(')
|
277 |
+
# assert tmp[-1] == ')'
|
278 |
+
# tmp = tmp[:-1]
|
279 |
+
# e1, e2 = tmp.split(', ')
|
280 |
+
# if not ttp in type_set and ':' in ttp:
|
281 |
+
# ttp = ttp.split(':')[0]
|
282 |
+
# dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
|
283 |
+
# dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
|
284 |
+
# nodes_list.append(e1)
|
285 |
+
# nodes_list.append(e2)
|
286 |
+
# edges_list.append((e1, e2))
|
287 |
+
# nodes_list = list(set(nodes_list))
|
288 |
+
# pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
|
289 |
+
# graph = nx.Graph(edges_list)
|
290 |
+
|
291 |
+
# type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
|
292 |
+
# # print(type_list)
|
293 |
+
# # for i in range(len(type_list)):
|
294 |
+
# # print(pure_name[i], type_list[i])
|
295 |
+
# for i in range(len(nodes_list)):
|
296 |
+
# if type_list[i] != '':
|
297 |
+
# for j in range(len(nodes_list)):
|
298 |
+
# if i != j and type_list[j] != '':
|
299 |
+
# if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
|
300 |
+
# # print(f'{type_list[i]}_{type_list[j]}')
|
301 |
+
# ret_path = []
|
302 |
+
# sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
|
303 |
+
# start = sp[0]
|
304 |
+
# end = sp[-1]
|
305 |
+
# for k in range(len(sp)-1):
|
306 |
+
# e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
|
307 |
+
# if e1 == start:
|
308 |
+
# e1 = 'start_entity-x'
|
309 |
+
# if e2 == start:
|
310 |
+
# e2 = 'start_entity-x'
|
311 |
+
# if e1 == end:
|
312 |
+
# e1 = 'end_entity-x'
|
313 |
+
# if e2 == end:
|
314 |
+
# e2 = 'end_entity-x'
|
315 |
+
# ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
|
316 |
+
# dependency_P = ' '.join(ret_path)
|
317 |
+
# DP_list.append((f'{type_list[i]}-{type_list[j]}',
|
318 |
+
# name_to_meshid[pure_name[i]],
|
319 |
+
# name_to_meshid[pure_name[j]],
|
320 |
+
# dependency_P))
|
321 |
+
|
322 |
+
# boo = False
|
323 |
+
# modified_attack = []
|
324 |
+
# for k, ss, tt, dp in DP_list:
|
325 |
+
# if dp in dependency_to_type_id[k].keys():
|
326 |
+
# tp = str(dependency_to_type_id[k][dp])
|
327 |
+
# id_ss = str(meshid_to_id[ss])
|
328 |
+
# id_tt = str(meshid_to_id[tt])
|
329 |
+
# modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
|
330 |
+
# if int(dependency_to_type_id[k][dp]) == int(r):
|
331 |
+
# # if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
|
332 |
+
# boo = True
|
333 |
+
# modified_attack = list(set(modified_attack))
|
334 |
+
# modified_attack = [k.split('*') for k in modified_attack]
|
335 |
+
# if boo:
|
336 |
+
# add += 1
|
337 |
+
# # else:
|
338 |
+
# # print(ii)
|
339 |
+
|
340 |
+
# # for i in range(len(type_list)):
|
341 |
+
# # if type_list[i]:
|
342 |
+
# # print(pure_name[i], type_list[i])
|
343 |
+
# # for k, ss, tt, dp in DP_list:
|
344 |
+
# # print(k, dp)
|
345 |
+
# # print(record[record_index - 1])
|
346 |
+
# # raise Exception('No!!')
|
347 |
+
# Attack.append(modified_attack)
|
348 |
+
|
349 |
+
record = []
|
350 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl:
|
351 |
+
Tmp = []
|
352 |
+
tmp = []
|
353 |
+
for i,line in enumerate(fl.readlines()):
|
354 |
+
# print(len(line), line)
|
355 |
+
line = line.replace('\n', '')
|
356 |
+
if len(line) > 1:
|
357 |
+
tmp.append(line)
|
358 |
+
else:
|
359 |
+
if len(Tmp) == 2:
|
360 |
+
if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]:
|
361 |
+
Tmp.append([])
|
362 |
+
record.append(Tmp)
|
363 |
+
Tmp = []
|
364 |
+
Tmp.append(tmp)
|
365 |
+
if len(Tmp) == 2 and tmp[0][:5] != '(ROOT':
|
366 |
+
print(record[-1][2])
|
367 |
+
raise Exception('??')
|
368 |
+
tmp = []
|
369 |
+
if len(Tmp) == 3:
|
370 |
+
record.append(Tmp)
|
371 |
+
Tmp = []
|
372 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl:
|
373 |
+
parsin = fl.readlines()
|
374 |
+
|
375 |
+
print('Record len', len(record), 'Parsin len:', len(parsin))
|
376 |
+
record_index = 0
|
377 |
+
add = 0
|
378 |
+
|
379 |
+
Attack = []
|
380 |
+
for ii, (k, v_dict) in enumerate(tqdm(draft.items())):
|
381 |
+
|
382 |
+
input = v_dict['in']
|
383 |
+
output = v_dict['out']
|
384 |
+
output = output.replace('\n', ' ')
|
385 |
+
s, r, o = attack_data[ii]
|
386 |
+
assert ii == int(k.split('_')[-1])
|
387 |
+
|
388 |
+
DP_list = []
|
389 |
+
if input != '':
|
390 |
+
|
391 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
392 |
+
target_dp = set()
|
393 |
+
for dp_path, sen_list in dependency_sen_dict.items():
|
394 |
+
target_dp.add(dp_path)
|
395 |
+
doc = nlp(output)
|
396 |
+
|
397 |
+
for sen in doc.sents:
|
398 |
+
dp_dict = {}
|
399 |
+
if record_index >= len(record):
|
400 |
+
break
|
401 |
+
data = record[record_index]
|
402 |
+
record_index += 1
|
403 |
+
dp_paths = data[2]
|
404 |
+
nodes_list = []
|
405 |
+
edges_list = []
|
406 |
+
for line in dp_paths:
|
407 |
+
aa = line.split('(')
|
408 |
+
if len(aa) == 1:
|
409 |
+
print(ii)
|
410 |
+
print(sen)
|
411 |
+
print(data)
|
412 |
+
raise Exception
|
413 |
+
ttp, tmp = aa[0], aa[1]
|
414 |
+
assert tmp[-1] == ')'
|
415 |
+
tmp = tmp[:-1]
|
416 |
+
e1, e2 = tmp.split(', ')
|
417 |
+
if not ttp in type_set and ':' in ttp:
|
418 |
+
ttp = ttp.split(':')[0]
|
419 |
+
dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2]
|
420 |
+
dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2]
|
421 |
+
nodes_list.append(e1)
|
422 |
+
nodes_list.append(e2)
|
423 |
+
edges_list.append((e1, e2))
|
424 |
+
nodes_list = list(set(nodes_list))
|
425 |
+
pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list]
|
426 |
+
graph = nx.Graph(edges_list)
|
427 |
+
|
428 |
+
type_list = [name_to_type[name] if name in good_name else '' for name in pure_name]
|
429 |
+
# print(type_list)
|
430 |
+
for i in range(len(nodes_list)):
|
431 |
+
if type_list[i] != '':
|
432 |
+
for j in range(len(nodes_list)):
|
433 |
+
if i != j and type_list[j] != '':
|
434 |
+
if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys():
|
435 |
+
# print(f'{type_list[i]}_{type_list[j]}')
|
436 |
+
ret_path = []
|
437 |
+
sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j])
|
438 |
+
start = sp[0]
|
439 |
+
end = sp[-1]
|
440 |
+
for k in range(len(sp)-1):
|
441 |
+
e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}']
|
442 |
+
if e1 == start:
|
443 |
+
e1 = 'start_entity-x'
|
444 |
+
if e2 == start:
|
445 |
+
e2 = 'start_entity-x'
|
446 |
+
if e1 == end:
|
447 |
+
e1 = 'end_entity-x'
|
448 |
+
if e2 == end:
|
449 |
+
e2 = 'end_entity-x'
|
450 |
+
ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower())
|
451 |
+
dependency_P = ' '.join(ret_path)
|
452 |
+
DP_list.append((f'{type_list[i]}-{type_list[j]}',
|
453 |
+
name_to_meshid[pure_name[i]],
|
454 |
+
name_to_meshid[pure_name[j]],
|
455 |
+
dependency_P))
|
456 |
+
|
457 |
+
boo = False
|
458 |
+
modified_attack = []
|
459 |
+
for k, ss, tt, dp in DP_list:
|
460 |
+
if dp in dependency_to_type_id[k].keys():
|
461 |
+
tp = str(dependency_to_type_id[k][dp])
|
462 |
+
id_ss = str(meshid_to_id[ss])
|
463 |
+
id_tt = str(meshid_to_id[tt])
|
464 |
+
modified_attack.append(f'{id_ss}*{tp}*{id_tt}')
|
465 |
+
if int(dependency_to_type_id[k][dp]) == int(r):
|
466 |
+
if id_to_meshid[s] == ss and id_to_meshid[o] == tt:
|
467 |
+
boo = True
|
468 |
+
modified_attack = list(set(modified_attack))
|
469 |
+
modified_attack = [k.split('*') for k in modified_attack]
|
470 |
+
if boo:
|
471 |
+
# print(DP_list)
|
472 |
+
add += 1
|
473 |
+
Attack.append(modified_attack)
|
474 |
+
print(add)
|
475 |
+
print('End record_index:', record_index)
|
476 |
+
with open(modified_attack_path, 'wb') as fl:
|
477 |
+
pkl.dump(Attack, fl)
|
478 |
+
else:
|
479 |
+
raise Exception('Wrong action !!')
|
DiseaseSpecific/attack.py
ADDED
@@ -0,0 +1,841 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import pickle as pkl
|
3 |
+
from typing import Dict, Tuple, List
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import dill
|
8 |
+
import logging
|
9 |
+
import argparse
|
10 |
+
import math
|
11 |
+
from pprint import pprint
|
12 |
+
import pandas as pd
|
13 |
+
from collections import defaultdict
|
14 |
+
import copy
|
15 |
+
import time
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch.utils.data import DataLoader
|
20 |
+
import torch.backends.cudnn as cudnn
|
21 |
+
import torch.autograd as autograd
|
22 |
+
from torch.nn import functional as F
|
23 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
24 |
+
|
25 |
+
from model import Distmult, Complex, Conve
|
26 |
+
import utils
|
27 |
+
|
28 |
+
import sys
|
29 |
+
|
30 |
+
import dill
|
31 |
+
|
32 |
+
sys.path.append("..")
|
33 |
+
import Parameters
|
34 |
+
|
35 |
+
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
36 |
+
|
37 |
+
logger = None
|
38 |
+
def generate_nghbrs_single_entity(x, edge_nghbrs, bound):
|
39 |
+
|
40 |
+
ret_S = set(x)
|
41 |
+
ret_L = [x]
|
42 |
+
b = 0
|
43 |
+
while(b < len(ret_L)):
|
44 |
+
s = ret_L[b]
|
45 |
+
if s in edge_nghbrs.keys():
|
46 |
+
for v in edge_nghbrs[s]:
|
47 |
+
if v not in ret_S:
|
48 |
+
ret_S.add(v)
|
49 |
+
ret_L.append(v)
|
50 |
+
if len(ret_L) == bound:
|
51 |
+
return ret_L
|
52 |
+
b += 1
|
53 |
+
return ret_L
|
54 |
+
|
55 |
+
def generate_nghbrs(target_data, edge_nghbrs, args):
|
56 |
+
n_dict = {}
|
57 |
+
for i, (s, r, o) in enumerate(target_data):
|
58 |
+
L_s = generate_nghbrs_single_entity(s, edge_nghbrs, args.neighbor_num)
|
59 |
+
L_o = generate_nghbrs_single_entity(o, edge_nghbrs, args.neighbor_num)
|
60 |
+
n_dict[i] = list(set(L_s + L_o))
|
61 |
+
n_dict[i].sort()
|
62 |
+
return n_dict
|
63 |
+
#%%
|
64 |
+
def check_edge(s, r, o, used_trip = None, args = None):
|
65 |
+
"""Double check"""
|
66 |
+
if args is None:
|
67 |
+
return True
|
68 |
+
if not args.target_existed:
|
69 |
+
assert (s+'_'+o in used_trip) == args.target_existed
|
70 |
+
else:
|
71 |
+
s = entityid_to_nodetype[s]
|
72 |
+
o = entityid_to_nodetype[o]
|
73 |
+
r_tp = Parameters.edge_id_to_type[int(r)]
|
74 |
+
r_tp = r_tp.split(':')[0]
|
75 |
+
r_tp = r_tp.split('-')
|
76 |
+
assert s == r_tp[0] and o == r_tp[1]
|
77 |
+
|
78 |
+
def get_model_loss(batch, model, device, args = None):
|
79 |
+
s,r,o = batch[:,0], batch[:,1], batch[:,2]
|
80 |
+
|
81 |
+
emb_s = model.emb_e(s).squeeze(dim=1)
|
82 |
+
emb_r = model.emb_rel(r).squeeze(dim=1)
|
83 |
+
emb_o = model.emb_e(o).squeeze(dim=1)
|
84 |
+
|
85 |
+
if args.add_reciprocals:
|
86 |
+
r_rev = r + n_rel
|
87 |
+
emb_rrev = model.emb_rel(r_rev).squeeze(dim=1)
|
88 |
+
else:
|
89 |
+
r_rev = r
|
90 |
+
emb_rrev = emb_r
|
91 |
+
|
92 |
+
pred_sr = model.forward(emb_s, emb_r, mode='rhs')
|
93 |
+
loss_sr = model.loss(pred_sr, o) # Cross entropy loss
|
94 |
+
|
95 |
+
pred_or = model.forward(emb_o, emb_rrev, mode='lhs')
|
96 |
+
loss_or = model.loss(pred_or, s)
|
97 |
+
|
98 |
+
train_loss = loss_sr + loss_or
|
99 |
+
return train_loss
|
100 |
+
|
101 |
+
def get_model_loss_without_softmax(batch, model, device=None):
|
102 |
+
|
103 |
+
with torch.no_grad():
|
104 |
+
s,r,o = batch[:,0], batch[:,1], batch[:,2]
|
105 |
+
|
106 |
+
emb_s = model.emb_e(s).squeeze(dim=1)
|
107 |
+
emb_r = model.emb_rel(r).squeeze(dim=1)
|
108 |
+
|
109 |
+
pred = model.forward(emb_s, emb_r)
|
110 |
+
return -pred[range(o.shape[0]), o]
|
111 |
+
|
112 |
+
def lp_regularizer(model, weight, p):
|
113 |
+
trainable_params = [model.emb_e.weight, model.emb_rel.weight]
|
114 |
+
norm = 0
|
115 |
+
for i in range(len(trainable_params)):
|
116 |
+
norm += weight * torch.sum( torch.abs(trainable_params[i]) ** p)
|
117 |
+
return norm
|
118 |
+
|
119 |
+
def n3_regularizer(factors, weight, p):
|
120 |
+
norm = 0
|
121 |
+
for f in factors:
|
122 |
+
norm += weight * torch.sum(torch.abs(f) ** p)
|
123 |
+
return norm / factors[0].shape[0]
|
124 |
+
|
125 |
+
def get_train_loss(batch, model, device, args):
|
126 |
+
#batch = batch[0].to(device)
|
127 |
+
s,r,o = batch[:,0], batch[:,1], batch[:,2]
|
128 |
+
|
129 |
+
emb_s = model.emb_e(s).squeeze(dim=1)
|
130 |
+
emb_r = model.emb_rel(r).squeeze(dim=1)
|
131 |
+
emb_o = model.emb_e(o).squeeze(dim=1)
|
132 |
+
|
133 |
+
if args.add_reciprocals:
|
134 |
+
r_rev = r + n_rel
|
135 |
+
emb_rrev = model.emb_rel(r_rev).squeeze(dim=1)
|
136 |
+
else:
|
137 |
+
r_rev = r
|
138 |
+
emb_rrev = emb_r
|
139 |
+
|
140 |
+
pred_sr = model.forward(emb_s, emb_r, mode='rhs')
|
141 |
+
loss_sr = model.loss(pred_sr, o) # loss is cross entropy loss
|
142 |
+
|
143 |
+
pred_or = model.forward(emb_o, emb_rrev, mode='lhs')
|
144 |
+
loss_or = model.loss(pred_or, s)
|
145 |
+
|
146 |
+
train_loss = loss_sr + loss_or
|
147 |
+
|
148 |
+
if (args.reg_weight != 0.0 and args.reg_norm == 3):
|
149 |
+
#self.logger.info('Computing regularizer weight')
|
150 |
+
if model == 'complex':
|
151 |
+
emb_dim = args.embedding_dim #int(self.args.embedding_dim/2)
|
152 |
+
lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
|
153 |
+
rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
|
154 |
+
rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
|
155 |
+
rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
|
156 |
+
|
157 |
+
#print(lhs[0].shape, lhs[1].shape)
|
158 |
+
factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
159 |
+
torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
|
160 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
|
161 |
+
)
|
162 |
+
factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
163 |
+
torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
|
164 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
factors_sr = (emb_s, emb_r, emb_o)
|
168 |
+
factors_or = (emb_s, emb_rrev, emb_o)
|
169 |
+
|
170 |
+
train_loss += n3_regularizer(factors_sr, args.reg_weight, p=3)
|
171 |
+
train_loss += n3_regularizer(factors_or, args.reg_weight, p=3)
|
172 |
+
|
173 |
+
if (args.reg_weight != 0.0 and args.reg_norm == 2):
|
174 |
+
train_loss += lp_regularizer(model, args.reg_weight, p=2)
|
175 |
+
|
176 |
+
return train_loss
|
177 |
+
def hv(loss, model_params, v):
|
178 |
+
grad = autograd.grad(loss, model_params, create_graph=True, retain_graph=True)
|
179 |
+
Hv = autograd.grad(grad, model_params, grad_outputs=v)
|
180 |
+
return Hv
|
181 |
+
def gather_flat_grad(grads):
|
182 |
+
views = []
|
183 |
+
for p in grads:
|
184 |
+
if p.data.is_sparse:
|
185 |
+
view = p.data.to_dense().view(-1)
|
186 |
+
else:
|
187 |
+
view = p.data.view(-1)
|
188 |
+
views.append(view)
|
189 |
+
return torch.cat(views, 0)
|
190 |
+
|
191 |
+
def get_inverse_hvp_lissa(v, model, device, param_influence, train_data, args):
|
192 |
+
|
193 |
+
damping = args.damping
|
194 |
+
num_samples = args.lissa_repeat
|
195 |
+
scale = args.scale
|
196 |
+
train_batch_size = args.lissa_batch_size
|
197 |
+
lissa_num_batches = math.ceil(train_data.shape[0]/train_batch_size)
|
198 |
+
recursion_depth = int(lissa_num_batches*args.lissa_depth)
|
199 |
+
|
200 |
+
ihvp = None
|
201 |
+
# print('inversing hvp...')
|
202 |
+
for i in range(num_samples):
|
203 |
+
cur_estimate = v
|
204 |
+
#lissa_data_iterator = iter(train_loader)
|
205 |
+
input_data = torch.from_numpy(train_data.astype('int64'))
|
206 |
+
actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
|
207 |
+
del input_data
|
208 |
+
|
209 |
+
b_begin = 0
|
210 |
+
for j in range(recursion_depth):
|
211 |
+
model.zero_grad() # same as optimizer.zero_grad()
|
212 |
+
if b_begin >= actual_examples.shape[0]:
|
213 |
+
b_begin = 0
|
214 |
+
input_data = torch.from_numpy(train_data.astype('int64'))
|
215 |
+
actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
|
216 |
+
del input_data
|
217 |
+
|
218 |
+
input_batch = actual_examples[b_begin: b_begin + train_batch_size]
|
219 |
+
input_batch = input_batch.to(device)
|
220 |
+
|
221 |
+
train_loss = get_train_loss(input_batch, model, device, args)
|
222 |
+
|
223 |
+
hvp = hv(train_loss, param_influence, cur_estimate)
|
224 |
+
cur_estimate = [_a + (1-damping)*_b - _c / scale for _a, _b, _c in zip(v, cur_estimate, hvp)]
|
225 |
+
# if (j%200 == 0) or (j == recursion_depth -1 ):
|
226 |
+
# logger.info("Recursion at depth %s: norm is %f" % (j, np.linalg.norm(gather_flat_grad(cur_estimate).cpu().numpy())))
|
227 |
+
|
228 |
+
b_begin += train_batch_size
|
229 |
+
|
230 |
+
if ihvp == None:
|
231 |
+
ihvp = [_a / scale for _a in cur_estimate]
|
232 |
+
else:
|
233 |
+
ihvp = [_a + _b / scale for _a, _b in zip(ihvp, cur_estimate)]
|
234 |
+
|
235 |
+
# logger.info("Final ihvp norm is %f" % (np.linalg.norm(gather_flat_grad(ihvp).cpu().numpy())))
|
236 |
+
return_ihvp = gather_flat_grad(ihvp)
|
237 |
+
return_ihvp /= num_samples
|
238 |
+
|
239 |
+
return return_ihvp
|
240 |
+
|
241 |
+
#%%
|
242 |
+
def before_global_attack(device, n_rel, data, target_data, neighbors, model,
|
243 |
+
filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
|
244 |
+
entityid_to_nodetype, batch_size, args, lissa_path, target_disease):
|
245 |
+
|
246 |
+
if os.path.exists(lissa_path) and not args.update_lissa:
|
247 |
+
with open(lissa_path, 'rb') as fl:
|
248 |
+
ret = dill.load(fl)
|
249 |
+
return ret
|
250 |
+
ret = {}
|
251 |
+
|
252 |
+
test_data = []
|
253 |
+
for i in target_disease:
|
254 |
+
tp = entityid_to_nodetype[str(i)]
|
255 |
+
# r = torch.LongTensor([[10]]).to(device)
|
256 |
+
assert tp == 'disease'
|
257 |
+
if tp == 'disease':
|
258 |
+
for target in target_data:
|
259 |
+
test_data.append([str(target), str(10), str(i)])
|
260 |
+
test_data = np.array(test_data)
|
261 |
+
|
262 |
+
for target_trip in tqdm(test_data):
|
263 |
+
|
264 |
+
target_trip_ori = target_trip
|
265 |
+
trip_name = '_'.join(list(target_trip_ori))
|
266 |
+
target_trip = target_trip[None, :] # add a batch dimension
|
267 |
+
target_trip = torch.from_numpy(target_trip.astype('int64')).to(device)
|
268 |
+
# target_s, target_r, target_o = target_trip[:,0], target_trip[:,1], target_trip[:,2]
|
269 |
+
# target_vec = model.score_triples_vec(target_s, target_r, target_o)
|
270 |
+
|
271 |
+
model.eval()
|
272 |
+
model.zero_grad()
|
273 |
+
target_loss = get_model_loss(target_trip, model, device)
|
274 |
+
target_grads = autograd.grad(target_loss, param_influence)
|
275 |
+
|
276 |
+
model.train()
|
277 |
+
inverse_hvp = get_inverse_hvp_lissa(target_grads, model, device,
|
278 |
+
param_influence, data, args)
|
279 |
+
model.eval()
|
280 |
+
inverse_hvp = inverse_hvp.detach().cpu().unsqueeze(0)
|
281 |
+
ret[trip_name] = inverse_hvp
|
282 |
+
with open(lissa_path, 'wb') as fl:
|
283 |
+
dill.dump(ret, fl)
|
284 |
+
return ret
|
285 |
+
|
286 |
+
def global_addtion_attack(device, n_rel, data, target_data, neighbors, model,
|
287 |
+
filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
|
288 |
+
entityid_to_nodetype, batch_size, args, lissa, target_disease):
|
289 |
+
|
290 |
+
logger.info('------ Generating edits per target triple ------')
|
291 |
+
start_time = time.time()
|
292 |
+
logger.info('Start time: {0}'.format(str(start_time)))
|
293 |
+
|
294 |
+
used_trip = set()
|
295 |
+
print("Processing used triples ...")
|
296 |
+
for s, r, o in tqdm(data):
|
297 |
+
used_trip.add(s+'_'+o)
|
298 |
+
# used_trip.add(o+'_'+s)
|
299 |
+
print('Size of used triples:', len(used_trip))
|
300 |
+
logger.info('Size of used triples: {0}'.format(len(used_trip)))
|
301 |
+
|
302 |
+
ret_trip = []
|
303 |
+
score_record = []
|
304 |
+
real_add_rank_ratio = 0
|
305 |
+
|
306 |
+
with open(score_path, 'rb') as fl:
|
307 |
+
score_record = pkl.load(fl)
|
308 |
+
for i, target in enumerate(target_data):
|
309 |
+
|
310 |
+
print('\n\n------ Attacking target tripid:', i, 'tot:', len(target_data), ' ------')
|
311 |
+
# lissa_hvp = []
|
312 |
+
target_trip = []
|
313 |
+
for disease in target_disease:
|
314 |
+
target_trip.append([target, str(10), disease])
|
315 |
+
# nm = '{}_{}_{}'.format(target, 10, disease)
|
316 |
+
# lissa_hvp.append(lissa[nm])
|
317 |
+
# lissa_hvp = torch.cat(lissa_hvp, dim = 0).to(device)
|
318 |
+
|
319 |
+
target_trip = np.array(target_trip)
|
320 |
+
target_trip = torch.from_numpy(target_trip.astype('int64')).to(device)
|
321 |
+
|
322 |
+
model.eval()
|
323 |
+
model.zero_grad()
|
324 |
+
target_loss = get_model_loss(target_trip, model, device)
|
325 |
+
target_grads = autograd.grad(target_loss, param_influence)
|
326 |
+
|
327 |
+
model.train()
|
328 |
+
inverse_hvp = get_inverse_hvp_lissa(target_grads, model, device,
|
329 |
+
param_influence, data, args)
|
330 |
+
|
331 |
+
model.eval()
|
332 |
+
|
333 |
+
nghbr_trip = []
|
334 |
+
s = str(target)
|
335 |
+
tp = entityid_to_nodetype[s]
|
336 |
+
for nghbr in tqdm(neighbors):
|
337 |
+
o = str(nghbr)
|
338 |
+
if s!=o and s+'_'+o not in used_trip:
|
339 |
+
for r in range(n_rel):
|
340 |
+
if (tp, r) in filters["rhs"].keys() and filters["rhs"][(tp, r)][int(o)] == True:
|
341 |
+
nghbr_trip.append([s, str(r), o])
|
342 |
+
|
343 |
+
nghbr_trip = np.asarray(nghbr_trip)
|
344 |
+
influences = []
|
345 |
+
edge_losses = []
|
346 |
+
|
347 |
+
# nghbr_cos_log_prob, nghbr_LM_log_prob = score_record[i]
|
348 |
+
# assert nghbr_cos_log_prob.shape[0] == nghbr_trip.shape[0]
|
349 |
+
|
350 |
+
for train_trip in tqdm(nghbr_trip):
|
351 |
+
#model.train() #batch norm cannot be used here
|
352 |
+
train_trip = train_trip[None, :] # add batch dim
|
353 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
354 |
+
#### L-train gradient ####
|
355 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
356 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
357 |
+
model.zero_grad()
|
358 |
+
train_loss = get_model_loss(train_trip, model, device, args)
|
359 |
+
train_grads = autograd.grad(train_loss, param_influence)
|
360 |
+
train_grads = gather_flat_grad(train_grads)
|
361 |
+
influence = torch.dot(inverse_hvp, train_grads) #default dim=1
|
362 |
+
influences.append(influence.unsqueeze(0).detach())
|
363 |
+
|
364 |
+
edge_losses = torch.cat(edge_losses, dim = -1)
|
365 |
+
influences = torch.cat(influences, dim = -1)
|
366 |
+
edge_losses_log_prob = torch.log(F.softmax(-edge_losses, dim = -1))
|
367 |
+
influences_log_prob = torch.log(F.softmax(influences, dim = -1))
|
368 |
+
|
369 |
+
inf_score_sorted, influences_sort = torch.sort(influences_log_prob, -1, descending=True)
|
370 |
+
edge_score_sorted, edge_sort = torch.sort(edge_losses_log_prob, -1, descending=True)
|
371 |
+
influences_sort = influences_sort.cpu().numpy()
|
372 |
+
edge_sort = edge_sort.cpu().numpy()
|
373 |
+
inf_score_sorted = inf_score_sorted.cpu().numpy()
|
374 |
+
edge_score_sorted = edge_score_sorted.cpu().numpy()
|
375 |
+
|
376 |
+
logger.info('')
|
377 |
+
logger.info('Top 8 inf_score: {}'.format(" ".join(map(str, list(inf_score_sorted[:8])))))
|
378 |
+
logger.info('Top 8 edge_score: {}'.format(" ".join(map(str, list(edge_score_sorted[:8])))))
|
379 |
+
|
380 |
+
nghbr_cos_log_prob = influences_log_prob.detach().cpu().numpy()
|
381 |
+
nghbr_LM_log_prob = edge_losses_log_prob.detach().cpu().numpy()
|
382 |
+
max_sim = np.max(nghbr_cos_log_prob)
|
383 |
+
min_sim = np.min(nghbr_cos_log_prob)
|
384 |
+
max_LM = np.max(nghbr_LM_log_prob)
|
385 |
+
min_LM = np.min(nghbr_LM_log_prob)
|
386 |
+
|
387 |
+
# final_score = nghbr_cos_log_prob + nghbr_LM_log_prob
|
388 |
+
final_score = nghbr_cos_log_prob
|
389 |
+
|
390 |
+
index = np.argmax(final_score[:-1])
|
391 |
+
# p = np.where(index == edge_sort)[0][0]
|
392 |
+
# logger.info('Added edge\'s edge rank ratio: {}'.format(p / edge_sort.shape[0]))
|
393 |
+
real_add_rank_ratio += p
|
394 |
+
add_trip = nghbr_trip[index]
|
395 |
+
logger.info('max_inf: {0:.8}, min_inf: {1:.8}, max_edge: {2:.8}, min_edge: {3:.8}'.format(max_sim, min_sim, max_LM, min_LM))
|
396 |
+
logger.info('Attack trip: {0}_{1}_{2}.\n Influnce score: {3:.8}. Edge score: {4:.8}.'.format(add_trip[0], add_trip[1], add_trip[2],
|
397 |
+
nghbr_cos_log_prob[index], nghbr_LM_log_prob[index]))
|
398 |
+
ret_trip.append(add_trip)
|
399 |
+
score_record.append((nghbr_cos_log_prob, nghbr_LM_log_prob))
|
400 |
+
real_add_rank_ratio = real_add_rank_ratio / target_data.shape[0]
|
401 |
+
logger.info('Mean real ratio: {}.'.format(real_add_rank_ratio))
|
402 |
+
return ret_trip, score_record
|
403 |
+
|
404 |
+
def addition_attack(param_influence, device, n_rel, data, target_data, neighbors, model,
|
405 |
+
filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
|
406 |
+
entityid_to_nodetype, batch_size, args, load_Record = False, divide_bound = None, data_mean = None, data_std = None, cache_intermidiate = True):
|
407 |
+
|
408 |
+
if logger:
|
409 |
+
logger.info('------ Generating edits per target triple ------')
|
410 |
+
start_time = time.time()
|
411 |
+
if logger:
|
412 |
+
logger.info('Start time: {0}'.format(str(start_time)))
|
413 |
+
|
414 |
+
used_trip = set()
|
415 |
+
print("Processing used triples ...")
|
416 |
+
for s, r, o in tqdm(data):
|
417 |
+
used_trip.add(s+'_'+o)
|
418 |
+
# used_trip.add(o+'_'+s)
|
419 |
+
print('Size of used triples:', len(used_trip))
|
420 |
+
if logger:
|
421 |
+
logger.info('Size of used triples: {0}'.format(len(used_trip)))
|
422 |
+
|
423 |
+
nghbr_trip_len = []
|
424 |
+
ret_trip = []
|
425 |
+
score_record = []
|
426 |
+
direct_add_rank_ratio = 0
|
427 |
+
real_add_rank_ratio = 0
|
428 |
+
bad_ratio = 0
|
429 |
+
|
430 |
+
RRcord = []
|
431 |
+
print('****'*10)
|
432 |
+
if load_Record:
|
433 |
+
print('Load intermidiate file')
|
434 |
+
with open(intermidiate_path, 'rb') as fl:
|
435 |
+
RRcord = dill.load(fl)
|
436 |
+
else:
|
437 |
+
print('Donnot load intermidiate file')
|
438 |
+
|
439 |
+
for i, target_trip in enumerate(target_data):
|
440 |
+
|
441 |
+
print('\n\n------ Attacking target tripid:', i, ' ------')
|
442 |
+
target_nghbrs = neighbors[i]
|
443 |
+
for a in target_nghbrs:
|
444 |
+
if str(a) == '-1':
|
445 |
+
raise Exception('pppp')
|
446 |
+
|
447 |
+
target_trip_ori = target_trip
|
448 |
+
check_edge(target_trip[0], target_trip[1], target_trip[2], used_trip)
|
449 |
+
target_trip = target_trip[None, :] # add a batch dimension
|
450 |
+
target_trip = torch.from_numpy(target_trip.astype('int64')).to(device)
|
451 |
+
# target_s, target_r, target_o = target_trip[:,0], target_trip[:,1], target_trip[:,2]
|
452 |
+
# target_vec = model.score_triples_vec(target_s, target_r, target_o)
|
453 |
+
|
454 |
+
model.eval()
|
455 |
+
|
456 |
+
if load_Record:
|
457 |
+
o_target_trip, nghbr_trip, edge_losses, influences, edge_losses_log_prob, influences_log_prob = RRcord[i]
|
458 |
+
assert (o_target_trip.cpu() == target_trip.cpu()).sum().item() == 3
|
459 |
+
else:
|
460 |
+
model.zero_grad()
|
461 |
+
target_loss = get_model_loss(target_trip, model, device, args)
|
462 |
+
target_grads = autograd.grad(target_loss, param_influence)
|
463 |
+
|
464 |
+
model.train()
|
465 |
+
inverse_hvp = get_inverse_hvp_lissa(target_grads, model, device,
|
466 |
+
param_influence, data, args)
|
467 |
+
|
468 |
+
model.eval()
|
469 |
+
nghbr_trip = []
|
470 |
+
valid_trip = 0
|
471 |
+
if args.candidate_mode == 'quadratic':
|
472 |
+
s_o_list = [(i, j) for i in target_nghbrs for j in target_nghbrs]
|
473 |
+
elif args.candidate_mode == 'linear':
|
474 |
+
s_o_list = [(j, i) for i in target_nghbrs for j in [target_trip_ori[0], target_trip_ori[2]]] \
|
475 |
+
+ [(i, j) for i in target_nghbrs for j in [target_trip_ori[0], target_trip_ori[2]]]
|
476 |
+
else:
|
477 |
+
raise Exception('Wrong candidate_mode: '+args.candidate_mode)
|
478 |
+
for s, o in tqdm(s_o_list):
|
479 |
+
tp = entityid_to_nodetype[s]
|
480 |
+
if s!=o and s+'_'+o not in used_trip:
|
481 |
+
for r in range(n_rel):
|
482 |
+
if (tp, r) in filters["rhs"].keys() and filters["rhs"][(tp, r)][int(o)] == True:
|
483 |
+
# check_edge(s, r, o)
|
484 |
+
valid_trip += 1
|
485 |
+
nghbr_trip.append([s, str(r), o])
|
486 |
+
# logger.info('{0}_{1}_{2}'.format(s, str(r), o))
|
487 |
+
nghbr_trip_len.append(len(nghbr_trip))
|
488 |
+
print('Valid trip:', valid_trip)
|
489 |
+
|
490 |
+
if target_trip_ori[0]+'_'+target_trip_ori[2] not in used_trip:
|
491 |
+
nghbr_trip.append(target_trip_ori)
|
492 |
+
nghbr_trip = np.asarray(nghbr_trip)
|
493 |
+
print("Edge scoring ...")
|
494 |
+
|
495 |
+
influences = []
|
496 |
+
edge_losses = []
|
497 |
+
|
498 |
+
for train_trip in tqdm(nghbr_trip):
|
499 |
+
#model.train() #batch norm cannot be used here
|
500 |
+
train_trip = train_trip[None, :] # add batch dim
|
501 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
502 |
+
#### L-train gradient ####
|
503 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze()
|
504 |
+
edge_losses.append(edge_loss.unsqueeze(0).detach())
|
505 |
+
model.zero_grad()
|
506 |
+
train_loss = get_model_loss(train_trip, model, device, args)
|
507 |
+
train_grads = autograd.grad(train_loss, param_influence)
|
508 |
+
train_grads = gather_flat_grad(train_grads)
|
509 |
+
influence = torch.dot(inverse_hvp, train_grads) #default dim=1
|
510 |
+
influences.append(influence.unsqueeze(0).detach())
|
511 |
+
|
512 |
+
edge_losses = torch.cat(edge_losses, dim = -1)
|
513 |
+
influences = torch.cat(influences, dim = -1)
|
514 |
+
edge_losses_log_prob = torch.log(F.softmax(-edge_losses, dim = -1))
|
515 |
+
influences_log_prob = torch.log(F.softmax(influences, dim = -1))
|
516 |
+
std_scale = torch.std(edge_losses_log_prob) / torch.std(influences_log_prob)
|
517 |
+
influences_log_prob = (influences_log_prob - influences_log_prob.mean()) * std_scale + edge_losses_log_prob.mean()
|
518 |
+
|
519 |
+
RRcord.append([target_trip.detach(), nghbr_trip, edge_losses, influences, edge_losses_log_prob, influences_log_prob])
|
520 |
+
|
521 |
+
inf_score_sorted, influences_sort = torch.sort(influences_log_prob, -1, descending=True)
|
522 |
+
edge_score_sorted, edge_sort = torch.sort(edge_losses_log_prob, -1, descending=True)
|
523 |
+
|
524 |
+
influences_sort = influences_sort.cpu().numpy()
|
525 |
+
edge_sort = edge_sort.cpu().numpy()
|
526 |
+
inf_score_sorted = inf_score_sorted.cpu().numpy()
|
527 |
+
edge_score_sorted = edge_score_sorted.cpu().numpy()
|
528 |
+
edge_losses = edge_losses.cpu().numpy()
|
529 |
+
|
530 |
+
p = np.where(influences_sort[0] == edge_sort)[0][0]
|
531 |
+
direct_add_rank_ratio += p / edge_sort.shape[0]
|
532 |
+
if logger:
|
533 |
+
logger.info('Top 8 inf_score: {}'.format(" ".join(map(str, list(inf_score_sorted[:8])))))
|
534 |
+
logger.info('Top 8 edge_score: {}'.format(" ".join(map(str, list(edge_score_sorted[:8])))))
|
535 |
+
|
536 |
+
nghbr_cos_log_prob = influences_log_prob.detach().cpu().numpy()
|
537 |
+
nghbr_LM_log_prob = edge_losses_log_prob.detach().cpu().numpy()
|
538 |
+
max_sim = nghbr_cos_log_prob[influences_sort[0]]
|
539 |
+
min_sim = nghbr_cos_log_prob[influences_sort[-1]]
|
540 |
+
max_LM = nghbr_LM_log_prob[edge_sort[0]]
|
541 |
+
min_LM = nghbr_LM_log_prob[edge_sort[-1]]
|
542 |
+
direct_score_0 = 0
|
543 |
+
direct_score_1 = 0
|
544 |
+
if target_trip_ori[0]+'_'+target_trip_ori[2] not in used_trip:
|
545 |
+
direct_score_0 = nghbr_cos_log_prob[-1]
|
546 |
+
direct_score_1 = nghbr_LM_log_prob[-1]
|
547 |
+
|
548 |
+
# bound = math.log(1 / nghbr_LM_log_prob.shape[0])
|
549 |
+
bound = 1 - args.reasonable_rate
|
550 |
+
edge_losses = (edge_losses - data_mean) / data_std
|
551 |
+
edge_losses_prob = 1 / ( 1 + np.exp(edge_losses - divide_bound) )
|
552 |
+
nghbr_LM_log_prob[edge_losses_prob < bound] = -(1e20)
|
553 |
+
|
554 |
+
final_score = nghbr_cos_log_prob + nghbr_LM_log_prob
|
555 |
+
|
556 |
+
index = np.argmax(final_score[:-1])
|
557 |
+
sort_index = [(i, final_score[i])for i in range(len(final_score) - 1)]
|
558 |
+
sort_index = sorted(sort_index, key=lambda x: x[1], reverse=True)
|
559 |
+
assert sort_index[0][0] == index
|
560 |
+
|
561 |
+
p = np.where(index == edge_sort)[0][0]
|
562 |
+
if logger:
|
563 |
+
logger.info('Bad edge ratio: {}'.format((edge_losses_prob < bound).mean()))
|
564 |
+
logger.info('Bounded edge\'s edge rank ratio: {}'.format(p / edge_sort.shape[0]))
|
565 |
+
real_add_rank_ratio += p / edge_sort.shape[0]
|
566 |
+
bad_ratio += (edge_losses_prob < bound).mean()
|
567 |
+
|
568 |
+
add_trip = nghbr_trip[index]
|
569 |
+
|
570 |
+
if (int(add_trip[0]) == int(-1)):
|
571 |
+
add_trip[0], add_trip[1], add_trip[2] = -1, -1, -1
|
572 |
+
print(final_score.shape, index, edge_losses_prob[index], bound)
|
573 |
+
raise Exception('??')
|
574 |
+
|
575 |
+
if logger:
|
576 |
+
logger.info('max_inf: {0:.8}, min_inf: {1:.8}, max_edge: {2:.8}, min_edge: {3:.8}'.format(max_sim, min_sim, max_LM, min_LM))
|
577 |
+
logger.info('Target trip: {0}_{1}_{2}. Attack trip: {3}_{4}_{5}.\n Influnce score: {6:.8}. Edge score: {7:.8}. Direct score: {8:.8} + {9:.8}'.format(target_trip_ori[0],target_trip_ori[1], target_trip_ori[2],
|
578 |
+
add_trip[0], add_trip[1], add_trip[2],
|
579 |
+
nghbr_cos_log_prob[index], nghbr_LM_log_prob[index],
|
580 |
+
direct_score_0, direct_score_1))
|
581 |
+
if (args.added_edge_num == '' or int(args.added_edge_num) == 1):
|
582 |
+
ret_trip.append(add_trip)
|
583 |
+
else:
|
584 |
+
edge_num = int(args.added_edge_num)
|
585 |
+
for i in range(edge_num):
|
586 |
+
ret_trip.append(nghbr_trip[sort_index[i][0]])
|
587 |
+
score_record.append((nghbr_cos_log_prob, nghbr_LM_log_prob))
|
588 |
+
|
589 |
+
if not load_Record and cache_intermidiate:
|
590 |
+
with open(intermidiate_path, 'wb') as fl:
|
591 |
+
dill.dump(RRcord, fl)
|
592 |
+
direct_add_rank_ratio = direct_add_rank_ratio / target_data.shape[0]
|
593 |
+
real_add_rank_ratio = real_add_rank_ratio / target_data.shape[0]
|
594 |
+
bad_ratio = bad_ratio / target_data.shape[0]
|
595 |
+
if logger:
|
596 |
+
logger.info('Mean direct ratio: {}. Mean real ratio: {}. Mean bad ratio: {}'.format(direct_add_rank_ratio, real_add_rank_ratio, bad_ratio))
|
597 |
+
return ret_trip, score_record
|
598 |
+
|
599 |
+
def calculate_edge_bound(data, model, device, n_ent):
|
600 |
+
|
601 |
+
tmp = np.random.choice(a = data.shape[0], size = data.shape[0] // 10, replace=False)
|
602 |
+
existed_data= data[tmp, :]
|
603 |
+
|
604 |
+
print('calculating edge bound ...')
|
605 |
+
print(existed_data.shape)
|
606 |
+
|
607 |
+
existed_edge = set()
|
608 |
+
for src_trip in existed_data:
|
609 |
+
existed_edge.add('_'.join(list(src_trip)))
|
610 |
+
|
611 |
+
not_existed = []
|
612 |
+
for s, r, o in existed_data:
|
613 |
+
|
614 |
+
if np.random.randint(0, n_ent) % 2 == 0:
|
615 |
+
while True:
|
616 |
+
oo = np.random.randint(0, n_ent)
|
617 |
+
if '_'.join([s, r, str(oo)]) not in existed_edge:
|
618 |
+
not_existed.append([s, r, str(oo)])
|
619 |
+
break
|
620 |
+
else:
|
621 |
+
while True:
|
622 |
+
ss = np.random.randint(0, n_ent)
|
623 |
+
if '_'.join([str(ss), r, o]) not in existed_edge:
|
624 |
+
not_existed.append([str(ss), r, o])
|
625 |
+
break
|
626 |
+
existed_data = np.array(existed_data)
|
627 |
+
not_existed = np.array(not_existed)
|
628 |
+
existed_data = torch.from_numpy(existed_data.astype('int64')).to(device)
|
629 |
+
not_existed = torch.from_numpy(not_existed.astype('int64')).to(device)
|
630 |
+
loss_existed = get_model_loss_without_softmax(existed_data, model).cpu().numpy()
|
631 |
+
loss_not_existed = get_model_loss_without_softmax(not_existed, model).cpu().numpy()
|
632 |
+
tot_loss = np.hstack((loss_existed, loss_not_existed))
|
633 |
+
tot_mean, tot_std = np.mean(tot_loss), np.std(tot_loss)
|
634 |
+
|
635 |
+
loss_existed = (loss_existed - tot_mean) / tot_std
|
636 |
+
loss_not_existed = (loss_not_existed - tot_mean) / tot_std
|
637 |
+
|
638 |
+
print('Tot mean: {}, Tot std: {}'.format(tot_mean, tot_std))
|
639 |
+
|
640 |
+
# print(np.mean(loss_existed), np.std(loss_existed), np.max(loss_existed))
|
641 |
+
# print(np.mean(loss_not_existed), np.std(loss_not_existed), np.min(loss_not_existed))
|
642 |
+
l_mean, l_std = np.mean(loss_existed), np.std(loss_existed)
|
643 |
+
r_mean, r_std = np.mean(loss_not_existed), np.std(loss_not_existed)
|
644 |
+
|
645 |
+
A = -1/(l_std**2) + 1/(r_std**2)
|
646 |
+
B = 2 * (-r_mean/(r_std**2) + l_mean/(l_std**2))
|
647 |
+
C = (r_mean**2)/(r_std**2)-(l_mean**2)/(l_std**2) + np.log((r_std**2)/(l_std**2))
|
648 |
+
|
649 |
+
delta = B**2 - 4*A*C
|
650 |
+
|
651 |
+
x_1 = ( -B + math.sqrt(delta) ) / (2*A)
|
652 |
+
x_2 = ( -B - math.sqrt(delta) ) / (2*A)
|
653 |
+
|
654 |
+
x = None
|
655 |
+
if (x_1 > l_mean and x_1 < r_mean):
|
656 |
+
x = x_1
|
657 |
+
if (x_2 > l_mean and x_2 < r_mean):
|
658 |
+
x = x_2
|
659 |
+
if not x:
|
660 |
+
raise Exception('Bad model!!!!')
|
661 |
+
TP = (loss_existed < x).mean()
|
662 |
+
TN = (loss_not_existed > x).mean()
|
663 |
+
FP = (loss_not_existed < x).mean()
|
664 |
+
FN = (loss_existed > x).mean()
|
665 |
+
print('X:{}, TP:{}, TN:{}, FP:{}, FN{}'.format(x, TP, TN, FP, FN))
|
666 |
+
|
667 |
+
sig_existed = 1 / ( 1 + np.exp(loss_existed- x) ) # negtive important
|
668 |
+
sig_not_existed = 1 / ( 1 + np.exp(loss_not_existed - x) )
|
669 |
+
|
670 |
+
print('Positive mean score:', sig_existed.mean(),'Negetive mean score:', sig_not_existed.mean())
|
671 |
+
|
672 |
+
return x, tot_mean, tot_std
|
673 |
+
|
674 |
+
|
675 |
+
#%%
|
676 |
+
if __name__ == '__main__':
|
677 |
+
parser = utils.get_argument_parser()
|
678 |
+
parser = utils.add_attack_parameters(parser)
|
679 |
+
args = parser.parse_args()
|
680 |
+
args = utils.set_hyperparams(args)
|
681 |
+
|
682 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
683 |
+
args.device = device
|
684 |
+
args.device1 = device
|
685 |
+
if torch.cuda.device_count() >= 2:
|
686 |
+
args.device = "cuda:0"
|
687 |
+
args.device1 = "cuda:1"
|
688 |
+
|
689 |
+
utils.seed_all(args.seed)
|
690 |
+
np.set_printoptions(precision=5)
|
691 |
+
cudnn.benchmark = False
|
692 |
+
|
693 |
+
model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
|
694 |
+
model_path = 'saved_models/{0}_{1}.model'.format(args.data, model_name)
|
695 |
+
data_path = os.path.join('processed_data', args.data)
|
696 |
+
target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
|
697 |
+
lissa_path = 'lissa/{0}_{1}_{2}'.format(args.model,
|
698 |
+
args.data,
|
699 |
+
args.target_size)
|
700 |
+
intermidiate_path = 'intermidiate/{0}_{1}_{2}_{3}_{4}_{5}_{6}'.format(args.model,
|
701 |
+
args.target_split,
|
702 |
+
args.target_size,
|
703 |
+
'exists:'+str(args.target_existed),
|
704 |
+
args.neighbor_num,
|
705 |
+
args.candidate_mode,
|
706 |
+
args.attack_goal)
|
707 |
+
log_path = 'logs/attack_logs/cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}'.format(args.model,
|
708 |
+
args.target_split,
|
709 |
+
args.target_size,
|
710 |
+
'exists:'+str(args.target_existed),
|
711 |
+
args.neighbor_num,
|
712 |
+
args.candidate_mode,
|
713 |
+
args.attack_goal,
|
714 |
+
str(args.reasonable_rate))
|
715 |
+
print(log_path)
|
716 |
+
attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}.txt'.format(args.model,
|
717 |
+
args.target_split,
|
718 |
+
args.target_size,
|
719 |
+
'exists:'+str(args.target_existed),
|
720 |
+
args.neighbor_num,
|
721 |
+
args.candidate_mode,
|
722 |
+
args.attack_goal,
|
723 |
+
str(args.reasonable_rate),
|
724 |
+
str(args.added_edge_num)))
|
725 |
+
|
726 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
727 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
728 |
+
level = logging.INFO,
|
729 |
+
filename = log_path
|
730 |
+
)
|
731 |
+
logger = logging.getLogger(__name__)
|
732 |
+
logger.info(vars(args))
|
733 |
+
#%%
|
734 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
735 |
+
data = utils.load_data(os.path.join(data_path, 'all.txt'))
|
736 |
+
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
|
737 |
+
filters = pkl.load(fl)
|
738 |
+
with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl:
|
739 |
+
entityid_to_nodetype = json.load(fl)
|
740 |
+
with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl:
|
741 |
+
edge_nghbrs = pkl.load(fl)
|
742 |
+
with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl:
|
743 |
+
disease_meshid = pkl.load(fl)
|
744 |
+
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
|
745 |
+
entity_to_id = json.load(fl)
|
746 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
747 |
+
entity_raw_name = pkl.load(fl)
|
748 |
+
#%%
|
749 |
+
init_mask = np.asarray([0] * n_ent).astype('int64')
|
750 |
+
init_mask = (init_mask == 1)
|
751 |
+
for k, v in filters.items():
|
752 |
+
for kk, vv in v.items():
|
753 |
+
tmp = init_mask.copy()
|
754 |
+
tmp[np.asarray(vv)] = True
|
755 |
+
t = torch.ByteTensor(tmp).to(args.device)
|
756 |
+
filters[k][kk] = t
|
757 |
+
#%%
|
758 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
759 |
+
divide_bound, data_mean, data_std = calculate_edge_bound(data, model, args.device, n_ent)
|
760 |
+
# index = torch.LongTensor([0, 1]).to(device)
|
761 |
+
# print(model.emb_rel(index)[:, :32])
|
762 |
+
# print(model.emb_e(index)[:, :32])
|
763 |
+
# raise Exception
|
764 |
+
#%%
|
765 |
+
target_data = utils.load_data(target_path)
|
766 |
+
if args.attack_goal == 'single':
|
767 |
+
neighbors = generate_nghbrs(target_data, edge_nghbrs, args)
|
768 |
+
elif args.attack_goal == 'global':
|
769 |
+
s_set = set()
|
770 |
+
for s, r, o in target_data:
|
771 |
+
s_set.add(s)
|
772 |
+
target_data = list(s_set)
|
773 |
+
target_data.sort()
|
774 |
+
target_data = np.array(target_data, dtype=str)
|
775 |
+
neighbors = []
|
776 |
+
for i in list(range(n_ent)):
|
777 |
+
tp = entityid_to_nodetype[str(i)]
|
778 |
+
# r = torch.LongTensor([[10]]).to(device)
|
779 |
+
if tp == 'gene':
|
780 |
+
neighbors.append(str(i))
|
781 |
+
target_disease = []
|
782 |
+
tid = 1
|
783 |
+
bound = 50
|
784 |
+
while True:
|
785 |
+
meshid = disease_meshid[tid][0]
|
786 |
+
fre = disease_meshid[tid][1]
|
787 |
+
if len(entity_raw_name[meshid]) > 4:
|
788 |
+
target_disease.append(entity_to_id[meshid])
|
789 |
+
bound -= 1
|
790 |
+
if bound == 0:
|
791 |
+
break
|
792 |
+
tid += 1
|
793 |
+
else:
|
794 |
+
raise Exception('Wrong attack_goal: '+args.attack_goal)
|
795 |
+
|
796 |
+
param_optimizer = list(model.named_parameters())
|
797 |
+
param_influence = []
|
798 |
+
for n,p in param_optimizer:
|
799 |
+
param_influence.append(p)
|
800 |
+
if args.attack_goal == 'single':
|
801 |
+
len_list = []
|
802 |
+
for v in neighbors.values():
|
803 |
+
len_list.append(len(v))
|
804 |
+
mean_len = np.mean(len_list)
|
805 |
+
else:
|
806 |
+
mean_len = len(neighbors)
|
807 |
+
print('Mean length of neighbors:', mean_len)
|
808 |
+
logger.info("Mean length of neighbors: {0}".format(mean_len))
|
809 |
+
|
810 |
+
# GPT_LM = LMscore_calculator(data_path, args)
|
811 |
+
lissa_num_batches = math.ceil(data.shape[0]/args.lissa_batch_size)
|
812 |
+
logger.info('-------- Lissa Params for IHVP --------')
|
813 |
+
logger.info('Damping: {0}'.format(args.damping))
|
814 |
+
logger.info('Lissa_repeat: {0}'.format(args.lissa_repeat))
|
815 |
+
logger.info('Lissa_depth: {0}'.format(args.lissa_depth))
|
816 |
+
logger.info('Scale: {0}'.format(args.scale))
|
817 |
+
logger.info('Lissa batch size: {0}'.format(args.lissa_batch_size))
|
818 |
+
logger.info('Lissa num bacthes: {0}'.format(lissa_num_batches))
|
819 |
+
|
820 |
+
score_path = os.path.join('attack_results', args.data, 'score_cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}.txt'.format(args.model,
|
821 |
+
args.target_split,
|
822 |
+
args.target_size,
|
823 |
+
'exists:'+str(args.target_existed),
|
824 |
+
args.neighbor_num,
|
825 |
+
args.candidate_mode,
|
826 |
+
args.attack_goal,
|
827 |
+
str(args.reasonable_rate),
|
828 |
+
str(args.added_edge_num)))
|
829 |
+
|
830 |
+
if args.attack_goal == 'single':
|
831 |
+
attack_trip, score_record = addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std)
|
832 |
+
else:
|
833 |
+
# lissa = before_global_attack(args.device, n_rel, data, target_data, neighbors, model, filters, entityid_to_nodetype, args.attack_batch_size, args, lissa_path, target_disease)
|
834 |
+
|
835 |
+
attack_trip, score_record = global_addtion_attack(args.device, n_rel, data, target_data, neighbors, model, filters, entityid_to_nodetype, args.attack_batch_size, args, None, target_disease)
|
836 |
+
|
837 |
+
utils.save_data(attack_path, attack_trip)
|
838 |
+
|
839 |
+
logger.info("Attack triples are saved in " + attack_path)
|
840 |
+
with open(score_path, 'wb') as fl:
|
841 |
+
pkl.dump(score_record, fl)
|
DiseaseSpecific/edge_to_abstract.py
ADDED
@@ -0,0 +1,530 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from torch.autograd import Variable
|
5 |
+
# from sklearn import metrics
|
6 |
+
|
7 |
+
import datetime
|
8 |
+
from typing import Dict, Tuple, List
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import utils
|
12 |
+
import pickle as pkl
|
13 |
+
import json
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
import sys
|
19 |
+
sys.path.append("..")
|
20 |
+
import Parameters
|
21 |
+
|
22 |
+
parser = utils.get_argument_parser()
|
23 |
+
parser = utils.add_attack_parameters(parser)
|
24 |
+
parser.add_argument('--mode', type=str, default='sentence', help='sentence, biogpt or finetune')
|
25 |
+
parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words')
|
26 |
+
args = parser.parse_args()
|
27 |
+
args = utils.set_hyperparams(args)
|
28 |
+
|
29 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
30 |
+
|
31 |
+
utils.seed_all(args.seed)
|
32 |
+
np.set_printoptions(precision=5)
|
33 |
+
cudnn.benchmark = False
|
34 |
+
|
35 |
+
data_path = os.path.join('processed_data', args.data)
|
36 |
+
target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
|
37 |
+
attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}.txt'.format(args.model,
|
38 |
+
args.target_split,
|
39 |
+
args.target_size,
|
40 |
+
'exists:'+str(args.target_existed),
|
41 |
+
args.neighbor_num,
|
42 |
+
args.candidate_mode,
|
43 |
+
args.attack_goal,
|
44 |
+
str(args.reasonable_rate)))
|
45 |
+
# target_data = utils.load_data(target_path)
|
46 |
+
attack_data = utils.load_data(attack_path, drop=False)
|
47 |
+
# assert target_data.shape == attack_data.shape
|
48 |
+
#%%
|
49 |
+
|
50 |
+
with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl:
|
51 |
+
id_to_meshid = json.load(fl)
|
52 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
53 |
+
entity_raw_name = pkl.load(fl)
|
54 |
+
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
|
55 |
+
retieve_sentence_through_edgetype = pkl.load(fl)
|
56 |
+
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
57 |
+
raw_text_sen = pkl.load(fl)
|
58 |
+
|
59 |
+
if not os.path.exists('generate_abstract/valid_entity.json'):
|
60 |
+
valid_entity = set()
|
61 |
+
for paper_id, paper in raw_text_sen.items():
|
62 |
+
for sen_id, sen in paper.items():
|
63 |
+
text = sen['text'].split(' ')
|
64 |
+
for a in text:
|
65 |
+
if '_' in a:
|
66 |
+
valid_entity.add(a.replace('_', ' '))
|
67 |
+
with open('valid_entity.json', 'w') as fl:
|
68 |
+
json.dump(list(valid_entity), fl, indent=4)
|
69 |
+
print('Valid entity saved!!')
|
70 |
+
|
71 |
+
if args.mode == 'sentence':
|
72 |
+
import torch
|
73 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
74 |
+
from transformers import AutoTokenizer
|
75 |
+
from transformers import BioGptForCausalLM
|
76 |
+
criterion = CrossEntropyLoss(reduction="none")
|
77 |
+
|
78 |
+
print('Generating GPT input ...')
|
79 |
+
|
80 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
81 |
+
tokenizer.pad_token = tokenizer.eos_token
|
82 |
+
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
|
83 |
+
model.to(device)
|
84 |
+
model.eval()
|
85 |
+
GPT_batch_size = 32
|
86 |
+
single_sentence = {}
|
87 |
+
test_text = []
|
88 |
+
test_dp = []
|
89 |
+
test_parse = []
|
90 |
+
for i, (s, r, o) in enumerate(tqdm(attack_data)):
|
91 |
+
|
92 |
+
if int(s) != -1:
|
93 |
+
|
94 |
+
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
|
95 |
+
candidate_sen = []
|
96 |
+
Dp_path = []
|
97 |
+
L = len(dependency_sen_dict.keys())
|
98 |
+
bound = 500 // L
|
99 |
+
if bound == 0:
|
100 |
+
bound = 1
|
101 |
+
for dp_path, sen_list in dependency_sen_dict.items():
|
102 |
+
if len(sen_list) > bound:
|
103 |
+
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
|
104 |
+
sen_list = [sen_list[aa] for aa in index]
|
105 |
+
candidate_sen += sen_list
|
106 |
+
Dp_path += [dp_path] * len(sen_list)
|
107 |
+
|
108 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
109 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
110 |
+
candidate_text_sen = []
|
111 |
+
candidate_ori_sen = []
|
112 |
+
candidate_parse_sen = []
|
113 |
+
|
114 |
+
for paper_id, sen_id in candidate_sen:
|
115 |
+
sen = raw_text_sen[paper_id][sen_id]
|
116 |
+
text = sen['text']
|
117 |
+
candidate_ori_sen.append(text)
|
118 |
+
ss = sen['start_formatted']
|
119 |
+
oo = sen['end_formatted']
|
120 |
+
text = text.replace('-LRB-', '(')
|
121 |
+
text = text.replace('-RRB-', ')')
|
122 |
+
text = text.replace('-LSB-', '[')
|
123 |
+
text = text.replace('-RSB-', ']')
|
124 |
+
text = text.replace('-LCB-', '{')
|
125 |
+
text = text.replace('-RCB-', '}')
|
126 |
+
parse_text = text
|
127 |
+
parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
|
128 |
+
parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
|
129 |
+
text = text.replace(ss, text_s)
|
130 |
+
text = text.replace(oo, text_o)
|
131 |
+
text = text.replace('_', ' ')
|
132 |
+
candidate_text_sen.append(text)
|
133 |
+
candidate_parse_sen.append(parse_text)
|
134 |
+
tokens = tokenizer( candidate_text_sen,
|
135 |
+
truncation = True,
|
136 |
+
padding = True,
|
137 |
+
max_length = 300,
|
138 |
+
return_tensors="pt")
|
139 |
+
target_ids = tokens['input_ids'].to(device)
|
140 |
+
attention_mask = tokens['attention_mask'].to(device)
|
141 |
+
|
142 |
+
L = len(candidate_text_sen)
|
143 |
+
assert L > 0
|
144 |
+
ret_log_L = []
|
145 |
+
for l in range(0, L, GPT_batch_size):
|
146 |
+
R = min(L, l + GPT_batch_size)
|
147 |
+
target = target_ids[l:R, :]
|
148 |
+
attention = attention_mask[l:R, :]
|
149 |
+
outputs = model(input_ids = target,
|
150 |
+
attention_mask = attention,
|
151 |
+
labels = target)
|
152 |
+
logits = outputs.logits
|
153 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
154 |
+
shift_labels = target[..., 1:].contiguous()
|
155 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
156 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
157 |
+
attention = attention[..., 1:].contiguous()
|
158 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
159 |
+
ret_log_L.append(log_Loss.detach())
|
160 |
+
|
161 |
+
|
162 |
+
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
|
163 |
+
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
|
164 |
+
sen_score.sort(key = lambda x: x[1])
|
165 |
+
test_text.append(sen_score[0][2])
|
166 |
+
test_dp.append(sen_score[0][3])
|
167 |
+
test_parse.append(sen_score[0][4])
|
168 |
+
single_sentence.update({f'{s}_{r}_{o}_{i}': sen_score[0][0]})
|
169 |
+
|
170 |
+
else:
|
171 |
+
single_sentence.update({f'{s}_{r}_{o}_{i}': ''})
|
172 |
+
|
173 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_sentence.json', 'w') as fl:
|
174 |
+
json.dump(single_sentence, fl, indent=4)
|
175 |
+
# with open('generate_abstract/test.txt', 'w') as fl:
|
176 |
+
# fl.write('\n'.join(test_text))
|
177 |
+
# with open('generate_abstract/dp.txt', 'w') as fl:
|
178 |
+
# fl.write('\n'.join(test_dp))
|
179 |
+
with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'w') as fl:
|
180 |
+
fl.write('\n'.join(test_dp))
|
181 |
+
with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_temp.json', 'w') as fl:
|
182 |
+
fl.write('\n'.join(test_text))
|
183 |
+
|
184 |
+
elif args.mode == 'finetune':
|
185 |
+
|
186 |
+
import spacy
|
187 |
+
import pprint
|
188 |
+
from transformers import AutoModel, AutoTokenizer,BartForConditionalGeneration
|
189 |
+
|
190 |
+
print('Finetuning ...')
|
191 |
+
|
192 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}_chat.json', 'r') as fl:
|
193 |
+
draft = json.load(fl)
|
194 |
+
with open (f'generate_abstract/path/{args.target_split}_{args.reasonable_rate}_path.json', 'r') as fl:
|
195 |
+
dpath = fl.readlines()
|
196 |
+
|
197 |
+
nlp = spacy.load("en_core_web_sm")
|
198 |
+
if os.path.exists(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json'):
|
199 |
+
with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'r') as fl:
|
200 |
+
ret_candidates = json.load(fl)
|
201 |
+
# if False:
|
202 |
+
# pass
|
203 |
+
else:
|
204 |
+
|
205 |
+
def find_mini_span(vec, words, check_set):
|
206 |
+
|
207 |
+
|
208 |
+
def cal(text, sset):
|
209 |
+
add = 0
|
210 |
+
for tt in sset:
|
211 |
+
if tt in text:
|
212 |
+
add += 1
|
213 |
+
return add
|
214 |
+
text = ' '.join(words)
|
215 |
+
max_add = cal(text, check_set)
|
216 |
+
|
217 |
+
minn = 10000000
|
218 |
+
span = ''
|
219 |
+
rc = None
|
220 |
+
for i in range(len(vec)):
|
221 |
+
if vec[i] == True:
|
222 |
+
p = -1
|
223 |
+
for j in range(i+1, len(vec)+1):
|
224 |
+
if vec[j-1] == True:
|
225 |
+
text = ' '.join(words[i:j])
|
226 |
+
if cal(text, check_set) == max_add:
|
227 |
+
p = j
|
228 |
+
break
|
229 |
+
if p > 0:
|
230 |
+
if (p-i) < minn:
|
231 |
+
minn = p-i
|
232 |
+
span = ' '.join(words[i:p])
|
233 |
+
rc = (i, p)
|
234 |
+
if rc:
|
235 |
+
for i in range(rc[0], rc[1]):
|
236 |
+
vec[i] = True
|
237 |
+
return vec, span
|
238 |
+
|
239 |
+
def mask_func(tokenized_sen):
|
240 |
+
|
241 |
+
if len(tokenized_sen) == 0:
|
242 |
+
return []
|
243 |
+
token_list = []
|
244 |
+
# for sen in tokenized_sen:
|
245 |
+
# for token in sen:
|
246 |
+
# token_list.append(token)
|
247 |
+
for sen in tokenized_sen:
|
248 |
+
token_list += sen.text.split(' ')
|
249 |
+
if args.ratio == '':
|
250 |
+
P = 0.3
|
251 |
+
else:
|
252 |
+
P = float(args.ratio)
|
253 |
+
|
254 |
+
ret_list = []
|
255 |
+
i = 0
|
256 |
+
mask_num = 0
|
257 |
+
while i < len(token_list):
|
258 |
+
t = token_list[i]
|
259 |
+
if '.' in t or '(' in t or ')' in t or '[' in t or ']' in t:
|
260 |
+
ret_list.append(t)
|
261 |
+
i += 1
|
262 |
+
mask_num = 0
|
263 |
+
else:
|
264 |
+
length = np.random.poisson(3)
|
265 |
+
if np.random.rand() < P and length > 0:
|
266 |
+
if mask_num < 8:
|
267 |
+
ret_list.append('<mask>')
|
268 |
+
mask_num += 1
|
269 |
+
i += length
|
270 |
+
else:
|
271 |
+
ret_list.append(t)
|
272 |
+
i += 1
|
273 |
+
mask_num = 0
|
274 |
+
return [' '.join(ret_list)]
|
275 |
+
|
276 |
+
model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
|
277 |
+
model.eval()
|
278 |
+
model.to(device)
|
279 |
+
tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
|
280 |
+
|
281 |
+
ret_candidates = {}
|
282 |
+
dpath_i = 0
|
283 |
+
|
284 |
+
for i,(k, v) in enumerate(tqdm(draft.items())):
|
285 |
+
|
286 |
+
input = v['in'].replace('\n', '')
|
287 |
+
output = v['out'].replace('\n', '')
|
288 |
+
s, r, o = attack_data[i]
|
289 |
+
|
290 |
+
if int(s) == -1:
|
291 |
+
ret_candidates[str(i)] = {'span': '', 'prompt' : '', 'out' : [], 'in': [], 'assist': []}
|
292 |
+
continue
|
293 |
+
|
294 |
+
path_text = dpath[dpath_i].replace('\n', '')
|
295 |
+
dpath_i += 1
|
296 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
297 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
298 |
+
|
299 |
+
doc = nlp(output)
|
300 |
+
words= input.split(' ')
|
301 |
+
tokenized_sens = [sen for sen in doc.sents]
|
302 |
+
sens = np.array([sen.text for sen in doc.sents])
|
303 |
+
|
304 |
+
checkset = set([text_s, text_o])
|
305 |
+
e_entity = set(['start_entity', 'end_entity'])
|
306 |
+
for path in path_text.split(' '):
|
307 |
+
a, b, c = path.split('|')
|
308 |
+
if a not in e_entity:
|
309 |
+
checkset.add(a)
|
310 |
+
if c not in e_entity:
|
311 |
+
checkset.add(c)
|
312 |
+
vec = []
|
313 |
+
l = 0
|
314 |
+
while(l < len(words)):
|
315 |
+
bo =False
|
316 |
+
for j in range(len(words), l, -1): # reversing is important !!!
|
317 |
+
cc = ' '.join(words[l:j])
|
318 |
+
if (cc in checkset):
|
319 |
+
vec += [True] * (j-l)
|
320 |
+
l = j
|
321 |
+
bo = True
|
322 |
+
break
|
323 |
+
if not bo:
|
324 |
+
vec.append(False)
|
325 |
+
l += 1
|
326 |
+
vec, span = find_mini_span(vec, words, checkset)
|
327 |
+
# vec = np.vectorize(lambda x: x in checkset)(words)
|
328 |
+
vec[-1] = True
|
329 |
+
prompt = []
|
330 |
+
mask_num = 0
|
331 |
+
for j, bo in enumerate(vec):
|
332 |
+
if not bo:
|
333 |
+
mask_num += 1
|
334 |
+
else:
|
335 |
+
if mask_num > 0:
|
336 |
+
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
|
337 |
+
mask_num = max(mask_num, 1)
|
338 |
+
mask_num= min(8, mask_num)
|
339 |
+
prompt += ['<mask>'] * mask_num
|
340 |
+
prompt.append(words[j])
|
341 |
+
mask_num = 0
|
342 |
+
prompt = ' '.join(prompt)
|
343 |
+
Text = []
|
344 |
+
Assist = []
|
345 |
+
|
346 |
+
for j in range(len(sens)):
|
347 |
+
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
|
348 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
349 |
+
Text.append(' '.join(Bart_input))
|
350 |
+
Assist.append(' '.join(assist))
|
351 |
+
|
352 |
+
for j in range(len(sens)):
|
353 |
+
Bart_input = mask_func(tokenized_sens[:j]) + [input] + mask_func(tokenized_sens[j+1:])
|
354 |
+
assist = list(sens[:j]) + [input] +list(sens[j+1:])
|
355 |
+
Text.append(' '.join(Bart_input))
|
356 |
+
Assist.append(' '.join(assist))
|
357 |
+
|
358 |
+
batch_size = len(Text) // 2
|
359 |
+
Outs = []
|
360 |
+
for l in range(2):
|
361 |
+
A = tokenizer(Text[batch_size * l:batch_size * (l+1)],
|
362 |
+
truncation = True,
|
363 |
+
padding = True,
|
364 |
+
max_length = 1024,
|
365 |
+
return_tensors="pt")
|
366 |
+
input_ids = A['input_ids'].to(device)
|
367 |
+
attention_mask = A['attention_mask'].to(device)
|
368 |
+
aaid = model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024)
|
369 |
+
outs = tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
370 |
+
Outs += outs
|
371 |
+
ret_candidates[str(i)] = {'span': span, 'prompt' : prompt, 'out' : Outs, 'in': Text, 'assist': Assist}
|
372 |
+
with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_candidates.json', 'w') as fl:
|
373 |
+
json.dump(ret_candidates, fl, indent = 4)
|
374 |
+
|
375 |
+
from torch.nn.modules.loss import CrossEntropyLoss
|
376 |
+
from transformers import BioGptForCausalLM
|
377 |
+
criterion = CrossEntropyLoss(reduction="none")
|
378 |
+
|
379 |
+
tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
380 |
+
tokenizer.pad_token = tokenizer.eos_token
|
381 |
+
model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=tokenizer.eos_token_id)
|
382 |
+
model.to(device)
|
383 |
+
model.eval()
|
384 |
+
|
385 |
+
scored = {}
|
386 |
+
ret = {}
|
387 |
+
dpath_i = 0
|
388 |
+
for i,(k, v) in enumerate(tqdm(draft.items())):
|
389 |
+
|
390 |
+
span = ret_candidates[str(i)]['span']
|
391 |
+
prompt = ret_candidates[str(i)]['prompt']
|
392 |
+
sen_list = ret_candidates[str(i)]['out']
|
393 |
+
BART_in = ret_candidates[str(i)]['in']
|
394 |
+
Assist = ret_candidates[str(i)]['assist']
|
395 |
+
|
396 |
+
s, r, o = attack_data[i]
|
397 |
+
|
398 |
+
if int(s) == -1:
|
399 |
+
ret[k] = {'prompt': '', 'in':'', 'out': ''}
|
400 |
+
continue
|
401 |
+
|
402 |
+
text_s = entity_raw_name[id_to_meshid[s]]
|
403 |
+
text_o = entity_raw_name[id_to_meshid[o]]
|
404 |
+
|
405 |
+
def process(text):
|
406 |
+
|
407 |
+
for i in range(ord('A'), ord('Z')+1):
|
408 |
+
text = text.replace(f'.{chr(i)}', f'. {chr(i)}')
|
409 |
+
return text
|
410 |
+
|
411 |
+
sen_list = [process(text) for text in sen_list]
|
412 |
+
path_text = dpath[dpath_i].replace('\n', '')
|
413 |
+
dpath_i += 1
|
414 |
+
|
415 |
+
checkset = set([text_s, text_o])
|
416 |
+
e_entity = set(['start_entity', 'end_entity'])
|
417 |
+
for path in path_text.split(' '):
|
418 |
+
a, b, c = path.split('|')
|
419 |
+
if a not in e_entity:
|
420 |
+
checkset.add(a)
|
421 |
+
if c not in e_entity:
|
422 |
+
checkset.add(c)
|
423 |
+
|
424 |
+
input = v['in'].replace('\n', '')
|
425 |
+
output = v['out'].replace('\n', '')
|
426 |
+
|
427 |
+
doc = nlp(output)
|
428 |
+
gpt_sens = [sen.text for sen in doc.sents]
|
429 |
+
assert len(gpt_sens) == len(sen_list) // 2
|
430 |
+
|
431 |
+
word_sets = []
|
432 |
+
for sen in gpt_sens:
|
433 |
+
word_sets.append(set(sen.split(' ')))
|
434 |
+
|
435 |
+
def sen_align(word_sets, modified_word_sets):
|
436 |
+
|
437 |
+
l = 0
|
438 |
+
while(l < len(modified_word_sets)):
|
439 |
+
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
|
440 |
+
l += 1
|
441 |
+
else:
|
442 |
+
break
|
443 |
+
if l == len(modified_word_sets):
|
444 |
+
return -1, -1, -1, -1
|
445 |
+
r = l + 1
|
446 |
+
r1 = None
|
447 |
+
r2 = None
|
448 |
+
for pos1 in range(r, len(word_sets)):
|
449 |
+
for pos2 in range(r, len(modified_word_sets)):
|
450 |
+
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
|
451 |
+
r1 = pos1
|
452 |
+
r2 = pos2
|
453 |
+
break
|
454 |
+
if r1 is not None:
|
455 |
+
break
|
456 |
+
if r1 is None:
|
457 |
+
r1 = len(word_sets)
|
458 |
+
r2 = len(modified_word_sets)
|
459 |
+
return l, r1, l, r2
|
460 |
+
|
461 |
+
replace_sen_list = []
|
462 |
+
boundary = []
|
463 |
+
assert len(sen_list) % 2 == 0
|
464 |
+
for j in range(len(sen_list) // 2):
|
465 |
+
doc = nlp(sen_list[j])
|
466 |
+
sens = [sen.text for sen in doc.sents]
|
467 |
+
modified_word_sets = [set(sen.split(' ')) for sen in sens]
|
468 |
+
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
|
469 |
+
boundary.append((l1, r1, l2, r2))
|
470 |
+
if l1 == -1:
|
471 |
+
replace_sen_list.append(sen_list[j])
|
472 |
+
continue
|
473 |
+
check_text = ' '.join(sens[l2: r2])
|
474 |
+
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
|
475 |
+
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
|
476 |
+
|
477 |
+
old_L = len(sen_list)
|
478 |
+
sen_list.append(output)
|
479 |
+
sen_list += Assist
|
480 |
+
tokens = tokenizer( sen_list,
|
481 |
+
truncation = True,
|
482 |
+
padding = True,
|
483 |
+
max_length = 1024,
|
484 |
+
return_tensors="pt")
|
485 |
+
target_ids = tokens['input_ids'].to(device)
|
486 |
+
attention_mask = tokens['attention_mask'].to(device)
|
487 |
+
L = len(sen_list)
|
488 |
+
ret_log_L = []
|
489 |
+
for l in range(0, L, 5):
|
490 |
+
R = min(L, l + 5)
|
491 |
+
target = target_ids[l:R, :]
|
492 |
+
attention = attention_mask[l:R, :]
|
493 |
+
outputs = model(input_ids = target,
|
494 |
+
attention_mask = attention,
|
495 |
+
labels = target)
|
496 |
+
logits = outputs.logits
|
497 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
498 |
+
shift_labels = target[..., 1:].contiguous()
|
499 |
+
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
|
500 |
+
Loss = Loss.view(-1, shift_logits.shape[1])
|
501 |
+
attention = attention[..., 1:].contiguous()
|
502 |
+
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
|
503 |
+
ret_log_L.append(log_Loss.detach())
|
504 |
+
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
|
505 |
+
|
506 |
+
real_log_Loss = log_Loss.copy()
|
507 |
+
|
508 |
+
log_Loss = log_Loss[:old_L]
|
509 |
+
|
510 |
+
p = np.argmin(log_Loss)
|
511 |
+
content = []
|
512 |
+
for i in range(len(real_log_Loss)):
|
513 |
+
content.append([sen_list[i], str(real_log_Loss[i])])
|
514 |
+
scored[k] = {'path':path_text, 'prompt': prompt, 'in':input, 's':text_s, 'o':text_o, 'out': content, 'bound': boundary}
|
515 |
+
p_p = p
|
516 |
+
# print('Old_L:', old_L)
|
517 |
+
|
518 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
519 |
+
p_p = p+1+old_L
|
520 |
+
|
521 |
+
if real_log_Loss[p] > real_log_Loss[old_L]:
|
522 |
+
if real_log_Loss[p] > real_log_Loss[p+1+old_L]:
|
523 |
+
p = p+1+old_L
|
524 |
+
ret[k] = {'prompt': prompt, 'in':input, 'out': sen_list[p]}
|
525 |
+
with open(f'generate_abstract/{args.target_split}_{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'w') as fl:
|
526 |
+
json.dump(ret, fl, indent=4)
|
527 |
+
with open(f'generate_abstract/bioBART/{args.target_split}_{args.reasonable_rate}{args.ratio}_scored.json', 'w') as fl:
|
528 |
+
json.dump(scored, fl, indent=4)
|
529 |
+
else:
|
530 |
+
raise Exception('Wrong mode !!')
|
DiseaseSpecific/evaluation.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from sklearn import metrics
|
5 |
+
|
6 |
+
import datetime
|
7 |
+
from typing import Dict, Tuple, List
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import utils
|
11 |
+
import pickle as pkl
|
12 |
+
import json
|
13 |
+
from tqdm import tqdm
|
14 |
+
import torch.backends.cudnn as cudnn
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append("..")
|
18 |
+
import Parameters
|
19 |
+
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
def get_model_loss_without_softmax(batch, model, device=None):
|
23 |
+
|
24 |
+
with torch.no_grad():
|
25 |
+
s,r,o = batch[:,0], batch[:,1], batch[:,2]
|
26 |
+
|
27 |
+
emb_s = model.emb_e(s).squeeze(dim=1)
|
28 |
+
emb_r = model.emb_rel(r).squeeze(dim=1)
|
29 |
+
|
30 |
+
pred = model.forward(emb_s, emb_r)
|
31 |
+
return -pred[range(o.shape[0]), o]
|
32 |
+
|
33 |
+
def check(trip, model, reasonable_rate, device, data_mean = -4.008113861083984, data_std = 5.153779983520508, divide_bound = 0.05440050354114886):
|
34 |
+
|
35 |
+
if args.model == 'distmult':
|
36 |
+
pass
|
37 |
+
elif args.model == 'conve':
|
38 |
+
data_mean = 13.890259742
|
39 |
+
data_std = 12.396190643
|
40 |
+
divide_bound = -0.1986345871
|
41 |
+
else:
|
42 |
+
raise Exception('Wrong model!!')
|
43 |
+
trip = np.array(trip)
|
44 |
+
train_trip = trip[None, :]
|
45 |
+
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
|
46 |
+
edge_loss = get_model_loss_without_softmax(train_trip, model, device).squeeze().item()
|
47 |
+
|
48 |
+
bound = 1 - reasonable_rate
|
49 |
+
edge_loss = (edge_loss - data_mean) / data_std
|
50 |
+
edge_loss_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound))
|
51 |
+
return edge_loss_prob > bound
|
52 |
+
|
53 |
+
|
54 |
+
def get_ranking(model, queries,
|
55 |
+
valid_filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
|
56 |
+
device, batch_size, entityid_to_nodetype, exists_edge):
|
57 |
+
"""
|
58 |
+
Ranking for target generation.
|
59 |
+
"""
|
60 |
+
ranks = []
|
61 |
+
total_nums = []
|
62 |
+
b_begin = 0
|
63 |
+
|
64 |
+
for b_begin in range(0, len(queries), 1):
|
65 |
+
b_queries = queries[b_begin : b_begin+1]
|
66 |
+
s,r,o = b_queries[:,0], b_queries[:,1], b_queries[:,2]
|
67 |
+
r_rev = r
|
68 |
+
lhs_score = model.score_or(o, r_rev, sigmoid=False) #this gives scores not probabilities
|
69 |
+
# print(b_queries.shape)
|
70 |
+
for i, query in enumerate(b_queries):
|
71 |
+
|
72 |
+
if not args.target_existed:
|
73 |
+
tp1 = entityid_to_nodetype[str(query[0].item())]
|
74 |
+
tp2 = entityid_to_nodetype[str(query[2].item())]
|
75 |
+
filter = valid_filters['lhs'][(tp2, query[1].item())].clone()
|
76 |
+
filter[exists_edge['lhs'][str(query[2].item())]] = False
|
77 |
+
filter = (filter == False)
|
78 |
+
else:
|
79 |
+
tp1 = entityid_to_nodetype[str(query[0].item())]
|
80 |
+
tp2 = entityid_to_nodetype[str(query[2].item())]
|
81 |
+
filter = valid_filters['lhs'][(tp2, query[1].item())]
|
82 |
+
filter = (filter == False)
|
83 |
+
|
84 |
+
# if (str(query[2].item())) == '16566':
|
85 |
+
# print('16566', filter.sum(), valid_filters['lhs'][(tp2, query[1].item())].sum(), tp2, query[1].item())
|
86 |
+
# raise Exception('??')
|
87 |
+
|
88 |
+
score = lhs_score
|
89 |
+
# target_value = rhs_score[i, query[0].item()].item()
|
90 |
+
# zero all known cases (this are not interesting)
|
91 |
+
# this corresponds to the filtered setting
|
92 |
+
score[i][filter] = 1e6
|
93 |
+
total_nums.append(n_ent - filter.sum().item())
|
94 |
+
# write base the saved values
|
95 |
+
# if b_begin < len(queries) // 2:
|
96 |
+
# score[i][query[2].item()] = target_value
|
97 |
+
# else:
|
98 |
+
# score[i][query[0].item()] = target_value
|
99 |
+
|
100 |
+
# sort and rank
|
101 |
+
min_values, sort_v = torch.sort(score, dim=1, descending=False) #low scores get low number ranks
|
102 |
+
|
103 |
+
sort_v = sort_v.cpu().numpy()
|
104 |
+
|
105 |
+
for i, query in enumerate(b_queries):
|
106 |
+
# find the rank of the target entities
|
107 |
+
rank = np.where(sort_v[i]==query[0].item())[0][0]
|
108 |
+
|
109 |
+
# rank+1, since the lowest rank is rank 1 not rank 0
|
110 |
+
ranks.append(rank)
|
111 |
+
|
112 |
+
#logger.info('Ranking done for all queries')
|
113 |
+
return ranks, total_nums
|
114 |
+
|
115 |
+
|
116 |
+
def evaluation(model, queries,
|
117 |
+
valid_filters:Dict[str, Dict[Tuple[str, int], torch.Tensor]],
|
118 |
+
device, batch_size, entityid_to_nodetype, exists_edge, eval_type = '', attack_data = None, ori_ranks = None, ori_totals = None):
|
119 |
+
|
120 |
+
#get ranking
|
121 |
+
ranks, total_nums = get_ranking(model, queries, valid_filters, device, batch_size, entityid_to_nodetype, exists_edge)
|
122 |
+
ranks, total_nums = np.array(ranks), np.array(total_nums)
|
123 |
+
# print(ranks)
|
124 |
+
# print(total_nums)
|
125 |
+
# print(ranks)
|
126 |
+
# print(total_nums)
|
127 |
+
|
128 |
+
ranks = total_nums - ranks
|
129 |
+
|
130 |
+
if (attack_data is not None):
|
131 |
+
for i, tri in enumerate(attack_data):
|
132 |
+
if args.mode == '':
|
133 |
+
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
|
134 |
+
if int(tri[0]) == -1:
|
135 |
+
ranks[i] = ori_ranks[i]
|
136 |
+
total_nums[i] = ori_totals[i]
|
137 |
+
else:
|
138 |
+
if int(tri[0][0]) == -1:
|
139 |
+
ranks[i] = ori_ranks[i]
|
140 |
+
total_nums[i] = ori_totals[i]
|
141 |
+
else:
|
142 |
+
if len(tri) == 0:
|
143 |
+
ranks[i] = ori_ranks[i]
|
144 |
+
total_nums[i] = ori_totals[i]
|
145 |
+
|
146 |
+
mean = (ranks / total_nums).mean()
|
147 |
+
std = (ranks / total_nums).std()
|
148 |
+
#final logging
|
149 |
+
hits_at = np.arange(1,11)
|
150 |
+
hits_at_both = list(map(lambda x: np.mean((ranks <= x), dtype=np.float64).item(),
|
151 |
+
hits_at))
|
152 |
+
mr = np.mean(ranks, dtype=np.float64).item()
|
153 |
+
|
154 |
+
mrr = np.mean(1. / ranks, dtype=np.float64).item()
|
155 |
+
|
156 |
+
logger.info('')
|
157 |
+
logger.info('-'*50)
|
158 |
+
# logger.info(split+'_'+save_name)
|
159 |
+
logger.info('')
|
160 |
+
if eval_type:
|
161 |
+
logger.info(eval_type)
|
162 |
+
else:
|
163 |
+
logger.info('after attck')
|
164 |
+
|
165 |
+
for i in hits_at:
|
166 |
+
logger.info('Hits @{0}: {1}'.format(i, hits_at_both[i-1]))
|
167 |
+
logger.info('Mean rank: {0}'.format( mr))
|
168 |
+
logger.info('Mean reciprocal rank lhs: {0}'.format(mrr))
|
169 |
+
logger.info('Mean proportion: {0}'.format(mean))
|
170 |
+
logger.info('Std proportion: {0}'.format(std))
|
171 |
+
logger.info('Mean candidate num: {0}'.format(np.mean(total_nums)))
|
172 |
+
|
173 |
+
# with open(os.path.join('results', split + '_' + save_name + '.txt'), 'a') as text_file:
|
174 |
+
# text_file.write('Epoch: {0}\n'.format(epoch))
|
175 |
+
# text_file.write('Lhs denotes ranking by subject corruptions \n')
|
176 |
+
# text_file.write('Rhs denotes ranking by object corruptions \n')
|
177 |
+
# for i in hits_at:
|
178 |
+
# text_file.write('Hits left @{0}: {1}\n'.format(i, hits_at_lhs[i-1]))
|
179 |
+
# text_file.write('Hits right @{0}: {1}\n'.format(i, hits_at_rhs[i-1]))
|
180 |
+
# text_file.write('Hits @{0}: {1}\n'.format(i, np.mean([hits_at_lhs[i-1],hits_at_rhs[i-1]]).item()))
|
181 |
+
# text_file.write('Mean rank lhs: {0}\n'.format( mr_lhs))
|
182 |
+
# text_file.write('Mean rank rhs: {0}\n'.format(mr_rhs))
|
183 |
+
# text_file.write('Mean rank: {0}\n'.format( np.mean([mr_lhs, mr_rhs])))
|
184 |
+
# text_file.write('MRR lhs: {0}\n'.format( mrr_lhs))
|
185 |
+
# text_file.write('MRR rhs: {0}\n'.format(mrr_rhs))
|
186 |
+
# text_file.write('MRR: {0}\n'.format(np.mean([mrr_rhs, mrr_lhs])))
|
187 |
+
# text_file.write('-------------------------------------------------\n')
|
188 |
+
|
189 |
+
|
190 |
+
results = {}
|
191 |
+
for i in hits_at:
|
192 |
+
results['hits @{}'.format(i)] = hits_at_both[i-1]
|
193 |
+
results['mrr'] = mrr
|
194 |
+
results['mr'] = mr
|
195 |
+
results['proportion'] = mean
|
196 |
+
results['std'] = std
|
197 |
+
|
198 |
+
return results, list(ranks), list(total_nums)
|
199 |
+
|
200 |
+
|
201 |
+
parser = utils.get_argument_parser()
|
202 |
+
parser = utils.add_attack_parameters(parser)
|
203 |
+
parser = utils.add_eval_parameters(parser)
|
204 |
+
args = parser.parse_args()
|
205 |
+
args = utils.set_hyperparams(args)
|
206 |
+
|
207 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
208 |
+
|
209 |
+
utils.seed_all(args.seed)
|
210 |
+
np.set_printoptions(precision=5)
|
211 |
+
cudnn.benchmark = False
|
212 |
+
|
213 |
+
data_path = os.path.join('processed_data', args.data)
|
214 |
+
target_path = os.path.join(data_path, 'DD_target_{0}_{1}_{2}_{3}_{4}_{5}.txt'.format(args.model, args.data, args.target_split, args.target_size, 'exists:'+str(args.target_existed), args.attack_goal))
|
215 |
+
|
216 |
+
log_path = 'logs/evaluation_logs/cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}'.format(args.model,
|
217 |
+
args.target_split,
|
218 |
+
args.target_size,
|
219 |
+
'exists:'+str(args.target_existed),
|
220 |
+
args.neighbor_num,
|
221 |
+
args.candidate_mode,
|
222 |
+
args.attack_goal,
|
223 |
+
str(args.reasonable_rate),
|
224 |
+
args.mode)
|
225 |
+
record_path = 'eval_record/{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}{9}{10}'.format(args.model,
|
226 |
+
args.target_split,
|
227 |
+
args.target_size,
|
228 |
+
'exists:'+str(args.target_existed),
|
229 |
+
args.neighbor_num,
|
230 |
+
args.candidate_mode,
|
231 |
+
args.attack_goal,
|
232 |
+
str(args.reasonable_rate),
|
233 |
+
args.mode,
|
234 |
+
str(args.added_edge_num),
|
235 |
+
args.mask_ratio)
|
236 |
+
init_record_path = 'eval_record/{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}'.format(args.model,
|
237 |
+
args.target_split,
|
238 |
+
args.target_size,
|
239 |
+
'exists:'+str(args.target_existed),
|
240 |
+
args.neighbor_num,
|
241 |
+
args.candidate_mode,
|
242 |
+
args.attack_goal,
|
243 |
+
str(args.reasonable_rate),
|
244 |
+
'init')
|
245 |
+
|
246 |
+
if args.seperate:
|
247 |
+
record_path += '_seperate'
|
248 |
+
log_path += '_seperate'
|
249 |
+
else:
|
250 |
+
record_path += '_batch'
|
251 |
+
|
252 |
+
if args.direct:
|
253 |
+
log_path += '_direct'
|
254 |
+
record_path += '_direct'
|
255 |
+
else:
|
256 |
+
log_path += '_nodirect'
|
257 |
+
record_path += '_nodirect'
|
258 |
+
|
259 |
+
dis_turbrbed_path_pre = os.path.join(data_path, 'evaluation')
|
260 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
261 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
262 |
+
level = logging.INFO,
|
263 |
+
filename = log_path
|
264 |
+
)
|
265 |
+
logger = logging.getLogger(__name__)
|
266 |
+
logger.info(vars(args))
|
267 |
+
|
268 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
269 |
+
model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
|
270 |
+
model_path = 'saved_models/{0}_{1}.model'.format(args.data, model_name)
|
271 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, device)
|
272 |
+
|
273 |
+
ori_data = utils.load_data(os.path.join(data_path, 'all.txt'))
|
274 |
+
target_data = utils.load_data(target_path)
|
275 |
+
|
276 |
+
index = range(len(target_data))
|
277 |
+
index = np.random.permutation(index)
|
278 |
+
target_data = target_data[index]
|
279 |
+
|
280 |
+
if args.direct:
|
281 |
+
assert args.attack_goal == 'single'
|
282 |
+
raise Exception('This option is abandoned in this version .')
|
283 |
+
# disturbed_data = list(ori_data) + list(target_data)
|
284 |
+
else:
|
285 |
+
|
286 |
+
attack_path = os.path.join('attack_results', args.data, 'cos_{0}_{1}_{2}_{3}_{4}_{5}_{6}_{7}{8}{9}{10}.txt'.format(args.model,
|
287 |
+
args.target_split,
|
288 |
+
args.target_size,
|
289 |
+
'exists:'+str(args.target_existed),
|
290 |
+
args.neighbor_num,
|
291 |
+
args.candidate_mode,
|
292 |
+
args.attack_goal,
|
293 |
+
str(args.reasonable_rate),
|
294 |
+
args.mode,
|
295 |
+
str(args.added_edge_num),
|
296 |
+
args.mask_ratio))
|
297 |
+
if args.mode == '':
|
298 |
+
attack_data = utils.load_data(attack_path, drop=False)
|
299 |
+
if not(args.added_edge_num == '' or int(args.added_edge_num) == 1):
|
300 |
+
assert int(args.added_edge_num) * len(target_data) == len(attack_data)
|
301 |
+
attack_data = attack_data.reshape((len(target_data), int(args.added_edge_num), 3))
|
302 |
+
attack_data = attack_data[index]
|
303 |
+
else:
|
304 |
+
assert len(target_data) == len(attack_data)
|
305 |
+
attack_data = attack_data[index]
|
306 |
+
# if not args.seperate:
|
307 |
+
# disturbed_data = list(ori_data) + list(attack_data)
|
308 |
+
else:
|
309 |
+
with open(attack_path, 'rb') as fl:
|
310 |
+
attack_data = pkl.load(fl)
|
311 |
+
|
312 |
+
tmp_attack_data = []
|
313 |
+
for vv in attack_data:
|
314 |
+
a_attack = []
|
315 |
+
for v in vv:
|
316 |
+
if check(v, model, args.reasonable_rate, device):
|
317 |
+
a_attack.append(v)
|
318 |
+
tmp_attack_data.append(a_attack)
|
319 |
+
attack_data = tmp_attack_data
|
320 |
+
attack_data = [attack_data[i] for i in index]
|
321 |
+
|
322 |
+
# if not args.seperate:
|
323 |
+
# disturbed_data = list(ori_data)
|
324 |
+
# if args.mode == '':
|
325 |
+
# for aa in list(attack_data):
|
326 |
+
# if int(aa[0]) != -1:
|
327 |
+
# disturbed_data.append(aa)
|
328 |
+
# else:
|
329 |
+
# for vv in attack_data:
|
330 |
+
# for v in vv:
|
331 |
+
# disturbed_data.append(v)
|
332 |
+
|
333 |
+
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
|
334 |
+
valid_filters = pkl.load(fl)
|
335 |
+
with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl:
|
336 |
+
entityid_to_nodetype = json.load(fl)
|
337 |
+
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
|
338 |
+
entity_raw_name = pkl.load(fl)
|
339 |
+
with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl:
|
340 |
+
disease_meshid = pkl.load(fl)
|
341 |
+
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
|
342 |
+
entity_to_id = json.load(fl)
|
343 |
+
|
344 |
+
if args.attack_goal == 'global':
|
345 |
+
raise Exception('Please refer to pagerank method in global setting.')
|
346 |
+
# target_disease = []
|
347 |
+
# tid = 1
|
348 |
+
# bound = 50
|
349 |
+
# while True:
|
350 |
+
# meshid = disease_meshid[tid][0]
|
351 |
+
# fre = disease_meshid[tid][1]
|
352 |
+
# if len(entity_raw_name[meshid]) > 4:
|
353 |
+
# target_disease.append(entity_to_id[meshid])
|
354 |
+
# bound -= 1
|
355 |
+
# if bound == 0:
|
356 |
+
# break
|
357 |
+
# tid += 1
|
358 |
+
# s_set = set()
|
359 |
+
# for s, r, o in target_data:
|
360 |
+
# s_set.add(s)
|
361 |
+
# target_data = list(s_set)
|
362 |
+
# target_data.sort()
|
363 |
+
|
364 |
+
# target_list = []
|
365 |
+
# for s in target_data:
|
366 |
+
# for o in target_disease:
|
367 |
+
# target_list.append([str(s), str(10), str(o)])
|
368 |
+
# target_data = np.array(target_list, dtype = str)
|
369 |
+
|
370 |
+
init_mask = np.asarray([0] * n_ent).astype('int64')
|
371 |
+
init_mask = (init_mask == 1)
|
372 |
+
for k, v in valid_filters.items():
|
373 |
+
for kk, vv in v.items():
|
374 |
+
tmp = init_mask.copy()
|
375 |
+
tmp[np.asarray(vv)] = True
|
376 |
+
t = torch.ByteTensor(tmp).to(device)
|
377 |
+
valid_filters[k][kk] = t
|
378 |
+
# print('what??', valid_filters['lhs'][('disease', 10)].sum())
|
379 |
+
|
380 |
+
exists_edge = {'lhs':{}, 'rhs':{}}
|
381 |
+
for s, r, o in ori_data:
|
382 |
+
if s not in exists_edge['rhs'].keys():
|
383 |
+
exists_edge['rhs'][s] = []
|
384 |
+
if o not in exists_edge['lhs'].keys():
|
385 |
+
exists_edge['lhs'][o] = []
|
386 |
+
exists_edge['rhs'][s].append(int(o))
|
387 |
+
exists_edge['lhs'][o].append(int(s))
|
388 |
+
target_data = torch.from_numpy(target_data.astype('int64')).to(device)
|
389 |
+
# print(target_data[:5, :])
|
390 |
+
ori_results, ori_ranks, ori_totals = evaluation(model, target_data, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge, 'original')
|
391 |
+
print('Original:', ori_results)
|
392 |
+
with open(init_record_path, 'wb') as fl:
|
393 |
+
pkl.dump([ori_results, ori_ranks, ori_totals], fl)
|
394 |
+
|
395 |
+
# raise Exception('Check Original Rank!!!')
|
396 |
+
|
397 |
+
thread_name = args.model+'_'+args.target_split+'_'+args.attack_goal+'_'+str(args.reasonable_rate)+str(args.added_edge_num)+str(args.mask_ratio)
|
398 |
+
if args.direct:
|
399 |
+
thread_name += '_direct'
|
400 |
+
else:
|
401 |
+
thread_name += '_nodirect'
|
402 |
+
if args.seperate:
|
403 |
+
thread_name += '_seperate'
|
404 |
+
else:
|
405 |
+
thread_name += '_batch'
|
406 |
+
thread_name += args.mode
|
407 |
+
|
408 |
+
disturbed_data_path = os.path.join(dis_turbrbed_path_pre, 'all_{}.txt'.format(thread_name))
|
409 |
+
|
410 |
+
if args.seperate:
|
411 |
+
# assert len(attack_data) * len(target_disease) == len(target_data)
|
412 |
+
assert len(attack_data) == len(target_data)
|
413 |
+
# final_result = None
|
414 |
+
Ranks = []
|
415 |
+
Totals = []
|
416 |
+
print('Training model {}...'.format(thread_name))
|
417 |
+
for i in tqdm(range(len(attack_data))):
|
418 |
+
attack_trip = attack_data[i]
|
419 |
+
if args.mode == '':
|
420 |
+
attack_trip = [attack_trip]
|
421 |
+
# target = target_data[i*len(target_disease) : (i+1)*len(target_disease)]
|
422 |
+
target = target_data[i: i+1, :]
|
423 |
+
if len(attack_trip) > 0 and int(attack_trip[0][0]) != -1:
|
424 |
+
disturbed_data = list(ori_data) + attack_trip
|
425 |
+
disturbed_data = np.array(disturbed_data)
|
426 |
+
utils.save_data(disturbed_data_path, disturbed_data)
|
427 |
+
|
428 |
+
cmd = 'CUDA_VISIBLE_DEVICES={} python main_multiprocess.py --data {} --model {} --thread-name {}'.format(args.cuda_name,args.data, args.model, thread_name)
|
429 |
+
os.system(cmd)
|
430 |
+
model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, thread_name)
|
431 |
+
model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, model_name)
|
432 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, device)
|
433 |
+
a_results, a_ranks, a_total_nums = evaluation(model, target, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge)
|
434 |
+
assert len(a_ranks) == 1
|
435 |
+
if not final_result:
|
436 |
+
final_result = a_results
|
437 |
+
else:
|
438 |
+
for k in final_result.keys():
|
439 |
+
final_result[k] += a_results[k]
|
440 |
+
Ranks += a_ranks
|
441 |
+
Totals += a_total_nums
|
442 |
+
else:
|
443 |
+
Ranks += [ori_ranks[i]]
|
444 |
+
Totals += [ori_totals[i]]
|
445 |
+
final_result['proportion'] += ori_ranks[i] / ori_totals[i]
|
446 |
+
for k in final_result.keys():
|
447 |
+
final_result[k] /= attack_data.shape[0]
|
448 |
+
print('Final !!!')
|
449 |
+
print(final_result)
|
450 |
+
logger.info('Final !!!!')
|
451 |
+
for k, v in final_result.items():
|
452 |
+
logger.info('{} : {}'.format(k, v))
|
453 |
+
tmp = np.array(Ranks) / np.array(Totals)
|
454 |
+
print('Std:', np.std(tmp))
|
455 |
+
with open(record_path, 'wb') as fl:
|
456 |
+
pkl.dump([final_result, Ranks, Totals], fl)
|
457 |
+
|
458 |
+
else:
|
459 |
+
assert len(target_data) == len(attack_data)
|
460 |
+
print('Attack shape:' , len(attack_data))
|
461 |
+
Results = []
|
462 |
+
Ranks = []
|
463 |
+
Totals = []
|
464 |
+
for l in range(0, len(target_data), 50):
|
465 |
+
r = min(l+50, len(target_data))
|
466 |
+
t_target_data = target_data[l:r]
|
467 |
+
t_attack_data = attack_data[l:r]
|
468 |
+
t_ori_ranks = ori_ranks[l:r]
|
469 |
+
t_ori_totals = ori_totals[l:r]
|
470 |
+
if args.mode == '':
|
471 |
+
if not(args.added_edge_num == '' or int(args.added_edge_num) == 1):
|
472 |
+
tt_attack_data = []
|
473 |
+
for vv in t_attack_data:
|
474 |
+
tt_attack_data += list(vv)
|
475 |
+
t_attack_data = tt_attack_data
|
476 |
+
else:
|
477 |
+
assert args.mode == 'sentence' or args.mode == 'bioBART'
|
478 |
+
tt_attack_data = []
|
479 |
+
for vv in t_attack_data:
|
480 |
+
tt_attack_data += vv
|
481 |
+
t_attack_data = tt_attack_data
|
482 |
+
disturbed_data = list(ori_data) + list(t_attack_data)
|
483 |
+
|
484 |
+
|
485 |
+
utils.save_data(disturbed_data_path, disturbed_data)
|
486 |
+
cmd = 'CUDA_VISIBLE_DEVICES={} python main_multiprocess.py --data {} --model {} --thread-name {}'.format(args.cuda_name,args.data, args.model, thread_name)
|
487 |
+
print('Training model {}...'.format(thread_name))
|
488 |
+
os.system(cmd)
|
489 |
+
model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, thread_name)
|
490 |
+
model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, model_name)
|
491 |
+
model = utils.load_model(model_path, args, n_ent, n_rel, device)
|
492 |
+
a_results, a_ranks, a_totals = evaluation(model, t_target_data, valid_filters, device, args.test_batch_size, entityid_to_nodetype, exists_edge, attack_data = attack_data[l:r], ori_ranks = t_ori_ranks, ori_totals = t_ori_totals)
|
493 |
+
print(f'************Current l: {l}\n', a_results)
|
494 |
+
assert len(a_ranks) == t_target_data.shape[0]
|
495 |
+
Results += [a_results]
|
496 |
+
Ranks += list(a_ranks)
|
497 |
+
Totals += list(a_totals)
|
498 |
+
with open(record_path, 'wb') as fl:
|
499 |
+
pkl.dump([Results, Ranks, Totals, index], fl)
|
DiseaseSpecific/main.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#%%
|
2 |
+
import pickle as pkl
|
3 |
+
from typing import Dict, Tuple, List
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import json
|
7 |
+
import logging
|
8 |
+
import argparse
|
9 |
+
import math
|
10 |
+
from pprint import pprint
|
11 |
+
import pandas as pd
|
12 |
+
from collections import defaultdict
|
13 |
+
import copy
|
14 |
+
import time
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
import torch.backends.cudnn as cudnn
|
20 |
+
import torch.autograd as autograd
|
21 |
+
|
22 |
+
from model import Distmult, Complex, Conve
|
23 |
+
import utils
|
24 |
+
|
25 |
+
# from evaluation import evaluation
|
26 |
+
|
27 |
+
#%%
|
28 |
+
class Main(object):
|
29 |
+
def __init__(self, args):
|
30 |
+
self.args = args
|
31 |
+
|
32 |
+
self.model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
|
33 |
+
#leaving batches from the model_name since they do not depend on model_architecture
|
34 |
+
# also leaving kernel size and filters, siinice don't intend to change those
|
35 |
+
self.model_path = 'saved_models/{0}_{1}.model'.format(args.data, self.model_name)
|
36 |
+
|
37 |
+
self.log_path = 'logs/{0}_{1}_{2}_{3}.log'.format(args.data, self.model_name, args.epochs, args.train_batch_size)
|
38 |
+
self.loss_path = 'losses/{0}_{1}_{2}_{3}.pickle'.format(args.data, self.model_name, args.epochs, args.train_batch_size)
|
39 |
+
|
40 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
41 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
42 |
+
level = logging.INFO,
|
43 |
+
filename = self.log_path)
|
44 |
+
self.logger = logging.getLogger(__name__)
|
45 |
+
self.logger.info(vars(self.args))
|
46 |
+
self.logger.info('\n')
|
47 |
+
|
48 |
+
|
49 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
50 |
+
|
51 |
+
self.load_data()
|
52 |
+
self.model = self.add_model()
|
53 |
+
self.optimizer = self.add_optimizer(self.model.parameters())
|
54 |
+
|
55 |
+
if self.args.save_influence_map:
|
56 |
+
self.logger.info('-------- Argument save_influence_map is set. Will use GR to compute and save influence maps ----------\n')
|
57 |
+
# when we want to save influence during training
|
58 |
+
self.args.add_reciprocals = False # to keep things simple
|
59 |
+
# init an empty influence map
|
60 |
+
self.influence_map = defaultdict(float)
|
61 |
+
#self.influence_path = 'influence_maps/{0}_{1}.json'.format(args.data, self.model_name)
|
62 |
+
self.influence_path = 'influence_maps/{0}_{1}.pickle'.format(args.data, self.model_name)
|
63 |
+
# Initialize a copy of the model prams to track previous weights in an epoch
|
64 |
+
self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
|
65 |
+
self.logger.info('Shape for previous weights: {}, {}'.format(self.previous_weights[0].shape, self.previous_weights[1].shape))
|
66 |
+
|
67 |
+
def load_data(self):
|
68 |
+
'''
|
69 |
+
Load the train, valid datasets
|
70 |
+
'''
|
71 |
+
data_path = os.path.join('processed_data', self.args.data)
|
72 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
73 |
+
self.n_ent = n_ent
|
74 |
+
self.n_rel = n_rel
|
75 |
+
|
76 |
+
self.train_data = utils.load_data(os.path.join(data_path, 'all.txt'))
|
77 |
+
# print(type(self.train_data), self.train_data.shape) #(1996432, 3)
|
78 |
+
tmp = np.random.choice(a = self.train_data.shape[0], size = int(self.train_data.shape[0] * self.args.KG_valid_rate), replace=False)
|
79 |
+
self.valid_data= self.train_data[tmp, :]
|
80 |
+
|
81 |
+
|
82 |
+
def add_model(self):
|
83 |
+
|
84 |
+
if self.args.model is None:
|
85 |
+
model = Distmult(self.args, self.n_ent, self.n_rel)
|
86 |
+
elif self.args.model == 'distmult':
|
87 |
+
model = Distmult(self.args, self.n_ent, self.n_rel)
|
88 |
+
elif self.args.model == 'complex':
|
89 |
+
model = Complex(self.args, self.n_ent, self.n_rel)
|
90 |
+
elif self.args.model == 'conve':
|
91 |
+
model = Conve(self.args, self.n_ent, self.n_rel)
|
92 |
+
else:
|
93 |
+
self.logger.info('Unknown model: {0}', self.args.model)
|
94 |
+
raise Exception("Unknown model!")
|
95 |
+
model.to(self.device)
|
96 |
+
return model
|
97 |
+
|
98 |
+
def add_optimizer(self, parameters):
|
99 |
+
return torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
|
100 |
+
|
101 |
+
def save_model(self):
|
102 |
+
state = {
|
103 |
+
'state_dict': self.model.state_dict(),
|
104 |
+
'optimizer': self.optimizer.state_dict(),
|
105 |
+
'args': vars(self.args)
|
106 |
+
}
|
107 |
+
torch.save(state, self.model_path)
|
108 |
+
self.logger.info('Saving model to {0}'.format(self.model_path))
|
109 |
+
|
110 |
+
def load_model(self):
|
111 |
+
self.logger.info('Loading saved model from {0}'.format(self.model_path))
|
112 |
+
state = torch.load(self.model_path)
|
113 |
+
model_params = state['state_dict']
|
114 |
+
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
115 |
+
for key, size, count in params:
|
116 |
+
self.logger.info(key, size, count)
|
117 |
+
self.model.load_state_dict(model_params)
|
118 |
+
self.optimizer.load_state_dict(state['optimizer'])
|
119 |
+
|
120 |
+
def lp_regularizer(self):
|
121 |
+
# Apply p-norm regularization; assign weights to each param
|
122 |
+
weight = self.args.reg_weight
|
123 |
+
p = self.args.reg_norm
|
124 |
+
|
125 |
+
trainable_params = [self.model.emb_e.weight, self.model.emb_rel.weight]
|
126 |
+
norm = 0
|
127 |
+
for i in range(len(trainable_params)):
|
128 |
+
#norm += weight * trainable_params[i].norm(p = p)**p
|
129 |
+
norm += weight * torch.sum( torch.abs(trainable_params[i]) ** p)
|
130 |
+
|
131 |
+
return norm
|
132 |
+
|
133 |
+
def n3_regularizer(self, factors):
|
134 |
+
# factors are the embeddings for lhs, rel, rhs for triples in a batch
|
135 |
+
weight = self.args.reg_weight
|
136 |
+
p = self.args.reg_norm
|
137 |
+
|
138 |
+
norm = 0
|
139 |
+
for f in factors:
|
140 |
+
norm += weight * torch.sum(torch.abs(f) ** p)
|
141 |
+
|
142 |
+
return norm / factors[0].shape[0] # scale by number of triples in batch
|
143 |
+
|
144 |
+
def get_influence_map(self):
|
145 |
+
"""
|
146 |
+
Turns the influence map into a list, ready to be written to disc. (before: numpy)
|
147 |
+
:return: the influence map with lists as values
|
148 |
+
"""
|
149 |
+
assert self.args.save_influence_map == True
|
150 |
+
|
151 |
+
for key in self.influence_map:
|
152 |
+
self.influence_map[key] = self.influence_map[key].tolist()
|
153 |
+
#self.logger.info('get_influence_map passed')
|
154 |
+
return self.influence_map
|
155 |
+
|
156 |
+
def evaluate(self, split, batch_size, epoch):
|
157 |
+
"""
|
158 |
+
The same as self.run_epoch()
|
159 |
+
"""
|
160 |
+
|
161 |
+
self.model.eval()
|
162 |
+
losses = []
|
163 |
+
|
164 |
+
with torch.no_grad():
|
165 |
+
input_data = torch.from_numpy(self.valid_data.astype('int64'))
|
166 |
+
actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
|
167 |
+
del input_data
|
168 |
+
|
169 |
+
batch_size = self.args.valid_batch_size
|
170 |
+
for b_begin in tqdm(range(0, actual_examples.shape[0], batch_size)):
|
171 |
+
|
172 |
+
input_batch = actual_examples[b_begin: b_begin + batch_size]
|
173 |
+
input_batch = input_batch.to(self.device)
|
174 |
+
|
175 |
+
s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
|
176 |
+
|
177 |
+
emb_s = self.model.emb_e(s).squeeze(dim=1)
|
178 |
+
emb_r = self.model.emb_rel(r).squeeze(dim=1)
|
179 |
+
emb_o = self.model.emb_e(o).squeeze(dim=1)
|
180 |
+
|
181 |
+
if self.args.add_reciprocals:
|
182 |
+
r_rev = r + self.n_rel
|
183 |
+
emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
|
184 |
+
else:
|
185 |
+
r_rev = r
|
186 |
+
emb_rrev = emb_r
|
187 |
+
|
188 |
+
pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
|
189 |
+
loss_sr = self.model.loss(pred_sr, o) # cross entropy loss
|
190 |
+
|
191 |
+
pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
|
192 |
+
loss_or = self.model.loss(pred_or, s)
|
193 |
+
|
194 |
+
total_loss = loss_sr + loss_or
|
195 |
+
|
196 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
|
197 |
+
#self.logger.info('Computing regularizer weight')
|
198 |
+
if self.args.model == 'complex':
|
199 |
+
emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
|
200 |
+
lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
|
201 |
+
rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
|
202 |
+
rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
|
203 |
+
rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
|
204 |
+
|
205 |
+
#print(lhs[0].shape, lhs[1].shape)
|
206 |
+
factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
207 |
+
torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
|
208 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
|
209 |
+
)
|
210 |
+
factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
211 |
+
torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
|
212 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
|
213 |
+
)
|
214 |
+
else:
|
215 |
+
factors_sr = (emb_s, emb_r, emb_o)
|
216 |
+
factors_or = (emb_s, emb_rrev, emb_o)
|
217 |
+
|
218 |
+
total_loss += self.n3_regularizer(factors_sr)
|
219 |
+
total_loss += self.n3_regularizer(factors_or)
|
220 |
+
|
221 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
|
222 |
+
total_loss += self.lp_regularizer()
|
223 |
+
|
224 |
+
losses.append(total_loss.item())
|
225 |
+
|
226 |
+
loss = np.mean(losses)
|
227 |
+
self.logger.info('[Epoch:{}]: Validating Loss:{:.6}\n'.format(epoch, loss))
|
228 |
+
return loss
|
229 |
+
|
230 |
+
|
231 |
+
def run_epoch(self, epoch):
|
232 |
+
self.model.train()
|
233 |
+
losses = []
|
234 |
+
|
235 |
+
#shuffle the train dataset
|
236 |
+
input_data = torch.from_numpy(self.train_data.astype('int64'))
|
237 |
+
actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
|
238 |
+
del input_data
|
239 |
+
|
240 |
+
batch_size = self.args.train_batch_size
|
241 |
+
|
242 |
+
for b_begin in tqdm(range(0, actual_examples.shape[0], batch_size)):
|
243 |
+
self.optimizer.zero_grad()
|
244 |
+
input_batch = actual_examples[b_begin: b_begin + batch_size]
|
245 |
+
input_batch = input_batch.to(self.device)
|
246 |
+
|
247 |
+
s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
|
248 |
+
|
249 |
+
emb_s = self.model.emb_e(s).squeeze(dim=1)
|
250 |
+
emb_r = self.model.emb_rel(r).squeeze(dim=1)
|
251 |
+
emb_o = self.model.emb_e(o).squeeze(dim=1)
|
252 |
+
|
253 |
+
if self.args.add_reciprocals:
|
254 |
+
r_rev = r + self.n_rel
|
255 |
+
emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
|
256 |
+
else:
|
257 |
+
r_rev = r
|
258 |
+
emb_rrev = emb_r
|
259 |
+
|
260 |
+
pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
|
261 |
+
loss_sr = self.model.loss(pred_sr, o) # loss is cross entropy loss
|
262 |
+
|
263 |
+
pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
|
264 |
+
loss_or = self.model.loss(pred_or, s)
|
265 |
+
|
266 |
+
total_loss = loss_sr + loss_or
|
267 |
+
|
268 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
|
269 |
+
#self.logger.info('Computing regularizer weight')
|
270 |
+
if self.args.model == 'complex':
|
271 |
+
emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
|
272 |
+
lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
|
273 |
+
rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
|
274 |
+
rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
|
275 |
+
rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
|
276 |
+
|
277 |
+
#print(lhs[0].shape, lhs[1].shape)
|
278 |
+
factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
279 |
+
torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
|
280 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
|
281 |
+
factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
282 |
+
torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
|
283 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
|
284 |
+
else:
|
285 |
+
factors_sr = (emb_s, emb_r, emb_o)
|
286 |
+
factors_or = (emb_s, emb_rrev, emb_o)
|
287 |
+
|
288 |
+
total_loss += self.n3_regularizer(factors_sr)
|
289 |
+
total_loss += self.n3_regularizer(factors_or)
|
290 |
+
|
291 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
|
292 |
+
total_loss += self.lp_regularizer()
|
293 |
+
|
294 |
+
|
295 |
+
total_loss.backward()
|
296 |
+
self.optimizer.step()
|
297 |
+
losses.append(total_loss.item())
|
298 |
+
|
299 |
+
if self.args.save_influence_map: #for gradient rollback
|
300 |
+
with torch.no_grad():
|
301 |
+
prev_emb_e = self.previous_weights[0]
|
302 |
+
prev_emb_rel = self.previous_weights[1]
|
303 |
+
# need to compute the influence value per-triple
|
304 |
+
for idx in range(input_batch.shape[0]):
|
305 |
+
head, rel, tail = s[idx], r[idx], o[idx]
|
306 |
+
inf_head = (emb_s[idx] - prev_emb_e[head]).cpu().detach().numpy()
|
307 |
+
inf_tail = (emb_o[idx] - prev_emb_e[tail]).cpu().detach().numpy()
|
308 |
+
inf_rel = (emb_r[idx] - prev_emb_rel[rel]).cpu().detach().numpy()
|
309 |
+
#print(inf_head.shape, inf_tail.shape, inf_rel.shape)
|
310 |
+
|
311 |
+
#write the influences to dictionary
|
312 |
+
key_trip = '{0}_{1}_{2}'.format(head.item(), rel.item(), tail.item())
|
313 |
+
key = '{0}_s'.format(key_trip)
|
314 |
+
self.influence_map[key] += inf_head
|
315 |
+
#self.logger.info('Written to influence map. Key: {0}, Value shape: {1}'.format(key, inf_head.shape))
|
316 |
+
key = '{0}_r'.format(key_trip)
|
317 |
+
self.influence_map[key] += inf_rel
|
318 |
+
key = '{0}_o'.format(key_trip)
|
319 |
+
self.influence_map[key] += inf_tail
|
320 |
+
|
321 |
+
# update the previous weights to be tracked
|
322 |
+
self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
|
323 |
+
|
324 |
+
if (b_begin%5000 == 0) or (b_begin== (actual_examples.shape[0]-1)):
|
325 |
+
self.logger.info('[E:{} | {}]: Train Loss:{:.6}'.format(epoch, b_begin, np.mean(losses)))
|
326 |
+
|
327 |
+
loss = np.mean(losses)
|
328 |
+
self.logger.info('[Epoch:{}]: Training Loss:{:.6}\n'.format(epoch, loss))
|
329 |
+
return loss
|
330 |
+
|
331 |
+
def fit(self):
|
332 |
+
self.model.init()
|
333 |
+
self.logger.info(self.model)
|
334 |
+
|
335 |
+
self.logger.info('------ Start the model training ------')
|
336 |
+
start_time = time.time()
|
337 |
+
self.logger.info('Start time: {0}'.format(str(start_time)))
|
338 |
+
|
339 |
+
|
340 |
+
train_losses = []
|
341 |
+
valid_losses = []
|
342 |
+
best_val = 10000000000.
|
343 |
+
for epoch in range(self.args.epochs):
|
344 |
+
|
345 |
+
print("="*15,'epoch:',epoch,'='*15)
|
346 |
+
train_loss = self.run_epoch(epoch)
|
347 |
+
train_losses.append(train_loss)
|
348 |
+
|
349 |
+
if train_loss < best_val:
|
350 |
+
best_val = train_loss
|
351 |
+
self.save_model()
|
352 |
+
print("Train loss: {0}, Best loss: {1}\n\n".format(train_loss, best_val))
|
353 |
+
|
354 |
+
|
355 |
+
with open(self.loss_path, "wb") as fl:
|
356 |
+
pkl.dump({"train loss":train_losses, "valid loss":valid_losses}, fl)
|
357 |
+
self.logger.info('Time taken to train the model: {0}'.format(str(time.time() - start_time)))
|
358 |
+
start_time = time.time()
|
359 |
+
|
360 |
+
if self.args.save_influence_map: #save the influence map
|
361 |
+
with open(self.influence_path, "wb") as fl: #Pickling
|
362 |
+
pkl.dump(self.get_influence_map(), fl)
|
363 |
+
self.logger.info('Finished saving influence map')
|
364 |
+
self.logger.info('Time taken to save the influence map: {0}'.format(str(time.time() - start_time)))
|
365 |
+
|
366 |
+
#%%
|
367 |
+
parser = utils.get_argument_parser()
|
368 |
+
|
369 |
+
args = parser.parse_args()
|
370 |
+
args = utils.set_hyperparams(args)
|
371 |
+
|
372 |
+
utils.seed_all(args.seed)
|
373 |
+
np.set_printoptions(precision=5)
|
374 |
+
cudnn.benchmark = False
|
375 |
+
|
376 |
+
model = Main(args)
|
377 |
+
model.fit()
|
DiseaseSpecific/main_multiprocess.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Multiprocess for main.py"""
|
2 |
+
#%%
|
3 |
+
import pickle as pkl
|
4 |
+
from typing import Dict, Tuple, List
|
5 |
+
import os
|
6 |
+
import numpy as np
|
7 |
+
import json
|
8 |
+
import logging
|
9 |
+
import argparse
|
10 |
+
import math
|
11 |
+
from pprint import pprint
|
12 |
+
import pandas as pd
|
13 |
+
from collections import defaultdict
|
14 |
+
import copy
|
15 |
+
import time
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
import torch.backends.cudnn as cudnn
|
20 |
+
import torch.autograd as autograd
|
21 |
+
|
22 |
+
from model import Distmult, Complex, Conve
|
23 |
+
import utils
|
24 |
+
|
25 |
+
# from evaluation import evaluation
|
26 |
+
|
27 |
+
#%%
|
28 |
+
class Main(object):
|
29 |
+
def __init__(self, args):
|
30 |
+
self.args = args
|
31 |
+
|
32 |
+
self.model_name = '{0}_{1}_{2}_{3}_{4}_{5}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop, args.thread_name)
|
33 |
+
#leaving batches from the model_name since they do not depend on model_architecture
|
34 |
+
# also leaving kernel size and filters, siinice don't intend to change those
|
35 |
+
self.model_path = 'saved_models/evaluation/{0}_{1}.model'.format(args.data, self.model_name)
|
36 |
+
|
37 |
+
self.log_path = 'logs/evaluation_logs/{0}_{1}_{2}_{3}_{4}.log'.format(args.data, self.model_name, args.epochs, args.train_batch_size, args.thread_name)
|
38 |
+
self.loss_path = 'losses/evaluation_losses/{0}_{1}_{2}_{3}_{4}.pickle'.format(args.data, self.model_name, args.epochs, args.train_batch_size, args.thread_name)
|
39 |
+
|
40 |
+
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
41 |
+
datefmt = '%m/%d/%Y %H:%M:%S',
|
42 |
+
level = logging.INFO,
|
43 |
+
filename = self.log_path)
|
44 |
+
self.logger = logging.getLogger(__name__)
|
45 |
+
self.logger.info(vars(self.args))
|
46 |
+
self.logger.info('\n')
|
47 |
+
|
48 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
|
50 |
+
self.load_data()
|
51 |
+
self.model = self.add_model()
|
52 |
+
self.optimizer = self.add_optimizer(self.model.parameters())
|
53 |
+
|
54 |
+
if self.args.save_influence_map:
|
55 |
+
self.logger.info('-------- Argument save_influence_map is set. Will use GR to compute and save influence maps ----------\n')
|
56 |
+
# when we want to save influence during training
|
57 |
+
self.args.add_reciprocals = False # to keep things simple
|
58 |
+
# init an empty influence map
|
59 |
+
self.influence_map = defaultdict(float)
|
60 |
+
#self.influence_path = 'influence_maps/{0}_{1}.json'.format(args.data, self.model_name)
|
61 |
+
self.influence_path = 'influence_maps/{0}_{1}.pickle'.format(args.data, self.model_name)
|
62 |
+
# Initialize a copy of the model prams to track previous weights in an epoch
|
63 |
+
self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
|
64 |
+
self.logger.info('Shape for previous weights: {}, {}'.format(self.previous_weights[0].shape, self.previous_weights[1].shape))
|
65 |
+
|
66 |
+
def load_data(self):
|
67 |
+
'''
|
68 |
+
Load the train, valid datasets
|
69 |
+
'''
|
70 |
+
data_path = os.path.join('processed_data', self.args.data)
|
71 |
+
|
72 |
+
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
73 |
+
self.n_ent = n_ent
|
74 |
+
self.n_rel = n_rel
|
75 |
+
|
76 |
+
self.train_data = utils.load_data(os.path.join(data_path, 'evaluation', 'all_{}.txt'.format(self.args.thread_name)))
|
77 |
+
|
78 |
+
self.valid_data= self.train_data[-100:, :].copy()
|
79 |
+
# self.train_data = utils.load_data()
|
80 |
+
|
81 |
+
def add_model(self):
|
82 |
+
|
83 |
+
if self.args.model is None:
|
84 |
+
model = Distmult(self.args, self.n_ent, self.n_rel)
|
85 |
+
elif self.args.model == 'distmult':
|
86 |
+
model = Distmult(self.args, self.n_ent, self.n_rel)
|
87 |
+
elif self.args.model == 'complex':
|
88 |
+
model = Complex(self.args, self.n_ent, self.n_rel)
|
89 |
+
elif self.args.model == 'conve':
|
90 |
+
model = Conve(self.args, self.n_ent, self.n_rel)
|
91 |
+
else:
|
92 |
+
self.logger.info('Unknown model: {0}', self.args.model)
|
93 |
+
raise Exception("Unknown model!")
|
94 |
+
model.to(self.device)
|
95 |
+
return model
|
96 |
+
|
97 |
+
def add_optimizer(self, parameters):
|
98 |
+
#if self.args.optimizer == 'adam' : return torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
|
99 |
+
#else : return torch.optim.SGD(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
|
100 |
+
return torch.optim.Adam(parameters, lr=self.args.lr, weight_decay=self.args.lr_decay)
|
101 |
+
|
102 |
+
def save_model(self):
|
103 |
+
state = {
|
104 |
+
'state_dict': self.model.state_dict(),
|
105 |
+
'optimizer': self.optimizer.state_dict(),
|
106 |
+
'args': vars(self.args)
|
107 |
+
}
|
108 |
+
torch.save(state, self.model_path)
|
109 |
+
self.logger.info('Saving model to {0}'.format(self.model_path))
|
110 |
+
|
111 |
+
def load_model(self):
|
112 |
+
self.logger.info('Loading saved model from {0}'.format(self.model_path))
|
113 |
+
state = torch.load(self.model_path)
|
114 |
+
model_params = state['state_dict']
|
115 |
+
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
116 |
+
for key, size, count in params:
|
117 |
+
self.logger.info(key, size, count)
|
118 |
+
self.model.load_state_dict(model_params)
|
119 |
+
self.optimizer.load_state_dict(state['optimizer'])
|
120 |
+
|
121 |
+
def lp_regularizer(self):
|
122 |
+
# Apply p-norm regularization; assign weights to each param
|
123 |
+
weight = self.args.reg_weight
|
124 |
+
p = self.args.reg_norm
|
125 |
+
|
126 |
+
trainable_params = [self.model.emb_e.weight, self.model.emb_rel.weight]
|
127 |
+
norm = 0
|
128 |
+
for i in range(len(trainable_params)):
|
129 |
+
#norm += weight * trainable_params[i].norm(p = p)**p
|
130 |
+
norm += weight * torch.sum( torch.abs(trainable_params[i]) ** p)
|
131 |
+
|
132 |
+
return norm
|
133 |
+
|
134 |
+
def n3_regularizer(self, factors):
|
135 |
+
# factors are the embeddings for lhs, rel, rhs for triples in a batch
|
136 |
+
weight = self.args.reg_weight
|
137 |
+
p = self.args.reg_norm
|
138 |
+
|
139 |
+
norm = 0
|
140 |
+
for f in factors:
|
141 |
+
norm += weight * torch.sum(torch.abs(f) ** p)
|
142 |
+
|
143 |
+
return norm / factors[0].shape[0] # scale by number of triples in batch
|
144 |
+
|
145 |
+
def get_influence_map(self):
|
146 |
+
"""
|
147 |
+
Turns the influence map into a list, ready to be written to disc. (before: numpy)
|
148 |
+
:return: the influence map with lists as values
|
149 |
+
"""
|
150 |
+
assert self.args.save_influence_map == True
|
151 |
+
|
152 |
+
for key in self.influence_map:
|
153 |
+
self.influence_map[key] = self.influence_map[key].tolist()
|
154 |
+
#self.logger.info('get_influence_map passed')
|
155 |
+
return self.influence_map
|
156 |
+
|
157 |
+
def evaluate(self, split, batch_size, epoch):
|
158 |
+
"""
|
159 |
+
The same as self.run_epoch()
|
160 |
+
"""
|
161 |
+
|
162 |
+
self.model.eval()
|
163 |
+
losses = []
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
input_data = torch.from_numpy(self.valid_data.astype('int64'))
|
167 |
+
actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
|
168 |
+
del input_data
|
169 |
+
|
170 |
+
batch_size = self.args.valid_batch_size
|
171 |
+
for b_begin in range(0, actual_examples.shape[0], batch_size):
|
172 |
+
|
173 |
+
input_batch = actual_examples[b_begin: b_begin + batch_size]
|
174 |
+
input_batch = input_batch.to(self.device)
|
175 |
+
|
176 |
+
s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
|
177 |
+
|
178 |
+
emb_s = self.model.emb_e(s).squeeze(dim=1)
|
179 |
+
emb_r = self.model.emb_rel(r).squeeze(dim=1)
|
180 |
+
emb_o = self.model.emb_e(o).squeeze(dim=1)
|
181 |
+
|
182 |
+
if self.args.add_reciprocals:
|
183 |
+
r_rev = r + self.n_rel
|
184 |
+
emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
|
185 |
+
else:
|
186 |
+
r_rev = r
|
187 |
+
emb_rrev = emb_r
|
188 |
+
|
189 |
+
pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
|
190 |
+
loss_sr = self.model.loss(pred_sr, o) # cross entropy loss
|
191 |
+
|
192 |
+
pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
|
193 |
+
loss_or = self.model.loss(pred_or, s)
|
194 |
+
|
195 |
+
total_loss = loss_sr + loss_or
|
196 |
+
|
197 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
|
198 |
+
#self.logger.info('Computing regularizer weight')
|
199 |
+
if self.args.model == 'complex':
|
200 |
+
emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
|
201 |
+
lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
|
202 |
+
rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
|
203 |
+
rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
|
204 |
+
rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
|
205 |
+
|
206 |
+
#print(lhs[0].shape, lhs[1].shape)
|
207 |
+
factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
208 |
+
torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
|
209 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
|
210 |
+
)
|
211 |
+
factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
212 |
+
torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
|
213 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2)
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
factors_sr = (emb_s, emb_r, emb_o)
|
217 |
+
factors_or = (emb_s, emb_rrev, emb_o)
|
218 |
+
|
219 |
+
total_loss += self.n3_regularizer(factors_sr)
|
220 |
+
total_loss += self.n3_regularizer(factors_or)
|
221 |
+
|
222 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
|
223 |
+
total_loss += self.lp_regularizer()
|
224 |
+
|
225 |
+
losses.append(total_loss.item())
|
226 |
+
|
227 |
+
loss = np.mean(losses)
|
228 |
+
self.logger.info('[Epoch:{}]: Validating Loss:{:.6}\n'.format(epoch, loss))
|
229 |
+
return loss
|
230 |
+
|
231 |
+
|
232 |
+
def run_epoch(self, epoch):
|
233 |
+
self.model.train()
|
234 |
+
losses = []
|
235 |
+
|
236 |
+
#shuffle the train dataset
|
237 |
+
input_data = torch.from_numpy(self.train_data.astype('int64'))
|
238 |
+
actual_examples = input_data[torch.randperm(input_data.shape[0]), :]
|
239 |
+
del input_data
|
240 |
+
|
241 |
+
batch_size = self.args.train_batch_size
|
242 |
+
|
243 |
+
for b_begin in range(0, actual_examples.shape[0], batch_size):
|
244 |
+
self.optimizer.zero_grad()
|
245 |
+
input_batch = actual_examples[b_begin: b_begin + batch_size]
|
246 |
+
input_batch = input_batch.to(self.device)
|
247 |
+
|
248 |
+
s,r,o = input_batch[:,0], input_batch[:,1], input_batch[:,2]
|
249 |
+
|
250 |
+
emb_s = self.model.emb_e(s).squeeze(dim=1)
|
251 |
+
emb_r = self.model.emb_rel(r).squeeze(dim=1)
|
252 |
+
emb_o = self.model.emb_e(o).squeeze(dim=1)
|
253 |
+
|
254 |
+
if self.args.add_reciprocals:
|
255 |
+
r_rev = r + self.n_rel
|
256 |
+
emb_rrev = self.model.emb_rel(r_rev).squeeze(dim=1)
|
257 |
+
else:
|
258 |
+
r_rev = r
|
259 |
+
emb_rrev = emb_r
|
260 |
+
|
261 |
+
pred_sr = self.model.forward(emb_s, emb_r, mode='rhs')
|
262 |
+
loss_sr = self.model.loss(pred_sr, o) # loss is cross entropy loss
|
263 |
+
|
264 |
+
pred_or = self.model.forward(emb_o, emb_rrev, mode='lhs')
|
265 |
+
loss_or = self.model.loss(pred_or, s)
|
266 |
+
|
267 |
+
total_loss = loss_sr + loss_or
|
268 |
+
|
269 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 3):
|
270 |
+
#self.logger.info('Computing regularizer weight')
|
271 |
+
if self.args.model == 'complex':
|
272 |
+
emb_dim = self.args.embedding_dim #int(self.args.embedding_dim/2)
|
273 |
+
lhs = (emb_s[:, :emb_dim], emb_s[:, emb_dim:])
|
274 |
+
rel = (emb_r[:, :emb_dim], emb_r[:, emb_dim:])
|
275 |
+
rel_rev = (emb_rrev[:, :emb_dim], emb_rrev[:, emb_dim:])
|
276 |
+
rhs = (emb_o[:, :emb_dim], emb_o[:, emb_dim:])
|
277 |
+
|
278 |
+
#print(lhs[0].shape, lhs[1].shape)
|
279 |
+
factors_sr = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
280 |
+
torch.sqrt(rel[0] ** 2 + rel[1] ** 2),
|
281 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
|
282 |
+
factors_or = (torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
|
283 |
+
torch.sqrt(rel_rev[0] ** 2 + rel_rev[1] ** 2),
|
284 |
+
torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2))
|
285 |
+
else:
|
286 |
+
factors_sr = (emb_s, emb_r, emb_o)
|
287 |
+
factors_or = (emb_s, emb_rrev, emb_o)
|
288 |
+
|
289 |
+
total_loss += self.n3_regularizer(factors_sr)
|
290 |
+
total_loss += self.n3_regularizer(factors_or)
|
291 |
+
|
292 |
+
if (self.args.reg_weight != 0.0 and self.args.reg_norm == 2):
|
293 |
+
total_loss += self.lp_regularizer()
|
294 |
+
|
295 |
+
|
296 |
+
total_loss.backward()
|
297 |
+
self.optimizer.step()
|
298 |
+
losses.append(total_loss.item())
|
299 |
+
|
300 |
+
if self.args.save_influence_map: #for gradient rollback
|
301 |
+
with torch.no_grad():
|
302 |
+
prev_emb_e = self.previous_weights[0]
|
303 |
+
prev_emb_rel = self.previous_weights[1]
|
304 |
+
# need to compute the influence value per-triple
|
305 |
+
for idx in range(input_batch.shape[0]):
|
306 |
+
head, rel, tail = s[idx], r[idx], o[idx]
|
307 |
+
inf_head = (emb_s[idx] - prev_emb_e[head]).cpu().detach().numpy()
|
308 |
+
inf_tail = (emb_o[idx] - prev_emb_e[tail]).cpu().detach().numpy()
|
309 |
+
inf_rel = (emb_r[idx] - prev_emb_rel[rel]).cpu().detach().numpy()
|
310 |
+
#print(inf_head.shape, inf_tail.shape, inf_rel.shape)
|
311 |
+
|
312 |
+
#write the influences to dictionary
|
313 |
+
key_trip = '{0}_{1}_{2}'.format(head.item(), rel.item(), tail.item())
|
314 |
+
key = '{0}_s'.format(key_trip)
|
315 |
+
self.influence_map[key] += inf_head
|
316 |
+
#self.logger.info('Written to influence map. Key: {0}, Value shape: {1}'.format(key, inf_head.shape))
|
317 |
+
key = '{0}_r'.format(key_trip)
|
318 |
+
self.influence_map[key] += inf_rel
|
319 |
+
key = '{0}_o'.format(key_trip)
|
320 |
+
self.influence_map[key] += inf_tail
|
321 |
+
|
322 |
+
# update the previous weights to be tracked
|
323 |
+
self.previous_weights = [copy.deepcopy(param) for param in self.model.parameters()]
|
324 |
+
|
325 |
+
if (b_begin%5000 == 0) or (b_begin== (actual_examples.shape[0]-1)):
|
326 |
+
self.logger.info('[E:{} | {}]: Train Loss:{:.6}'.format(epoch, b_begin, np.mean(losses)))
|
327 |
+
|
328 |
+
loss = np.mean(losses)
|
329 |
+
self.logger.info('[Epoch:{}]: Training Loss:{:.6}\n'.format(epoch, loss))
|
330 |
+
return loss
|
331 |
+
|
332 |
+
def fit(self):
|
333 |
+
# if self.args.resume:
|
334 |
+
# self.load_model()
|
335 |
+
# results = self.evaluate(split=self.args.resume_split, batch_size = self.args.test_batch_size, epoch = -1)
|
336 |
+
# pprint(results)
|
337 |
+
|
338 |
+
# else:
|
339 |
+
self.model.init()
|
340 |
+
self.logger.info(self.model)
|
341 |
+
|
342 |
+
self.logger.info('------ Start the model training ------')
|
343 |
+
start_time = time.time()
|
344 |
+
self.logger.info('Start time: {0}'.format(str(start_time)))
|
345 |
+
|
346 |
+
|
347 |
+
train_losses = []
|
348 |
+
valid_losses = []
|
349 |
+
best_val = 10000000000.
|
350 |
+
for epoch in range(self.args.epochs):
|
351 |
+
|
352 |
+
train_loss = self.run_epoch(epoch)
|
353 |
+
train_losses.append(train_loss)
|
354 |
+
|
355 |
+
# Don't use valid_data here !!!!!!!!!
|
356 |
+
|
357 |
+
# valid_loss = self.evaluate(split='valid', batch_size = self.args.valid_batch_size, epoch = epoch)
|
358 |
+
# valid_losses.append(valid_loss)
|
359 |
+
# results_test = self.evaluate(split='test', batch_size = self.args.test_batch_size, epoch = epoch)
|
360 |
+
if train_loss < best_val:
|
361 |
+
best_val = train_loss
|
362 |
+
self.save_model()
|
363 |
+
self.logger.info("Train loss: {0}, Best loss: {1}\n\n".format(train_loss, best_val))
|
364 |
+
# print("Valid loss: {0}, Best loss: {1}\n\n".format(valid_loss, best_val))
|
365 |
+
|
366 |
+
with open(self.loss_path, "wb") as fl:
|
367 |
+
pkl.dump({"train loss":train_losses, "valid loss":valid_losses}, fl)
|
368 |
+
self.logger.info('Time taken to train the model: {0}'.format(str(time.time() - start_time)))
|
369 |
+
start_time = time.time()
|
370 |
+
|
371 |
+
if self.args.save_influence_map: #save the influence map
|
372 |
+
with open(self.influence_path, "wb") as fl: #Pickling
|
373 |
+
pkl.dump(self.get_influence_map(), fl)
|
374 |
+
self.logger.info('Finished saving influence map')
|
375 |
+
self.logger.info('Time taken to save the influence map: {0}'.format(str(time.time() - start_time)))
|
376 |
+
|
377 |
+
#%%
|
378 |
+
parser = utils.get_argument_parser()
|
379 |
+
parser.add_argument('--thread-name', type = str, required=True, help = "This parameter will be automatically determined.")
|
380 |
+
|
381 |
+
args = parser.parse_args()
|
382 |
+
args = utils.set_hyperparams(args)
|
383 |
+
# if args.reproduce_results:
|
384 |
+
# args = utils.set_hyperparams(args)
|
385 |
+
|
386 |
+
utils.seed_all(args.seed)
|
387 |
+
np.set_printoptions(precision=5)
|
388 |
+
cudnn.benchmark = False
|
389 |
+
|
390 |
+
model = Main(args)
|
391 |
+
model.fit()
|
DiseaseSpecific/model.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn import functional as F, Parameter
|
3 |
+
from torch.autograd import Variable
|
4 |
+
from torch.nn.init import xavier_normal_, xavier_uniform_
|
5 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
6 |
+
|
7 |
+
class Distmult(torch.nn.Module):
|
8 |
+
def __init__(self, args, num_entities, num_relations):
|
9 |
+
super(Distmult, self).__init__()
|
10 |
+
|
11 |
+
if args.max_norm:
|
12 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
|
13 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
|
14 |
+
else:
|
15 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
|
16 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
|
17 |
+
|
18 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
19 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
20 |
+
|
21 |
+
self.init()
|
22 |
+
|
23 |
+
def init(self):
|
24 |
+
xavier_normal_(self.emb_e.weight)
|
25 |
+
xavier_normal_(self.emb_rel.weight)
|
26 |
+
|
27 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
28 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
29 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
30 |
+
|
31 |
+
#sub_emb = self.inp_drop(sub_emb)
|
32 |
+
#rel_emb = self.inp_drop(rel_emb)
|
33 |
+
|
34 |
+
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
35 |
+
if sigmoid:
|
36 |
+
pred = torch.sigmoid(pred)
|
37 |
+
return pred
|
38 |
+
|
39 |
+
def score_or(self, obj, rel, sigmoid = False):
|
40 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
41 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
42 |
+
|
43 |
+
#obj_emb = self.inp_drop(obj_emb)
|
44 |
+
#rel_emb = self.inp_drop(rel_emb)
|
45 |
+
|
46 |
+
pred = torch.mm(obj_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
47 |
+
if sigmoid:
|
48 |
+
pred = torch.sigmoid(pred)
|
49 |
+
return pred
|
50 |
+
|
51 |
+
|
52 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
53 |
+
'''
|
54 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
55 |
+
For distmult, computations for both modes are equivalent, so we do not need if-else block
|
56 |
+
'''
|
57 |
+
sub_emb = self.inp_drop(sub_emb)
|
58 |
+
rel_emb = self.inp_drop(rel_emb)
|
59 |
+
|
60 |
+
pred = torch.mm(sub_emb*rel_emb, self.emb_e.weight.transpose(1,0))
|
61 |
+
|
62 |
+
if sigmoid:
|
63 |
+
pred = torch.sigmoid(pred)
|
64 |
+
|
65 |
+
return pred
|
66 |
+
|
67 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
68 |
+
'''
|
69 |
+
Inputs - subject, relation, object
|
70 |
+
Return - score
|
71 |
+
'''
|
72 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
73 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
74 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
75 |
+
|
76 |
+
pred = torch.sum(sub_emb*rel_emb*obj_emb, dim=-1)
|
77 |
+
|
78 |
+
if sigmoid:
|
79 |
+
pred = torch.sigmoid(pred)
|
80 |
+
|
81 |
+
return pred
|
82 |
+
|
83 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
84 |
+
'''
|
85 |
+
Inputs - embeddings of subject, relation, object
|
86 |
+
Return - score
|
87 |
+
'''
|
88 |
+
pred = torch.sum(emb_s*emb_r*emb_o, dim=-1)
|
89 |
+
|
90 |
+
if sigmoid:
|
91 |
+
pred = torch.sigmoid(pred)
|
92 |
+
|
93 |
+
return pred
|
94 |
+
|
95 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
96 |
+
'''
|
97 |
+
Inputs - subject, relation, object
|
98 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
99 |
+
'''
|
100 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
101 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
102 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
103 |
+
|
104 |
+
pred = sub_emb*rel_emb*obj_emb
|
105 |
+
|
106 |
+
if sigmoid:
|
107 |
+
pred = torch.sigmoid(pred)
|
108 |
+
|
109 |
+
return pred
|
110 |
+
|
111 |
+
class Complex(torch.nn.Module):
|
112 |
+
def __init__(self, args, num_entities, num_relations):
|
113 |
+
super(Complex, self).__init__()
|
114 |
+
|
115 |
+
if args.max_norm:
|
116 |
+
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, max_norm=1.0)
|
117 |
+
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim)
|
118 |
+
else:
|
119 |
+
self.emb_e = torch.nn.Embedding(num_entities, 2*args.embedding_dim, padding_idx=None)
|
120 |
+
self.emb_rel = torch.nn.Embedding(num_relations, 2*args.embedding_dim, padding_idx=None)
|
121 |
+
|
122 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
123 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
124 |
+
|
125 |
+
self.init()
|
126 |
+
|
127 |
+
def init(self):
|
128 |
+
xavier_normal_(self.emb_e.weight)
|
129 |
+
xavier_normal_(self.emb_rel.weight)
|
130 |
+
|
131 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
132 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
133 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
134 |
+
|
135 |
+
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
|
136 |
+
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
|
137 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
138 |
+
|
139 |
+
realo_realreal = s_real*rel_real
|
140 |
+
realo_imgimg = s_img*rel_img
|
141 |
+
realo = realo_realreal - realo_imgimg
|
142 |
+
real = torch.mm(realo, emb_e_real.transpose(1,0))
|
143 |
+
|
144 |
+
imgo_realimg = s_real*rel_img
|
145 |
+
imgo_imgreal = s_img*rel_real
|
146 |
+
imgo = imgo_realimg + imgo_imgreal
|
147 |
+
img = torch.mm(imgo, emb_e_img.transpose(1,0))
|
148 |
+
|
149 |
+
pred = real + img
|
150 |
+
|
151 |
+
if sigmoid:
|
152 |
+
pred = torch.sigmoid(pred)
|
153 |
+
return pred
|
154 |
+
|
155 |
+
|
156 |
+
def score_or(self, obj, rel, sigmoid = False):
|
157 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
158 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
159 |
+
|
160 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
161 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
162 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
163 |
+
|
164 |
+
#rel_real = self.inp_drop(rel_real)
|
165 |
+
#rel_img = self.inp_drop(rel_img)
|
166 |
+
#o_real = self.inp_drop(o_real)
|
167 |
+
#o_img = self.inp_drop(o_img)
|
168 |
+
|
169 |
+
# complex space bilinear product (equivalent to HolE)
|
170 |
+
# realrealreal = torch.mm(rel_real*o_real, emb_e_real.transpose(1,0))
|
171 |
+
# realimgimg = torch.mm(rel_img*o_img, emb_e_real.transpose(1,0))
|
172 |
+
# imgrealimg = torch.mm(rel_real*o_img, emb_e_img.transpose(1,0))
|
173 |
+
# imgimgreal = torch.mm(rel_img*o_real, emb_e_img.transpose(1,0))
|
174 |
+
# pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
175 |
+
|
176 |
+
reals_realreal = rel_real*o_real
|
177 |
+
reals_imgimg = rel_img*o_img
|
178 |
+
reals = reals_realreal + reals_imgimg
|
179 |
+
real = torch.mm(reals, emb_e_real.transpose(1,0))
|
180 |
+
|
181 |
+
imgs_realimg = rel_real*o_img
|
182 |
+
imgs_imgreal = rel_img*o_real
|
183 |
+
imgs = imgs_realimg - imgs_imgreal
|
184 |
+
img = torch.mm(imgs, emb_e_img.transpose(1,0))
|
185 |
+
|
186 |
+
pred = real + img
|
187 |
+
|
188 |
+
if sigmoid:
|
189 |
+
pred = torch.sigmoid(pred)
|
190 |
+
return pred
|
191 |
+
|
192 |
+
|
193 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
194 |
+
'''
|
195 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
196 |
+
|
197 |
+
'''
|
198 |
+
if mode == 'lhs':
|
199 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
200 |
+
o_real, o_img = torch.chunk(sub_emb, 2, dim=-1)
|
201 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
202 |
+
|
203 |
+
rel_real = self.inp_drop(rel_real)
|
204 |
+
rel_img = self.inp_drop(rel_img)
|
205 |
+
o_real = self.inp_drop(o_real)
|
206 |
+
o_img = self.inp_drop(o_img)
|
207 |
+
|
208 |
+
reals_realreal = rel_real*o_real
|
209 |
+
reals_imgimg = rel_img*o_img
|
210 |
+
reals = reals_realreal + reals_imgimg
|
211 |
+
real = torch.mm(reals, emb_e_real.transpose(1,0))
|
212 |
+
|
213 |
+
imgs_realimg = rel_real*o_img
|
214 |
+
imgs_imgreal = rel_img*o_real
|
215 |
+
imgs = imgs_realimg - imgs_imgreal
|
216 |
+
img = torch.mm(imgs, emb_e_img.transpose(1,0))
|
217 |
+
|
218 |
+
pred = real + img
|
219 |
+
|
220 |
+
else:
|
221 |
+
s_real, s_img = torch.chunk(rel_emb, 2, dim=-1)
|
222 |
+
rel_real, rel_img = torch.chunk(sub_emb, 2, dim=-1)
|
223 |
+
emb_e_real, emb_e_img = torch.chunk(self.emb_e.weight, 2, dim=-1)
|
224 |
+
|
225 |
+
s_real = self.inp_drop(s_real)
|
226 |
+
s_img = self.inp_drop(s_img)
|
227 |
+
rel_real = self.inp_drop(rel_real)
|
228 |
+
rel_img = self.inp_drop(rel_img)
|
229 |
+
|
230 |
+
realo_realreal = s_real*rel_real
|
231 |
+
realo_imgimg = s_img*rel_img
|
232 |
+
realo = realo_realreal - realo_imgimg
|
233 |
+
real = torch.mm(realo, emb_e_real.transpose(1,0))
|
234 |
+
|
235 |
+
imgo_realimg = s_real*rel_img
|
236 |
+
imgo_imgreal = s_img*rel_real
|
237 |
+
imgo = imgo_realimg + imgo_imgreal
|
238 |
+
img = torch.mm(imgo, emb_e_img.transpose(1,0))
|
239 |
+
|
240 |
+
pred = real + img
|
241 |
+
|
242 |
+
if sigmoid:
|
243 |
+
pred = torch.sigmoid(pred)
|
244 |
+
|
245 |
+
return pred
|
246 |
+
|
247 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
248 |
+
'''
|
249 |
+
Inputs - subject, relation, object
|
250 |
+
Return - score
|
251 |
+
'''
|
252 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
253 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
254 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
255 |
+
|
256 |
+
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
|
257 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
258 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
259 |
+
|
260 |
+
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
|
261 |
+
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
|
262 |
+
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
|
263 |
+
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
|
264 |
+
|
265 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
266 |
+
|
267 |
+
if sigmoid:
|
268 |
+
pred = torch.sigmoid(pred)
|
269 |
+
|
270 |
+
return pred
|
271 |
+
|
272 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
273 |
+
'''
|
274 |
+
Inputs - embeddings of subject, relation, object
|
275 |
+
Return - score
|
276 |
+
'''
|
277 |
+
|
278 |
+
s_real, s_img = torch.chunk(emb_s, 2, dim=-1)
|
279 |
+
rel_real, rel_img = torch.chunk(emb_r, 2, dim=-1)
|
280 |
+
o_real, o_img = torch.chunk(emb_o, 2, dim=-1)
|
281 |
+
|
282 |
+
realrealreal = torch.sum(s_real*rel_real*o_real, dim=-1)
|
283 |
+
realimgimg = torch.sum(s_real*rel_img*o_img, axis=-1)
|
284 |
+
imgrealimg = torch.sum(s_img*rel_real*o_img, axis=-1)
|
285 |
+
imgimgreal = torch.sum(s_img*rel_img*o_real, axis=-1)
|
286 |
+
|
287 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
288 |
+
|
289 |
+
if sigmoid:
|
290 |
+
pred = torch.sigmoid(pred)
|
291 |
+
|
292 |
+
return pred
|
293 |
+
|
294 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
295 |
+
'''
|
296 |
+
Inputs - subject, relation, object
|
297 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
298 |
+
'''
|
299 |
+
sub_emb = self.emb_e(sub).squeeze(dim=1)
|
300 |
+
rel_emb = self.emb_rel(rel).squeeze(dim=1)
|
301 |
+
obj_emb = self.emb_e(obj).squeeze(dim=1)
|
302 |
+
|
303 |
+
s_real, s_img = torch.chunk(sub_emb, 2, dim=-1)
|
304 |
+
rel_real, rel_img = torch.chunk(rel_emb, 2, dim=-1)
|
305 |
+
o_real, o_img = torch.chunk(obj_emb, 2, dim=-1)
|
306 |
+
|
307 |
+
realrealreal = s_real*rel_real*o_real
|
308 |
+
realimgimg = s_real*rel_img*o_img
|
309 |
+
imgrealimg = s_img*rel_real*o_img
|
310 |
+
imgimgreal = s_img*rel_img*o_real
|
311 |
+
|
312 |
+
pred = realrealreal + realimgimg + imgrealimg - imgimgreal
|
313 |
+
|
314 |
+
if sigmoid:
|
315 |
+
pred = torch.sigmoid(pred)
|
316 |
+
|
317 |
+
return pred
|
318 |
+
|
319 |
+
class Conve(torch.nn.Module):
|
320 |
+
|
321 |
+
#Too slow !!!!
|
322 |
+
|
323 |
+
def __init__(self, args, num_entities, num_relations):
|
324 |
+
super(Conve, self).__init__()
|
325 |
+
|
326 |
+
if args.max_norm:
|
327 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, max_norm=1.0)
|
328 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim)
|
329 |
+
else:
|
330 |
+
self.emb_e = torch.nn.Embedding(num_entities, args.embedding_dim, padding_idx=None)
|
331 |
+
self.emb_rel = torch.nn.Embedding(num_relations, args.embedding_dim, padding_idx=None)
|
332 |
+
|
333 |
+
self.inp_drop = torch.nn.Dropout(args.input_drop)
|
334 |
+
self.hidden_drop = torch.nn.Dropout(args.hidden_drop)
|
335 |
+
self.feature_drop = torch.nn.Dropout2d(args.feat_drop)
|
336 |
+
|
337 |
+
self.embedding_dim = args.embedding_dim #default is 200
|
338 |
+
self.num_filters = args.num_filters # default is 32
|
339 |
+
self.kernel_size = args.kernel_size # default is 3
|
340 |
+
self.stack_width = args.stack_width # default is 20
|
341 |
+
self.stack_height = args.embedding_dim // self.stack_width
|
342 |
+
|
343 |
+
self.bn0 = torch.nn.BatchNorm2d(1)
|
344 |
+
self.bn1 = torch.nn.BatchNorm2d(self.num_filters)
|
345 |
+
self.bn2 = torch.nn.BatchNorm1d(args.embedding_dim)
|
346 |
+
|
347 |
+
self.conv1 = torch.nn.Conv2d(1, out_channels=self.num_filters,
|
348 |
+
kernel_size=(self.kernel_size, self.kernel_size),
|
349 |
+
stride=1, padding=0, bias=args.use_bias)
|
350 |
+
#self.conv1 = torch.nn.Conv2d(1, 32, (3, 3), 1, 0, bias=args.use_bias) # <-- default
|
351 |
+
|
352 |
+
flat_sz_h = int(2*self.stack_width) - self.kernel_size + 1
|
353 |
+
flat_sz_w = self.stack_height - self.kernel_size + 1
|
354 |
+
self.flat_sz = flat_sz_h*flat_sz_w*self.num_filters
|
355 |
+
self.fc = torch.nn.Linear(self.flat_sz, args.embedding_dim)
|
356 |
+
|
357 |
+
self.register_parameter('b', Parameter(torch.zeros(num_entities)))
|
358 |
+
self.loss = torch.nn.CrossEntropyLoss()
|
359 |
+
|
360 |
+
self.init()
|
361 |
+
|
362 |
+
def init(self):
|
363 |
+
xavier_normal_(self.emb_e.weight)
|
364 |
+
xavier_normal_(self.emb_rel.weight)
|
365 |
+
|
366 |
+
def concat(self, e1_embed, rel_embed, form='plain'):
|
367 |
+
if form == 'plain':
|
368 |
+
e1_embed = e1_embed. view(-1, 1, self.stack_width, self.stack_height)
|
369 |
+
rel_embed = rel_embed.view(-1, 1, self.stack_width, self.stack_height)
|
370 |
+
stack_inp = torch.cat([e1_embed, rel_embed], 2)
|
371 |
+
|
372 |
+
elif form == 'alternate':
|
373 |
+
e1_embed = e1_embed. view(-1, 1, self.embedding_dim)
|
374 |
+
rel_embed = rel_embed.view(-1, 1, self.embedding_dim)
|
375 |
+
stack_inp = torch.cat([e1_embed, rel_embed], 1)
|
376 |
+
stack_inp = torch.transpose(stack_inp, 2, 1).reshape((-1, 1, 2*self.stack_width, self.stack_height))
|
377 |
+
|
378 |
+
else: raise NotImplementedError
|
379 |
+
return stack_inp
|
380 |
+
|
381 |
+
def conve_architecture(self, sub_emb, rel_emb):
|
382 |
+
stacked_inputs = self.concat(sub_emb, rel_emb)
|
383 |
+
stacked_inputs = self.bn0(stacked_inputs)
|
384 |
+
x = self.inp_drop(stacked_inputs)
|
385 |
+
x = self.conv1(x)
|
386 |
+
x = self.bn1(x)
|
387 |
+
x = F.relu(x)
|
388 |
+
x = self.feature_drop(x)
|
389 |
+
#x = x.view(x.shape[0], -1)
|
390 |
+
x = x.view(-1, self.flat_sz)
|
391 |
+
x = self.fc(x)
|
392 |
+
x = self.hidden_drop(x)
|
393 |
+
x = self.bn2(x)
|
394 |
+
x = F.relu(x)
|
395 |
+
|
396 |
+
return x
|
397 |
+
|
398 |
+
def score_sr(self, sub, rel, sigmoid = False):
|
399 |
+
sub_emb = self.emb_e(sub)
|
400 |
+
rel_emb = self.emb_rel(rel)
|
401 |
+
|
402 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
403 |
+
|
404 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
405 |
+
pred += self.b.expand_as(pred)
|
406 |
+
|
407 |
+
if sigmoid:
|
408 |
+
pred = torch.sigmoid(pred)
|
409 |
+
return pred
|
410 |
+
|
411 |
+
def score_or(self, obj, rel, sigmoid = False):
|
412 |
+
obj_emb = self.emb_e(obj)
|
413 |
+
rel_emb = self.emb_rel(rel)
|
414 |
+
|
415 |
+
x = self.conve_architecture(obj_emb, rel_emb)
|
416 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
417 |
+
pred += self.b.expand_as(pred)
|
418 |
+
|
419 |
+
if sigmoid:
|
420 |
+
pred = torch.sigmoid(pred)
|
421 |
+
return pred
|
422 |
+
|
423 |
+
|
424 |
+
def forward(self, sub_emb, rel_emb, mode='rhs', sigmoid=False):
|
425 |
+
'''
|
426 |
+
When mode is 'rhs' we expect (s,r); for 'lhs', we expect (o,r)
|
427 |
+
For conve, computations for both modes are equivalent, so we do not need if-else block
|
428 |
+
'''
|
429 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
430 |
+
|
431 |
+
pred = torch.mm(x, self.emb_e.weight.transpose(1,0))
|
432 |
+
pred += self.b.expand_as(pred)
|
433 |
+
|
434 |
+
if sigmoid:
|
435 |
+
pred = torch.sigmoid(pred)
|
436 |
+
|
437 |
+
return pred
|
438 |
+
|
439 |
+
def score_triples(self, sub, rel, obj, sigmoid=False):
|
440 |
+
'''
|
441 |
+
Inputs - subject, relation, object
|
442 |
+
Return - score
|
443 |
+
'''
|
444 |
+
sub_emb = self.emb_e(sub)
|
445 |
+
rel_emb = self.emb_rel(rel)
|
446 |
+
obj_emb = self.emb_e(obj)
|
447 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
448 |
+
|
449 |
+
pred = torch.mm(x, obj_emb.transpose(1,0))
|
450 |
+
#print(pred.shape)
|
451 |
+
pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding
|
452 |
+
# above works fine for single input triples;
|
453 |
+
# but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
|
454 |
+
# so use torch.diagonal() after calling this function
|
455 |
+
pred = torch.diagonal(pred)
|
456 |
+
# or could have used : pred= torch.sum(x*obj_emb, dim=-1)
|
457 |
+
|
458 |
+
if sigmoid:
|
459 |
+
pred = torch.sigmoid(pred)
|
460 |
+
|
461 |
+
return pred
|
462 |
+
|
463 |
+
def score_emb(self, emb_s, emb_r, emb_o, sigmoid=False):
|
464 |
+
'''
|
465 |
+
Inputs - embeddings of subject, relation, object
|
466 |
+
Return - score
|
467 |
+
'''
|
468 |
+
x = self.conve_architecture(emb_s, emb_r)
|
469 |
+
|
470 |
+
pred = torch.mm(x, emb_o.transpose(1,0))
|
471 |
+
#pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - don't know which obj
|
472 |
+
# above works fine for single input triples;
|
473 |
+
# but if input is batch of triples, then this is a matrix of (num_trip x num_trip) where diagonal is scores
|
474 |
+
# so use torch.diagonal() after calling this function
|
475 |
+
pred = torch.diagonal(pred)
|
476 |
+
# or could have used : pred= torch.sum(x*obj_emb, dim=-1)
|
477 |
+
|
478 |
+
if sigmoid:
|
479 |
+
pred = torch.sigmoid(pred)
|
480 |
+
|
481 |
+
return pred
|
482 |
+
|
483 |
+
def score_triples_vec(self, sub, rel, obj, sigmoid=False):
|
484 |
+
'''
|
485 |
+
Inputs - subject, relation, object
|
486 |
+
Return - a vector score for the triple instead of reducing over the embedding dimension
|
487 |
+
'''
|
488 |
+
sub_emb = self.emb_e(sub)
|
489 |
+
rel_emb = self.emb_rel(rel)
|
490 |
+
obj_emb = self.emb_e(obj)
|
491 |
+
|
492 |
+
x = self.conve_architecture(sub_emb, rel_emb)
|
493 |
+
|
494 |
+
#pred = torch.mm(x, obj_emb.transpose(1,0))
|
495 |
+
pred = x*obj_emb
|
496 |
+
#print(pred.shape, self.b[obj].shape) #shapes are [7,200] and [7]
|
497 |
+
#pred += self.b[obj].expand_as(pred) #taking the bias value for object embedding - can't add scalar to vector
|
498 |
+
|
499 |
+
#pred = sub_emb*rel_emb*obj_emb
|
500 |
+
|
501 |
+
if sigmoid:
|
502 |
+
pred = torch.sigmoid(pred)
|
503 |
+
|
504 |
+
return pred
|
DiseaseSpecific/utils.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
A file modified on https://github.com/PeruBhardwaj/AttributionAttack/blob/main/KGEAttack/ConvE/utils.py
|
3 |
+
'''
|
4 |
+
#%%
|
5 |
+
import logging
|
6 |
+
import time
|
7 |
+
from tqdm import tqdm
|
8 |
+
import io
|
9 |
+
import pandas as pd
|
10 |
+
import numpy as np
|
11 |
+
import os
|
12 |
+
import json
|
13 |
+
|
14 |
+
import argparse
|
15 |
+
import torch
|
16 |
+
import random
|
17 |
+
|
18 |
+
from yaml import parse
|
19 |
+
|
20 |
+
from model import Conve, Distmult, Complex
|
21 |
+
|
22 |
+
logger = logging.getLogger(__name__)
|
23 |
+
#%%
|
24 |
+
def generate_dicts(data_path):
|
25 |
+
with open (os.path.join(data_path, 'entities_dict.json'), 'r') as f:
|
26 |
+
ent_to_id = json.load(f)
|
27 |
+
with open (os.path.join(data_path, 'relations_dict.json'), 'r') as f:
|
28 |
+
rel_to_id = json.load(f)
|
29 |
+
n_ent = len(list(ent_to_id.keys()))
|
30 |
+
n_rel = len(list(rel_to_id.keys()))
|
31 |
+
|
32 |
+
return n_ent, n_rel, ent_to_id, rel_to_id
|
33 |
+
|
34 |
+
def save_data(file_name, data):
|
35 |
+
with open(file_name, 'w') as fl:
|
36 |
+
for item in data:
|
37 |
+
fl.write("%s\n" % "\t".join(map(str, item)))
|
38 |
+
|
39 |
+
def load_data(file_name, drop = True):
|
40 |
+
df = pd.read_csv(file_name, sep='\t', header=None, names=None, dtype=str)
|
41 |
+
if drop:
|
42 |
+
df = df.drop_duplicates()
|
43 |
+
else:
|
44 |
+
pass
|
45 |
+
return df.values
|
46 |
+
|
47 |
+
def seed_all(seed=1):
|
48 |
+
random.seed(seed)
|
49 |
+
np.random.seed(seed)
|
50 |
+
torch.manual_seed(seed)
|
51 |
+
torch.cuda.manual_seed_all(seed)
|
52 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
53 |
+
torch.backends.cudnn.deterministic = True
|
54 |
+
|
55 |
+
def add_model(args, n_ent, n_rel):
|
56 |
+
if args.model is None:
|
57 |
+
model = Distmult(args, n_ent, n_rel)
|
58 |
+
elif args.model == 'distmult':
|
59 |
+
model = Distmult(args, n_ent, n_rel)
|
60 |
+
elif args.model == 'complex':
|
61 |
+
model = Complex(args, n_ent, n_rel)
|
62 |
+
elif args.model == 'conve':
|
63 |
+
model = Conve(args, n_ent, n_rel)
|
64 |
+
else:
|
65 |
+
raise Exception("Unknown model!")
|
66 |
+
|
67 |
+
return model
|
68 |
+
|
69 |
+
def load_model(model_path, args, n_ent, n_rel, device):
|
70 |
+
# add a model and load the pre-trained params
|
71 |
+
model = add_model(args, n_ent, n_rel)
|
72 |
+
model.to(device)
|
73 |
+
logger.info('Loading saved model from {0}'.format(model_path))
|
74 |
+
state = torch.load(model_path)
|
75 |
+
model_params = state['state_dict']
|
76 |
+
params = [(key, value.size(), value.numel()) for key, value in model_params.items()]
|
77 |
+
for key, size, count in params:
|
78 |
+
logger.info('Key:{0}, Size:{1}, Count:{2}'.format(key, size, count))
|
79 |
+
|
80 |
+
model.load_state_dict(model_params)
|
81 |
+
model.eval()
|
82 |
+
logger.info(model)
|
83 |
+
|
84 |
+
return model
|
85 |
+
|
86 |
+
def add_eval_parameters(parser):
|
87 |
+
|
88 |
+
# parser.add_argument('--eval-mode', type = str, default = 'all', help = 'Method to evaluate the attack performance. Default: all. (all or single)')
|
89 |
+
parser.add_argument('--cuda-name', type = str, required = True, help = 'Start a main thread on each cuda.')
|
90 |
+
parser.add_argument('--direct', action='store_true', help = 'Directly add edge or not.')
|
91 |
+
parser.add_argument('--seperate', action='store_true', help = 'Evaluate seperatly or not')
|
92 |
+
parser.add_argument('--mode', type = str, default = '', help = ' '' or '' ')
|
93 |
+
parser.add_argument('--mask-ratio', type=str, default='', help='Mask ratio for Fig4b')
|
94 |
+
return parser
|
95 |
+
|
96 |
+
def add_attack_parameters(parser):
|
97 |
+
|
98 |
+
# parser.add_argument('--target-split', type=str, default='0_100_1', help='Ranks to use for target set. Values are 0 for ranks==1; 1 for ranks <=10; 2 for ranks>10 and ranks<=100. Default: 1')
|
99 |
+
parser.add_argument('--target-split', type=str, default='min', help='Methods for target triple selection. Default: min. (min or top_?, top means top_0.1)')
|
100 |
+
parser.add_argument('--target-size', type=int, default=50, help='Number of target triples. Default: 50')
|
101 |
+
parser.add_argument('--target-existed', action='store_true', help='Whether the targeted s_?_o already exists.')
|
102 |
+
|
103 |
+
# parser.add_argument('--budget', type=int, default=1, help='Budget for each target triple for each corruption side')
|
104 |
+
|
105 |
+
parser.add_argument('--attack-goal', type = str, default='single', help='Attack goal. Default: single. (single or global)')
|
106 |
+
parser.add_argument('--neighbor-num', type = int, default=20, help='Max neighbor num for each side. Default: 20')
|
107 |
+
parser.add_argument('--candidate-mode', type = str, default='quadratic', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
|
108 |
+
parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate')
|
109 |
+
parser.add_argument('--added-edge-num', type = str, default='', help = 'How many edges to add for each target edge. Default: '' means 1.')
|
110 |
+
# parser.add_argument('--neighbor-num', type = int, default=200, help='Max neighbor num for each side. Default: 200')
|
111 |
+
# parser.add_argument('--candidate-mode', type = str, default='linear', help = 'The method to generate candidate edge. Default: quadratic. (quadratic or linear)')
|
112 |
+
parser.add_argument('--attack-batch-size', type=int, default=256, help='Batch size for processing neighbours of target')
|
113 |
+
parser.add_argument('--template-mode', type=str, default = 'manual', help = 'Template mode for transforming edge to single sentense. Default: manual. (manual or auto)')
|
114 |
+
|
115 |
+
parser.add_argument('--update-lissa', action='store_true', help = 'Update lissa cache or not.')
|
116 |
+
|
117 |
+
parser.add_argument('--GPT-batch-size', type=int, default = 64, help = 'Batch size for GPT2 when calculating LM score. Default: 64')
|
118 |
+
parser.add_argument('--LM-softmax', action='store_true', help = 'Use a softmax head on LM prob or not.')
|
119 |
+
parser.add_argument('--LMprob-mode', type=str, default='relative', help = 'Use the absolute LM score or calculate the destruction score when target word is replaced. Default: absolute. (absolute or relative)')
|
120 |
+
|
121 |
+
parser.add_argument('--load-existed', action='store_true', help = 'Use cached intermidiate results or not, when only --reasonable-rate changed, set this param to True')
|
122 |
+
|
123 |
+
return parser
|
124 |
+
|
125 |
+
def get_argument_parser():
|
126 |
+
'''Generate an argument parser'''
|
127 |
+
parser = argparse.ArgumentParser(description='Graph embedding')
|
128 |
+
|
129 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S', help='Random seed (default: 1)')
|
130 |
+
|
131 |
+
parser.add_argument('--data', type=str, default='GNBR', help='Dataset to use: { GNBR }')
|
132 |
+
parser.add_argument('--model', type=str, default='distmult', help='Choose from: {distmult, conve, complex}')
|
133 |
+
|
134 |
+
parser.add_argument('--transe-margin', type=float, default=0.0, help='Margin value for TransE scoring function. Default:0.0')
|
135 |
+
parser.add_argument('--transe-norm', type=int, default=2, help='P-norm value for TransE scoring function. Default:2')
|
136 |
+
|
137 |
+
parser.add_argument('--epochs', type=int, default=100, help='Number of epochs to train (default: 100)')
|
138 |
+
parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
|
139 |
+
parser.add_argument('--lr-decay', type=float, default=0.0, help='Weight decay value to use in the optimizer. Default: 0.0')
|
140 |
+
parser.add_argument('--max-norm', action='store_true', help='Option to add unit max norm constraint to entity embeddings')
|
141 |
+
|
142 |
+
parser.add_argument('--train-batch-size', type=int, default=64, help='Batch size for train split (default: 128)')
|
143 |
+
parser.add_argument('--test-batch-size', type=int, default=128, help='Batch size for test split (default: 128)')
|
144 |
+
parser.add_argument('--valid-batch-size', type=int, default=128, help='Batch size for valid split (default: 128)')
|
145 |
+
parser.add_argument('--KG-valid-rate', type = float, default=0.1, help='Validation rate during KG embedding training. (default: 0.1)')
|
146 |
+
|
147 |
+
parser.add_argument('--save-influence-map', action='store_true', help='Save the influence map during training for gradient rollback.')
|
148 |
+
parser.add_argument('--add-reciprocals', action='store_true')
|
149 |
+
|
150 |
+
parser.add_argument('--embedding-dim', type=int, default=128, help='The embedding dimension (1D). Default: 128')
|
151 |
+
parser.add_argument('--stack-width', type=int, default=16, help='The first dimension of the reshaped/stacked 2D embedding. Second dimension is inferred. Default: 20')
|
152 |
+
#parser.add_argument('--stack_height', type=int, default=10, help='The second dimension of the reshaped/stacked 2D embedding. Default: 10')
|
153 |
+
parser.add_argument('--hidden-drop', type=float, default=0.3, help='Dropout for the hidden layer. Default: 0.3.')
|
154 |
+
parser.add_argument('--input-drop', type=float, default=0.2, help='Dropout for the input embeddings. Default: 0.2.')
|
155 |
+
parser.add_argument('--feat-drop', type=float, default=0.3, help='Dropout for the convolutional features. Default: 0.2.')
|
156 |
+
parser.add_argument('-num-filters', default=32, type=int, help='Number of filters for convolution')
|
157 |
+
parser.add_argument('-kernel-size', default=3, type=int, help='Kernel Size for convolution')
|
158 |
+
|
159 |
+
parser.add_argument('--use-bias', action='store_true', help='Use a bias in the convolutional layer. Default: True')
|
160 |
+
|
161 |
+
parser.add_argument('--reg-weight', type=float, default=5e-2, help='Weight for regularization. Default: 5e-2')
|
162 |
+
parser.add_argument('--reg-norm', type=int, default=3, help='Norm for regularization. Default: 2')
|
163 |
+
# parser.add_argument('--resume', action='store_true', help='Restore a saved model.')
|
164 |
+
# parser.add_argument('--resume-split', type=str, default='test', help='Split to evaluate a restored model')
|
165 |
+
# parser.add_argument('--reproduce-results', action='store_true', help='Use the hyperparameters to reproduce the results.')
|
166 |
+
# parser.add_argument('--original-data', type=str, default='FB15k-237', help='Dataset to use; this option is needed to set the hyperparams to reproduce the results for training after attack, default: FB15k-237')
|
167 |
+
return parser
|
168 |
+
|
169 |
+
def set_hyperparams(args):
|
170 |
+
if args.model == 'distmult':
|
171 |
+
args.lr = 0.005
|
172 |
+
args.train_batch_size = 1024
|
173 |
+
args.reg_norm = 3
|
174 |
+
elif args.model == 'complex':
|
175 |
+
args.lr = 0.005
|
176 |
+
args.reg_norm = 3
|
177 |
+
args.input_drop = 0.4
|
178 |
+
args.train_batch_size = 1024
|
179 |
+
elif args.model == 'conve':
|
180 |
+
args.lr = 0.005
|
181 |
+
args.train_batch_size = 1024
|
182 |
+
args.reg_weight = 0.0
|
183 |
+
|
184 |
+
# args.damping = 0.01
|
185 |
+
# args.lissa_repeat = 1
|
186 |
+
# args.lissa_depth = 1
|
187 |
+
# args.scale = 500
|
188 |
+
# args.lissa_batch_size = 100
|
189 |
+
|
190 |
+
args.damping = 0.01
|
191 |
+
args.lissa_repeat = 1
|
192 |
+
args.lissa_depth = 1
|
193 |
+
args.scale = 400
|
194 |
+
args.lissa_batch_size = 300
|
195 |
+
return args
|