Reality8081 commited on
Commit
35beba6
·
1 Parent(s): 9466500

Update src

Browse files
Files changed (4) hide show
  1. app.py +52 -41
  2. src/data_processing.py +4 -4
  3. src/inference.py +22 -26
  4. src/train.py +81 -80
app.py CHANGED
@@ -6,7 +6,7 @@ import numpy as np
6
  import yfinance as yf
7
  from datetime import datetime, timedelta
8
 
9
- from src.inference import predict_next_day
10
 
11
  # --- CÁC HÀM HỖ TRỢ TÍNH TOÁN KỸ THUẬT CHO UI ---
12
  def calculate_ui_technical_indicators(df):
@@ -42,7 +42,7 @@ def calculate_ui_technical_indicators(df):
42
  def generate_quant_dashboard(ticker, model_name):
43
  try:
44
  # 1. Gọi Inference Engine
45
- preds, last_close, last_date, _ = predict_next_day(ticker, model_name)
46
 
47
  # 2. Lấy dữ liệu OHLCV 90 ngày để vẽ Candlestick & tính toán Context
48
  # Sử dụng yfinance trực tiếp để render UI mượt mà, độc lập với backend load_data nặng nề
@@ -60,7 +60,12 @@ def generate_quant_dashboard(ticker, model_name):
60
  rsi_val = last_row['RSI_14']
61
  macd_h = last_row['MACD_Hist']
62
 
63
- next_day = pd.to_datetime(last_date) + pd.offsets.BDay(1)
 
 
 
 
 
64
 
65
  except Exception as e:
66
  error_html = f"""<div style='background-color:#3a1010; padding:15px; border-left: 4px solid #ff4d4d; color: #ff8080;'>
@@ -76,6 +81,14 @@ def generate_quant_dashboard(ticker, model_name):
76
  consensus_html = ""
77
  target_price = 0
78
 
 
 
 
 
 
 
 
 
79
  if model_name == "Cả Hai":
80
  target_price = (price_lr + price_svr) / 2
81
  spread_bps = abs(pred_lr - pred_svr) * 10000 # Basis points
@@ -90,28 +103,32 @@ def generate_quant_dashboard(ticker, model_name):
90
  direction = "UNCERTAIN / CHOPPY ⚠️"
91
 
92
  consensus_html = f"""
93
- <div style="background:#1a1a24; border: 1px solid #333; padding: 15px; border-radius: 5px; margin-bottom: 10px;">
94
- <p style="color:#8892b0; margin:0; font-size:12px; font-family: monospace;">ALGO CONSENSUS ENGINE</p>
95
- <h3 style="color:{color}; margin: 5px 0;">{status}: {direction}</h3>
96
- <p style="color:#a8b2d1; margin:0; font-size:13px;">Divergence Spread: <b>{spread_bps:.1f} bps</b></p>
97
- <div style="display:flex; justify-content: space-between; margin-top: 10px; font-family: monospace; font-size: 13px;">
98
- <span style="color: {'#00ff00' if pred_lr>0 else '#ff3333'}">LR: {pred_lr*100:+.2f}% (${price_lr:.2f})</span>
99
- <span style="color: {'#00ff00' if pred_svr>0 else '#ff3333'}">SVR: {pred_svr*100:+.2f}% (${price_svr:.2f})</span>
100
- </div>
101
- </div>
102
- """
103
  else:
104
  active_pred = pred_lr if model_name == "Linear Regression" else pred_svr
105
  target_price = price_lr if model_name == "Linear Regression" else price_svr
106
  dir_color = "#00ff00" if active_pred > 0 else "#ff3333"
107
  dir_text = "BULLISH 📈" if active_pred > 0 else "BEARISH 📉"
108
  consensus_html = f"""
