resrer-demo / app.py
seonglae's picture
fix: remove question mark for better inference
f12e339
raw
history blame contribute delete
No virus
3.08 kB
import os
import streamlit as st
from pymilvus import MilvusClient
import torch
from model import encode_dpr_question, get_dpr_encoder
from model import summarize_text, get_summarizer
from model import ask_reader, get_reader
TITLE = 'ReSRer: Retriever-Summarizer-Reader'
INITIAL = "What is the population of NYC"
st.set_page_config(page_title=TITLE)
st.header(TITLE)
st.markdown('''
<h5>Ask short-answer question that can be find in Wikipedia data.</h5>
''', unsafe_allow_html=True)
st.markdown(
'This demo searches through 21,000,000 Wikipedia passages in real-time under the hood.')
@st.cache_resource
def load_models():
models = {}
models['encoder'] = get_dpr_encoder()
models['summarizer'] = get_summarizer()
models['reader'] = get_reader()
return models
@st.cache_resource
def load_client():
client = MilvusClient(user='resrer', password=os.environ['MILVUS_PW'],
uri=f"http://{os.environ['MILVUS_HOST']}:19530", db_name='psgs_w100')
return client
client = load_client()
models = load_models()
styl = """
<style>
.StatusWidget-enter-done{
position: fixed;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
}
.StatusWidget-enter-done button{
display: none;
}
</style>
"""
st.markdown(styl, unsafe_allow_html=True)
question = st.text_input("Question", INITIAL)
col1, col2, col3 = st.columns(3)
if col1.button("What is the capital of South Korea"):
question = "What is the capital of South Korea"
if col2.button("What is the most famous building in Paris"):
question = "What is the most famous building in Paris"
if col3.button("Who is the actor of Harry Potter"):
question = "Who is the actor of Harry Potter"
@torch.inference_mode()
def main(question: str):
if question in st.session_state:
print("Cache hit!")
ctx, summary, answer = st.session_state[question]
else:
print(f"Input: {question}")
# Embedding
question_vectors = encode_dpr_question(
models['encoder'][0], models['encoder'][1], [question])
query_vector = question_vectors.detach().cpu().numpy().tolist()[0]
# Retriever
results = client.search(collection_name='dpr_nq', data=[
query_vector], limit=10, output_fields=['title', 'text'])
texts = [result['entity']['text'] for result in results[0]]
ctx = '\n'.join(texts)
# Reader
[summary] = summarize_text(models['summarizer'][0],
models['summarizer'][1], [ctx])
answers = ask_reader(models['reader'][0],
models['reader'][1], [question], [summary])
answer = answers[0]['answer']
print(f"\nAnswer: {answer}")
st.session_state[question] = (ctx, summary, answer)
# Summary
st.write(f"### Answer: {answer}")
st.markdown('<h5>Summarized Context</h5>', unsafe_allow_html=True)
st.markdown(
f"<h6 style='padding: 0'>{summary}</h6><hr style='margin: 1em 0px'>", unsafe_allow_html=True)
st.markdown('<h5>Original Context</h5>', unsafe_allow_html=True)
st.markdown(ctx)
if question:
main(question)