ccsasuke commited on
Commit
0594600
1 Parent(s): a5fb3b3

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +117 -0
README.md ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - feature-extraction
4
+ pipeline_tag: feature-extraction
5
+ ---
6
+
7
+ This model is the query encoder of the MS MARCO UniCOIL Lexical Model (Λ) from the SPAR paper:
8
+
9
+ [Salient Phrase Aware Dense Retrieval: Can a Dense Retriever Imitate a Sparse One?](https://arxiv.org/abs/2110.06918)
10
+ <br>
11
+ Xilun Chen, Kushal Lakhotia, Barlas Oğuz, Anchit Gupta, Patrick Lewis, Stan Peshterliev, Yashar Mehdad, Sonal Gupta and Wen-tau Yih
12
+ <br>
13
+ **Meta AI**
14
+
15
+ The associated github repo is available here: https://github.com/facebookresearch/dpr-scale/tree/main/spar
16
+
17
+ This model is a BERT-base sized dense retriever trained on the MS MARCO corpus to imitate the behavior of [UniCOIL](https://arxiv.org/abs/2106.14807), a sparse retriever.
18
+ The following models are also available:
19
+ Pretrained Model | Corpus | Teacher | Architecture | Query Encoder Path | Context Encoder Path
20
+ |---|---|---|---|---|---
21
+ Wiki BM25 Λ | Wikipedia | BM25 | BERT-base | facebook/spar-wiki-bm25-lexmodel-query-encoder | facebook/spar-wiki-bm25-lexmodel-context-encoder
22
+ PAQ BM25 Λ | PAQ | BM25 | BERT-base | facebook/spar-paq-bm25-lexmodel-query-encoder | facebook/spar-paq-bm25-lexmodel-context-encoder
23
+ MARCO BM25 Λ | MS MARCO | BM25 | BERT-base | facebook/spar-marco-bm25-lexmodel-query-encoder | facebook/spar-marco-bm25-lexmodel-context-encoder
24
+ MARCO UniCOIL Λ | MS MARCO | UniCOIL | BERT-base | facebook/spar-marco-unicoil-lexmodel-query-encoder | facebook/spar-marco-unicoil-lexmodel-context-encoder
25
+
26
+ # Using the Lexical Model (Λ) Alone
27
+
28
+ This model should be used together with the associated query encoder, similar to the [DPR](https://huggingface.co/docs/transformers/v4.22.1/en/model_doc/dpr) model.
29
+
30
+ ```
31
+ import torch
32
+ from transformers import AutoTokenizer, AutoModel
33
+
34
+ # The tokenizer is the same for the query and context encoder
35
+ tokenizer = AutoTokenizer.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
36
+ query_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
37
+ context_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-context-encoder')
38
+
39
+ query = "Where was Marie Curie born?"
40
+ contexts = [
41
+ "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
42
+ "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
43
+ ]
44
+
45
+ # Apply tokenizer
46
+ query_input = tokenizer(query, return_tensors='pt')
47
+ ctx_input = tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
48
+
49
+ # Compute embeddings: take the last-layer hidden state of the [CLS] token
50
+ query_emb = query_encoder(**query_input).last_hidden_state[:, 0, :]
51
+ ctx_emb = context_encoder(**ctx_input).last_hidden_state[:, 0, :]
52
+
53
+ # Compute similarity scores using dot product
54
+ score1 = query_emb @ ctx_emb[0] # 341.3268
55
+ score2 = query_emb @ ctx_emb[1] # 340.1626
56
+
57
+ ```
58
+
59
+ # Using the Lexical Model (Λ) with a Base Dense Retriever as in SPAR
60
+ As Λ learns lexical matching from a sparse teacher retriever, it can be used in combination with a standard dense retriever (e.g. [DPR](https://huggingface.co/docs/transformers/v4.22.1/en/model_doc/dpr#dpr), [Contriever](https://huggingface.co/facebook/contriever-msmarco)) to build a dense retriever that excels at both lexical and semantic matching.
61
+
62
+ In the following example, we show how to build the SPAR-Wiki model for Open-Domain Question Answering by concatenating the embeddings of DPR and the Wiki BM25 Λ.
63
+ ```
64
+ import torch
65
+ from transformers import AutoTokenizer, AutoModel
66
+ from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
67
+ from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
68
+
69
+ # DPR model
70
+ dpr_ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
71
+ dpr_ctx_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-multiset-base")
72
+ dpr_query_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-multiset-base")
73
+ dpr_query_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-multiset-base")
74
+
75
+ # Wiki BM25 Λ model
76
+ lexmodel_tokenizer = AutoTokenizer.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
77
+ lexmodel_query_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-query-encoder')
78
+ lexmodel_context_encoder = AutoModel.from_pretrained('facebook/spar-wiki-bm25-lexmodel-context-encoder')
79
+
80
+ query = "Where was Marie Curie born?"
81
+ contexts = [
82
+ "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.",
83
+ "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace."
84
+ ]
85
+
86
+ # Compute DPR embeddings
87
+ dpr_query_input = dpr_query_tokenizer(query, return_tensors='pt')['input_ids']
88
+ dpr_query_emb = dpr_query_encoder(dpr_query_input).pooler_output
89
+ dpr_ctx_input = dpr_ctx_tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
90
+ dpr_ctx_emb = dpr_ctx_encoder(**dpr_ctx_input).pooler_output
91
+
92
+ # Compute Λ embeddings
93
+ lexmodel_query_input = lexmodel_tokenizer(query, return_tensors='pt')
94
+ lexmodel_query_emb = lexmodel_query_encoder(**query_input).last_hidden_state[:, 0, :]
95
+ lexmodel_ctx_input = lexmodel_tokenizer(contexts, padding=True, truncation=True, return_tensors='pt')
96
+ lexmodel_ctx_emb = lexmodel_context_encoder(**ctx_input).last_hidden_state[:, 0, :]
97
+
98
+ # Form SPAR embeddings via concatenation
99
+
100
+ # The concatenation weight is only applied to query embeddings
101
+ # Refer to the SPAR paper for details
102
+ concat_weight = 0.7
103
+
104
+ spar_query_emb = torch.cat(
105
+ [dpr_query_emb, concat_weight * lexmodel_query_emb],
106
+ dim=-1,
107
+ )
108
+ spar_ctx_emb = torch.cat(
109
+ [dpr_ctx_emb, lexmodel_ctx_emb],
110
+ dim=-1,
111
+ )
112
+
113
+ # Compute similarity scores
114
+ score1 = spar_query_emb @ spar_ctx_emb[0] # 317.6931
115
+ score2 = spar_query_emb @ spar_ctx_emb[1] # 314.6144
116
+ ```
117
+