|
import gradio as gr |
|
import os |
|
from langchain.retrievers import EnsembleRetriever |
|
from utils import * |
|
import requests |
|
from pyvi import ViTokenizer, ViPosTagger |
|
import time |
|
from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
|
import torch |
|
|
|
retriever = load_the_embedding_retrieve(is_ready=False, k=3) |
|
bm25_retriever = load_the_bm25_retrieve(k=3) |
|
|
|
ensemble_retriever = EnsembleRetriever( |
|
retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5] |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN")) |
|
model = AutoModelForQuestionAnswering.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN")) |
|
|
|
headers = { |
|
"Accept": "application/json", |
|
"Authorization": "Bearer " + os.environ.get("HF_TOKEN"), |
|
"Content-Type": "application/json" |
|
} |
|
|
|
|
|
def query(payload): |
|
response = requests.post(os.environ.get("API_URL"), headers=headers, json=payload) |
|
return response.json() |
|
|
|
|
|
def greet(quote): |
|
sources = [] |
|
answers = [] |
|
scores = [] |
|
ids = [] |
|
|
|
docs = ensemble_retriever.get_relevant_documents(quote) |
|
|
|
for i in docs: |
|
context = ViTokenizer.tokenize(i.page_content) |
|
question = ViTokenizer.tokenize(quote) |
|
print("source:", i.metadata['source']) |
|
sources.append(i.metadata['source']) |
|
output = query({ |
|
"inputs": { |
|
"question": question, |
|
"context": context[:256] |
|
}, |
|
}) |
|
while "error" in output: |
|
|
|
time.sleep(1) |
|
output = query({ |
|
"inputs": { |
|
"question": question, |
|
"context": context[:256] |
|
}, |
|
}) |
|
|
|
answers.append(output['answer']) |
|
return answers |
|
|
|
|
|
def greet2(quote): |
|
answers = [] |
|
docs = ensemble_retriever.get_relevant_documents(quote) |
|
|
|
return docs |
|
|
|
for i in docs: |
|
context = ViTokenizer.tokenize(i.page_content) |
|
question = ViTokenizer.tokenize(quote) |
|
|
|
inputs = tokenizer(question, context, return_tensors="pt") |
|
|
|
outputs = model(**inputs) |
|
|
|
start_index = torch.argmax(outputs.start_logits) |
|
end_index = torch.argmax(outputs.end_logits) + 1 |
|
|
|
answer = tokenizer.convert_tokens_to_string( |
|
tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index])) |
|
|
|
answers.append(answer) |
|
|
|
return answers |
|
|
|
|
|
if __name__ == "__main__": |
|
quote = "Địa chỉ nhà trường?" |
|
|
|
iface = gr.Interface(fn=greet2, inputs="text", outputs="text") |
|
iface.launch() |
|
|