109
- <div style="background:#1a1a24; border: 1px solid #333; padding: 15px; border-radius: 5px; margin-bottom: 10px;">
110
- <p style="color:#8892b0; margin:0; font-size:12px; font-family: monospace;">SINGLE MODEL ACTIVATED: {model_name.upper()}</p>
111
- <h3 style="color:{dir_color}; margin: 5px 0;">DIRECTION: {dir_text}</h3>
112
- <p style="color:#a8b2d1; margin:0; font-size:13px;">Expected Return: <b>{active_pred*100:+.2f}%</b></p>
113
- </div>
114
- """
 
 
 
 
115
 
116
  # 4. Market Context Panel (Technical Stats)
117
  rsi_color = "#ff3333" if rsi_val > 70 else ("#00ff00" if rsi_val < 30 else "#a8b2d1")
@@ -134,7 +151,7 @@ def generate_quant_dashboard(ticker, model_name):
134
  # 5. Vẽ biểu đồ Plotly cấp độ Institutional
135
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
136
  vertical_spacing=0.03, row_heights=[0.75, 0.25],
137
- subplot_titles=(f"{ticker} - PRICE ACTION & PROJECTIONS", "VOLUME"))
138
 
139
  # Candlestick
140
  fig.add_trace(go.Candlestick(
@@ -150,33 +167,27 @@ def generate_quant_dashboard(ticker, model_name):
150
  # Volume subplot
151
  colors = ['#00ff00' if row['Close'] >= row['Open'] else '#ff3333' for _, row in df_ui.iterrows()]
152
  fig.add_trace(go.Bar(x=df_ui['Date'], y=df_ui['Volume'], marker_color=colors, name='Volume'), row=2, col=1)
153
-
154
  # --- Thêm điểm dự báo và Confidence Interval (Error Bands) dựa trên ATR ---
155
  if model_name in ["Linear Regression", "Cả Hai"]:
156
- fig.add_trace(go.Scatter(
157
- x=[df_ui['Date'].iloc[-1], next_day], y=[last_close, price_lr],
158
- mode='lines+markers', name='LR Target',
159
- line=dict(color='#ff00ff', dash='dot'), marker=dict(size=10, symbol='diamond')
160
- ))
161
 
162
  if model_name in ["SVR", "Cả Hai"]:
163
- fig.add_trace(go.Scatter(
164
- x=[df_ui['Date'].iloc[-1], next_day], y=[last_close, price_svr],
165
- mode='lines+markers', name='SVR Target',
166
- line=dict(color='#00ffff', dash='dot'), marker=dict(size=10, symbol='diamond')
167
- ))
168
-
169
  # Error Band (±1 ATR cho mức target)
170
- fig.add_trace(go.Scatter(
171
- x=[next_day, next_day],
172
- y=[target_price - atr_val, target_price + atr_val],
173
- mode='lines', name='±1 ATR Volatility Band',
174
- line=dict(color='rgba(255, 255, 255, 0.4)', width=5)
175
- ))
176
 
177
  # Tối ưu giao diện Plotly Dark Mode
178
  fig.update_layout(
179
- height=700,
180
  template="plotly_dark",
181
  plot_bgcolor='#0d0d14', paper_bgcolor='#0d0d14',
182
  margin=dict(l=40, r=40, t=40, b=40),
@@ -200,7 +211,7 @@ body { background-color: #0d0d14; color: #e6e6fa; font-family: 'Inter', sans-ser
200
  with gr.Blocks(title="Quant Terminal | Stock ML", css=css, theme=gr.themes.Monochrome()) as demo:
201
  gr.Markdown("""
202
  <div style="padding: 10px 0; border-bottom: 2px solid #333;">
203
- <h1 style="color: #e6e6fa; margin: 0; font-family: monospace;">⚡ QUANTRONIC ML TERMINAL </h1>
204
  <p style="color: #8892b0; margin: 0; font-family: monospace;">SVR & Ridge Regression Predictive Analytics Engine</p>
205
  </div>
206
  """)
@@ -220,7 +231,7 @@ with gr.Blocks(title="Quant Terminal | Stock ML", css=css, theme=gr.themes.Monoc
220
 
221
  # MAIN AREA (Charts)
222
  with gr.Column(scale=3):
223
- plot_chart = gr.Plot()
224
 
225
  btn_predict.click(
226
  fn=generate_quant_dashboard,
 
6
  import yfinance as yf
7
  from datetime import datetime, timedelta
8
 
9
+ from src.inference import predict_horizons
10
 
11
  # --- CÁC HÀM HỖ TRỢ TÍNH TOÁN KỸ THUẬT CHO UI ---
12
  def calculate_ui_technical_indicators(df):
 
42
  def generate_quant_dashboard(ticker, model_name):
43
  try:
44
  # 1. Gọi Inference Engine
45
+ preds, last_close, last_date, _ = predict_horizons(ticker, model_name)
46
 
47
  # 2. Lấy dữ liệu OHLCV 90 ngày để vẽ Candlestick & tính toán Context
48
  # Sử dụng yfinance trực tiếp để render UI mượt mà, độc lập với backend load_data nặng nề
 
60
  rsi_val = last_row['RSI_14']
61
  macd_h = last_row['MACD_Hist']
62
 
63
+ base_date = pd.to_datetime(last_date)
64
+ dates_future = {
65
+ 1: base_date + pd.offsets.BDay(1),
66
+ 7: base_date + pd.offsets.BDay(7),
67
+ 21: base_date + pd.offsets.BDay(21)
68
+ }
69
 
70
  except Exception as e:
71
  error_html = f"""<div style='background-color:#3a1010; padding:15px; border-left: 4px solid #ff4d4d; color: #ff8080;'>
 
