|
from haystack.nodes.retriever import EmbeddingRetriever |
|
from haystack.nodes import TableReader, FARMReader, RouteDocuments, JoinAnswers |
|
from haystack import Pipeline |
|
|
|
text_reader_types = { |
|
"minilm": "deepset/minilm-uncased-squad2", |
|
"distilroberta": "deepset/tinyroberta-squad2", |
|
"electra-base": "deepset/electra-base-squad2", |
|
"bert-base": "deepset/bert-base-cased-squad2", |
|
"deberta-large": "deepset/deberta-v3-large-squad2", |
|
"gpt3": "implement openai answer generator" |
|
} |
|
table_reader_types = { |
|
"tapas": "deepset/tapas-large-nq-hn-reader", |
|
"text": "implement changing tables to text" |
|
} |
|
|
|
|
|
def create_retriever(document_store): |
|
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/all-mpnet-base-v2-table") |
|
document_store.update_embeddings(retriever=retriever) |
|
return document_store, retriever |
|
|
|
def create_readers_and_pipeline(retriever, text_reader_type = "deepset/roberta-base-squad2", table_reader_type="deepset/tapas-large-nq-hn-reader", use_table=True, use_text=True): |
|
both = (use_table and use_text) |
|
if use_text or both: |
|
print("Initializing Text reader..") |
|
text_reader = FARMReader(text_reader_type) |
|
if use_table or both: |
|
print("Initializing table reader..") |
|
table_reader = TableReader(table_reader_type) |
|
if both: |
|
route_documents = RouteDocuments() |
|
join_answers = JoinAnswers() |
|
|
|
text_table_qa_pipeline = Pipeline() |
|
text_table_qa_pipeline.add_node(component=retriever, name="EmbeddingRetriever", inputs=["Query"]) |
|
if use_table and not use_text: |
|
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["EmbeddingRetriever"]) |
|
elif use_text and not use_table: |
|
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["EmbeddingRetriever"]) |
|
elif both: |
|
text_table_qa_pipeline.add_node(component=route_documents, name="RouteDocuments", inputs=["EmbeddingRetriever"]) |
|
text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["RouteDocuments.output_1"]) |
|
text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["RouteDocuments.output_2"]) |
|
text_table_qa_pipeline.add_node(component=join_answers, name="JoinAnswers", inputs=["TextReader", "TableReader"]) |
|
|
|
return text_table_qa_pipeline |
|
|