Ramon Meffert commited on
Commit
8bbe3aa
1 Parent(s): 7177a08

Add retriever based on DPR (WIP)

Browse files
Files changed (7) hide show
  1. .gitignore +3 -0
  2. README.md +20 -7
  3. base_model/reader.py +2 -0
  4. base_model/retriever.py +86 -0
  5. main.py +14 -0
  6. poetry.lock +0 -0
  7. pyproject.toml +5 -0
.gitignore CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[cod]
 
1
+ # VS Code stuff
2
+ .vscode/
3
+
4
  # Byte-compiled / optimized / DLL files
5
  __pycache__/
6
  *.py[cod]
README.md CHANGED
@@ -1,14 +1,27 @@
1
  # nlp-flashcard-project
2
 
3
-
4
  ## Todo voor progress meeting
5
 
6
- - Data inlezen/Repo klaarmaken
7
- - Proof of concept met UnifiedQA
8
- - Standaard QA model met de dataset
9
- - Papers verzamelen/lezen
10
- - Eerder werk bekijken, inspiratie opdoen voor research richting
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  ## Handige info
13
 
14
- Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
 
 
1
  # nlp-flashcard-project
2
 
 
3
  ## Todo voor progress meeting
4
 
5
+ - [ ] Data inlezen/Repo klaarmaken
6
+ - [ ] Proof of concept met UnifiedQA
7
+ - [ ] Standaard QA model met de dataset
8
+ - [ ] Papers verzamelen/lezen
9
+ - [ ] Eerder werk bekijken, inspiratie opdoen voor research richting
10
+
11
+ ## Overview
12
+
13
+ De meeste QA systemen bestaan uit twee onderdelen:
14
+
15
+ - Een retriever. Die haalt adhv de vraag _k_ relevante stukken context op, bv.
16
+ met `tf-idf`.
17
+ - Een model dat het antwoord genereert. Wat je hier precies gebruikt hangt af
18
+ van de manier van question answering:
19
+ - Voor **extractive QA** gebruik je een reader;
20
+ - Voor **generative QA** gebruik je een generator.
21
+
22
+ Beide werken op basis van een language model.
23
 
24
  ## Handige info
25
 
26
+ - Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
27
+ - Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
base_model/reader.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ class Reader():
2
+ pass
base_model/retriever.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, \
2
+ DPRQuestionEncoder, DPRQuestionEncoderTokenizer
3
+ from datasets import load_dataset
4
+ import torch
5
+
6
+
7
+ class Retriever():
8
+ """A class used to retrieve relevant documents based on some query.
9
+ based on https://huggingface.co/docs/datasets/faiss_es#faiss.
10
+ """
11
+
12
+ def __init__(self, dataset: str = "GroNLP/ik-nlp-22_slp") -> None:
13
+ """Initialize the retriever
14
+
15
+ Args:
16
+ dataset (str, optional): The dataset to train on. Assumes the
17
+ information is stored in a column named 'text'. Defaults to
18
+ "GroNLP/ik-nlp-22_slp".
19
+ """
20
+ torch.set_grad_enabled(False)
21
+
22
+ # Context encoding and tokenization
23
+ self.ctx_encoder = DPRContextEncoder.from_pretrained(
24
+ "facebook/dpr-ctx_encoder-single-nq-base")
25
+ self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
26
+ "facebook/dpr-ctx_encoder-single-nq-base")
27
+
28
+ # Question encoding and tokenization
29
+ self.q_encoder = DPRQuestionEncoder.from_pretrained(
30
+ "facebook/dpr-question_encoder-single-nq-base")
31
+ self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
32
+ "facebook/dpr-question_encoder-single-nq-base")
33
+
34
+ # Dataset building
35
+ self.dataset = self.__init_dataset(dataset)
36
+
37
+ def __init_dataset(self, dataset: str):
38
+ """Loads the dataset and adds FAISS embeddings.
39
+
40
+ Args:
41
+ dataset (str): A HuggingFace dataset name.
42
+
43
+ Returns:
44
+ Dataset: A dataset with a new column 'embeddings' containing FAISS
45
+ embeddings.
46
+ """
47
+ # TODO: save ds w/ embeddings to disk and retrieve it if it already exists
48
+
49
+ # Load dataset
50
+ ds = load_dataset(dataset, name='paragraphs')['train']
51
+
52
+ def embed(row):
53
+ # Inline helper function to perform embedding
54
+ p = row['text']
55
+ tok = self.ctx_tokenizer(p, return_tensors='pt', truncation=True)
56
+ enc = self.ctx_encoder(**tok)[0][0].numpy()
57
+ return {'embeddings': enc}
58
+
59
+ # Add FAISS embeddings
60
+ ds_with_embeddings = ds.map(embed)
61
+
62
+ # Todo: this throws a weird error.
63
+ ds_with_embeddings.add_faiss_index(column='embeddings')
64
+ return ds_with_embeddings
65
+
66
+ def retrieve(self, query: str, k: int = 5):
67
+ """Retrieve the top k matches for a search query.
68
+
69
+ Args:
70
+ query (str): A search query
71
+ k (int, optional): The number of documents to retrieve. Defaults to
72
+ 5.
73
+
74
+ Returns:
75
+ tuple: A tuple of lists of scores and results.
76
+ """
77
+
78
+ def embed(q):
79
+ # Inline helper function to perform embedding
80
+ tok = self.q_tokenizer(q, return_tensors='pt', truncation=True)
81
+ return self.q_encoder(**tok)[0][0].numpy()
82
+
83
+ question_embedding = embed(query)
84
+ scores, results = self.dataset.get_nearest_examples(
85
+ 'embeddings', question_embedding, k=k)
86
+ return scores, results
main.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base_model.retriever import Retriever
2
+
3
+ if __name__ == '__main__':
4
+ # Initialize retriever
5
+ r = Retriever()
6
+
7
+ # Retrieve example
8
+ retrieved = r.retrieve(
9
+ "When is a stochastic process said to be stationary?")
10
+
11
+ for i, (score, result) in enumerate(retrieved):
12
+ print(f"Result {i+1} (score: {score * 100:.02f}:")
13
+ print(result['text'][0])
14
+ print() # Newline
poetry.lock CHANGED
The diff for this file is too large to render. See raw diff
 
pyproject.toml CHANGED
@@ -7,8 +7,13 @@ authors = ["Your Name <you@example.com>"]
7
  [tool.poetry.dependencies]
8
  python = "^3.8"
9
  numpy = "^1.22.3"
 
 
 
 
10
 
11
  [tool.poetry.dev-dependencies]
 
12
 
13
  [build-system]
14
  requires = ["poetry-core>=1.0.0"]
 
7
  [tool.poetry.dependencies]
8
  python = "^3.8"
9
  numpy = "^1.22.3"
10
+ transformers = "^4.17.0"
11
+ torch = "^1.11.0"
12
+ datasets = "^1.18.4"
13
+ faiss-cpu = "^1.7.2"
14
 
15
  [tool.poetry.dev-dependencies]
16
+ flake8 = "^4.0.1"
17
 
18
  [build-system]
19
  requires = ["poetry-core>=1.0.0"]