griffingoodwin04 commited on
Commit
c05a501
·
1 Parent(s): 4c7c808

add flux map analysis script and configuration for detecting and tracking active regions

Browse files
README.md CHANGED
@@ -44,6 +44,11 @@ The solar soft X-ray (SXR) irradiance is a long-standing proxy of solar activity
44
 
45
  ```text
46
  FOXES
 
 
 
 
 
47
  ├── data # Data cleaning and preprocessing
48
  │ ├── align_data.py # Align AIA and SXR timestamps; save matched pairs
49
  │ ├── euv_data_cleaning.py # EUV image quality filtering and cleaning
@@ -67,8 +72,7 @@ FOXES
67
  │ ├── inference
68
  │ │ ├── inference.py # Batch inference; writes predictions.csv
69
  │ │ ├── evaluation.py # Compute metrics and generate evaluation plots
70
- │ │ ├── flare_analysis.py # Detect, track, and match flares; generate plots
71
- │ │ ├── local_config.yaml # Config for inference.py and flare_analysis.py
72
  │ │ └── evaluation_config.yaml # Config for evaluation.py
73
  │ ├── models
74
  │ │ └── vit_patch_model_local.py # ViTLocal: Vision Transformer with patch flux heads
@@ -126,7 +130,9 @@ FOXES uses a single orchestrator script (`run_pipeline.py`) and a top-level conf
126
  | 7 | `train` | Train the ViTLocal solar flare forecasting model |
127
  | 8 | `inference` | Run batch inference and save a predictions CSV |
128
  | 9 | `evaluate` | Compute metrics and generate evaluation plots |
129
- | 10 | `flare_analysis` | Detect, track, and match flares; generate plots/movies |
 
 
130
 
131
  ### Usage
132
 
@@ -245,6 +251,7 @@ Steps can also be run individually by calling their scripts directly:
245
  python forecasting/training/train.py -config forecasting/training/train_config.yaml
246
  python forecasting/inference/inference.py -config forecasting/inference/local_config.yaml
247
  python forecasting/inference/evaluation.py -config forecasting/inference/evaluation_config.yaml
 
248
  ```
249
 
250
  ---
 
44
 
45
  ```text
46
  FOXES
47
+ ├── analysis # Post-inference analysis scripts
48
+ │ ├── flux_map_analysis.py # Detect, track, and visualize active regions from flux maps
49
+ │ ├── flux_map_config.yaml # Config for flux_map_analysis.py
50
+ │ ├── spatial_performance.py # Flux-weighted spatial error heatmap on the solar disk
51
+ │ └── ablation_analysis.py # Ablation study visualization
52
  ├── data # Data cleaning and preprocessing
53
  │ ├── align_data.py # Align AIA and SXR timestamps; save matched pairs
54
  │ ├── euv_data_cleaning.py # EUV image quality filtering and cleaning
 
72
  │ ├── inference
73
  │ │ ├── inference.py # Batch inference; writes predictions.csv
74
  │ │ ├── evaluation.py # Compute metrics and generate evaluation plots
75
+ │ │ ├── local_config.yaml # Config for inference.py
 
76
  │ │ └── evaluation_config.yaml # Config for evaluation.py
77
  │ ├── models
78
  │ │ └── vit_patch_model_local.py # ViTLocal: Vision Transformer with patch flux heads
 
130
  | 7 | `train` | Train the ViTLocal solar flare forecasting model |
131
  | 8 | `inference` | Run batch inference and save a predictions CSV |
132
  | 9 | `evaluate` | Compute metrics and generate evaluation plots |
133
+ | 10 | `ablation` | Run channel-masking ablation study on a pretrained model |
134
+ | 11 | `spatial_performance` | Generate flux-weighted spatial error heatmap on the solar disk |
135
+ | 12 | `flux_map_analysis` | Detect and track active regions from flux maps; render frames and a movie |
136
 
137
  ### Usage
138
 
 
251
  python forecasting/training/train.py -config forecasting/training/train_config.yaml
252
  python forecasting/inference/inference.py -config forecasting/inference/local_config.yaml
253
  python forecasting/inference/evaluation.py -config forecasting/inference/evaluation_config.yaml
254
+ python analysis/flux_map_analysis.py --config analysis/flux_map_config.yaml
255
  ```
256
 
257
  ---
