File size: 2,449 Bytes
d7d2bc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import yfinance as yf
import pandas as pd
import numpy as np
from catboost import CatBoostRegressor
from sklearn.multioutput import MultiOutputRegressor
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import os

def fetch_data(ticker, period="3mo"):
    df = yf.download(ticker, period=period)[['Open', 'High', 'Low', 'Close', 'Volume']].dropna()
    return df

def prepare_training_data(df, window_size=5, forecast_days=5):
    X, Y = [], []
    for i in range(len(df) - window_size - forecast_days):
        X.append(df.iloc[i:i+window_size].values.flatten())
        Y.append(df['Close'].iloc[i+window_size : i+window_size+forecast_days].values)
    return np.array(X), np.array(Y).reshape(len(Y), -1)

def main(ticker, forecast_days):
    print(f"下載 {ticker} 股票資料...")
    df = fetch_data(ticker)

    print("建立訓練資料...")
    X, Y = prepare_training_data(df, window_size=10, forecast_days=forecast_days)

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    print("訓練 CatBoost 模型...")
    model = MultiOutputRegressor(CatBoostRegressor(
        iterations=500,
        learning_rate=0.05,
        depth=6,
        verbose=100
    ))
    model.fit(X_scaled, Y)

    # 使用最新一筆資料預測未來
    print(f"使用最新資料預測未來 {forecast_days} 天...")
    latest_input = df.iloc[-10:].values.flatten().reshape(1, -1)
    latest_input_scaled = scaler.transform(latest_input)
    forecast = model.predict(latest_input_scaled)[0]

    # 繪圖
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, forecast_days + 1), forecast, label='Predicted', color='orange')
    plt.title(f"{ticker} Forecast for Next {forecast_days} Days (CatBoost)")
    plt.xlabel("Days Ahead")
    plt.ylabel("Predicted Close Price")
    plt.legend()
    plt.grid(True)

    if "DISPLAY" in os.environ:
        plt.show()
    else:
        filename = f"catboost_{ticker.lower()}_forecast.png"
        plt.savefig(filename)
        print(f"圖已儲存為 {filename}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Forecast future stock price using CatBoost")
    parser.add_argument('--ticker', type=str, default='TSLA', help='股票代碼,例如 TSLA, AAPL')
    parser.add_argument('--days', type=int, default=5, help='預測未來幾天的收盤價')
    args = parser.parse_args()

    main(args.ticker, args.days)