Vector_db / app.py
ShynBui's picture
Update app.py
aab3dcd verified
raw
history blame
No virus
2.61 kB
import gradio as gr
import os
from langchain.retrievers import EnsembleRetriever
from utils import *
import requests
from pyvi import ViTokenizer, ViPosTagger
import time
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import torch
retriever = load_the_embedding_retrieve(is_ready=False, k=3)
bm25_retriever = load_the_bm25_retrieve(k=3)
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5]
)
tokenizer = AutoTokenizer.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
model = AutoModelForQuestionAnswering.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN"))
headers = {
"Accept": "application/json",
"Authorization": "Bearer " + os.environ.get("HF_TOKEN"),
"Content-Type": "application/json"
}
def query(payload):
response = requests.post(os.environ.get("API_URL"), headers=headers, json=payload)
return response.json()
def greet(quote):
sources = []
answers = []
scores = []
ids = []
docs = ensemble_retriever.get_relevant_documents(quote)
for i in docs:
context = ViTokenizer.tokenize(i.page_content)
question = ViTokenizer.tokenize(quote)
print("source:", i.metadata['source'])
sources.append(i.metadata['source'])
output = query({
"inputs": {
"question": question,
"context": context[:256]
},
})
while "error" in output:
# print('fail')
time.sleep(1)
output = query({
"inputs": {
"question": question,
"context": context[:256]
},
})
answers.append(output['answer'])
return answers
def greet2(quote):
answers = []
docs = ensemble_retriever.get_relevant_documents(quote)
return docs
for i in docs:
context = ViTokenizer.tokenize(i.page_content)
question = ViTokenizer.tokenize(quote)
inputs = tokenizer(question, context, return_tensors="pt")
outputs = model(**inputs)
start_index = torch.argmax(outputs.start_logits)
end_index = torch.argmax(outputs.end_logits) + 1
answer = tokenizer.convert_tokens_to_string(
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index]))
answers.append(answer)
return answers
if __name__ == "__main__":
quote = "Địa chỉ nhà trường?"
iface = gr.Interface(fn=greet2, inputs="text", outputs="text")
iface.launch()