Spaces:
Paused
Paused
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
import nltk | |
from transformers.models.roberta.modeling_roberta import * | |
from transformers import RobertaForQuestionAnswering | |
from nltk import word_tokenize | |
import json | |
import pandas as pd | |
# import re | |
import base64 | |
print("===================================================================") | |
def download_nltk_punkt(): | |
nltk.download('punkt') | |
# Cache loading PhoBert model and tokenizer | |
def load_phoBert(): | |
model = AutoModelForSequenceClassification.from_pretrained('minhdang14902/Phobert_Law') | |
tokenizer = AutoTokenizer.from_pretrained('minhdang14902/Phobert_Law') | |
return model, tokenizer | |
# Call the cached functions | |
download_nltk_punkt() | |
phoBert_model, phoBert_tokenizer = load_phoBert() | |
# Initialize the pipeline with the loaded PhoBert model and tokenizer | |
chatbot_pipeline = pipeline("sentiment-analysis", model=phoBert_model, tokenizer=phoBert_tokenizer) | |
# Load spaCy Vietnamese model | |
# nlp = spacy.load('vi_core_news_lg') | |
# Load intents from json file | |
def load_json_file(filename): | |
with open(filename) as f: | |
file = json.load(f) | |
return file | |
filename = './Law_2907.json' | |
intents = load_json_file(filename) | |
def create_df(): | |
df = pd.DataFrame({ | |
'Pattern': [], | |
'Tag': [] | |
}) | |
return df | |
df = create_df() | |
def extract_json_info(json_file, df): | |
for intent in json_file['intents']: | |
for pattern in intent['patterns']: | |
sentence_tag = [pattern, intent['tag']] | |
df.loc[len(df.index)] = sentence_tag | |
return df | |
df = extract_json_info(intents, df) | |
df2 = df.copy() | |
labels = df2['Tag'].unique().tolist() | |
labels = [s.strip() for s in labels] | |
num_labels = len(labels) | |
id2label = {id: label for id, label in enumerate(labels)} | |
label2id = {label: id for id, label in enumerate(labels)} | |
def chatPhobert(text): | |
label = label2id[chatbot_pipeline(text)[0]['label']] | |
response = intents['intents'][label]['responses'] | |
print(response[0]) | |
return response[0] | |
st.title("Chatbot Phobert Law") | |
st.write("Hi! Tôi là trợ lý của bạn trong việc trả lời các câu hỏi về pháp luật. Nếu câu trả lời trống trơn, đừng lo, chỉ là hệ thống không thấy câu trả lời phù hợp!!") | |
text = st.text_input("User: ", key="input") | |
if 'chat_history' not in st.session_state: | |
st.session_state['chat_history'] = [] | |
# prompt = st.chat_input("Hãy chat gì đó!") | |
# if prompt: | |
# result = chatRoberta(text) | |
# st.write(f"HUIT Chatbot: {result[0]['answer']}") | |
# if st.button("Chat!"): | |
# if text: | |
# result = chatRoberta(text) | |
# st.write(f"Chatbot: {result[0]['answer']}") | |
# else: | |
# st.write("Hãy chat gì đó!") | |
def get_response(text): | |
st.subheader("The Answer is:") | |
st.write(text) | |
answer = chatPhobert(text) | |
result = answer | |
return result | |
if st.button("Chat!"): | |
st.session_state['chat_history'].append(("User", text)) | |
response = get_response(text) | |
st.subheader("The Response is:") | |
message = st.empty() | |
result = "" | |
for chunk in response: | |
result += chunk | |
message.markdown(result + "❚ ") | |
message.markdown(result) | |
st.session_state['chat_history'].append(("Bot", result)) | |
for i, (sender, message) in enumerate(st.session_state['chat_history']): | |
if sender == "User": | |
st.text_area(f"User:", value=message, height=100, max_chars=None, key=f"user_{i}") | |
else: | |
st.text_area(f"Bot:", value=message, height=100, max_chars=None, key=f"bot_{i}") |