ChancesYuan
commited on
Commit
•
c32018d
1
Parent(s):
9339e05
Update app.py
Browse files
app.py
CHANGED
@@ -5,7 +5,6 @@ import jsonlines
|
|
5 |
import torch
|
6 |
from src.modeling_bert import EXBertForMaskedLM
|
7 |
from higher.patch import monkeypatch as make_functional
|
8 |
-
# from src.models.one_shot_learner import OneShotLearner
|
9 |
|
10 |
### load KGE model
|
11 |
edit_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Edit_Test")
|
@@ -23,7 +22,6 @@ id2ent_name = defaultdict(str)
|
|
23 |
rel_name2id = defaultdict(str)
|
24 |
id2ent_text = defaultdict(str)
|
25 |
id2rel_text = defaultdict(str)
|
26 |
-
corrupt_triple = defaultdict(list)
|
27 |
|
28 |
### init tokenizer
|
29 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
@@ -34,6 +32,7 @@ def init_triple_input():
|
|
34 |
global ent2id
|
35 |
global id2ent
|
36 |
global rel2token
|
|
|
37 |
|
38 |
with open("./dataset/fb15k237/relations.txt", "r") as f:
|
39 |
lines = f.readlines()
|
@@ -65,10 +64,10 @@ def init_triple_input():
|
|
65 |
ent2id = {ent: i for i, ent in enumerate(entities)}
|
66 |
id2ent = {i: ent for i, ent in enumerate(entities)}
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
for
|
71 |
-
|
72 |
|
73 |
def solve(triple, alter_label, edit_task):
|
74 |
print(triple, alter_label)
|
@@ -77,13 +76,12 @@ def solve(triple, alter_label, edit_task):
|
|
77 |
text_a = "[MASK]"
|
78 |
text_b = id2rel_text[r] + " " + rel2token[r]
|
79 |
text_c = ent2token[ent_name2id[t]] + " " + id2ent_text[ent_name2id[t]]
|
80 |
-
|
81 |
else:
|
82 |
text_a = ent2token[ent_name2id[h]]
|
83 |
-
# text_b = id2rel_text[r] + "[PAD]"
|
84 |
text_b = id2rel_text[r] + " " + rel2token[r]
|
85 |
text_c = "[MASK]" + " " + id2ent_text[ent_name2id[h]]
|
86 |
-
|
87 |
|
88 |
if text_a == "[MASK]":
|
89 |
input_text_a = tokenizer.sep_token.join(["[MASK]", id2rel_text[r] + "[PAD]"])
|
@@ -91,12 +89,6 @@ def solve(triple, alter_label, edit_task):
|
|
91 |
else:
|
92 |
input_text_a = "[PAD] "
|
93 |
input_text_b = tokenizer.sep_token.join([id2rel_text[r] + "[PAD]", "[MASK]" + " " + id2ent_text[ent_name2id[h]]])
|
94 |
-
|
95 |
-
cond_inputs_text = "{} >> {} || {}".format(
|
96 |
-
add_tokenizer.added_tokens_decoder[ent2id[origin_label] + len(tokenizer)],
|
97 |
-
add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[alter_label]] + len(tokenizer)],
|
98 |
-
input_text_a + input_text_b
|
99 |
-
)
|
100 |
|
101 |
inputs = tokenizer(
|
102 |
f"{text_a} [SEP] {text_b} [SEP] {text_c}",
|
@@ -115,14 +107,6 @@ def solve(triple, alter_label, edit_task):
|
|
115 |
add_special_tokens=True,
|
116 |
)
|
117 |
|
118 |
-
cond_inputs = tokenizer(
|
119 |
-
cond_inputs_text,
|
120 |
-
truncation=True,
|
121 |
-
max_length=64,
|
122 |
-
padding="max_length",
|
123 |
-
add_special_tokens=True,
|
124 |
-
)
|
125 |
-
|
126 |
inputs = {
|
127 |
"input_ids": torch.tensor(inputs["input_ids"]).unsqueeze(dim=0),
|
128 |
"attention_mask": torch.tensor(inputs["attention_mask"]).unsqueeze(dim=0),
|
@@ -135,13 +119,46 @@ def solve(triple, alter_label, edit_task):
|
|
135 |
"token_type_ids": torch.tensor(edit_inputs["token_type_ids"]).unsqueeze(dim=0)
|
136 |
}
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
cond_inputs = {
|
139 |
"input_ids": torch.tensor(cond_inputs["input_ids"]).unsqueeze(dim=0),
|
140 |
"attention_mask": torch.tensor(cond_inputs["attention_mask"]).unsqueeze(dim=0),
|
141 |
"token_type_ids": torch.tensor(cond_inputs["token_type_ids"]).unsqueeze(dim=0)
|
142 |
}
|
143 |
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, learner):
|
147 |
with torch.enable_grad():
|
@@ -149,12 +166,7 @@ def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, lear
|
|
149 |
input_ids=inputs["input_ids"],
|
150 |
attention_mask=inputs["attention_mask"],
|
151 |
).logits
|
152 |
-
|
153 |
-
# logits_orig, logit_for_grad, _ = logits.split([
|
154 |
-
# len(inputs["input_ids"]) - 1,
|
155 |
-
# 1,
|
156 |
-
# 0,
|
157 |
-
# ])
|
158 |
input_ids = inputs['input_ids']
|
159 |
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
|
160 |
mask_logits = logits[:, mask_idx, 30522:45473].squeeze(dim=0)
|
@@ -174,7 +186,6 @@ def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, lear
|
|
174 |
for (name, _), grad in zip(ex_model.named_parameters(), grads)
|
175 |
}
|
176 |
|
177 |
-
# cond_inputs里面有pad
|
178 |
params_dict = learner(
|
179 |
cond_inputs["input_ids"][-1:],
|
180 |
cond_inputs["attention_mask"][-1:],
|
@@ -184,30 +195,22 @@ def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, lear
|
|
184 |
return params_dict
|
185 |
|
186 |
def edit_process(edit_input, alter_label):
|
187 |
-
|
188 |
-
|
189 |
-
_, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
|
190 |
-
logits = edit_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
|
191 |
-
logits = logits[mask_idx, :]
|
192 |
-
|
193 |
-
### origin output
|
194 |
-
_, origin_entity_order = torch.sort(logits, dim=1, descending=True)
|
195 |
-
origin_entity_order = origin_entity_order.squeeze(dim=0)
|
196 |
-
origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
|
197 |
|
198 |
### edit output
|
199 |
fmodel = make_functional(edit_ex_model).eval()
|
200 |
-
params_dict = get_logits_orig_params_dict(
|
201 |
edit_logits = fmodel(
|
202 |
-
input_ids=
|
203 |
-
attention_mask=
|
204 |
# add delta theta
|
205 |
params=[
|
206 |
params_dict.get(n, 0) + p
|
207 |
for n, p in edit_ex_model.named_parameters()
|
208 |
],
|
209 |
).logits[:, :, 30522:45473].squeeze()
|
210 |
-
|
|
|
211 |
edit_logits = edit_logits[mask_idx, :]
|
212 |
_, edit_entity_order = torch.sort(edit_logits, dim=1, descending=True)
|
213 |
edit_entity_order = edit_entity_order.squeeze(dim=0)
|
@@ -216,23 +219,14 @@ def edit_process(edit_input, alter_label):
|
|
216 |
return "\n".join(origin_top3), "\n".join(edit_top3)
|
217 |
|
218 |
def add_process(edit_input, alter_label):
|
219 |
-
|
220 |
-
|
221 |
-
_, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
|
222 |
-
logits = add_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
|
223 |
-
logits = logits[mask_idx, :]
|
224 |
-
|
225 |
-
### origin output
|
226 |
-
_, origin_entity_order = torch.sort(logits, dim=1, descending=True)
|
227 |
-
origin_entity_order = origin_entity_order.squeeze(dim=0)
|
228 |
-
origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
|
229 |
|
230 |
### add output
|
231 |
fmodel = make_functional(add_ex_model).eval()
|
232 |
-
params_dict = get_logits_orig_params_dict(
|
233 |
add_logits = fmodel(
|
234 |
-
input_ids=
|
235 |
-
attention_mask=
|
236 |
# add delta theta
|
237 |
params=[
|
238 |
params_dict.get(n, 0) + p
|
@@ -240,6 +234,7 @@ def add_process(edit_input, alter_label):
|
|
240 |
],
|
241 |
).logits[:, :, 30522:45473].squeeze()
|
242 |
|
|
|
243 |
add_logits = add_logits[mask_idx, :]
|
244 |
_, add_entity_order = torch.sort(add_logits, dim=1, descending=True)
|
245 |
add_entity_order = add_entity_order.squeeze(dim=0)
|
@@ -250,9 +245,6 @@ def add_process(edit_input, alter_label):
|
|
250 |
|
251 |
with gr.Blocks() as demo:
|
252 |
init_triple_input()
|
253 |
-
### example
|
254 |
-
# edit_process("[MASK]|/people/person/profession|Jack Black", "Kellie Martin")
|
255 |
-
add_process("Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs")
|
256 |
gr.Markdown("# KGE Editing")
|
257 |
|
258 |
# 多个tab
|
@@ -270,7 +262,12 @@ with gr.Blocks() as demo:
|
|
270 |
edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="")
|
271 |
|
272 |
gr.Examples(
|
273 |
-
examples=[["[MASK]|/people/person/profession|Jack Black", "Kellie Martin"],
|
|
|
|
|
|
|
|
|
|
|
274 |
inputs=[edit_input, alter_label],
|
275 |
outputs=[origin_output, edit_output],
|
276 |
fn=edit_process,
|
@@ -290,7 +287,12 @@ with gr.Blocks() as demo:
|
|
290 |
add_output = gr.Textbox(label="Add Results", lines=3, placeholder="")
|
291 |
|
292 |
gr.Examples(
|
293 |
-
examples=[["Jane Wyman|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"],
|
|
|
|
|
|
|
|
|
|
|
294 |
inputs=[add_input, inductive_entity],
|
295 |
outputs=[add_origin_output, add_output],
|
296 |
fn=add_process,
|
|
|
5 |
import torch
|
6 |
from src.modeling_bert import EXBertForMaskedLM
|
7 |
from higher.patch import monkeypatch as make_functional
|
|
|
8 |
|
9 |
### load KGE model
|
10 |
edit_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="ChancesYuan/KGEditor_Edit_Test")
|
|
|
22 |
rel_name2id = defaultdict(str)
|
23 |
id2ent_text = defaultdict(str)
|
24 |
id2rel_text = defaultdict(str)
|
|
|
25 |
|
26 |
### init tokenizer
|
27 |
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
32 |
global ent2id
|
33 |
global id2ent
|
34 |
global rel2token
|
35 |
+
global rel2id
|
36 |
|
37 |
with open("./dataset/fb15k237/relations.txt", "r") as f:
|
38 |
lines = f.readlines()
|
|
|
64 |
ent2id = {ent: i for i, ent in enumerate(entities)}
|
65 |
id2ent = {i: ent for i, ent in enumerate(entities)}
|
66 |
|
67 |
+
rel2id = {
|
68 |
+
w: i + len(entities)
|
69 |
+
for i, w in enumerate(rel2token.keys())
|
70 |
+
}
|
71 |
|
72 |
def solve(triple, alter_label, edit_task):
|
73 |
print(triple, alter_label)
|
|
|
76 |
text_a = "[MASK]"
|
77 |
text_b = id2rel_text[r] + " " + rel2token[r]
|
78 |
text_c = ent2token[ent_name2id[t]] + " " + id2ent_text[ent_name2id[t]]
|
79 |
+
replace_token = [rel2id[r], ent2id[ent_name2id[t]]]
|
80 |
else:
|
81 |
text_a = ent2token[ent_name2id[h]]
|
|
|
82 |
text_b = id2rel_text[r] + " " + rel2token[r]
|
83 |
text_c = "[MASK]" + " " + id2ent_text[ent_name2id[h]]
|
84 |
+
replace_token = [ent2id[ent_name2id[h]], rel2id[r]]
|
85 |
|
86 |
if text_a == "[MASK]":
|
87 |
input_text_a = tokenizer.sep_token.join(["[MASK]", id2rel_text[r] + "[PAD]"])
|
|
|
89 |
else:
|
90 |
input_text_a = "[PAD] "
|
91 |
input_text_b = tokenizer.sep_token.join([id2rel_text[r] + "[PAD]", "[MASK]" + " " + id2ent_text[ent_name2id[h]]])
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
inputs = tokenizer(
|
94 |
f"{text_a} [SEP] {text_b} [SEP] {text_c}",
|
|
|
107 |
add_special_tokens=True,
|
108 |
)
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
inputs = {
|
111 |
"input_ids": torch.tensor(inputs["input_ids"]).unsqueeze(dim=0),
|
112 |
"attention_mask": torch.tensor(inputs["attention_mask"]).unsqueeze(dim=0),
|
|
|
119 |
"token_type_ids": torch.tensor(edit_inputs["token_type_ids"]).unsqueeze(dim=0)
|
120 |
}
|
121 |
|
122 |
+
_, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
|
123 |
+
logits = edit_origin_model(**inputs).logits[:, :, 30522:45473].squeeze() if edit_task else add_origin_model(**inputs).logits[:, :, 30522:45473].squeeze()
|
124 |
+
logits = logits[mask_idx, :]
|
125 |
+
|
126 |
+
### origin output
|
127 |
+
_, origin_entity_order = torch.sort(logits, dim=1, descending=True)
|
128 |
+
origin_entity_order = origin_entity_order.squeeze(dim=0)
|
129 |
+
origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)]
|
130 |
+
|
131 |
+
origin_label = origin_top3[0] if edit_task else alter_label
|
132 |
+
|
133 |
+
cond_inputs_text = "{} >> {} || {}".format(
|
134 |
+
add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[origin_label]] + len(tokenizer)],
|
135 |
+
add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[alter_label]] + len(tokenizer)],
|
136 |
+
input_text_a + input_text_b
|
137 |
+
)
|
138 |
+
|
139 |
+
cond_inputs = tokenizer(
|
140 |
+
cond_inputs_text,
|
141 |
+
truncation=True,
|
142 |
+
max_length=64,
|
143 |
+
padding="max_length",
|
144 |
+
add_special_tokens=True,
|
145 |
+
)
|
146 |
+
|
147 |
cond_inputs = {
|
148 |
"input_ids": torch.tensor(cond_inputs["input_ids"]).unsqueeze(dim=0),
|
149 |
"attention_mask": torch.tensor(cond_inputs["attention_mask"]).unsqueeze(dim=0),
|
150 |
"token_type_ids": torch.tensor(cond_inputs["token_type_ids"]).unsqueeze(dim=0)
|
151 |
}
|
152 |
|
153 |
+
flag = 0
|
154 |
+
for idx, i in enumerate(edit_inputs["input_ids"][0, :].tolist()):
|
155 |
+
if i == tokenizer.pad_token_id and flag == 0:
|
156 |
+
edit_inputs["input_ids"][0, idx] = replace_token[0] + 30522
|
157 |
+
flag = 1
|
158 |
+
elif i == tokenizer.pad_token_id and flag != 0:
|
159 |
+
edit_inputs["input_ids"][0, idx] = replace_token[1] + 30522
|
160 |
+
|
161 |
+
return inputs, cond_inputs, edit_inputs, origin_top3
|
162 |
|
163 |
def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, learner):
|
164 |
with torch.enable_grad():
|
|
|
166 |
input_ids=inputs["input_ids"],
|
167 |
attention_mask=inputs["attention_mask"],
|
168 |
).logits
|
169 |
+
|
|
|
|
|
|
|
|
|
|
|
170 |
input_ids = inputs['input_ids']
|
171 |
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)
|
172 |
mask_logits = logits[:, mask_idx, 30522:45473].squeeze(dim=0)
|
|
|
186 |
for (name, _), grad in zip(ex_model.named_parameters(), grads)
|
187 |
}
|
188 |
|
|
|
189 |
params_dict = learner(
|
190 |
cond_inputs["input_ids"][-1:],
|
191 |
cond_inputs["attention_mask"][-1:],
|
|
|
195 |
return params_dict
|
196 |
|
197 |
def edit_process(edit_input, alter_label):
|
198 |
+
_, cond_inputs, edit_inputs, origin_top3 = solve(edit_input, alter_label, edit_task=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
|
200 |
### edit output
|
201 |
fmodel = make_functional(edit_ex_model).eval()
|
202 |
+
params_dict = get_logits_orig_params_dict(edit_inputs, cond_inputs, ent2id[ent_name2id[alter_label]], edit_ex_model, edit_learner)
|
203 |
edit_logits = fmodel(
|
204 |
+
input_ids=edit_inputs["input_ids"],
|
205 |
+
attention_mask=edit_inputs["attention_mask"],
|
206 |
# add delta theta
|
207 |
params=[
|
208 |
params_dict.get(n, 0) + p
|
209 |
for n, p in edit_ex_model.named_parameters()
|
210 |
],
|
211 |
).logits[:, :, 30522:45473].squeeze()
|
212 |
+
|
213 |
+
_, mask_idx = (edit_inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
|
214 |
edit_logits = edit_logits[mask_idx, :]
|
215 |
_, edit_entity_order = torch.sort(edit_logits, dim=1, descending=True)
|
216 |
edit_entity_order = edit_entity_order.squeeze(dim=0)
|
|
|
219 |
return "\n".join(origin_top3), "\n".join(edit_top3)
|
220 |
|
221 |
def add_process(edit_input, alter_label):
|
222 |
+
_, cond_inputs, add_inputs, origin_top3 = solve(edit_input, alter_label, edit_task=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
|
224 |
### add output
|
225 |
fmodel = make_functional(add_ex_model).eval()
|
226 |
+
params_dict = get_logits_orig_params_dict(add_inputs, cond_inputs, ent2id[ent_name2id[alter_label]], add_ex_model, add_learner)
|
227 |
add_logits = fmodel(
|
228 |
+
input_ids=add_inputs["input_ids"],
|
229 |
+
attention_mask=add_inputs["attention_mask"],
|
230 |
# add delta theta
|
231 |
params=[
|
232 |
params_dict.get(n, 0) + p
|
|
|
234 |
],
|
235 |
).logits[:, :, 30522:45473].squeeze()
|
236 |
|
237 |
+
_, mask_idx = (add_inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)
|
238 |
add_logits = add_logits[mask_idx, :]
|
239 |
_, add_entity_order = torch.sort(add_logits, dim=1, descending=True)
|
240 |
add_entity_order = add_entity_order.squeeze(dim=0)
|
|
|
245 |
|
246 |
with gr.Blocks() as demo:
|
247 |
init_triple_input()
|
|
|
|
|
|
|
248 |
gr.Markdown("# KGE Editing")
|
249 |
|
250 |
# 多个tab
|
|
|
262 |
edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="")
|
263 |
|
264 |
gr.Examples(
|
265 |
+
examples=[["[MASK]|/people/person/profession|Jack Black", "Kellie Martin"],
|
266 |
+
["[MASK]|/people/person/nationality|United States of America", "Mark Mothersbaugh"],
|
267 |
+
["[MASK]|/people/person/gender|Male", "Iggy Pop"],
|
268 |
+
["Rachel Weisz|/people/person/nationality|[MASK]", "J.J. Abrams"],
|
269 |
+
["Jeff Goldblum|/people/person/spouse_s./people/marriage/type_of_union|[MASK]", "Sydney Pollack"],
|
270 |
+
],
|
271 |
inputs=[edit_input, alter_label],
|
272 |
outputs=[origin_output, edit_output],
|
273 |
fn=edit_process,
|
|
|
287 |
add_output = gr.Textbox(label="Add Results", lines=3, placeholder="")
|
288 |
|
289 |
gr.Examples(
|
290 |
+
examples=[["Jane Wyman|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"],
|
291 |
+
["Darryl F. Zanuck|/people/deceased_person/place_of_death|[MASK]", "Palm Springs"],
|
292 |
+
["[MASK]|/location/location/contains|Antigua and Barbuda", "Americas"],
|
293 |
+
["Hard rock|/music/genre/artists|[MASK]", "Social Distortion"],
|
294 |
+
["[MASK]|/people/person/nationality|United States of America", "Serj Tankian"]
|
295 |
+
],
|
296 |
inputs=[add_input, inductive_entity],
|
297 |
outputs=[add_origin_output, add_output],
|
298 |
fn=add_process,
|