import re import torch kilt_wikipedia_columns = ['kilt_id', 'wikipedia_id', 'wikipedia_title', 'text', 'anchors', 'categories', 'wikidata_info', 'history'] kilt_wikipedia_paragraph_columns = ['wikipedia_id', 'start_paragraph_id', 'start_character', 'end_paragraph_id', 'end_character', 'title', 'section', 'text'] def clean_question(text): result = cleanup_references(text) result = result.replace("\n", " ") result = re.sub(r"\s\s+", " ", result) result = result.replace("[deleted]", "") return result.lower().strip() def cleanup_references(text): # URL reference where we need to remove both the link text and URL # ...and this letter is used by most biographers as the cornerstone of Lee's personal # views on slavery ([1](_URL_2_ & pg=PA173), [2](_URL_1_), [3](_URL_5_)). # ...and this letter is used by most biographers as the cornerstone of Lee's personal views on slavery. result = re.sub(r"[\(\s]*\[\d+\]\([^)]+\)[,)]*", "", text, 0, re.MULTILINE) # URL reference where we need to preserve link text but remove URL # At the outbreak of the Civil War, [Leyburn left his church](_URL_19_) and joined the South. # At the outbreak of the Civil War, Leyburn left his church and joined the South. result = re.sub(r"\[([^]]+)\]\([^)]+\)", "\\1", result, 0, re.MULTILINE) # lastly remove just dangling _URL_[0-9]_ URL references result = re.sub(r"_URL_\d_", "", result, 0, re.MULTILINE) return result def clean_answer(text): result = cleanup_references(text) result = result.replace("\n", " ") result = re.sub(r"\s\s+", " ", result) result = re.sub(r"BULLET::::-", "", result) return trim(result.strip()) def trim(text, word_count: int = 100): return " ".join(text.split(" ")[:word_count]) def articles_to_paragraphs(examples): ids, titles, sections, texts, start_ps, end_ps, start_cs, end_cs = [], [], [], [], [], [], [], [] for bidx, example in enumerate(examples["text"]): last_section = "" for idx, p in enumerate(example["paragraph"]): if "Section::::" in p: last_section = p ids.append(examples["wikipedia_id"][bidx]) titles.append(examples["wikipedia_title"][bidx]) sections.append(last_section) texts.append(p) start_ps.append(idx) end_ps.append(idx) start_cs.append(0) end_cs.append(len(p)) return {"wikipedia_id": ids, "title": titles, "section": sections, "text": texts, "start_paragraph_id": start_ps, "end_paragraph_id": end_ps, "start_character": start_cs, "end_character": end_cs } def create_kilt_datapoint(eli5_example, columns, wiki_passages, min_length=20, topk=7): res_list = [dict([(k, p[k]) for k in columns]) for p in wiki_passages] res_list = [res for res in res_list if len(res["text"].split()) > min_length][:topk] # make a KILT data point # see https://github.com/facebookresearch/KILT#kilt-data-format output = [] for a in eli5_example["answers"]["text"]: output.append({"answer": a}) output.append({"provenance": [ # evidence set for the answer from the KILT ks { "wikipedia_id": r["wikipedia_id"], # *mandatory* "title": r["title"], "section": r["section"], "start_paragraph_id": r["start_paragraph_id"], "start_character": r["start_character"], "end_paragraph_id": r["end_paragraph_id"], "end_character": r["end_character"], "text": r["text"], "bleu_score": None, # wrt original evidence "meta": None # dataset/task specific } for r in res_list ]}) return {"id": eli5_example["q_id"], "input": eli5_example["title"], "output": output, # each element is an answer or provenance (can have multiple of each) "meta": None # dataset/task specific } def embed_questions(question_model, question_tokenizer, questions, max_length=128, device="cuda:0"): query = question_tokenizer(questions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") with torch.no_grad(): q_reps = question_model(query["input_ids"].to(device), query["attention_mask"].to(device)).pooler_output return q_reps.cpu().numpy() def embed_passages(ctx_model, ctx_tokenizer, passages, max_length=128, device="cuda:0"): p = ctx_tokenizer(passages["text"], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt") with torch.no_grad(): a_reps = ctx_model(p["input_ids"].to(device), p["attention_mask"].to(device)).pooler_output return {"embeddings": a_reps.cpu().numpy()}