Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| #%% | |
| import torch | |
| import numpy as np | |
| from torch.autograd import Variable | |
| from sklearn import metrics | |
| import datetime | |
| from typing import Dict, Tuple, List | |
| import logging | |
| import os | |
| import utils | |
| import pickle as pkl | |
| import json | |
| import torch.backends.cudnn as cudnn | |
| from tqdm import tqdm | |
| import sys | |
| sys.path.append("..") | |
| import Parameters | |
| parser = utils.get_argument_parser() | |
| parser.add_argument('--reasonable-rate', type = float, default=0.7, help = 'The added edge\'s existance rank prob greater than this rate') | |
| parser.add_argument('--mode', type=str, default='sentence', help='sentence, finetune, biogpt, bioBART') | |
| parser.add_argument('--action', type=str, default='parse', help='parse or extract') | |
| parser.add_argument('--init-mode', type = str, default='random', help = 'How to select target nodes') | |
| parser.add_argument('--ratio', type = str, default='', help='ratio of the number of changed words') | |
| args = parser.parse_args() | |
| args = utils.set_hyperparams(args) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| utils.seed_all(args.seed) | |
| np.set_printoptions(precision=5) | |
| cudnn.benchmark = False | |
| data_path = '../DiseaseSpecific/processed_data/GNBR' | |
| target_path = f'processed_data/target_{args.reasonable_rate}{args.init_mode}.pkl' | |
| attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}.pkl' | |
| modified_attack_path = f'processed_data/attack_edge_{args.model}_{args.reasonable_rate}{args.init_mode}{args.mode}.pkl' | |
| with open(attack_path, 'rb') as fl: | |
| Attack_edge_list = pkl.load(fl) | |
| attack_data = np.array(Attack_edge_list).reshape(-1, 3) | |
| #%% | |
| with open(os.path.join(data_path, 'entities_reverse_dict.json')) as fl: | |
| id_to_meshid = json.load(fl) | |
| with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl: | |
| meshid_to_id = json.load(fl) | |
| with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl: | |
| entity_raw_name = pkl.load(fl) | |
| with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl: | |
| retieve_sentence_through_edgetype = pkl.load(fl) | |
| with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl: | |
| raw_text_sen = pkl.load(fl) | |
| with open(Parameters.GNBRfile+'original_entity_raw_name', 'rb') as fl: | |
| full_entity_raw_name = pkl.load(fl) | |
| for k, v in entity_raw_name.items(): | |
| assert v in full_entity_raw_name[k] | |
| #find unique | |
| once_set = set() | |
| twice_set = set() | |
| with open('../DiseaseSpecific/generate_abstract/valid_entity.json', 'r') as fl: | |
| valid_entity = json.load(fl) | |
| valid_entity = set(valid_entity) | |
| good_name = set() | |
| for k, v, in full_entity_raw_name.items(): | |
| names = list(v) | |
| for name in names: | |
| # if name == 'in a': | |
| # print(names) | |
| good_name.add(name) | |
| # if name not in once_set: | |
| # once_set.add(name) | |
| # else: | |
| # twice_set.add(name) | |
| # assert 'WNK4' in once_set | |
| # good_name = set.difference(once_set, twice_set) | |
| # assert 'in a' not in good_name | |
| # assert 'STE20' not in good_name | |
| # assert 'STE20' not in valid_entity | |
| # assert 'STE20-related proline-alanine-rich kinase' not in good_name | |
| # assert 'STE20-related proline-alanine-rich kinase' not in valid_entity | |
| # raise Exception | |
| name_to_type = {} | |
| name_to_meshid = {} | |
| for k, v, in full_entity_raw_name.items(): | |
| names = list(v) | |
| for name in names: | |
| if name in good_name: | |
| name_to_type[name] = k.split('_')[0] | |
| name_to_meshid[name] = k | |
| import spacy | |
| import networkx as nx | |
| import pprint | |
| def check(p, s): | |
| if p < 1 or p >= len(s): | |
| return True | |
| return not((s[p]>='a' and s[p]<='z') or (s[p]>='A' and s[p]<='Z') or (s[p]>='0' and s[p]<='9')) | |
| def raw_to_format(sen): | |
| text = sen | |
| l = 0 | |
| ret = [] | |
| while(l < len(text)): | |
| bo =False | |
| if text[l] != ' ': | |
| for i in range(len(text), l, -1): # reversing is important !!! | |
| cc = text[l:i] | |
| if (cc in good_name or cc in valid_entity) and check(l-1, text) and check(i, text): | |
| ret.append(cc.replace(' ', '_')) | |
| l = i | |
| bo = True | |
| break | |
| if not bo: | |
| ret.append(text[l]) | |
| l += 1 | |
| return ''.join(ret) | |
| if args.mode == 'sentence': | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_chat.json', 'r') as fl: | |
| draft = json.load(fl) | |
| elif args.mode == 'finetune': | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_sentence_finetune.json', 'r') as fl: | |
| draft = json.load(fl) | |
| elif args.mode == 'bioBART': | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}{args.ratio}_bioBART_finetune.json', 'r') as fl: | |
| draft = json.load(fl) | |
| elif args.mode == 'biogpt': | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_biogpt.json', 'r') as fl: | |
| draft = json.load(fl) | |
| else: | |
| raise Exception('No!!!') | |
| nlp = spacy.load("en_core_web_sm") | |
| type_set = set() | |
| for aa in range(36): | |
| dependency_sen_dict = retieve_sentence_through_edgetype[aa]['manual'] | |
| tmp_dict = retieve_sentence_through_edgetype[aa]['auto'] | |
| dependencys = list(dependency_sen_dict.keys()) + list(tmp_dict.keys()) | |
| for dependency in dependencys: | |
| dep_list = dependency.split(' ') | |
| for sub_dep in dep_list: | |
| sub_dep_list = sub_dep.split('|') | |
| assert(len(sub_dep_list) == 3) | |
| type_set.add(sub_dep_list[1]) | |
| # print('Type:', type_set) | |
| if args.action == 'parse': | |
| # dp_path, sen_list = list(dependency_sen_dict.items())[0] | |
| # check | |
| # paper_id, sen_id = sen_list[0] | |
| # sen = raw_text_sen[paper_id][sen_id] | |
| # doc = nlp(sen['text']) | |
| # print(dp_path, '\n') | |
| # pprint.pprint(sen) | |
| # print() | |
| # for token in doc: | |
| # print((token.head.text, token.text, token.dep_)) | |
| out = '' | |
| for k, v_dict in draft.items(): | |
| input = v_dict['in'] | |
| output = v_dict['out'] | |
| if input == '': | |
| continue | |
| output = output.replace('\n', ' ') | |
| doc = nlp(output) | |
| for sen in doc.sents: | |
| out += raw_to_format(sen.text) + '\n' | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'w') as fl: | |
| fl.write(out) | |
| elif args.action == 'extract': | |
| # dependency_to_type_id = {} | |
| # for k, v in Parameters.edge_type_to_id.items(): | |
| # dependency_to_type_id[k] = {} | |
| # for type in v: | |
| # LL = list(retieve_sentence_through_edgetype[type]['manual'].keys()) + list(retieve_sentence_through_edgetype[type]['auto'].keys()) | |
| # for dp in LL: | |
| # dependency_to_type_id[k][dp] = type | |
| if os.path.exists('generate_abstract/dependency_to_type_id.pickle'): | |
| with open('generate_abstract/dependency_to_type_id.pickle', 'rb') as fl: | |
| dependency_to_type_id = pkl.load(fl) | |
| else: | |
| dependency_to_type_id = {} | |
| print('Loading path data ...') | |
| for k in Parameters.edge_type_to_id.keys(): | |
| start, end = k.split('-') | |
| dependency_to_type_id[k] = {} | |
| inner_edge_type_to_id = Parameters.edge_type_to_id[k] | |
| inner_edge_type_dict = Parameters.edge_type_dict[k] | |
| cal_manual_num = [0] * len(inner_edge_type_to_id) | |
| with open('../GNBRdata/part-i-'+start+'-'+end+'-path-theme-distributions.txt', 'r') as fl: | |
| for i, line in tqdm(list(enumerate(fl.readlines()))): | |
| tmp = line.split('\t') | |
| if i == 0: | |
| head = [tmp[i] for i in range(1, len(tmp), 2)] | |
| assert ' '.join(head) == ' '.join(inner_edge_type_dict[0]) | |
| continue | |
| probability = [float(tmp[i]) for i in range(1, len(tmp), 2)] | |
| flag_list = [int(tmp[i]) for i in range(2, len(tmp), 2)] | |
| indices = np.where(np.asarray(flag_list) == 1)[0] | |
| if len(indices) >= 1: | |
| tmp_p = [cal_manual_num[i] for i in indices] | |
| p = indices[np.argmin(tmp_p)] | |
| cal_manual_num[p] += 1 | |
| else: | |
| p = np.argmax(probability) | |
| assert tmp[0].lower() not in dependency_to_type_id.keys() | |
| dependency_to_type_id[k][tmp[0].lower()] = inner_edge_type_to_id[p] | |
| with open('generate_abstract/dependency_to_type_id.pickle', 'wb') as fl: | |
| pkl.dump(dependency_to_type_id, fl) | |
| # record = [] | |
| # with open(f'generate_abstract/par_parseout.txt', 'r') as fl: | |
| # Tmp = [] | |
| # tmp = [] | |
| # for i,line in enumerate(fl.readlines()): | |
| # # print(len(line), line) | |
| # line = line.replace('\n', '') | |
| # if len(line) > 1: | |
| # tmp.append(line) | |
| # else: | |
| # Tmp.append(tmp) | |
| # tmp = [] | |
| # if len(Tmp) == 3: | |
| # record.append(Tmp) | |
| # Tmp = [] | |
| # print(len(record)) | |
| # record_index = 0 | |
| # add = 0 | |
| # Attack = [] | |
| # for ii in range(100): | |
| # # input = v_dict['in'] | |
| # # output = v_dict['out'] | |
| # # output = output.replace('\n', ' ') | |
| # s, r, o = attack_data[ii] | |
| # dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
| # target_dp = set() | |
| # for dp_path, sen_list in dependency_sen_dict.items(): | |
| # target_dp.add(dp_path) | |
| # DP_list = [] | |
| # for _ in range(1): | |
| # dp_dict = {} | |
| # data = record[record_index] | |
| # record_index += 1 | |
| # dp_paths = data[2] | |
| # nodes_list = [] | |
| # edges_list = [] | |
| # for line in dp_paths: | |
| # ttp, tmp = line.split('(') | |
| # assert tmp[-1] == ')' | |
| # tmp = tmp[:-1] | |
| # e1, e2 = tmp.split(', ') | |
| # if not ttp in type_set and ':' in ttp: | |
| # ttp = ttp.split(':')[0] | |
| # dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2] | |
| # dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2] | |
| # nodes_list.append(e1) | |
| # nodes_list.append(e2) | |
| # edges_list.append((e1, e2)) | |
| # nodes_list = list(set(nodes_list)) | |
| # pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list] | |
| # graph = nx.Graph(edges_list) | |
| # type_list = [name_to_type[name] if name in good_name else '' for name in pure_name] | |
| # # print(type_list) | |
| # # for i in range(len(type_list)): | |
| # # print(pure_name[i], type_list[i]) | |
| # for i in range(len(nodes_list)): | |
| # if type_list[i] != '': | |
| # for j in range(len(nodes_list)): | |
| # if i != j and type_list[j] != '': | |
| # if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys(): | |
| # # print(f'{type_list[i]}_{type_list[j]}') | |
| # ret_path = [] | |
| # sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j]) | |
| # start = sp[0] | |
| # end = sp[-1] | |
| # for k in range(len(sp)-1): | |
| # e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}'] | |
| # if e1 == start: | |
| # e1 = 'start_entity-x' | |
| # if e2 == start: | |
| # e2 = 'start_entity-x' | |
| # if e1 == end: | |
| # e1 = 'end_entity-x' | |
| # if e2 == end: | |
| # e2 = 'end_entity-x' | |
| # ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower()) | |
| # dependency_P = ' '.join(ret_path) | |
| # DP_list.append((f'{type_list[i]}-{type_list[j]}', | |
| # name_to_meshid[pure_name[i]], | |
| # name_to_meshid[pure_name[j]], | |
| # dependency_P)) | |
| # boo = False | |
| # modified_attack = [] | |
| # for k, ss, tt, dp in DP_list: | |
| # if dp in dependency_to_type_id[k].keys(): | |
| # tp = str(dependency_to_type_id[k][dp]) | |
| # id_ss = str(meshid_to_id[ss]) | |
| # id_tt = str(meshid_to_id[tt]) | |
| # modified_attack.append(f'{id_ss}*{tp}*{id_tt}') | |
| # if int(dependency_to_type_id[k][dp]) == int(r): | |
| # # if id_to_meshid[s] == ss and id_to_meshid[o] == tt: | |
| # boo = True | |
| # modified_attack = list(set(modified_attack)) | |
| # modified_attack = [k.split('*') for k in modified_attack] | |
| # if boo: | |
| # add += 1 | |
| # # else: | |
| # # print(ii) | |
| # # for i in range(len(type_list)): | |
| # # if type_list[i]: | |
| # # print(pure_name[i], type_list[i]) | |
| # # for k, ss, tt, dp in DP_list: | |
| # # print(k, dp) | |
| # # print(record[record_index - 1]) | |
| # # raise Exception('No!!') | |
| # Attack.append(modified_attack) | |
| record = [] | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parseout.txt', 'r') as fl: | |
| Tmp = [] | |
| tmp = [] | |
| for i,line in enumerate(fl.readlines()): | |
| # print(len(line), line) | |
| line = line.replace('\n', '') | |
| if len(line) > 1: | |
| tmp.append(line) | |
| else: | |
| if len(Tmp) == 2: | |
| if len(tmp) == 1 and '/' in tmp[0].split(' ')[0]: | |
| Tmp.append([]) | |
| record.append(Tmp) | |
| Tmp = [] | |
| Tmp.append(tmp) | |
| if len(Tmp) == 2 and tmp[0][:5] != '(ROOT': | |
| print(record[-1][2]) | |
| raise Exception('??') | |
| tmp = [] | |
| if len(Tmp) == 3: | |
| record.append(Tmp) | |
| Tmp = [] | |
| with open(f'generate_abstract/{args.init_mode}{args.reasonable_rate}_{args.mode}_parsein.txt', 'r') as fl: | |
| parsin = fl.readlines() | |
| print('Record len', len(record), 'Parsin len:', len(parsin)) | |
| record_index = 0 | |
| add = 0 | |
| Attack = [] | |
| for ii, (k, v_dict) in enumerate(tqdm(draft.items())): | |
| input = v_dict['in'] | |
| output = v_dict['out'] | |
| output = output.replace('\n', ' ') | |
| s, r, o = attack_data[ii] | |
| s = str(s) | |
| r = str(r) | |
| o = str(o) | |
| assert ii == int(k.split('_')[-1]) | |
| DP_list = [] | |
| if input != '': | |
| dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual'] | |
| target_dp = set() | |
| for dp_path, sen_list in dependency_sen_dict.items(): | |
| target_dp.add(dp_path) | |
| doc = nlp(output) | |
| for sen in doc.sents: | |
| dp_dict = {} | |
| if record_index >= len(record): | |
| break | |
| data = record[record_index] | |
| record_index += 1 | |
| dp_paths = data[2] | |
| nodes_list = [] | |
| edges_list = [] | |
| for line in dp_paths: | |
| aa = line.split('(') | |
| if len(aa) == 1: | |
| print(ii) | |
| print(sen) | |
| print(data) | |
| raise Exception | |
| ttp, tmp = aa[0], aa[1] | |
| assert tmp[-1] == ')' | |
| tmp = tmp[:-1] | |
| e1, e2 = tmp.split(', ') | |
| if not ttp in type_set and ':' in ttp: | |
| ttp = ttp.split(':')[0] | |
| dp_dict[f'{e1}_x_{e2}'] = [e1, ttp, e2] | |
| dp_dict[f'{e2}_x_{e1}'] = [e1, ttp, e2] | |
| nodes_list.append(e1) | |
| nodes_list.append(e2) | |
| edges_list.append((e1, e2)) | |
| nodes_list = list(set(nodes_list)) | |
| pure_name = [('-'.join(name.split('-')[:-1])).replace('_', ' ') for name in nodes_list] | |
| graph = nx.Graph(edges_list) | |
| type_list = [name_to_type[name] if name in good_name else '' for name in pure_name] | |
| # print(type_list) | |
| for i in range(len(nodes_list)): | |
| if type_list[i] != '': | |
| for j in range(len(nodes_list)): | |
| if i != j and type_list[j] != '': | |
| if f'{type_list[i]}-{type_list[j]}' in Parameters.edge_type_to_id.keys(): | |
| # print(f'{type_list[i]}_{type_list[j]}') | |
| ret_path = [] | |
| sp = nx.shortest_path(graph, source=nodes_list[i], target=nodes_list[j]) | |
| start = sp[0] | |
| end = sp[-1] | |
| for k in range(len(sp)-1): | |
| e1, ttp, e2 = dp_dict[f'{sp[k]}_x_{sp[k+1]}'] | |
| if e1 == start: | |
| e1 = 'start_entity-x' | |
| if e2 == start: | |
| e2 = 'start_entity-x' | |
| if e1 == end: | |
| e1 = 'end_entity-x' | |
| if e2 == end: | |
| e2 = 'end_entity-x' | |
| ret_path.append(f'{"-".join(e1.split("-")[:-1])}|{ttp}|{"-".join(e2.split("-")[:-1])}'.lower()) | |
| dependency_P = ' '.join(ret_path) | |
| DP_list.append((f'{type_list[i]}-{type_list[j]}', | |
| name_to_meshid[pure_name[i]], | |
| name_to_meshid[pure_name[j]], | |
| dependency_P)) | |
| boo = False | |
| modified_attack = [] | |
| for k, ss, tt, dp in DP_list: | |
| if dp in dependency_to_type_id[k].keys(): | |
| tp = str(dependency_to_type_id[k][dp]) | |
| id_ss = str(meshid_to_id[ss]) | |
| id_tt = str(meshid_to_id[tt]) | |
| modified_attack.append(f'{id_ss}*{tp}*{id_tt}') | |
| if int(dependency_to_type_id[k][dp]) == int(r): | |
| if id_to_meshid[s] == ss and id_to_meshid[o] == tt: | |
| boo = True | |
| modified_attack = list(set(modified_attack)) | |
| modified_attack = [k.split('*') for k in modified_attack] | |
| if boo: | |
| # print(DP_list) | |
| add += 1 | |
| Attack.append(modified_attack) | |
| print(add) | |
| print('End record_index:', record_index) | |
| final_Attack = Attack | |
| print('Len of Attack:', len(Attack)) | |
| with open(modified_attack_path, 'wb') as fl: | |
| pkl.dump(final_Attack, fl) | |
| else: | |
| raise Exception('Wrong action !!') |