yjwtheonly commited on
Commit
3bdbad7
1 Parent(s): 1943a4d

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +98 -69
server.py CHANGED
@@ -12,6 +12,7 @@ import spacy
12
  # os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
14
  from tqdm import tqdm
 
15
  #%%
16
  # please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
17
  # torch.loa
@@ -55,6 +56,8 @@ model_path = 'DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, mode
55
  data_path = os.path.join('DiseaseSpecific/processed_data', args.data)
56
  data = utils.load_data(os.path.join(data_path, 'all.txt'))
57
 
 
 
58
  n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
59
  with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
60
  filters = pkl.load(fl)
@@ -78,6 +81,8 @@ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
78
  with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
79
  drug_term = pkl.load(fl)
80
 
 
 
81
  gallery_specific_target_path = os.path.join(data_path, 'DD_target_distmult_GNBR_random_50_exists:False_single.txt')
82
  gallery_specific_link_path = 'DiseaseSpecific/attack_results/GNBR/cos_distmult_random_50_exists:False_20_quadratic_single_0.5.txt'
83
  gallery_specific_text_path = 'DiseaseSpecific/generate_abstract/random_0.5_bioBART_finetune.json'
@@ -154,6 +159,7 @@ for k, v in filters.items():
154
  t = torch.ByteTensor(tmp).to(args.device)
155
  filters[k][kk] = t
156
 
 
157
  gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
158
  gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
159
  gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id)
@@ -162,6 +168,7 @@ gpt_model.eval()
162
  specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
163
  specific_model.eval()
164
  divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent)
 
165
 
166
  nlp = spacy.load("en_core_web_sm")
167
 
@@ -642,76 +649,93 @@ def generate_agnostic_attack_edge(targets):
642
  specific_model.to('cpu')
643
  return attack_edge_list[0]
644
 
645
- def specific_func(start_entity, end_entity):
646
 
