Spaces:
Sleeping
Sleeping
#basics | |
from http import server | |
import time | |
import pandas as pd | |
import numpy as np | |
import pickle | |
from PIL import Image | |
#DL | |
import torch | |
from transformers import T5ForConditionalGeneration, T5TokenizerFast | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
#streamlit | |
import streamlit as st | |
# from streamlit_server_state import server_state, server_state_lock | |
# import SessionState | |
from load_css import local_css | |
local_css("./style.css") | |
#text preprocess | |
import re | |
from pyvi import ViTokenizer | |
from rank_bm25 import BM25Okapi | |
#helper functions | |
from inspect import getsourcefile | |
import os.path as path, sys | |
from pathlib import Path | |
current_dir = path.dirname(path.abspath(getsourcefile(lambda:0))) | |
sys.path.insert(0, current_dir[:current_dir.rfind(path.sep)]) | |
# import src.clean_dataset as clean | |
def preprocess(sentence): | |
sentence=str(sentence) | |
sentence = sentence.lower() | |
sentence=sentence.replace('{html}',"") | |
cleanr = re.compile('<.*?>') | |
cleantext = re.sub(cleanr, '', sentence) | |
rem_url=re.sub(r'http\S+', '',cleantext) | |
word_list = rem_url.split() | |
preped = ViTokenizer.tokenize(" ".join(word_list)) | |
return preped | |
DEFAULT = '< PICK A VALUE >' | |
def selectbox_with_default(text, values, default=DEFAULT, sidebar=False): | |
func = st.sidebar.selectbox if sidebar else st.selectbox | |
return func(text, np.insert(np.array(values, object), 0, default)) | |
def loadmodels(): | |
model = T5ForConditionalGeneration.from_pretrained("wanderer2k1/T5-LawsQA") | |
tokenizer = T5TokenizerFast.from_pretrained("wanderer2k1/T5-LawsQA") | |
bi_encoder = SentenceTransformer('wanderer2k1/BertCondenser_LawsQA') | |
return tokenizer, model, bi_encoder | |
def hf_run_model(tokenizer, model, input_string, **generator_args): | |
generator_args = { | |
"max_length": 256, | |
"temperature":0.0, | |
"num_beams": 4, | |
"length_penalty": 0.1, | |
"no_repeat_ngram_size": 8, | |
"early_stopping": True, | |
} | |
input_string = "generate questions: " + input_string + " </s>" | |
input_ids = tokenizer.encode(input_string, return_tensors="pt") | |
res = model.generate(input_ids, **generator_args) | |
output = tokenizer.batch_decode(res, skip_special_tokens=True) | |
output = [item.split("<sep>") for item in output] | |
return output | |
#%% | |
sys.path.pop(0) | |
#1. load in complete transformed and processed dataset | |
if 'df' not in st.session_state: | |
st.session_state['df'] = pd.read_csv('./data/corpus.pkl', sep = '\t') | |
st.session_state['passages'] = st.session_state['df']['text'].values.tolist() | |
st.session_state['passage_id'] = st.session_state['df']['title'].values.tolist() | |
#2 load corpus embeddings for neural QA: | |
if 'embedded_passages' not in st.session_state: | |
with open("./data/embedded_corpus_BertCondenser_tuples.pkl", 'rb') as inp: | |
embedded_passages = pickle.load(inp) | |
st.session_state['embedded_passages'] = torch.Tensor(embedded_passages) | |
#3 load BM25: | |
if 'bm25' not in st.session_state: | |
with open("models/BM25_pyvi_segmented_splitted.pkl", 'rb') as inp: | |
st.session_state['bm25'] = pickle.load(inp) | |
#4: model | |
if 'model' not in st.session_state: | |
st.session_state['tokenizer'], st.session_state['model'], st.session_state['bi_encoder'] = loadmodels() | |
#%% | |
def deploy(question): | |
top_k = returns # Number of passages we want to retrieve with the bi-encoder | |
tokenized_query = preprocess(question).split() | |
query = ' '.join(tokenized_query) | |
emb_query = st.session_state['bi_encoder'].encode(query) | |
scores = st.session_state['bm25'].get_scores(tokenized_query) | |
top_score_ids = np.argpartition(scores, -50)[-50:] | |
emb_candidates = torch.Tensor() | |
for i in top_score_ids: | |
emb_candidates = torch.cat([emb_candidates,st.session_state['embedded_passages'][i:i+1]], axis = 0) | |
cosine_sim = cos_sim(emb_query, emb_candidates) | |
doc_inds = np.argpartition(cosine_sim.numpy()[0], -top_k)[-top_k:] | |
top_score_ids = top_score_ids.take(doc_inds) | |
matches = [] | |
ids = [] | |
answers = [] | |
for doc_ind in top_score_ids: | |
doc = st.session_state['passages'][doc_ind].replace('_',' ') | |
matches.append(doc)#' '.join(doc).replace('_',' ')) | |
ids.append(st.session_state['passage_id'][doc_ind].replace('_',' '))#' '.join(doc[:30].split()[:3])) | |
# i=0 | |
for context in matches: | |
q = "Trả lời câu hỏi: "+query + " Trong ngữ cảnh: "+context#tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(context)) | |
a = hf_run_model(st.session_state['tokenizer'], st.session_state['model'], q)[0][0] | |
answers.append(a) | |
# generate result df | |
df_results = pd.DataFrame( | |
{'Title': ids, | |
'Answer': answers, | |
'Retrieved': matches, | |
}) | |
# st.header("Retrieved Answers:") | |
# df_results.set_index('title', inplace=True) | |
st.header("Results:") | |
st.table(df_results) | |
# del tokenizer, model, bi_encoder, emb_candidates | |
#%% | |
#title start page | |
st.title('Closed Domain QA System on Vietnamese Laws') | |
sdg = Image.open('./logo.jpg') | |
st.sidebar.image(sdg, width=300) | |
st.sidebar.title('Settings') | |
st.caption("by HoangNV - on custom laws QA data set") | |
returns = st.sidebar.slider('Number of answer suggestions:', 1, 3, 2) | |
question = st.text_input('Type in your legal question:') | |
if len(question) != 0: | |
t0 = time.time() | |
with st.spinner('Finding best answers...'): | |
deploy(question) | |
st.write("Runtime: "+str(time.time()-t0)) | |
#%% | |
p = Path('.') | |