Ramon Meffert commited on
Commit
83870cc
1 Parent(s): 8bbe3aa

Add base model retriever

Browse files
README.md CHANGED
@@ -25,3 +25,51 @@ De meeste QA systemen bestaan uit twee onderdelen:
25
 
26
  - Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
27
  - Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  - Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
27
  - Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
28
+
29
+ ## Base model
30
+
31
+ Tot nu toe alleen een retriever die adhv een vraag de top-k relevante documents
32
+ ophaalt. Haalt voor veel vragen wel hoge similarity scores, maar de documents
33
+ die die ophaalt zijn meestal niet erg relevant.
34
+
35
+ ```bash
36
+ poetry shell
37
+ cd base_model
38
+ poetry run python main.py
39
+ ```
40
+
41
+ ### Voorbeeld
42
+
43
+ "What is the perplexity of a language model?"
44
+
45
+ > Result 1 (score: 74.10):
46
+ > Figure 10 .17 A sample alignment between sentences in English and French, with
47
+ > sentences extracted from Antoine de Saint-Exupery's Le Petit Prince and a
48
+ > hypothetical translation. Sentence alignment takes sentences e 1 , ..., e n ,
49
+ > and f 1 , ..., f n and finds minimal > sets of sentences that are translations
50
+ > of each other, including single sentence mappings like (e 1 ,f 1 ), (e 4 -f 3
51
+ > ), (e 5 -f 4 ), (e 6 -f 6 ) as well as 2-1 alignments (e 2 /e 3 ,f 2 ), (e 7
52
+ > /e 8 -f 7 ), and null alignments (f 5 ).
53
+ >
54
+ > Result 2 (score: 74.23):
55
+ > Character or word overlap-based metrics like chrF (or BLEU, or etc.) are
56
+ > mainly used to compare two systems, with the goal of answering questions like:
57
+ > did the new algorithm we just invented improve our MT system? To know if the
58
+ > difference between the chrF scores of two > MT systems is a significant
59
+ > difference, we use the paired bootstrap test, or the similar randomization
60
+ > test.
61
+ >
62
+ > Result 3 (score: 74.43):
63
+ > The model thus predicts the class negative for the test sentence.
64
+ >
65
+ > Result 4 (score: 74.95):
66
+ > Translating from languages with extensive pro-drop, like Chinese or Japanese,
67
+ > to non-pro-drop languages like English can be difficult since the model must
68
+ > somehow identify each zero and recover who or what is being talked about in
69
+ > order to insert the proper pronoun.
70
+ >
71
+ > Result 5 (score: 76.22):
72
+ > Similarly, a recent challenge set, the WinoMT dataset (Stanovsky et al., 2019)
73
+ > shows that MT systems perform worse when they are asked to translate sentences
74
+ > that describe people with non-stereotypical gender roles, like "The doctor
75
+ > asked the nurse to help her in the > operation".
main.py → base_model/main.py RENAMED
@@ -1,14 +1,15 @@
1
- from base_model.retriever import Retriever
 
2
 
3
  if __name__ == '__main__':
4
  # Initialize retriever
5
  r = Retriever()
6
 
7
  # Retrieve example
8
- retrieved = r.retrieve(
9
- "When is a stochastic process said to be stationary?")
10
 
11
- for i, (score, result) in enumerate(retrieved):
12
- print(f"Result {i+1} (score: {score * 100:.02f}:")
13
- print(result['text'][0])
14
  print() # Newline
 
1
+ from retriever import Retriever
2
+
3
 
4
  if __name__ == '__main__':
5
  # Initialize retriever
6
  r = Retriever()
7
 
8
  # Retrieve example
9
+ scores, result = r.retrieve(
10
+ "What is the perplexity of a language model?")
11
 
12
+ for i, score in enumerate(scores):
13
+ print(f"Result {i+1} (score: {score:.02f}):")
14
+ print(result['text'][i])
15
  print() # Newline
