ABSA_it / app.py
firstmetis's picture
Update app.py
2dd927f verified
import re
import streamlit as st
import torch
from transformers import MBartForConditionalGeneration, MBartTokenizer
from huggingface_hub import hf_hub_download
# 🎯 โหลดโมเดลจาก Hugging Face
st.markdown(
"""
<style>
.container {
max-width: 700px;
margin: auto;
border-radius: 10px;
background-color: #f9f9f9;
box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}
.content {
text-align: justify;
line-height: 1.6;
}
.under {
text-decoration-line: underline;
text-decoration-style: double;
}
</style>
""",
unsafe_allow_html=True,
)
@st.cache_resource
def load_model():
try:
# 🔹 ดาวน์โหลด model.pth จาก Hugging Face
model_path = hf_hub_download(repo_id="firstmetis/absa_it", filename="model.pth")
# 🔹 โหลดโมเดล MBart
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50")
# 🔹 โหลด tokenizer และเพิ่ม special tokens
tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50")
special_tokens = ['<SYMBOL>', '<ASPECT>', '<OPINION>', '<POS>', '<NEG>', '<NEU>']
tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
# 🔹 ปรับขนาด token embeddings
model.resize_token_embeddings(len(tokenizer))
# 🔹 โหลดพารามิเตอร์
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.eval()
return model, tokenizer
except Exception as e:
st.error(f"❌ เกิดข้อผิดพลาดขณะโหลดโมเดล: {e}")
return None, None
# โหลดโมเดล
model, tokenizer = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model:
model.to(device)
# ฟังก์ชันแปลงผลลัพธ์ (ลบความมั่นใจออก)
def format_output(text):
pattern = r"<SYMBOL>\s*(.*?)\s*<ASPECT>\s*(.*?)\s*<OPINION>\s*(.*?)\s*<(POS|NEG|NEU)>"
match = re.search(pattern, text)
sentiment_mapping = {
"POS": "เชิงบวก (Positive)",
"NEG": "เชิงลบ (Negative)",
"NEU": "เชิงกลาง (Neutral)"
}
sentiment_colors = {
"POS": "#d4edda", # สีเขียว
"NEG": "#f8d7da", # สีแดง
"NEU": "#ffffff" # สีขาว
}
if match:
symbol, aspect, opinion, sentiment = match.groups()
sentiment_text = sentiment_mapping.get(sentiment, sentiment) # แปลง sentiment
return f"""
<b>SYMBOL:</b> <span style= "color: black; background-color: #dbeafe; padding: 3px 6px; border-radius: 5px;">{symbol}</span>&nbsp;&nbsp;
<b>ASPECT:</b> <span style= "color: black; background-color: #ffefd9; padding: 3px 6px; border-radius: 5px;">{aspect}</span>&nbsp;&nbsp;
<b>OPINION:</b> <span style= "color: black; background-color: #f5c6ff; padding: 3px 6px; border-radius: 5px;">{opinion}</span>&nbsp;&nbsp;
<b>SENTIMENT:</b> <span style= "color: black; background-color: {sentiment_colors.get(sentiment, '#ffffff')}; padding: 3px 6px; border-radius: 5px;">{sentiment_text}</span>
"""
return f"{text}"
# ฟังก์ชันสำหรับสร้างข้อความ (ไม่คำนวณ confidence)
def generate_text(input_text):
input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=512).input_ids
input_ids = input_ids.to(device)
with torch.no_grad():
outputs = model.generate(
input_ids,
num_beams=4,
do_sample=True,
temperature=1.2,
top_k=50,
top_p=0.95,
num_return_sequences=4,
max_length=50,
return_dict_in_generate=True,
output_scores=False # ไม่ต้องการให้คืนค่า logits ของ output
)
sequences = outputs.sequences
# แปลง sequences เป็นข้อความ
output_texts = [
tokenizer.decode(seq, skip_special_tokens=False).replace("</s>", "").replace("<pad>", "").strip()
for seq in sequences
]
# คืนค่าเป็นแค่ข้อความที่สร้างขึ้นจากโมเดล
return output_texts
# **🎯 สร้าง UI ด้วย Streamlit**
st.title("📌 Aspect-based Sentiment Analysis (ABSA)")
st.markdown(
"""
<div class='content'>
<h4>📍 วิธีการใช้งานเว็บไซต์</h4>
<p>
&emsp;1. เลือกพาดหัวข่าวเกี่ยวกับหุ้นที่สนใจโดยมีเงื่อนไขดังนี้</br>
&emsp;&emsp;&emsp;- เป็นข่าวหุ้นไทยในปี พ.ศ.2566-2567</br>
&emsp;&emsp;&emsp;- เป็นข่าวหุ้นไทยที่มีสัญลักษณ์หุ้นชัดเจน</br>
&emsp;&emsp;&emsp;- เป็นข่าวหุ้นไทยที่มีการออกข่าวค่อนข้างบ่อย</br>
&emsp;&emsp;&emsp; <u class="under">ตัวอย่าง</u> : TISCO ปันผลดี เหมาะสะสม บล.ดีบีเอสฯให้เป้า 118 บ.</br>
&emsp;2. นำพาดหัวข่าวใส่ลงช่องว่างด้านล่าง</br>
&emsp;3. กดปุ่ม Apply เพื่อวิเคราะห์
</p>
</div>
""",
unsafe_allow_html=True,
)
st.markdown("ใส่พาดหัวข่าวหุ้น เพื่อวิเคราะห์ Sentiment")
# รับค่าจากผู้ใช้
user_input = st.text_input("✍️ ใส่ข้อความตรงนี้ :", "")
# ปุ่ม Apply
if st.button("Apply"):
if user_input:
responses = generate_text(user_input) # ได้ list ของข้อความที่สร้างจากโมเดล
if responses: # ตรวจสอบว่ามีข้อมูล
for i, response_text in enumerate(responses, 1):
formatted_output = format_output(response_text)
st.markdown(f"**🔹 ผลลัพธ์ {i} :**<br>{formatted_output}", unsafe_allow_html=True)
st.markdown("<hr>", unsafe_allow_html=True) # เส้นคั่นระหว่างผลลัพธ์
else:
st.warning("⚠️ ไม่พบผลลัพธ์ที่สามารถวิเคราะห์ได้")
else:
st.warning("⚠️ กรุณากรอกข้อความก่อนกด Apply")