Demo_final / app.py
Linhz's picture
Update app.py
f1ec914 verified
import streamlit as st
import faiss
from sentence_transformers import SentenceTransformer
import pickle
import re
from transformers import pipeline
import torch
st.set_page_config(page_title = "Vietnamese Legal Question Answering System", page_icon= "./app/static/Law.png", layout="centered", initial_sidebar_state="collapsed")
with open("./static/styles.css") as f:
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True)
st.markdown(f"""
<div class=logo_area>
<img src="./app/static/Law.png"/>
</div>
""", unsafe_allow_html=True)
st.markdown(
"""
<h1 style="text-align: center;">Vietnamese Legal Question Answering System</h1>
""",
unsafe_allow_html=True
)
with open('articles.pkl', 'rb') as file:
articles = pickle.load(file)
index_loaded = faiss.read_index("sentence_embeddings_index_no_citation.faiss")
device = 0 if torch.cuda.is_available() else -1
if 'model_embedding' not in st.session_state:
st.session_state.model_embedding = SentenceTransformer('bkai-foundation-models/vietnamese-bi-encoder', device = f"cuda:{device}")
# Replace this with your own checkpoint
model_checkpoint = "model"
question_answerer = pipeline("question-answering", model=model_checkpoint, device = device)
def question_answering(question):
print(question)
query_sentence = [question]
query_embedding = st.session_state.model_embedding.encode(query_sentence)
k = 200
D, I = index_loaded.search(query_embedding.astype('float32'), k) # D is distances, I is indices
answer = [question_answerer(question=query_sentence[0], context=articles[I[0][i]], max_answer_len = 512) for i in range(k)]
best_answer = max(answer, key=lambda x: x['score'])
print(best_answer)
if best_answer['score'] > 0.7:
return best_answer['answer']
elif best_answer['score'] > 0.3:
return f"Tôi không chắc lắm nhưng có lẽ câu trả lời là: \n{best_answer['answer']}"
return f"Xin lỗi tôi không biết câu trả lời cho câu hỏi này, vui lòng hỏi lại câu hỏi khác"
# if "messages" not in st.session_state:
# st.session_state.messages = []
# for message in st.session_state.messages:
# with st.chat_message(message["role"]):
# st.markdown(message["content"])
def clean_answer(s):
# Sử dụng regex để loại bỏ tất cả các ký tự đặc biệt ở cuối chuỗi
return re.sub(r'[^aAàÀảẢáÁạẠăĂằẰẳẲẵẴắẮặẶâÂầẦẩẨẫẪấẤậẬbBcCdDđĐeEèÈẻẺẽẼéÉẹẸêÊềỀểỂễỄếẾệỆfFgGhHiIìÌỉỈĩĨíÍịỊjJkKlLmMnNoOòÒỏỎõÕóÓọỌôÔồỒổỔỗỖốỐộỘơƠờỜởỞỡỠớỚợỢpPqQrRsStTuUùÙủỦũŨúÚụỤưƯừỪửỬữỮứỨựỰvVwWxXyYỳỲỷỶỹỸýÝỵỴzZ0-9]+$', '', s)
# if prompt := st.chat_input("What is up?"):
# st.session_state.messages.append({"role": "user", "content": prompt})
# with st.chat_message("user"):
# st.markdown(prompt)
# response = clean_answer(question_answering(prompt))
# with st.chat_message("assistant"):
# st.markdown(response)
# st.session_state.messages.append({"role": "assistant", "content": response})
if 'messages' not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
if message['role'] == 'assistant':
avatar_class = "assistant-avatar"
message_class = "assistant-message"
avatar = './app/static/AI.png'
else:
avatar_class = "user-avatar"
message_class = "user-message"
avatar = './app/static/human.jpg'
st.markdown(f"""
<div class="{message_class}">
<img src="{avatar}" class="{avatar_class}" />
<div class="stMarkdown">{message['content']}</div>
</div>
""", unsafe_allow_html=True)
if prompt := st.chat_input(placeholder='Xin chào, tôi có thể giúp được gì cho bạn?'):
st.markdown(f"""
<div class="user-message">
<img src="./app/static/human.jpg" class="user-avatar" />
<div class="stMarkdown">{prompt}</div>
</div>
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'user', 'content': prompt})
respond = clean_answer(question_answering(prompt))
st.markdown(f"""
<div class="assistant-message">
<img src="./app/static/AI.png" class="assistant-avatar" />
<div class="stMarkdown">{respond}</div>
</div>
""", unsafe_allow_html=True)
st.session_state.messages.append({'role': 'assistant', 'content': respond})