analysis/ablation_analysis.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import matplotlib.ticker as mticker
5
+ from matplotlib import rcParams
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ PROJECT_ROOT = Path(__file__).parent.parent
10
+ sys.path.insert(0, str(PROJECT_ROOT))
11
+ from forecasting.inference.evaluation import setup_barlow_font
12
+
13
+ setup_barlow_font()
14
+
15
+ DATA_DIR = "/Users/griffingoodwin/Documents/gitrepos/FOXES/Untracked/data"
16
+ WAVELENGTHS = ["94", "131", "171", "193", "211", "304", "335", "stereo", "all"]
17
+ LABELS = ["Ablate 94 Å", "Ablate 131 Å", "Ablate 171 Å", "Ablate 193 Å",
18
+ "Ablate 211 Å", "Ablate 304 Å", "Ablate 335 Å", "Ablate STEREO", "Ablate All"]
19
+
20
+ rcParams['font.family'] = 'sans-serif'
21
+ rcParams['font.sans-serif'] = ['Barlow', 'Arial', 'DejaVu Sans']
22
+
23
+ FLARE_CLASSES = {
24
+ 'A1.0': (1e-8, 1e-7),
25
+ 'B1.0': (1e-7, 1e-6),
26
+ 'C1.0': (1e-6, 1e-5),
27
+ 'M1.0': (1e-5, 1e-4),
28
+ 'X1.0': (1e-4, 1e-3),
29
+ }
30
+
31
+ text_color = 'black'
32
+ grid_color = '#CCCCCC'
33
+
34
+ VMIN_GLOBAL = 1e-9
35
+ VMAX_GLOBAL = 1e-2
36
+
37
+
38
+ def add_flare_class_axes(ax, vmin, vmax):
39
+ def identity(x):
40
+ return x
41
+
42
+ ax_top = ax.secondary_xaxis('top', functions=(identity, identity))
43
+ ax_right = ax.secondary_yaxis('right', functions=(identity, identity))
44
+
45
+ positions, labels = [], []
46
+ for cls, (lo, hi) in FLARE_CLASSES.items():
47
+ if vmin <= lo <= vmax:
48
+ positions.append(lo)
49
+ labels.append(cls)
50
+
51
+ ax_top.set_xticks(positions)
52
+ ax_top.set_xticklabels(labels, fontsize=6, color=text_color, rotation=45, ha='left')
53
+ ax_top.grid(False)
54
+ ax_top.tick_params(length=3)
55
+
56
+ ax_right.set_yticks(positions)
57
+ ax_right.set_yticklabels(labels, fontsize=6, color=text_color)
58
+ ax_right.grid(False)
59
+ ax_right.tick_params(length=3)
60
+
61
+
62
+ fig, axes = plt.subplots(3, 3, figsize=(16, 14), layout='constrained')
63
+ axes = axes.flatten()
64
+
65
+ hb_last = None # for shared colorbar
66
+
67
+ for i, (wav, label) in enumerate(zip(WAVELENGTHS, LABELS)):
68
+ ab = pd.read_csv(f"{DATA_DIR}/ablate_{wav}_global_1.csv")
69
+ gt = ab["groundtruth"].values
70
+ pred = ab["predictions"].values
71
+
72
+ mask = (gt > 0) & (pred > 0)
73
+ gt, pred = gt[mask], pred[mask]
74
+
75
+ log_mae = np.mean(np.abs(np.log10(gt) - np.log10(pred)))
76
+
77
+ vmin = max(VMIN_GLOBAL, min(gt.min(), pred.min()))
78
+ vmax = min(VMAX_GLOBAL, max(gt.max(), pred.max()))
79
+
80
+ ax = axes[i]
81
+ ax.set_facecolor("#FFFFFF")
82
+
83
+ hb = ax.hexbin(gt, pred, gridsize=80, xscale='log', yscale='log',
84
+ cmap='bone', mincnt=1, bins='log',
85
+ extent=(np.log10(vmin), np.log10(vmax),
86
+ np.log10(vmin), np.log10(vmax)))
87
+ hb_last = hb
88
+
89
+ # 1:1 line
90
+ ax.plot([vmin, vmax], [vmin, vmax], ls='--', c='red', alpha=0.85, lw=1.2)
91
+
92
+ ax.set_xlim(vmin, vmax)
93
+ ax.set_ylim(vmin, vmax)
94
+ ax.set_xscale('log')
95
+ ax.set_yscale('log')
96
+
97
+ #ax.set_title(label, fontsize=11, fontweight='bold', color=text_color)
98
+ ax.set_xlabel(r'Ground Truth (W/m$^2$)', fontsize=8, color=text_color)
99
+ ax.set_ylabel(r'Prediction (W/m$^2$)', fontsize=8, color=text_color)
100
+ ax.tick_params(labelsize=7, colors=text_color)
101
+ ax.grid(True, alpha=0.5, color=grid_color, linewidth=0.5)
102
+ ax.set_axisbelow(True)
103
+
104
+ for lbl in ax.get_xticklabels():
105
+ lbl.set_fontfamily('Barlow')
106
+ for lbl in ax.get_yticklabels():
107
+ lbl.set_fontfamily('Barlow')
108
+
109
+ ax.text(0.04, 0.96, f"Log MAE = {log_mae:.3f}",
110
+ transform=ax.transAxes, fontsize=8, va='top', color=text_color,
111
+ bbox=dict(boxstyle='round,pad=0.3', facecolor='white',
112
+ edgecolor='#CCCCCC', alpha=0.85))
113
+
114
+ add_flare_class_axes(ax, vmin, vmax)
115
+
116
+ # Shared colorbar
117
+ cbar = fig.colorbar(hb_last, ax=axes.tolist(), orientation='vertical', shrink=0.6, pad=0.01)
118
+ cbar.set_label("Count (log)", fontsize=11, color=text_color)
119
+ cbar.ax.tick_params(labelsize=9, colors=text_color)
120
+ cbar.ax.yaxis.set_minor_locator(mticker.LogLocator(base=10, subs='auto', numticks=10))
121
+ cbar.ax.tick_params(which='minor', colors=text_color)
122
+
123
+ #fig.suptitle("Ablation Study: Channel Masking vs. Baseline", fontsize=14, fontweight='bold')
124
+ plt.savefig("/Users/griffingoodwin/Documents/gitrepos/FOXES/analysis/ablation_3x3.png", dpi=150, bbox_inches="tight")
125
+ plt.show()
126
+ print("Saved: analysis/ablation_3x3.png")
analysis/ablation_lollipop.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as mpatches
6
+ import matplotlib.font_manager as fm
7
+ from matplotlib import rcParams
8
+
9
+
10
+ def setup_barlow_font():
11
+ try:
12
+ barlow_fonts = [f.name for f in fm.fontManager.ttflist
13
+ if 'barlow' in f.name.lower() or 'barlow' in f.fname.lower()]
14
+ if barlow_fonts:
15
+ rcParams['font.family'] = 'Barlow'
16
+ else:
17
+ for path in ['/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf',
18
+ '/Users/griffingoodwin/Library/Fonts/Barlow-Regular.otf']:
19
+ if os.path.exists(path):
20
+ fm.fontManager.addfont(path)
21
+ rcParams['font.family'] = 'Barlow'
22
+ break
23
+ else:
24
+ rcParams['font.family'] = 'sans-serif'
25
+ except Exception:
26
+ rcParams['font.family'] = 'sans-serif'
27
+
28
+ setup_barlow_font()
29
+
30
+ DATA_DIR = "/Users/griffingoodwin/Documents/gitrepos/FOXES/Untracked/data"
31
+ BASELINE_CSV = "/Volumes/T9/FOXES_Misc/batch_results/vit/vit_predictions_test.csv"
32
+
33
+ WAVELENGTHS = ["94", "131", "171", "193", "211", "304", "335","STEREO"]
34
+ LABELS = {
35
+ "94": "Ablate 94 Å",
36
+ "131": "Ablate 131 Å",
37
+ "171": "Ablate 171 Å",
38
+ "193": "Ablate 193 Å",
39
+ "211": "Ablate 211 Å",
40
+ "304": "Ablate 304 Å",
41
+ "335": "Ablate 335 Å",
42
+ "STEREO": "Ablate 94, 131, 335 Å\n(STEREO)",
43
+ }
44
+
45
+ FLARE_CLASSES = {
46
+ '< C': (1e-15, 1e-6),
47
+ 'C': (1e-6, 1e-5),
48
+ 'M': (1e-5, 1e-4),
49
+ 'X': (1e-4, 1e-2),
50
+ }
51
+
52
+ CLASS_COLORS = {
53
+ '< C': '#4C9BE8',
54
+ 'C': '#56C490',
55
+ 'M': '#F5A623',
56
+ 'X': '#E84C4C',
57
+ }
58
+
59
+ # ── Compute metrics ────────────────────────────────────────────────────────────
60
+ def compute_row(label, gt, pred, is_baseline=False):
61
+ mask = (gt > 0) & (pred > 0)
62
+ gt, pred = gt[mask], pred[mask]
63
+ overall = np.mean(np.abs(np.log10(gt) - np.log10(pred)))
64
+ row = {"label": label, "overall": overall, "is_baseline": is_baseline}
65
+ for cls, (lo, hi) in FLARE_CLASSES.items():
66
+ m = (gt >= lo) & (gt < hi)
67
+ row[cls] = np.mean(np.abs(np.log10(gt[m]) - np.log10(pred[m]))) if m.sum() > 5 else np.nan
68
+ return row
69
+
70
+ records = []
71
+
72
+ # Baseline
73
+ bl = pd.read_csv(BASELINE_CSV)
74
+ records.append(compute_row("FOXES (no ablation)",
75
+ bl["groundtruth"].values, bl["predictions"].values,
76
+ is_baseline=True))
77
+
78
+ for wav in WAVELENGTHS:
79
+ ab = pd.read_csv(f"{DATA_DIR}/ablate_{wav}_global_1.csv")
80
+ records.append(compute_row(LABELS[wav], ab["groundtruth"].values, ab["predictions"].values))
81
+
82
+ # Sort ablation rows by overall MAE (worst first), keep baseline pinned at bottom
83
+ ablation_df = pd.DataFrame([r for r in records if not r["is_baseline"]])
84
+ ablation_df = ablation_df.sort_values("overall", ascending=False).reset_index(drop=True)
85
+ baseline_df = pd.DataFrame([r for r in records if r["is_baseline"]])
86
+ df = pd.concat([ablation_df, baseline_df], ignore_index=True)
87
+
88
+ # ── Plot ───────────────────────────────────────────────────────────────────────
89
+ n_rows = len(df)
90
+ fig, ax = plt.subplots(figsize=(11, 0.6 * n_rows + 1.5))
91
+ #ax.set_facecolor("#FAFAFA")
92
+ fig.patch.set_facecolor("#FFFFFF")
93
+
94
+ y_positions = np.arange(n_rows)
95
+
96
+ # Separator line between ablations and baseline
97
+ ax.axhline(y=n_rows - 1.5, color="#BBBBBB", linewidth=1, linestyle=":", zorder=1)
98
+
99
+ for i, row in df.iterrows():
100
+ y = y_positions[i]
101
+ is_bl = row["is_baseline"]
102
+
103
+ # Highlight baseline row
104
+ if is_bl:
105
+ ax.axhspan(y - 0.45, y + 0.45, color="#EEF6FF", zorder=0)
106
+
107
+ # Span line across per-class range
108
+ class_vals = [row[c] for c in FLARE_CLASSES if not np.isnan(row[c])]
109
+ if class_vals:
110
+ ax.hlines(y, min(class_vals), max(class_vals),
111
+ color="#CCCCCC", linewidth=2, zorder=1)
112
+
113
+ # Stem from 0 to overall
114
+ ax.hlines(y, 0, row["overall"],
115
+ color="#AAAAAA", linewidth=1.2, linestyle="--", zorder=0, alpha=0.6)
116
+
117
+ # Per-class dots
118
+ for cls in FLARE_CLASSES:
119
+ val = row[cls]
120
+ if not np.isnan(val):
121
+ ax.scatter(val, y, color=CLASS_COLORS[cls], s=80, zorder=4,
122
+ edgecolors="white", linewidths=0.6, alpha=0.75)
123
+
124
+ # Overall dot
125
+ outline_color = "#1A6BBF" if is_bl else "black"
126
+ ax.scatter(row["overall"], y, color="white", s=190, zorder=3,
127
+ edgecolors=outline_color, linewidths=2.0 if is_bl else 1.5, alpha=0.75)
128
+ ax.scatter(row["overall"], y, color=outline_color, s=75, zorder=3,
129
+ marker="|", linewidths=1.5, alpha=0.75)
130
+
131
+ tick_colors = ["black"] * n_rows
132
+ tick_colors[-1] = "#1A6BBF" # baseline label in blue
133
+ ax.set_yticks(y_positions)
134
+ ax.set_yticklabels(df["label"], fontsize=12)
135
+ for ticklabel, color in zip(ax.get_yticklabels(), tick_colors):
136
+ ticklabel.set_color(color)
137
+ if color != "black":
138
+ ticklabel.set_fontweight("bold")
139
+ ax.set_xlabel("MAE (log$_{10}$ scale)", fontsize=12)
140
+ ax.grid(True, axis="x", alpha=0.4, color="#CCCCCC", linewidth=0.6)
141
+ ax.set_axisbelow(True)
142
+ ax.spines[["top", "right"]].set_visible(False)
143
+ ax.tick_params(axis="y", length=0, labelsize=11)
144
+ ax.tick_params(axis="x", labelsize=10)
145
+
146
+ # Legend
147
+ class_patches = [
148
+ mpatches.Patch(color=CLASS_COLORS[c], label=f"{c}-class") for c in FLARE_CLASSES
149
+ ]
150
+ overall_patch = mpatches.Patch(facecolor="white", edgecolor="black", label="Overall")
151
+ #baseline_patch = mpatches.Patch(facecolor="white", edgecolor="#1A6BBF", label="Baseline (overall)")
152
+ ax.legend(handles=class_patches + [overall_patch],
153
+ loc="upper right", fontsize=10, framealpha=0.9,
154
+ edgecolor="#CCCCCC")
155
+
156
+ # ax.set_title("Ablation Study — Log MAE by Channel & Flare Class",
157
+ # fontsize=14, fontweight="bold", pad=14)
158
+ plt.xlim(0, .85)
159
+ plt.tight_layout()
160
+ plt.savefig("ablation_lollipop.png", dpi=450, bbox_inches="tight")
161
+ plt.show()
162
+ print("Saved: analysis/ablation_lollipop.png")
analysis/flux_map_analysis.py ADDED
@@ -0,0 +1,916 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flare Analysis — Frame & Movie Generator
4
+
5
+ Detects and tracks active regions from flux contribution maps,
6
+ then renders per-timestamp frames and stitches them into a movie.
7
+
8
+ Usage:
9
+ python flux_map_analysis.py --config flux_map_config.yaml
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import os
16
+ import time
17
+ import warnings
18
+ from dataclasses import dataclass
19
+ from datetime import datetime
20
+ from heapq import heappush, heappop
21
+ from multiprocessing import Pool
22
+ from pathlib import Path
23
+ from typing import Dict, List, Optional, Tuple
24
+
25
+ import imageio.v2 as imageio
26
+ import imageio_ffmpeg
27
+ import matplotlib
28
+ matplotlib.use('Agg')
29
+ import matplotlib.dates as mdates
30
+ import matplotlib.font_manager as fm
31
+ import matplotlib.pyplot as plt
32
+ import numpy as np
33
+ import pandas as pd
34
+ import yaml
35
+ from matplotlib import rcParams
36
+ from scipy.ndimage import maximum_filter, gaussian_filter
37
+ from tqdm import tqdm
38
+
39
+ warnings.filterwarnings('ignore')
40
+
41
+
42
+ # =============================================================================
43
+ # Configuration
44
+ # =============================================================================
45
+
46
+ @dataclass
47
+ class FlareAnalysisConfig:
48
+ """Configuration for flare analysis."""
49
+
50
+ # Paths
51
+ flux_path: Optional[str] = None
52
+ aia_path: Optional[str] = None
53
+ predictions_csv: Optional[str] = None
54
+ output_dir: Optional[str] = None
55
+
56
+ # Time range
57
+ start_time: Optional[str] = None
58
+ end_time: Optional[str] = None
59
+
60
+ # Detection
61
+ min_flux_threshold: float = 1e-7
62
+ threshold_std_multiplier: float = 3.0
63
+ spatial_smoothing_sigma: float = 1.0
64
+ radial_expansion_threshold_percentile: float = 30.0
65
+ peak_neighborhood_sizes: Tuple[int, ...] = (10, 15, 20, 25)
66
+ peak_min_scale_agreement: int = 2
67
+ peak_scale_tolerance: int = 2
68
+ min_peak_distance: int = 10
69
+
70
+ # Grid
71
+ grid_size: Tuple[int, int] = (64, 64)
72
+ patch_size: int = 8
73
+ input_size: int = 512
74
+
75
+ # Tracking
76
+ max_tracking_distance: int = 8
77
+ flux_ratio_weight: float = 0.1
78
+ size_ratio_weight: float = 0.1
79
+ distance_weight: float = 1.0
80
+ age_bonus_weight: float = 1.0 # scales 1/(1+age) penalty on new tracks
81
+ cadence_seconds: float = 60.0
82
+ max_gap_frames: int = 1 # frames a track can persist without a detection before expiring
83
+
84
+ # Movie / output
85
+ create_movie: bool = False
86
+ plot_window_hours: float = 4.0
87
+ movie_fps: float = 2.0
88
+ movie_frame_interval_minutes: float = 1.0
89
+ movie_num_workers: int = 4
90
+ movie_dpi: float = 75.0
91
+ movie_frame_format: str = 'jpg'
92
+ movie_jpeg_quality: int = 90
93
+
94
+ @classmethod
95
+ def from_yaml(cls, path: str) -> "FlareAnalysisConfig":
96
+ with open(path) as f:
97
+ raw = yaml.safe_load(f) or {}
98
+
99
+ # Flatten one level of nesting
100
+ # The 'movie' section uses short keys (fps, dpi, …) — prefix them with 'movie_'
101
+ flat: Dict = {}
102
+ for key, val in raw.items():
103
+ if isinstance(val, dict):
104
+ if key == 'movie':
105
+ valid = cls.__dataclass_fields__
106
+ for k, v in val.items():
107
+ if k in valid:
108
+ flat[k] = v
109
+ elif f'movie_{k}' in valid:
110
+ flat[f'movie_{k}'] = v
111
+ else:
112
+ flat[k] = v
113
+ else:
114
+ flat.update(val)
115
+ else:
116
+ flat[key] = val
117
+
118
+ # Renamed YAML keys
119
+ if 'start' in flat:
120
+ flat['start_time'] = flat.pop('start')
121
+ if 'end' in flat:
122
+ flat['end_time'] = flat.pop('end')
123
+
124
+ # Lists → tuples for tuple-typed fields
125
+ for k in ('grid_size', 'peak_neighborhood_sizes'):
126
+ if k in flat and isinstance(flat[k], list):
127
+ flat[k] = tuple(flat[k])
128
+
129
+ valid = {f for f in cls.__dataclass_fields__}
130
+ return cls(**{k: v for k, v in flat.items() if k in valid and v is not None})
131
+
132
+
133
+ # =============================================================================
134
+ # Utilities
135
+ # =============================================================================
136
+
137
+ def flux_to_goes_class(flux: float) -> str:
138
+ """Convert physical flux (W/m²) to GOES class string."""
139
+ if not isinstance(flux, (int, float)) or np.isnan(flux) or flux <= 0:
140
+ return "N/A"
141
+ if flux >= 1e-4:
142
+ prefix, scale = "X", 1e-4
143
+ elif flux >= 1e-5:
144
+ prefix, scale = "M", 1e-5
145
+ elif flux >= 1e-6:
146
+ prefix, scale = "C", 1e-6
147
+ elif flux >= 1e-7:
148
+ prefix, scale = "B", 1e-7
149
+ else:
150
+ prefix, scale = "A", 1e-8
151
+ magnitude = min(flux / scale, 9.9)
152
+ return f"{prefix}{magnitude:.1f}" if magnitude != int(magnitude) else f"{prefix}{int(magnitude)}.0"
153
+
154
+
155
+ def setup_barlow_font() -> None:
156
+ """Register and activate the Barlow font if available."""
157
+ try:
158
+ barlow_fonts = [
159
+ (f.name, f.fname) for f in fm.fontManager.ttflist
160
+ if 'barlow' in f.name.lower()
161
+ ]
162
+ if barlow_fonts:
163
+ preferred = next((n for n, _ in barlow_fonts if n.lower() in ('barlow', 'barlow regular')), barlow_fonts[0][0])
164
+ rcParams['font.family'] = preferred
165
+ return
166
+
167
+ search_paths = [
168
+ os.path.expanduser('~/Library/Fonts/Barlow-Regular.otf'),
169
+ os.path.expanduser('~/Library/Fonts/Barlow-Regular.ttf'),
170
+ '/Library/Fonts/Barlow-Regular.otf',
171
+ '/usr/share/fonts/truetype/barlow/Barlow-Regular.ttf',
172
+ ]
173
+ for path in search_paths:
174
+ if os.path.exists(path):
175
+ fm.fontManager.addfont(path)
176
+ from matplotlib.font_manager import FontProperties
177
+ rcParams['font.family'] = FontProperties(fname=path).get_name()
178
+ return
179
+ except Exception:
180
+ pass
181
+ rcParams['font.family'] = 'sans-serif'
182
+
183
+
184
+ def load_aia_image_at_time(aia_path: Path, timestamp: str) -> Optional[np.ndarray]:
185
+ """Load AIA image as normalised RGB composite (channels 0, 1, 2 → 94, 131, 171 Å)."""
186
+ if aia_path is None or not aia_path.exists():
187
+ return None
188
+
189
+ search_dirs = [aia_path] + [aia_path / s for s in ('test', 'train', 'val') if (aia_path / s).exists()]
190
+ for d in search_dirs:
191
+ fp = d / f"{timestamp}.npy"
192
+ if fp.exists():
193
+ try:
194
+ data = np.load(fp) # (7, H, W)
195
+ if data.ndim == 3 and data.shape[0] >= 3:
196
+ rgb = np.zeros((data.shape[1], data.shape[2], 3))
197
+ for i in range(3):
198
+ ch = data[i]
199
+ r = ch.max() - ch.min()
200
+ rgb[..., i] = (ch - ch.min()) / r if r > 0 else ch
201
+ return rgb
202
+ except Exception:
203
+ continue
204
+ return None
205
+
206
+
207
+ # =============================================================================
208
+ # Region Detection & Tracking
209
+ # =============================================================================
210
+
211
+ class FluxContributionAnalyzer:
212
+ """Detects and tracks active regions from per-patch flux contribution maps."""
213
+
214
+ def __init__(self, config: FlareAnalysisConfig, output_dir: Optional[Path] = None):
215
+ self.config = config
216
+ self.flux_path = Path(config.flux_path) if config.flux_path else None
217
+ self.aia_path = Path(config.aia_path) if config.aia_path else None
218
+ self.output_dir = output_dir
219
+ self.grid_size = config.grid_size
220
+ self.patch_size = config.patch_size
221
+ self.input_size = config.input_size
222
+ self.region_labels_cache: Dict[str, np.ndarray] = {}
223
+
224
+ if config.predictions_csv:
225
+ self.predictions_df = pd.read_csv(config.predictions_csv)
226
+ self.predictions_df['datetime'] = pd.to_datetime(self.predictions_df['timestamp'])
227
+ self.predictions_df = self.predictions_df.sort_values('datetime')
228
+ if config.start_time and config.end_time:
229
+ start, end = pd.to_datetime(config.start_time), pd.to_datetime(config.end_time)
230
+ mask = (self.predictions_df['datetime'] >= start) & (self.predictions_df['datetime'] <= end)
231
+ self.predictions_df = self.predictions_df[mask].reset_index(drop=True)
232
+ print(f"Loaded {len(self.predictions_df)} predictions "
233
+ f"({self.predictions_df['datetime'].min()} → {self.predictions_df['datetime'].max()})")
234
+ else:
235
+ self.predictions_df = pd.DataFrame()
236
+
237
+ # ------------------------------------------------------------------
238
+ # Data loading
239
+ # ------------------------------------------------------------------
240
+
241
+ def load_flux_contributions(self, timestamp: str) -> Optional[np.ndarray]:
242
+ if self.flux_path is None:
243
+ return None
244
+ fp = self.flux_path / f"{timestamp}.npy"
245
+ return np.load(fp) if fp.exists() else None
246
+
247
+ # ------------------------------------------------------------------
248
+ # Peak detection
249
+ # ------------------------------------------------------------------
250
+
251
+ def _find_flux_peaks_single_scale(self, flux: np.ndarray, size: int) -> Tuple[List, List]:
252
+ valid = np.isfinite(flux) & (flux > 0)
253
+ masked = np.where(valid, flux, -np.inf)
254
+ local_max = (maximum_filter(masked, size=size) == masked) & valid
255
+ ys, xs = np.where(local_max)
256
+ coords = list(zip(ys.tolist(), xs.tolist()))
257
+ fluxes = [float(flux[y, x]) for y, x in coords]
258
+ return coords, fluxes
259
+
260
+ def _find_flux_peaks_multiscale(self, flux: np.ndarray) -> Tuple[List, List]:
261
+ cfg = self.config
262
+ registry: Dict[Tuple, dict] = {}
263
+
264
+ for size in cfg.peak_neighborhood_sizes:
265
+ coords, fluxes = self._find_flux_peaks_single_scale(flux, size)
266
+ for (y, x), fv in zip(coords, fluxes):
267
+ matched = next(
268
+ ((py, px) for (py, px) in registry
269
+ if abs(y - py) <= cfg.peak_scale_tolerance and abs(x - px) <= cfg.peak_scale_tolerance),
270
+ None
271
+ )
272
+ if matched:
273
+ e = registry[matched]
274
+ e['count'] += 1
275
+ if fv > e['best_flux']:
276
+ e['best_flux'] = fv
277
+ e['best_coord'] = (y, x)
278
+ else:
279
+ registry[(y, x)] = {'count': 1, 'best_flux': fv, 'best_coord': (y, x)}
280
+
281
+ stable = [(e['best_coord'], e['best_flux'])
282
+ for e in registry.values() if e['count'] >= cfg.peak_min_scale_agreement]
283
+ if not stable:
284
+ return [], []
285
+
286
+ stable.sort(key=lambda p: p[1], reverse=True)
287
+
288
+ coords = [p[0] for p in stable]
289
+ fluxes = [p[1] for p in stable]
290
+
291
+ if cfg.min_peak_distance > 0 and len(coords) > 1:
292
+ coords, fluxes = self._merge_close_peaks(coords, fluxes, cfg.min_peak_distance)
293
+
294
+ return coords, fluxes
295
+
296
+ def _merge_close_peaks(self, coords, fluxes, min_dist):
297
+ order = np.argsort(fluxes)[::-1]
298
+ kept = []
299
+ for i in order:
300
+ if all(np.hypot(coords[i][0] - coords[j][0], coords[i][1] - coords[j][1]) >= min_dist
301
+ for j in kept):
302
+ kept.append(i)
303
+ kept = sorted(kept)
304
+ return [coords[i] for i in kept], [fluxes[i] for i in kept]
305
+
306
+ # ------------------------------------------------------------------
307
+ # Region segmentation (radial flood-fill from peaks)
308
+ # ------------------------------------------------------------------
309
+
310
+ def _detect_regions_with_peak_clustering(
311
+ self, flux_contrib: np.ndarray, pred_data: pd.Series
312
+ ) -> Tuple[List[Dict], Optional[np.ndarray], str]:
313
+
314
+ cfg = self.config
315
+ valid = flux_contrib[np.isfinite(flux_contrib) & (flux_contrib > 0)]
316
+ if valid.size == 0:
317
+ return [], None, "no_valid_flux"
318
+
319
+ total_flux = float(flux_contrib[flux_contrib > 0].sum())
320
+ log_flux = np.log(valid)
321
+ threshold = max(
322
+ np.exp(np.median(log_flux) + cfg.threshold_std_multiplier * np.std(log_flux)),
323
+ cfg.min_flux_threshold,
324
+ )
325
+ above = int((flux_contrib > threshold).sum())
326
+ masked = np.where(flux_contrib > threshold, flux_contrib, 0.0)
327
+
328
+ if above == 0:
329
+ return [], None, f"all_below_threshold(thr={threshold:.3e} total={total_flux:.3e})"
330
+
331
+ if cfg.spatial_smoothing_sigma > 0:
332
+ masked = gaussian_filter(masked, sigma=cfg.spatial_smoothing_sigma)
333
+
334
+ peak_coords, peak_fluxes = self._find_flux_peaks_multiscale(masked)
335
+ if not peak_coords:
336
+ return [], None, f"no_peaks(thr={threshold:.3e} above={above} total={total_flux:.3e})"
337
+
338
+ # Radial flood-fill from all peaks simultaneously (Dijkstra-style)
339
+ labels = np.zeros_like(masked, dtype=np.int32)
340
+ valid_vals = masked[(masked > 0) & np.isfinite(masked)]
341
+ growth_threshold = np.percentile(valid_vals, cfg.radial_expansion_threshold_percentile) if valid_vals.size else 0
342
+
343
+ pq, counter = [], 0
344
+ for idx, ((py, px), _) in enumerate(zip(peak_coords, peak_fluxes)):
345
+ labels[py, px] = idx + 1
346
+ heappush(pq, (0.0, counter, py, px, idx + 1, py, px))
347
+ counter += 1
348
+
349
+ neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
350
+ H, W = masked.shape
351
+ while pq:
352
+ dist, _, y, x, label, py, px = heappop(pq)
353
+ for dy, dx in neighbors:
354
+ ny, nx = y + dy, x + dx
355
+ if 0 <= ny < H and 0 <= nx < W and labels[ny, nx] == 0 and masked[ny, nx] > growth_threshold:
356
+ labels[ny, nx] = label
357
+ new_dist = np.hypot(ny - py, nx - px)
358
+ heappush(pq, (new_dist, counter, ny, nx, label, py, px))
359
+ counter += 1
360
+
361
+ regions = []
362
+ skipped_below_min = 0
363
+ for lid in range(1, len(peak_coords) + 1):
364
+ mask = labels == lid
365
+ ys, xs = np.where(mask)
366
+ if ys.size == 0:
367
+ continue
368
+ fv = masked[mask]
369
+ total = float(fv.sum())
370
+ if total < cfg.min_flux_threshold:
371
+ skipped_below_min += 1
372
+ continue
373
+ cy, cx = float(ys.mean()), float(xs.mean())
374
+ peak_y, peak_x = peak_coords[lid - 1]
375
+ regions.append({
376
+ 'id': len(regions) + 1,
377
+ 'region_label': lid,
378
+ 'size': int(ys.size),
379
+ 'sum_flux': total,
380
+ 'max_flux': float(fv.max()),
381
+ 'centroid_patch_y': cy,
382
+ 'centroid_patch_x': cx,
383
+ 'centroid_img_y': cy * self.patch_size + self.patch_size // 2,
384
+ 'centroid_img_x': cx * self.patch_size + self.patch_size // 2,
385
+ 'peak_y': peak_y,
386
+ 'peak_x': peak_x,
387
+ 'peak_img_y': peak_y * self.patch_size + self.patch_size // 2,
388
+ 'peak_img_x': peak_x * self.patch_size + self.patch_size // 2,
389
+ 'peak_flux': peak_fluxes[lid - 1],
390
+ 'mask': mask,
391
+ })
392
+
393
+ n_peaks = len(peak_coords)
394
+ reason = (f"ok: {len(regions)} regions from {n_peaks} peaks"
395
+ f" thr={threshold:.3e} above={above} total={total_flux:.3e}"
396
+ + (f" skipped={skipped_below_min}_below_min_flux" if skipped_below_min else ""))
397
+ return regions, labels, reason
398
+
399
+ def _detect_regions_worker(self, timestamp: str) -> Tuple[str, Optional[List], Optional[np.ndarray], str]:
400
+ try:
401
+ flux = self.load_flux_contributions(timestamp)
402
+ if flux is None:
403
+ return timestamp, None, None, "no_flux_file"
404
+ pred = self.predictions_df[self.predictions_df['timestamp'] == timestamp]
405
+ if pred.empty:
406
+ return timestamp, None, None, "no_prediction_row"
407
+ regions, labels, reason = self._detect_regions_with_peak_clustering(flux, pred.iloc[0])
408
+ return timestamp, regions, (labels.astype(np.int16) if labels is not None else None), reason
409
+ except Exception as e:
410
+ return timestamp, None, None, f"exception: {e}"
411
+
412
+ # ------------------------------------------------------------------
413
+ # Tracking
414
+ # ------------------------------------------------------------------
415
+
416
+ def track_regions_over_time(self, timestamps: List[str]) -> Dict:
417
+ cfg = self.config
418
+ print("Detecting regions (parallel)…")
419
+ n_workers = max(1, min((os.cpu_count() or 1) - 1, len(timestamps)))
420
+
421
+ all_regions: Dict[str, List] = {}
422
+ detection_reasons: Dict[str, str] = {}
423
+ with Pool(processes=n_workers) as pool:
424
+ for ts, regions, labels, reason in tqdm(
425
+ pool.imap(self._detect_regions_worker, timestamps),
426
+ total=len(timestamps), desc="Detecting regions"
427
+ ):
428
+ detection_reasons[ts] = reason
429
+ if regions is not None:
430
+ all_regions[ts] = regions
431
+ if labels is not None:
432
+ self.region_labels_cache[ts] = labels
433
+
434
+ print("Tracking regions across time…")
435
+ print(f" max_tracking_distance={cfg.max_tracking_distance} "
436
+ f"max_gap_frames={cfg.max_gap_frames} "
437
+ f"age_bonus_weight={cfg.age_bonus_weight} "
438
+ f"distance_weight={cfg.distance_weight}")
439
+ tracks: Dict[int, List] = {}
440
+ next_id = 1
441
+ last_seen: Dict[int, int] = {} # track_id → frame index when last matched
442
+ _debug_log: List[str] = [] # per-frame tracking log
443
+
444
+ for frame_idx, ts in enumerate(tqdm(timestamps, desc="Tracking")):
445
+ # Expire tracks that haven't been seen within max_gap_frames
446
+ active = {tid for tid, fi in last_seen.items()
447
+ if frame_idx - fi <= cfg.max_gap_frames}
448
+
449
+ if ts not in all_regions:
450
+ det_reason = detection_reasons.get(ts, "unknown")
451
+ _debug_log.append(f"{ts} SKIP {det_reason}")
452
+ continue
453
+
454
+ current_regions = all_regions[ts]
455
+
456
+ # Build all valid (score, region_idx, track_id) candidates
457
+ candidates = []
458
+ for ri, region in enumerate(current_regions):
459
+ cur_flux = region.get('sum_flux', 0.0)
460
+ cur_size = region.get('size', 1)
461
+ for tid in active:
462
+ history = tracks[tid]
463
+ # Smooth position over last few frames to reduce centroid jitter.
464
+ # Use PATCH coordinates so max_tracking_distance is in patch units (matching config).
465
+ n_smooth = min(5, len(history))
466
+ avg_x = np.mean([h[1]['centroid_patch_x'] for h in history[-n_smooth:]])
467
+ avg_y = np.mean([h[1]['centroid_patch_y'] for h in history[-n_smooth:]])
468
+ dist = np.hypot(
469
+ region['centroid_patch_x'] - avg_x,
470
+ region['centroid_patch_y'] - avg_y,
471
+ )
472
+ _, last = history[-1]
473
+ if dist >= cfg.max_tracking_distance:
474
+ continue
475
+ lf = last.get('sum_flux', 1e-15)
476
+ ls = last.get('size', 1)
477
+ flux_ratio = max(cur_flux, lf) / max(min(cur_flux, lf), 1e-15)
478
+ size_ratio = max(cur_size, ls) / max(min(cur_size, ls), 1)
479
+ track_age = len(tracks[tid])
480
+ # Discount grows with age: 0 (new) → age_bonus_weight (very old)
481
+ # Makes established tracks harder to beat at equal distance
482
+ age_discount = cfg.age_bonus_weight * track_age / (1.0 + track_age)
483
+ score = (cfg.distance_weight * dist
484
+ + cfg.flux_ratio_weight * flux_ratio
485
+ + cfg.size_ratio_weight * size_ratio
486
+ - age_discount)
487
+ candidates.append((score, ri, tid))
488
+
489
+ # Greedy one-to-one assignment: best scores first, each region/track used once
490
+ candidates.sort()
491
+ assigned_regions: set = set()
492
+ assigned_tracks: set = set()
493
+ assignments: Dict[int, int] = {} # region_idx → track_id
494
+ for score, ri, tid in candidates:
495
+ if ri in assigned_regions or tid in assigned_tracks:
496
+ continue
497
+ assignments[ri] = tid
498
+ assigned_regions.add(ri)
499
+ assigned_tracks.add(tid)
500
+
501
+ # Log detection outcome for this frame
502
+ det_reason = detection_reasons.get(ts, "unknown")
503
+ _debug_log.append(f"{ts} DETECT {det_reason}")
504
+
505
+ # Log active-but-unmatched tracks (gaps)
506
+ for tid in active:
507
+ if tid not in assigned_tracks:
508
+ gap = frame_idx - last_seen.get(tid, frame_idx)
509
+ cx = tracks[tid][-1][1]['centroid_patch_x']
510
+ cy = tracks[tid][-1][1]['centroid_patch_y']
511
+ _debug_log.append(
512
+ f"{ts} GAP track={tid:3d} age={len(tracks[tid]):4d} "
513
+ f"gap_frames={gap:2d} last_patch=({cx:.1f},{cy:.1f})"
514
+ )
515
+
516
+ # Apply assignments; spawn new track for unmatched regions
517
+ for ri, region in enumerate(current_regions):
518
+ r = region.copy()
519
+ r['timestamp'] = ts
520
+ if ri in assignments:
521
+ tid = assignments[ri]
522
+ r['id'] = tid
523
+ tracks[tid].append((ts, r))
524
+ cx, cy = r['centroid_patch_x'], r['centroid_patch_y']
525
+ _debug_log.append(
526
+ f"{ts} MATCH track={tid:3d} age={len(tracks[tid]):4d} "
527
+ f"patch=({cx:.1f},{cy:.1f}) flux={r.get('sum_flux', 0):.3e}"
528
+ )
529
+ else:
530
+ r['id'] = next_id
531
+ tracks[next_id] = [(ts, r)]
532
+ cx, cy = r['centroid_patch_x'], r['centroid_patch_y']
533
+ _debug_log.append(
534
+ f"{ts} NEW track={next_id:3d} age= 1 "
535
+ f"patch=({cx:.1f},{cy:.1f}) flux={r.get('sum_flux', 0):.3e}"
536
+ )
537
+ next_id += 1
538
+ tid = r['id']
539
+ last_seen[tid] = frame_idx
540
+
541
+ tracks = {k: v for k, v in tracks.items() if v}
542
+ print(f"Found {len(tracks)} region tracks across {len(timestamps)} timestamps")
543
+
544
+ if self.output_dir and _debug_log:
545
+ log_path = Path(self.output_dir) / "tracking_debug.log"
546
+ with open(log_path, 'w') as f:
547
+ f.write(f"# Tracking log — {len(tracks)} tracks, {len(timestamps)} timestamps\n")
548
+ f.write(f"# max_tracking_distance={cfg.max_tracking_distance} "
549
+ f"max_gap_frames={cfg.max_gap_frames} "
550
+ f"age_bonus_weight={cfg.age_bonus_weight}\n")
551
+ f.write("#\n# timestamp event track age detail\n")
552
+ f.write('\n'.join(_debug_log))
553
+ print(f"Tracking debug log → {log_path}")
554
+
555
+ return tracks
556
+
557
+ def detect_flare_events(self, timestamps: Optional[List[str]] = None) -> pd.DataFrame:
558
+ """Run detection + tracking and return a per-timestamp events DataFrame."""
559
+ if timestamps is None:
560
+ timestamps = self.predictions_df['timestamp'].tolist()
561
+
562
+ tracks = self.track_regions_over_time(timestamps)
563
+ rows = []
564
+ for track_id, history in tracks.items():
565
+ for ts, r in history:
566
+ pred = self.predictions_df[self.predictions_df['timestamp'] == ts]
567
+ if pred.empty:
568
+ continue
569
+ pred = pred.iloc[0]
570
+ rows.append({
571
+ 'timestamp': ts,
572
+ 'datetime': pred['datetime'],
573
+ 'prediction': pred['predictions'],
574
+ 'groundtruth': pred.get('groundtruth', None),
575
+ 'region_size': r.get('size', 0),
576
+ 'sum_flux': r.get('sum_flux', 0.0),
577
+ 'max_flux': r.get('max_flux', 0.0),
578
+ 'mean_flux': r.get('sum_flux', 0.0) / max(r.get('size', 1), 1),
579
+ 'centroid_patch_y': r.get('centroid_patch_y', 0.0),
580
+ 'centroid_patch_x': r.get('centroid_patch_x', 0.0),
581
+ 'centroid_img_y': r.get('centroid_img_y', 0.0),
582
+ 'centroid_img_x': r.get('centroid_img_x', 0.0),
583
+ 'peak_img_y': r.get('peak_img_y', None),
584
+ 'peak_img_x': r.get('peak_img_x', None),
585
+ 'region_label': r.get('region_label', None),
586
+ 'track_id': track_id,
587
+ })
588
+
589
+ print(f"Recorded {len(rows)} events from {len(tracks)} tracks")
590
+ return pd.DataFrame(rows) if rows else pd.DataFrame()
591
+
592
+
593
+ # =============================================================================
594
+ # Frame Generation
595
+ # =============================================================================
596
+
597
+ # Colours cycled across FOXES tracks
598
+ _TRACK_COLORS = [
599
+ '#E6194B', '#3CB44B', '#FFE119', '#4363D8', '#F58231',
600
+ '#911EB4', '#42D4F4', '#F032E6', '#BFEF45', '#FABED4',
601
+ '#469990', '#DCBEFF', '#9A6324', '#FFFAC8', '#800000',
602
+ '#AAFFC3', '#808000', '#FFD8B1', '#000075', '#A9A9A9',
603
+ ]
604
+
605
+
606
+ def _generate_single_frame(args: Tuple) -> Optional[str]:
607
+ """Render one frame. Designed for multiprocessing."""
608
+ setup_barlow_font()
609
+
610
+ frame_idx, timestamp, fd = args
611
+ try:
612
+ flare_events_df = fd['flare_events_df']
613
+ predictions_df = fd['predictions_df']
614
+ region_labels_cache = fd['region_labels_cache']
615
+ config = fd['config']
616
+ track_color_map = fd['track_color_map']
617
+ plot_window_hours = fd['plot_window_hours']
618
+ aia_path = fd['aia_path']
619
+ frames_dir = Path(fd['frames_dir'])
620
+
621
+ current_time = pd.to_datetime(timestamp)
622
+ window_start = current_time - pd.Timedelta(hours=plot_window_hours / 2)
623
+ window_end = current_time + pd.Timedelta(hours=plot_window_hours / 2)
624
+
625
+ # ── Figure: AIA (left) + SXR timeseries (right) ─────────────────────
626
+ fig = plt.figure(figsize=(14, 7))
627
+ gs = fig.add_gridspec(1, 2, width_ratios=[1, 1], wspace=0.3,
628
+ left=0.07, right=0.97, top=0.93, bottom=0.10)
629
+ ax_aia = fig.add_subplot(gs[0])
630
+ ax_sxr = fig.add_subplot(gs[1])
631
+
632
+ # ── AIA image ────────────────────────────────────────────────────────
633
+ aia_image = load_aia_image_at_time(Path(aia_path), timestamp) if aia_path else None
634
+ if aia_image is not None:
635
+ ax_aia.imshow(aia_image, origin='lower', aspect='equal', alpha=0.9)
636
+ else:
637
+ ax_aia.imshow(np.zeros((512, 512, 3)), origin='lower', aspect='equal')
638
+
639
+ ax_aia.set_title(f'{current_time.strftime("%Y-%m-%d %H:%M:%S")}', fontsize=11)
640
+ ax_aia.set_xlabel('X (pixels)', fontsize=9)
641
+ ax_aia.set_ylabel('Y (pixels)', fontsize=9)
642
+
643
+ # ── Region contours + FOXES markers ──────────────────────────────────
644
+ region_labels = region_labels_cache.get(timestamp)
645
+ current_events = (
646
+ flare_events_df[flare_events_df['timestamp'] == timestamp].copy()
647
+ if not flare_events_df.empty and 'timestamp' in flare_events_df.columns
648
+ else pd.DataFrame()
649
+ )
650
+ plotted_tracks: set = set()
651
+
652
+ for _, ev in current_events.iterrows():
653
+ tid = ev['track_id']
654
+ if tid in plotted_tracks:
655
+ continue
656
+ plotted_tracks.add(tid)
657
+
658
+ cx, cy = ev.get('centroid_img_x'), ev.get('centroid_img_y')
659
+ if pd.isna(cx) or pd.isna(cy) or not (0 <= cx <= 512) or not (0 <= cy <= 512):
660
+ continue
661
+
662
+ px = ev.get('peak_img_x') if pd.notna(ev.get('peak_img_x')) else cx
663
+ py = ev.get('peak_img_y') if pd.notna(ev.get('peak_img_y')) else cy
664
+ color = track_color_map.get(tid, _TRACK_COLORS[0])
665
+ cur_flux = ev.get('sum_flux', 0.0)
666
+ is_active = cur_flux >= config.min_flux_threshold
667
+
668
+ # Contour
669
+ rl = ev.get('region_label')
670
+ if region_labels is not None and pd.notna(rl) and int(rl) > 0:
671
+ region_mask = region_labels == int(rl)
672
+ if np.any(region_mask):
673
+ try:
674
+ # Upsample 64×64 mask to 512×512 for crisp contours on AIA image
675
+ scale = 512 // region_labels.shape[0]
676
+ mask_up = region_mask.repeat(scale, axis=0).repeat(scale, axis=1).astype(float)
677
+ ax_aia.contour(mask_up, levels=[0.5],
678
+ colors=color, linewidths=4.0 if is_active else 2.5,
679
+ alpha=0.9, extent=[0, 512, 0, 512])
680
+ except Exception:
681
+ pass
682
+
683
+ # Marker
684
+ if is_active:
685
+ ax_aia.plot(px, py, '*', markersize=15, color=color,
686
+ markeredgecolor='black', markeredgewidth=2, alpha=0.7, zorder=15)
687
+ ax_aia.annotate(f'FOXES: {flux_to_goes_class(cur_flux)}', (px, py),
688
+ xytext=(15, 15), textcoords='offset points', fontsize=11,
689
+ color='black', weight='bold',
690
+ bbox=dict(boxstyle='round,pad=0.3', facecolor=color,
691
+ alpha=0.95, edgecolor='black', linewidth=2))
692
+ else:
693
+ ax_aia.plot(px, py, 'o', markersize=10, color=color,
694
+ markeredgecolor='white', markeredgewidth=1.5, alpha=0.8, zorder=12)
695
+
696
+ # ── SXR timeseries ───────────────────────────────────────────────────
697
+ if predictions_df is not None and not predictions_df.empty:
698
+ in_win = predictions_df[
699
+ (predictions_df['datetime'] >= window_start) &
700
+ (predictions_df['datetime'] <= window_end)
701
+ ]
702
+ if not in_win.empty:
703
+ if 'groundtruth' in in_win.columns:
704
+ ax_sxr.plot(in_win['datetime'], in_win['groundtruth'],
705
+ 'b-', linewidth=1.5, alpha=0.8, label='GOES (Truth)')
706
+ if 'predictions' in in_win.columns:
707
+ ax_sxr.plot(in_win['datetime'], in_win['predictions'],
708
+ 'r--', linewidth=1.5, alpha=0.8, label='FOXES')
709
+
710
+ # Track fluxes
711
+ all_tracks_in_win = (
712
+ flare_events_df[
713
+ (flare_events_df['datetime'] >= window_start) &
714
+ (flare_events_df['datetime'] <= window_end)
715
+ ] if not flare_events_df.empty else pd.DataFrame()
716
+ )
717
+ first_other = True
718
+ for tid, tdata in (all_tracks_in_win.groupby('track_id') if not all_tracks_in_win.empty else []):
719
+ tdata = tdata.sort_values('datetime')
720
+ color = track_color_map.get(tid, _TRACK_COLORS[0])
721
+ is_active = tdata['sum_flux'].max() >= config.min_flux_threshold
722
+ if is_active:
723
+ ax_sxr.plot(tdata['datetime'], tdata['sum_flux'],
724
+ color=color, linewidth=2.5, alpha=0.9, label=f'Track {tid}', zorder=4)
725
+ else:
726
+ label = 'Other tracks' if first_other else None
727
+ ax_sxr.plot(tdata['datetime'], tdata['sum_flux'],
728
+ color=color, linewidth=0.9, alpha=0.35, label=label, zorder=3)
729
+ first_other = False
730
+
731
+ ax_sxr.axvline(current_time, color='#E5446D', linewidth=2, alpha=0.8, zorder=10)
732
+ ax_sxr.set_xlim(window_start, window_end)
733
+ ax_sxr.set_yscale('log')
734
+ ax_sxr.set_ylabel('Flux (W/m²)', fontsize=9)
735
+ ax_sxr.set_xlabel('Time (UTC)', fontsize=9)
736
+ ax_sxr.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
737
+ ax_sxr.xaxis.set_major_locator(mdates.HourLocator(interval=1))
738
+ plt.setp(ax_sxr.xaxis.get_majorticklabels(), rotation=0)
739
+ ax_sxr.legend(loc='lower right', fontsize=8, framealpha=1)
740
+ ax_sxr.grid(True, alpha=0.3)
741
+
742
+ plt.tight_layout()
743
+
744
+ fmt = getattr(config, 'movie_frame_format', 'jpg').lower()
745
+ dpi = getattr(config, 'movie_dpi', 75.0)
746
+ ext = 'jpg' if fmt in ('jpg', 'jpeg') else 'png'
747
+ frame_path = frames_dir / f"frame_{frame_idx:06d}.{ext}"
748
+ plt.savefig(frame_path, dpi=dpi, format=ext)
749
+ plt.close()
750
+ return str(frame_path)
751
+
752
+ except Exception as e:
753
+ plt.close('all')
754
+ print(f"Error creating frame {frame_idx} ({timestamp}): {e}")
755
+ return None
756
+
757
+
758
+ # =============================================================================
759
+ # Movie Assembly
760
+ # =============================================================================
761
+
762
+ def create_flare_movie(
763
+ flare_events_df: pd.DataFrame,
764
+ output_dir: Path,
765
+ config: FlareAnalysisConfig,
766
+ predictions_csv: Optional[str] = None,
767
+ analyzer: Optional[FluxContributionAnalyzer] = None,
768
+ fps: float = 2.0,
769
+ frame_interval_minutes: float = 1.0,
770
+ num_workers: int = 4,
771
+ ) -> Optional[str]:
772
+ """Generate per-timestamp frames and stitch into an MP4."""
773
+ setup_barlow_font()
774
+
775
+ if flare_events_df.empty:
776
+ print("No flare data — skipping movie.")
777
+ return None
778
+
779
+ output_dir = Path(output_dir)
780
+ movie_dir = output_dir / "movies"
781
+ movie_dir.mkdir(parents=True, exist_ok=True)
782
+
783
+ # Load predictions for timeseries
784
+ predictions_df = None
785
+ if predictions_csv and Path(predictions_csv).exists():
786
+ predictions_df = pd.read_csv(predictions_csv)
787
+ dt_col = 'datetime' if 'datetime' in predictions_df.columns else 'timestamp'
788
+ predictions_df['datetime'] = pd.to_datetime(predictions_df[dt_col])
789
+
790
+ flare_events_df = flare_events_df.copy()
791
+ flare_events_df['datetime'] = pd.to_datetime(flare_events_df['datetime'])
792
+
793
+ all_timestamps = sorted(flare_events_df['timestamp'].unique())
794
+
795
+ # Subsample by frame_interval_minutes
796
+ timestamps_to_use, last_dt = [], None
797
+ for ts in all_timestamps:
798
+ dt = pd.to_datetime(ts)
799
+ if last_dt is None or (dt - last_dt).total_seconds() >= frame_interval_minutes * 60:
800
+ timestamps_to_use.append(ts)
801
+ last_dt = dt
802
+
803
+ print(f"Creating movie: {len(timestamps_to_use)} frames @ {fps} fps")
804
+
805
+ # Assign consistent colours to tracks
806
+ unique_tracks = flare_events_df['track_id'].unique()
807
+ track_color_map = {tid: _TRACK_COLORS[i % len(_TRACK_COLORS)] for i, tid in enumerate(unique_tracks)}
808
+
809
+ frames_dir = movie_dir / "frames_temp"
810
+ frames_dir.mkdir(exist_ok=True)
811
+
812
+ frame_data = {
813
+ 'flare_events_df': flare_events_df,
814
+ 'predictions_df': predictions_df,
815
+ 'frames_dir': str(frames_dir),
816
+ 'region_labels_cache': analyzer.region_labels_cache if analyzer else {},
817
+ 'config': config,
818
+ 'track_color_map': track_color_map,
819
+ 'plot_window_hours': config.plot_window_hours,
820
+ 'aia_path': config.aia_path,
821
+ }
822
+
823
+ frame_args = [(i, ts, frame_data) for i, ts in enumerate(timestamps_to_use)]
824
+
825
+ if num_workers > 1:
826
+ with Pool(processes=num_workers) as pool:
827
+ results = list(tqdm(pool.imap(_generate_single_frame, frame_args),
828
+ total=len(frame_args), desc="Generating frames"))
829
+ else:
830
+ results = [_generate_single_frame(a) for a in tqdm(frame_args, desc="Generating frames")]
831
+
832
+ frame_paths = sorted(
833
+ (Path(p) for p in results if p is not None),
834
+ key=lambda p: p.name
835
+ )
836
+
837
+ if not frame_paths:
838
+ print("No frames generated.")
839
+ return None
840
+
841
+ # Stitch into MP4 via imageio (reads frames as RGB → passes to ffmpeg correctly)
842
+ datetimes = [pd.to_datetime(ts) for ts in timestamps_to_use]
843
+ movie_name = (f"flare_movie_{datetimes[0].strftime('%Y%m%d')}"
844
+ f"_{datetimes[-1].strftime('%Y%m%d')}.mp4")
845
+ movie_path = movie_dir / movie_name
846
+
847
+ # Read first frame to get dimensions
848
+ first_frame = imageio.imread(str(frame_paths[0]))
849
+ h, w = first_frame.shape[:2]
850
+ # yuv420p requires even dimensions
851
+ w = w if w % 2 == 0 else w - 1
852
+ h = h if h % 2 == 0 else h - 1
853
+
854
+ t0 = time.time()
855
+ writer = imageio_ffmpeg.write_frames(
856
+ str(movie_path),
857
+ size=(w, h),
858
+ fps=fps,
859
+ codec='libx264',
860
+ pix_fmt_in='rgb24',
861
+ pix_fmt_out='yuv420p',
862
+ output_params=['-preset', 'veryfast', '-crf', '25', '-movflags', '+faststart'],
863
+ )
864
+ writer.send(None) # initialise
865
+ for fp in tqdm(frame_paths, desc="Writing movie"):
866
+ if fp.exists():
867
+ frame = imageio.imread(str(fp))
868
+ writer.send(frame[:h, :w].tobytes())
869
+ writer.close()
870
+
871
+ print(f"Movie saved → {movie_path} ({time.time() - t0:.1f}s)")
872
+ print(f"Frames kept → {frames_dir}")
873
+
874
+ return str(movie_path)
875
+
876
+
877
+ # =============================================================================
878
+ # Entry point
879
+ # =============================================================================
880
+
881
+ def main() -> None:
882
+ parser = argparse.ArgumentParser(description="Flare Analysis — Frame & Movie Generator")
883
+ parser.add_argument("--config", required=True, help="Path to YAML config file")
884
+ args = parser.parse_args()
885
+
886
+ config = FlareAnalysisConfig.from_yaml(args.config)
887
+
888
+ run_ts = datetime.now().strftime("%Y%m%d_%H%M%S")
889
+ out_dir = Path(config.output_dir or '.') / f"run_{run_ts}"
890
+ out_dir.mkdir(parents=True, exist_ok=True)
891
+ print(f"Output: {out_dir}")
892
+
893
+ analyzer = FluxContributionAnalyzer(config, output_dir=out_dir)
894
+ flare_events_df = analyzer.detect_flare_events()
895
+
896
+ if not flare_events_df.empty:
897
+ flare_events_df.to_csv(out_dir / "flare_events.csv", index=False)
898
+ print(f"Saved {len(flare_events_df)} events → {out_dir / 'flare_events.csv'}")
899
+
900
+ if config.create_movie:
901
+ create_flare_movie(
902
+ flare_events_df = flare_events_df,
903
+ output_dir = out_dir,
904
+ config = config,
905
+ predictions_csv = config.predictions_csv,
906
+ analyzer = analyzer,
907
+ fps = config.movie_fps,
908
+ frame_interval_minutes = config.movie_frame_interval_minutes,
909
+ num_workers = config.movie_num_workers,
910
+ )
911
+
912
+ print(f"\nDone. Results in {out_dir}")
913
+
914
+
915
+ if __name__ == "__main__":
916
+ main()
analysis/flux_map_config.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================================================================
2
+ # Flare Analysis Configuration
3
+ # =============================================================================
4
+ # Usage: python analysis/flux_map_analysis.py --config analysis/flux_map_config.yaml
5
+ # =============================================================================
6
+
7
+ # -----------------------------------------------------------------------------
8
+ # Paths
9
+ # -----------------------------------------------------------------------------
10
+ paths:
11
+ flux_path: "/Volumes/T9/FOXES_Data/flux"
12
+ aia_path: "/Volumes/T9/FOXES_Data/AIA_processed"
13
+ predictions_csv: "/Volumes/T9/FOXES_Misc/batch_results/vit/vit_predictions_test.csv"
14
+ output_dir: "/Volumes/T9/FOXES_Data/flare_analysis"
15
+
16
+ # -----------------------------------------------------------------------------
17
+ # Time Range (null = use full predictions CSV range)
18
+ # -----------------------------------------------------------------------------
19
+ time_range:
20
+ start: "2024-03-23T00:00:00" # e.g. "2024-03-23T00:00:00"
21
+ end: "2024-03-23T3:00:00"
22
+
23
+ # -----------------------------------------------------------------------------
24
+ # Detection
25
+ # -----------------------------------------------------------------------------
26
+ detection:
27
+ min_flux_threshold: 1.0e-7 # W/m² — patches below this are ignored
28
+ threshold_std_multiplier: 4.0 # flux mask: mean + N*std
29
+ spatial_smoothing_sigma: 1.0 # Gaussian pre-smoothing (patches, 0 = off)
30
+ radial_expansion_threshold_percentile: 30.0 # flood-fill growth cutoff percentile
31
+ peak_neighborhood_sizes: [10, 15, 20, 25] # multi-scale local-max windows
32
+ peak_min_scale_agreement: 2 # peaks must appear at N scales
33
+ peak_scale_tolerance: 10 # patch-distance to count as "same peak"
34
+ min_peak_distance: 5 # min patches between distinct peaks
35
+
36
+ # -----------------------------------------------------------------------------
37
+ # Grid / Patch Parameters
38
+ # -----------------------------------------------------------------------------
39
+ grid:
40
+ grid_size: [64, 64]
41
+ patch_size: 8
42
+ input_size: 512
43
+
44
+ # -----------------------------------------------------------------------------
45
+ # Region Tracking
46
+ # -----------------------------------------------------------------------------
47
+ tracking:
48
+ max_tracking_distance: 10 # max patch-distance between frames to link regions
49
+ flux_ratio_weight: 0 # weight of flux-ratio term in linking score
50
+ size_ratio_weight: 0 # weight of size-ratio term in linking score
51
+ distance_weight: 1.0 # weight of spatial distance in linking score
52
+ age_bonus_weight: 2.0 # bias toward established tracks (scales 1/(1+age))
53
+ cadence_seconds: 60.0 # expected data cadence
54
+ max_gap_frames: 15 # frames a track can go undetected before expiring
55
+
56
+ # -----------------------------------------------------------------------------
57
+ # Movie / Output
58
+ # -----------------------------------------------------------------------------
59
+ movie:
60
+ create_movie: true
61
+ plot_window_hours: 2.0 # SXR plot time window around current frame
62
+ fps: 30.0
63
+ frame_interval_minutes: 1.0 # one frame per minute of data
64
+ num_workers: 8
65
+ dpi: 75.0
66
+ frame_format: "jpg" # "jpg" (fast) or "png" (quality)
67
+ jpeg_quality: 90
analysis/spatial_performance.py CHANGED
@@ -46,7 +46,7 @@ CROP_FACTOR = 1.1 # AIA images cropped at 1.1 solar radii
46
  SOLAR_RADIUS_PATCHES = (GRID_SIZE / 2) / CROP_FACTOR # ≈ 29.1 patches
