Commit ·
c05a501
1
Parent(s): 4c7c808
add flux map analysis script and configuration for detecting and tracking active regions
Browse files- README.md +10 -3
- analysis/ablation_analysis.py +126 -0
- analysis/ablation_lollipop.py +162 -0
- analysis/flux_map_analysis.py +916 -0
- analysis/flux_map_config.yaml +67 -0
- analysis/spatial_performance.py +1 -1
- forecasting/inference/flare_analysis.py +0 -0
- forecasting/inference/flare_analysis_poster.py +0 -0
- pipeline_config.yaml +18 -1
- run_pipeline.py +17 -17
README.md
CHANGED
|
@@ -44,6 +44,11 @@ The solar soft X-ray (SXR) irradiance is a long-standing proxy of solar activity
|
|
| 44 |
|
| 45 |
```text
|
| 46 |
FOXES
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
├── data # Data cleaning and preprocessing
|
| 48 |
│ ├── align_data.py # Align AIA and SXR timestamps; save matched pairs
|
| 49 |
│ ├── euv_data_cleaning.py # EUV image quality filtering and cleaning
|
|
@@ -67,8 +72,7 @@ FOXES
|
|
| 67 |
│ ├── inference
|
| 68 |
│ │ ├── inference.py # Batch inference; writes predictions.csv
|
| 69 |
│ │ ├── evaluation.py # Compute metrics and generate evaluation plots
|
| 70 |
-
│ │ ├──
|
| 71 |
-
│ │ ├── local_config.yaml # Config for inference.py and flare_analysis.py
|
| 72 |
│ │ └── evaluation_config.yaml # Config for evaluation.py
|
| 73 |
│ ├── models
|
| 74 |
│ │ └── vit_patch_model_local.py # ViTLocal: Vision Transformer with patch flux heads
|
|
@@ -126,7 +130,9 @@ FOXES uses a single orchestrator script (`run_pipeline.py`) and a top-level conf
|
|
| 126 |
| 7 | `train` | Train the ViTLocal solar flare forecasting model |
|
| 127 |
| 8 | `inference` | Run batch inference and save a predictions CSV |
|
| 128 |
| 9 | `evaluate` | Compute metrics and generate evaluation plots |
|
| 129 |
-
| 10 | `
|
|
|
|
|
|
|
| 130 |
|
| 131 |
### Usage
|
| 132 |
|
|
@@ -245,6 +251,7 @@ Steps can also be run individually by calling their scripts directly:
|
|
| 245 |
python forecasting/training/train.py -config forecasting/training/train_config.yaml
|
| 246 |
python forecasting/inference/inference.py -config forecasting/inference/local_config.yaml
|
| 247 |
python forecasting/inference/evaluation.py -config forecasting/inference/evaluation_config.yaml
|
|
|
|
| 248 |
```
|
| 249 |
|
| 250 |
---
|
|
|
|
| 44 |
|
| 45 |
```text
|
| 46 |
FOXES
|
| 47 |
+
├── analysis # Post-inference analysis scripts
|
| 48 |
+
│ ├── flux_map_analysis.py # Detect, track, and visualize active regions from flux maps
|
| 49 |
+
│ ├── flux_map_config.yaml # Config for flux_map_analysis.py
|
| 50 |
+
│ ├── spatial_performance.py # Flux-weighted spatial error heatmap on the solar disk
|
| 51 |
+
│ └── ablation_analysis.py # Ablation study visualization
|
| 52 |
├── data # Data cleaning and preprocessing
|
| 53 |
│ ├── align_data.py # Align AIA and SXR timestamps; save matched pairs
|
| 54 |
│ ├── euv_data_cleaning.py # EUV image quality filtering and cleaning
|
|
|
|
| 72 |
│ ├── inference
|
| 73 |
│ │ ├── inference.py # Batch inference; writes predictions.csv
|
| 74 |
│ │ ├── evaluation.py # Compute metrics and generate evaluation plots
|
| 75 |
+
│ │ ├── local_config.yaml # Config for inference.py
|
|
|
|
| 76 |
│ │ └── evaluation_config.yaml # Config for evaluation.py
|
| 77 |
│ ├── models
|
| 78 |
│ │ └── vit_patch_model_local.py # ViTLocal: Vision Transformer with patch flux heads
|
|
|
|
| 130 |
| 7 | `train` | Train the ViTLocal solar flare forecasting model |
|
| 131 |
| 8 | `inference` | Run batch inference and save a predictions CSV |
|
| 132 |
| 9 | `evaluate` | Compute metrics and generate evaluation plots |
|
| 133 |
+
| 10 | `ablation` | Run channel-masking ablation study on a pretrained model |
|
| 134 |
+
| 11 | `spatial_performance` | Generate flux-weighted spatial error heatmap on the solar disk |
|
| 135 |
+
| 12 | `flux_map_analysis` | Detect and track active regions from flux maps; render frames and a movie |
|
| 136 |
|
| 137 |
### Usage
|
| 138 |
|
|
|
|
| 251 |
python forecasting/training/train.py -config forecasting/training/train_config.yaml
|
| 252 |
python forecasting/inference/inference.py -config forecasting/inference/local_config.yaml
|
| 253 |
python forecasting/inference/evaluation.py -config forecasting/inference/evaluation_config.yaml
|
| 254 |
+
python analysis/flux_map_analysis.py --config analysis/flux_map_config.yaml
|
| 255 |
```
|
| 256 |
|
| 257 |
---
|
analysis/ablation_analysis.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import matplotlib.ticker as mticker
|
| 5 |
+
from matplotlib import rcParams
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
PROJECT_ROOT = Path(__file__).parent.parent
|
| 10 |
+
sys.path.insert(0, str(PROJECT_ROOT))
|
| 11 |
+
from forecasting.inference.evaluation import setup_barlow_font
|
| 12 |
+
|
| 13 |
+
setup_barlow_font()
|
| 14 |
+
|
| 15 |
+
DATA_DIR = "/Users/griffingoodwin/Documents/gitrepos/FOXES/Untracked/data"
|
| 16 |
+
WAVELENGTHS = ["94", "131", "171", "193", "211", "304", "335", "stereo", "all"]
|
| 17 |
+
LABELS = ["Ablate 94 Å", "Ablate 131 Å", "Ablate 171 Å", "Ablate 193 Å",
|
| 18 |
+
"Ablate 211 Å", "Ablate 304 Å", "Ablate 335 Å", "Ablate STEREO", "Ablate All"]
|
| 19 |
+
|
| 20 |
+
rcParams['font.family'] = 'sans-serif'
|
| 21 |
+
rcParams['font.sans-serif'] = ['Barlow', 'Arial', 'DejaVu Sans']
|
| 22 |
+
|
| 23 |
+
FLARE_CLASSES = {
|
| 24 |
+
'A1.0': (1e-8, 1e-7),
|
| 25 |
+
'B1.0': (1e-7, 1e-6),
|
| 26 |
+
'C1.0': (1e-6, 1e-5),
|
| 27 |
+
'M1.0': (1e-5, 1e-4),
|
| 28 |
+
'X1.0': (1e-4, 1e-3),
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
text_color = 'black'
|
| 32 |
+
grid_color = '#CCCCCC'
|
| 33 |
+
|
| 34 |
+
VMIN_GLOBAL = 1e-9
|
| 35 |
+
VMAX_GLOBAL = 1e-2
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def add_flare_class_axes(ax, vmin, vmax):
|
| 39 |
+
def identity(x):
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
ax_top = ax.secondary_xaxis('top', functions=(identity, identity))
|
| 43 |
+
ax_right = ax.secondary_yaxis('right', functions=(identity, identity))
|
| 44 |
+
|
| 45 |
+
positions, labels = [], []
|
| 46 |
+
for cls, (lo, hi) in FLARE_CLASSES.items():
|
| 47 |
+
if vmin <= lo <= vmax:
|
| 48 |
+
positions.append(lo)
|
| 49 |
+
labels.append(cls)
|
| 50 |
+
|
| 51 |
+
ax_top.set_xticks(positions)
|
| 52 |
+
ax_top.set_xticklabels(labels, fontsize=6, color=text_color, rotation=45, ha='left')
|
| 53 |
+
ax_top.grid(False)
|
| 54 |
+
ax_top.tick_params(length=3)
|
| 55 |
+
|
| 56 |
+
ax_right.set_yticks(positions)
|
| 57 |
+
ax_right.set_yticklabels(labels, fontsize=6, color=text_color)
|
| 58 |
+
ax_right.grid(False)
|
| 59 |
+
ax_right.tick_params(length=3)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
fig, axes = plt.subplots(3, 3, figsize=(16, 14), layout='constrained')
|
| 63 |
+
axes = axes.flatten()
|
| 64 |
+
|
| 65 |
+
hb_last = None # for shared colorbar
|
| 66 |
+
|
| 67 |
+
for i, (wav, label) in enumerate(zip(WAVELENGTHS, LABELS)):
|
| 68 |
+
ab = pd.read_csv(f"{DATA_DIR}/ablate_{wav}_global_1.csv")
|
| 69 |
+
gt = ab["groundtruth"].values
|
| 70 |
+
pred = ab["predictions"].values
|
| 71 |
+
|
| 72 |
+
mask = (gt > 0) & (pred > 0)
|
| 73 |
+
gt, pred = gt[mask], pred[mask]
|
| 74 |
+
|
| 75 |
+
log_mae = np.mean(np.abs(np.log10(gt) - np.log10(pred)))
|
| 76 |
+
|
| 77 |
+
vmin = max(VMIN_GLOBAL, min(gt.min(), pred.min()))
|
| 78 |
+
vmax = min(VMAX_GLOBAL, max(gt.max(), pred.max()))
|
| 79 |
+
|
| 80 |
+
ax = axes[i]
|
| 81 |
+
ax.set_facecolor("#FFFFFF")
|
| 82 |
+
|
| 83 |
+
hb = ax.hexbin(gt, pred, gridsize=80, xscale='log', yscale='log',
|
| 84 |
+
cmap='bone', mincnt=1, bins='log',
|
| 85 |
+
extent=(np.log10(vmin), np.log10(vmax),
|
| 86 |
+
np.log10(vmin), np.log10(vmax)))
|
| 87 |
+
hb_last = hb
|
| 88 |
+
|
| 89 |
+
# 1:1 line
|
| 90 |
+
ax.plot([vmin, vmax], [vmin, vmax], ls='--', c='red', alpha=0.85, lw=1.2)
|
| 91 |
+
|
| 92 |
+
ax.set_xlim(vmin, vmax)
|
| 93 |
+
ax.set_ylim(vmin, vmax)
|
| 94 |
+
ax.set_xscale('log')
|
| 95 |
+
ax.set_yscale('log')
|
| 96 |
+
|
| 97 |
+
#ax.set_title(label, fontsize=11, fontweight='bold', color=text_color)
|
| 98 |
+
ax.set_xlabel(r'Ground Truth (W/m$^2$)', fontsize=8, color=text_color)
|
| 99 |
+
ax.set_ylabel(r'Prediction (W/m$^2$)', fontsize=8, color=text_color)
|
| 100 |
+
ax.tick_params(labelsize=7, colors=text_color)
|
| 101 |
+
ax.grid(True, alpha=0.5, color=grid_color, linewidth=0.5)
|
| 102 |
+
ax.set_axisbelow(True)
|
| 103 |
+
|
| 104 |
+
for lbl in ax.get_xticklabels():
|
| 105 |
+
lbl.set_fontfamily('Barlow')
|
| 106 |
+
for lbl in ax.get_yticklabels():
|
| 107 |
+
lbl.set_fontfamily('Barlow')
|
| 108 |
+
|
| 109 |
+
ax.text(0.04, 0.96, f"Log MAE = {log_mae:.3f}",
|
| 110 |
+
transform=ax.transAxes, fontsize=8, va='top', color=text_color,
|
| 111 |
+
bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
|
| 112 |
+
edgecolor='#CCCCCC', alpha=0.85))
|
| 113 |
+
|
| 114 |
+
add_flare_class_axes(ax, vmin, vmax)
|
| 115 |
+
|
| 116 |
+
# Shared colorbar
|
| 117 |
+
cbar = fig.colorbar(hb_last, ax=axes.tolist(), orientation='vertical', shrink=0.6, pad=0.01)
|
| 118 |
+
cbar.set_label("Count (log)", fontsize=11, color=text_color)
|
| 119 |
+
cbar.ax.tick_params(labelsize=9, colors=text_color)
|
| 120 |
+
cbar.ax.yaxis.set_minor_locator(mticker.LogLocator(base=10, subs='auto', numticks=10))
|
| 121 |
+
cbar.ax.tick_params(which='minor', colors=text_color)
|
| 122 |
+
|
| 123 |
+
#fig.suptitle("Ablation Study: Channel Masking vs. Baseline", fontsize=14, fontweight='bold')
|
| 124 |
+
plt.savefig("/Users/griffingoodwin/Documents/gitrepos/FOXES/analysis/ablation_3x3.png", dpi=150, bbox_inches="tight")
|
| 125 |
+
plt.show()
|
| 126 |
+
print("Saved: analysis/ablation_3x3.png")
|
analysis/ablation_lollipop.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import matplotlib.patches as mpatches
|
| 6 |
+
import matplotlib.font_manager as fm
|
| 7 |
+
from matplotlib import rcParams
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def setup_barlow_font():
|
| 11 |
+
try:
|
| 12 |
+
barlow_fonts = [f.name for f in fm.fontManager.ttflist
|
| 13 |
+
if 'barlow' in f.name.lower() or 'barlow' in f.fname.lower()]
|
| 14 |
+
if barlow_fonts:
|
| 15 |
+
rcParams['font.family'] = 'Barlow'
|
| 16 |
+
else:
|
| 17 |
+
for path in ['/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf',
|
| 18 |
+
'/Users/griffingoodwin/Library/Fonts/Barlow-Regular.otf']:
|
| 19 |
+
if os.path.exists(path):
|
| 20 |
+
fm.fontManager.addfont(path)
|
| 21 |
+
rcParams['font.family'] = 'Barlow'
|
| 22 |
+
break
|
| 23 |
+
else:
|
| 24 |
+
rcParams['font.family'] = 'sans-serif'
|
| 25 |
+
except Exception:
|
| 26 |
+
rcParams['font.family'] = 'sans-serif'
|
| 27 |
+
|
| 28 |
+
setup_barlow_font()
|
| 29 |
+
|
| 30 |
+
DATA_DIR = "/Users/griffingoodwin/Documents/gitrepos/FOXES/Untracked/data"
|
| 31 |
+
BASELINE_CSV = "/Volumes/T9/FOXES_Misc/batch_results/vit/vit_predictions_test.csv"
|
| 32 |
+
|
| 33 |
+
WAVELENGTHS = ["94", "131", "171", "193", "211", "304", "335","STEREO"]
|
| 34 |
+
LABELS = {
|
| 35 |
+
"94": "Ablate 94 Å",
|
| 36 |
+
"131": "Ablate 131 Å",
|
| 37 |
+
"171": "Ablate 171 Å",
|
| 38 |
+
"193": "Ablate 193 Å",
|
| 39 |
+
"211": "Ablate 211 Å",
|
| 40 |
+
"304": "Ablate 304 Å",
|
| 41 |
+
"335": "Ablate 335 Å",
|
| 42 |
+
"STEREO": "Ablate 94, 131, 335 Å\n(STEREO)",
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
FLARE_CLASSES = {
|
| 46 |
+
'< C': (1e-15, 1e-6),
|
| 47 |
+
'C': (1e-6, 1e-5),
|
| 48 |
+
'M': (1e-5, 1e-4),
|
| 49 |
+
'X': (1e-4, 1e-2),
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
CLASS_COLORS = {
|
| 53 |
+
'< C': '#4C9BE8',
|
| 54 |
+
'C': '#56C490',
|
| 55 |
+
'M': '#F5A623',
|
| 56 |
+
'X': '#E84C4C',
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# ── Compute metrics ────────────────────────────────────────────────────────────
|
| 60 |
+
def compute_row(label, gt, pred, is_baseline=False):
|
| 61 |
+
mask = (gt > 0) & (pred > 0)
|
| 62 |
+
gt, pred = gt[mask], pred[mask]
|
| 63 |
+
overall = np.mean(np.abs(np.log10(gt) - np.log10(pred)))
|
| 64 |
+
row = {"label": label, "overall": overall, "is_baseline": is_baseline}
|
| 65 |
+
for cls, (lo, hi) in FLARE_CLASSES.items():
|
| 66 |
+
m = (gt >= lo) & (gt < hi)
|
| 67 |
+
row[cls] = np.mean(np.abs(np.log10(gt[m]) - np.log10(pred[m]))) if m.sum() > 5 else np.nan
|
| 68 |
+
return row
|
| 69 |
+
|
| 70 |
+
records = []
|
| 71 |
+
|
| 72 |
+
# Baseline
|
| 73 |
+
bl = pd.read_csv(BASELINE_CSV)
|
| 74 |
+
records.append(compute_row("FOXES (no ablation)",
|
| 75 |
+
bl["groundtruth"].values, bl["predictions"].values,
|
| 76 |
+
is_baseline=True))
|
| 77 |
+
|
| 78 |
+
for wav in WAVELENGTHS:
|
| 79 |
+
ab = pd.read_csv(f"{DATA_DIR}/ablate_{wav}_global_1.csv")
|
| 80 |
+
records.append(compute_row(LABELS[wav], ab["groundtruth"].values, ab["predictions"].values))
|
| 81 |
+
|
| 82 |
+
# Sort ablation rows by overall MAE (worst first), keep baseline pinned at bottom
|
| 83 |
+
ablation_df = pd.DataFrame([r for r in records if not r["is_baseline"]])
|
| 84 |
+
ablation_df = ablation_df.sort_values("overall", ascending=False).reset_index(drop=True)
|
| 85 |
+
baseline_df = pd.DataFrame([r for r in records if r["is_baseline"]])
|
| 86 |
+
df = pd.concat([ablation_df, baseline_df], ignore_index=True)
|
| 87 |
+
|
| 88 |
+
# ── Plot ───────────────────────────────────────────────────────────────────────
|
| 89 |
+
n_rows = len(df)
|
| 90 |
+
fig, ax = plt.subplots(figsize=(11, 0.6 * n_rows + 1.5))
|
| 91 |
+
#ax.set_facecolor("#FAFAFA")
|
| 92 |
+
fig.patch.set_facecolor("#FFFFFF")
|
| 93 |
+
|
| 94 |
+
y_positions = np.arange(n_rows)
|
| 95 |
+
|
| 96 |
+
# Separator line between ablations and baseline
|
| 97 |
+
ax.axhline(y=n_rows - 1.5, color="#BBBBBB", linewidth=1, linestyle=":", zorder=1)
|
| 98 |
+
|
| 99 |
+
for i, row in df.iterrows():
|
| 100 |
+
y = y_positions[i]
|
| 101 |
+
is_bl = row["is_baseline"]
|
| 102 |
+
|
| 103 |
+
# Highlight baseline row
|
| 104 |
+
if is_bl:
|
| 105 |
+
ax.axhspan(y - 0.45, y + 0.45, color="#EEF6FF", zorder=0)
|
| 106 |
+
|
| 107 |
+
# Span line across per-class range
|
| 108 |
+
class_vals = [row[c] for c in FLARE_CLASSES if not np.isnan(row[c])]
|
| 109 |
+
if class_vals:
|
| 110 |
+
ax.hlines(y, min(class_vals), max(class_vals),
|
| 111 |
+
color="#CCCCCC", linewidth=2, zorder=1)
|
| 112 |
+
|
| 113 |
+
# Stem from 0 to overall
|
| 114 |
+
ax.hlines(y, 0, row["overall"],
|
| 115 |
+
color="#AAAAAA", linewidth=1.2, linestyle="--", zorder=0, alpha=0.6)
|
| 116 |
+
|
| 117 |
+
# Per-class dots
|
| 118 |
+
for cls in FLARE_CLASSES:
|
| 119 |
+
val = row[cls]
|
| 120 |
+
if not np.isnan(val):
|
| 121 |
+
ax.scatter(val, y, color=CLASS_COLORS[cls], s=80, zorder=4,
|
| 122 |
+
edgecolors="white", linewidths=0.6, alpha=0.75)
|
| 123 |
+
|
| 124 |
+
# Overall dot
|
| 125 |
+
outline_color = "#1A6BBF" if is_bl else "black"
|
| 126 |
+
ax.scatter(row["overall"], y, color="white", s=190, zorder=3,
|
| 127 |
+
edgecolors=outline_color, linewidths=2.0 if is_bl else 1.5, alpha=0.75)
|
| 128 |
+
ax.scatter(row["overall"], y, color=outline_color, s=75, zorder=3,
|
| 129 |
+
marker="|", linewidths=1.5, alpha=0.75)
|
| 130 |
+
|
| 131 |
+
tick_colors = ["black"] * n_rows
|
| 132 |
+
tick_colors[-1] = "#1A6BBF" # baseline label in blue
|
| 133 |
+
ax.set_yticks(y_positions)
|
| 134 |
+
ax.set_yticklabels(df["label"], fontsize=12)
|
| 135 |
+
for ticklabel, color in zip(ax.get_yticklabels(), tick_colors):
|
| 136 |
+
ticklabel.set_color(color)
|
| 137 |
+
if color != "black":
|
| 138 |
+
ticklabel.set_fontweight("bold")
|
| 139 |
+
ax.set_xlabel("MAE (log$_{10}$ scale)", fontsize=12)
|
| 140 |
+
ax.grid(True, axis="x", alpha=0.4, color="#CCCCCC", linewidth=0.6)
|
| 141 |
+
ax.set_axisbelow(True)
|
| 142 |
+
ax.spines[["top", "right"]].set_visible(False)
|
| 143 |
+
ax.tick_params(axis="y", length=0, labelsize=11)
|
| 144 |
+
ax.tick_params(axis="x", labelsize=10)
|
| 145 |
+
|
| 146 |
+
# Legend
|
| 147 |
+
class_patches = [
|
| 148 |
+
mpatches.Patch(color=CLASS_COLORS[c], label=f"{c}-class") for c in FLARE_CLASSES
|
| 149 |
+
]
|
| 150 |
+
overall_patch = mpatches.Patch(facecolor="white", edgecolor="black", label="Overall")
|
| 151 |
+
#baseline_patch = mpatches.Patch(facecolor="white", edgecolor="#1A6BBF", label="Baseline (overall)")
|
| 152 |
+
ax.legend(handles=class_patches + [overall_patch],
|
| 153 |
+
loc="upper right", fontsize=10, framealpha=0.9,
|
| 154 |
+
edgecolor="#CCCCCC")
|
| 155 |
+
|
| 156 |
+
# ax.set_title("Ablation Study — Log MAE by Channel & Flare Class",
|
| 157 |
+
# fontsize=14, fontweight="bold", pad=14)
|
| 158 |
+
plt.xlim(0, .85)
|
| 159 |
+
plt.tight_layout()
|
| 160 |
+
plt.savefig("ablation_lollipop.png", dpi=450, bbox_inches="tight")
|
| 161 |
+
plt.show()
|
| 162 |
+
print("Saved: analysis/ablation_lollipop.png")
|
analysis/flux_map_analysis.py
ADDED
|
@@ -0,0 +1,916 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Flare Analysis — Frame & Movie Generator
|
| 4 |
+
|
| 5 |
+
Detects and tracks active regions from flux contribution maps,
|
| 6 |
+
then renders per-timestamp frames and stitches them into a movie.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python flux_map_analysis.py --config flux_map_config.yaml
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import os
|
| 16 |
+
import time
|
| 17 |
+
import warnings
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
from heapq import heappush, heappop
|
| 21 |
+
from multiprocessing import Pool
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Dict, List, Optional, Tuple
|
| 24 |
+
|
| 25 |
+
import imageio.v2 as imageio
|
| 26 |
+
import imageio_ffmpeg
|
| 27 |
+
import matplotlib
|
| 28 |
+
matplotlib.use('Agg')
|
| 29 |
+
import matplotlib.dates as mdates
|
| 30 |
+
import matplotlib.font_manager as fm
|
| 31 |
+
import matplotlib.pyplot as plt
|
| 32 |
+
import numpy as np
|
| 33 |
+
import pandas as pd
|
| 34 |
+
import yaml
|
| 35 |
+
from matplotlib import rcParams
|
| 36 |
+
from scipy.ndimage import maximum_filter, gaussian_filter
|
| 37 |
+
from tqdm import tqdm
|
| 38 |
+
|
| 39 |
+
warnings.filterwarnings('ignore')
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# =============================================================================
|
| 43 |
+
# Configuration
|
| 44 |
+
# =============================================================================
|
| 45 |
+
|
| 46 |
+
@dataclass
|
| 47 |
+
class FlareAnalysisConfig:
|
| 48 |
+
"""Configuration for flare analysis."""
|
| 49 |
+
|
| 50 |
+
# Paths
|
| 51 |
+
flux_path: Optional[str] = None
|
| 52 |
+
aia_path: Optional[str] = None
|
| 53 |
+
predictions_csv: Optional[str] = None
|
| 54 |
+
output_dir: Optional[str] = None
|
| 55 |
+
|
| 56 |
+
# Time range
|
| 57 |
+
start_time: Optional[str] = None
|
| 58 |
+
end_time: Optional[str] = None
|
| 59 |
+
|
| 60 |
+
# Detection
|
| 61 |
+
min_flux_threshold: float = 1e-7
|
| 62 |
+
threshold_std_multiplier: float = 3.0
|
| 63 |
+
spatial_smoothing_sigma: float = 1.0
|
| 64 |
+
radial_expansion_threshold_percentile: float = 30.0
|
| 65 |
+
peak_neighborhood_sizes: Tuple[int, ...] = (10, 15, 20, 25)
|
| 66 |
+
peak_min_scale_agreement: int = 2
|
| 67 |
+
peak_scale_tolerance: int = 2
|
| 68 |
+
min_peak_distance: int = 10
|
| 69 |
+
|
| 70 |
+
# Grid
|
| 71 |
+
grid_size: Tuple[int, int] = (64, 64)
|
| 72 |
+
patch_size: int = 8
|
| 73 |
+
input_size: int = 512
|
| 74 |
+
|
| 75 |
+
# Tracking
|
| 76 |
+
max_tracking_distance: int = 8
|
| 77 |
+
flux_ratio_weight: float = 0.1
|
| 78 |
+
size_ratio_weight: float = 0.1
|
| 79 |
+
distance_weight: float = 1.0
|
| 80 |
+
age_bonus_weight: float = 1.0 # scales 1/(1+age) penalty on new tracks
|
| 81 |
+
cadence_seconds: float = 60.0
|
| 82 |
+
max_gap_frames: int = 1 # frames a track can persist without a detection before expiring
|
| 83 |
+
|
| 84 |
+
# Movie / output
|
| 85 |
+
create_movie: bool = False
|
| 86 |
+
plot_window_hours: float = 4.0
|
| 87 |
+
movie_fps: float = 2.0
|
| 88 |
+
movie_frame_interval_minutes: float = 1.0
|
| 89 |
+
movie_num_workers: int = 4
|
| 90 |
+
movie_dpi: float = 75.0
|
| 91 |
+
movie_frame_format: str = 'jpg'
|
| 92 |
+
movie_jpeg_quality: int = 90
|
| 93 |
+
|
| 94 |
+
@classmethod
|
| 95 |
+
def from_yaml(cls, path: str) -> "FlareAnalysisConfig":
|
| 96 |
+
with open(path) as f:
|
| 97 |
+
raw = yaml.safe_load(f) or {}
|
| 98 |
+
|
| 99 |
+
# Flatten one level of nesting
|
| 100 |
+
# The 'movie' section uses short keys (fps, dpi, …) — prefix them with 'movie_'
|
| 101 |
+
flat: Dict = {}
|
| 102 |
+
for key, val in raw.items():
|
| 103 |
+
if isinstance(val, dict):
|
| 104 |
+
if key == 'movie':
|
| 105 |
+
valid = cls.__dataclass_fields__
|
| 106 |
+
for k, v in val.items():
|
| 107 |
+
if k in valid:
|
| 108 |
+
flat[k] = v
|
| 109 |
+
elif f'movie_{k}' in valid:
|
| 110 |
+
flat[f'movie_{k}'] = v
|
| 111 |
+
else:
|
| 112 |
+
flat[k] = v
|
| 113 |
+
else:
|
| 114 |
+
flat.update(val)
|
| 115 |
+
else:
|
| 116 |
+
flat[key] = val
|
| 117 |
+
|
| 118 |
+
# Renamed YAML keys
|
| 119 |
+
if 'start' in flat:
|
| 120 |
+
flat['start_time'] = flat.pop('start')
|
| 121 |
+
if 'end' in flat:
|
| 122 |
+
flat['end_time'] = flat.pop('end')
|
| 123 |
+
|
| 124 |
+
# Lists → tuples for tuple-typed fields
|
| 125 |
+
for k in ('grid_size', 'peak_neighborhood_sizes'):
|
| 126 |
+
if k in flat and isinstance(flat[k], list):
|
| 127 |
+
flat[k] = tuple(flat[k])
|
| 128 |
+
|
| 129 |
+
valid = {f for f in cls.__dataclass_fields__}
|
| 130 |
+
return cls(**{k: v for k, v in flat.items() if k in valid and v is not None})
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# =============================================================================
|
| 134 |
+
# Utilities
|
| 135 |
+
# =============================================================================
|
| 136 |
+
|
| 137 |
+
def flux_to_goes_class(flux: float) -> str:
|
| 138 |
+
"""Convert physical flux (W/m²) to GOES class string."""
|
| 139 |
+
if not isinstance(flux, (int, float)) or np.isnan(flux) or flux <= 0:
|
| 140 |
+
return "N/A"
|
| 141 |
+
if flux >= 1e-4:
|
| 142 |
+
prefix, scale = "X", 1e-4
|
| 143 |
+
elif flux >= 1e-5:
|
| 144 |
+
prefix, scale = "M", 1e-5
|
| 145 |
+
elif flux >= 1e-6:
|
| 146 |
+
prefix, scale = "C", 1e-6
|
| 147 |
+
elif flux >= 1e-7:
|
| 148 |
+
prefix, scale = "B", 1e-7
|
| 149 |
+
else:
|
| 150 |
+
prefix, scale = "A", 1e-8
|
| 151 |
+
magnitude = min(flux / scale, 9.9)
|
| 152 |
+
return f"{prefix}{magnitude:.1f}" if magnitude != int(magnitude) else f"{prefix}{int(magnitude)}.0"
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def setup_barlow_font() -> None:
|
| 156 |
+
"""Register and activate the Barlow font if available."""
|
| 157 |
+
try:
|
| 158 |
+
barlow_fonts = [
|
| 159 |
+
(f.name, f.fname) for f in fm.fontManager.ttflist
|
| 160 |
+
if 'barlow' in f.name.lower()
|
| 161 |
+
]
|
| 162 |
+
if barlow_fonts:
|
| 163 |
+
preferred = next((n for n, _ in barlow_fonts if n.lower() in ('barlow', 'barlow regular')), barlow_fonts[0][0])
|
| 164 |
+
rcParams['font.family'] = preferred
|
| 165 |
+
return
|
| 166 |
+
|
| 167 |
+
search_paths = [
|
| 168 |
+
os.path.expanduser('~/Library/Fonts/Barlow-Regular.otf'),
|
| 169 |
+
os.path.expanduser('~/Library/Fonts/Barlow-Regular.ttf'),
|
| 170 |
+
'/Library/Fonts/Barlow-Regular.otf',
|
| 171 |
+
'/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf',
|
| 172 |
+
]
|
| 173 |
+
for path in search_paths:
|
| 174 |
+
if os.path.exists(path):
|
| 175 |
+
fm.fontManager.addfont(path)
|
| 176 |
+
from matplotlib.font_manager import FontProperties
|
| 177 |
+
rcParams['font.family'] = FontProperties(fname=path).get_name()
|
| 178 |
+
return
|
| 179 |
+
except Exception:
|
| 180 |
+
pass
|
| 181 |
+
rcParams['font.family'] = 'sans-serif'
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def load_aia_image_at_time(aia_path: Path, timestamp: str) -> Optional[np.ndarray]:
|
| 185 |
+
"""Load AIA image as normalised RGB composite (channels 0, 1, 2 → 94, 131, 171 Å)."""
|
| 186 |
+
if aia_path is None or not aia_path.exists():
|
| 187 |
+
return None
|
| 188 |
+
|
| 189 |
+
search_dirs = [aia_path] + [aia_path / s for s in ('test', 'train', 'val') if (aia_path / s).exists()]
|
| 190 |
+
for d in search_dirs:
|
| 191 |
+
fp = d / f"{timestamp}.npy"
|
| 192 |
+
if fp.exists():
|
| 193 |
+
try:
|
| 194 |
+
data = np.load(fp) # (7, H, W)
|
| 195 |
+
if data.ndim == 3 and data.shape[0] >= 3:
|
| 196 |
+
rgb = np.zeros((data.shape[1], data.shape[2], 3))
|
| 197 |
+
for i in range(3):
|
| 198 |
+
ch = data[i]
|
| 199 |
+
r = ch.max() - ch.min()
|
| 200 |
+
rgb[..., i] = (ch - ch.min()) / r if r > 0 else ch
|
| 201 |
+
return rgb
|
| 202 |
+
except Exception:
|
| 203 |
+
continue
|
| 204 |
+
return None
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
# =============================================================================
|
| 208 |
+
# Region Detection & Tracking
|
| 209 |
+
# =============================================================================
|
| 210 |
+
|
| 211 |
+
class FluxContributionAnalyzer:
|
| 212 |
+
"""Detects and tracks active regions from per-patch flux contribution maps."""
|
| 213 |
+
|
| 214 |
+
def __init__(self, config: FlareAnalysisConfig, output_dir: Optional[Path] = None):
|
| 215 |
+
self.config = config
|
| 216 |
+
self.flux_path = Path(config.flux_path) if config.flux_path else None
|
| 217 |
+
self.aia_path = Path(config.aia_path) if config.aia_path else None
|
| 218 |
+
self.output_dir = output_dir
|
| 219 |
+
self.grid_size = config.grid_size
|
| 220 |
+
self.patch_size = config.patch_size
|
| 221 |
+
self.input_size = config.input_size
|
| 222 |
+
self.region_labels_cache: Dict[str, np.ndarray] = {}
|
| 223 |
+
|
| 224 |
+
if config.predictions_csv:
|
| 225 |
+
self.predictions_df = pd.read_csv(config.predictions_csv)
|
| 226 |
+
self.predictions_df['datetime'] = pd.to_datetime(self.predictions_df['timestamp'])
|
| 227 |
+
self.predictions_df = self.predictions_df.sort_values('datetime')
|
| 228 |
+
if config.start_time and config.end_time:
|
| 229 |
+
start, end = pd.to_datetime(config.start_time), pd.to_datetime(config.end_time)
|
| 230 |
+
mask = (self.predictions_df['datetime'] >= start) & (self.predictions_df['datetime'] <= end)
|
| 231 |
+
self.predictions_df = self.predictions_df[mask].reset_index(drop=True)
|
| 232 |
+
print(f"Loaded {len(self.predictions_df)} predictions "
|
| 233 |
+
f"({self.predictions_df['datetime'].min()} → {self.predictions_df['datetime'].max()})")
|
| 234 |
+
else:
|
| 235 |
+
self.predictions_df = pd.DataFrame()
|
| 236 |
+
|
| 237 |
+
# ------------------------------------------------------------------
|
| 238 |
+
# Data loading
|
| 239 |
+
# ------------------------------------------------------------------
|
| 240 |
+
|
| 241 |
+
def load_flux_contributions(self, timestamp: str) -> Optional[np.ndarray]:
|
| 242 |
+
if self.flux_path is None:
|
| 243 |
+
return None
|
| 244 |
+
fp = self.flux_path / f"{timestamp}.npy"
|
| 245 |
+
return np.load(fp) if fp.exists() else None
|
| 246 |
+
|
| 247 |
+
# ------------------------------------------------------------------
|
| 248 |
+
# Peak detection
|
| 249 |
+
# ------------------------------------------------------------------
|
| 250 |
+
|
| 251 |
+
def _find_flux_peaks_single_scale(self, flux: np.ndarray, size: int) -> Tuple[List, List]:
|
| 252 |
+
valid = np.isfinite(flux) & (flux > 0)
|
| 253 |
+
masked = np.where(valid, flux, -np.inf)
|
| 254 |
+
local_max = (maximum_filter(masked, size=size) == masked) & valid
|
| 255 |
+
ys, xs = np.where(local_max)
|
| 256 |
+
coords = list(zip(ys.tolist(), xs.tolist()))
|
| 257 |
+
fluxes = [float(flux[y, x]) for y, x in coords]
|
| 258 |
+
return coords, fluxes
|
| 259 |
+
|
| 260 |
+
def _find_flux_peaks_multiscale(self, flux: np.ndarray) -> Tuple[List, List]:
|
| 261 |
+
cfg = self.config
|
| 262 |
+
registry: Dict[Tuple, dict] = {}
|
| 263 |
+
|
| 264 |
+
for size in cfg.peak_neighborhood_sizes:
|
| 265 |
+
coords, fluxes = self._find_flux_peaks_single_scale(flux, size)
|
| 266 |
+
for (y, x), fv in zip(coords, fluxes):
|
| 267 |
+
matched = next(
|
| 268 |
+
((py, px) for (py, px) in registry
|
| 269 |
+
if abs(y - py) <= cfg.peak_scale_tolerance and abs(x - px) <= cfg.peak_scale_tolerance),
|
| 270 |
+
None
|
| 271 |
+
)
|
| 272 |
+
if matched:
|
| 273 |
+
e = registry[matched]
|
| 274 |
+
e['count'] += 1
|
| 275 |
+
if fv > e['best_flux']:
|
| 276 |
+
e['best_flux'] = fv
|
| 277 |
+
e['best_coord'] = (y, x)
|
| 278 |
+
else:
|
| 279 |
+
registry[(y, x)] = {'count': 1, 'best_flux': fv, 'best_coord': (y, x)}
|
| 280 |
+
|
| 281 |
+
stable = [(e['best_coord'], e['best_flux'])
|
| 282 |
+
for e in registry.values() if e['count'] >= cfg.peak_min_scale_agreement]
|
| 283 |
+
if not stable:
|
| 284 |
+
return [], []
|
| 285 |
+
|
| 286 |
+
stable.sort(key=lambda p: p[1], reverse=True)
|
| 287 |
+
|
| 288 |
+
coords = [p[0] for p in stable]
|
| 289 |
+
fluxes = [p[1] for p in stable]
|
| 290 |
+
|
| 291 |
+
if cfg.min_peak_distance > 0 and len(coords) > 1:
|
| 292 |
+
coords, fluxes = self._merge_close_peaks(coords, fluxes, cfg.min_peak_distance)
|
| 293 |
+
|
| 294 |
+
return coords, fluxes
|
| 295 |
+
|
| 296 |
+
def _merge_close_peaks(self, coords, fluxes, min_dist):
|
| 297 |
+
order = np.argsort(fluxes)[::-1]
|
| 298 |
+
kept = []
|
| 299 |
+
for i in order:
|
| 300 |
+
if all(np.hypot(coords[i][0] - coords[j][0], coords[i][1] - coords[j][1]) >= min_dist
|
| 301 |
+
for j in kept):
|
| 302 |
+
kept.append(i)
|
| 303 |
+
kept = sorted(kept)
|
| 304 |
+
return [coords[i] for i in kept], [fluxes[i] for i in kept]
|
| 305 |
+
|
| 306 |
+
# ------------------------------------------------------------------
|
| 307 |
+
# Region segmentation (radial flood-fill from peaks)
|
| 308 |
+
# ------------------------------------------------------------------
|
| 309 |
+
|
| 310 |
+
def _detect_regions_with_peak_clustering(
|
| 311 |
+
self, flux_contrib: np.ndarray, pred_data: pd.Series
|
| 312 |
+
) -> Tuple[List[Dict], Optional[np.ndarray], str]:
|
| 313 |
+
|
| 314 |
+
cfg = self.config
|
| 315 |
+
valid = flux_contrib[np.isfinite(flux_contrib) & (flux_contrib > 0)]
|
| 316 |
+
if valid.size == 0:
|
| 317 |
+
return [], None, "no_valid_flux"
|
| 318 |
+
|
| 319 |
+
total_flux = float(flux_contrib[flux_contrib > 0].sum())
|
| 320 |
+
log_flux = np.log(valid)
|
| 321 |
+
threshold = max(
|
| 322 |
+
np.exp(np.median(log_flux) + cfg.threshold_std_multiplier * np.std(log_flux)),
|
| 323 |
+
cfg.min_flux_threshold,
|
| 324 |
+
)
|
| 325 |
+
above = int((flux_contrib > threshold).sum())
|
| 326 |
+
masked = np.where(flux_contrib > threshold, flux_contrib, 0.0)
|
| 327 |
+
|
| 328 |
+
if above == 0:
|
| 329 |
+
return [], None, f"all_below_threshold(thr={threshold:.3e} total={total_flux:.3e})"
|
| 330 |
+
|
| 331 |
+
if cfg.spatial_smoothing_sigma > 0:
|
| 332 |
+
masked = gaussian_filter(masked, sigma=cfg.spatial_smoothing_sigma)
|
| 333 |
+
|
| 334 |
+
peak_coords, peak_fluxes = self._find_flux_peaks_multiscale(masked)
|
| 335 |
+
if not peak_coords:
|
| 336 |
+
return [], None, f"no_peaks(thr={threshold:.3e} above={above} total={total_flux:.3e})"
|
| 337 |
+
|
| 338 |
+
# Radial flood-fill from all peaks simultaneously (Dijkstra-style)
|
| 339 |
+
labels = np.zeros_like(masked, dtype=np.int32)
|
| 340 |
+
valid_vals = masked[(masked > 0) & np.isfinite(masked)]
|
| 341 |
+
growth_threshold = np.percentile(valid_vals, cfg.radial_expansion_threshold_percentile) if valid_vals.size else 0
|
| 342 |
+
|
| 343 |
+
pq, counter = [], 0
|
| 344 |
+
for idx, ((py, px), _) in enumerate(zip(peak_coords, peak_fluxes)):
|
| 345 |
+
labels[py, px] = idx + 1
|
| 346 |
+
heappush(pq, (0.0, counter, py, px, idx + 1, py, px))
|
| 347 |
+
counter += 1
|
| 348 |
+
|
| 349 |
+
neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
|
| 350 |
+
H, W = masked.shape
|
| 351 |
+
while pq:
|
| 352 |
+
dist, _, y, x, label, py, px = heappop(pq)
|
| 353 |
+
for dy, dx in neighbors:
|
| 354 |
+
ny, nx = y + dy, x + dx
|
| 355 |
+
if 0 <= ny < H and 0 <= nx < W and labels[ny, nx] == 0 and masked[ny, nx] > growth_threshold:
|
| 356 |
+
labels[ny, nx] = label
|
| 357 |
+
new_dist = np.hypot(ny - py, nx - px)
|
| 358 |
+
heappush(pq, (new_dist, counter, ny, nx, label, py, px))
|
| 359 |
+
counter += 1
|
| 360 |
+
|
| 361 |
+
regions = []
|
| 362 |
+
skipped_below_min = 0
|
| 363 |
+
for lid in range(1, len(peak_coords) + 1):
|
| 364 |
+
mask = labels == lid
|
| 365 |
+
ys, xs = np.where(mask)
|
| 366 |
+
if ys.size == 0:
|
| 367 |
+
continue
|
| 368 |
+
fv = masked[mask]
|
| 369 |
+
total = float(fv.sum())
|
| 370 |
+
if total < cfg.min_flux_threshold:
|
| 371 |
+
skipped_below_min += 1
|
| 372 |
+
continue
|
| 373 |
+
cy, cx = float(ys.mean()), float(xs.mean())
|
| 374 |
+
peak_y, peak_x = peak_coords[lid - 1]
|
| 375 |
+
regions.append({
|
| 376 |
+
'id': len(regions) + 1,
|
| 377 |
+
'region_label': lid,
|
| 378 |
+
'size': int(ys.size),
|
| 379 |
+
'sum_flux': total,
|
| 380 |
+
'max_flux': float(fv.max()),
|
| 381 |
+
'centroid_patch_y': cy,
|
| 382 |
+
'centroid_patch_x': cx,
|
| 383 |
+
'centroid_img_y': cy * self.patch_size + self.patch_size // 2,
|
| 384 |
+
'centroid_img_x': cx * self.patch_size + self.patch_size // 2,
|
| 385 |
+
'peak_y': peak_y,
|
| 386 |
+
'peak_x': peak_x,
|
| 387 |
+
'peak_img_y': peak_y * self.patch_size + self.patch_size // 2,
|
| 388 |
+
'peak_img_x': peak_x * self.patch_size + self.patch_size // 2,
|
| 389 |
+
'peak_flux': peak_fluxes[lid - 1],
|
| 390 |
+
'mask': mask,
|
| 391 |
+
})
|
| 392 |
+
|
| 393 |
+
n_peaks = len(peak_coords)
|
| 394 |
+
reason = (f"ok: {len(regions)} regions from {n_peaks} peaks"
|
| 395 |
+
f" thr={threshold:.3e} above={above} total={total_flux:.3e}"
|
| 396 |
+
+ (f" skipped={skipped_below_min}_below_min_flux" if skipped_below_min else ""))
|
| 397 |
+
return regions, labels, reason
|
| 398 |
+
|
| 399 |
+
def _detect_regions_worker(self, timestamp: str) -> Tuple[str, Optional[List], Optional[np.ndarray], str]:
|
| 400 |
+
try:
|
| 401 |
+
flux = self.load_flux_contributions(timestamp)
|
| 402 |
+
if flux is None:
|
| 403 |
+
return timestamp, None, None, "no_flux_file"
|
| 404 |
+
pred = self.predictions_df[self.predictions_df['timestamp'] == timestamp]
|
| 405 |
+
if pred.empty:
|
| 406 |
+
return timestamp, None, None, "no_prediction_row"
|
| 407 |
+
regions, labels, reason = self._detect_regions_with_peak_clustering(flux, pred.iloc[0])
|
| 408 |
+
return timestamp, regions, (labels.astype(np.int16) if labels is not None else None), reason
|
| 409 |
+
except Exception as e:
|
| 410 |
+
return timestamp, None, None, f"exception: {e}"
|
| 411 |
+
|
| 412 |
+
# ------------------------------------------------------------------
|
| 413 |
+
# Tracking
|
| 414 |
+
# ------------------------------------------------------------------
|
| 415 |
+
|
| 416 |
+
def track_regions_over_time(self, timestamps: List[str]) -> Dict:
|
| 417 |
+
cfg = self.config
|
| 418 |
+
print("Detecting regions (parallel)…")
|
| 419 |
+
n_workers = max(1, min((os.cpu_count() or 1) - 1, len(timestamps)))
|
| 420 |
+
|
| 421 |
+
all_regions: Dict[str, List] = {}
|
| 422 |
+
detection_reasons: Dict[str, str] = {}
|
| 423 |
+
with Pool(processes=n_workers) as pool:
|
| 424 |
+
for ts, regions, labels, reason in tqdm(
|
| 425 |
+
pool.imap(self._detect_regions_worker, timestamps),
|
| 426 |
+
total=len(timestamps), desc="Detecting regions"
|
| 427 |
+
):
|
| 428 |
+
detection_reasons[ts] = reason
|
| 429 |
+
if regions is not None:
|
| 430 |
+
all_regions[ts] = regions
|
| 431 |
+
if labels is not None:
|
| 432 |
+
self.region_labels_cache[ts] = labels
|
| 433 |
+
|
| 434 |
+
print("Tracking regions across time…")
|
| 435 |
+
print(f" max_tracking_distance={cfg.max_tracking_distance} "
|
| 436 |
+
f"max_gap_frames={cfg.max_gap_frames} "
|
| 437 |
+
f"age_bonus_weight={cfg.age_bonus_weight} "
|
| 438 |
+
f"distance_weight={cfg.distance_weight}")
|
| 439 |
+
tracks: Dict[int, List] = {}
|
| 440 |
+
next_id = 1
|
| 441 |
+
last_seen: Dict[int, int] = {} # track_id → frame index when last matched
|
| 442 |
+
_debug_log: List[str] = [] # per-frame tracking log
|
| 443 |
+
|
| 444 |
+
for frame_idx, ts in enumerate(tqdm(timestamps, desc="Tracking")):
|
| 445 |
+
# Expire tracks that haven't been seen within max_gap_frames
|
| 446 |
+
active = {tid for tid, fi in last_seen.items()
|
| 447 |
+
if frame_idx - fi <= cfg.max_gap_frames}
|
| 448 |
+
|
| 449 |
+
if ts not in all_regions:
|
| 450 |
+
det_reason = detection_reasons.get(ts, "unknown")
|
| 451 |
+
_debug_log.append(f"{ts} SKIP {det_reason}")
|
| 452 |
+
continue
|
| 453 |
+
|
| 454 |
+
current_regions = all_regions[ts]
|
| 455 |
+
|
| 456 |
+
# Build all valid (score, region_idx, track_id) candidates
|
| 457 |
+
candidates = []
|
| 458 |
+
for ri, region in enumerate(current_regions):
|
| 459 |
+
cur_flux = region.get('sum_flux', 0.0)
|
| 460 |
+
cur_size = region.get('size', 1)
|
| 461 |
+
for tid in active:
|
| 462 |
+
history = tracks[tid]
|
| 463 |
+
# Smooth position over last few frames to reduce centroid jitter.
|
| 464 |
+
# Use PATCH coordinates so max_tracking_distance is in patch units (matching config).
|
| 465 |
+
n_smooth = min(5, len(history))
|
| 466 |
+
avg_x = np.mean([h[1]['centroid_patch_x'] for h in history[-n_smooth:]])
|
| 467 |
+
avg_y = np.mean([h[1]['centroid_patch_y'] for h in history[-n_smooth:]])
|
| 468 |
+
dist = np.hypot(
|
| 469 |
+
region['centroid_patch_x'] - avg_x,
|
| 470 |
+
region['centroid_patch_y'] - avg_y,
|
| 471 |
+
)
|
| 472 |
+
_, last = history[-1]
|
| 473 |
+
if dist >= cfg.max_tracking_distance:
|
| 474 |
+
continue
|
| 475 |
+
lf = last.get('sum_flux', 1e-15)
|
| 476 |
+
ls = last.get('size', 1)
|
| 477 |
+
flux_ratio = max(cur_flux, lf) / max(min(cur_flux, lf), 1e-15)
|
| 478 |
+
size_ratio = max(cur_size, ls) / max(min(cur_size, ls), 1)
|
| 479 |
+
track_age = len(tracks[tid])
|
| 480 |
+
# Discount grows with age: 0 (new) → age_bonus_weight (very old)
|
| 481 |
+
# Makes established tracks harder to beat at equal distance
|
| 482 |
+
age_discount = cfg.age_bonus_weight * track_age / (1.0 + track_age)
|
| 483 |
+
score = (cfg.distance_weight * dist
|
| 484 |
+
+ cfg.flux_ratio_weight * flux_ratio
|
| 485 |
+
+ cfg.size_ratio_weight * size_ratio
|
| 486 |
+
- age_discount)
|
| 487 |
+
candidates.append((score, ri, tid))
|
| 488 |
+
|
| 489 |
+
# Greedy one-to-one assignment: best scores first, each region/track used once
|
| 490 |
+
candidates.sort()
|
| 491 |
+
assigned_regions: set = set()
|
| 492 |
+
assigned_tracks: set = set()
|
| 493 |
+
assignments: Dict[int, int] = {} # region_idx → track_id
|
| 494 |
+
for score, ri, tid in candidates:
|
| 495 |
+
if ri in assigned_regions or tid in assigned_tracks:
|
| 496 |
+
continue
|
| 497 |
+
assignments[ri] = tid
|
| 498 |
+
assigned_regions.add(ri)
|
| 499 |
+
assigned_tracks.add(tid)
|
| 500 |
+
|
| 501 |
+
# Log detection outcome for this frame
|
| 502 |
+
det_reason = detection_reasons.get(ts, "unknown")
|
| 503 |
+
_debug_log.append(f"{ts} DETECT {det_reason}")
|
| 504 |
+
|
| 505 |
+
# Log active-but-unmatched tracks (gaps)
|
| 506 |
+
for tid in active:
|
| 507 |
+
if tid not in assigned_tracks:
|
| 508 |
+
gap = frame_idx - last_seen.get(tid, frame_idx)
|
| 509 |
+
cx = tracks[tid][-1][1]['centroid_patch_x']
|
| 510 |
+
cy = tracks[tid][-1][1]['centroid_patch_y']
|
| 511 |
+
_debug_log.append(
|
| 512 |
+
f"{ts} GAP track={tid:3d} age={len(tracks[tid]):4d} "
|
| 513 |
+
f"gap_frames={gap:2d} last_patch=({cx:.1f},{cy:.1f})"
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
# Apply assignments; spawn new track for unmatched regions
|
| 517 |
+
for ri, region in enumerate(current_regions):
|
| 518 |
+
r = region.copy()
|
| 519 |
+
r['timestamp'] = ts
|
| 520 |
+
if ri in assignments:
|
| 521 |
+
tid = assignments[ri]
|
| 522 |
+
r['id'] = tid
|
| 523 |
+
tracks[tid].append((ts, r))
|
| 524 |
+
cx, cy = r['centroid_patch_x'], r['centroid_patch_y']
|
| 525 |
+
_debug_log.append(
|
| 526 |
+
f"{ts} MATCH track={tid:3d} age={len(tracks[tid]):4d} "
|
| 527 |
+
f"patch=({cx:.1f},{cy:.1f}) flux={r.get('sum_flux', 0):.3e}"
|
| 528 |
+
)
|
| 529 |
+
else:
|
| 530 |
+
r['id'] = next_id
|
| 531 |
+
tracks[next_id] = [(ts, r)]
|
| 532 |
+
cx, cy = r['centroid_patch_x'], r['centroid_patch_y']
|
| 533 |
+
_debug_log.append(
|
| 534 |
+
f"{ts} NEW track={next_id:3d} age= 1 "
|
| 535 |
+
f"patch=({cx:.1f},{cy:.1f}) flux={r.get('sum_flux', 0):.3e}"
|
| 536 |
+
)
|
| 537 |
+
next_id += 1
|
| 538 |
+
tid = r['id']
|
| 539 |
+
last_seen[tid] = frame_idx
|
| 540 |
+
|
| 541 |
+
tracks = {k: v for k, v in tracks.items() if v}
|
| 542 |
+
print(f"Found {len(tracks)} region tracks across {len(timestamps)} timestamps")
|
| 543 |
+
|
| 544 |
+
if self.output_dir and _debug_log:
|
| 545 |
+
log_path = Path(self.output_dir) / "tracking_debug.log"
|
| 546 |
+
with open(log_path, 'w') as f:
|
| 547 |
+
f.write(f"# Tracking log — {len(tracks)} tracks, {len(timestamps)} timestamps\n")
|
| 548 |
+
f.write(f"# max_tracking_distance={cfg.max_tracking_distance} "
|
| 549 |
+
f"max_gap_frames={cfg.max_gap_frames} "
|
| 550 |
+
f"age_bonus_weight={cfg.age_bonus_weight}\n")
|
| 551 |
+
f.write("#\n# timestamp event track age detail\n")
|
| 552 |
+
f.write('\n'.join(_debug_log))
|
| 553 |
+
print(f"Tracking debug log → {log_path}")
|
| 554 |
+
|
| 555 |
+
return tracks
|
| 556 |
+
|
| 557 |
+
def detect_flare_events(self, timestamps: Optional[List[str]] = None) -> pd.DataFrame:
|
| 558 |
+
"""Run detection + tracking and return a per-timestamp events DataFrame."""
|
| 559 |
+
if timestamps is None:
|
| 560 |
+
timestamps = self.predictions_df['timestamp'].tolist()
|
| 561 |
+
|
| 562 |
+
tracks = self.track_regions_over_time(timestamps)
|
| 563 |
+
rows = []
|
| 564 |
+
for track_id, history in tracks.items():
|
| 565 |
+
for ts, r in history:
|
| 566 |
+
pred = self.predictions_df[self.predictions_df['timestamp'] == ts]
|
| 567 |
+
if pred.empty:
|
| 568 |
+
continue
|
| 569 |
+
pred = pred.iloc[0]
|
| 570 |
+
rows.append({
|
| 571 |
+
'timestamp': ts,
|
| 572 |
+
'datetime': pred['datetime'],
|
| 573 |
+
'prediction': pred['predictions'],
|
| 574 |
+
'groundtruth': pred.get('groundtruth', None),
|
| 575 |
+
'region_size': r.get('size', 0),
|
| 576 |
+
'sum_flux': r.get('sum_flux', 0.0),
|
| 577 |
+
'max_flux': r.get('max_flux', 0.0),
|
| 578 |
+
'mean_flux': r.get('sum_flux', 0.0) / max(r.get('size', 1), 1),
|
| 579 |
+
'centroid_patch_y': r.get('centroid_patch_y', 0.0),
|
| 580 |
+
'centroid_patch_x': r.get('centroid_patch_x', 0.0),
|
| 581 |
+
'centroid_img_y': r.get('centroid_img_y', 0.0),
|
| 582 |
+
'centroid_img_x': r.get('centroid_img_x', 0.0),
|
| 583 |
+
'peak_img_y': r.get('peak_img_y', None),
|
| 584 |
+
'peak_img_x': r.get('peak_img_x', None),
|
| 585 |
+
'region_label': r.get('region_label', None),
|
| 586 |
+
'track_id': track_id,
|
| 587 |
+
})
|
| 588 |
+
|
| 589 |
+
print(f"Recorded {len(rows)} events from {len(tracks)} tracks")
|
| 590 |
+
return pd.DataFrame(rows) if rows else pd.DataFrame()
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
# =============================================================================
|
| 594 |
+
# Frame Generation
|
| 595 |
+
# =============================================================================
|
| 596 |
+
|
| 597 |
+
# Colours cycled across FOXES tracks
|
| 598 |
+
_TRACK_COLORS = [
|
| 599 |
+
'#E6194B', '#3CB44B', '#FFE119', '#4363D8', '#F58231',
|
| 600 |
+
'#911EB4', '#42D4F4', '#F032E6', '#BFEF45', '#FABED4',
|
| 601 |
+
'#469990', '#DCBEFF', '#9A6324', '#FFFAC8', '#800000',
|
| 602 |
+
'#AAFFC3', '#808000', '#FFD8B1', '#000075', '#A9A9A9',
|
| 603 |
+
]
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def _generate_single_frame(args: Tuple) -> Optional[str]:
|
| 607 |
+
"""Render one frame. Designed for multiprocessing."""
|
| 608 |
+
setup_barlow_font()
|
| 609 |
+
|
| 610 |
+
frame_idx, timestamp, fd = args
|
| 611 |
+
try:
|
| 612 |
+
flare_events_df = fd['flare_events_df']
|
| 613 |
+
predictions_df = fd['predictions_df']
|
| 614 |
+
region_labels_cache = fd['region_labels_cache']
|
| 615 |
+
config = fd['config']
|
| 616 |
+
track_color_map = fd['track_color_map']
|
| 617 |
+
plot_window_hours = fd['plot_window_hours']
|
| 618 |
+
aia_path = fd['aia_path']
|
| 619 |
+
frames_dir = Path(fd['frames_dir'])
|
| 620 |
+
|
| 621 |
+
current_time = pd.to_datetime(timestamp)
|
| 622 |
+
window_start = current_time - pd.Timedelta(hours=plot_window_hours / 2)
|
| 623 |
+
window_end = current_time + pd.Timedelta(hours=plot_window_hours / 2)
|
| 624 |
+
|
| 625 |
+
# ── Figure: AIA (left) + SXR timeseries (right) ─────────────────────
|
| 626 |
+
fig = plt.figure(figsize=(14, 7))
|
| 627 |
+
gs = fig.add_gridspec(1, 2, width_ratios=[1, 1], wspace=0.3,
|
| 628 |
+
left=0.07, right=0.97, top=0.93, bottom=0.10)
|
| 629 |
+
ax_aia = fig.add_subplot(gs[0])
|
| 630 |
+
ax_sxr = fig.add_subplot(gs[1])
|
| 631 |
+
|
| 632 |
+
# ── AIA image ────────────────────────────────────────────────────────
|
| 633 |
+
aia_image = load_aia_image_at_time(Path(aia_path), timestamp) if aia_path else None
|
| 634 |
+
if aia_image is not None:
|
| 635 |
+
ax_aia.imshow(aia_image, origin='lower', aspect='equal', alpha=0.9)
|
| 636 |
+
else:
|
| 637 |
+
ax_aia.imshow(np.zeros((512, 512, 3)), origin='lower', aspect='equal')
|
| 638 |
+
|
| 639 |
+
ax_aia.set_title(f'{current_time.strftime("%Y-%m-%d %H:%M:%S")}', fontsize=11)
|
| 640 |
+
ax_aia.set_xlabel('X (pixels)', fontsize=9)
|
| 641 |
+
ax_aia.set_ylabel('Y (pixels)', fontsize=9)
|
| 642 |
+
|
| 643 |
+
# ── Region contours + FOXES markers ──────────────────────────────────
|
| 644 |
+
region_labels = region_labels_cache.get(timestamp)
|
| 645 |
+
current_events = (
|
| 646 |
+
flare_events_df[flare_events_df['timestamp'] == timestamp].copy()
|
| 647 |
+
if not flare_events_df.empty and 'timestamp' in flare_events_df.columns
|
| 648 |
+
else pd.DataFrame()
|
| 649 |
+
)
|
| 650 |
+
plotted_tracks: set = set()
|
| 651 |
+
|
| 652 |
+
for _, ev in current_events.iterrows():
|
| 653 |
+
tid = ev['track_id']
|
| 654 |
+
if tid in plotted_tracks:
|
| 655 |
+
continue
|
| 656 |
+
plotted_tracks.add(tid)
|
| 657 |
+
|
| 658 |
+
cx, cy = ev.get('centroid_img_x'), ev.get('centroid_img_y')
|
| 659 |
+
if pd.isna(cx) or pd.isna(cy) or not (0 <= cx <= 512) or not (0 <= cy <= 512):
|
| 660 |
+
continue
|
| 661 |
+
|
| 662 |
+
px = ev.get('peak_img_x') if pd.notna(ev.get('peak_img_x')) else cx
|
| 663 |
+
py = ev.get('peak_img_y') if pd.notna(ev.get('peak_img_y')) else cy
|
| 664 |
+
color = track_color_map.get(tid, _TRACK_COLORS[0])
|
| 665 |
+
cur_flux = ev.get('sum_flux', 0.0)
|
| 666 |
+
is_active = cur_flux >= config.min_flux_threshold
|
| 667 |
+
|
| 668 |
+
# Contour
|
| 669 |
+
rl = ev.get('region_label')
|
| 670 |
+
if region_labels is not None and pd.notna(rl) and int(rl) > 0:
|
| 671 |
+
region_mask = region_labels == int(rl)
|
| 672 |
+
if np.any(region_mask):
|
| 673 |
+
try:
|
| 674 |
+
# Upsample 64×64 mask to 512×512 for crisp contours on AIA image
|
| 675 |
+
scale = 512 // region_labels.shape[0]
|
| 676 |
+
mask_up = region_mask.repeat(scale, axis=0).repeat(scale, axis=1).astype(float)
|
| 677 |
+
ax_aia.contour(mask_up, levels=[0.5],
|
| 678 |
+
colors=color, linewidths=4.0 if is_active else 2.5,
|
| 679 |
+
alpha=0.9, extent=[0, 512, 0, 512])
|
| 680 |
+
except Exception:
|
| 681 |
+
pass
|
| 682 |
+
|
| 683 |
+
# Marker
|
| 684 |
+
if is_active:
|
| 685 |
+
ax_aia.plot(px, py, '*', markersize=15, color=color,
|
| 686 |
+
markeredgecolor='black', markeredgewidth=2, alpha=0.7, zorder=15)
|
| 687 |
+
ax_aia.annotate(f'FOXES: {flux_to_goes_class(cur_flux)}', (px, py),
|
| 688 |
+
xytext=(15, 15), textcoords='offset points', fontsize=11,
|
| 689 |
+
color='black', weight='bold',
|
| 690 |
+
bbox=dict(boxstyle='round,pad=0.3', facecolor=color,
|
| 691 |
+
alpha=0.95, edgecolor='black', linewidth=2))
|
| 692 |
+
else:
|
| 693 |
+
ax_aia.plot(px, py, 'o', markersize=10, color=color,
|
| 694 |
+
markeredgecolor='white', markeredgewidth=1.5, alpha=0.8, zorder=12)
|
| 695 |
+
|
| 696 |
+
# ── SXR timeseries ───────────────────────────────────────────────────
|
| 697 |
+
if predictions_df is not None and not predictions_df.empty:
|
| 698 |
+
in_win = predictions_df[
|
| 699 |
+
(predictions_df['datetime'] >= window_start) &
|
| 700 |
+
(predictions_df['datetime'] <= window_end)
|
| 701 |
+
]
|
| 702 |
+
if not in_win.empty:
|
| 703 |
+
if 'groundtruth' in in_win.columns:
|
| 704 |
+
ax_sxr.plot(in_win['datetime'], in_win['groundtruth'],
|
| 705 |
+
'b-', linewidth=1.5, alpha=0.8, label='GOES (Truth)')
|
| 706 |
+
if 'predictions' in in_win.columns:
|
| 707 |
+
ax_sxr.plot(in_win['datetime'], in_win['predictions'],
|
| 708 |
+
'r--', linewidth=1.5, alpha=0.8, label='FOXES')
|
| 709 |
+
|
| 710 |
+
# Track fluxes
|
| 711 |
+
all_tracks_in_win = (
|
| 712 |
+
flare_events_df[
|
| 713 |
+
(flare_events_df['datetime'] >= window_start) &
|
| 714 |
+
(flare_events_df['datetime'] <= window_end)
|
| 715 |
+
] if not flare_events_df.empty else pd.DataFrame()
|
| 716 |
+
)
|
| 717 |
+
first_other = True
|
| 718 |
+
for tid, tdata in (all_tracks_in_win.groupby('track_id') if not all_tracks_in_win.empty else []):
|
| 719 |
+
tdata = tdata.sort_values('datetime')
|
| 720 |
+
color = track_color_map.get(tid, _TRACK_COLORS[0])
|
| 721 |
+
is_active = tdata['sum_flux'].max() >= config.min_flux_threshold
|
| 722 |
+
if is_active:
|
| 723 |
+
ax_sxr.plot(tdata['datetime'], tdata['sum_flux'],
|
| 724 |
+
color=color, linewidth=2.5, alpha=0.9, label=f'Track {tid}', zorder=4)
|
| 725 |
+
else:
|
| 726 |
+
label = 'Other tracks' if first_other else None
|
| 727 |
+
ax_sxr.plot(tdata['datetime'], tdata['sum_flux'],
|
| 728 |
+
color=color, linewidth=0.9, alpha=0.35, label=label, zorder=3)
|
| 729 |
+
first_other = False
|
| 730 |
+
|
| 731 |
+
ax_sxr.axvline(current_time, color='#E5446D', linewidth=2, alpha=0.8, zorder=10)
|
| 732 |
+
ax_sxr.set_xlim(window_start, window_end)
|
| 733 |
+
ax_sxr.set_yscale('log')
|
| 734 |
+
ax_sxr.set_ylabel('Flux (W/m²)', fontsize=9)
|
| 735 |
+
ax_sxr.set_xlabel('Time (UTC)', fontsize=9)
|
| 736 |
+
ax_sxr.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
|
| 737 |
+
ax_sxr.xaxis.set_major_locator(mdates.HourLocator(interval=1))
|
| 738 |
+
plt.setp(ax_sxr.xaxis.get_majorticklabels(), rotation=0)
|
| 739 |
+
ax_sxr.legend(loc='lower right', fontsize=8, framealpha=1)
|
| 740 |
+
ax_sxr.grid(True, alpha=0.3)
|
| 741 |
+
|
| 742 |
+
plt.tight_layout()
|
| 743 |
+
|
| 744 |
+
fmt = getattr(config, 'movie_frame_format', 'jpg').lower()
|
| 745 |
+
dpi = getattr(config, 'movie_dpi', 75.0)
|
| 746 |
+
ext = 'jpg' if fmt in ('jpg', 'jpeg') else 'png'
|
| 747 |
+
frame_path = frames_dir / f"frame_{frame_idx:06d}.{ext}"
|
| 748 |
+
plt.savefig(frame_path, dpi=dpi, format=ext)
|
| 749 |
+
plt.close()
|
| 750 |
+
return str(frame_path)
|
| 751 |
+
|
| 752 |
+
except Exception as e:
|
| 753 |
+
plt.close('all')
|
| 754 |
+
print(f"Error creating frame {frame_idx} ({timestamp}): {e}")
|
| 755 |
+
return None
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
# =============================================================================
|
| 759 |
+
# Movie Assembly
|
| 760 |
+
# =============================================================================
|
| 761 |
+
|
| 762 |
+
def create_flare_movie(
|
| 763 |
+
flare_events_df: pd.DataFrame,
|
| 764 |
+
output_dir: Path,
|
| 765 |
+
config: FlareAnalysisConfig,
|
| 766 |
+
predictions_csv: Optional[str] = None,
|
| 767 |
+
analyzer: Optional[FluxContributionAnalyzer] = None,
|
| 768 |
+
fps: float = 2.0,
|
| 769 |
+
frame_interval_minutes: float = 1.0,
|
| 770 |
+
num_workers: int = 4,
|
| 771 |
+
) -> Optional[str]:
|
| 772 |
+
"""Generate per-timestamp frames and stitch into an MP4."""
|
| 773 |
+
setup_barlow_font()
|
| 774 |
+
|
| 775 |
+
if flare_events_df.empty:
|
| 776 |
+
print("No flare data — skipping movie.")
|
| 777 |
+
return None
|
| 778 |
+
|
| 779 |
+
output_dir = Path(output_dir)
|
| 780 |
+
movie_dir = output_dir / "movies"
|
| 781 |
+
movie_dir.mkdir(parents=True, exist_ok=True)
|
| 782 |
+
|
| 783 |
+
# Load predictions for timeseries
|
| 784 |
+
predictions_df = None
|
| 785 |
+
if predictions_csv and Path(predictions_csv).exists():
|
| 786 |
+
predictions_df = pd.read_csv(predictions_csv)
|
| 787 |
+
dt_col = 'datetime' if 'datetime' in predictions_df.columns else 'timestamp'
|
| 788 |
+
predictions_df['datetime'] = pd.to_datetime(predictions_df[dt_col])
|
| 789 |
+
|
| 790 |
+
flare_events_df = flare_events_df.copy()
|
| 791 |
+
flare_events_df['datetime'] = pd.to_datetime(flare_events_df['datetime'])
|
| 792 |
+
|
| 793 |
+
all_timestamps = sorted(flare_events_df['timestamp'].unique())
|
| 794 |
+
|
| 795 |
+
# Subsample by frame_interval_minutes
|
| 796 |
+
timestamps_to_use, last_dt = [], None
|
| 797 |
+
for ts in all_timestamps:
|
| 798 |
+
dt = pd.to_datetime(ts)
|
| 799 |
+
if last_dt is None or (dt - last_dt).total_seconds() >= frame_interval_minutes * 60:
|
| 800 |
+
timestamps_to_use.append(ts)
|
| 801 |
+
last_dt = dt
|
| 802 |
+
|
| 803 |
+
print(f"Creating movie: {len(timestamps_to_use)} frames @ {fps} fps")
|
| 804 |
+
|
| 805 |
+
# Assign consistent colours to tracks
|
| 806 |
+
unique_tracks = flare_events_df['track_id'].unique()
|
| 807 |
+
track_color_map = {tid: _TRACK_COLORS[i % len(_TRACK_COLORS)] for i, tid in enumerate(unique_tracks)}
|
| 808 |
+
|
| 809 |
+
frames_dir = movie_dir / "frames_temp"
|
| 810 |
+
frames_dir.mkdir(exist_ok=True)
|
| 811 |
+
|
| 812 |
+
frame_data = {
|
| 813 |
+
'flare_events_df': flare_events_df,
|
| 814 |
+
'predictions_df': predictions_df,
|
| 815 |
+
'frames_dir': str(frames_dir),
|
| 816 |
+
'region_labels_cache': analyzer.region_labels_cache if analyzer else {},
|
| 817 |
+
'config': config,
|
| 818 |
+
'track_color_map': track_color_map,
|
| 819 |
+
'plot_window_hours': config.plot_window_hours,
|
| 820 |
+
'aia_path': config.aia_path,
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
frame_args = [(i, ts, frame_data) for i, ts in enumerate(timestamps_to_use)]
|
| 824 |
+
|
| 825 |
+
if num_workers > 1:
|
| 826 |
+
with Pool(processes=num_workers) as pool:
|
| 827 |
+
results = list(tqdm(pool.imap(_generate_single_frame, frame_args),
|
| 828 |
+
total=len(frame_args), desc="Generating frames"))
|
| 829 |
+
else:
|
| 830 |
+
results = [_generate_single_frame(a) for a in tqdm(frame_args, desc="Generating frames")]
|
| 831 |
+
|
| 832 |
+
frame_paths = sorted(
|
| 833 |
+
(Path(p) for p in results if p is not None),
|
| 834 |
+
key=lambda p: p.name
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
if not frame_paths:
|
| 838 |
+
print("No frames generated.")
|
| 839 |
+
return None
|
| 840 |
+
|
| 841 |
+
# Stitch into MP4 via imageio (reads frames as RGB → passes to ffmpeg correctly)
|
| 842 |
+
datetimes = [pd.to_datetime(ts) for ts in timestamps_to_use]
|
| 843 |
+
movie_name = (f"flare_movie_{datetimes[0].strftime('%Y%m%d')}"
|
| 844 |
+
f"_{datetimes[-1].strftime('%Y%m%d')}.mp4")
|
| 845 |
+
movie_path = movie_dir / movie_name
|
| 846 |
+
|
| 847 |
+
# Read first frame to get dimensions
|
| 848 |
+
first_frame = imageio.imread(str(frame_paths[0]))
|
| 849 |
+
h, w = first_frame.shape[:2]
|
| 850 |
+
# yuv420p requires even dimensions
|
| 851 |
+
w = w if w % 2 == 0 else w - 1
|
| 852 |
+
h = h if h % 2 == 0 else h - 1
|
| 853 |
+
|
| 854 |
+
t0 = time.time()
|
| 855 |
+
writer = imageio_ffmpeg.write_frames(
|
| 856 |
+
str(movie_path),
|
| 857 |
+
size=(w, h),
|
| 858 |
+
fps=fps,
|
| 859 |
+
codec='libx264',
|
| 860 |
+
pix_fmt_in='rgb24',
|
| 861 |
+
pix_fmt_out='yuv420p',
|
| 862 |
+
output_params=['-preset', 'veryfast', '-crf', '25', '-movflags', '+faststart'],
|
| 863 |
+
)
|
| 864 |
+
writer.send(None) # initialise
|
| 865 |
+
for fp in tqdm(frame_paths, desc="Writing movie"):
|
| 866 |
+
if fp.exists():
|
| 867 |
+
frame = imageio.imread(str(fp))
|
| 868 |
+
writer.send(frame[:h, :w].tobytes())
|
| 869 |
+
writer.close()
|
| 870 |
+
|
| 871 |
+
print(f"Movie saved → {movie_path} ({time.time() - t0:.1f}s)")
|
| 872 |
+
print(f"Frames kept → {frames_dir}")
|
| 873 |
+
|
| 874 |
+
return str(movie_path)
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
# =============================================================================
|
| 878 |
+
# Entry point
|
| 879 |
+
# =============================================================================
|
| 880 |
+
|
| 881 |
+
def main() -> None:
|
| 882 |
+
parser = argparse.ArgumentParser(description="Flare Analysis — Frame & Movie Generator")
|
| 883 |
+
parser.add_argument("--config", required=True, help="Path to YAML config file")
|
| 884 |
+
args = parser.parse_args()
|
| 885 |
+
|
| 886 |
+
config = FlareAnalysisConfig.from_yaml(args.config)
|
| 887 |
+
|
| 888 |
+
run_ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 889 |
+
out_dir = Path(config.output_dir or '.') / f"run_{run_ts}"
|
| 890 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 891 |
+
print(f"Output: {out_dir}")
|
| 892 |
+
|
| 893 |
+
analyzer = FluxContributionAnalyzer(config, output_dir=out_dir)
|
| 894 |
+
flare_events_df = analyzer.detect_flare_events()
|
| 895 |
+
|
| 896 |
+
if not flare_events_df.empty:
|
| 897 |
+
flare_events_df.to_csv(out_dir / "flare_events.csv", index=False)
|
| 898 |
+
print(f"Saved {len(flare_events_df)} events → {out_dir / 'flare_events.csv'}")
|
| 899 |
+
|
| 900 |
+
if config.create_movie:
|
| 901 |
+
create_flare_movie(
|
| 902 |
+
flare_events_df = flare_events_df,
|
| 903 |
+
output_dir = out_dir,
|
| 904 |
+
config = config,
|
| 905 |
+
predictions_csv = config.predictions_csv,
|
| 906 |
+
analyzer = analyzer,
|
| 907 |
+
fps = config.movie_fps,
|
| 908 |
+
frame_interval_minutes = config.movie_frame_interval_minutes,
|
| 909 |
+
num_workers = config.movie_num_workers,
|
| 910 |
+
)
|
| 911 |
+
|
| 912 |
+
print(f"\nDone. Results in {out_dir}")
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
if __name__ == "__main__":
|
| 916 |
+
main()
|
analysis/flux_map_config.yaml
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# =============================================================================
|
| 2 |
+
# Flare Analysis Configuration
|
| 3 |
+
# =============================================================================
|
| 4 |
+
# Usage: python analysis/flux_map_analysis.py --config analysis/flux_map_config.yaml
|
| 5 |
+
# =============================================================================
|
| 6 |
+
|
| 7 |
+
# -----------------------------------------------------------------------------
|
| 8 |
+
# Paths
|
| 9 |
+
# -----------------------------------------------------------------------------
|
| 10 |
+
paths:
|
| 11 |
+
flux_path: "/Volumes/T9/FOXES_Data/flux"
|
| 12 |
+
aia_path: "/Volumes/T9/FOXES_Data/AIA_processed"
|
| 13 |
+
predictions_csv: "/Volumes/T9/FOXES_Misc/batch_results/vit/vit_predictions_test.csv"
|
| 14 |
+
output_dir: "/Volumes/T9/FOXES_Data/flare_analysis"
|
| 15 |
+
|
| 16 |
+
# -----------------------------------------------------------------------------
|
| 17 |
+
# Time Range (null = use full predictions CSV range)
|
| 18 |
+
# -----------------------------------------------------------------------------
|
| 19 |
+
time_range:
|
| 20 |
+
start: "2024-03-23T00:00:00" # e.g. "2024-03-23T00:00:00"
|
| 21 |
+
end: "2024-03-23T3:00:00"
|
| 22 |
+
|
| 23 |
+
# -----------------------------------------------------------------------------
|
| 24 |
+
# Detection
|
| 25 |
+
# -----------------------------------------------------------------------------
|
| 26 |
+
detection:
|
| 27 |
+
min_flux_threshold: 1.0e-7 # W/m² — patches below this are ignored
|
| 28 |
+
threshold_std_multiplier: 4.0 # flux mask: mean + N*std
|
| 29 |
+
spatial_smoothing_sigma: 1.0 # Gaussian pre-smoothing (patches, 0 = off)
|
| 30 |
+
radial_expansion_threshold_percentile: 30.0 # flood-fill growth cutoff percentile
|
| 31 |
+
peak_neighborhood_sizes: [10, 15, 20, 25] # multi-scale local-max windows
|
| 32 |
+
peak_min_scale_agreement: 2 # peaks must appear at N scales
|
| 33 |
+
peak_scale_tolerance: 10 # patch-distance to count as "same peak"
|
| 34 |
+
min_peak_distance: 5 # min patches between distinct peaks
|
| 35 |
+
|
| 36 |
+
# -----------------------------------------------------------------------------
|
| 37 |
+
# Grid / Patch Parameters
|
| 38 |
+
# -----------------------------------------------------------------------------
|
| 39 |
+
grid:
|
| 40 |
+
grid_size: [64, 64]
|
| 41 |
+
patch_size: 8
|
| 42 |
+
input_size: 512
|
| 43 |
+
|
| 44 |
+
# -----------------------------------------------------------------------------
|
| 45 |
+
# Region Tracking
|
| 46 |
+
# -----------------------------------------------------------------------------
|
| 47 |
+
tracking:
|
| 48 |
+
max_tracking_distance: 10 # max patch-distance between frames to link regions
|
| 49 |
+
flux_ratio_weight: 0 # weight of flux-ratio term in linking score
|
| 50 |
+
size_ratio_weight: 0 # weight of size-ratio term in linking score
|
| 51 |
+
distance_weight: 1.0 # weight of spatial distance in linking score
|
| 52 |
+
age_bonus_weight: 2.0 # bias toward established tracks (scales 1/(1+age))
|
| 53 |
+
cadence_seconds: 60.0 # expected data cadence
|
| 54 |
+
max_gap_frames: 15 # frames a track can go undetected before expiring
|
| 55 |
+
|
| 56 |
+
# -----------------------------------------------------------------------------
|
| 57 |
+
# Movie / Output
|
| 58 |
+
# -----------------------------------------------------------------------------
|
| 59 |
+
movie:
|
| 60 |
+
create_movie: true
|
| 61 |
+
plot_window_hours: 2.0 # SXR plot time window around current frame
|
| 62 |
+
fps: 30.0
|
| 63 |
+
frame_interval_minutes: 1.0 # one frame per minute of data
|
| 64 |
+
num_workers: 8
|
| 65 |
+
dpi: 75.0
|
| 66 |
+
frame_format: "jpg" # "jpg" (fast) or "png" (quality)
|
| 67 |
+
jpeg_quality: 90
|
analysis/spatial_performance.py
CHANGED
|
@@ -46,7 +46,7 @@ CROP_FACTOR = 1.1 # AIA images cropped at 1.1 solar radii
|
|
| 46 |
SOLAR_RADIUS_PATCHES = (GRID_SIZE / 2) / CROP_FACTOR # ≈ 29.1 patches
|
| 47 |
|
| 48 |
# Patches beyond ±PATCH_CROP_RADIUS from center (in original 64×64 patch units) are masked.
|
| 49 |
-
PATCH_CROP_RADIUS =
|
| 50 |
|
| 51 |
# Percentile cap for colorbar scaling (applied to non-NaN values).
|
| 52 |
# e.g. 99 clips the top 1% of values so detail in the bulk is visible.
|
|
|
|
| 46 |
SOLAR_RADIUS_PATCHES = (GRID_SIZE / 2) / CROP_FACTOR # ≈ 29.1 patches
|
| 47 |
|
| 48 |
# Patches beyond ±PATCH_CROP_RADIUS from center (in original 64×64 patch units) are masked.
|
| 49 |
+
PATCH_CROP_RADIUS = 24
|
| 50 |
|
| 51 |
# Percentile cap for colorbar scaling (applied to non-NaN values).
|
| 52 |
# e.g. 99 clips the top 1% of values so detail in the bulk is visible.
|
forecasting/inference/flare_analysis.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
forecasting/inference/flare_analysis_poster.py
DELETED
|
The diff for this file is too large to render.
See raw diff
|
|
|
pipeline_config.yaml
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
#
|
| 6 |
# Usage:
|
| 7 |
# python run_pipeline.py --config pipeline_config.yaml --steps all
|
| 8 |
-
# python run_pipeline.py --config pipeline_config.yaml --steps train,inference,
|
| 9 |
# python run_pipeline.py --list
|
| 10 |
#
|
| 11 |
# Variables
|
|
@@ -152,6 +152,23 @@ spatial_performance:
|
|
| 152 |
predictions_csv: "${base_dir}/inference/predictions.csv"
|
| 153 |
out_dir: "${base_dir}/inference/spatial_performance"
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
# -----------------------------------------------------------------------------
|
| 156 |
# Evaluation (step: evaluate)
|
| 157 |
# -----------------------------------------------------------------------------
|
|
|
|
| 5 |
#
|
| 6 |
# Usage:
|
| 7 |
# python run_pipeline.py --config pipeline_config.yaml --steps all
|
| 8 |
+
# python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flux_map_analysis
|
| 9 |
# python run_pipeline.py --list
|
| 10 |
#
|
| 11 |
# Variables
|
|
|
|
| 152 |
predictions_csv: "${base_dir}/inference/predictions.csv"
|
| 153 |
out_dir: "${base_dir}/inference/spatial_performance"
|
| 154 |
|
| 155 |
+
# -----------------------------------------------------------------------------
|
| 156 |
+
# Flux map analysis (step: flux_map_analysis)
|
| 157 |
+
# Detects and tracks active regions from per-patch flux contribution maps,
|
| 158 |
+
# renders side-by-side AIA + SXR frames, and stitches them into a movie.
|
| 159 |
+
# -----------------------------------------------------------------------------
|
| 160 |
+
flux_map_analysis:
|
| 161 |
+
config: "analysis/flux_map_config.yaml"
|
| 162 |
+
# overrides: # uncomment to override config values
|
| 163 |
+
# paths:
|
| 164 |
+
# flux_path: "${base_dir}/flux"
|
| 165 |
+
# aia_path: "${base_dir}/AIA_processed"
|
| 166 |
+
# predictions_csv: "${base_dir}/inference/predictions.csv"
|
| 167 |
+
# output_dir: "${base_dir}/inference/flux_map_analysis"
|
| 168 |
+
# time_range:
|
| 169 |
+
# start: null # null = full predictions CSV range
|
| 170 |
+
# end: null
|
| 171 |
+
|
| 172 |
# -----------------------------------------------------------------------------
|
| 173 |
# Evaluation (step: evaluate)
|
| 174 |
# -----------------------------------------------------------------------------
|
run_pipeline.py
CHANGED
|
@@ -13,12 +13,12 @@ Runs any combination of pipeline steps in order:
|
|
| 13 |
6. normalize - Compute SXR normalization stats on train split (data/sxr_normalization.py)
|
| 14 |
7. train - Train the ViTLocal forecasting model (forecasting/training/train.py)
|
| 15 |
8. inference - Run batch inference on val/test data (forecasting/inference/inference.py)
|
| 16 |
-
9.
|
| 17 |
|
| 18 |
Usage:
|
| 19 |
python run_pipeline.py --list
|
| 20 |
python run_pipeline.py --config pipeline_config.yaml --steps all
|
| 21 |
-
python run_pipeline.py --config pipeline_config.yaml --steps train,inference
|
| 22 |
"""
|
| 23 |
|
| 24 |
import argparse
|
|
@@ -116,9 +116,9 @@ STEP_ORDER = [
|
|
| 116 |
"train",
|
| 117 |
"inference",
|
| 118 |
"evaluate",
|
| 119 |
-
"flare_analysis",
|
| 120 |
"ablation",
|
| 121 |
"spatial_performance",
|
|
|
|
| 122 |
]
|
| 123 |
|
| 124 |
STEP_INFO = {
|
|
@@ -166,10 +166,6 @@ STEP_INFO = {
|
|
| 166 |
"description": "Compute metrics and generate evaluation plots from predictions CSV",
|
| 167 |
"script": ROOT / "forecasting" / "inference" / "evaluation.py",
|
| 168 |
},
|
| 169 |
-
"flare_analysis": {
|
| 170 |
-
"description": "Detect, track, and match flares; generate plots/movies",
|
| 171 |
-
"script": ROOT / "forecasting" / "inference" / "flare_analysis.py",
|
| 172 |
-
},
|
| 173 |
"ablation": {
|
| 174 |
"description": "Run Gaussian noise channel-masking ablation study on pretrained model",
|
| 175 |
"script": ROOT / "forecasting" / "inference" / "ablation_inference.py",
|
|
@@ -178,6 +174,10 @@ STEP_INFO = {
|
|
| 178 |
"description": "Generate flux-weighted spatial error heatmap on the solar disk",
|
| 179 |
"script": ROOT / "analysis" / "spatial_performance.py",
|
| 180 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
}
|
| 182 |
|
| 183 |
|
|
@@ -329,15 +329,6 @@ def build_commands(step: str, cfg: dict, force: bool) -> list[list[str]] | None:
|
|
| 329 |
config_path = str(write_merged_config(config_path, ev["overrides"], "evaluate_config"))
|
| 330 |
return [base + ["-config", config_path]]
|
| 331 |
|
| 332 |
-
if step == "flare_analysis":
|
| 333 |
-
if not require(["config"], "inference"):
|
| 334 |
-
return None
|
| 335 |
-
inf = cfg["inference"]
|
| 336 |
-
config_path = inf["config"]
|
| 337 |
-
if inf.get("overrides"):
|
| 338 |
-
config_path = str(write_merged_config(config_path, inf["overrides"], "inference_config"))
|
| 339 |
-
return [base + ["--config", config_path]]
|
| 340 |
-
|
| 341 |
if step == "ablation":
|
| 342 |
if not require(["config"], "ablation"):
|
| 343 |
return None
|
|
@@ -358,6 +349,15 @@ def build_commands(step: str, cfg: dict, force: bool) -> list[list[str]] | None:
|
|
| 358 |
cmd += ["--out_dir", sp["out_dir"]]
|
| 359 |
return [cmd]
|
| 360 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
return [base]
|
| 362 |
|
| 363 |
|
|
@@ -398,7 +398,7 @@ def list_steps():
|
|
| 398 |
print(f" {i}. {step:<16} {STEP_INFO[step]['description']}")
|
| 399 |
print()
|
| 400 |
print("Use --steps all to run every step, or comma-separate specific steps.")
|
| 401 |
-
print("Example: --steps train,inference
|
| 402 |
|
| 403 |
|
| 404 |
def main():
|
|
|
|
| 13 |
6. normalize - Compute SXR normalization stats on train split (data/sxr_normalization.py)
|
| 14 |
7. train - Train the ViTLocal forecasting model (forecasting/training/train.py)
|
| 15 |
8. inference - Run batch inference on val/test data (forecasting/inference/inference.py)
|
| 16 |
+
9. flux_map_analysis - Detect, track, and match flares (analysis/flux_map_analysis.py)
|
| 17 |
|
| 18 |
Usage:
|
| 19 |
python run_pipeline.py --list
|
| 20 |
python run_pipeline.py --config pipeline_config.yaml --steps all
|
| 21 |
+
python run_pipeline.py --config pipeline_config.yaml --steps train,inference
|
| 22 |
"""
|
| 23 |
|
| 24 |
import argparse
|
|
|
|
| 116 |
"train",
|
| 117 |
"inference",
|
| 118 |
"evaluate",
|
|
|
|
| 119 |
"ablation",
|
| 120 |
"spatial_performance",
|
| 121 |
+
"flux_map_analysis",
|
| 122 |
]
|
| 123 |
|
| 124 |
STEP_INFO = {
|
|
|
|
| 166 |
"description": "Compute metrics and generate evaluation plots from predictions CSV",
|
| 167 |
"script": ROOT / "forecasting" / "inference" / "evaluation.py",
|
| 168 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
"ablation": {
|
| 170 |
"description": "Run Gaussian noise channel-masking ablation study on pretrained model",
|
| 171 |
"script": ROOT / "forecasting" / "inference" / "ablation_inference.py",
|
|
|
|
| 174 |
"description": "Generate flux-weighted spatial error heatmap on the solar disk",
|
| 175 |
"script": ROOT / "analysis" / "spatial_performance.py",
|
| 176 |
},
|
| 177 |
+
"flux_map_analysis": {
|
| 178 |
+
"description": "Detect and track active regions from flux maps; render per-frame movie",
|
| 179 |
+
"script": ROOT / "analysis" / "flux_map_analysis.py",
|
| 180 |
+
},
|
| 181 |
}
|
| 182 |
|
| 183 |
|
|
|
|
| 329 |
config_path = str(write_merged_config(config_path, ev["overrides"], "evaluate_config"))
|
| 330 |
return [base + ["-config", config_path]]
|
| 331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 332 |
if step == "ablation":
|
| 333 |
if not require(["config"], "ablation"):
|
| 334 |
return None
|
|
|
|
| 349 |
cmd += ["--out_dir", sp["out_dir"]]
|
| 350 |
return [cmd]
|
| 351 |
|
| 352 |
+
if step == "flux_map_analysis":
|
| 353 |
+
if not require(["config"], "flux_map_analysis"):
|
| 354 |
+
return None
|
| 355 |
+
fma = cfg["flux_map_analysis"]
|
| 356 |
+
config_path = fma["config"]
|
| 357 |
+
if fma.get("overrides"):
|
| 358 |
+
config_path = str(write_merged_config(config_path, fma["overrides"], "flux_map_analysis_config"))
|
| 359 |
+
return [base + ["--config", config_path]]
|
| 360 |
+
|
| 361 |
return [base]
|
| 362 |
|
| 363 |
|
|
|
|
| 398 |
print(f" {i}. {step:<16} {STEP_INFO[step]['description']}")
|
| 399 |
print()
|
| 400 |
print("Use --steps all to run every step, or comma-separate specific steps.")
|
| 401 |
+
print("Example: --steps train,inference\n")
|
| 402 |
|
| 403 |
|
| 404 |
def main():
|