Spaces:
Runtime error
Runtime error
import gradio as gr | |
from haystack.nodes import BM25Retriever, FARMReader | |
from haystack.document_stores import ElasticsearchDocumentStore | |
from haystack.pipelines import ExtractiveQAPipeline | |
from abc import ABC, abstractmethod | |
import certifi | |
ca_certs=certifi.where() | |
class Config(): | |
es_host = "ask2democracy.es.us-central1.gcp.cloud.es.io" | |
es_user = "elastic" | |
es_password = "siKAHmmk2flwEaKNqQVZwp49" | |
proposals_index = "petrolfo" | |
#reader_model_name_or_path = "deepset/roberta-base-squad2" | |
reader_model_name_or_path = "deepset/xlm-roberta-large-squad2" | |
use_gpu = True | |
class DocumentQueries(ABC): | |
def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str): | |
pass | |
class ExtractiveProposalQueries(DocumentQueries): | |
def __init__(self, es_host: str, es_index: str, es_user, es_password, reader_name_or_path: str, use_gpu = False) -> None: | |
reader = FARMReader(model_name_or_path = reader_name_or_path, use_gpu = use_gpu, num_processes=1) | |
self._initialize_pipeline(es_host, es_index, es_user, es_password, reader = reader) | |
def _initialize_pipeline(self, es_host, es_index, es_user, es_password, reader = None): | |
if reader is not None: | |
self.reader = reader | |
self.es_host = es_host | |
self.es_user = es_user | |
self.es_password = es_password | |
self.document_store = ElasticsearchDocumentStore(host = es_host, username=es_user, password=es_password, index = es_index, port = 443, scheme='https', verify_certs=True, ca_certs=ca_certs) | |
self.retriever = BM25Retriever(document_store = self.document_store) | |
self.pipe = ExtractiveQAPipeline(self.reader, self.retriever) | |
def search_by_query(self, query : str, retriever_top_k: int, reader_top_k: int, es_index: str = None) : | |
if es_index is not None: | |
self._initialize_pipeline(self.es_host, es_index, self.es_user, self.es_password) | |
params = {"Retriever": {"top_k": retriever_top_k}, "Reader": {"top_k": reader_top_k}} | |
prediction = self.pipe.run( query = query, params = params) | |
return prediction["answers"] | |
query = ExtractiveProposalQueries(es_host = Config.es_host, es_index = Config.proposals_index, | |
es_user = Config.es_user, es_password = Config.es_password, | |
reader_name_or_path = Config.reader_model_name_or_path, | |
use_gpu = Config.use_gpu) | |
def update(query): | |
return f"{query}", f"{query}", f"{query}", f"{query}" | |
def search(question): | |
p1_result = query.search_by_query(query = question, retriever_top_k = 5, reader_top_k = 3, es_index = "petro") | |
p2_result = query.search_by_query(query = question, retriever_top_k = 5, reader_top_k = 3, es_index = "rodolfo") | |
return [p1_result[0].answer, | |
p1_result[0].context, | |
p2_result[0].answer, | |
p2_result[0].context] | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown( | |
""" | |
# Ask2Democracy | |
Preguntale a los candidatos | |
""") | |
inp = gr.Textbox(placeholder="Haz tu pregunta aquΓ") | |
search_button = gr.Button("Buscar") | |
with gr.Row(): | |
response = gr.Label(value="Petro") | |
context = gr.Label(value="El viejo") | |
with gr.Row(): | |
with gr.Column(): | |
# resp_1 = gr.Markdown("<b>Respuesta</b>") | |
# context_1 = gr.Markdown("<b>Contexto</b>") | |
resp_1 = gr.Textbox(lines=1, label="respuesta") | |
context_1 = gr.Textbox(lines=5, label="contexto") | |
with gr.Column(): | |
resp_2 = gr.Textbox(lines=1, label="respuesta") | |
context_2 = gr.Textbox(lines=5, label="contexto") | |
search_button.click(search, inputs = inp, outputs=[resp_1, context_1, resp_2, context_2]) | |
demo.launch(debug = True) |