from haystack.nodes.retriever import EmbeddingRetriever from haystack.nodes import TableReader, FARMReader, RouteDocuments, JoinAnswers from haystack import Pipeline text_reader_types = { "minilm": "deepset/minilm-uncased-squad2", "distilroberta": "deepset/tinyroberta-squad2", "electra-base": "deepset/electra-base-squad2", "bert-base": "deepset/bert-base-cased-squad2", "deberta-large": "deepset/deberta-v3-large-squad2", "gpt3": "implement openai answer generator" } table_reader_types = { "tapas": "deepset/tapas-large-nq-hn-reader", "text": "implement changing tables to text" } def create_retriever(document_store): retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/all-mpnet-base-v2-table") document_store.update_embeddings(retriever=retriever) return document_store, retriever def create_readers_and_pipeline(retriever, text_reader_type = "deepset/roberta-base-squad2", table_reader_type="deepset/tapas-large-nq-hn-reader", use_table=True, use_text=True): both = (use_table and use_text) if use_text or both: print("Initializing Text reader..") text_reader = FARMReader(text_reader_type) if use_table or both: print("Initializing table reader..") table_reader = TableReader(table_reader_type) if both: route_documents = RouteDocuments() join_answers = JoinAnswers() text_table_qa_pipeline = Pipeline() text_table_qa_pipeline.add_node(component=retriever, name="EmbeddingRetriever", inputs=["Query"]) if use_table and not use_text: text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["EmbeddingRetriever"]) elif use_text and not use_table: text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["EmbeddingRetriever"]) elif both: text_table_qa_pipeline.add_node(component=route_documents, name="RouteDocuments", inputs=["EmbeddingRetriever"]) text_table_qa_pipeline.add_node(component=text_reader, name="TextReader", inputs=["RouteDocuments.output_1"]) text_table_qa_pipeline.add_node(component=table_reader, name="TableReader", inputs=["RouteDocuments.output_2"]) text_table_qa_pipeline.add_node(component=join_answers, name="JoinAnswers", inputs=["TextReader", "TableReader"]) return text_table_qa_pipeline