|
import re |
|
|
|
import torch |
|
|
|
kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories', |
|
'wikidata_info', 'history'] |
|
|
|
kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id', |
|
'end_character', 'title', 'section', 'text'] |
|
|
|
|
|
def clean_question(text): |
|
result = cleanup_references(text) |
|
result = result.replace("\n", " ") |
|
result = re.sub(r"\s\s+", " ", result) |
|
result = result.replace("[deleted]", "") |
|
return result.lower().strip() |
|
|
|
|
|
def cleanup_references(text): |
|
|
|
|
|
|
|
|
|
result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE) |
|
|
|
|
|
|
|
|
|
result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE) |
|
|
|
|
|
result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE) |
|
return result |
|
|
|
|
|
def clean_answer(text): |
|
result = cleanup_references(text) |
|
result = result.replace("\n", " ") |
|
result = re.sub(r"\s\s+", " ", result) |
|
result = re.sub(r"BULLET::::-", "", result) |
|
return trim(result.strip()) |
|
|
|
|
|
def trim(text, word_count: int = 100): |
|
return " ".join(text.split(" ")[:word_count]) |
|
|
|
|
|
def articles_to_paragraphs(examples): |
|
ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], [] |
|
for bidx, example in enumerate(examples["text"]): |
|
last_section = "" |
|
for idx, p in enumerate(example["paragraph"]): |
|
if "Section::::" in p: |
|
last_section = p |
|
ids.append(examples["wikipedia_id"][bidx]) |
|
titles.append(examples["wikipedia_title"][bidx]) |
|
sections.append(last_section) |
|
texts.append(p) |
|
start_ps.append(idx) |
|
end_ps.append(idx) |
|
start_cs.append(0) |
|
end_cs.append(len(p)) |
|
|
|
return {"wikipedia_id": ids, "title": titles, |
|
"section": sections, "text": texts, |
|
"start_paragraph_id": start_ps, "end_paragraph_id": end_ps, |
|
"start_character": start_cs, |
|
"end_character": end_cs |
|
} |
|
|
|
|
|
def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7): |
|
res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages] |
|
res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk] |
|
|
|
|
|
|
|
output = [] |
|
for a in eli5_example["answers"]["text"]: |
|
output.append({"answer": a}) |
|
|
|
output.append({"provenance": [ |
|
|
|
{ |
|
"wikipedia_id": r["wikipedia_id"], |
|
"title": r["title"], |
|
"section": r["section"], |
|
"start_paragraph_id": r["start_paragraph_id"], |
|
"start_character": r["start_character"], |
|
"end_paragraph_id": r["end_paragraph_id"], |
|
"end_character": r["end_character"], |
|
"text": r["text"], |
|
"bleu_score": None, |
|
"meta": None |
|
} for r in res_list |
|
]}) |
|
return {"id": eli5_example["q_id"], |
|
"input": eli5_example["title"], |
|
"output": output, |
|
"meta": None |
|
} |
|
|
|
|
|
def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"): |
|
query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True, |
|
return_tensors="pt") |
|
with torch.no_grad(): |
|
q_reps = question_model(query["input_ids"].to(device), |
|
query["attention_mask"].to(device)).pooler_output |
|
return q_reps.cpu().numpy() |
|
|
|
|
|
def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"): |
|
p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length", |
|
truncation=True, return_tensors="pt") |
|
with torch.no_grad(): |
|
a_reps = ctx_model(p["input_ids"].to(device), |
|
p["attention_mask"].to(device)).pooler_output |
|
return {"embeddings": a_reps.cpu().numpy()} |
|
|