47
 
48
  # Patches beyond ±PATCH_CROP_RADIUS from center (in original 64×64 patch units) are masked.
49
- PATCH_CROP_RADIUS = 19
50
 
51
  # Percentile cap for colorbar scaling (applied to non-NaN values).
52
  # e.g. 99 clips the top 1% of values so detail in the bulk is visible.
 
46
  SOLAR_RADIUS_PATCHES = (GRID_SIZE / 2) / CROP_FACTOR # ≈ 29.1 patches
47
 
48
  # Patches beyond ±PATCH_CROP_RADIUS from center (in original 64×64 patch units) are masked.
49
+ PATCH_CROP_RADIUS = 24
50
 
51
  # Percentile cap for colorbar scaling (applied to non-NaN values).
52
  # e.g. 99 clips the top 1% of values so detail in the bulk is visible.
forecasting/inference/flare_analysis.py DELETED
The diff for this file is too large to render. See raw diff
 
forecasting/inference/flare_analysis_poster.py DELETED
The diff for this file is too large to render. See raw diff
 
pipeline_config.yaml CHANGED
@@ -5,7 +5,7 @@
5
  #
6
  # Usage:
7
  # python run_pipeline.py --config pipeline_config.yaml --steps all
8
- # python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flare_analysis
9
  # python run_pipeline.py --list
