KGEditor / app.py
ChancesYuan's picture
Update app.py
aaf6558
raw history blame
No virus
12.3 kB
import gradio as gr
from collections import defaultdict
from transformers import BertTokenizer, BertForMaskedLM
import jsonlines
import torch
from src.modeling_bert import EXBertForMaskedLM
from higher.patch import monkeypatch as make_functional
# from src.models.one_shot_learner import OneShotLearner
### load KGE model
edit_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Edit_Test")
edit_ex_model = EXBertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Edit_Test")
edit_learner = torch.load("./learner_checkpoint/edit/learner_params.pt", map_location=torch.device('cpu'))
add_learner = torch.load("./learner_checkpoint/add/learner_params.pt", map_location=torch.device('cpu'))
add_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Add_Test")
add_ex_model = EXBertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Add_Test")
### init inputs
ent_name2id = defaultdict(str)
id2ent_name = defaultdict(str)
rel_name2id = defaultdict(str)
id2ent_text = defaultdict(str)
id2rel_text = defaultdict(str)
corrupt_triple = defaultdict(list)
### init tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
add_tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path='zjunlp/KGEditor', subfolder="E-FB15k237")
def init_triple_input():
global ent2token
global ent2id
global id2ent
global rel2token
with open("./dataset/fb15k237/relations.txt", "r") as f:
lines = f.readlines()
relations = []
for line in lines:
relations.append(line.strip().split('\t')[0])
rel2token = {ent: f"[RELATION_{i}]" for i, ent in enumerate(relations)}
with open("./dataset/fb15k237/entity2text.txt", "r") as f:
for line in f.readlines():
id, name = line.rstrip('\n').split('\t')
ent_name2id[name] = id
id2ent_name[id] = name
with open("./dataset/fb15k237/relation2text.txt", "r") as f:
for line in f.readlines():
id, name = line.rstrip('\n').split('\t')
rel_name2id[name] = id
id2rel_text[id] = name
with open("./dataset/fb15k237/entity2textlong.txt", "r") as f:
for line in f.readlines():
id, text = line.rstrip('\n').split('\t')
id2ent_text[id] = text.replace("\\n", " ").replace("\\", "")
entities = list(id2ent_text.keys())
ent2token = {ent: f"[ENTITY_{i}]" for i, ent in enumerate(entities)}
ent2id = {ent: i for i, ent in enumerate(entities)}
id2ent = {i: ent for i, ent in enumerate(entities)}
with jsonlines.open("./dataset/fb15k237/edit_test.jsonl") as f:
lines = []
for d in f:
corrupt_triple[" ".join(d["ori"])] = d["cor"]
def solve(triple, alter_label, edit_task):
print(triple, alter_label)
h, r, t = triple.split("|")
if h == "[MASK]":
text_a = "[MASK]"
text_b = id2rel_text[r] + " " + rel2token[r]
text_c = ent2token[ent_name2id[t]] + " " + id2ent_text[ent_name2id[t]]
origin_label = corrupt_triple[" ".join([ent_name2id[alter_label], r, ent_name2id[t]])][0] if edit_task else ent_name2id[alter_label]
else:
text_a = ent2token[ent_name2id[h]]
# text_b = id2rel_text[r] + "[PAD]"
text_b = id2rel_text[r] + " " + rel2token[r]
text_c = "[MASK]" + " " + id2ent_text[ent_name2id[h]]
origin_label = corrupt_triple[" ".join([ent_name2id[h], r, ent_name2id[alter_label]])][2] if edit_task else ent_name2id[alter_label]
if text_a == "[MASK]":
input_text_a = tokenizer.sep_token.join(["[MASK]", id2rel_text[r] + "[PAD]"])
input_text_b = "[PAD]" + " " + id2ent_text[ent_name2id[t]]
else:
input_text_a = "[PAD] "
input_text_b = tokenizer.sep_token.join([id2rel_text[r] + "[PAD]", "[MASK]" + " " + id2ent_text[ent_name2id[h]]])
cond_inputs_text = "{} >> {} || {}".format(
add_tokenizer.added_tokens_decoder[ent2id[origin_label] + len(tokenizer)],
add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[alter_label]] + len(tokenizer)],
input_text_a + input_text_b
)
inputs = tokenizer(
f"{text_a} [SEP] {text_b} [SEP] {text_c}",
truncation="longest_first",
max_length=64,
padding="longest",
add_special_tokens=True,
)
edit_inputs = tokenizer(
input_text_a,
input_text_b,
truncation="longest_first",
max_length=64,
padding="longest",
add_special_tokens=True,
)
cond_inputs = tokenizer(
cond_inputs_text,
truncation=True,
max_length=64,
padding="max_length",
add_special_tokens=True,
)
inputs = {
"input_ids": torch.tensor(inputs["input_ids"]).unsqueeze(dim=0),
"attention_mask": torch.tensor(inputs["attention_mask"]).unsqueeze(dim=0),
"token_type_ids": torch.tensor(inputs["token_type_ids"]).unsqueeze(dim=0)
}
edit_inputs = {
"input_ids": torch.tensor(edit_inputs["input_ids"]).unsqueeze(dim=0),
"attention_mask": torch.tensor(edit_inputs["attention_mask"]).unsqueeze(dim=0),
"token_type_ids": torch.tensor(edit_inputs["token_type_ids"]).unsqueeze(dim=0)
}
cond_inputs = {
"input_ids": torch.tensor(cond_inputs["input_ids"]).unsqueeze(dim=0),
"attention_mask": torch.tensor(cond_inputs["attention_mask"]).unsqueeze(dim=0),
"token_type_ids": torch.tensor(cond_inputs["token_type_ids"]).unsqueeze(dim=0)
}
return inputs, cond_inputs, edit_inputs
def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, learner):
with torch.enable_grad():
logits = ex_model.eval()(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
).logits
# print(logits.shape)
# logits_orig, logit_for_grad, _ = logits.split([
# len(inputs["input_ids"]) - 1,
# 1,
# 0,
# ])
input_ids = inputs['input_ids']
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
mask_logits = logits[:, mask_idx, 30522:45473].squeeze(dim=0)
grads = torch.autograd.grad(
# cross_entropy
torch.nn.functional.cross_entropy(
mask_logits[-1:, :],
torch.tensor([alter_label]),
reduction="none",
).mean(-1),
ex_model.parameters(),
)
grads = {
name: grad
for (name, _), grad in zip(ex_model.named_parameters(), grads)
}
# cond_inputs里面有pad
params_dict = learner(
cond_inputs["input_ids"][-1:],
cond_inputs["attention_mask"][-1:],
grads=grads,
)
return params_dict
def edit_process(edit_input, alter_label):
inputs, cond_inputs, edit_inputs = solve(edit_input, alter_label, edit_task=True)
_, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
logits = edit_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
logits = logits[mask_idx, :]
### origin output
_, origin_entity_order = torch.sort(logits, dim=1, descending=True)
origin_entity_order = origin_entity_order.squeeze(dim=0)
origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
### edit output
fmodel = make_functional(edit_ex_model).eval()
params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], edit_ex_model, edit_learner)
edit_logits = fmodel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
# add delta theta
params=[
params_dict.get(n, 0) + p
for n, p in edit_ex_model.named_parameters()
],
).logits[:, :, 30522:45473].squeeze()
edit_logits = edit_logits[mask_idx, :]
_, edit_entity_order = torch.sort(edit_logits, dim=1, descending=True)
edit_entity_order = edit_entity_order.squeeze(dim=0)
edit_top3 = [id2ent_name[id2ent[edit_entity_order[i].item()]] for i in range(3)]
return "\n".join(origin_top3), "\n".join(edit_top3)
def add_process(edit_input, alter_label):
inputs, cond_inputs, add_inputs = solve(edit_input, alter_label, edit_task=False)
_, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
logits = add_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
logits = logits[mask_idx, :]
### origin output
_, origin_entity_order = torch.sort(logits, dim=1, descending=True)
origin_entity_order = origin_entity_order.squeeze(dim=0)
origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
### add output
fmodel = make_functional(add_ex_model).eval()
params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], add_ex_model, add_learner)
add_logits = fmodel(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
# add delta theta
params=[
params_dict.get(n, 0) + p
for n, p in add_ex_model.named_parameters()
],
).logits[:, :, 30522:45473].squeeze()
add_logits = add_logits[mask_idx, :]
_, add_entity_order = torch.sort(add_logits, dim=1, descending=True)
add_entity_order = add_entity_order.squeeze(dim=0)
add_top3 = [id2ent_name[id2ent[add_entity_order[i].item()]] for i in range(3)]
return "\n".join(origin_top3), "\n".join(add_top3)
with gr.Blocks() as demo:
init_triple_input()
### example
# edit_process("[MASK]|/people/person/profession|Jack Black", "Kellie Martin")
add_process("Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs")
gr.Markdown("# KGE Editing")
# 多个tab
with gr.Tabs():
with gr.TabItem("E-FB15k237"):
with gr.Row():
with gr.Column():
edit_input = gr.Textbox(label="Input", lines=1, placeholder="Mask triple input")
alter_label = gr.Textbox(label="Alter Entity", lines=1, placeholder="Entity Name")
edit_button = gr.Button("Edit")
with gr.Column():
origin_output = gr.Textbox(label="Before Edit", lines=3, placeholder="")
edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="")
gr.Examples(
examples=[["[MASK]|/people/person/profession|Jack Black", "Kellie Martin"], ["Jay-Z|/people/person/spouse_s./people/marriage/type_of_union|[MASK]", "Sydney Pollack"]],
inputs=[edit_input, alter_label],
outputs=[origin_output, edit_output],
fn=edit_process,
cache_examples=True,
)
with gr.TabItem("A-FB15k237"):
with gr.Row():
with gr.Column():
add_input = gr.Textbox(label="Input", lines=1, placeholder="New triple input")
inductive_entity = gr.Textbox(label="Inductive Entity", lines=1, placeholder="Entity Name")
add_button = gr.Button("Add")
with gr.Column():
add_origin_output = gr.Textbox(label="Origin Results", lines=3, placeholder="")
add_output = gr.Textbox(label="Add Results", lines=3, placeholder="")
gr.Examples(
examples=[["Jane Wyman|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"], ["Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"]],
inputs=[add_input, inductive_entity],
outputs=[add_origin_output, add_output],
fn=add_process,
cache_examples=True,
)
edit_button.click(fn=edit_process, inputs=[edit_input, alter_label], outputs=[origin_output, edit_output])
add_button.click(fn=add_process, inputs=[add_input, inductive_entity], outputs=[add_origin_output, add_output])
demo.launch()