base_model/retriever.py CHANGED
@@ -1,10 +1,21 @@
1
- from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, \
2
- DPRQuestionEncoder, DPRQuestionEncoderTokenizer
 
 
 
 
3
  from datasets import load_dataset
4
  import torch
 
5
 
 
 
 
6
 
7
- class Retriever():
 
 
 
8
  """A class used to retrieve relevant documents based on some query.
9
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
10
  """
@@ -21,47 +32,64 @@ class Retriever():
21
 
22
  # Context encoding and tokenization
23
  self.ctx_encoder = DPRContextEncoder.from_pretrained(
24
- "facebook/dpr-ctx_encoder-single-nq-base")
 
25
  self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
26
- "facebook/dpr-ctx_encoder-single-nq-base")
 
27
 
28
  # Question encoding and tokenization
29
  self.q_encoder = DPRQuestionEncoder.from_pretrained(
30
- "facebook/dpr-question_encoder-single-nq-base")
 
31
  self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
32
- "facebook/dpr-question_encoder-single-nq-base")
 
33
 
34
  # Dataset building
35
  self.dataset = self.__init_dataset(dataset)
36
 
37
- def __init_dataset(self, dataset: str):
 
 
38
  """Loads the dataset and adds FAISS embeddings.
39
 
40
  Args:
41
  dataset (str): A HuggingFace dataset name.
 
 
42
 
43
  Returns:
44
  Dataset: A dataset with a new column 'embeddings' containing FAISS
45
  embeddings.
46
  """
47
- # TODO: save ds w/ embeddings to disk and retrieve it if it already exists
48
-
49
  # Load dataset
50
- ds = load_dataset(dataset, name='paragraphs')['train']
51
 
52
- def embed(row):
53
- # Inline helper function to perform embedding
54
- p = row['text']
55
- tok = self.ctx_tokenizer(p, return_tensors='pt', truncation=True)
56
- enc = self.ctx_encoder(**tok)[0][0].numpy()
57
- return {'embeddings': enc}
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Add FAISS embeddings
60
- ds_with_embeddings = ds.map(embed)
 
61
 
62
- # Todo: this throws a weird error.
63
- ds_with_embeddings.add_faiss_index(column='embeddings')
64
- return ds_with_embeddings
65
 
66
  def retrieve(self, query: str, k: int = 5):
67
  """Retrieve the top k matches for a search query.
@@ -77,10 +105,11 @@ class Retriever():
77
 
78
  def embed(q):
79
  # Inline helper function to perform embedding
80
- tok = self.q_tokenizer(q, return_tensors='pt', truncation=True)
81
  return self.q_encoder(**tok)[0][0].numpy()
82
 
83
  question_embedding = embed(query)
84
  scores, results = self.dataset.get_nearest_examples(
85
- 'embeddings', question_embedding, k=k)
 
86
  return scores, results
 
1
+ from transformers import (
2
+ DPRContextEncoder,
3
+ DPRContextEncoderTokenizer,
4
+ DPRQuestionEncoder,
5
+ DPRQuestionEncoderTokenizer,
6
+ )
7
  from datasets import load_dataset
8
  import torch
9
+ import os.path
10
 
11
+ # Hacky fix for FAISS error on macOS
12
+ # See https://stackoverflow.com/a/63374568/4545692
13
+ import os
14
 
15
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
16
+
17
+
18
+ class Retriever:
19
  """A class used to retrieve relevant documents based on some query.
20
  based on https://huggingface.co/docs/datasets/faiss_es#faiss.
21
  """
 
32
 
33
  # Context encoding and tokenization
34
  self.ctx_encoder = DPRContextEncoder.from_pretrained(
35
+ "facebook/dpr-ctx_encoder-single-nq-base"
36
+ )
37
  self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
38
+ "facebook/dpr-ctx_encoder-single-nq-base"
39
+ )
40
 
41
  # Question encoding and tokenization
42
  self.q_encoder = DPRQuestionEncoder.from_pretrained(
43
+ "facebook/dpr-question_encoder-single-nq-base"
44
+ )
45
  self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
46
+ "facebook/dpr-question_encoder-single-nq-base"
47
+ )
48
 
49
  # Dataset building
50
  self.dataset = self.__init_dataset(dataset)
51
 
52
+ def __init_dataset(self,
53
+ dataset: str,
54
+ fname: str = "./models/paragraphs_embedding.faiss"):
55
  """Loads the dataset and adds FAISS embeddings.