81
  consensus_html = ""
82
  target_price = 0
83
 
84
+ def get_avg_price(h):
85
+ if model_name == "Cả Hai":
86
+ return (preds[h]["Linear Regression"]["pred_price"] + preds[h]["SVR"]["pred_price"]) / 2
87
+ else:
88
+ return preds[h][model_name]["pred_price"]
89
+ target_1d = get_avg_price(1)
90
+ target_7d = get_avg_price(7)
91
+ target_21d = get_avg_price(21)
92
  if model_name == "Cả Hai":
93
  target_price = (price_lr + price_svr) / 2
94
  spread_bps = abs(pred_lr - pred_svr) * 10000 # Basis points
 
103
  direction = "UNCERTAIN / CHOPPY ⚠️"
104
 
105
  consensus_html = f"""
106
+ <div style="background:#1a1a24; border: 1px solid #333; padding: 15px; border-radius: 5px; font-family: monospace;">
107
+ <p style="color:#8892b0; margin:0 0 10px 0; font-size:12px;">LAST CLOSE: {last_date}</p>
108
+ <h2 style="color:white; margin:0 0 15px 0; border-bottom: 1px solid #333; padding-bottom: 10px;">${last_close:.2f}</h2>
109
+ <table style="width: 100%; color: #a8b2d1; font-size: 13px;">
110
+ <tr><td style="padding: 4px 0;">Target T+1 (Day)</td><td style="text-align: right; font-weight: bold; color: #ffd700;">${target_1d:.2f}</td></tr>
111
+ <tr><td style="padding: 4px 0;">Target T+7 (Week)</td><td style="text-align: right; font-weight: bold; color: #ffaa00;">${target_7d:.2f}</td></tr>
112
+ <tr><td style="padding: 4px 0;">Target T+21 (Month)</td><td style="text-align: right; font-weight: bold; color: #ff5500;">${target_21d:.2f}</td></tr>
113
+ </table>
114
+ </div>
115
+ """
116
  else:
117
  active_pred = pred_lr if model_name == "Linear Regression" else pred_svr
118
  target_price = price_lr if model_name == "Linear Regression" else price_svr
119
  dir_color = "#00ff00" if active_pred > 0 else "#ff3333"
120
  dir_text = "BULLISH 📈" if active_pred > 0 else "BEARISH 📉"
121
  consensus_html = f"""
122
+ <div style="background:#1a1a24; border: 1px solid #333; padding: 15px; border-radius: 5px; font-family: monospace;">
123
+ <p style="color:#8892b0; margin:0 0 10px 0; font-size:12px;">LAST CLOSE: {last_date}</p>
124
+ <h2 style="color:white; margin:0 0 15px 0; border-bottom: 1px solid #333; padding-bottom: 10px;">${last_close:.2f}</h2>
125
+ <table style="width: 100%; color: #a8b2d1; font-size: 13px;">
126
+ <tr><td style="padding: 4px 0;">Target T+1 (Day)</td><td style="text-align: right; font-weight: bold; color: #ffd700;">${target_1d:.2f}</td></tr>
127
+ <tr><td style="padding: 4px 0;">Target T+7 (Week)</td><td style="text-align: right; font-weight: bold; color: #ffaa00;">${target_7d:.2f}</td></tr>
128
+ <tr><td style="padding: 4px 0;">Target T+21 (Month)</td><td style="text-align: right; font-weight: bold; color: #ff5500;">${target_21d:.2f}</td></tr>
129
+ </table>
130
+ </div>
131
+ """
132
 
133
  # 4. Market Context Panel (Technical Stats)
134
  rsi_color = "#ff3333" if rsi_val > 70 else ("#00ff00" if rsi_val < 30 else "#a8b2d1")
 
151
  # 5. Vẽ biểu đồ Plotly cấp độ Institutional
152
  fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
153
  vertical_spacing=0.03, row_heights=[0.75, 0.25],
154
+ subplot_titles=(f"{ticker} - MULTI-HORIZON PROJECTIONS", "VOLUME"))
155
 
156
  # Candlestick