10
  #
11
  # Variables
@@ -152,6 +152,23 @@ spatial_performance:
152
  predictions_csv: "${base_dir}/inference/predictions.csv"
153
  out_dir: "${base_dir}/inference/spatial_performance"
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  # -----------------------------------------------------------------------------
156
  # Evaluation (step: evaluate)
157
  # -----------------------------------------------------------------------------
 
5
  #
6
  # Usage:
7
  # python run_pipeline.py --config pipeline_config.yaml --steps all
8
+ # python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flux_map_analysis
9
  # python run_pipeline.py --list
10
  #
11
  # Variables
 
152
  predictions_csv: "${base_dir}/inference/predictions.csv"
153
  out_dir: "${base_dir}/inference/spatial_performance"
154
 
155
+ # -----------------------------------------------------------------------------
156
+ # Flux map analysis (step: flux_map_analysis)
157
+ # Detects and tracks active regions from per-patch flux contribution maps,
158
+ # renders side-by-side AIA + SXR frames, and stitches them into a movie.
159
+ # -----------------------------------------------------------------------------
160
+ flux_map_analysis:
161
+ config: "analysis/flux_map_config.yaml"
162
+ # overrides: # uncomment to override config values
163
+ # paths:
164
+ # flux_path: "${base_dir}/flux"
165
+ # aia_path: "${base_dir}/AIA_processed"
166
+ # predictions_csv: "${base_dir}/inference/predictions.csv"
167
+ # output_dir: "${base_dir}/inference/flux_map_analysis"
168
+ # time_range:
169
+ # start: null # null = full predictions CSV range
170
+ # end: null
171
+
172
  # -----------------------------------------------------------------------------