56
 
57
  Args:
58
  dataset (str): A HuggingFace dataset name.
59
+ fname (str): The name to use to save the embeddings to disk for
60
+ faster loading after the first run.
61
 
62
  Returns:
63
  Dataset: A dataset with a new column 'embeddings' containing FAISS
64
  embeddings.
65
  """
 
 
66
  # Load dataset
67
+ ds = load_dataset(dataset, name="paragraphs")["train"]
68
 
69
+ if os.path.exists(fname):
70
+ # If we already have FAISS embeddings, load them from disk
71
+ ds.load_faiss_index('embeddings', fname)
72
+ return ds
73
+ else:
74
+ # If there are no FAISS embeddings, generate them
75
+ def embed(row):
76
+ # Inline helper function to perform embedding
77
+ p = row["text"]
78
+ tok = self.ctx_tokenizer(
79
+ p, return_tensors="pt", truncation=True)
80
+ enc = self.ctx_encoder(**tok)[0][0].numpy()
81
+ return {"embeddings": enc}
82
+
83
+ # Add FAISS embeddings
84
+ ds_with_embeddings = ds.map(embed)
85
+
86
+ ds_with_embeddings.add_faiss_index(column="embeddings")
87
 
88
+ # save dataset w/ embeddings
89
+ os.makedirs("./models/", exist_ok=True)
90
+ ds_with_embeddings.save_faiss_index("embeddings", fname)
91
 
92
+ return ds_with_embeddings
 
 
93
 
94
  def retrieve(self, query: str, k: int = 5):
95
  """Retrieve the top k matches for a search query.
 
105
 
106
  def embed(q):
107
  # Inline helper function to perform embedding
108
+ tok = self.q_tokenizer(q, return_tensors="pt", truncation=True)
109
  return self.q_encoder(**tok)[0][0].numpy()
110
 
111
  question_embedding = embed(query)
112
  scores, results = self.dataset.get_nearest_examples(
113
+ "embeddings", question_embedding, k=k
114
+ )
115
  return scores, results
poetry.lock CHANGED
@@ -51,6 +51,18 @@ docs = ["furo", "sphinx", "zope.interface", "sphinx-notfound-page"]
51
  tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
52
  tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
53
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  [[package]]
55
  name = "certifi"
56
  version = "2021.10.8"
@@ -460,6 +472,14 @@ python-versions = "*"
460
  docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
461
  testing = ["pytest", "requests", "numpy", "datasets"]
462
 
 
 
 
 
 
 
 
 
463
  [[package]]
464
  name = "torch"
465
  version = "1.11.0"
@@ -590,7 +610,7 @@ multidict = ">=4.0"
590
  [metadata]
591
  lock-version = "1.1"
592
  python-versions = "^3.8"
593
- content-hash = "9f99ff0196acf862c585450123952a4d10e93ce9dddd7222ca43dd8076451fb3"
594
 
595
  [metadata.files]
