Ramon Meffert commited on
Commit
be1f224
1 Parent(s): b06298d

Add longformer

Browse files
.gitattributes CHANGED
@@ -28,3 +28,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
  *.zip filter=lfs diff=lfs merge=lfs -text
29
  *.zstandard filter=lfs diff=lfs merge=lfs -text
30
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
28
  *.zip filter=lfs diff=lfs merge=lfs -text
29
  *.zstandard filter=lfs diff=lfs merge=lfs -text
30
  *tfevents* filter=lfs diff=lfs merge=lfs -text
31
+ src/models/dpr.faiss filter=lfs diff=lfs merge=lfs -text
32
+ src/models/longformer.faiss filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -75,7 +75,10 @@ By default, the best answer along with its location in the book will be
75
  returned. If you want to generate more answers (say, a top-5), you can supply
76
  the `--top=5` option. The default retriever uses [FAISS](https://faiss.ai/), but
77
  you can also use [ElasticSearch](https://www.elastic.co/elastic-stack/) using
78
- the `--retriever=es` option.
 
 
 
79
 
80
  ### CLI overview
81
 
@@ -83,7 +86,7 @@ To get an overview of all available options, run `python query.py --help`. The
83
  options are also printed below.
84
 
85
  ```sh
86
- usage: query.py [-h] [--top int] [--retriever {faiss,es}] str
87
 
88
  positional arguments:
89
  str The question to feed to the QA system
@@ -93,6 +96,8 @@ options:
93
  --top int, -t int The number of answers to retrieve
94
  --retriever {faiss,es}, -r {faiss,es}
95
  The retrieval method to use
 
 
96
  ```
97
 
98
 
75
  returned. If you want to generate more answers (say, a top-5), you can supply
76
  the `--top=5` option. The default retriever uses [FAISS](https://faiss.ai/), but
77
  you can also use [ElasticSearch](https://www.elastic.co/elastic-stack/) using
78
+ the `--retriever=es` option. You can also pick a language model using the
79
+ `--lm` option, which accepts either `dpr` (Dense Passage Retrieval) or
80
+ `longformer`. The language model is used to generate embeddings for FAISS, and
81
+ is used to generate the answer.
82
 
83
  ### CLI overview
84
 
86
  options are also printed below.
87
 
88
  ```sh
89
+ usage: query.py [-h] [--top int] [--retriever {faiss,es}] [--lm {dpr,longformer}] str
90
 
91
  positional arguments:
92
  str The question to feed to the QA system
96
  --top int, -t int The number of answers to retrieve
97
  --retriever {faiss,es}, -r {faiss,es}
98
  The retrieval method to use
99
+ --lm {dpr,longformer}, -l {dpr,longformer}
100
+ The language model to use for the FAISS retriever
101
  ```
102
 
103
 
query.py CHANGED
@@ -2,21 +2,48 @@ import argparse
2
  import torch
3
  import transformers
4
 
5
- from typing import List, Literal, Union, cast
6
  from datasets import load_dataset, DatasetDict
7
  from dotenv import load_dotenv
8
 
 
 
9
  from src.readers.dpr_reader import DprReader
10
  from src.retrievers.base_retriever import Retriever
11
  from src.retrievers.es_retriever import ESRetriever
12
- from src.retrievers.faiss_retriever import FaissRetriever
 
 
 
13
  from src.utils.preprocessing import context_to_reader_input
14
  from src.utils.log import get_logger
15
 
16
 
17
- def get_retriever(r: Union[Literal["es"], Literal["fais"]], paragraphs: DatasetDict) -> Retriever:
18
- retriever = ESRetriever if r == "es" else FaissRetriever
19
- return retriever(paragraphs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  def print_name(contexts: dict, section: str, id: int):
@@ -51,7 +78,11 @@ def print_answers(answers: List[tuple], scores: List[float], contexts: dict):
51
  print()
52
 
53
 
54
- def probe(query: str, retriever: Retriever, reader: DprReader, num_answers: int = 5):
 
 
 
 
55
  scores, contexts = retriever.retrieve(query)
56
  reader_input = context_to_reader_input(contexts)
57
  answers = reader.read(query, reader_input, num_answers)
@@ -63,7 +94,7 @@ def default_probe(query: str):
63
  # default probe is a probe that prints 5 answers with faiss
64
  paragraphs = cast(DatasetDict, load_dataset(
65
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
66
- retriever = get_retriever("faiss", paragraphs)
67
  reader = DprReader()
68
 
69
  return probe(query, retriever, reader)
@@ -75,13 +106,20 @@ def main(args: argparse.Namespace):
75
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
76
 
77
  # Retrieve
78
- retriever = get_retriever(args.retriever, paragraphs)
79
- reader = DprReader()
80
  answers, scores, contexts = probe(
81
- args.query, retriever, reader, args.num_answers)
82
 
83
  # Print output
84
- print_answers(answers, scores, contexts)
 
 
 
 
 
 
 
85
 
86
 
87
  if __name__ == "__main__":
@@ -94,13 +132,18 @@ if __name__ == "__main__":
94
  parser = argparse.ArgumentParser(
95
  formatter_class=argparse.MetavarTypeHelpFormatter
96
  )
97
- parser.add_argument("query", type=str,
98
- help="The question to feed to the QA system")
99
- parser.add_argument("--top", "-t", type=int, default=1,
100
- help="The number of answers to retrieve")
101
- parser.add_argument("--retriever", "-r", type=str.lower,
102
- choices=["faiss", "es"], default="faiss",
103
- help="The retrieval method to use")
 
 
 
 
 
104
 
105
  args = parser.parse_args()
106
  main(args)
2
  import torch
3
  import transformers
4
 
5
+ from typing import Dict, List, Literal, Tuple, cast
6
  from datasets import load_dataset, DatasetDict
7
  from dotenv import load_dotenv
8
 
9
+ from src.readers.base_reader import Reader
10
+ from src.readers.longformer_reader import LongformerReader
11
  from src.readers.dpr_reader import DprReader
12
  from src.retrievers.base_retriever import Retriever
13
  from src.retrievers.es_retriever import ESRetriever
14
+ from src.retrievers.faiss_retriever import (
15
+ FaissRetriever,
16
+ FaissRetrieverOptions
17
+ )
18
  from src.utils.preprocessing import context_to_reader_input
19
  from src.utils.log import get_logger
20
 
21
 
22
+ def get_retriever(paragraphs: DatasetDict,
23
+ r: Literal["es", "faiss"],
24
+ lm: Literal["dpr", "longformer"]) -> Retriever:
25
+ match (r, lm):
26
+ case "es", _:
27
+ return ESRetriever()
28
+ case "faiss", "dpr":
29
+ options = FaissRetrieverOptions.dpr("./src/models/dpr.faiss")
30
+ return FaissRetriever(paragraphs, options)
31
+ case "faiss", "longformer":
32
+ options = FaissRetrieverOptions.longformer(
33
+ "./src/models/longformer.faiss")
34
+ return FaissRetriever(paragraphs, options)
35
+ case _:
36
+ raise ValueError("Retriever options not recognized")
37
+
38
+
39
+ def get_reader(lm: Literal["dpr", "longformer"]) -> Reader:
40
+ match lm:
41
+ case "dpr":
42
+ return DprReader()
43
+ case "longformer":
44
+ return LongformerReader()
45
+ case _:
46
+ raise ValueError("Language model not recognized")
47
 
48
 
49
  def print_name(contexts: dict, section: str, id: int):
78
  print()
79
 
80
 
81
+ def probe(query: str,
82
+ retriever: Retriever,
83
+ reader: Reader,
84
+ num_answers: int = 5) \
85
+ -> Tuple[List[tuple], List[float], Dict[str, List[str]]]:
86
  scores, contexts = retriever.retrieve(query)
87
  reader_input = context_to_reader_input(contexts)
88
  answers = reader.read(query, reader_input, num_answers)
94
  # default probe is a probe that prints 5 answers with faiss
95
  paragraphs = cast(DatasetDict, load_dataset(
96
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
97
+ retriever = get_retriever(paragraphs, "faiss", "dpr")
98
  reader = DprReader()
99
 
100
  return probe(query, retriever, reader)
106
  "GroNLP/ik-nlp-22_slp", "paragraphs"))
107
 
108
  # Retrieve
109
+ retriever = get_retriever(paragraphs, args.retriever, args.lm)
110
+ reader = get_reader(args.lm)
111
  answers, scores, contexts = probe(
112
+ args.query, retriever, reader, args.top)
113
 
114
  # Print output
115
+ print("Question: " + args.query)
116
+ print("Answer(s):")
117
+ if args.lm == "dpr":
118
+ print_answers(answers, scores, contexts)
119
+ else:
120
+ answers = filter(lambda a: len(a[0].strip()) > 0, answers)
121
+ for pos, answer in enumerate(answers, start=1):
122
+ print(f" - {answer[0].strip()}")
123
 
124
 
125
  if __name__ == "__main__":
132
  parser = argparse.ArgumentParser(
133
  formatter_class=argparse.MetavarTypeHelpFormatter
134
  )
135
+ parser.add_argument(
136
+ "query", type=str, help="The question to feed to the QA system")
137
+ parser.add_argument(
138
+ "--top", "-t", type=int, default=1,
139
+ help="The number of answers to retrieve")
140
+ parser.add_argument(
141
+ "--retriever", "-r", type=str.lower, choices=["faiss", "es"],
142
+ default="faiss", help="The retrieval method to use")
143
+ parser.add_argument(
144
+ "--lm", "-l", type=str.lower,
145
+ choices=["dpr", "longformer"], default="dpr",
146
+ help="The language model to use for the FAISS retriever")
147
 
148
  args = parser.parse_args()
149
  main(args)
src/models/{paragraphs_embedding.faiss → dpr.faiss} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4fddf97865d5b1a967df90b7e2808bd27510cce633d55ed2af8328619828b168
3
  size 5213229
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bc0e5c38ddeb0a6a4daaf3ae98cd3e564f22ff9a263bc8867d0b363e828ccce
3
  size 5213229
src/models/longformer.faiss ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56b2616392540f4d2d8fa34d313a59c41572dca3ef5a683c7a8dbd2691418ea6
3
+ size 5213229
src/readers/base_reader.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple
2
+
3
+
4
+ class Reader():
5
+ def read(self,
6
+ query: str,
7
+ context: Dict[str, List[str]],
8
+ num_answers: int) -> List[Tuple]:
9
+ raise NotImplementedError()
src/readers/dpr_reader.py CHANGED
@@ -1,8 +1,10 @@
1
  from transformers import DPRReader, DPRReaderTokenizer
2
  from typing import List, Dict, Tuple
3
 
 
4
 
5
- class DprReader():
 
6
  def __init__(self) -> None:
7
  self._tokenizer = DPRReaderTokenizer.from_pretrained(
8
  "facebook/dpr-reader-single-nq-base")
1
  from transformers import DPRReader, DPRReaderTokenizer
2
  from typing import List, Dict, Tuple
3
 
4
+ from src.readers.base_reader import Reader
5
 
6
+
7
+ class DprReader(Reader):
8
  def __init__(self) -> None:
9
  self._tokenizer = DPRReaderTokenizer.from_pretrained(
10
  "facebook/dpr-reader-single-nq-base")
src/readers/longformer_reader.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ LongformerTokenizerFast,
4
+ LongformerForQuestionAnswering
5
+ )
6
+ from typing import List, Dict, Tuple
7
+
8
+ from src.readers.base_reader import Reader
9
+
10
+
11
+ class LongformerReader(Reader):
12
+ def __init__(self) -> None:
13
+ checkpoint = "valhalla/longformer-base-4096-finetuned-squadv1"
14
+ self.tokenizer = LongformerTokenizerFast.from_pretrained(checkpoint)
15
+ self.model = LongformerForQuestionAnswering.from_pretrained(checkpoint)
16
+
17
+ def read(self,
18
+ query: str,
19
+ context: Dict[str, List[str]],
20
+ num_answers=5) -> List[Tuple]:
21
+ answers = []
22
+
23
+ for text in context['texts']:
24
+ encoding = self.tokenizer(
25
+ query, text, return_tensors="pt")
26
+ input_ids = encoding["input_ids"]
27
+ attention_mask = encoding["attention_mask"]
28
+ outputs = self.model(input_ids, attention_mask=attention_mask)
29
+
30
+ start_logits = outputs.start_logits
31
+ end_logits = outputs.end_logits
32
+ all_tokens = self.tokenizer.convert_ids_to_tokens(
33
+ input_ids[0].tolist())
34
+ answer_tokens = all_tokens[
35
+ torch.argmax(start_logits):torch.argmax(end_logits) + 1]
36
+ answer = self.tokenizer.decode(
37
+ self.tokenizer.convert_tokens_to_ids(answer_tokens)
38
+ )
39
+ answers.append([answer, [], []])
40
+
41
+ return answers
src/retrievers/faiss_retriever.py CHANGED
@@ -1,14 +1,19 @@
1
  import os
2
  import os.path
3
-
4
  import torch
5
- from datasets import DatasetDict, load_dataset
 
 
6
  from transformers import (
7
  DPRContextEncoder,
8
- DPRContextEncoderTokenizer,
9
  DPRQuestionEncoder,
10
- DPRQuestionEncoderTokenizer,
 
 
11
  )
 
 
12
 
13
  from src.retrievers.base_retriever import RetrieveType, Retriever
14
  from src.utils.log import get_logger
@@ -23,35 +28,99 @@ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
23
  logger = get_logger()
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  class FaissRetriever(Retriever):
27
  """A class used to retrieve relevant documents based on some query.
28
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
29
  """
30
 
31
- def __init__(self, paragraphs: DatasetDict, embedding_path: str = "./src/models/paragraphs_embedding.faiss") -> None:
 
32
  torch.set_grad_enabled(False)
33
 
 
 
34
  # Context encoding and tokenization
35
- self.ctx_encoder = DPRContextEncoder.from_pretrained(
36
- "facebook/dpr-ctx_encoder-single-nq-base"
37
- )
38
- self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
39
- "facebook/dpr-ctx_encoder-single-nq-base"
40
- )
41
 
42
  # Question encoding and tokenization
43
- self.q_encoder = DPRQuestionEncoder.from_pretrained(
44
- "facebook/dpr-question_encoder-single-nq-base"
45
- )
46
- self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
47
- "facebook/dpr-question_encoder-single-nq-base"
48
- )
49
 
50
  self.paragraphs = paragraphs
51
- self.embedding_path = embedding_path
52
 
53
  self.index = self._init_index()
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def _init_index(
56
  self,
57
  force_new_embedding: bool = False):
@@ -64,16 +133,8 @@ class FaissRetriever(Retriever):
64
  'embeddings', self.embedding_path) # type: ignore
65
  return ds
66
  else:
67
- def embed(row):
68
- # Inline helper function to perform embedding
69
- p = row["text"]
70
- tok = self.ctx_tokenizer(
71
- p, return_tensors="pt", truncation=True)
72
- enc = self.ctx_encoder(**tok)[0][0].numpy()
73
- return {"embeddings": enc}
74
-
75
  # Add FAISS embeddings
76
- index = ds.map(embed) # type: ignore
77
 
78
  index.add_faiss_index(column="embeddings")
79
 
@@ -86,12 +147,7 @@ class FaissRetriever(Retriever):
86
 
87
  @timeit("faissretriever.retrieve")
88
  def retrieve(self, query: str, k: int = 5) -> RetrieveType:
89
- def embed(q):
90
- # Inline helper function to perform embedding
91
- tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
92
- return self.q_encoder(**tok)[0][0].numpy()
93
-
94
- question_embedding = embed(query)
95
  scores, results = self.index.get_nearest_examples(
96
  "embeddings", question_embedding, k=k
97
  )
1
  import os
2
  import os.path
 
3
  import torch
4
+
5
+ from datasets import DatasetDict
6
+ from dataclasses import dataclass
7
  from transformers import (
8
  DPRContextEncoder,
9
+ DPRContextEncoderTokenizerFast,
10
  DPRQuestionEncoder,
11
+ DPRQuestionEncoderTokenizerFast,
12
+ LongformerModel,
13
+ LongformerTokenizerFast
14
  )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
17
 
18
  from src.retrievers.base_retriever import RetrieveType, Retriever
19
  from src.utils.log import get_logger
28
  logger = get_logger()
29
 
30
 
31
+ @dataclass
32
+ class FaissRetrieverOptions:
33
+ ctx_encoder: PreTrainedModel
34
+ ctx_tokenizer: PreTrainedTokenizerFast
35
+ q_encoder: PreTrainedModel
36
+ q_tokenizer: PreTrainedTokenizerFast
37
+ embedding_path: str
38
+ lm: str
39
+
40
+ @staticmethod
41
+ def dpr(embedding_path: str):
42
+ return FaissRetrieverOptions(
43
+ ctx_encoder=DPRContextEncoder.from_pretrained(
44
+ "facebook/dpr-ctx_encoder-single-nq-base"
45
+ ),
46
+ ctx_tokenizer=DPRContextEncoderTokenizerFast.from_pretrained(
47
+ "facebook/dpr-ctx_encoder-single-nq-base"
48
+ ),
49
+ q_encoder=DPRQuestionEncoder.from_pretrained(
50
+ "facebook/dpr-question_encoder-single-nq-base"
51
+ ),
52
+ q_tokenizer=DPRQuestionEncoderTokenizerFast.from_pretrained(
53
+ "facebook/dpr-question_encoder-single-nq-base"
54
+ ),
55
+ embedding_path=embedding_path,
56
+ lm="dpr"
57
+ )
58
+
59
+ @staticmethod
60
+ def longformer(embedding_path: str):
61
+ encoder = LongformerModel.from_pretrained(
62
+ "allenai/longformer-base-4096"
63
+ )
64
+ tokenizer = LongformerTokenizerFast.from_pretrained(
65
+ "allenai/longformer-base-4096"
66
+ )
67
+ return FaissRetrieverOptions(
68
+ ctx_encoder=encoder,
69
+ ctx_tokenizer=tokenizer,
70
+ q_encoder=encoder,
71
+ q_tokenizer=tokenizer,
72
+ embedding_path=embedding_path,
73
+ lm="longformer"
74
+ )
75
+
76
+
77
  class FaissRetriever(Retriever):
78
  """A class used to retrieve relevant documents based on some query.
79
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
80
  """
81
 
82
+ def __init__(self, paragraphs: DatasetDict,
83
+ options: FaissRetrieverOptions) -> None:
84
  torch.set_grad_enabled(False)
85
 
86
+ self.lm = options.lm
87
+
88
  # Context encoding and tokenization
89
+ self.ctx_encoder = options.ctx_encoder
90
+ self.ctx_tokenizer = options.ctx_tokenizer
 
 
 
 
91
 
92
  # Question encoding and tokenization
93
+ self.q_encoder = options.q_encoder
94
+ self.q_tokenizer = options.q_tokenizer
 
 
 
 
95
 
96
  self.paragraphs = paragraphs
97
+ self.embedding_path = options.embedding_path
98
 
99
  self.index = self._init_index()
100
 
101
+ def _embed_question(self, q):
102
+ match self.lm:
103
+ case "dpr":
104
+ tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
105
+ return self.q_encoder(**tok)[0][0].numpy()
106
+ case "longformer":
107
+ tok = self.q_tokenizer(q, return_tensors="pt")
108
+ return self.q_encoder(**tok).last_hidden_state[0][0].numpy()
109
+
110
+ def _embed_context(self, row):
111
+ p = row["text"]
112
+
113
+ match self.lm:
114
+ case "dpr":
115
+ tok = self.ctx_tokenizer(
116
+ p, return_tensors="pt", truncation=True)
117
+ enc = self.ctx_encoder(**tok)[0][0].numpy()
118
+ return {"embeddings": enc}
119
+ case "longformer":
120
+ tok = self.ctx_tokenizer(p, return_tensors="pt")
121
+ enc = self.ctx_encoder(**tok).last_hidden_state[0][0].numpy()
122
+ return {"embeddings": enc}
123
+
124
  def _init_index(
125
  self,
126
  force_new_embedding: bool = False):
133
  'embeddings', self.embedding_path) # type: ignore
134
  return ds
135
  else:
 
 
 
 
 
 
 
 
136
  # Add FAISS embeddings
137
+ index = ds.map(self._embed_context) # type: ignore
138
 
139
  index.add_faiss_index(column="embeddings")
140
 
147
 
148
  @timeit("faissretriever.retrieve")
149
  def retrieve(self, query: str, k: int = 5) -> RetrieveType:
150
+ question_embedding = self._embed_question(query)
 
 
 
 
 
151
  scores, results = self.index.get_nearest_examples(
152
  "embeddings", question_embedding, k=k
153
  )