import os import tensorflow as tf import transformers from tensorflow import keras from transformers import BertTokenizer, TFBertModel import pandas as pd from datetime import date, timedelta import requests import time from typing import List, Optional, Dict, Any class BertPredictor: """ 用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。 """ def __init__(self, tokenizer_name: str = 'hfl/rbt3', max_news_per_keyword: int = 5): """ 初始化預測器,載入分詞器、預訓練模型並獲取新聞。 Args: tokenizer_name (str): BERT 分詞器的名稱。 max_news_per_keyword (int): 每個關鍵字要抓取的新聞最大數量。 """ # --- 路徑和檔案名稱設置 --- self.current_dir = os.path.dirname(os.path.abspath(__file__)) self.model_path = os.path.join(self.current_dir, 'Best-complete-model.h5') # 檔案名稱用今天的日期,但內容是昨天的 today_date_str = date.today().strftime('%Y-%m-%d') self.news_csv_path = os.path.join(self.current_dir, f'news_{today_date_str}.csv') ### 抓取固定檔案而非今天,成果展示要註解 self.news_csv_path = os.path.join(self.current_dir, "news_2025-09-12.csv") # 用於API查詢的日期仍然是昨天 self.target_date = date.today() - timedelta(days=1) self.target_date_str = self.target_date.strftime('%Y-%m-%d') # --- GNews API 設定 --- self.api_key = "fd12e84a158c7d9eaf31627aaae0927a" # 請替換成您的 API Key self.base_url = "https://gnews.io/api/v4/search" self.keywords = ["Fed", "Interest Rates", "Inflation", "Tariffs", "ADR", "Treasury Yields"] self.max_news_per_keyword = max_news_per_keyword # --- 模型相關設置 --- self.text_max_length = 256 self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name) # 載入最佳模型 print("正在加載模型...") self.model = keras.models.load_model( self.model_path, custom_objects={'TFBertModel': TFBertModel} ) print("模型加載完成。") # --- 初始化流程 --- self._check_file_and_get_news_if_needed() # --- 內部使用方法 --- def _encode_texts(self, texts: list): """將文本轉換為 BERT 輸入格式 (input_ids, attention_mask)""" return self.tokenizer( texts, max_length=self.text_max_length, padding='max_length', truncation=True, return_tensors='tf' ) def _predict(self, new_text: str) -> float: """ 對單一新聞文本進行預測。 Args: new_text (str): 待預測的新聞文本。 Returns: float: 預測的股市影響分數。 """ new_encoding = self._encode_texts([new_text]) predicted_score = self.model.predict(dict(new_encoding), verbose=0)[0][0] return float(predicted_score) def _check_file_and_get_news_if_needed(self): """ 檢查今天的 news csv 是否存在。如果不存在,則呼叫 _get_news() 進行抓取。 """ if not os.path.exists(self.news_csv_path): print(f"找不到今天的檔案 '{os.path.basename(self.news_csv_path)}'。") self._get_news() else: print(f"已找到今天的檔案 '{os.path.basename(self.news_csv_path)}',將跳過新聞抓取步驟。") def _get_news(self): """ 使用 GNews API 抓取目標日期(昨天)的新聞,即時預測分數並儲存。 """ print("開始執行新聞抓取與即時預測...") print(f"搜尋日期設定為:{self.target_date_str} (將存檔至檔名含今日日期的檔案)") results = [] for kw in self.keywords: params = { "q": kw, "lang": "en", "country": "us", "max": self.max_news_per_keyword, "in": "title,description", "apikey": self.api_key, "from": f"{self.target_date_str}T00:00:00Z", "to": f"{self.target_date_str}T23:59:59Z" } try: response = requests.get(self.base_url, params=params) response.raise_for_status() data = response.json() print(f"關鍵字 '{kw}' 成功抓取到: {data.get('totalArticles', 0)} 則新聞") if "articles" in data: for article in data["articles"]: published_date = pd.to_datetime(article['publishedAt']).strftime('%Y-%m-%d') news_content = f"{article['title']} - {article.get('description', '')}" score = self._predict(news_content) results.append({ "時間": published_date, "分數": score, "內容": news_content }) except requests.exceptions.RequestException as e: print(f"錯誤:API 請求失敗 - {e}") continue finally: time.sleep(0.5) if not results: print("抓取完成。未找到任何相關新聞。") df_to_save = pd.DataFrame(columns=['時間', '分數', '內容']) else: print(f"成功抓取並預測 {len(results)} 筆新聞。") df_to_save = pd.DataFrame(results) try: print(f"正在將結果寫入檔案 '{self.news_csv_path}'...") df_to_save.to_csv(self.news_csv_path, index=False, encoding='utf-8-sig') print(f"成功!檔案已儲存至 '{self.news_csv_path}'。") except IOError as e: print(f"錯誤:寫入檔案失敗 - {e}") # --- 公開方法 --- def get_news_index(self) -> Optional[float]: """ 從今天的 news csv 檔案中讀取所有新聞分數並回傳其平均值。 Returns: float or None: 所有新聞的平均分數,如果檔案不存在或為空則回傳 None。 """ try: df = pd.read_csv(self.news_csv_path) if df.empty or '分數' not in df.columns: print(f"'{self.news_csv_path}' 為空或缺少 '分數' 欄位。") return None average_score = pd.to_numeric(df['分數'], errors='coerce').mean() return average_score if pd.notna(average_score) else None except FileNotFoundError: print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") return None except Exception as e: print(f"讀取或計算 CSV 檔案時發生錯誤:{e}") return None def get_news(self) -> Optional[List[str]]: """ 讀取今天的 news csv 檔案,並以 list 格式回傳分數絕對值最高的三則新聞內容。 """ try: df = pd.read_csv(self.news_csv_path) df['分數'] = pd.to_numeric(df['分數'], errors='coerce') df.dropna(subset=['分數'], inplace=True) if df.empty: return [] df['abs_score'] = df['分數'].abs() top_3_news_df = df.sort_values(by='abs_score', ascending=False).head(3) # 將 '內容' 欄位轉換為 list of strings return top_3_news_df['內容'].tolist() except FileNotFoundError: print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") return None except Exception as e: print(f"讀取或處理 CSV 檔案時發生錯誤:{e}") return None # --- 主程式區塊:只有當腳本直接執行時才運行 --- if __name__ == "__main__": if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Best-complete-model.h5')): print("錯誤:找不到模型文件 'Best-complete-model.h5'。請先訓練模型並確保它已保存。") else: predictor = BertPredictor(max_news_per_keyword=3) print("\n" + "="*30) avg_score = predictor.get_news_index() if avg_score is not None: print(f"從新聞檔案中計算出的平均分數為:{avg_score:.4f}") else: print("無法計算新聞檔案中的平均分數。") print("\n" + "="*30) top_news_content = predictor.get_news() if top_news_content: print("\n分數絕對值最高的三則新聞內容:") for i, content in enumerate(top_news_content): print(f" {i+1}. {content}") elif top_news_content == []: print("新聞檔案中無有效內容可顯示。") else: print("無法獲取最高分新聞。")