Stefano Fiorucci
commited on
Commit
•
aabdf81
1
Parent(s):
418ba7e
start refactoring
Browse files- app.py +8 -44
- config.py +9 -0
- haystack_utils.py +42 -0
app.py
CHANGED
@@ -6,18 +6,11 @@ from json import JSONDecodeError
|
|
6 |
from markdown import markdown
|
7 |
import random
|
8 |
from typing import List, Dict, Any, Tuple, Optional
|
9 |
-
|
10 |
-
from haystack.document_stores import FAISSDocumentStore
|
11 |
-
from haystack.nodes import EmbeddingRetriever
|
12 |
-
from haystack.pipelines import ExtractiveQAPipeline
|
13 |
-
from haystack.nodes import FARMReader
|
14 |
-
from haystack.pipelines import ExtractiveQAPipeline
|
15 |
from annotated_text import annotation
|
16 |
-
import shutil
|
17 |
from urllib.parse import unquote
|
18 |
|
|
|
19 |
|
20 |
-
# FAISS index directory
|
21 |
INDEX_DIR = 'data/index'
|
22 |
QUESTIONS_PATH = 'data/questions.txt'
|
23 |
RETRIEVER_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
@@ -26,53 +19,24 @@ READER_MODEL = "deepset/roberta-base-squad2"
|
|
26 |
READER_CONFIG_THRESHOLD = 0.15
|
27 |
RETRIEVER_TOP_K = 10
|
28 |
READER_TOP_K = 5
|
29 |
-
# pipe=None
|
30 |
-
|
31 |
-
# the following function is cached to make index and models load only at start
|
32 |
-
|
33 |
|
|
|
|
|
|
|
34 |
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},
|
35 |
allow_output_mutation=True)
|
36 |
-
def
|
37 |
-
|
38 |
-
load document store, retriever, reader and create pipeline
|
39 |
-
"""
|
40 |
-
shutil.copy(f'{INDEX_DIR}/faiss_document_store.db', '.')
|
41 |
-
document_store = FAISSDocumentStore(
|
42 |
-
faiss_index_path=f'{INDEX_DIR}/my_faiss_index.faiss',
|
43 |
-
faiss_config_path=f'{INDEX_DIR}/my_faiss_index.json')
|
44 |
-
print(f'Index size: {document_store.get_document_count()}')
|
45 |
-
retriever = EmbeddingRetriever(
|
46 |
-
document_store=document_store,
|
47 |
-
embedding_model=RETRIEVER_MODEL,
|
48 |
-
model_format=RETRIEVER_MODEL_FORMAT
|
49 |
-
)
|
50 |
-
reader = FARMReader(model_name_or_path=READER_MODEL,
|
51 |
-
use_gpu=False,
|
52 |
-
confidence_threshold=READER_CONFIG_THRESHOLD)
|
53 |
-
pipe = ExtractiveQAPipeline(reader, retriever)
|
54 |
-
return pipe
|
55 |
|
56 |
|
57 |
@st.cache()
|
58 |
-
def
|
59 |
-
|
60 |
-
questions = [line.strip() for line in fin.readlines()
|
61 |
-
if not line.startswith('#')]
|
62 |
-
return questions
|
63 |
-
|
64 |
-
|
65 |
-
def set_state_if_absent(key, value):
|
66 |
-
if key not in st.session_state:
|
67 |
-
st.session_state[key] = value
|
68 |
-
|
69 |
|
70 |
pipe = start_haystack()
|
71 |
|
72 |
# the pipeline is not included as parameter of the following function,
|
73 |
# because it is difficult to cache
|
74 |
-
|
75 |
-
|
76 |
@st.cache(persist=True, allow_output_mutation=True)
|
77 |
def query(question: str, retriever_top_k: int = 10, reader_top_k: int = 5):
|
78 |
"""Run query and get answers"""
|
|
|
6 |
from markdown import markdown
|
7 |
import random
|
8 |
from typing import List, Dict, Any, Tuple, Optional
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
from annotated_text import annotation
|
|
|
10 |
from urllib.parse import unquote
|
11 |
|
12 |
+
from haystack_utils import start_haystack, set_state_if_absent, load_questions
|
13 |
|
|
|
14 |
INDEX_DIR = 'data/index'
|
15 |
QUESTIONS_PATH = 'data/questions.txt'
|
16 |
RETRIEVER_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
|
|
19 |
READER_CONFIG_THRESHOLD = 0.15
|
20 |
RETRIEVER_TOP_K = 10
|
21 |
READER_TOP_K = 5
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
# the following function is a wrapper for start_haystack,
|
24 |
+
# which loads document store, retriever, reader and creates pipeline.
|
25 |
+
# cached to make index and models load only at start
|
26 |
@st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},
|
27 |
allow_output_mutation=True)
|
28 |
+
def start_app():
|
29 |
+
return start_haystack()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
|
32 |
@st.cache()
|
33 |
+
def load_questions_wrapper():
|
34 |
+
return load_questions()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
pipe = start_haystack()
|
37 |
|
38 |
# the pipeline is not included as parameter of the following function,
|
39 |
# because it is difficult to cache
|
|
|
|
|
40 |
@st.cache(persist=True, allow_output_mutation=True)
|
41 |
def query(question: str, retriever_top_k: int = 10, reader_top_k: int = 5):
|
42 |
"""Run query and get answers"""
|
config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
INDEX_DIR = 'data/index'
|
3 |
+
QUESTIONS_PATH = 'data/questions.txt'
|
4 |
+
RETRIEVER_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
|
5 |
+
RETRIEVER_MODEL_FORMAT = "sentence_transformers"
|
6 |
+
READER_MODEL = "deepset/roberta-base-squad2"
|
7 |
+
READER_CONFIG_THRESHOLD = 0.15
|
8 |
+
RETRIEVER_TOP_K = 10
|
9 |
+
READER_TOP_K = 5
|
haystack_utils.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import shutil
|
2 |
+
from haystack.document_stores import FAISSDocumentStore
|
3 |
+
from haystack.nodes import EmbeddingRetriever
|
4 |
+
from haystack.pipelines import ExtractiveQAPipeline
|
5 |
+
from haystack.nodes import FARMReader
|
6 |
+
import streamlit as st
|
7 |
+
|
8 |
+
from config import (INDEX_DIR, RETRIEVER_MODEL, RETRIEVER_MODEL_FORMAT,
|
9 |
+
READER_MODEL, READER_CONFIG_THRESHOLD, QUESTIONS_PATH)
|
10 |
+
|
11 |
+
def start_haystack():
|
12 |
+
"""
|
13 |
+
load document store, retriever, reader and create pipeline
|
14 |
+
"""
|
15 |
+
shutil.copy(f'{INDEX_DIR}/faiss_document_store.db', '.')
|
16 |
+
document_store = FAISSDocumentStore(
|
17 |
+
faiss_index_path=f'{INDEX_DIR}/my_faiss_index.faiss',
|
18 |
+
faiss_config_path=f'{INDEX_DIR}/my_faiss_index.json')
|
19 |
+
print(f'Index size: {document_store.get_document_count()}')
|
20 |
+
|
21 |
+
retriever = EmbeddingRetriever(
|
22 |
+
document_store=document_store,
|
23 |
+
embedding_model=RETRIEVER_MODEL,
|
24 |
+
model_format=RETRIEVER_MODEL_FORMAT
|
25 |
+
)
|
26 |
+
|
27 |
+
reader = FARMReader(model_name_or_path=READER_MODEL,
|
28 |
+
use_gpu=False,
|
29 |
+
confidence_threshold=READER_CONFIG_THRESHOLD)
|
30 |
+
|
31 |
+
pipe = ExtractiveQAPipeline(reader, retriever)
|
32 |
+
return pipe
|
33 |
+
|
34 |
+
def set_state_if_absent(key, value):
|
35 |
+
if key not in st.session_state:
|
36 |
+
st.session_state[key] = value
|
37 |
+
|
38 |
+
def load_questions():
|
39 |
+
with open(QUESTIONS_PATH) as fin:
|
40 |
+
questions = [line.strip() for line in fin.readlines()
|
41 |
+
if not line.startswith('#')]
|
42 |
+
return questions
|