173
  # Evaluation (step: evaluate)
174
  # -----------------------------------------------------------------------------
run_pipeline.py CHANGED
@@ -13,12 +13,12 @@ Runs any combination of pipeline steps in order:
13
  6. normalize - Compute SXR normalization stats on train split (data/sxr_normalization.py)
14
  7. train - Train the ViTLocal forecasting model (forecasting/training/train.py)
15
  8. inference - Run batch inference on val/test data (forecasting/inference/inference.py)
16
- 9. flare_analysis - Detect, track, and match flares (forecasting/inference/flare_analysis.py)
17
 
18
  Usage:
19
  python run_pipeline.py --list
20
  python run_pipeline.py --config pipeline_config.yaml --steps all
21
- python run_pipeline.py --config pipeline_config.yaml --steps train,inference,flare_analysis
22
  """
23
 
24
  import argparse
@@ -116,9 +116,9 @@ STEP_ORDER = [
116
  "train",
117
  "inference",
118
  "evaluate",
119
- "flare_analysis",
120
  "ablation",
121
  "spatial_performance",
 
122
  ]
123
 
124
  STEP_INFO = {
@@ -166,10 +166,6 @@ STEP_INFO = {
166
  "description": "Compute metrics and generate evaluation plots from predictions CSV",
167
  "script": ROOT / "forecasting" / "inference" / "evaluation.py",
168
  },
169
- "flare_analysis": {
170
- "description": "Detect, track, and match flares; generate plots/movies",
171
- "script": ROOT / "forecasting" / "inference" / "flare_analysis.py",
172
- },
173
  "ablation": {
174
  "description": "Run Gaussian noise channel-masking ablation study on pretrained model",
175
  "script": ROOT / "forecasting" / "inference" / "ablation_inference.py",
@@ -178,6 +174,10 @@ STEP_INFO = {
178
  "description": "Generate flux-weighted spatial error heatmap on the solar disk",
179
  "script": ROOT / "analysis" / "spatial_performance.py",
180
  },
 
 
 
 
181
  }
182
 
183
 
@@ -329,15 +329,6 @@ def build_commands(step: str, cfg: dict, force: bool) -> list[list[str]] | None:
329
  config_path = str(write_merged_config(config_path, ev["overrides"], "evaluate_config"))
330
  return [base + ["-config", config_path]]
331
 
332
- if step == "flare_analysis":
333
- if not require(["config"], "inference"):
334
- return None
335
- inf = cfg["inference"]
336
- config_path = inf["config"]
337
- if inf.get("overrides"):
338
- config_path = str(write_merged_config(config_path, inf["overrides"], "inference_config"))
339
- return [base + ["--config", config_path]]
340
-
341
  if step == "ablation":
342
  if not require(["config"], "ablation"):
343
  return None
@@ -358,6 +349,15 @@ def build_commands(step: str, cfg: dict, force: bool) -> list[list[str]] | None:
358
  cmd += ["--out_dir", sp["out_dir"]]
359
  return [cmd]
360
 
 
 
 
 
 
 
 
 
 
361
  return [base]
362
 
363
 
@@ -398,7 +398,7 @@ def list_steps():
398
  print(f" {i}. {step:<16} {STEP_INFO[step]['description']}")
399
  print()
400
  print("Use --steps all to run every step, or comma-separate specific steps.")
401
- print("Example: --steps train,inference,flare_analysis\n")
402
 
403
 
404
  def main():
 
13
  6. normalize - Compute SXR normalization stats on train split (data/sxr_normalization.py)
14
  7. train - Train the ViTLocal forecasting model (forecasting/training/train.py)
15
  8. inference - Run batch inference on val/test data (forecasting/inference/inference.py)
16
+ 9. flux_map_analysis - Detect, track, and match flares (analysis/flux_map_analysis.py)
17
 
18
  Usage:
19
  python run_pipeline.py --list
20
  python run_pipeline.py --config pipeline_config.yaml --steps all
21
+ python run_pipeline.py --config pipeline_config.yaml --steps train,inference
22
  """
