Mengmeng Liu
initial build
36c7297
import gradio
import os
import json
import torch
import numpy as np
from utils import ModelWrapper
from sklearn.metrics.pairwise import cosine_similarity
# load the models and all other utils functions
model_loader = ModelWrapper()
def my_inference_function(question):
question_embeddings = model_loader.get_embeddings(question, 0)
# not embed the documents for now
if 0:
files = os.listdir("./documents")
document_embeddings = {}
for file in files:
# open document
f = open("./documents/"+file,"r", encoding="utf-8")
f = f.read()
# get the embedding of the document
document_embeddings[file] = model_loader.get_embeddings(f, 1).tolist()
# save the embeddings of all the documents as vector database
with open("./vectors/embeddings.json","w") as outfile:
outfile.write(json.dumps(document_embeddings, indent=4))
# open the embeddings for documents
# will replace with vector database later on
embeddings_file = open("./vectors/embeddings.json","r")
document_embeddings = json.load(embeddings_file)
# linear search for the most relevant document
max_similarity = -1
most_relevant_document = None
for document in document_embeddings:
cur_similarity = cosine_similarity(question_embeddings, document_embeddings[document])
if cur_similarity > max_similarity:
most_relevant_document = document
max_similarity = cur_similarity
if max_similarity >= 0.35:
with open("./documents/"+most_relevant_document, "r", encoding="utf-8") as f:
f = f.read()
inputs = model_loader.tokenizer(question, f, return_tensors="pt")
with torch.no_grad():
outputs = model_loader.model_qa(**inputs)
answer_start_index = outputs.start_logits.argmax()
answer_end_index = outputs.end_logits.argmax()
predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
predict_answer = model_loader.tokenizer.decode(predict_answer_tokens, skip_special_tokens=True)
if predict_answer is None:
predict_answer = "I can't answer your question right now. I am evolving ..."
ret = {"answer":predict_answer, "most_relevant_document": most_relevant_document, "cosine_similarity": str(max_similarity)}
else:
ret = {"answer": "Sorry we can't find the relevant document", "most_relevant_document": "None", "cosine_similarity": str(-1)}
return ret
gradio_interface = gradio.Interface(fn = my_inference_function,
inputs = "text",
outputs = "json",
examples = ["Where did Robert Kauffman graduate?", "What's the position of Fred Danback?"],
title = "HRA Leadership QA Bot"
)
gradio_interface.launch()