Spaces:
Sleeping
Sleeping
| import re | |
| import streamlit as st | |
| import torch | |
| from transformers import MBartForConditionalGeneration, MBartTokenizer | |
| from huggingface_hub import hf_hub_download | |
| # 🎯 โหลดโมเดลจาก Hugging Face | |
| st.markdown( | |
| """ | |
| <style> | |
| .container { | |
| max-width: 700px; | |
| margin: auto; | |
| border-radius: 10px; | |
| background-color: #f9f9f9; | |
| box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1); | |
| } | |
| .content { | |
| text-align: justify; | |
| line-height: 1.6; | |
| } | |
| .under { | |
| text-decoration-line: underline; | |
| text-decoration-style: double; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def load_model(): | |
| try: | |
| # 🔹 ดาวน์โหลด model.pth จาก Hugging Face | |
| model_path = hf_hub_download(repo_id="firstmetis/absa_it", filename="model.pth") | |
| # 🔹 โหลดโมเดล MBart | |
| model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50") | |
| # 🔹 โหลด tokenizer และเพิ่ม special tokens | |
| tokenizer = MBartTokenizer.from_pretrained("facebook/mbart-large-50") | |
| special_tokens = ['<SYMBOL>', '<ASPECT>', '<OPINION>', '<POS>', '<NEG>', '<NEU>'] | |
| tokenizer.add_special_tokens({'additional_special_tokens': special_tokens}) | |
| # 🔹 ปรับขนาด token embeddings | |
| model.resize_token_embeddings(len(tokenizer)) | |
| # 🔹 โหลดพารามิเตอร์ | |
| model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
| model.eval() | |
| return model, tokenizer | |
| except Exception as e: | |
| st.error(f"❌ เกิดข้อผิดพลาดขณะโหลดโมเดล: {e}") | |
| return None, None | |
| # โหลดโมเดล | |
| model, tokenizer = load_model() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| if model: | |
| model.to(device) | |
| # ฟังก์ชันแปลงผลลัพธ์ (ลบความมั่นใจออก) | |
| def format_output(text): | |
| pattern = r"<SYMBOL>\s*(.*?)\s*<ASPECT>\s*(.*?)\s*<OPINION>\s*(.*?)\s*<(POS|NEG|NEU)>" | |
| match = re.search(pattern, text) | |
| sentiment_mapping = { | |
| "POS": "เชิงบวก (Positive)", | |
| "NEG": "เชิงลบ (Negative)", | |
| "NEU": "เชิงกลาง (Neutral)" | |
| } | |
| sentiment_colors = { | |
| "POS": "#d4edda", # สีเขียว | |
| "NEG": "#f8d7da", # สีแดง | |
| "NEU": "#ffffff" # สีขาว | |
| } | |
| if match: | |
| symbol, aspect, opinion, sentiment = match.groups() | |
| sentiment_text = sentiment_mapping.get(sentiment, sentiment) # แปลง sentiment | |
| return f""" | |
| <b>SYMBOL:</b> <span style= "color: black; background-color: #dbeafe; padding: 3px 6px; border-radius: 5px;">{symbol}</span> | |
| <b>ASPECT:</b> <span style= "color: black; background-color: #ffefd9; padding: 3px 6px; border-radius: 5px;">{aspect}</span> | |
| <b>OPINION:</b> <span style= "color: black; background-color: #f5c6ff; padding: 3px 6px; border-radius: 5px;">{opinion}</span> | |
| <b>SENTIMENT:</b> <span style= "color: black; background-color: {sentiment_colors.get(sentiment, '#ffffff')}; padding: 3px 6px; border-radius: 5px;">{sentiment_text}</span> | |
| """ | |
| return f"{text}" | |
| # ฟังก์ชันสำหรับสร้างข้อความ (ไม่คำนวณ confidence) | |
| def generate_text(input_text): | |
| input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, padding=True, max_length=512).input_ids | |
| input_ids = input_ids.to(device) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_ids, | |
| num_beams=4, | |
| do_sample=True, | |
| temperature=1.2, | |
| top_k=50, | |
| top_p=0.95, | |
| num_return_sequences=4, | |
| max_length=50, | |
| return_dict_in_generate=True, | |
| output_scores=False # ไม่ต้องการให้คืนค่า logits ของ output | |
| ) | |
| sequences = outputs.sequences | |
| # แปลง sequences เป็นข้อความ | |
| output_texts = [ | |
| tokenizer.decode(seq, skip_special_tokens=False).replace("</s>", "").replace("<pad>", "").strip() | |
| for seq in sequences | |
| ] | |
| # คืนค่าเป็นแค่ข้อความที่สร้างขึ้นจากโมเดล | |
| return output_texts | |
| # **🎯 สร้าง UI ด้วย Streamlit** | |
| st.title("📌 Aspect-based Sentiment Analysis (ABSA)") | |
| st.markdown( | |
| """ | |
| <div class='content'> | |
| <h4>📍 วิธีการใช้งานเว็บไซต์</h4> | |
| <p> | |
|  1. เลือกพาดหัวข่าวเกี่ยวกับหุ้นที่สนใจโดยมีเงื่อนไขดังนี้</br> | |
|    - เป็นข่าวหุ้นไทยในปี พ.ศ.2566-2567</br> | |
|    - เป็นข่าวหุ้นไทยที่มีสัญลักษณ์หุ้นชัดเจน</br> | |
|    - เป็นข่าวหุ้นไทยที่มีการออกข่าวค่อนข้างบ่อย</br> | |
|     <u class="under">ตัวอย่าง</u> : TISCO ปันผลดี เหมาะสะสม บล.ดีบีเอสฯให้เป้า 118 บ.</br> | |
|  2. นำพาดหัวข่าวใส่ลงช่องว่างด้านล่าง</br> | |
|  3. กดปุ่ม Apply เพื่อวิเคราะห์ | |
| </p> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.markdown("ใส่พาดหัวข่าวหุ้น เพื่อวิเคราะห์ Sentiment") | |
| # รับค่าจากผู้ใช้ | |
| user_input = st.text_input("✍️ ใส่ข้อความตรงนี้ :", "") | |
| # ปุ่ม Apply | |
| if st.button("Apply"): | |
| if user_input: | |
| responses = generate_text(user_input) # ได้ list ของข้อความที่สร้างจากโมเดล | |
| if responses: # ตรวจสอบว่ามีข้อมูล | |
| for i, response_text in enumerate(responses, 1): | |
| formatted_output = format_output(response_text) | |
| st.markdown(f"**🔹 ผลลัพธ์ {i} :**<br>{formatted_output}", unsafe_allow_html=True) | |
| st.markdown("<hr>", unsafe_allow_html=True) # เส้นคั่นระหว่างผลลัพธ์ | |
| else: | |
| st.warning("⚠️ ไม่พบผลลัพธ์ที่สามารถวิเคราะห์ได้") | |
| else: | |
| st.warning("⚠️ กรุณากรอกข้อความก่อนกด Apply") | |