23
 
24
  import argparse
 
116
  "train",
117
  "inference",
118
  "evaluate",
 
119
  "ablation",
120
  "spatial_performance",
121
+ "flux_map_analysis",
122
  ]
123
 
124
  STEP_INFO = {
 
166
  "description": "Compute metrics and generate evaluation plots from predictions CSV",
167
  "script": ROOT / "forecasting" / "inference" / "evaluation.py",
168
  },
 
 
 
 
169
  "ablation": {
170
  "description": "Run Gaussian noise channel-masking ablation study on pretrained model",
171
  "script": ROOT / "forecasting" / "inference" / "ablation_inference.py",
 
174
  "description": "Generate flux-weighted spatial error heatmap on the solar disk",
175
  "script": ROOT / "analysis" / "spatial_performance.py",
176
  },
177
+ "flux_map_analysis": {
178
+ "description": "Detect and track active regions from flux maps; render per-frame movie",
179
+ "script": ROOT / "analysis" / "flux_map_analysis.py",
180
+ },
181
  }
182
 
183
 
 
329
  config_path = str(write_merged_config(config_path, ev["overrides"], "evaluate_config"))
330
  return [base + ["-config", config_path]]
331
 
 
 
 
 
 
 
 
 
 
332
  if step == "ablation":
333
  if not require(["config"], "ablation"):
334
  return None
 
349
  cmd += ["--out_dir", sp["out_dir"]]
350
  return [cmd]
351
 
352
+ if step == "flux_map_analysis":
353
+ if not require(["config"], "flux_map_analysis"):
354
+ return None
355
+ fma = cfg["flux_map_analysis"]
356
+ config_path = fma["config"]
357
+ if fma.get("overrides"):
358
+ config_path = str(write_merged_config(config_path, fma["overrides"], "flux_map_analysis_config"))
359
+ return [base + ["--config", config_path]]
360
+
361
  return [base]
362
 
363
 
 
398
  print(f" {i}. {step:<16} {STEP_INFO[step]['description']}")
399
  print()
400
  print("Use --steps all to run every step, or comma-separate specific steps.")
401
+ print("Example: --steps train,inference\n")
402
 
403
 
404
  def main():