Spaces:
Runtime error
Runtime error
File size: 2,621 Bytes
75128dd 8ab7d9a 75128dd 4e2b533 75128dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import shutil
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes.retriever import EmbeddingRetriever, MultiModalRetriever
from haystack.nodes.reader import FARMReader
from haystack import Pipeline
from utils.config import (INDEX_DIR)
from typing import List
from haystack import BaseComponent, Answer
import streamlit as st
class AnswerToQuery(BaseComponent):
outgoing_edges = 1
def run(self, query: str, answers: List[Answer]):
return {"query": answers[0].answer}, "output_1"
def run_batch(self):
raise NotImplementedError()
# cached to make index and models load only at start
@st.cache(
hash_funcs={"builtins.SwigPyObject": lambda _: None}, allow_output_mutation=True
)
def start_haystack():
"""
load document store, retriever, entailment checker and create pipeline
"""
shutil.copy(f"{INDEX_DIR}/text.db", ".")
shutil.copy(f"{INDEX_DIR}/images.db", ".")
document_store_text = FAISSDocumentStore(
faiss_index_path=f"{INDEX_DIR}/text.faiss",
faiss_config_path=f"{INDEX_DIR}/text.json",
)
document_store_images = FAISSDocumentStore(
faiss_index_path=f"{INDEX_DIR}/images.faiss",
faiss_config_path=f"{INDEX_DIR}/images.json",
)
retriever_text = EmbeddingRetriever(
document_store=document_store_text,
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
model_format="sentence_transformers",
)
reader = FARMReader(model_name_or_path="deepset/roberta-base-squad2", use_gpu=True)
retriever_images = MultiModalRetriever(
document_store=document_store_images,
query_embedding_model = "sentence-transformers/clip-ViT-B-32",
query_type="text",
document_embedding_models = {
"image": "sentence-transformers/clip-ViT-B-32"
}
)
answer_to_query = AnswerToQuery()
pipe = Pipeline()
pipe.add_node(retriever_text, name="text_retriever", inputs=["Query"])
pipe.add_node(reader, name="text_reader", inputs=["text_retriever"])
pipe.add_node(answer_to_query, name="answer2query", inputs=["text_reader"])
pipe.add_node(retriever_images, name="image_retriever", inputs=["answer2query"])
return pipe
pipe = start_haystack()
@st.cache(allow_output_mutation=True)
def query(statement: str, text_retriever_top_k: int = 5, image_retriever_top_k = 1):
"""Run query"""
params = {"image_retriever": {"top_k": image_retriever_top_k},"text_retriever": {"top_k": text_retriever_top_k} }
results = pipe.run(statement, params=params)
return results |