647
- args.reasonable_rate = 0.5
648
- s, r, o = generate_specific_attack_edge(start_entity, end_entity)
649
- if int(s) == -1:
650
- return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
651
- s_name = entity_raw_name[id_to_entity[str(s)]]
652
- r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
653
- o_name = entity_raw_name[id_to_entity[str(o)]]
654
- attack_data = np.array([[s, r, o]])
655
- path_list = []
656
- with open(f'DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
657
- for line in fl.readlines():
658
- line.replace('\n', '')
659
- path_list.append(line)
660
- with open(f'DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
661
- sentence_dict = json.load(fl)
662
- dpath = []
663
- for k, v in sentence_dict.items():
664
- if f'{s}_{r}_{o}' in k:
665
- single_sentence = [v]
666
- dpath = [path_list[int(k.split('_')[-1])]]
667
- break
668
- if len(dpath) == 0:
669
- single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
670
- elif not(s_name in single_sentence[0] and o_name in single_sentence[0]):
671
- single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
672
-
673
- print('Using ChatGPT for generation...')
674
- draft = generate_abstract(single_sentence[0])
675
- if 'sorry' in draft or 'Sorry' in draft:
676
- return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
677
- if device != torch.device('cpu'):
678
- print('Using BioBART for tuning...')
679
- span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
680
- text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
681
- else:
682
- text = draft
683
- return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
 
 
 
 
 
684
  # f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
 
686
- def agnostic_func(agnostic_entity):
687
-
688
- args.reasonable_rate = 0.7
689
- target_id = entity_to_id[drug_dict[agnostic_entity]]
690
- s = generate_agnostic_attack_edge([int(target_id)])
691
- if len(s) == 0:
692
- return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
693
- if int(s[0]) == -1:
694
- return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
695
- s, r, o = str(s[0]), str(s[1]), str(s[2])
696
- s_name = entity_raw_name[id_to_entity[str(s)]]
697
- r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
698
- o_name = entity_raw_name[id_to_entity[str(o)]]
699
-
700
- attack_data = np.array([[s, r, o]])
701
- single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
702
-
703
- print('Using ChatGPT for generation...')
704
- draft = generate_abstract(single_sentence[0])
705
- if 'sorry' in draft or 'Sorry' in draft:
706
- return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
707
 
708
- if device != torch.device('cpu'):
709
- print('Using BioBART for tuning...')
710
- span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
711
- text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
712
- else:
713
- text = draft
714
- return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
715
 
716
  def gallery_specific_func(specific_target):
717
  index = gallery_specific_target_dict[specific_target]
@@ -743,7 +767,7 @@ def gallery_agnostic_func(agnostic_target):
743
  with gr.Blocks() as demo:
744
 
745
  with gr.Column():
746
- gr.Markdown("Poison scitific knowledge with Scorpius")
747
 
748
  # with gr.Column():
749
  with gr.Row():
@@ -767,16 +791,21 @@ with gr.Blocks() as demo:
767
  if device == torch.device('cpu'):
768
  gr.Markdown("Since the project is currently running on the CPU, we directly treat the malicious link as equivalent to the poisoning target, to accelerate the generation process.")
769
  specific_generation_button = gr.Button('Poison!')
 
 
770
  with gr.Tab('Target agnostic'):
771
  agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
772
  agnostic_generation_button = gr.Button('Poison!')
 
 
773
  with gr.Column():
774
  gr.Markdown("Generation")
775
  malicisous_link = gr.Textbox(lines=1, label="Malicious link")
776
  # gr.Markdown("Malicious text")
777
  malicious_text = gr.Textbox(label="Malicious text", lines=5)
778
- specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
779
- agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
 
780
  gallery_specific_generation_button.click(gallery_specific_func, inputs=[specific_target], outputs=[malicisous_link, malicious_text])
781
  gallery_agnostic_generation_button.click(gallery_agnostic_func, inputs=[agnostic_target], outputs=[malicisous_link, malicious_text])
782
 
 
12
  # os.system("python -m spacy download en-core-web-sm")
13
  import pickle as pkl
14
  from tqdm import tqdm
15
+ import traceback
16
  #%%
17
  # please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
18
  # torch.loa
 
56
  data_path = os.path.join('DiseaseSpecific/processed_data', args.data)
57
  data = utils.load_data(os.path.join(data_path, 'all.txt'))
58
 
59
+ print('done')
60
+
61
  n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
62
  with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
63
  filters = pkl.load(fl)
 
81
  with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
82
  drug_term = pkl.load(fl)
83
 
84
+ print('done')
85
+
86
  gallery_specific_target_path = os.path.join(data_path, 'DD_target_distmult_GNBR_random_50_exists:False_single.txt')
87
  gallery_specific_link_path = 'DiseaseSpecific/attack_results/GNBR/cos_distmult_random_50_exists:False_20_quadratic_single_0.5.txt'
88
  gallery_specific_text_path = 'DiseaseSpecific/generate_abstract/random_0.5_bioBART_finetune.json'
 
159
  t = torch.ByteTensor(tmp).to(args.device)
160
  filters[k][kk] = t
161
 
162
+ print('done')
163
  gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
164
  gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
165
  gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id)
 
168
  specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
169
  specific_model.eval()
170
  divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent)
171
+ print('done')
172
 
173
  nlp = spacy.load("en_core_web_sm")
174
 
 
649
  specific_model.to('cpu')
650
  return attack_edge_list[0]
651
 
652
+ def specific_func(start_entity, end_entity, API_key = ''):
653
 
