Spaces:
Sleeping
Sleeping
# ======================== (A) 导入库和配置环境 ======================== | |
import pandas as pd | |
import numpy as np | |
import time | |
import warnings | |
import os | |
import logging | |
import tempfile | |
import traceback | |
# 数据分析与建模 | |
from scipy import stats | |
from statsmodels.tsa.stattools import adfuller | |
from statsmodels.tsa.seasonal import STL | |
from statsmodels.stats.diagnostic import acorr_ljungbox | |
from prophet import Prophet | |
import pmdarima as pm | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
from joblib import Parallel, delayed | |
# 可视化 | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
# Web UI | |
import gradio as gr | |
# --- 全局设置 --- | |
warnings.filterwarnings("ignore") | |
logging.getLogger('prophet').setLevel(logging.ERROR) | |
logging.getLogger('cmdstanpy').setLevel(logging.ERROR) | |
# 配置中文字体 (使用 packages.txt 安装的字体) | |
plt.rcParams['font.sans-serif'] = ['WenQuanYi Zen Hei'] | |
plt.rcParams['axes.unicode_minus'] = False | |
# --- 输出文件夹设置 --- | |
OUTPUT_DIR = 'outputs' | |
if not os.path.exists(OUTPUT_DIR): | |
os.makedirs(OUTPUT_DIR) | |
# ======================== (B) 辅助函数 ======================== | |
def calculate_metrics(actual, predicted): | |
metrics_df = pd.DataFrame({'actual': actual, 'predicted': predicted}).dropna() | |
if metrics_df.empty: | |
return {'MAE': np.nan, 'RMSE': np.nan, 'MAPE': np.nan, 'sMAPE': np.nan} | |
clean_actual, clean_predicted = metrics_df['actual'], metrics_df['predicted'] | |
mae = mean_absolute_error(clean_actual, clean_predicted) | |
rmse = np.sqrt(mean_squared_error(clean_actual, clean_predicted)) | |
actual_safe = np.where(clean_actual == 0, 1e-6, clean_actual) | |
mape = np.mean(np.abs((clean_actual - clean_predicted) / actual_safe)) * 100 | |
smape = 200 * np.mean(np.abs(clean_actual - clean_predicted) / (np.abs(clean_actual) + np.abs(clean_predicted))) | |
return {'MAE': mae, 'RMSE': rmse, 'MAPE': mape, 'sMAPE': smape} | |
# ======================== (C) 主分析函数 (Gradio核心) ======================== | |
def run_full_analysis(progress=gr.Progress(track_tqdm=True)): | |
""" | |
这个主函数封装了所有的分析步骤,并通过 yield 返回结果来实时更新Gradio界面。 | |
""" | |
# --- 1. 初始化 --- | |
log_lines = ["## 🚀 数据分析流程已启动..."] | |
figure_paths = [] | |
final_report_text = "" | |
report_file_path = None | |
# 辅助函数,用于将Matplotlib Figure保存为临时图片文件并返回路径 | |
def save_fig_to_path(fig): | |
# 使用 NamedTemporaryFile 来创建一个不会被立即删除的临时文件 | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: | |
fig.savefig(tmpfile.name) | |
figure_paths.append(tmpfile.name) | |
plt.close(fig) # 操作完成后关闭图形,释放内存 | |
# 辅助函数,用于更新UI状态 | |
def update_ui(new_log_line=None): | |
if new_log_line: | |
log_lines.append(new_log_line) | |
# [log, gallery, final_report, download_button] | |
return "\n\n".join(log_lines), figure_paths, final_report_text, report_file_path | |
yield update_ui() # 立即显示启动信息 | |
try: | |
# --- 2. 数据清洗 --- | |
log_lines.append("### 1. 数据清洗") | |
df = pd.read_excel('gmqrkl.xlsx') | |
df['Date'] = pd.to_datetime(df['Date']) | |
df = df.drop_duplicates(subset=['Date']).sort_values('Date').reset_index(drop=True) | |
log_lines.append(f"✅ 数据读取并去重成功,共 {len(df)} 行。") | |
df['Value'] = df['Value'].replace(0, np.nan) | |
df['Value'].interpolate(method='linear', limit_direction='both', inplace=True) | |
log_lines.append("✅ 零值替换与线性插值完成。") | |
df.set_index('Date', inplace=True) | |
ts_data = df['Value'] | |
yield update_ui() | |
# --- 3. 平稳性检验与差分 --- | |
log_lines.append("### 2. 平稳性检验与差分") | |
current_data = ts_data.dropna() | |
adf_result = adfuller(current_data) | |
if adf_result[1] < 0.05: | |
log_lines.append(f"✅ 序列在 d=0 阶差分后达到平稳 (p={adf_result[1]:.4f})。") | |
d_order = 0 | |
else: | |
current_data_diff = current_data.diff().dropna() | |
adf_result_diff = adfuller(current_data_diff) | |
d_order = 1 | |
if adf_result_diff[1] < 0.05: | |
log_lines.append(f"✅ 序列在 d=1 阶差分后达到平稳 (p={adf_result_diff[1]:.4f})。") | |
else: | |
log_lines.append(f"⚠️ 1阶差分后仍未平稳 (p={adf_result_diff[1]:.4f}),将使用 d=1 继续分析。") | |
current_data = current_data_diff | |
ts_stationary = current_data | |
yield update_ui() | |
# --- 4. 白噪声检验 --- | |
log_lines.append("### 3. 白噪声检验") | |
lags = min(10, len(ts_stationary) // 5) | |
lb_test_result = acorr_ljungbox(ts_stationary, lags=[lags], return_df=True) | |
if lb_test_result['lb_pvalue'].iloc[0] > 0.05: | |
log_lines.append(f"⚠️ 序列可能是白噪声(p-value = {lb_test_result['lb_pvalue'].iloc[0]:.4f}),模型可能无效。") | |
else: | |
log_lines.append(f"✅ 通过白噪声检验 (p-value = {lb_test_result['lb_pvalue'].iloc[0]:.4f}),可以进行后续建模。") | |
yield update_ui() | |
# --- 5. 季节性检验与分解 --- | |
log_lines.append("\n### 4. 季节性检验与STL分解") | |
period = 365 | |
seasonal_enabled = len(ts_data) > 2 * 14 # 数据多于两周则开启周季节性 | |
m_period = 7 if seasonal_enabled else 1 | |
log_lines.append(f"✅ 季节性参数设定: m={m_period}, seasonal={seasonal_enabled}") | |
if len(ts_data) >= 2 * period: | |
seasonal_period_for_stl = period if period % 2 != 0 else period + 1 | |
log_lines.append(f"✅ 准备进行STL分解,周期(period)={period},季节平滑窗口(seasonal)={seasonal_period_for_stl}。") | |
yield update_ui() | |
stl = STL(ts_data, period=period, seasonal=seasonal_period_for_stl) | |
res = stl.fit() | |
fig = res.plot() | |
fig.set_size_inches(12, 8) | |
fig.suptitle(f'STL 分解图 (周期={period})', fontsize=16, y=0.98) | |
plt.tight_layout() | |
save_fig_to_path(fig) | |
log_lines.append("✅ STL分解图已生成。") | |
else: | |
log_lines.append("⚠️ 数据长度不足以进行年度季节性分解。") | |
yield update_ui() | |
# --- 6. 混合策略回测优化窗口大小 --- | |
log_lines.append("\n### 5. 优化训练窗口大小") | |
log_lines.append("⏳ **此步骤计算量大,可能需要5-15分钟,请耐心等待...**") | |
yield update_ui() | |
def evaluate_window_hybrid(window_size, time_series, d, m, seasonal): | |
errors = [] | |
series_values = time_series.values | |
backtest_length = 100 | |
if len(series_values) <= window_size + backtest_length: return {'window_size': window_size, 'mae': np.inf} | |
end_index = len(series_values) | |
start_index = end_index - backtest_length | |
for i in range(start_index, end_index): | |
train_window = pd.Series(series_values[i-window_size:i], index=time_series.index[i-window_size:i]) | |
test_point = series_values[i] | |
use_seasonal = seasonal and (len(train_window) >= 2 * m) | |
try: | |
model = pm.auto_arima(train_window, d=d, seasonal=use_seasonal, m=m, max_p=2, max_q=2, | |
stepwise=True, trace=False, error_action='ignore', suppress_warnings=True) | |
forecast = model.predict(n_periods=1)[0] | |
errors.append(test_point - forecast) | |
except Exception: continue | |
if not errors: return {'window_size': window_size, 'mae': np.inf} | |
return {'window_size': window_size, 'mae': np.mean(np.abs(errors))} | |
window_sizes_to_test = np.arange(70, 211, 14) | |
with Parallel(n_jobs=-1) as parallel: | |
results = parallel( | |
delayed(evaluate_window_hybrid)(ws, ts_data, d_order, m_period, seasonal_enabled) for ws in window_sizes_to_test | |
) | |
window_results_df = pd.DataFrame(results).sort_values('mae').set_index('window_size') | |
if not window_results_df.empty and np.isfinite(window_results_df['mae'].min()): | |
best_window_size = window_results_df['mae'].idxmin() | |
log_lines.append(f"✅ **窗口优化完成!** 基于MAE,最佳训练窗口大小为: **{best_window_size}** 天。") | |
else: | |
best_window_size = 90 | |
log_lines.append(f"⚠️ 窗口优化失败,使用默认窗口大小: {best_window_size} 天。") | |
fig, ax = plt.subplots(figsize=(12, 6)) | |
ax.plot(window_results_df.index, window_results_df['mae'], marker='o', label='MAE') | |
ax.set_title('训练窗口大小对预测误差的影响') | |
ax.set_xlabel('训练窗口天数'); ax.set_ylabel('误差值'); ax.legend(); ax.grid(True) | |
save_fig_to_path(fig) | |
yield update_ui() | |
# --- 7 & 8. 动态滚动预测与评估 --- | |
log_lines.append("\n### 6. 动态滚动预测与评估") | |
log_lines.append("⏳ **此步骤同样耗时,正在进行模型滚动预测...**") | |
yield update_ui() | |
split_point_roll = int(len(ts_data) * 0.8) | |
test_rolling_target = ts_data.iloc[split_point_roll:] | |
rolling_predictions = {} | |
# SARIMA 滚动 | |
sarima_rolling_preds = [] | |
for i in progress.tqdm(range(len(test_rolling_target)), desc="SARIMA Rolling Forecast"): | |
train_window = ts_data.iloc[split_point_roll + i - best_window_size : split_point_roll + i] | |
try: | |
model = pm.auto_arima(train_window, d=d_order, m=m_period, seasonal=seasonal_enabled, | |
stepwise=True, trace=False, error_action='ignore', suppress_warnings=True) | |
sarima_rolling_preds.append(model.predict(n_periods=1)[0]) | |
except: | |
sarima_rolling_preds.append(sarima_rolling_preds[-1] if sarima_rolling_preds else np.nan) | |
rolling_predictions['Auto-SARIMA'] = pd.Series(sarima_rolling_preds, index=test_rolling_target.index).ffill() | |
log_lines.append("✅ Auto-SARIMA 滚动预测完成。") | |
yield update_ui() | |
# Prophet 滚动 | |
prophet_rolling_preds = [] | |
prophet_model = None | |
for i, (date, value) in enumerate(progress.tqdm(test_rolling_target.items(), desc="Prophet Rolling Forecast")): | |
if i % 14 == 0 or prophet_model is None: | |
train_upto_date = ts_data.loc[:date - pd.Timedelta(days=1)] | |
prophet_train_df = train_upto_date.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}) | |
prophet_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(prophet_train_df) | |
future_df = pd.DataFrame({'ds': [date]}) | |
forecast = prophet_model.predict(future_df) | |
prophet_rolling_preds.append(forecast['yhat'].iloc[0]) | |
rolling_predictions['Prophet'] = pd.Series(prophet_rolling_preds, index=test_rolling_target.index) | |
log_lines.append("✅ Prophet 滚动预测完成。") | |
# 评估 | |
rolling_evaluation_results = {name: calculate_metrics(test_rolling_target, preds) for name, preds in rolling_predictions.items()} | |
rolling_evaluation_df = pd.DataFrame(rolling_evaluation_results).T.sort_values(by='MAE') | |
best_rolling_model_name = rolling_evaluation_df.index[0] | |
log_lines.append("\n**滚动预测性能对比:**") | |
log_lines.append(f"```\n{rolling_evaluation_df.to_markdown()}\n```") | |
log_lines.append(f"\n==> ✅ 最佳滚动预测模型是: **{best_rolling_model_name}**") | |
yield update_ui() | |
# --- 9. 滚动预测可视化 --- | |
log_lines.append("\n### 7. 生成结果图表") | |
fig, ax = plt.subplots(figsize=(15, 8)) | |
ax.plot(ts_data, label='历史数据', color='gray', alpha=0.5) | |
ax.plot(test_rolling_target, label='真实值 (测试集)', color='blue', linewidth=2) | |
for model_name, preds in rolling_predictions.items(): | |
is_best = ' (最佳)' if model_name == best_rolling_model_name else '' | |
ax.plot(preds.dropna(), label=f'{model_name} 预测{is_best}', linestyle='--') | |
ax.set_title('滚动预测结果对比'); ax.legend(); ax.grid(True) | |
save_fig_to_path(fig) | |
yield update_ui() | |
# --- 10. 最终未来预测 --- | |
forecast_horizon = 90 | |
log_lines.append(f"\n### 8. 使用最佳模型 `{best_rolling_model_name}` 预测未来 {forecast_horizon} 天") | |
# 训练最终模型 | |
if 'Auto-SARIMA' in best_rolling_model_name: | |
final_train_data = ts_data.iloc[-best_window_size:] | |
final_model = pm.auto_arima(final_train_data, d=d_order, m=m_period, seasonal=seasonal_enabled, | |
stepwise=True, trace=False, error_action='ignore', suppress_warnings=True) | |
final_forecast_values, conf_int = final_model.predict(n_periods=forecast_horizon, return_conf_int=True) | |
else: # Prophet | |
final_train_data = ts_data | |
final_prophet_train_df = final_train_data.reset_index().rename(columns={'Date': 'ds', 'Value': 'y'}) | |
final_model = Prophet(yearly_seasonality='auto', weekly_seasonality=seasonal_enabled, daily_seasonality=False).fit(final_prophet_train_df) | |
future_df = final_model.make_future_dataframe(periods=forecast_horizon, freq='D') | |
forecast_obj = final_model.predict(future_df) | |
final_forecast_values = forecast_obj['yhat'].iloc[-forecast_horizon:].values | |
conf_int = np.column_stack((forecast_obj['yhat_lower'].iloc[-forecast_horizon:].values, forecast_obj['yhat_upper'].iloc[-forecast_horizon:].values)) | |
future_dates = pd.date_range(start=ts_data.index[-1] + pd.Timedelta(days=1), periods=forecast_horizon) | |
final_forecast_series = pd.Series(final_forecast_values, index=future_dates) | |
# 可视化最终预测 | |
fig, ax = plt.subplots(figsize=(15, 8)) | |
ax.plot(ts_data.tail(365), label='近期历史数据', color='blue') | |
ax.plot(final_forecast_series, label=f'未来 {forecast_horizon} 天预测', color='red', linestyle='--') | |
ax.fill_between(future_dates, conf_int[:, 0], conf_int[:, 1], color='red', alpha=0.2, label='95% 置信区间') | |
ax.set_title(f'最终未来用量预测 (基于 {best_rolling_model_name})'); ax.legend(); ax.grid(True) | |
save_fig_to_path(fig) | |
# 生成最终报告 | |
final_report_text = f""" | |
# 药品用量预测分析报告 | |
## 1. 数据概览 | |
- **数据时间范围**: {ts_data.index.min().strftime('%Y-%m-%d')} to {ts_data.index.max().strftime('%Y-%m-%d')} | |
- **总数据点**: {len(ts_data)} | |
- **平均用量**: {ts_data.mean():.2f} | |
## 2. 分析与建模参数 | |
- **平稳性差分阶数 (d)**: {d_order} | |
- **季节性周期 (m)**: {m_period} | |
- **最佳训练窗口**: {best_window_size} 天 | |
## 3. 模型评估 (基于动态滚动预测) | |
通过在历史数据上进行滚动预测,我们能更真实地评估模型在实际应用中的表现。 | |
{rolling_evaluation_df.to_markdown()} | |
## 4. 最终结论与未来预测 | |
- **最佳模型**: **{best_rolling_model_name}** 被选为最终预测模型,因为它在滚动预测中表现最佳(MAE最低)。 | |
- **未来预测**: 已使用 `{best_rolling_model_name}` 模型对未来 **{forecast_horizon}** 天的用量进行预测。 | |
- **预测摘要**: | |
- 未来一周平均日用量: **{final_forecast_series.head(7).mean():.2f}** | |
- 未来一月平均日用量: **{final_forecast_series.head(30).mean():.2f}** | |
""".strip() | |
report_file_path = os.path.join(OUTPUT_DIR, 'final_analysis_report.txt') | |
with open(report_file_path, 'w', encoding='utf-8') as f: | |
f.write(final_report_text) | |
log_lines.append("\n## 🎉 全部分析流程完成!请查看最终报告和图表。") | |
yield update_ui() | |
except Exception as e: | |
log_lines.append(f"\n\n❌ **分析过程中断,出现错误:**\n`{str(e)}`") | |
log_lines.append(f"\n**Traceback:**\n```{traceback.format_exc()}```") | |
yield update_ui() | |
# ======================== (D) Gradio 界面构建 ======================== | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: | |
gr.Markdown( | |
""" | |
# 📈 智能时序预测与分析平台 (动态回测版) | |
欢迎使用!请确保您的数据文件 `gmqrkl.xlsx` 已上传至本 Space 的文件库中。 | |
然后,点击下方按钮,启动包含 **窗口优化** 和 **动态滚动预测** 的完整分析流程。 | |
**注意:完整流程计算量大,可能需要10-20分钟。请耐心等待,并观察下方日志区的实时进度。** | |
""" | |
) | |
start_button = gr.Button("🚀 点击这里,开始完整分析", variant="primary") | |
gr.Markdown("---") | |
with gr.Tabs(): | |
with gr.TabItem("📊 可视化图表", id=0): | |
gallery_output = gr.Gallery(label="分析图表", elem_id="gallery", columns=[1], height="auto", object_fit="contain") | |
with gr.TabItem("📝 实时分析日志", id=1): | |
log_output = gr.Markdown("点击按钮后,分析日志将实时显示在这里...") | |
with gr.TabItem("📋 最终报告与下载", id=2): | |
final_report_output = gr.Markdown("分析完成后,最终报告将显示在这里。") | |
download_output = gr.File(label="下载报告文件") | |
start_button.click( | |
fn=run_full_analysis, | |
inputs=None, | |
outputs=[log_output, gallery_output, final_report_output, download_output] | |
) | |
gr.Markdown("<p style='text-align: center; font-size: 12px; color: grey;'>Powered by Gradio and Hugging Face Spaces.</p>") | |
if __name__ == "__main__": | |
demo.launch() |