Stefano Fiorucci commited on
Commit
aabdf81
1 Parent(s): 418ba7e

start refactoring

Browse files
Files changed (3) hide show
  1. app.py +8 -44
  2. config.py +9 -0
  3. haystack_utils.py +42 -0
app.py CHANGED
@@ -6,18 +6,11 @@ from json import JSONDecodeError
6
  from markdown import markdown
7
  import random
8
  from typing import List, Dict, Any, Tuple, Optional
9
-
10
- from haystack.document_stores import FAISSDocumentStore
11
- from haystack.nodes import EmbeddingRetriever
12
- from haystack.pipelines import ExtractiveQAPipeline
13
- from haystack.nodes import FARMReader
14
- from haystack.pipelines import ExtractiveQAPipeline
15
  from annotated_text import annotation
16
- import shutil
17
  from urllib.parse import unquote
18
 
 
19
 
20
- # FAISS index directory
21
  INDEX_DIR = 'data/index'
22
  QUESTIONS_PATH = 'data/questions.txt'
23
  RETRIEVER_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
@@ -26,53 +19,24 @@ READER_MODEL = "deepset/roberta-base-squad2"
26
  READER_CONFIG_THRESHOLD = 0.15
27
  RETRIEVER_TOP_K = 10
28
  READER_TOP_K = 5
29
- # pipe=None
30
-
31
- # the following function is cached to make index and models load only at start
32
-
33
 
 
 
 
34
  @st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},
35
  allow_output_mutation=True)
36
- def start_haystack():
37
- """
38
- load document store, retriever, reader and create pipeline
39
- """
40
- shutil.copy(f'{INDEX_DIR}/faiss_document_store.db', '.')
41
- document_store = FAISSDocumentStore(
42
- faiss_index_path=f'{INDEX_DIR}/my_faiss_index.faiss',
43
- faiss_config_path=f'{INDEX_DIR}/my_faiss_index.json')
44
- print(f'Index size: {document_store.get_document_count()}')
45
- retriever = EmbeddingRetriever(
46
- document_store=document_store,
47
- embedding_model=RETRIEVER_MODEL,
48
- model_format=RETRIEVER_MODEL_FORMAT
49
- )
50
- reader = FARMReader(model_name_or_path=READER_MODEL,
51
- use_gpu=False,
52
- confidence_threshold=READER_CONFIG_THRESHOLD)
53
- pipe = ExtractiveQAPipeline(reader, retriever)
54
- return pipe
55
 
56
 
57
  @st.cache()
58
- def load_questions():
59
- with open(QUESTIONS_PATH) as fin:
60
- questions = [line.strip() for line in fin.readlines()
61
- if not line.startswith('#')]
62
- return questions
63
-
64
-
65
- def set_state_if_absent(key, value):
66
- if key not in st.session_state:
67
- st.session_state[key] = value
68
-
69
 
70
  pipe = start_haystack()
71
 
72
  # the pipeline is not included as parameter of the following function,
73
  # because it is difficult to cache
74
-
75
-
76
  @st.cache(persist=True, allow_output_mutation=True)
77
  def query(question: str, retriever_top_k: int = 10, reader_top_k: int = 5):
78
  """Run query and get answers"""
 
6
  from markdown import markdown
7
  import random
8
  from typing import List, Dict, Any, Tuple, Optional
 
 
 
 
 
 
9
  from annotated_text import annotation
 
10
  from urllib.parse import unquote
11
 
12
+ from haystack_utils import start_haystack, set_state_if_absent, load_questions
13
 
 
14
  INDEX_DIR = 'data/index'
15
  QUESTIONS_PATH = 'data/questions.txt'
16
  RETRIEVER_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
 
19
  READER_CONFIG_THRESHOLD = 0.15
20
  RETRIEVER_TOP_K = 10
21
  READER_TOP_K = 5
 
 
 
 
22
 
23
+ # the following function is a wrapper for start_haystack,
24
+ # which loads document store, retriever, reader and creates pipeline.
25
+ # cached to make index and models load only at start
26
  @st.cache(hash_funcs={"builtins.SwigPyObject": lambda _: None},
27
  allow_output_mutation=True)
28
+ def start_app():
29
+ return start_haystack()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
 
32
  @st.cache()
33
+ def load_questions_wrapper():
34
+ return load_questions()
 
 
 
 
 
 
 
 
 
35
 
36
  pipe = start_haystack()
37
 
38
  # the pipeline is not included as parameter of the following function,
39
  # because it is difficult to cache
 
 
40
  @st.cache(persist=True, allow_output_mutation=True)
41
  def query(question: str, retriever_top_k: int = 10, reader_top_k: int = 5):
42
  """Run query and get answers"""
config.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+
2
+ INDEX_DIR = 'data/index'
3
+ QUESTIONS_PATH = 'data/questions.txt'
4
+ RETRIEVER_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
5
+ RETRIEVER_MODEL_FORMAT = "sentence_transformers"
6
+ READER_MODEL = "deepset/roberta-base-squad2"
7
+ READER_CONFIG_THRESHOLD = 0.15
8
+ RETRIEVER_TOP_K = 10
9
+ READER_TOP_K = 5
haystack_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from haystack.document_stores import FAISSDocumentStore
3
+ from haystack.nodes import EmbeddingRetriever
4
+ from haystack.pipelines import ExtractiveQAPipeline
5
+ from haystack.nodes import FARMReader
6
+ import streamlit as st
7
+
8
+ from config import (INDEX_DIR, RETRIEVER_MODEL, RETRIEVER_MODEL_FORMAT,
9
+ READER_MODEL, READER_CONFIG_THRESHOLD, QUESTIONS_PATH)
10
+
11
+ def start_haystack():
12
+ """
13
+ load document store, retriever, reader and create pipeline
14
+ """
15
+ shutil.copy(f'{INDEX_DIR}/faiss_document_store.db', '.')
16
+ document_store = FAISSDocumentStore(
17
+ faiss_index_path=f'{INDEX_DIR}/my_faiss_index.faiss',
18
+ faiss_config_path=f'{INDEX_DIR}/my_faiss_index.json')
19
+ print(f'Index size: {document_store.get_document_count()}')
20
+
21
+ retriever = EmbeddingRetriever(
22
+ document_store=document_store,
23
+ embedding_model=RETRIEVER_MODEL,
24
+ model_format=RETRIEVER_MODEL_FORMAT
25
+ )
26
+
27
+ reader = FARMReader(model_name_or_path=READER_MODEL,
28
+ use_gpu=False,
29
+ confidence_threshold=READER_CONFIG_THRESHOLD)
30
+
31
+ pipe = ExtractiveQAPipeline(reader, retriever)
32
+ return pipe
33
+
34
+ def set_state_if_absent(key, value):
35
+ if key not in st.session_state:
36
+ st.session_state[key] = value
37
+
38
+ def load_questions():
39
+ with open(QUESTIONS_PATH) as fin:
40
+ questions = [line.strip() for line in fin.readlines()
41
+ if not line.startswith('#')]
42
+ return questions