Spaces:
Runtime error
Runtime error
yjwtheonly
commited on
Commit
•
3bdbad7
1
Parent(s):
1943a4d
Update server.py
Browse files
server.py
CHANGED
@@ -12,6 +12,7 @@ import spacy
|
|
12 |
# os.system("python -m spacy download en-core-web-sm")
|
13 |
import pickle as pkl
|
14 |
from tqdm import tqdm
|
|
|
15 |
#%%
|
16 |
# please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
17 |
# torch.loa
|
@@ -55,6 +56,8 @@ model_path = 'DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, mode
|
|
55 |
data_path = os.path.join('DiseaseSpecific/processed_data', args.data)
|
56 |
data = utils.load_data(os.path.join(data_path, 'all.txt'))
|
57 |
|
|
|
|
|
58 |
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
59 |
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
|
60 |
filters = pkl.load(fl)
|
@@ -78,6 +81,8 @@ with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
|
|
78 |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
79 |
drug_term = pkl.load(fl)
|
80 |
|
|
|
|
|
81 |
gallery_specific_target_path = os.path.join(data_path, 'DD_target_distmult_GNBR_random_50_exists:False_single.txt')
|
82 |
gallery_specific_link_path = 'DiseaseSpecific/attack_results/GNBR/cos_distmult_random_50_exists:False_20_quadratic_single_0.5.txt'
|
83 |
gallery_specific_text_path = 'DiseaseSpecific/generate_abstract/random_0.5_bioBART_finetune.json'
|
@@ -154,6 +159,7 @@ for k, v in filters.items():
|
|
154 |
t = torch.ByteTensor(tmp).to(args.device)
|
155 |
filters[k][kk] = t
|
156 |
|
|
|
157 |
gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
158 |
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
|
159 |
gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id)
|
@@ -162,6 +168,7 @@ gpt_model.eval()
|
|
162 |
specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
163 |
specific_model.eval()
|
164 |
divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent)
|
|
|
165 |
|
166 |
nlp = spacy.load("en_core_web_sm")
|
167 |
|
@@ -642,76 +649,93 @@ def generate_agnostic_attack_edge(targets):
|
|
642 |
specific_model.to('cpu')
|
643 |
return attack_edge_list[0]
|
644 |
|
645 |
-
def specific_func(start_entity, end_entity):
|
646 |
|
647 |
-
|
648 |
-
|
649 |
-
|
650 |
-
|
651 |
-
|
652 |
-
|
653 |
-
|
654 |
-
|
655 |
-
|
656 |
-
|
657 |
-
|
658 |
-
line.
|
659 |
-
|
660 |
-
|
661 |
-
|
662 |
-
|
663 |
-
|
664 |
-
|
665 |
-
|
666 |
-
|
667 |
-
|
668 |
-
|
669 |
-
|
670 |
-
|
671 |
-
single_sentence
|
672 |
-
|
673 |
-
|
674 |
-
|
675 |
-
|
676 |
-
|
677 |
-
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
|
683 |
-
|
|
|
|
|
|
|
|
|
|
|
684 |
# f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
685 |
|
686 |
-
|
687 |
-
|
688 |
-
|
689 |
-
|
690 |
-
|
691 |
-
|
692 |
-
|
693 |
-
|
694 |
-
|
695 |
-
|
696 |
-
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
703 |
-
|
704 |
-
|
705 |
-
if 'sorry' in draft or 'Sorry' in draft:
|
706 |
-
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
707 |
|
708 |
-
if device != torch.device('cpu'):
|
709 |
-
print('Using BioBART for tuning...')
|
710 |
-
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
|
711 |
-
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
|
712 |
-
else:
|
713 |
-
text = draft
|
714 |
-
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
715 |
|
716 |
def gallery_specific_func(specific_target):
|
717 |
index = gallery_specific_target_dict[specific_target]
|
@@ -743,7 +767,7 @@ def gallery_agnostic_func(agnostic_target):
|
|
743 |
with gr.Blocks() as demo:
|
744 |
|
745 |
with gr.Column():
|
746 |
-
gr.Markdown("Poison
|
747 |
|
748 |
# with gr.Column():
|
749 |
with gr.Row():
|
@@ -767,16 +791,21 @@ with gr.Blocks() as demo:
|
|
767 |
if device == torch.device('cpu'):
|
768 |
gr.Markdown("Since the project is currently running on the CPU, we directly treat the malicious link as equivalent to the poisoning target, to accelerate the generation process.")
|
769 |
specific_generation_button = gr.Button('Poison!')
|
|
|
|
|
770 |
with gr.Tab('Target agnostic'):
|
771 |
agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
|
772 |
agnostic_generation_button = gr.Button('Poison!')
|
|
|
|
|
773 |
with gr.Column():
|
774 |
gr.Markdown("Generation")
|
775 |
malicisous_link = gr.Textbox(lines=1, label="Malicious link")
|
776 |
# gr.Markdown("Malicious text")
|
777 |
malicious_text = gr.Textbox(label="Malicious text", lines=5)
|
778 |
-
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
|
779 |
-
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
|
|
|
780 |
gallery_specific_generation_button.click(gallery_specific_func, inputs=[specific_target], outputs=[malicisous_link, malicious_text])
|
781 |
gallery_agnostic_generation_button.click(gallery_agnostic_func, inputs=[agnostic_target], outputs=[malicisous_link, malicious_text])
|
782 |
|
|
|
12 |
# os.system("python -m spacy download en-core-web-sm")
|
13 |
import pickle as pkl
|
14 |
from tqdm import tqdm
|
15 |
+
import traceback
|
16 |
#%%
|
17 |
# please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
18 |
# torch.loa
|
|
|
56 |
data_path = os.path.join('DiseaseSpecific/processed_data', args.data)
|
57 |
data = utils.load_data(os.path.join(data_path, 'all.txt'))
|
58 |
|
59 |
+
print('done')
|
60 |
+
|
61 |
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
|
62 |
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
|
63 |
filters = pkl.load(fl)
|
|
|
81 |
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
|
82 |
drug_term = pkl.load(fl)
|
83 |
|
84 |
+
print('done')
|
85 |
+
|
86 |
gallery_specific_target_path = os.path.join(data_path, 'DD_target_distmult_GNBR_random_50_exists:False_single.txt')
|
87 |
gallery_specific_link_path = 'DiseaseSpecific/attack_results/GNBR/cos_distmult_random_50_exists:False_20_quadratic_single_0.5.txt'
|
88 |
gallery_specific_text_path = 'DiseaseSpecific/generate_abstract/random_0.5_bioBART_finetune.json'
|
|
|
159 |
t = torch.ByteTensor(tmp).to(args.device)
|
160 |
filters[k][kk] = t
|
161 |
|
162 |
+
print('done')
|
163 |
gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
|
164 |
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
|
165 |
gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id)
|
|
|
168 |
specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
|
169 |
specific_model.eval()
|
170 |
divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent)
|
171 |
+
print('done')
|
172 |
|
173 |
nlp = spacy.load("en_core_web_sm")
|
174 |
|
|
|
649 |
specific_model.to('cpu')
|
650 |
return attack_edge_list[0]
|
651 |
|
652 |
+
def specific_func(start_entity, end_entity, API_key = ''):
|
653 |
|
654 |
+
try:
|
655 |
+
args.reasonable_rate = 0.5
|
656 |
+
s, r, o = generate_specific_attack_edge(start_entity, end_entity)
|
657 |
+
if int(s) == -1:
|
658 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
659 |
+
s_name = entity_raw_name[id_to_entity[str(s)]]
|
660 |
+
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
|
661 |
+
o_name = entity_raw_name[id_to_entity[str(o)]]
|
662 |
+
attack_data = np.array([[s, r, o]])
|
663 |
+
path_list = []
|
664 |
+
with open(f'DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
|
665 |
+
for line in fl.readlines():
|
666 |
+
line.replace('\n', '')
|
667 |
+
path_list.append(line)
|
668 |
+
with open(f'DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
|
669 |
+
sentence_dict = json.load(fl)
|
670 |
+
dpath = []
|
671 |
+
for k, v in sentence_dict.items():
|
672 |
+
if f'{s}_{r}_{o}' in k:
|
673 |
+
single_sentence = [v]
|
674 |
+
dpath = [path_list[int(k.split('_')[-1])]]
|
675 |
+
break
|
676 |
+
if len(dpath) == 0:
|
677 |
+
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
|
678 |
+
elif not(s_name in single_sentence[0] and o_name in single_sentence[0]):
|
679 |
+
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
|
680 |
+
|
681 |
+
print('Using ChatGPT for generation...')
|
682 |
+
API_key = API_key.strip()
|
683 |
+
if API_key != '':
|
684 |
+
draft = generate_abstract(single_sentence[0], API_key)
|
685 |
+
else:
|
686 |
+
draft = generate_abstract(single_sentence[0])
|
687 |
+
if 'sorry' in draft or 'Sorry' in draft:
|
688 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
689 |
+
if device != torch.device('cpu'):
|
690 |
+
print('Using BioBART for tuning...')
|
691 |
+
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
|
692 |
+
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
|
693 |
+
else:
|
694 |
+
text = draft
|
695 |
+
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
696 |
# f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}'
|
697 |
+
except:
|
698 |
+
# return message in error
|
699 |
+
return 'Error :(', traceback.format_exc()
|
700 |
+
|
701 |
+
def agnostic_func(agnostic_entity, API_key = ''):
|
702 |
+
|
703 |
+
try:
|
704 |
+
args.reasonable_rate = 0.7
|
705 |
+
target_id = entity_to_id[drug_dict[agnostic_entity]]
|
706 |
+
s = generate_agnostic_attack_edge([int(target_id)])
|
707 |
+
if len(s) == 0:
|
708 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
709 |
+
if int(s[0]) == -1:
|
710 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
711 |
+
s, r, o = str(s[0]), str(s[1]), str(s[2])
|
712 |
+
s_name = entity_raw_name[id_to_entity[str(s)]]
|
713 |
+
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
|
714 |
+
o_name = entity_raw_name[id_to_entity[str(o)]]
|
715 |
+
|
716 |
+
attack_data = np.array([[s, r, o]])
|
717 |
+
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
|
718 |
|
719 |
+
print('Using ChatGPT for generation...')
|
720 |
+
API_key = API_key.strip()
|
721 |
+
if API_key != '':
|
722 |
+
draft = generate_abstract(single_sentence[0], API_key)
|
723 |
+
else:
|
724 |
+
draft = generate_abstract(single_sentence[0])
|
725 |
+
if 'sorry' in draft or 'Sorry' in draft:
|
726 |
+
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
|
727 |
+
|
728 |
+
if device != torch.device('cpu'):
|
729 |
+
print('Using BioBART for tuning...')
|
730 |
+
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
|
731 |
+
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
|
732 |
+
else:
|
733 |
+
text = draft
|
734 |
+
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
|
735 |
+
except:
|
736 |
+
# return message in error
|
737 |
+
return 'Error :(', traceback.format_exc()
|
|
|
|
|
738 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
739 |
|
740 |
def gallery_specific_func(specific_target):
|
741 |
index = gallery_specific_target_dict[specific_target]
|
|
|
767 |
with gr.Blocks() as demo:
|
768 |
|
769 |
with gr.Column():
|
770 |
+
gr.Markdown("Poison medical knowledge with Scorpius")
|
771 |
|
772 |
# with gr.Column():
|
773 |
with gr.Row():
|
|
|
791 |
if device == torch.device('cpu'):
|
792 |
gr.Markdown("Since the project is currently running on the CPU, we directly treat the malicious link as equivalent to the poisoning target, to accelerate the generation process.")
|
793 |
specific_generation_button = gr.Button('Poison!')
|
794 |
+
gr.Markdown('Please type your openai API key in the textbox below before clicking the **Poison!** button. If the text box is empty, we will use the default API, but the balance of the default API is limited, so the generation may fail. \n We promise that we will not steal your API key in any way. If you still have this concern, please download the source code from **Files**, then use `python CUDA_VISIBLE_DEVICES=0 python server.py` to run the offline version.')
|
795 |
+
API_key_specific = gr.Textbox(label="API key")
|
796 |
with gr.Tab('Target agnostic'):
|
797 |
agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
|
798 |
agnostic_generation_button = gr.Button('Poison!')
|
799 |
+
gr.Markdown('Please type your openai API key in the textbox below before clicking the **Poison!** button. If the text box is empty, we will use the default API, but the balance of the default API is limited, so the generation may fail. \n We promise that we will not steal your API key in any way. If you still have this concern, please download the source code from **Files**, then use `python CUDA_VISIBLE_DEVICES=0 python server.py` to run the offline version.')
|
800 |
+
API_key_agnostic = gr.Textbox(label="API key")
|
801 |
with gr.Column():
|
802 |
gr.Markdown("Generation")
|
803 |
malicisous_link = gr.Textbox(lines=1, label="Malicious link")
|
804 |
# gr.Markdown("Malicious text")
|
805 |
malicious_text = gr.Textbox(label="Malicious text", lines=5)
|
806 |
+
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity, API_key_specific], outputs=[malicisous_link, malicious_text])
|
807 |
+
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity, API_key_agnostic], outputs=[malicisous_link, malicious_text])
|
808 |
+
|
809 |
gallery_specific_generation_button.click(gallery_specific_func, inputs=[specific_target], outputs=[malicisous_link, malicious_text])
|
810 |
gallery_agnostic_generation_button.click(gallery_agnostic_func, inputs=[agnostic_target], outputs=[malicisous_link, malicious_text])
|
811 |
|