import streamlit as st import faiss from sentence_transformers import SentenceTransformer import pickle import re from transformers import pipeline import torch st.set_page_config(page_title = "Vietnamese Legal Question Answering System", page_icon= "./app/static/Law.png", layout="centered", initial_sidebar_state="collapsed") with open("./static/styles.css") as f: st.markdown(f"", unsafe_allow_html=True) st.markdown(f"""
""", unsafe_allow_html=True) st.markdown( """

Vietnamese Legal Question Answering System

""", unsafe_allow_html=True ) with open('articles.pkl', 'rb') as file: articles = pickle.load(file) index_loaded = faiss.read_index("sentence_embeddings_index_no_citation.faiss") device = 0 if torch.cuda.is_available() else -1 if 'model_embedding' not in st.session_state: st.session_state.model_embedding = SentenceTransformer('bkai-foundation-models/vietnamese-bi-encoder', device = f"cuda:{device}") # Replace this with your own checkpoint model_checkpoint = "model" question_answerer = pipeline("question-answering", model=model_checkpoint, device = device) def question_answering(question): print(question) query_sentence = [question] query_embedding = st.session_state.model_embedding.encode(query_sentence) k = 200 D, I = index_loaded.search(query_embedding.astype('float32'), k) # D is distances, I is indices answer = [question_answerer(question=query_sentence[0], context=articles[I[0][i]], max_answer_len = 512) for i in range(k)] best_answer = max(answer, key=lambda x: x['score']) print(best_answer) if best_answer['score'] > 0.7: return best_answer['answer'] elif best_answer['score'] > 0.3: return f"Tôi không chắc lắm nhưng có lẽ câu trả lời là: \n{best_answer['answer']}" return f"Xin lỗi tôi không biết câu trả lời cho câu hỏi này, vui lòng hỏi lại câu hỏi khác" # if "messages" not in st.session_state: # st.session_state.messages = [] # for message in st.session_state.messages: # with st.chat_message(message["role"]): # st.markdown(message["content"]) def clean_answer(s): # Sử dụng regex để loại bỏ tất cả các ký tự đặc biệt ở cuối chuỗi return re.sub(r'[^aAàÀảẢáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0-9]+$', '', s) # if prompt := st.chat_input("What is up?"): # st.session_state.messages.append({"role": "user", "content": prompt}) # with st.chat_message("user"): # st.markdown(prompt) # response = clean_answer(question_answering(prompt)) # with st.chat_message("assistant"): # st.markdown(response) # st.session_state.messages.append({"role": "assistant", "content": response}) if 'messages' not in st.session_state: st.session_state.messages = [] for message in st.session_state.messages: if message['role'] == 'assistant': avatar_class = "assistant-avatar" message_class = "assistant-message" avatar = './app/static/AI.png' else: avatar_class = "user-avatar" message_class = "user-message" avatar = './app/static/human.jpg' st.markdown(f"""
{message['content']}
""", unsafe_allow_html=True) if prompt := st.chat_input(placeholder='Xin chào, tôi có thể giúp được gì cho bạn?'): st.markdown(f"""
{prompt}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'user', 'content': prompt}) respond = clean_answer(question_answering(prompt)) st.markdown(f"""
{respond}
""", unsafe_allow_html=True) st.session_state.messages.append({'role': 'assistant', 'content': respond})