Phobert_Law / app.py
minhdang14902's picture
Update app.py
66aca7d verified
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import nltk
from transformers.models.roberta.modeling_roberta import *
from transformers import RobertaForQuestionAnswering
from nltk import word_tokenize
import json
import pandas as pd
# import re
import base64
print("===================================================================")
@st.cache_data
def download_nltk_punkt():
nltk.download('punkt')
# Cache loading PhoBert model and tokenizer
@st.cache_data
def load_phoBert():
model = AutoModelForSequenceClassification.from_pretrained('minhdang14902/Phobert_Law')
tokenizer = AutoTokenizer.from_pretrained('minhdang14902/Phobert_Law')
return model, tokenizer
# Call the cached functions
download_nltk_punkt()
phoBert_model, phoBert_tokenizer = load_phoBert()
# Initialize the pipeline with the loaded PhoBert model and tokenizer
chatbot_pipeline = pipeline("sentiment-analysis", model=phoBert_model, tokenizer=phoBert_tokenizer)
# Load spaCy Vietnamese model
# nlp = spacy.load('vi_core_news_lg')
# Load intents from json file
def load_json_file(filename):
with open(filename) as f:
file = json.load(f)
return file
filename = './Law_2907.json'
intents = load_json_file(filename)
def create_df():
df = pd.DataFrame({
'Pattern': [],
'Tag': []
})
return df
df = create_df()
def extract_json_info(json_file, df):
for intent in json_file['intents']:
for pattern in intent['patterns']:
sentence_tag = [pattern, intent['tag']]
df.loc[len(df.index)] = sentence_tag
return df
df = extract_json_info(intents, df)
df2 = df.copy()
labels = df2['Tag'].unique().tolist()
labels = [s.strip() for s in labels]
num_labels = len(labels)
id2label = {id: label for id, label in enumerate(labels)}
label2id = {label: id for id, label in enumerate(labels)}
def chatPhobert(text):
label = label2id[chatbot_pipeline(text)[0]['label']]
response = intents['intents'][label]['responses']
print(response[0])
return response[0]
st.title("Chatbot Phobert Law")
st.write("Hi! Tôi là trợ lý của bạn trong việc trả lời các câu hỏi về pháp luật. Nếu câu trả lời trống trơn, đừng lo, chỉ là hệ thống không thấy câu trả lời phù hợp!!")
text = st.text_input("User: ", key="input")
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
# prompt = st.chat_input("Hãy chat gì đó!")
# if prompt:
# result = chatRoberta(text)
# st.write(f"HUIT Chatbot: {result[0]['answer']}")
# if st.button("Chat!"):
# if text:
# result = chatRoberta(text)
# st.write(f"Chatbot: {result[0]['answer']}")
# else:
# st.write("Hãy chat gì đó!")
def get_response(text):
st.subheader("The Answer is:")
st.write(text)
answer = chatPhobert(text)
result = answer
return result
if st.button("Chat!"):
st.session_state['chat_history'].append(("User", text))
response = get_response(text)
st.subheader("The Response is:")
message = st.empty()
result = ""
for chunk in response:
result += chunk
message.markdown(result + "❚ ")
message.markdown(result)
st.session_state['chat_history'].append(("Bot", result))
for i, (sender, message) in enumerate(st.session_state['chat_history']):
if sender == "User":
st.text_area(f"User:", value=message, height=100, max_chars=None, key=f"user_{i}")
else:
st.text_area(f"Bot:", value=message, height=100, max_chars=None, key=f"bot_{i}")