|
import streamlit as st |
|
|
|
from utils.config import document_store_configs, model_configs |
|
from haystack import Pipeline |
|
from haystack.schema import Answer |
|
from haystack.document_stores import BaseDocumentStore |
|
from haystack.document_stores import InMemoryDocumentStore, OpenSearchDocumentStore, WeaviateDocumentStore |
|
from haystack.nodes import EmbeddingRetriever, FARMReader, PromptNode, PreProcessor |
|
|
|
from milvus_haystack import MilvusDocumentStore |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def start_preprocessor_node(): |
|
print('initializing preprocessor node') |
|
processor = PreProcessor( |
|
clean_empty_lines= True, |
|
clean_whitespace=True, |
|
clean_header_footer=True, |
|
|
|
split_by="word", |
|
split_length=100, |
|
split_respect_sentence_boundary=True, |
|
|
|
|
|
) |
|
return processor |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def start_document_store(type: str): |
|
|
|
print('initializing document store') |
|
if type == 'inmemory': |
|
document_store = InMemoryDocumentStore(use_bm25=True, embedding_dim=384) |
|
''' |
|
documents = [ |
|
{ |
|
'content': "Pi is a super dog", |
|
'meta': {'name': "pi.txt"} |
|
}, |
|
{ |
|
'content': "The revenue of siemens is 5 milion Euro", |
|
'meta': {'name': "siemens.txt"} |
|
}, |
|
] |
|
document_store.write_documents(documents) |
|
''' |
|
elif type == 'opensearch': |
|
document_store = OpenSearchDocumentStore(scheme = document_store_configs['OPENSEARCH_SCHEME'], |
|
username = document_store_configs['OPENSEARCH_USERNAME'], |
|
password = document_store_configs['OPENSEARCH_PASSWORD'], |
|
host = document_store_configs['OPENSEARCH_HOST'], |
|
port = document_store_configs['OPENSEARCH_PORT'], |
|
index = document_store_configs['OPENSEARCH_INDEX'], |
|
embedding_dim = document_store_configs['OPENSEARCH_EMBEDDING_DIM']) |
|
elif type == 'weaviate': |
|
document_store = WeaviateDocumentStore(host = document_store_configs['WEAVIATE_HOST'], |
|
port = document_store_configs['WEAVIATE_PORT'], |
|
index = document_store_configs['WEAVIATE_INDEX'], |
|
embedding_dim = document_store_configs['WEAVIATE_EMBEDDING_DIM']) |
|
elif type == 'milvus': |
|
document_store = MilvusDocumentStore(uri = document_store_configs['MILVUS_URI'], |
|
index = document_store_configs['MILVUS_INDEX'], |
|
embedding_dim = document_store_configs['MILVUS_EMBEDDING_DIM'], |
|
return_embedding=True) |
|
return document_store |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def start_retriever(_document_store: BaseDocumentStore): |
|
print('initializing retriever') |
|
retriever = EmbeddingRetriever(document_store=_document_store, |
|
embedding_model=model_configs['EMBEDDING_MODEL'], |
|
top_k=5) |
|
|
|
|
|
|
|
return retriever |
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def start_reader(): |
|
print('initializing reader') |
|
reader = FARMReader(model_name_or_path=model_configs['EXTRACTIVE_MODEL']) |
|
return reader |
|
|
|
|
|
|
|
|
|
@st.cache_resource(show_spinner=False) |
|
def start_haystack_extractive(_document_store: BaseDocumentStore, _retriever: EmbeddingRetriever, _reader: FARMReader): |
|
print('initializing pipeline') |
|
pipe = Pipeline() |
|
pipe.add_node(component=_retriever, name="Retriever", inputs=["Query"]) |
|
pipe.add_node(component= _reader, name="Reader", inputs=["Retriever"]) |
|
return pipe |
|
|
|
@st.cache_resource(show_spinner=False) |
|
def start_haystack_rag(_document_store: BaseDocumentStore, _retriever: EmbeddingRetriever, openai_key): |
|
prompt_node = PromptNode(default_prompt_template="deepset/question-answering", |
|
model_name_or_path=model_configs['GENERATIVE_MODEL'], |
|
api_key=openai_key, |
|
max_length=500) |
|
pipe = Pipeline() |
|
|
|
pipe.add_node(component=_retriever, name="Retriever", inputs=["Query"]) |
|
pipe.add_node(component=prompt_node, name="PromptNode", inputs=["Retriever"]) |
|
|
|
return pipe |
|
|
|
|
|
def query(_pipeline, question): |
|
params = {} |
|
results = _pipeline.run(question, params=params) |
|
return results |
|
|
|
def initialize_pipeline(task, document_store, retriever, reader, openai_key = ""): |
|
if task == 'extractive': |
|
return start_haystack_extractive(document_store, retriever, reader) |
|
elif task == 'rag': |
|
return start_haystack_rag(document_store, retriever, openai_key) |
|
|
|
|
|
|