import os
import streamlit as st
from pymilvus import MilvusClient
import torch
from model import encode_dpr_question, get_dpr_encoder
from model import summarize_text, get_summarizer
from model import ask_reader, get_reader
TITLE = 'ReSRer: Retriever-Summarizer-Reader'
INITIAL = "What is the population of NYC"
st.set_page_config(page_title=TITLE)
st.header(TITLE)
st.markdown('''
Ask short-answer question that can be find in Wikipedia data.
''', unsafe_allow_html=True)
@st.cache_resource
def load_models():
models = {}
models['encoder'] = get_dpr_encoder()
models['summarizer'] = get_summarizer()
models['reader'] = get_reader()
return models
@st.cache_resource
def load_client():
client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'],
uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100')
return client
client = load_client()
models = load_models()
styl = """
"""
st.markdown(styl, unsafe_allow_html=True)
question = st.text_area("Text to summarize", INITIAL)
@torch.inference_mode()
def main(question: str):
if question in st.session_state:
print("Cache hit!")
ctx, summary, answer = st.session_state[question]
else:
print(f"Input: {question}")
# Embedding
question_vectors = encode_dpr_question(
models['encoder'][0], models['encoder'][1], [question])
query_vector = question_vectors.detach().cpu().numpy().tolist()[0]
# Retriever
results = client.search(collection_name='dpr_nq', data=[
query_vector], limit=10, output_fields=['title', 'text'])
texts = [result['entity']['text'] for result in results[0]]
ctx = '\n'.join(texts)
# Reader
summary = summarize_text(models['summarizer'][0],
models['summarizer'][1], [summary])
answers = ask_reader(models['reader'][0],
models['reader'][1], [question], [ctx])
answer = answers[0]['answer']
print(f"\nAnswer: {answer}")
st.session_state[question] = (ctx, summary, answer)
# Summary
st.markdown(answer)
st.write("## Summary")
st.markdown(
f"{summary}
", unsafe_allow_html=True)
st.markdown(ctx)
st.write(f"{question}", unsafe_allow_html=True)
if question:
main(question)