157
  fig.add_trace(go.Candlestick(
 
167
  # Volume subplot
168
  colors = ['#00ff00' if row['Close'] >= row['Open'] else '#ff3333' for _, row in df_ui.iterrows()]
169
  fig.add_trace(go.Bar(x=df_ui['Date'], y=df_ui['Volume'], marker_color=colors, name='Volume'), row=2, col=1)
170
+ x_future = [base_date, dates_future[1], dates_future[7], dates_future[21]]
171
  # --- Thêm điểm dự báo và Confidence Interval (Error Bands) dựa trên ATR ---
172
  if model_name in ["Linear Regression", "Cả Hai"]:
173
+ y_lr = [last_close, preds[1]["Linear Regression"]["pred_price"],
174
+ preds[7]["Linear Regression"]["pred_price"], preds[21]["Linear Regression"]["pred_price"]]
175
+ fig.add_trace(go.Scatter(x=x_future, y=y_lr, mode='lines+markers', name='LR Trajectory',
176
+ line=dict(color='#ff00ff', dash='dot'), marker=dict(size=8, symbol='diamond')), row=1, col=1)
 
177
 
178
  if model_name in ["SVR", "Cả Hai"]:
179
+ y_svr = [last_close, preds[1]["SVR"]["pred_price"],
180
+ preds[7]["SVR"]["pred_price"], preds[21]["SVR"]["pred_price"]]
181
+ fig.add_trace(go.Scatter(x=x_future, y=y_svr, mode='lines+markers', name='SVR Trajectory',
182
+ line=dict(color='#00ffff', dash='dot'), marker=dict(size=8, symbol='diamond')), row=1, col=1)
183
+ upper_band = [last_close, target_1d + atr_val*np.sqrt(1), target_7d + atr_val*np.sqrt(7), target_21d + atr_val*np.sqrt(21)]
184
+ lower_band = [last_close, target_1d - atr_val*np.sqrt(1), target_7d - atr_val*np.sqrt(7), target_21d - atr_val*np.sqrt(21)]
185
  # Error Band (±1 ATR cho mức target)
186
+ fig.add_trace(go.Scatter(x=x_future, y=upper_band, mode='lines', name='Risk Cone Upper', line=dict(color='rgba(255, 255, 255, 0.2)')), row=1, col=1)
187
+ fig.add_trace(go.Scatter(x=x_future, y=lower_band, mode='lines', fill='tonexty', fillcolor='rgba(255, 255, 255, 0.05)', name='Risk Cone Lower', line=dict(color='rgba(255, 255, 255, 0.2)')), row=1, col=1)
 
 
 
 
188
 
189
  # Tối ưu giao diện Plotly Dark Mode
190
  fig.update_layout(
 
191
  template="plotly_dark",
192
  plot_bgcolor='#0d0d14', paper_bgcolor='#0d0d14',
193
  margin=dict(l=40, r=40, t=40, b=40),
 
211
  with gr.Blocks(title="Quant Terminal | Stock ML", css=css, theme=gr.themes.Monochrome()) as demo:
212
  gr.Markdown("""
213
  <div style="padding: 10px 0; border-bottom: 2px solid #333;">
214
+ <h1 style="color: #e6e6fa; margin: 0; font-family: monospace;">⚡ QUANTRONIC ML TERMINAL v2.0</h1>
215
  <p style="color: #8892b0; margin: 0; font-family: monospace;">SVR & Ridge Regression Predictive Analytics Engine</p>
216
  </div>
217
  """)
 
231
 
232
  # MAIN AREA (Charts)
233
  with gr.Column(scale=3):
234
+ plot_chart = gr.Plot(height=700)
235
 
236
  btn_predict.click(
237
  fn=generate_quant_dashboard,
src/data_processing.py CHANGED
@@ -88,7 +88,7 @@ def validate_data(df, stage="pre_feature"):
88
  print(f"Validation passed at {stage} (no critical issues).")
89
  return df
90
 
91
- def generate_technical_features(df, is_inference=False):
92
  """
93
  Feature Engineering hoàn toàn mới theo 5 yêu cầu:
94
  1. Corporate actions đã được xử lý ở load_data (auto_adjust=True)
@@ -181,17 +181,17 @@ def generate_technical_features(df, is_inference=False):
181
  data = pd.concat(data_list, ignore_index=True)
182
 
183
  if not is_inference:
184
- data['Target_Return'] = data.groupby('Ticker')['Daily_Return'].shift(-1)
185
  data = data.dropna().reset_index(drop=True)
186
  # === 5. DATA VALIDATION TRƯỚC KHI TRẢ VỀ ===
187
- data = validate_data(data, stage="post_feature_engineering")
188
 
189
  df_backtest = data.copy()
190
  drop_cols = ['Date', 'Ticker', 'Market_Close', 'Target_Return']
191
  X = data.drop(columns=drop_cols, errors='ignore')
192
  y = data['Target_Return'].copy()
193
 
194
- print(f"Generated stationary features & prepared ML data:\n"
195
  f" • Total rows: {len(data)} | Tickers: {data['Ticker'].nunique()}\n"
196
  f" • Features: {X.shape[1]} | X shape: {X.shape} | y shape: {y.shape}")
197
 
 
88
  print(f"Validation passed at {stage} (no critical issues).")
89
  return df
90
 
91
+ def generate_technical_features(df, is_inference=False, target_horizon=1):
92
  """
93
  Feature Engineering hoàn toàn mới theo 5 yêu cầu:
94
  1. Corporate actions đã được xử lý ở load_data (auto_adjust=True)
 
181
  data = pd.concat(data_list, ignore_index=True)
182
 
183
  if not is_inference:
184
+ data['Target_Return'] = data.groupby('Ticker')['Close'].shift(-target_horizon) / data['Close'] - 1
185
  data = data.dropna().reset_index(drop=True)
186
  # === 5. DATA VALIDATION TRƯỚC KHI TRẢ VỀ ===
187
+ data = validate_data(data, f"post_feature_engineering_h{target_horizon}")
188
 
189
  df_backtest = data.copy()
190
  drop_cols = ['Date', 'Ticker', 'Market_Close', 'Target_Return']
191
  X = data.drop(columns=drop_cols, errors='ignore')
192
  y = data['Target_Return'].copy()
193
 
194
+ print(f"Generated data for Horizon {target_horizon} days:\n"
195
  f" • Total rows: {len(data)} | Tickers: {data['Ticker'].nunique()}\n"
196
  f" • Features: {X.shape[1]} | X shape: {X.shape} | y shape: {y.shape}")
197
 
src/inference.py CHANGED
@@ -7,7 +7,7 @@ from src.data_processing import load_data, clean_data, generate_technical_featur
7
 
8
  REPO_ID = "Reality8081/Predict_Stock_SVR_Linear" # << THAY ĐỔI DÒNG NÀY TƯƠNG TỰ
9
  MARKET_SYMBOL = "^GSPC"
10
-
11
  # Tự động tải models từ Hugging Face nếu chưa có tại local
12
  def download_model_if_not_exists(filename):
13
  local_path = os.path.join("models", filename)
@@ -17,8 +17,7 @@ def download_model_if_not_exists(filename):
17
  return path
18
  return local_path
19
 
20
- def predict_next_day(ticker, model_name):
21
- # Lấy data 150 ngày gần nhất để tính đủ các window (SMA 100 cần ít nhất 100 nến)
22
  end_date = datetime.now()
23
  start_date = end_date - timedelta(days=150)
24
 
@@ -26,36 +25,33 @@ def predict_next_day(ticker, model_name):
26
  df_clean = clean_data(df_raw)
27
  df_features, X, _ = generate_technical_features(df_clean, is_inference=True)
28
 
29
- if len(X) == 0:
30
- raise ValueError(f"Không đủ dữ liệu cho {ticker} để tạo đặc trưng.")
31
 
32
- # Lấy dòng cuối cùng (ngày giao dịch gần nhất)
33
  latest_X = X.iloc[[-1]]
34
  latest_data = df_features.iloc[-1]
35
  last_close = latest_data['Close']
36
  last_date = latest_data['Date'].strftime('%Y-%m-%d')
37
 
38
- predictions = {}
39
 
40
- if model_name in ["Linear Regression", "Cả Hai"]:
41
- scaler_lr = joblib.load(download_model_if_not_exists('scaler_Linear.pkl'))
42
- model_lr = joblib.load(download_model_if_not_exists('trained_model_Linear.pkl'))
43
- pred_return_lr = model_lr.predict(scaler_lr.transform(latest_X))[0]
44
- predictions["Linear Regression"] = {
45
- "pred_return": pred_return_lr,
46
- "pred_price": last_close * (1 + pred_return_lr)
47
- }
48
-
49
- if model_name in ["SVR", "Cả Hai"]:
50
- scaler_svr = joblib.load(download_model_if_not_exists('scaler_SVR.pkl'))
51
- model_svr = joblib.load(download_model_if_not_exists('trained_model_SVR.pkl'))
52
- pred_return_svr = model_svr.predict(scaler_svr.transform(latest_X))[0]
53
- predictions["SVR"] = {
54
- "pred_return": pred_return_svr,
55
- "pred_price": last_close * (1 + pred_return_svr)
56
- }
 
57
 
58
- # Lịch sử giá 30 phiên để vẽ biểu đồ
59
  historical_30 = df_features[['Date', 'Close']].tail(30)
60
-
61
  return predictions, last_close, last_date, historical_30
 
7
 
8
  REPO_ID = "Reality8081/Predict_Stock_SVR_Linear" # << THAY ĐỔI DÒNG NÀY TƯƠNG TỰ
9
  MARKET_SYMBOL = "^GSPC"
10
+ HORIZONS = [1, 7, 21]
11
  # Tự động tải models từ Hugging Face nếu chưa có tại local
12
  def download_model_if_not_exists(filename):
13
  local_path = os.path.join("models", filename)
 
17
  return path
18
  return local_path
19
 
20
+ def predict_horizons(ticker, model_name):
 
21
  end_date = datetime.now()
22
  start_date = end_date - timedelta(days=150)
23
 
 
25
  df_clean = clean_data(df_raw)
26
  df_features, X, _ = generate_technical_features(df_clean, is_inference=True)
27
 
28
+ if len(X) == 0: raise ValueError(f"Không đủ dữ liệu cho {ticker}.")
 
29
 
 
30
  latest_X = X.iloc[[-1]]
31
  latest_data = df_features.iloc[-1]
32
  last_close = latest_data['Close']
33
  last_date = latest_data['Date'].strftime('%Y-%m-%d')
34
 
35
+ predictions = {1: {}, 7: {}, 21: {}}
36
 
37
+ for h in HORIZONS:
38
+ if model_name in ["Linear Regression", "Cả Hai"]:
39
+ scaler_lr = joblib.load(download_model_if_not_exists(f'scaler_lr_{h}d.pkl'))
40
+ model_lr = joblib.load(download_model_if_not_exists(f'model_lr_{h}d.pkl'))
41
+ pred_return_lr = model_lr.predict(scaler_lr.transform(latest_X))[0]
42
+ predictions[h]["Linear Regression"] = {
43
+ "pred_return": pred_return_lr,
44
+ "pred_price": last_close * (1 + pred_return_lr)
45
+ }
46
+
47
+ if model_name in ["SVR", "Cả Hai"]:
48
+ scaler_svr = joblib.load(download_model_if_not_exists(f'scaler_svr_{h}d.pkl'))
49
+ model_svr = joblib.load(download_model_if_not_exists(f'model_svr_{h}d.pkl'))
50
+ pred_return_svr = model_svr.predict(scaler_svr.transform(latest_X))[0]
51
+ predictions[h]["SVR"] = {
52
+ "pred_return": pred_return_svr,
53
+ "pred_price": last_close * (1 + pred_return_svr)
54
+ }
55
 
 
56
  historical_30 = df_features[['Date', 'Close']].tail(30)
 
57
  return predictions, last_close, last_date, historical_30
src/train.py CHANGED
@@ -17,104 +17,105 @@ MARKET_SYMBOL = "^GSPC"
17
  START_DATE = "2010-01-01"
18
  END_DATE = datetime.now().strftime('%Y-%m-%d')
19
  REPO_ID = "Reality8081/Predict_Stock_SVR_Linear" # << THAY ĐỔI DÒNG NÀY
 
20
 
21
  def main():
22
  print("1. Đang tải và làm sạch dữ liệu...")
23
  df_raw = load_data(SYMBOLS, MARKET_SYMBOL, START_DATE, END_DATE)
24
  df_clean = clean_data(df_raw)
25
-
26
- print("2. Tạo đặc trưng (Features)...")
27
- _, X, y = generate_technical_features(df_clean, is_inference=False)
28
-
29
- tscv = TimeSeriesSplit(n_splits=5)
30
-
31
- # === TỐI ƯU LINEAR REGRESSION (RIDGE) ===
32
- print("3. Tối ưu siêu tham số Ridge Regression...")
33
- def objective_lr(trial):
34
- alpha = trial.suggest_float('alpha', 1e-4, 1e4, log=True)
35
 
36
  tscv = TimeSeriesSplit(n_splits=5)
37
- fold_scores = []
38
 
39
- for train_idx, val_idx in tscv.split(X):
40
- X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
41
- y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
 
42
 
43
- scaler = StandardScaler()
44
- X_train_scaled = scaler.fit_transform(X_train)
45
- X_val_scaled = scaler.transform(X_val)
46
 
47
- model = Ridge(alpha=alpha, random_state=42)
48
- model.fit(X_train_scaled, y_train)
49
- preds = model.predict(X_val_scaled)
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- rmse = np.sqrt(mean_squared_error(y_val, preds))
52
- fold_scores.append(rmse)
53
-
54
- return np.mean(fold_scores)
55
 
56
- study_lr = optuna.create_study(direction='minimize')
57
- study_lr.optimize(objective_lr, n_trials=20)
58
- best_alpha = study_lr.best_params['alpha']
59
 
60
- # === TỐI ƯU SVR ===
61
- print("4. Tối ưu siêu tham số SVR...")
62
- def objective_svr(trial):
63
- # Chỉ tối ưu siêu tham số SVR
64
- kernel = trial.suggest_categorical('kernel', ['linear', 'rbf'])
65
- C = trial.suggest_float('C', 1e-3, 100.0, log=True)
66
- epsilon = trial.suggest_float('epsilon', 1e-3, 1.0, log=True)
67
- gamma = trial.suggest_categorical('gamma', ['scale', 'auto']) if kernel == 'rbf' else 'scale'
68
-
69
- # Chuẩn bị data với feature cố định
70
-
71
- tscv = TimeSeriesSplit(n_splits=5)
72
- fold_scores = []
73
-
74
- for train_idx, val_idx in tscv.split(X):
75
- X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
76
- y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
77
 
78
- scaler = StandardScaler()
79
- X_train_scaled = scaler.fit_transform(X_train)
80
- X_val_scaled = scaler.transform(X_val)
81
 
82
- X_train_scaled = X_train_scaled.astype('float32')
83
- X_val_scaled = X_val_scaled.astype('float32')
84
- y_train_f32 = y_train.values.astype('float32')
85
- y_val_f32 = y_val.values.astype('float32')
86
 
87
- model = SVR(kernel=kernel, C=C, epsilon=epsilon, gamma=gamma, max_iter=5000)
88
- model.fit(X_train_scaled, y_train)
89
- preds = model.predict(X_val_scaled)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- rmse = np.sqrt(mean_squared_error(y_val, preds))
92
- fold_scores.append(rmse)
93
-
94
- return np.mean(fold_scores)
95
 
96
- study_svr = optuna.create_study(direction='minimize')
97
- study_svr.optimize(objective_svr, n_trials=10) # Set số trial vừa phải
98
-
99
- # === HUẤN LUYỆN MODEL CUỐI CÙNG & LƯU LẠI ===
100
- print("5. Huấn luyện mô hình cuối và lưu trữ...")
101
- os.makedirs("models", exist_ok=True)
102
-
103
- # Ridge
104
- scaler_lr = StandardScaler()
105
- X_scaled_lr = scaler_lr.fit_transform(X)
106
- model_lr = Ridge(alpha=best_alpha, random_state=42)
107
- model_lr.fit(X_scaled_lr, y)
108
- joblib.dump(scaler_lr, 'models/scaler_lr.pkl')
109
- joblib.dump(model_lr, 'models/model_lr.pkl')
110
-
111
- # SVR
112
- scaler_svr = StandardScaler()
113
- X_scaled_svr = scaler_svr.fit_transform(X)
114
- model_svr = SVR(kernel='rbf', C=study_svr.best_params['C'], epsilon=study_svr.best_params['epsilon'], gamma='scale')
115
- model_svr.fit(X_scaled_svr, y)
116
- joblib.dump(scaler_svr, 'models/scaler_svr.pkl')
117
- joblib.dump(model_svr, 'models/model_svr.pkl')
118
 
119
  print("6. Tải mô hình lên Hugging Face Hub...")
120
  hf_token = os.environ.get("HF_TOKEN")
 
17
  START_DATE = "2010-01-01"
18
  END_DATE = datetime.now().strftime('%Y-%m-%d')
19
  REPO_ID = "Reality8081/Predict_Stock_SVR_Linear" # << THAY ĐỔI DÒNG NÀY
20
+ HORIZONS = [1, 7, 21]
21
 
22
  def main():
23
  print("1. Đang tải và làm sạch dữ liệu...")
24
  df_raw = load_data(SYMBOLS, MARKET_SYMBOL, START_DATE, END_DATE)
25
  df_clean = clean_data(df_raw)
26
+ os.makedirs("models", exist_ok=True)
27
+ for h in HORIZONS:
28
+ print("2. Tạo đặc trưng (Features)...")
29
+ _, X, y = generate_technical_features(df_clean, is_inference=False, target_horizon=h)
 
 
 
 
 
 
30
 
31
  tscv = TimeSeriesSplit(n_splits=5)
 
32
 
33
+ # === TỐI ƯU LINEAR REGRESSION (RIDGE) ===
34
+ print("3. Tối ưu siêu tham số Ridge Regression...")
35
+ def objective_lr(trial):
36
+ alpha = trial.suggest_float('alpha', 1e-4, 1e4, log=True)
37
 
38
+ tscv = TimeSeriesSplit(n_splits=5)
39
+ fold_scores = []
 
40
 
41
+ for train_idx, val_idx in tscv.split(X):
42
+ X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
43
+ y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
44
+
45
+ scaler = StandardScaler()
46
+ X_train_scaled = scaler.fit_transform(X_train)
47
+ X_val_scaled = scaler.transform(X_val)
48
+
49
+ model = Ridge(alpha=alpha, random_state=42)
50
+ model.fit(X_train_scaled, y_train)
51
+ preds = model.predict(X_val_scaled)
52
+
53
+ rmse = np.sqrt(mean_squared_error(y_val, preds))
54
+ fold_scores.append(rmse)
55
 
56
+ return np.mean(fold_scores)
 
 
 
57
 
58
+ study_lr = optuna.create_study(direction='minimize')
59
+ study_lr.optimize(objective_lr, n_trials=20)
60
+ best_alpha = study_lr.best_params['alpha']
61
 
62
+ # === TỐI ƯU SVR ===
63
+ print("4. Tối ưu siêu tham số SVR...")
64
+ def objective_svr(trial):
65
+ # Chỉ tối ưu siêu tham số SVR
66
+ kernel = trial.suggest_categorical('kernel', ['linear', 'rbf'])
67
+ C = trial.suggest_float('C', 1e-3, 100.0, log=True)
68
+ epsilon = trial.suggest_float('epsilon', 1e-3, 1.0, log=True)
69
+ gamma = trial.suggest_categorical('gamma', ['scale', 'auto']) if kernel == 'rbf' else 'scale'
 
 
 
 
 
 
 
 
 
70
 
71
+ # Chuẩn bị data với feature cố định
 
 
72
 
73
+ tscv = TimeSeriesSplit(n_splits=5)
74
+ fold_scores = []
 
 
75
 
76
+ for train_idx, val_idx in tscv.split(X):
77
+ X_train, X_val = X.iloc[train_idx], X.iloc[val_idx]
78
+ y_train, y_val = y.iloc[train_idx], y.iloc[val_idx]
79
+
80
+ scaler = StandardScaler()
81
+ X_train_scaled = scaler.fit_transform(X_train)
82
+ X_val_scaled = scaler.transform(X_val)
83
+
84
+ X_train_scaled = X_train_scaled.astype('float32')
85
+ X_val_scaled = X_val_scaled.astype('float32')
86
+ y_train_f32 = y_train.values.astype('float32')
87
+ y_val_f32 = y_val.values.astype('float32')
88
+
89
+ model = SVR(kernel=kernel, C=C, epsilon=epsilon, gamma=gamma, max_iter=5000)
90
+ model.fit(X_train_scaled, y_train)
91
+ preds = model.predict(X_val_scaled)
92
+
93
+ rmse = np.sqrt(mean_squared_error(y_val, preds))
94
+ fold_scores.append(rmse)
95
 
96
+ return np.mean(fold_scores)
 
 
 
97
 
98
+ study_svr = optuna.create_study(direction='minimize')
99
+ study_svr.optimize(objective_svr, n_trials=10) # Set số trial vừa phải
100
+
101
+ # === HUẤN LUYỆN MODEL CUỐI CÙNG & LƯU LẠI ===
102
+ print("5. Huấn luyện mô hình cuối và lưu trữ...")
103
+ os.makedirs("models", exist_ok=True)
104
+
105
+ # Ridge
106
+ scaler_lr = StandardScaler()
107
+ X_scaled_lr = scaler_lr.fit_transform(X)
108
+ model_lr = Ridge(alpha=best_alpha, random_state=42)
109
+ model_lr.fit(X_scaled_lr, y)
110
+ joblib.dump(scaler_lr, f'models/scaler_lr_{h}d.pkl')
111
+ joblib.dump(model_lr, f'models/model_lr_{h}d.pkl')
112
+ # SVR
113
+ scaler_svr = StandardScaler()
114
+ X_scaled_svr = scaler_svr.fit_transform(X)
115
+ model_svr = SVR(kernel='rbf', C=study_svr.best_params['C'], epsilon=study_svr.best_params['epsilon'], gamma='scale')
116
+ model_svr.fit(X_scaled_svr, y)
117
+ joblib.dump(scaler_svr, f'models/scaler_svr_{h}d.pkl')
118
+ joblib.dump(model_svr, f'models/model_svr_{h}d.pkl')
 
119
 
120
  print("6. Tải mô hình lên Hugging Face Hub...")
121
  hf_token = os.environ.get("HF_TOKEN")