import gradio as gr import os from langchain.retrievers import EnsembleRetriever from utils import * import requests from pyvi import ViTokenizer, ViPosTagger import time from transformers import AutoTokenizer, AutoModelForQuestionAnswering import torch retriever = load_the_embedding_retrieve(is_ready=True, k=3) bm25_retriever = load_the_bm25_retrieve(k=3) ensemble_retriever = EnsembleRetriever( retrievers=[bm25_retriever, retriever], weights=[0.5, 0.5] ) tokenizer = AutoTokenizer.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN")) model = AutoModelForQuestionAnswering.from_pretrained("ShynBui/vie_qa", token=os.environ.get("HF_TOKEN")) headers = { "Accept": "application/json", "Authorization": "Bearer "+ os.environ.get("HF_TOKEN"), "Content-Type": "application/json" } def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() def greet(quote): sources = [] answers = [] scores = [] ids = [] docs = ensemble_retriever.get_relevant_documents(quote) for i in docs: context = ViTokenizer.tokenize(i.page_content) question = ViTokenizer.tokenize(quote) print("source:", i.metadata['source']) sources.append(i.metadata['source']) output = query({ "inputs": { "question": question, "context": context[:256] }, }) while "error" in output: # print('fail') time.sleep(1) output = query({ "inputs": { "question": question, "context": context[:256] }, }) answers.append(output['answer']) return answers def greet2(quote): answers = [] docs = ensemble_retriever.get_relevant_documents(quote) for i in docs: context = ViTokenizer.tokenize(i.page_content) question = ViTokenizer.tokenize(quote) inputs = tokenizer(question, context, return_tensors="pt") outputs = model(**inputs) start_index = torch.argmax(outputs.start_logits) end_index = torch.argmax(outputs.end_logits) + 1 answer = tokenizer.convert_tokens_to_string( tokenizer.convert_ids_to_tokens(inputs["input_ids"][0][start_index:end_index])) answers.append(answer) return answers if __name__ == "__main__": iface = gr.Interface(fn=greet2, inputs="text", outputs="text") iface.launch()