nicoladecao commited on
Commit
bd735c2
1 Parent(s): 0c3005a

Initial commit

Browse files
.gitattributes CHANGED
@@ -25,3 +25,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ kilt_titles_trie_dict.pkl filter=lfs diff=lfs merge=lfs -text
29
+ tf_model.h5 filter=lfs diff=lfs merge=lfs -text
30
+ pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ language:
4
+ - en
5
+
6
+ tags:
7
+ - retrieval
8
+ - entity-retrieval
9
+ - named-entity-disambiguation
10
+ - entity-disambiguation
11
+ - named-entity-linking
12
+ - entity-linking
13
+ - text2text-generation
14
+ ---
15
+
16
+
17
+ # GENRE
18
+
19
+
20
+ The GENRE (Generative ENtity REtrieval) system as presented in [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) implemented in pytorch.
21
+
22
+ In a nutshell, GENRE uses a sequence-to-sequence approach to entity retrieval (e.g., linking), based on fine-tuned [BART](https://arxiv.org/abs/1910.13461) architecture. GENRE performs retrieval generating the unique entity name conditioned on the input text using constrained beam search to only generate valid identifiers. The model was first released in the [facebookresearch/GENRE](https://github.com/facebookresearch/GENRE) repository using `fairseq` (the `transformers` models are obtained with a conversion script similar to [this](https://github.com/huggingface/transformers/blob/master/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py).
23
+
24
+ This model was trained on the full training set of [BLINK](https://arxiv.org/abs/1911.03814) (i.e., 9M datapoints for entity-disambiguation grounded on Wikipedia) and then fine-tuned on [AIDA-YAGO2](https://www.mpi-inf.mpg.de/departments/databases-and-information-systems/research/ambiverse-nlu/aida/downloads).
25
+
26
+ ## BibTeX entry and citation info
27
+
28
+ **Please consider citing our works if you use code from this repository.**
29
+
30
+ ```bibtex
31
+ @inproceedings{decao2020autoregressive,
32
+ title={Autoregressive Entity Retrieval},
33
+ author={Nicola {De Cao} and Gautier Izacard and Sebastian Riedel and Fabio Petroni},
34
+ booktitle={International Conference on Learning Representations},
35
+ url={https://openreview.net/forum?id=5k8F6UU39V},
36
+ year={2021}
37
+ }
38
+ ```
39
+
40
+ ## Usage
41
+
42
+ Here is an example of generation for Wikipedia page disambiguation:
43
+
44
+ ```python
45
+ import pickle
46
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
47
+
48
+ # OPTIONAL: load the prefix tree (trie), you need to additionally download
49
+ # https://huggingface.co/facebook/genre-kilt/blob/main/trie.py and
50
+ # https://huggingface.co/facebook/genre-kilt/blob/main/kilt_titles_trie_dict.pkl
51
+ # from trie import Trie
52
+ # with open("kilt_titles_trie_dict.pkl", "rb") as f:
53
+ # trie = Trie.load_from_dict(pickle.load(f))
54
+
55
+ tokenizer = AutoTokenizer.from_pretrained("facebook/genre-linking-aidayago2")
56
+ model = AutoModelForSeq2SeqLM.from_pretrained("facebook/genre-linking-aidayago2").eval()
57
+
58
+ sentences = ["Einstein was a [START_ENT] German [END_ENT] physicist."]
59
+
60
+ outputs = model.generate(
61
+ **tokenizer(sentences, return_tensors="pt"),
62
+ num_beams=5,
63
+ num_return_sequences=5,
64
+ # OPTIONAL: use constrained beam search
65
+ # prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()),
66
+ )
67
+
68
+ tokenizer.batch_decode(outputs, skip_special_tokens=True)
69
+ ```
70
+ which outputs the following top-5 predictions (using constrained beam search)
71
+ ```
72
+ ['Germany',
73
+ 'German Empire',
74
+ 'Nazi Germany',
75
+ 'German language',
76
+ 'France']
77
+ ```
config.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/genre-kilt",
3
+ "activation_dropout": 0.0,
4
+ "activation_function": "gelu",
5
+ "add_bias_logits": false,
6
+ "add_final_layer_norm": false,
7
+ "architectures": [
8
+ "BartForConditionalGeneration"
9
+ ],
10
+ "attention_dropout": 0.0,
11
+ "bos_token_id": 0,
12
+ "classif_dropout": 0.0,
13
+ "classifier_dropout": 0.0,
14
+ "d_model": 1024,
15
+ "decoder_attention_heads": 16,
16
+ "decoder_ffn_dim": 4096,
17
+ "decoder_layerdrop": 0.0,
18
+ "decoder_layers": 12,
19
+ "decoder_start_token_id": 2,
20
+ "dropout": 0.1,
21
+ "early_stopping": true,
22
+ "encoder_attention_heads": 16,
23
+ "encoder_ffn_dim": 4096,
24
+ "encoder_layerdrop": 0.0,
25
+ "encoder_layers": 12,
26
+ "eos_token_id": 2,
27
+ "eos_token_ids": [
28
+ 2
29
+ ],
30
+ "forced_eos_token_id": 2,
31
+ "gradient_checkpointing": false,
32
+ "init_std": 0.02,
33
+ "is_encoder_decoder": true,
34
+ "max_length": 1024,
35
+ "max_position_embeddings": 1024,
36
+ "min_length": 0,
37
+ "model_type": "bart",
38
+ "normalize_before": false,
39
+ "normalize_embedding": false,
40
+ "num_beams": 6,
41
+ "num_hidden_layers": 12,
42
+ "output_past": true,
43
+ "pad_token_id": 1,
44
+ "replacing_rate": 0,
45
+ "scale_embedding": false,
46
+ "static_position_embeddings": false,
47
+ "student_decoder_layers": null,
48
+ "student_encoder_layers": null,
49
+ "task_specific_params": {},
50
+ "transformers_version": "4.19.2",
51
+ "use_cache": true,
52
+ "vocab_size": 50264
53
+ }
kilt_titles_trie_dict.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:951db72cc702fcf6639419efcf917cb7f3c67cc6202ebe3ae3ca399c30614da2
3
+ size 215214973
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76b234106f6f3bb58db45b34230262d6b913967ea5c374ffdf61802d9eb6b9fe
3
+ size 1625526529
tf_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94d810ac75433e2ab4280ef644fd5a714b33a441d3aa06eef329a141a039c51b
3
+ size 1625921384
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"model_max_length": 1024}
vocab.json ADDED
The diff for this file is too large to render. See raw diff