596
  aiohttp = [
@@ -679,6 +699,10 @@ attrs = [
679
  {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
680
  {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
681
  ]
 
 
 
 
682
  certifi = [
683
  {file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"},
684
  {file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"},
@@ -1161,6 +1185,10 @@ tokenizers = [
1161
  {file = "tokenizers-0.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b28966c68a2cdecd5120f4becea159eebe0335b8202e21e292eb381031026edc"},
1162
  {file = "tokenizers-0.11.6.tar.gz", hash = "sha256:562b2022faf0882586c915385620d1f11798fc1b32bac55353a530132369a6d0"},
1163
  ]
 
 
 
 
1164
  torch = [
1165
  {file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
1166
  {file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
 
51
  tests = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "zope.interface", "cloudpickle"]
52
  tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "mypy", "pytest-mypy-plugins", "cloudpickle"]
53
 
54
+ [[package]]
55
+ name = "autopep8"
56
+ version = "1.6.0"
57
+ description = "A tool that automatically formats Python code to conform to the PEP 8 style guide"
58
+ category = "dev"
59
+ optional = false
60
+ python-versions = "*"
61
+
62
+ [package.dependencies]
63
+ pycodestyle = ">=2.8.0"
64
+ toml = "*"
65
+
66
  [[package]]
67
  name = "certifi"
68
  version = "2021.10.8"
 
472
  docs = ["sphinx", "sphinx-rtd-theme", "setuptools-rust"]
473
  testing = ["pytest", "requests", "numpy", "datasets"]
474
 
475
+ [[package]]
476
+ name = "toml"
477
+ version = "0.10.2"
478
+ description = "Python Library for Tom's Obvious, Minimal Language"
479
+ category = "dev"
480
+ optional = false
481
+ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
482
+
483
  [[package]]
484
  name = "torch"
485
  version = "1.11.0"
 
610
  [metadata]
611
  lock-version = "1.1"
612
  python-versions = "^3.8"
613
+ content-hash = "227b922ee14abf36ca75bb238d239d712bed9213d54c567996566d465e465733"
614
 
615
  [metadata.files]
616
  aiohttp = [
 
699
  {file = "attrs-21.4.0-py2.py3-none-any.whl", hash = "sha256:2d27e3784d7a565d36ab851fe94887c5eccd6a463168875832a1be79c82828b4"},
700
  {file = "attrs-21.4.0.tar.gz", hash = "sha256:626ba8234211db98e869df76230a137c4c40a12d72445c45d5f5b716f076e2fd"},
701
  ]
702
+ autopep8 = [
703
+ {file = "autopep8-1.6.0-py2.py3-none-any.whl", hash = "sha256:ed77137193bbac52d029a52c59bec1b0629b5a186c495f1eb21b126ac466083f"},
704
+ {file = "autopep8-1.6.0.tar.gz", hash = "sha256:44f0932855039d2c15c4510d6df665e4730f2b8582704fa48f9c55bd3e17d979"},
705
+ ]
706
  certifi = [
707
  {file = "certifi-2021.10.8-py2.py3-none-any.whl", hash = "sha256:d62a0163eb4c2344ac042ab2bdf75399a71a2d8c7d47eac2e2ee91b9d6339569"},
708
  {file = "certifi-2021.10.8.tar.gz", hash = "sha256:78884e7c1d4b00ce3cea67b44566851c4343c120abd683433ce934a68ea58872"},
 
1185
  {file = "tokenizers-0.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b28966c68a2cdecd5120f4becea159eebe0335b8202e21e292eb381031026edc"},
1186
  {file = "tokenizers-0.11.6.tar.gz", hash = "sha256:562b2022faf0882586c915385620d1f11798fc1b32bac55353a530132369a6d0"},
1187
  ]
1188
+ toml = [
1189
+ {file = "toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b"},
1190
+ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
1191
+ ]
1192
  torch = [
1193
  {file = "torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:62052b50fffc29ca7afc0c04ef8206b6f1ca9d10629cb543077e12967e8d0398"},
1194
  {file = "torch-1.11.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:866bfba29ac98dec35d893d8e17eaec149d0ac7a53be7baae5c98069897db667"},
pyproject.toml CHANGED
@@ -14,6 +14,7 @@ faiss-cpu = "^1.7.2"
14
 
15
  [tool.poetry.dev-dependencies]
16
  flake8 = "^4.0.1"
 
17
 
18
  [build-system]
19
  requires = ["poetry-core>=1.0.0"]
 
14
 
15
  [tool.poetry.dev-dependencies]
16
  flake8 = "^4.0.1"
17
+ autopep8 = "^1.6.0"
18
 
19
  [build-system]
20
  requires = ["poetry-core>=1.0.0"]