|
import streamlit as st |
|
from api import load_model_bert, load_model_lstm, inference |
|
import pandas as pd |
|
from huggingface_hub import hf_hub_download |
|
import os |
|
|
|
|
|
|
|
if os.path.exists("vietnamese_hate_speech_detection_phobert") == False: |
|
try: |
|
os.mkdir("vietnamese_hate_speech_detection_phobert") |
|
except FileExistsError: |
|
pass |
|
|
|
|
|
hf_hub_download( |
|
repo_id="jesse-tong/vietnamese_hate_speech_detection_phobert", |
|
filename="vinai_phobert-base-v2_finetuned.pth", |
|
repo_type="model", |
|
local_dir="vietnamese_hate_speech_detection_phobert" |
|
) |
|
hf_hub_download( |
|
repo_id="jesse-tong/vietnamese_hate_speech_detection_phobert", |
|
filename="distilled_lstm_model.pth", |
|
repo_type="model", |
|
local_dir="vietnamese_hate_speech_detection_phobert" |
|
) |
|
|
|
|
|
|
|
def app(): |
|
st.set_page_config(layout="wide") |
|
st.title("Phân tích ngôn từ thù địch, phân biệt sử dụng PhoBERT và LSTM") |
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_models(): |
|
loading_model_bar = st.progress(0, "Nạp các mô hình...") |
|
|
|
bert_model, bert_device = load_model_bert() |
|
loading_model_bar.progress(50, "Mô hình PhoBERT đã được nạp.") |
|
|
|
lstm_model, lstm_device = load_model_lstm() |
|
loading_model_bar.progress(100, "Mô hình LSTM đã được nạp.") |
|
loading_model_bar.empty() |
|
return bert_model, bert_device, lstm_model, lstm_device |
|
|
|
bert_model, bert_device, lstm_model, lstm_device = load_models() |
|
|
|
|
|
user_input = st.text_area("Nhập các bình luận để phân tích ngôn từ thù địch, phân biệt (xuống dòng cho từng bình luận):") |
|
|
|
if st.button("Phân tích"): |
|
if user_input: |
|
|
|
comments = user_input.splitlines() |
|
|
|
|
|
classification_bar = st.progress(0, "Đang phân tích với PhoBERT...") |
|
bert_predictions = inference(bert_model, bert_device, comments) |
|
st.write("Phân loại của PhoBERT:") |
|
st.table(pd.DataFrame(bert_predictions)) |
|
|
|
classification_bar.progress(50, "Đang phân tích với LSTM...") |
|
|
|
|
|
lstm_predictions = inference(lstm_model, lstm_device, comments) |
|
st.write("Phân loại của LSTM:") |
|
classification_bar.progress(100, "Phân tích hoàn tất!") |
|
classification_bar.empty() |
|
st.table(pd.DataFrame(lstm_predictions)) |
|
else: |
|
st.warning("Hãy nhập một vài bình luận.") |
|
|
|
if __name__ == "__main__": |
|
|
|
app() |