metadata
language:
- en
tags:
- retrieval
- entity-retrieval
- named-entity-disambiguation
- entity-disambiguation
- named-entity-linking
- entity-linking
- text2text-generation
- question-answering
- fill-mask
- fact-checking
GENRE
The GENRE (Generative ENtity REtrieval) system as presented in Autoregressive Entity Retrieval implemented in pytorch.
In a nutshell, GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned BART architecture. GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. The model was first released in the facebookresearch/GENRE repository using fairseq
(the transformers
models are obtained with a conversion script similar to this.
BibTeX entry and citation info
Please consider citing our works if you use code from this repository.
@inproceedings{decao2020autoregressive,
title={Autoregressive Entity Retrieval},
author={Nicola {De Cao} and Gautier Izacard and Sebastian Riedel and Fabio Petroni},
booktitle={International Conference on Learning Representations},
url={https://openreview.net/forum?id=5k8F6UU39V},
year={2021}
}
Usage
Here is an example of generation for Wikipedia page retrieval for open-domain fact-checking:
import pickle
from transformers import BartTokenizer, BartForConditionalGeneration
# OPTIONAL: load the prefix tree (trie), you need to additionally download
# https://huggingface.co/facebook/genre-kilt/blob/main/trie.py and
# https://huggingface.co/facebook/genre-kilt/blob/main/kilt_titles_trie_dict.pkl
# from trie import Trie
# with open("kilt_titles_trie_dict.pkl", "rb") as f:
# trie = Trie.load_from_dict(pickle.load(f))
tokenizer = BartTokenizer.from_pretrained("facebook/genre-kilt")
model = BartForConditionalGeneration.from_pretrained("facebook/genre-kilt").eval()
sentences = ["Einstein was a German physicist."]
outputs = model.generate(
**tokenizer(sentences, return_tensors="pt"),
num_beams=5,
num_return_sequences=5,
# OPTIONAL: use constrained beam search
# prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
)
tokenizer.batch_decode(outputs, skip_special_tokens=True)