Ramon Meffert
commited on
Commit
•
8bbe3aa
1
Parent(s):
7177a08
Add retriever based on DPR (WIP)
Browse files- .gitignore +3 -0
- README.md +20 -7
- base_model/reader.py +2 -0
- base_model/retriever.py +86 -0
- main.py +14 -0
- poetry.lock +0 -0
- pyproject.toml +5 -0
.gitignore
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
__pycache__/
|
3 |
*.py[cod]
|
|
|
1 |
+
# VS Code stuff
|
2 |
+
.vscode/
|
3 |
+
|
4 |
# Byte-compiled / optimized / DLL files
|
5 |
__pycache__/
|
6 |
*.py[cod]
|
README.md
CHANGED
@@ -1,14 +1,27 @@
|
|
1 |
# nlp-flashcard-project
|
2 |
|
3 |
-
|
4 |
## Todo voor progress meeting
|
5 |
|
6 |
-
- Data inlezen/Repo klaarmaken
|
7 |
-
- Proof of concept met UnifiedQA
|
8 |
-
- Standaard QA model met de dataset
|
9 |
-
- Papers verzamelen/lezen
|
10 |
-
- Eerder werk bekijken, inspiratie opdoen voor research richting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
## Handige info
|
13 |
|
14 |
-
Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
|
|
|
|
1 |
# nlp-flashcard-project
|
2 |
|
|
|
3 |
## Todo voor progress meeting
|
4 |
|
5 |
+
- [ ] Data inlezen/Repo klaarmaken
|
6 |
+
- [ ] Proof of concept met UnifiedQA
|
7 |
+
- [ ] Standaard QA model met de dataset
|
8 |
+
- [ ] Papers verzamelen/lezen
|
9 |
+
- [ ] Eerder werk bekijken, inspiratie opdoen voor research richting
|
10 |
+
|
11 |
+
## Overview
|
12 |
+
|
13 |
+
De meeste QA systemen bestaan uit twee onderdelen:
|
14 |
+
|
15 |
+
- Een retriever. Die haalt adhv de vraag _k_ relevante stukken context op, bv.
|
16 |
+
met `tf-idf`.
|
17 |
+
- Een model dat het antwoord genereert. Wat je hier precies gebruikt hangt af
|
18 |
+
van de manier van question answering:
|
19 |
+
- Voor **extractive QA** gebruik je een reader;
|
20 |
+
- Voor **generative QA** gebruik je een generator.
|
21 |
+
|
22 |
+
Beide werken op basis van een language model.
|
23 |
|
24 |
## Handige info
|
25 |
|
26 |
+
- Huggingface QA tutorial: <https://huggingface.co/docs/transformers/tasks/question_answering#finetune-with-tensorflow>
|
27 |
+
- Overview van open-domain question answering technieken: <https://lilianweng.github.io/posts/2020-10-29-odqa/>
|
base_model/reader.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
class Reader():
|
2 |
+
pass
|
base_model/retriever.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer, \
|
2 |
+
DPRQuestionEncoder, DPRQuestionEncoderTokenizer
|
3 |
+
from datasets import load_dataset
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class Retriever():
|
8 |
+
"""A class used to retrieve relevant documents based on some query.
|
9 |
+
based on https://huggingface.co/docs/datasets/faiss_es#faiss.
|
10 |
+
"""
|
11 |
+
|
12 |
+
def __init__(self, dataset: str = "GroNLP/ik-nlp-22_slp") -> None:
|
13 |
+
"""Initialize the retriever
|
14 |
+
|
15 |
+
Args:
|
16 |
+
dataset (str, optional): The dataset to train on. Assumes the
|
17 |
+
information is stored in a column named 'text'. Defaults to
|
18 |
+
"GroNLP/ik-nlp-22_slp".
|
19 |
+
"""
|
20 |
+
torch.set_grad_enabled(False)
|
21 |
+
|
22 |
+
# Context encoding and tokenization
|
23 |
+
self.ctx_encoder = DPRContextEncoder.from_pretrained(
|
24 |
+
"facebook/dpr-ctx_encoder-single-nq-base")
|
25 |
+
self.ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained(
|
26 |
+
"facebook/dpr-ctx_encoder-single-nq-base")
|
27 |
+
|
28 |
+
# Question encoding and tokenization
|
29 |
+
self.q_encoder = DPRQuestionEncoder.from_pretrained(
|
30 |
+
"facebook/dpr-question_encoder-single-nq-base")
|
31 |
+
self.q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
|
32 |
+
"facebook/dpr-question_encoder-single-nq-base")
|
33 |
+
|
34 |
+
# Dataset building
|
35 |
+
self.dataset = self.__init_dataset(dataset)
|
36 |
+
|
37 |
+
def __init_dataset(self, dataset: str):
|
38 |
+
"""Loads the dataset and adds FAISS embeddings.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
dataset (str): A HuggingFace dataset name.
|
42 |
+
|
43 |
+
Returns:
|
44 |
+
Dataset: A dataset with a new column 'embeddings' containing FAISS
|
45 |
+
embeddings.
|
46 |
+
"""
|
47 |
+
# TODO: save ds w/ embeddings to disk and retrieve it if it already exists
|
48 |
+
|
49 |
+
# Load dataset
|
50 |
+
ds = load_dataset(dataset, name='paragraphs')['train']
|
51 |
+
|
52 |
+
def embed(row):
|
53 |
+
# Inline helper function to perform embedding
|
54 |
+
p = row['text']
|
55 |
+
tok = self.ctx_tokenizer(p, return_tensors='pt', truncation=True)
|
56 |
+
enc = self.ctx_encoder(**tok)[0][0].numpy()
|
57 |
+
return {'embeddings': enc}
|
58 |
+
|
59 |
+
# Add FAISS embeddings
|
60 |
+
ds_with_embeddings = ds.map(embed)
|
61 |
+
|
62 |
+
# Todo: this throws a weird error.
|
63 |
+
ds_with_embeddings.add_faiss_index(column='embeddings')
|
64 |
+
return ds_with_embeddings
|
65 |
+
|
66 |
+
def retrieve(self, query: str, k: int = 5):
|
67 |
+
"""Retrieve the top k matches for a search query.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
query (str): A search query
|
71 |
+
k (int, optional): The number of documents to retrieve. Defaults to
|
72 |
+
5.
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
tuple: A tuple of lists of scores and results.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def embed(q):
|
79 |
+
# Inline helper function to perform embedding
|
80 |
+
tok = self.q_tokenizer(q, return_tensors='pt', truncation=True)
|
81 |
+
return self.q_encoder(**tok)[0][0].numpy()
|
82 |
+
|
83 |
+
question_embedding = embed(query)
|
84 |
+
scores, results = self.dataset.get_nearest_examples(
|
85 |
+
'embeddings', question_embedding, k=k)
|
86 |
+
return scores, results
|
main.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from base_model.retriever import Retriever
|
2 |
+
|
3 |
+
if __name__ == '__main__':
|
4 |
+
# Initialize retriever
|
5 |
+
r = Retriever()
|
6 |
+
|
7 |
+
# Retrieve example
|
8 |
+
retrieved = r.retrieve(
|
9 |
+
"When is a stochastic process said to be stationary?")
|
10 |
+
|
11 |
+
for i, (score, result) in enumerate(retrieved):
|
12 |
+
print(f"Result {i+1} (score: {score * 100:.02f}:")
|
13 |
+
print(result['text'][0])
|
14 |
+
print() # Newline
|
poetry.lock
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
pyproject.toml
CHANGED
@@ -7,8 +7,13 @@ authors = ["Your Name <you@example.com>"]
|
|
7 |
[tool.poetry.dependencies]
|
8 |
python = "^3.8"
|
9 |
numpy = "^1.22.3"
|
|
|
|
|
|
|
|
|
10 |
|
11 |
[tool.poetry.dev-dependencies]
|
|
|
12 |
|
13 |
[build-system]
|
14 |
requires = ["poetry-core>=1.0.0"]
|
|
|
7 |
[tool.poetry.dependencies]
|
8 |
python = "^3.8"
|
9 |
numpy = "^1.22.3"
|
10 |
+
transformers = "^4.17.0"
|
11 |
+
torch = "^1.11.0"
|
12 |
+
datasets = "^1.18.4"
|
13 |
+
faiss-cpu = "^1.7.2"
|
14 |
|
15 |
[tool.poetry.dev-dependencies]
|
16 |
+
flake8 = "^4.0.1"
|
17 |
|
18 |
[build-system]
|
19 |
requires = ["poetry-core>=1.0.0"]
|