hra_qa_bot_v1 / app.py
Mengmeng Liu
Improve the UI
a930d50
import streamlit as st
import os
import json
import torch
import numpy as np
from utils import ModelWrapper
from sklearn.metrics.pairwise import cosine_similarity
st.title('HRA Document QA')
with st.spinner("Please wait for loading the models"):
model_loader = ModelWrapper()
with st.chat_message("assistant"):
st.write("Hello πŸ‘‹ I am an HRA chatbot~")
st.write("I know everything about the leadership of HRA.")
st.write("Please ask your questions about the leadership of HRA. For example, you can ask 'Where did Robert Kauffman graduate?', 'What's the position for Fred Danback?' ")
question = st.chat_input("Please ask me some questions about the leadership of HRA:")
if question:
with st.chat_message("assistant"):
st.write("You asked a question:")
with st.chat_message("user"):
st.write(question)
# get the embeddings for the question
question_embeddings = model_loader.get_embeddings(question, 0)
# get the embeddings of all the documents
if 0:
with st.spinner("Please wait for computing the embeddings"):
files = os.listdir("./documents")
document_embeddings = {}
for file in files:
# open document
f = open("./documents/"+file,"r", encoding="utf-8")
f = f.read()
# get the embedding of the document
document_embeddings[file] = model_loader.get_embeddings(f, 1).tolist()
# save the embeddings of all the documents as vector database
with open("./vectors/embeddings.json","w") as outfile:
outfile.write(json.dumps(document_embeddings, indent=4))
embeddings_file = open("./vectors/embeddings.json","r")
document_embeddings = json.load(embeddings_file)
# linear search for the most relevant documnet
max_similarity = -1
most_relevant_document = None
for document in document_embeddings:
cur_similarity = cosine_similarity(question_embeddings, document_embeddings[document])
if cur_similarity > max_similarity:
most_relevant_document = document
max_similarity = cur_similarity
with st.chat_message("assistant"):
if max_similarity < 0.35:
st.write("Sorry we can't find relevant document")
else:
st.write("The most relevant document is:")
st.write(most_relevant_document)
st.write("And the cosine similarity is:" + str(max_similarity))
if max_similarity >= 0.35:
with open("./documents/"+most_relevant_document, "r", encoding="utf-8") as f:
f = f.read()
inputs = model_loader.tokenizer(question, f, return_tensors="pt")
with torch.no_grad():
outputs = model_loader.model_qa(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
predict_answer = model_loader.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
with st.chat_message("assistant"):
st.write("Answer:")
if predict_answer:
st.write(predict_answer)
else:
st.write(f)