Spaces:
Sleeping
Sleeping
import streamlit as st | |
import yfinance as yf | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
import pandas as pd | |
import twstock | |
from datetime import datetime, timedelta | |
def plot_stock_data(stock_symbols, period='1y'): | |
""" | |
繪製股票價格圖表 | |
:param stock_symbols: 股票代號列表 | |
:param period: 時間區間 | |
:return: Plotly figure | |
""" | |
# 創建子圖 | |
fig = make_subplots( | |
rows=len(stock_symbols), | |
cols=1, | |
subplot_titles=[f"股價走勢: {symbol}" for symbol in stock_symbols], | |
vertical_spacing=0.05, | |
specs=[[{"secondary_y": True}] for _ in stock_symbols] | |
) | |
# 為每個股票繪製圖形 | |
for idx, symbol in enumerate(stock_symbols, 1): | |
try: | |
# 獲取股票數據 | |
stock = yf.Ticker(symbol) | |
df = stock.history(period=period) | |
if df.empty: | |
st.warning(f"無法找到 {symbol} 的股票數據") | |
continue | |
# 添加蠟燭圖 | |
fig.add_trace( | |
go.Candlestick( | |
x=df.index, | |
open=df['Open'], | |
high=df['High'], | |
low=df['Low'], | |
close=df['Close'], | |
name=f'{symbol} 價格' | |
), | |
row=idx, col=1 | |
) | |
# 添加成交量柱狀圖 | |
fig.add_trace( | |
go.Bar( | |
x=df.index, | |
y=df['Volume'], | |
name=f'{symbol} 成交量', | |
opacity=0.3 | |
), | |
row=idx, col=1, | |
secondary_y=True | |
) | |
# 添加移動平均線 | |
for ma_days in [5, 20, 60]: | |
ma = df['Close'].rolling(window=ma_days).mean() | |
fig.add_trace( | |
go.Scatter( | |
x=df.index, | |
y=ma, | |
name=f'{symbol} MA{ma_days}', | |
line=dict(width=1) | |
), | |
row=idx, col=1 | |
) | |
except Exception as e: | |
st.error(f"處理 {symbol} 時發生錯誤: {str(e)}") | |
# 更新布局 | |
fig.update_layout( | |
height=400 * len(stock_symbols), | |
title_text="台股分析圖", | |
showlegend=True, | |
xaxis_rangeslider_visible=False, | |
template="plotly_white" | |
) | |
# 更新軸標籤 | |
for i in range(1, len(stock_symbols) + 1): | |
fig.update_xaxes(title_text="日期", row=i, col=1) | |
fig.update_yaxes(title_text="價格 (TWD)", row=i, col=1) | |
fig.update_yaxes(title_text="成交量", row=i, col=1, secondary_y=True) | |
return fig | |
def fetch_recent_stock_data(stock_code): | |
""" | |
使用 twstock 獲取近期股票交易數據 | |
""" | |
try: | |
stock = twstock.Stock(stock_code) | |
recent_data = stock.fetch_31() # 抓取最近 31 天的交易數據 | |
if not recent_data: | |
st.warning(f"無法找到 {stock_code} 的交易數據。") | |
return None | |
# 將數據整理為 DataFrame 格式 | |
data_list = [ | |
{ | |
"Date": data.date.strftime('%Y-%m-%d'), | |
"Open": data.open, | |
"High": data.high, | |
"Low": data.low, | |
"Close": data.close, | |
"Transaction": data.transaction, | |
"Capacity": data.capacity, | |
"Turnover": data.turnover | |
} | |
for data in recent_data | |
] | |
df = pd.DataFrame(data_list) | |
return df | |
except Exception as e: | |
st.error(f"發生錯誤: {e}") | |
return None | |
def main(): | |
st.set_page_config(page_title="台股分析工具", page_icon=":chart_with_upwards_trend:", layout="wide") | |
st.title("🚀 台股分析工具") | |
# 側邊欄設置 | |
with st.sidebar: | |
st.header("股票分析設定") | |
# 股票代碼輸入 | |
stock_input = st.text_input( | |
"股票代號 (用逗號分隔)", | |
value="2330.TW,2454.TW", | |
placeholder="例如: 2330.TW,2454.TW" | |
) | |
# 時間區間選擇 | |
period_select = st.selectbox( | |
"選擇時間區間", | |
["1mo", "3mo", "6mo", "1y", "2y", "5y", "max"], | |
index=3 # 預設為 1y | |
) | |
# 股票分析頁籤 | |
tab1, tab2 = st.tabs(["股價走勢圖", "近期交易數據"]) | |
with tab1: | |
# 股價走勢圖 | |
if st.button("繪製股價走勢圖"): | |
# 處理股票代號 | |
stocks = [s.strip() for s in stock_input.split(',')] | |
stocks = [f"{s}.TW" if not s.endswith('.TW') and s.isdigit() else s for s in stocks] | |
# 創建圖表 | |
fig = plot_stock_data(stocks, period_select) | |
st.plotly_chart(fig, use_container_width=True) | |
with tab2: | |
# 近期交易數據 | |
st.subheader("個股近期交易數據") | |
single_stock_code = st.text_input( | |
"請輸入股票代碼", | |
placeholder="例如: 2330" | |
) | |
if st.button("查詢交易數據"): | |
if single_stock_code: | |
# 獲取近期股票數據 | |
df = fetch_recent_stock_data(single_stock_code) | |
if df is not None: | |
# 顯示數據 | |
st.dataframe(df) | |
# 統計資訊 | |
st.subheader("基本統計") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("平均收盤價", f"{df['Close'].mean():.2f}") | |
with col2: | |
st.metric("最高價", f"{df['High'].max():.2f}") | |
with col3: | |
st.metric("最低價", f"{df['Low'].min():.2f}") | |
# 匯出 CSV | |
csv_data = df.to_csv(index=False).encode('utf-8-sig') | |
st.download_button( | |
label="下載CSV", | |
data=csv_data, | |
file_name=f"{single_stock_code}_recent_30days.csv", | |
mime="text/csv" | |
) | |
if __name__ == "__main__": | |
main() |