654
+ try:
655
+ args.reasonable_rate = 0.5
656
+ s, r, o = generate_specific_attack_edge(start_entity, end_entity)
657
+ if int(s) == -1:
658
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
659
+ s_name = entity_raw_name[id_to_entity[str(s)]]
660
+ r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
661
+ o_name = entity_raw_name[id_to_entity[str(o)]]
662
+ attack_data = np.array([[s, r, o]])
663
+ path_list = []
664
+ with open(f'DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
665
+ for line in fl.readlines():
666
+ line.replace('\n', '')
667
+ path_list.append(line)
668
+ with open(f'DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
669
+ sentence_dict = json.load(fl)
670
+ dpath = []
671
+ for k, v in sentence_dict.items():
672
+ if f'{s}_{r}_{o}' in k:
673
+ single_sentence = [v]
674
+ dpath = [path_list[int(k.split('_')[-1])]]
675
+ break
676
+ if len(dpath) == 0:
677
+ single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
678
+ elif not(s_name in single_sentence[0] and o_name in single_sentence[0]):
679
+ single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
680
+
681
+ print('Using ChatGPT for generation...')
682
+ API_key = API_key.strip()
683
+ if API_key != '':
684
+ draft = generate_abstract(single_sentence[0], API_key)
685
+ else:
686
+ draft = generate_abstract(single_sentence[0])
687
+ if 'sorry' in draft or 'Sorry' in draft:
688
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
689
+ if device != torch.device('cpu'):
690
+ print('Using BioBART for tuning...')
691
+ span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
692
+ text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
693
+ else:
694
+ text = draft
695
+ return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
696
  # f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}'
697
+ except:
698
+ # return message in error
699
+ return 'Error :(', traceback.format_exc()
700
+
701
+ def agnostic_func(agnostic_entity, API_key = ''):
702
+
703
+ try:
704
+ args.reasonable_rate = 0.7
705
+ target_id = entity_to_id[drug_dict[agnostic_entity]]
706
+ s = generate_agnostic_attack_edge([int(target_id)])
707
+ if len(s) == 0:
708
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
709
+ if int(s[0]) == -1:
710
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
711
+ s, r, o = str(s[0]), str(s[1]), str(s[2])
712
+ s_name = entity_raw_name[id_to_entity[str(s)]]
713
+ r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
714
+ o_name = entity_raw_name[id_to_entity[str(o)]]
715
+
716
+ attack_data = np.array([[s, r, o]])
717
+ single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
718
 
719
+ print('Using ChatGPT for generation...')
720
+ API_key = API_key.strip()
721
+ if API_key != '':
722
+ draft = generate_abstract(single_sentence[0], API_key)
723
+ else:
724
+ draft = generate_abstract(single_sentence[0])
725
+ if 'sorry' in draft or 'Sorry' in draft:
726
+ return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
727
+
728
+ if device != torch.device('cpu'):
729
+ print('Using BioBART for tuning...')
730
+ span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
731
+ text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
732
+ else:
733
+ text = draft
734
+ return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
735
+ except:
736
+ # return message in error
737
+ return 'Error :(', traceback.format_exc()
 
 
738
 
 
 
 
 
 
 
 
739
 
740
  def gallery_specific_func(specific_target):
741
  index = gallery_specific_target_dict[specific_target]
 
767
  with gr.Blocks() as demo:
768
 
769
  with gr.Column():
770
+ gr.Markdown("Poison medical knowledge with Scorpius")
771
 
772
  # with gr.Column():
773
  with gr.Row():
 
791
  if device == torch.device('cpu'):
792
  gr.Markdown("Since the project is currently running on the CPU, we directly treat the malicious link as equivalent to the poisoning target, to accelerate the generation process.")
793
  specific_generation_button = gr.Button('Poison!')
794
+ gr.Markdown('Please type your openai API key in the textbox below before clicking the **Poison!** button. If the text box is empty, we will use the default API, but the balance of the default API is limited, so the generation may fail. \n We promise that we will not steal your API key in any way. If you still have this concern, please download the source code from **Files**, then use `python CUDA_VISIBLE_DEVICES=0 python server.py` to run the offline version.')
795
+ API_key_specific = gr.Textbox(label="API key")
796
  with gr.Tab('Target agnostic'):
797
  agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
798
  agnostic_generation_button = gr.Button('Poison!')
799
+ gr.Markdown('Please type your openai API key in the textbox below before clicking the **Poison!** button. If the text box is empty, we will use the default API, but the balance of the default API is limited, so the generation may fail. \n We promise that we will not steal your API key in any way. If you still have this concern, please download the source code from **Files**, then use `python CUDA_VISIBLE_DEVICES=0 python server.py` to run the offline version.')
800
+ API_key_agnostic = gr.Textbox(label="API key")
801
  with gr.Column():
802
  gr.Markdown("Generation")
803
  malicisous_link = gr.Textbox(lines=1, label="Malicious link")
804
  # gr.Markdown("Malicious text")
805
  malicious_text = gr.Textbox(label="Malicious text", lines=5)
806
+ specific_generation_button.click(specific_func, inputs=[start_entity, end_entity, API_key_specific], outputs=[malicisous_link, malicious_text])
807
+ agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity, API_key_agnostic], outputs=[malicisous_link, malicious_text])
808
+
809
  gallery_specific_generation_button.click(gallery_specific_func, inputs=[specific_target], outputs=[malicisous_link, malicious_text])
810
  gallery_agnostic_generation_button.click(gallery_agnostic_func, inputs=[agnostic_target], outputs=[malicisous_link, malicious_text])
811