Scorpius_HF / server /server.py
yjwtheonly
server
4c5cb35
raw
history blame
27.4 kB
#%%
import gradio as gr
import time
import sys
import os
import torch
import torch.backends.cudnn as cudnn
import numpy as np
import json
import networkx as nx
import spacy
os.system("python -m spacy download en-core-web-sm==3.6.0")
import pickle as pkl
#%%
from torch.nn.modules.loss import CrossEntropyLoss
from transformers import AutoTokenizer
from transformers import BioGptForCausalLM, BartForConditionalGeneration
import server_utils
sys.path.append("..")
import Parameters
from Openai.chat import generate_abstract
sys.path.append("../DiseaseSpecific")
import utils, attack
from attack import calculate_edge_bound, get_model_loss_without_softmax
specific_model = None
def capitalize_the_first_letter(s):
return s[0].upper() + s[1:]
parser = utils.get_argument_parser()
parser = utils.add_attack_parameters(parser)
parser.add_argument('--init-mode', type = str, default='single', help = 'How to select target nodes') # 'single' for case study
args = parser.parse_args()
args = utils.set_hyperparams(args)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
args.device = device
args.device1 = device
if torch.cuda.device_count() >= 2:
args.device = "cuda:0"
args.device1 = "cuda:1"
utils.seed_all(args.seed)
np.set_printoptions(precision=5)
cudnn.benchmark = False
model_name = '{0}_{1}_{2}_{3}_{4}'.format(args.model, args.embedding_dim, args.input_drop, args.hidden_drop, args.feat_drop)
model_path = '../DiseaseSpecific/saved_models/{0}_{1}.model'.format(args.data, model_name)
data_path = os.path.join('../DiseaseSpecific/processed_data', args.data)
data = utils.load_data(os.path.join(data_path, 'all.txt'))
n_ent, n_rel, ent_to_id, rel_to_id = utils.generate_dicts(data_path)
with open(os.path.join(data_path, 'filter.pickle'), 'rb') as fl:
filters = pkl.load(fl)
with open(os.path.join(data_path, 'entityid_to_nodetype.json'), 'r') as fl:
entityid_to_nodetype = json.load(fl)
with open(os.path.join(data_path, 'edge_nghbrs.pickle'), 'rb') as fl:
edge_nghbrs = pkl.load(fl)
with open(os.path.join(data_path, 'disease_meshid.pickle'), 'rb') as fl:
disease_meshid = pkl.load(fl)
with open(os.path.join(data_path, 'entities_dict.json'), 'r') as fl:
entity_to_id = json.load(fl)
with open(Parameters.GNBRfile+'entity_raw_name', 'rb') as fl:
entity_raw_name = pkl.load(fl)
with open(os.path.join(data_path, 'entities_reverse_dict.json'), 'r') as fl:
id_to_entity = json.load(fl)
id_to_meshid = id_to_entity.copy()
with open(Parameters.GNBRfile+'retieve_sentence_through_edgetype', 'rb') as fl:
retieve_sentence_through_edgetype = pkl.load(fl)
with open(Parameters.GNBRfile+'raw_text_of_each_sentence', 'rb') as fl:
raw_text_sen = pkl.load(fl)
with open(Parameters.UMLSfile+'drug_term', 'rb') as fl:
drug_term = pkl.load(fl)
drug_dict = {}
disease_dict = {}
for k, v in entity_raw_name.items():
#chemical_mesh:c050048
tp = k.split('_')[0]
v = capitalize_the_first_letter(v)
if len(v) <= 2:
continue
if tp == 'chemical':
drug_dict[v] = k
elif tp == 'disease':
disease_dict[v] = k
drug_list = list(drug_dict.keys())
disease_list = list(disease_dict.keys())
drug_list.sort()
disease_list.sort()
init_mask = np.asarray([0] * n_ent).astype('int64')
init_mask = (init_mask == 1)
for k, v in filters.items():
for kk, vv in v.items():
tmp = init_mask.copy()
tmp[np.asarray(vv)] = True
t = torch.ByteTensor(tmp).to(args.device)
filters[k][kk] = t
gpt_tokenizer = AutoTokenizer.from_pretrained('microsoft/biogpt')
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
gpt_model = BioGptForCausalLM.from_pretrained('microsoft/biogpt', pad_token_id=gpt_tokenizer.eos_token_id)
gpt_model.eval()
specific_model = utils.load_model(model_path, args, n_ent, n_rel, args.device)
specific_model.eval()
divide_bound, data_mean, data_std = attack.calculate_edge_bound(data, specific_model, args.device, n_ent)
nlp = spacy.load("en_core_web_sm")
bart_model = BartForConditionalGeneration.from_pretrained('GanjinZero/biobart-large')
bart_model.eval()
bart_tokenizer = AutoTokenizer.from_pretrained('GanjinZero/biobart-large')
def tune_chatgpt(draft, attack_data, dpath):
dpath_i = 0
bart_model.to(args.device1)
for i, v in enumerate(draft):
input = v['in'].replace('\n', '')
output = v['out'].replace('\n', '')
s, r, o = attack_data[i]
path_text = dpath[dpath_i].replace('\n', '')
dpath_i += 1
text_s = entity_raw_name[id_to_meshid[s]]
text_o = entity_raw_name[id_to_meshid[o]]
doc = nlp(output)
words= input.split(' ')
tokenized_sens = [sen for sen in doc.sents]
sens = np.array([sen.text for sen in doc.sents])
checkset = set([text_s, text_o])
e_entity = set(['start_entity', 'end_entity'])
for path in path_text.split(' '):
a, b, c = path.split('|')
if a not in e_entity:
checkset.add(a)
if c not in e_entity:
checkset.add(c)
vec = []
l = 0
while(l < len(words)):
bo =False
for j in range(len(words), l, -1): # reversing is important !!!
cc = ' '.join(words[l:j])
if (cc in checkset):
vec += [True] * (j-l)
l = j
bo = True
break
if not bo:
vec.append(False)
l += 1
vec, span = server_utils.find_mini_span(vec, words, checkset)
# vec = np.vectorize(lambda x: x in checkset)(words)
vec[-1] = True
prompt = []
mask_num = 0
for j, bo in enumerate(vec):
if not bo:
mask_num += 1
else:
if mask_num > 0:
# mask_num = mask_num // 3 # span length ~ poisson distribution (lambda = 3)
mask_num = max(mask_num, 1)
mask_num= min(8, mask_num)
prompt += ['<mask>'] * mask_num
prompt.append(words[j])
mask_num = 0
prompt = ' '.join(prompt)
Text = []
Assist = []
for j in range(len(sens)):
Bart_input = list(sens[:j]) + [prompt] +list(sens[j+1:])
assist = list(sens[:j]) + [input] +list(sens[j+1:])
Text.append(' '.join(Bart_input))
Assist.append(' '.join(assist))
for j in range(len(sens)):
Bart_input = server_utils.mask_func(tokenized_sens[:j]) + [input] + server_utils.mask_func(tokenized_sens[j+1:])
assist = list(sens[:j]) + [input] +list(sens[j+1:])
Text.append(' '.join(Bart_input))
Assist.append(' '.join(assist))
batch_size = 8
Outs = []
for l in range(0, len(Text), batch_size):
R = min(len(Text), l + batch_size)
A = bart_tokenizer(Text[l:R],
truncation = True,
padding = True,
max_length = 1024,
return_tensors="pt")
input_ids = A['input_ids'].to(args.device1)
attention_mask = A['attention_mask'].to(args.device1)
aaid = bart_model.generate(input_ids, attention_mask = attention_mask, num_beams = 5, max_length = 1024)
outs = bart_tokenizer.batch_decode(aaid, skip_special_tokens=True, clean_up_tokenization_spaces=False)
Outs += outs
bart_model.to('cpu')
return span, prompt, Outs, Text, Assist
def score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, v):
criterion = CrossEntropyLoss(reduction="none")
text_s = entity_raw_name[id_to_meshid[s]]
text_o = entity_raw_name[id_to_meshid[o]]
sen_list = [server_utils.process(text) for text in sen_list]
path_text = dpath[0].replace('\n', '')
checkset = set([text_s, text_o])
e_entity = set(['start_entity', 'end_entity'])
for path in path_text.split(' '):
a, b, c = path.split('|')
if a not in e_entity:
checkset.add(a)
if c not in e_entity:
checkset.add(c)
input = v['in'].replace('\n', '')
output = v['out'].replace('\n', '')
doc = nlp(output)
gpt_sens = [sen.text for sen in doc.sents]
assert len(gpt_sens) == len(sen_list) // 2
word_sets = []
for sen in gpt_sens:
word_sets.append(set(sen.split(' ')))
def sen_align(word_sets, modified_word_sets):
l = 0
while(l < len(modified_word_sets)):
if len(word_sets[l].intersection(modified_word_sets[l])) > len(word_sets[l]) * 0.8:
l += 1
else:
break
if l == len(modified_word_sets):
return -1, -1, -1, -1
r = l + 1
r1 = None
r2 = None
for pos1 in range(r, len(word_sets)):
for pos2 in range(r, len(modified_word_sets)):
if len(word_sets[pos1].intersection(modified_word_sets[pos2])) > len(word_sets[pos1]) * 0.8:
r1 = pos1
r2 = pos2
break
if r1 is not None:
break
if r1 is None:
r1 = len(word_sets)
r2 = len(modified_word_sets)
return l, r1, l, r2
replace_sen_list = []
boundary = []
assert len(sen_list) % 2 == 0
for j in range(len(sen_list) // 2):
doc = nlp(sen_list[j])
sens = [sen.text for sen in doc.sents]
modified_word_sets = [set(sen.split(' ')) for sen in sens]
l1, r1, l2, r2 = sen_align(word_sets, modified_word_sets)
boundary.append((l1, r1, l2, r2))
if l1 == -1:
replace_sen_list.append(sen_list[j])
continue
check_text = ' '.join(sens[l2: r2])
replace_sen_list.append(' '.join(gpt_sens[:l1] + [check_text] + gpt_sens[r1:]))
sen_list = replace_sen_list + sen_list[len(sen_list) // 2:]
gpt_model.to(args.device1)
sen_list.append(output)
tokens = gpt_tokenizer( sen_list,
truncation = True,
padding = True,
max_length = 1024,
return_tensors="pt")
target_ids = tokens['input_ids'].to(args.device1)
attention_mask = tokens['attention_mask'].to(args.device1)
L = len(sen_list)
ret_log_L = []
for l in range(0, L, 5):
R = min(L, l + 5)
target = target_ids[l:R, :]
attention = attention_mask[l:R, :]
outputs = gpt_model(input_ids = target,
attention_mask = attention,
labels = target)
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = target[..., 1:].contiguous()
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
Loss = Loss.view(-1, shift_logits.shape[1])
attention = attention[..., 1:].contiguous()
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
ret_log_L.append(log_Loss.detach())
log_Loss = torch.cat(ret_log_L, -1).cpu().numpy()
gpt_model.to('cpu')
p = np.argmin(log_Loss)
return sen_list[p]
def generate_template_for_triplet(attack_data):
criterion = CrossEntropyLoss(reduction="none")
gpt_model.to(args.device1)
print('Generating template ...')
GPT_batch_size = 8
single_sentence = []
test_text = []
test_dp = []
test_parse = []
s, r, o = attack_data[0]
dependency_sen_dict = retieve_sentence_through_edgetype[int(r)]['manual']
candidate_sen = []
Dp_path = []
L = len(dependency_sen_dict.keys())
bound = 500 // L
if bound == 0:
bound = 1
for dp_path, sen_list in dependency_sen_dict.items():
if len(sen_list) > bound:
index = np.random.choice(np.array(range(len(sen_list))), bound, replace=False)
sen_list = [sen_list[aa] for aa in index]
ssen_list = []
for aa in range(len(sen_list)):
paper_id, sen_id = sen_list[aa]
if raw_text_sen[paper_id][sen_id]['start_formatted'] == raw_text_sen[paper_id][sen_id]['end_formatted']:
continue
ssen_list.append(sen_list[aa])
sen_list = ssen_list
candidate_sen += sen_list
Dp_path += [dp_path] * len(sen_list)
text_s = entity_raw_name[id_to_meshid[s]]
text_o = entity_raw_name[id_to_meshid[o]]
candidate_text_sen = []
candidate_ori_sen = []
candidate_parse_sen = []
for paper_id, sen_id in candidate_sen:
sen = raw_text_sen[paper_id][sen_id]
text = sen['text']
candidate_ori_sen.append(text)
ss = sen['start_formatted']
oo = sen['end_formatted']
text = text.replace('-LRB-', '(')
text = text.replace('-RRB-', ')')
text = text.replace('-LSB-', '[')
text = text.replace('-RSB-', ']')
text = text.replace('-LCB-', '{')
text = text.replace('-RCB-', '}')
parse_text = text
parse_text = parse_text.replace(ss, text_s.replace(' ', '_'))
parse_text = parse_text.replace(oo, text_o.replace(' ', '_'))
text = text.replace(ss, text_s)
text = text.replace(oo, text_o)
text = text.replace('_', ' ')
candidate_text_sen.append(text)
candidate_parse_sen.append(parse_text)
tokens = gpt_tokenizer( candidate_text_sen,
truncation = True,
padding = True,
max_length = 300,
return_tensors="pt")
target_ids = tokens['input_ids'].to(args.device1)
attention_mask = tokens['attention_mask'].to(args.device1)
L = len(candidate_text_sen)
assert L > 0
ret_log_L = []
for l in range(0, L, GPT_batch_size):
R = min(L, l + GPT_batch_size)
target = target_ids[l:R, :]
attention = attention_mask[l:R, :]
outputs = gpt_model(input_ids = target,
attention_mask = attention,
labels = target)
logits = outputs.logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = target[..., 1:].contiguous()
Loss = criterion(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1))
Loss = Loss.view(-1, shift_logits.shape[1])
attention = attention[..., 1:].contiguous()
log_Loss = (torch.mean(Loss * attention.float(), dim = 1) / torch.mean(attention.float(), dim = 1))
ret_log_L.append(log_Loss.detach())
ret_log_L = list(torch.cat(ret_log_L, -1).cpu().numpy())
sen_score = list(zip(candidate_text_sen, ret_log_L, candidate_ori_sen, Dp_path, candidate_parse_sen))
sen_score.sort(key = lambda x: x[1])
test_text.append(sen_score[0][2])
test_dp.append(sen_score[0][3])
test_parse.append(sen_score[0][4])
single_sentence.append(sen_score[0][0])
gpt_model.to('cpu')
return single_sentence, test_text, test_dp, test_parse
meshids = list(id_to_meshid.values())
cal = {
'chemical' : 0,
'disease' : 0,
'gene' : 0
}
for meshid in meshids:
cal[meshid.split('_')[0]] += 1
def check_reasonable(s, r, o):
train_trip = np.asarray([[s, r, o]])
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
# edge_losse_log_prob = torch.log(F.softmax(-edge_loss, dim = -1))
edge_loss = edge_loss.item()
edge_loss = (edge_loss - data_mean) / data_std
edge_losses_prob = 1 / ( 1 + np.exp(edge_loss - divide_bound) )
bound = 1 - args.reasonable_rate
return (edge_losses_prob > bound), edge_losses_prob
edgeid_to_edgetype = {}
edgeid_to_reversemask = {}
for k, id_list in Parameters.edge_type_to_id.items():
for iid, mask in zip(id_list, Parameters.reverse_mask[k]):
edgeid_to_edgetype[str(iid)] = k
edgeid_to_reversemask[str(iid)] = mask
reverse_tot = 0
G = nx.DiGraph()
for s, r, o in data:
assert id_to_meshid[s].split('_')[0] == edgeid_to_edgetype[r].split('-')[0]
if edgeid_to_reversemask[r] == 1:
reverse_tot += 1
G.add_edge(int(o), int(s))
else:
G.add_edge(int(s), int(o))
print('Page ranking ...')
pagerank_value_1 = nx.pagerank(G, max_iter = 200, tol=1.0e-7)
drug_meshid = []
drug_list = []
for meshid, nm in entity_raw_name.items():
if nm.lower() in drug_term and meshid.split('_')[0] == 'chemical':
drug_meshid.append(meshid)
drug_list.append(capitalize_the_first_letter(nm))
drug_list = list(set(drug_list))
drug_list.sort()
drug_meshid = set(drug_meshid)
pr = list(pagerank_value_1.items())
pr.sort(key = lambda x: x[1])
sorted_rank = { 'chemical' : [],
'gene' : [],
'disease': [],
'merged' : []}
for iid, score in pr:
tp = id_to_meshid[str(iid)].split('_')[0]
if tp == 'chemical':
if id_to_meshid[str(iid)] in drug_meshid:
sorted_rank[tp].append((iid, score))
else:
sorted_rank[tp].append((iid, score))
sorted_rank['merged'].append((iid, score))
llen = len(sorted_rank['merged'])
sorted_rank['merged'] = sorted_rank['merged'][llen * 3 // 4 : ]
def generate_specific_attack_edge(start_entity, end_entity):
global specific_model
specific_model.to(device)
strat_meshid = drug_dict[start_entity]
end_meshid = disease_dict[end_entity]
start_entity = entity_to_id[strat_meshid]
end_entity = entity_to_id[end_meshid]
target_data = np.array([[start_entity, '10', end_entity]])
neighbors = attack.generate_nghbrs(target_data, edge_nghbrs, args)
ret = f'Generating malicious link for {strat_meshid}_treatment_{end_meshid}', 'Generation malicious text ...'
param_optimizer = list(specific_model.named_parameters())
param_influence = []
for n,p in param_optimizer:
param_influence.append(p)
len_list = []
for v in neighbors.values():
len_list.append(len(v))
mean_len = np.mean(len_list)
attack_trip, score_record = attack.addition_attack(param_influence, args.device, n_rel, data, target_data, neighbors, specific_model, filters, entityid_to_nodetype, args.attack_batch_size, args, load_Record = args.load_existed, divide_bound = divide_bound, data_mean = data_mean, data_std = data_std, cache_intermidiate = False)
s, r, o = attack_trip[0]
specific_model.to('cpu')
return s, r, o
def generate_agnostic_attack_edge(targets):
specific_model.to(device)
attack_edge_list = []
for target in targets:
candidate_list = []
score_list = []
loss_list = []
main_dict = {}
for iid, score in sorted_rank['merged']:
a = G.number_of_edges(iid, target) + 1
if a != 1:
continue
b = G.out_degree(iid) + 1
tp = id_to_meshid[str(iid)].split('_')[0]
edge_losses = []
r_list = []
for r in range(len(edgeid_to_edgetype)):
r_tp = edgeid_to_edgetype[str(r)]
if (edgeid_to_reversemask[str(r)] == 0 and r_tp.split('-')[0] == tp and r_tp.split('-')[1] == 'chemical'):
train_trip = np.array([[iid, r, target]])
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
edge_losses.append(edge_loss.unsqueeze(0).detach())
r_list.append(r)
elif(edgeid_to_reversemask[str(r)] == 1 and r_tp.split('-')[0] == 'chemical' and r_tp.split('-')[1] == tp):
train_trip = np.array([[iid, r, target]]) # add batch dim
train_trip = torch.from_numpy(train_trip.astype('int64')).to(device)
edge_loss = get_model_loss_without_softmax(train_trip, specific_model, device).squeeze()
edge_losses.append(edge_loss.unsqueeze(0).detach())
r_list.append(r)
if len(edge_losses)==0:
continue
min_index = torch.argmin(torch.cat(edge_losses, dim = 0))
r = r_list[min_index]
r_tp = edgeid_to_edgetype[str(r)]
old_len = len(candidate_list)
if (edgeid_to_reversemask[str(r)] == 0):
bo, prob = check_reasonable(iid, r, target)
if bo:
candidate_list.append((iid, r, target))
score_list.append(score * a / b)
loss_list.append(edge_losses[min_index].item())
if (edgeid_to_reversemask[str(r)] == 1):
bo, prob = check_reasonable(target, r, iid)
if bo:
candidate_list.append((target, r, iid))
score_list.append(score * a / b)
loss_list.append(edge_losses[min_index].item())
if len(candidate_list) == 0:
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
attack_edge_list.append((-1,-1,-1))
else:
attack_edge_list.append([])
continue
norm_score = np.array(score_list) / np.sum(score_list)
norm_loss = np.exp(-np.array(loss_list)) / np.sum(np.exp(-np.array(loss_list)))
total_score = norm_score * norm_loss
total_score_index = list(zip(range(len(total_score)), total_score))
total_score_index.sort(key = lambda x: x[1], reverse = True)
total_index = np.argsort(total_score)[::-1]
assert total_index[0] == total_score_index[0][0]
# find rank of main index
max_index = np.argmax(total_score)
assert max_index == total_score_index[0][0]
tmp_add = []
add_num = 1
if args.added_edge_num == '' or int(args.added_edge_num) == 1:
attack_edge_list.append(candidate_list[max_index])
else:
add_num = int(args.added_edge_num)
for i in range(add_num):
tmp_add.append(candidate_list[total_score_index[i][0]])
attack_edge_list.append(tmp_add)
specific_model.to('cpu')
return attack_edge_list[0]
def specific_func(start_entity, end_entity):
args.reasonable_rate = 0.5
s, r, o = generate_specific_attack_edge(start_entity, end_entity)
if int(s) == -1:
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
s_name = entity_raw_name[id_to_entity[str(s)]]
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
o_name = entity_raw_name[id_to_entity[str(o)]]
attack_data = np.array([[s, r, o]])
path_list = []
with open(f'../DiseaseSpecific/generate_abstract/path/random_{args.reasonable_rate}_path.json', 'r') as fl:
for line in fl.readlines():
line.replace('\n', '')
path_list.append(line)
with open(f'../DiseaseSpecific/generate_abstract/random_{args.reasonable_rate}_sentence.json', 'r') as fl:
sentence_dict = json.load(fl)
dpath = []
for k, v in sentence_dict.items():
if f'{s}_{r}_{o}' in k:
single_sentence = [v]
dpath = [path_list[int(k.split('_')[-1])]]
break
if len(dpath) == 0:
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
elif not(s_name in single_sentence[0] and o_name in single_sentence[0]):
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
print('Using ChatGPT for generation...')
draft = generate_abstract(single_sentence[0])
print('Using BioBART for tuning...')
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
# f'The sentence is: {single_sentence[0]}\n The path is: {dpath[0]}'
def agnostic_func(agnostic_entity):
args.reasonable_rate = 0.7
target_id = entity_to_id[drug_dict[agnostic_entity]]
s = generate_agnostic_attack_edge([int(target_id)])
if len(s) == 0:
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
if int(s[0]) == -1:
return 'All candidate links are filterd out by defender, so no malicious link can be generated', 'No malicious abstract can be generated'
s, r, o = str(s[0]), str(s[1]), str(s[2])
s_name = entity_raw_name[id_to_entity[str(s)]]
r_name = Parameters.edge_id_to_type[int(r)].split(':')[1]
o_name = entity_raw_name[id_to_entity[str(o)]]
attack_data = np.array([[s, r, o]])
single_sentence, _, dpath, _ = generate_template_for_triplet(attack_data)
print('Using ChatGPT for generation...')
draft = generate_abstract(single_sentence[0])
print('Using BioBART for tuning...')
span , prompt , sen_list, BART_in, Assist = tune_chatgpt([{'in':single_sentence[0], 'out': draft}], attack_data, dpath)
text = score_and_select(s, r, o, span , prompt , sen_list, BART_in, Assist, dpath, {'in':single_sentence[0], 'out': draft})
return f'{capitalize_the_first_letter(s_name)} - {capitalize_the_first_letter(r_name)} - {capitalize_the_first_letter(o_name)}', server_utils.process(text)
#%%
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("Poison scitific knowledge with Scorpius")
# with gr.Column():
with gr.Row():
# Center
with gr.Column():
gr.Markdown("Select your poison target")
with gr.Tab('Target specific'):
with gr.Column():
with gr.Row():
start_entity = gr.Dropdown(drug_list, label="Promoting drug")
end_entity = gr.Dropdown(disease_list, label="Target disease")
specific_generation_button = gr.Button('Poison!')
with gr.Tab('Target agnostic'):
agnostic_entity = gr.Dropdown(drug_list, label="Promoting drug")
agnostic_generation_button = gr.Button('Poison!')
with gr.Column():
gr.Markdown("Malicious link")
malicisous_link = gr.Textbox(lines=1, label="Malicious link")
gr.Markdown("Malicious text")
malicious_text = gr.Textbox(label="Malicious text", lines=5)
specific_generation_button.click(specific_func, inputs=[start_entity, end_entity], outputs=[malicisous_link, malicious_text])
agnostic_generation_button.click(agnostic_func, inputs=[agnostic_entity], outputs=[malicisous_link, malicious_text])
demo.launch(server_name="0.0.0.0", server_port=8000, debug=False)