MesserMMP commited on
Commit
e4194f4
·
1 Parent(s): c2d9714

add files for inference

Browse files
inference/metrics_visualization.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Визуализация предсказаний SYNTAX:
3
+ - точки (SYNTAX GT vs предсказания модели) для нескольких датасетов;
4
+ - зоны риска (низкий / высокий риск);
5
+ - области ±σ и ±2σ вокруг диагонали;
6
+ - логистические тренды для каждого датасета.
7
+
8
+ Скрипт не зависит от PyTorch/Lightning и используется на этапе инференса.
9
+ Сохранение осуществляется в папку `visualizations/` внутри проекта.
10
+ """
11
+
12
+ import os
13
+ import numpy as np
14
+ import plotly.graph_objects as go
15
+ from scipy.optimize import curve_fit # type: ignore
16
+
17
+
18
+ def visualize_final_syntax_plotly_multi(
19
+ datasets,
20
+ r2_values,
21
+ gt_row,
22
+ postfix=None,
23
+ threshold=22.0,
24
+ recall_values=None,
25
+ backbone=False,
26
+ ):
27
+ """
28
+ Единая визуализация SYNTAX: точки, зоны риска и логистические тренды.
29
+
30
+ Параметры
31
+ ---------
32
+ datasets : dict[str, tuple[list[float], list[float]]]
33
+ Словарь {имя_датасета: (syntax_true_list, syntax_pred_list)}.
34
+ r2_values : dict[str, float]
35
+ Словарь R^2 по датасетам.
36
+ gt_row : str
37
+ Строка, попадающая в заголовок (например, "ENSEMBLE" или "BOTH").
38
+ postfix : str | None
39
+ Суффикс для имени сохраняемого файла.
40
+ threshold : float
41
+ Порог SYNTAX (обычно 22.0) для разделения зон риска.
42
+ recall_values : dict[str, float] | None
43
+ Словарь Recall по датасетам (может быть None).
44
+ backbone : bool
45
+ Если True, сохраняет в `visualizations/backbone`, иначе в `visualizations/`.
46
+ """
47
+ # ========== КОНСТАНТЫ ДЛЯ НАСТРОЙКИ ==========
48
+ DATA_MIN = 0.0
49
+ DATA_MAX = 60.0
50
+
51
+ PADDING = 0.5
52
+
53
+ SIGMA_SLOPE = 0.15
54
+ SIGMA_BASE = 1.4
55
+
56
+ PLOT_WIDTH = 980
57
+ PLOT_HEIGHT = 980
58
+
59
+ BASE_FONT_SIZE = 16
60
+ TITLE_FONT_SIZE = 22
61
+ AXIS_LABEL_FONT_SIZE = BASE_FONT_SIZE
62
+ AXIS_TICK_FONT_SIZE = 15
63
+ LEGEND_FONT_SIZE = 14
64
+
65
+ MARKER_SIZE = 11
66
+ MARKER_LINE_WIDTH = 1.1
67
+ LINE_WIDTH = 2
68
+ TREND_LINE_WIDTH = 3
69
+
70
+ PLOT_BG_COLOR = "rgba(235,238,245,1)"
71
+ PAPER_BG_COLOR = "white"
72
+ LEGEND_BG_COLOR = "rgba(255,255,255,0.94)"
73
+ GRID_COLOR = "rgba(100,116,139,0.18)"
74
+
75
+ MARGIN_LEFT = 70
76
+ MARGIN_RIGHT = 24
77
+ MARGIN_TOP = 78
78
+ MARGIN_BOTTOM = 70
79
+
80
+ LEGEND_X = 0.04
81
+ LEGEND_Y = 0.99
82
+
83
+ COLORS = ["#1E88E5", "#8E24AA", "#A0D137", "#EA1D1D", "#06EE0D", "#FB8C00"]
84
+ SYMBOLS = ["circle", "x", "square", "diamond", "triangle-up", "star"]
85
+
86
+ SIGMA_POINTS = 400
87
+ TREND_POINTS = 500
88
+
89
+ # ========== ВСПОМОГАТЕЛЬНЫЕ ФУНКЦИИ ==========
90
+
91
+ def _logistic_time(t, R0, Rmax, t50, k):
92
+ """Логистическая функция по времени/оценке SYNTAX."""
93
+ t = np.asarray(t, dtype=float)
94
+ t_safe = np.where(t <= 0, 1e-3, t)
95
+ return R0 + (Rmax - R0) / (1.0 + (t50 / t_safe) ** k)
96
+
97
+ def _fit_logistic(x, y, domain=(DATA_MIN, DATA_MAX), n=TREND_POINTS):
98
+ """
99
+ Аппроксимация логистической кривой.
100
+ Возвращает X, Y или (None, None), если фит не удался.
101
+ """
102
+ x = np.asarray(x, dtype=float)
103
+ y = np.asarray(y, dtype=float)
104
+ m = np.isfinite(x) & np.isfinite(y)
105
+ if m.sum() < 4:
106
+ return None, None
107
+
108
+ x_m, y_m = x[m], y[m]
109
+ x_min = max(float(np.min(x_m)), float(domain[0]))
110
+ x_max = min(float(np.max(x_m)), float(domain[1]))
111
+ if not np.isfinite(x_min) or not np.isfinite(x_max) or x_max <= x_min:
112
+ return None, None
113
+
114
+ x_pos = x_m[x_m > 0]
115
+ if x_pos.size == 0:
116
+ return None, None
117
+
118
+ R0_init = float(np.percentile(y_m, 10))
119
+ Rmax_init = float(np.percentile(y_m, 90))
120
+ t50_init = float(np.median(x_pos))
121
+ k_init = 1.0
122
+
123
+ lower = [-10.0, 0.0, 1e-3, 0.01]
124
+ upper = [60.0, 80.0, 60.0, 10.0]
125
+
126
+ try:
127
+ popt, _ = curve_fit(
128
+ _logistic_time,
129
+ x_m,
130
+ y_m,
131
+ p0=[R0_init, Rmax_init, t50_init, k_init],
132
+ bounds=(lower, upper),
133
+ maxfev=20000,
134
+ )
135
+ except Exception:
136
+ return None, None
137
+
138
+ X = np.linspace(x_min, x_max, n)
139
+ Y = _logistic_time(X, *popt)
140
+ return X, Y
141
+
142
+ # ========== ОСНОВНОЙ КОД ==========
143
+ fig = go.Figure()
144
+
145
+ line_min = DATA_MIN - PADDING
146
+ line_max = DATA_MAX + PADDING
147
+ domain = (line_min, line_max)
148
+
149
+ base_font = dict(
150
+ family="Inter, Roboto, Helvetica Neue, Arial, sans-serif",
151
+ size=BASE_FONT_SIZE,
152
+ )
153
+
154
+ # ---------- Пороги и линии (legendrank=0) ----------
155
+ fig.add_trace(
156
+ go.Scatter(
157
+ x=[line_min, threshold, threshold, line_min],
158
+ y=[line_min, line_min, threshold, threshold],
159
+ fill="toself",
160
+ fillcolor="rgba(255, 82, 82, 0.12)",
161
+ line=dict(color="rgba(0,0,0,0)"),
162
+ name="Low-risk zone",
163
+ legendgroup="zones",
164
+ legendgrouptitle_text="Пороги и линии",
165
+ showlegend=True,
166
+ hoverinfo="skip",
167
+ legendrank=0,
168
+ )
169
+ )
170
+ fig.add_trace(
171
+ go.Scatter(
172
+ x=[threshold, line_max, line_max, threshold],
173
+ y=[threshold, threshold, line_max, line_max],
174
+ fill="toself",
175
+ fillcolor="rgba(76, 175, 80, 0.14)",
176
+ line=dict(color="rgba(0,0,0,0)"),
177
+ name="High-risk zone",
178
+ legendgroup="zones",
179
+ showlegend=True,
180
+ hoverinfo="skip",
181
+ legendrank=0,
182
+ )
183
+ )
184
+
185
+ fig.add_trace(
186
+ go.Scatter(
187
+ x=[threshold, threshold, None, line_min, line_max],
188
+ y=[line_min, line_max, None, threshold, threshold],
189
+ mode="lines",
190
+ name=rf"$\mathrm{{SYNTAX}}={threshold}$",
191
+ legendgroup="zones",
192
+ showlegend=True,
193
+ line=dict(color="rgba(46,125,50,0.85)", width=LINE_WIDTH, dash="dash"),
194
+ legendrank=0,
195
+ hoverinfo="skip",
196
+ )
197
+ )
198
+
199
+ x_vals = np.linspace(line_min, line_max, SIGMA_POINTS)
200
+ sigma_upper = x_vals + SIGMA_BASE + SIGMA_SLOPE * x_vals
201
+ sigma_lower = x_vals - SIGMA_BASE - SIGMA_SLOPE * x_vals
202
+ two_sigma_upper = x_vals + 2 * SIGMA_BASE + 2 * SIGMA_SLOPE * x_vals
203
+ two_sigma_lower = x_vals - 2 * SIGMA_BASE - 2 * SIGMA_SLOPE * x_vals
204
+
205
+ fig.add_trace(
206
+ go.Scatter(
207
+ x=np.concatenate([x_vals, x_vals[::-1]]),
208
+ y=np.concatenate([two_sigma_lower, two_sigma_upper[::-1]]),
209
+ fill="toself",
210
+ fillcolor="rgba(255,193,7,0.18)",
211
+ line=dict(color="rgba(0,0,0,0)"),
212
+ name=r"$\pm 2\sigma$",
213
+ legendgroup="zones",
214
+ showlegend=True,
215
+ hoverinfo="skip",
216
+ legendrank=0,
217
+ )
218
+ )
219
+ fig.add_trace(
220
+ go.Scatter(
221
+ x=np.concatenate([x_vals, x_vals[::-1]]),
222
+ y=np.concatenate([sigma_lower, sigma_upper[::-1]]),
223
+ fill="toself",
224
+ fillcolor="rgba(255,152,0,0.30)",
225
+ line=dict(color="rgba(0,0,0,0)"),
226
+ name=r"$\pm \sigma$",
227
+ legendgroup="zones",
228
+ showlegend=True,
229
+ hoverinfo="skip",
230
+ legendrank=0,
231
+ )
232
+ )
233
+
234
+ fig.add_trace(
235
+ go.Scatter(
236
+ x=[line_min, line_max],
237
+ y=[line_min, line_max],
238
+ mode="lines",
239
+ name=r"$y=x$",
240
+ legendgroup="zones",
241
+ showlegend=True,
242
+ line=dict(color="rgba(30,30,30,0.85)", width=LINE_WIDTH),
243
+ legendrank=0,
244
+ )
245
+ )
246
+
247
+ # ---------- Датасеты (legendrank=20) ----------
248
+ first_dataset = True
249
+ for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
250
+ x = np.array(syntax_true, dtype=float)
251
+ y = np.array(syntax_pred, dtype=float)
252
+ if x.size == 0 or y.size == 0:
253
+ continue
254
+
255
+ r2 = r2_values.get(label, None)
256
+ recall = recall_values.get(label, None) if recall_values else None
257
+ hover_lines = [f"<b>{label}</b>"]
258
+ if r2 is not None:
259
+ hover_lines.append(f"R² = {r2:.3f}")
260
+ if recall is not None:
261
+ hover_lines.append(f"Recall = {recall:.3f}")
262
+ hovertemplate = (
263
+ "<br>".join(hover_lines)
264
+ + "<br>GT: %{x:.3f}<br>Pred: %{y:.3f}<extra></extra>"
265
+ )
266
+
267
+ fig.add_trace(
268
+ go.Scatter(
269
+ x=x,
270
+ y=y,
271
+ mode="markers",
272
+ name=label,
273
+ legendgroup="datasets",
274
+ legendgrouptitle_text=("Датасеты" if first_dataset else None),
275
+ showlegend=True,
276
+ marker=dict(
277
+ color=COLORS[i % len(COLORS)],
278
+ size=MARKER_SIZE,
279
+ opacity=0.96,
280
+ symbol=SYMBOLS[i % len(SYMBOLS)],
281
+ line=dict(
282
+ width=MARKER_LINE_WIDTH, color="rgba(255,255,255,0.95)"
283
+ ),
284
+ ),
285
+ hovertemplate=hovertemplate,
286
+ legendrank=20,
287
+ )
288
+ )
289
+ first_dataset = False
290
+
291
+ # ---------- Тренды: логистические (legendrank=30) ----------
292
+ first_trend = True
293
+ for i, (label, (syntax_true, syntax_pred)) in enumerate(datasets.items()):
294
+ x = np.array(syntax_true, dtype=float)
295
+ y = np.array(syntax_pred, dtype=float)
296
+ if x.size == 0 or y.size == 0:
297
+ continue
298
+
299
+ Xc, Yc = _fit_logistic(x, y, domain=domain)
300
+ if Xc is not None:
301
+ fig.add_trace(
302
+ go.Scatter(
303
+ x=Xc,
304
+ y=Yc,
305
+ mode="lines",
306
+ name=label, # без коротких alias, полное имя датасета
307
+ legendgroup="trends",
308
+ legendgrouptitle_text=(
309
+ "Тренды (логистические)" if first_trend else None
310
+ ),
311
+ showlegend=True,
312
+ line=dict(
313
+ color=COLORS[i % len(COLORS)], width=TREND_LINE_WIDTH
314
+ ),
315
+ hoverinfo="skip",
316
+ legendrank=30,
317
+ )
318
+ )
319
+ first_trend = False
320
+
321
+ # ---------- оформление ----------
322
+ title_text = f"SYNTAX predictions ({gt_row})"
323
+ if postfix:
324
+ title_text += f" {postfix}"
325
+
326
+ fig.update_layout(
327
+ title=dict(
328
+ text=title_text,
329
+ x=0.5,
330
+ xanchor="center",
331
+ font=dict(
332
+ size=TITLE_FONT_SIZE,
333
+ family=base_font["family"],
334
+ color="rgba(15,23,42,1)",
335
+ ),
336
+ ),
337
+ font=base_font,
338
+ xaxis_title=r"$\mathrm{SYNTAX\ GT}$",
339
+ yaxis_title=r"$\mathrm{SYNTAX\ predictions}$",
340
+ width=PLOT_WIDTH,
341
+ height=PLOT_HEIGHT,
342
+ plot_bgcolor=PLOT_BG_COLOR,
343
+ paper_bgcolor=PAPER_BG_COLOR,
344
+ legend=dict(
345
+ x=LEGEND_X,
346
+ y=LEGEND_Y,
347
+ bgcolor=LEGEND_BG_COLOR,
348
+ bordercolor="#CBD5E1",
349
+ borderwidth=1,
350
+ font=dict(size=LEGEND_FONT_SIZE, family=base_font["family"]),
351
+ tracegroupgap=8,
352
+ itemclick="toggle",
353
+ itemdoubleclick="toggleothers",
354
+ groupclick="toggleitem",
355
+ ),
356
+ xaxis=dict(
357
+ showgrid=True,
358
+ gridcolor=GRID_COLOR,
359
+ gridwidth=1,
360
+ zeroline=False,
361
+ tickfont=dict(size=AXIS_TICK_FONT_SIZE),
362
+ range=[line_min, line_max],
363
+ constrain="domain",
364
+ ),
365
+ yaxis=dict(
366
+ showgrid=True,
367
+ gridcolor=GRID_COLOR,
368
+ gridwidth=1,
369
+ zeroline=False,
370
+ tickfont=dict(size=AXIS_TICK_FONT_SIZE),
371
+ range=[line_min, line_max],
372
+ scaleanchor="x",
373
+ scaleratio=1,
374
+ constrain="domain",
375
+ ),
376
+ margin=dict(
377
+ l=MARGIN_LEFT,
378
+ r=MARGIN_RIGHT,
379
+ t=MARGIN_TOP,
380
+ b=MARGIN_BOTTOM,
381
+ ),
382
+ )
383
+
384
+ # ---------- сохранение ----------
385
+ save_dir = "visualizations"
386
+ if backbone:
387
+ save_dir = os.path.join(save_dir, "backbone")
388
+ os.makedirs(save_dir, exist_ok=True)
389
+
390
+ postfix_html = f"{postfix}" if postfix else "syntax"
391
+ save_path_html = os.path.join(save_dir, f"{postfix_html}.html")
392
+ fig.write_html(save_path_html, include_mathjax="cdn")
393
+ print(f"Saved visualization with logistic trends: {save_path_html}")
inference/rnn_apply.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import tqdm
4
+ import torch
5
+ import numpy as np
6
+ import click
7
+ from datetime import datetime
8
+ import lightning.pytorch as pl
9
+ import sklearn.metrics as skm
10
+
11
+ from torch.utils.data import DataLoader
12
+ from torchvision.transforms import transforms as T
13
+ from torchvision.transforms._transforms_video import ToTensorVideo
14
+ from pytorchvideo.transforms import Normalize
15
+
16
+ # Импорты из соседних папок (относительные пути)
17
+ from full_model.rnn_dataset import SyntaxDataset
18
+ from full_model.rnn_model import SyntaxLightningModule
19
+ from metrics_visualization import visualize_final_syntax_plotly_multi
20
+
21
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
22
+ print(f"DEVICE: {DEVICE}")
23
+
24
+
25
+ def safe_sample_std(values):
26
+ """Sample std (ddof=1). Если значение одно/пусто — 0.0."""
27
+ arr = np.array(values, dtype=float)
28
+ if arr.size <= 1:
29
+ return 0.0
30
+ return float(arr.std(ddof=1))
31
+
32
+
33
+ def compute_metrics(y_true, y_pred, thr=22.0):
34
+ """R2, MAE, Pearson, MAPE, Mean_Recall."""
35
+ y_true_arr = np.array(y_true, dtype=float)
36
+ y_pred_arr = np.array(y_pred, dtype=float)
37
+
38
+ r2 = float(skm.r2_score(y_true_arr, y_pred_arr))
39
+ mae = float(skm.mean_absolute_error(y_true_arr, y_pred_arr))
40
+
41
+ pearson = float(np.corrcoef(y_true_arr, y_pred_arr)[0, 1]) if len(y_true_arr) > 1 else 0.0
42
+ mape = float(skm.mean_absolute_percentage_error(y_true_arr, y_pred_arr))
43
+
44
+ y_true_bin = (y_true_arr >= thr).astype(int)
45
+ y_pred_bin = (y_pred_arr >= thr).astype(int)
46
+ unique_classes = np.unique(np.concatenate([y_true_bin, y_pred_bin]))
47
+ mean_recall = float(np.mean(skm.recall_score(y_true_bin, y_pred_bin, average=None, labels=[0, 1]))) \
48
+ if len(unique_classes) > 1 else 0.0
49
+
50
+ return r2, mae, pearson, mape, mean_recall
51
+
52
+
53
+ @click.command()
54
+ @click.option("-d", "--dataset-paths", multiple=True,
55
+ help="JSON с метаданными датасетов (относительно dataset_root).")
56
+ @click.option("-n", "--dataset-names", multiple=True,
57
+ help="Имена датасетов для метрик/графиков.")
58
+ @click.option("-p", "--postfixes", multiple=True,
59
+ help="Суффиксы для файлов предсказаний.")
60
+ @click.option("-r", "--dataset-root", type=click.Path(exists=True),
61
+ help="Корень датасета (где лежат JSON и DICOM).")
62
+ @click.option("-v", "--video-size", type=click.Tuple([int, int]),
63
+ help="Размер видео (H, W).")
64
+ @click.option("--frames-per-clip",
65
+ help="Количество кадров в клипе.")
66
+ @click.option("--num-workers",
67
+ help="Число DataLoader workers.")
68
+ @click.option("--seed",
69
+ help="Random seed.")
70
+ @click.option("--pt-weights-format", is_flag=True,
71
+ help="True → модели в .pt (torch.save), False → .ckpt (Lightning).")
72
+ @click.option("--use-scaling", is_flag=True,
73
+ help="Применить a*x+b scaling из JSON.")
74
+ @click.option("--scaling-file",
75
+ help="JSON с коэффициентами scaling (относительно dataset_root).")
76
+ @click.option("-e", "--ensemble-name",
77
+ help="Имя ансамбля в metrics.json.")
78
+ @click.option("-m", "--metrics-file",
79
+ help="JSON с метриками экспериментов.")
80
+ def main(dataset_paths, dataset_names, postfixes, dataset_root, video_size,
81
+ frames_per_clip, num_workers, seed, pt_weights_format, use_scaling,
82
+ scaling_file, ensemble_name, metrics_file):
83
+
84
+ pl.seed_everything(seed)
85
+ postfix_plotly = "Ensemble"
86
+
87
+ # Пути к моделям (backbone + RNN-head)
88
+ model_paths = {
89
+ "left": [
90
+ "full_model/checkpoints/leftBinSyntax_R3D_fold00_lstm_mean_post_best.pt",
91
+ "full_model/checkpoints/leftBinSyntax_R3D_fold01_lstm_mean_post_best.pt",
92
+ "full_model/checkpoints/leftBinSyntax_R3D_fold02_lstm_mean_post_best.pt",
93
+ "full_model/checkpoints/leftBinSyntax_R3D_fold03_lstm_mean_post_best.pt",
94
+ "full_model/checkpoints/leftBinSyntax_R3D_fold04_lstm_mean_post_best.pt",
95
+ ],
96
+ "right": [
97
+ "full_model/checkpoints/rightBinSyntax_R3D_fold00_lstm_mean_post_best.pt",
98
+ "full_model/checkpoints/rightBinSyntax_R3D_fold01_lstm_mean_post_best.pt",
99
+ "full_model/checkpoints/rightBinSyntax_R3D_fold02_lstm_mean_post_best.pt",
100
+ "full_model/checkpoints/rightBinSyntax_R3D_fold03_lstm_mean_post_best.pt",
101
+ "full_model/checkpoints/rightBinSyntax_R3D_fold04_lstm_mean_post_best.pt",
102
+ ]
103
+ }
104
+
105
+ # Scaling параметры
106
+ scaling_params_dict = {}
107
+ if use_scaling:
108
+ postfix_plotly += "_scaled"
109
+ ensemble_name += "_scaled"
110
+ scaling_path = os.path.join(dataset_root, scaling_file)
111
+ if os.path.exists(scaling_path):
112
+ with open(scaling_path, "r") as f:
113
+ scaling_params_dict = json.load(f)
114
+ print(f"Loaded scaling from {scaling_path}")
115
+ else:
116
+ print(f"⚠️ Scaling file not found: {scaling_path}")
117
+
118
+ # Результаты ансамбля
119
+ ensemble_results = {
120
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
121
+ "use_scaling": use_scaling,
122
+ "pt_weights_format": pt_weights_format,
123
+ "datasets": {}
124
+ }
125
+
126
+ all_datasets, all_r2, all_recalls = {}, {}, {}
127
+
128
+ for dataset_path, dataset_name, postfix in zip(dataset_paths, dataset_names, postfixes):
129
+ # Относительные пути
130
+ abs_dataset_path = os.path.join(dataset_root, dataset_path)
131
+ results_file = os.path.join(dataset_root, "coeffs", f"{postfix}.json")
132
+
133
+ # Загрузка/вычисление предсказаний
134
+ if os.path.exists(results_file):
135
+ print(f"[{postfix}] Loading from {results_file}")
136
+ with open(results_file, "r") as f:
137
+ data = json.load(f)
138
+ syntax_true = data["syntax_true"]
139
+ left_preds_all = data["left_preds"]
140
+ right_preds_all = data["right_preds"]
141
+ else:
142
+ print(f"[{postfix}] Computing predictions...")
143
+ left_preds_all, left_sids = run_artery(
144
+ abs_dataset_path, "left", model_paths["left"],
145
+ video_size, frames_per_clip, num_workers, pt_weights_format
146
+ )
147
+ right_preds_all, right_sids = run_artery(
148
+ abs_dataset_path, "right", model_paths["right"],
149
+ video_size, frames_per_clip, num_workers, pt_weights_format
150
+ )
151
+ assert left_sids == right_sids
152
+
153
+ with open(abs_dataset_path, "r") as f:
154
+ dataset = json.load(f)
155
+ syntax_true = [rec.get("mean_syntax", rec.get("syntax")) for rec in dataset]
156
+
157
+ os.makedirs(os.path.dirname(results_file), exist_ok=True)
158
+ save_data = {
159
+ "syntax_true": syntax_true,
160
+ "left_preds": left_preds_all,
161
+ "right_preds": right_preds_all
162
+ }
163
+ with open(results_file, "w") as f:
164
+ json.dump(save_data, f)
165
+ print(f"[{postfix}] Saved to {results_file}")
166
+
167
+ # Scaling (fold-wise для left/right)
168
+ if use_scaling:
169
+ left_scaled_all, right_scaled_all = [], []
170
+ for pred_list in left_preds_all:
171
+ scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val +
172
+ scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1]
173
+ for i, val in enumerate(pred_list)]
174
+ left_scaled_all.append(scaled)
175
+ for pred_list in right_preds_all:
176
+ scaled = [scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[0] * val +
177
+ scaling_params_dict.get(f"fold{i}", (1.0, 0.0))[1]
178
+ for i, val in enumerate(pred_list)]
179
+ right_scaled_all.append(scaled)
180
+ else:
181
+ left_scaled_all, right_scaled_all = left_preds_all, right_preds_all
182
+
183
+ # Ансамбль: mean по фолдам + left+right
184
+ syntax_pred = [max(0.0, float(np.mean([l + r for l, r in zip(l_list, r_list)])))
185
+ for l_list, r_list in zip(left_scaled_all, right_scaled_all)]
186
+
187
+ # Метрики ансамбля
188
+ r2, mae, pearson, mape, mean_recall = compute_metrics(syntax_true, syntax_pred)
189
+ print(f"[{postfix}] ENSEMBLE: R2={r2:.4f}, Pearson={pearson:.4f}, "
190
+ f"MAE={mae:.4f}, MAPE={mape:.4f}, Recall={mean_recall:.4f}")
191
+
192
+ # STD по фолдам
193
+ n_folds = len(left_scaled_all[0]) if left_scaled_all else 0
194
+ fold_metrics = {metric: [] for metric in ["R2", "MAE", "Pearson", "MAPE", "Mean_Recall"]}
195
+ for k in range(n_folds):
196
+ pred_k = [max(0.0, l_list[k] + r_list[k])
197
+ for l_list, r_list in zip(left_scaled_all, right_scaled_all)]
198
+ fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall = compute_metrics(syntax_true, pred_k)
199
+ for metric, value in zip(fold_metrics.keys(),
200
+ [fold_r2, fold_mae, fold_pearson, fold_mape, fold_recall]):
201
+ fold_metrics[metric].append(value)
202
+
203
+ fold_summary = {k: {"mean": float(np.mean(v)), "std": safe_sample_std(v), "values": v}
204
+ for k, v in fold_metrics.items()}
205
+
206
+ # Визуализация и сохранение
207
+ all_datasets[dataset_name] = (syntax_true, syntax_pred)
208
+ all_r2[dataset_name] = r2
209
+ all_recalls[dataset_name] = mean_recall
210
+
211
+ ensemble_results["datasets"][dataset_name] = {
212
+ # Ансамбль
213
+ "R2": round(r2, 4), "MAE": round(mae, 4),
214
+ "Pearson": round(pearson, 4), "MAPE": round(mape, 4),
215
+ "Mean_Recall": round(mean_recall, 4), "N_samples": len(syntax_true),
216
+ # По фолдам (mean±std)
217
+ **{f"{k}_mean": round(v["mean"], 4) for k, v in fold_summary.items()},
218
+ **{f"{k}_std": round(v["std"], 4) for k, v in fold_summary.items()},
219
+ **{f"{k}_folds": [round(x, 4) for x in v["values"]] for k, v in fold_summary.items()}
220
+ }
221
+
222
+ # Сохранение метрик
223
+ metrics_path = os.path.join(dataset_root, metrics_file)
224
+ full_history = {}
225
+ if os.path.exists(metrics_path):
226
+ try:
227
+ with open(metrics_path, "r") as f:
228
+ full_history = json.load(f)
229
+ except json.JSONDecodeError:
230
+ print("⚠️ Metrics file corrupted. Creating new.")
231
+
232
+ full_history[ensemble_name] = ensemble_results
233
+ with open(metrics_path, "w") as f:
234
+ json.dump(full_history, f, indent=4)
235
+ print(f"✅ Metrics saved: {metrics_path}")
236
+
237
+ # Визуализация
238
+ visualize_final_syntax_plotly_multi(
239
+ datasets=all_datasets, r2_values=all_r2, recall_values=all_recalls,
240
+ gt_row="ENSEMBLE", postfix=postfix_plotly
241
+ )
242
+
243
+
244
+ def run_artery(dataset_path, artery, model_paths, video_size, frames_per_clip,
245
+ num_workers, pt_weights_format=False):
246
+ """Инференс для одной артерии (5 фолдов)."""
247
+ imagenet_mean = [0.485, 0.456, 0.406]
248
+ imagenet_std = [0.229, 0.224, 0.225]
249
+ test_transform = T.Compose([
250
+ ToTensorVideo(),
251
+ T.Resize(size=video_size, antialias=True),
252
+ Normalize(mean=imagenet_mean, std=imagenet_std),
253
+ ])
254
+
255
+ val_set = SyntaxDataset(
256
+ root=os.path.dirname(dataset_path),
257
+ meta=dataset_path,
258
+ train=False,
259
+ length=frames_per_clip,
260
+ label="", # inference mode
261
+ artery=artery,
262
+ inference=True,
263
+ transform=test_transform
264
+ )
265
+ val_loader = DataLoader(val_set, batch_size=1, num_workers=num_workers,
266
+ shuffle=False, pin_memory=True)
267
+ print(f"{artery} artery: {len(val_loader)} samples")
268
+
269
+ models = []
270
+ for path in model_paths:
271
+ if not os.path.exists(path):
272
+ print(f"⚠️ Model not found: {path}")
273
+ continue
274
+ model = SyntaxLightningModule(
275
+ num_classes=2, lr=1e-5, variant="lstm_mean",
276
+ weight_decay=0.001, max_epochs=1,
277
+ pl_weight_path=path, pt_weights_format=pt_weights_format
278
+ )
279
+ model.to(DEVICE)
280
+ model.eval()
281
+ models.append(model)
282
+ if not models:
283
+ raise RuntimeError(f"No models loaded for {artery}")
284
+
285
+ preds_all, sids = [], []
286
+ with torch.no_grad():
287
+ for x, [y], [t], [sid] in tqdm.tqdm(val_loader, desc=f"{artery} infer"):
288
+ if len(x.shape) == 1: # пустое видео
289
+ val_syntax_list = [0.0] * len(models)
290
+ else:
291
+ x = x.to(DEVICE)
292
+ val_syntax_list = []
293
+ for model in models:
294
+ pred = model(x)
295
+ _, val_log = pred # регрессионный logit
296
+ val = float(torch.exp(val_log).cpu()) - 1
297
+ val_syntax_list.append(val)
298
+ preds_all.append(val_syntax_list)
299
+ sids.append(sid[0]) # study_uid
300
+
301
+ return preds_all, sids
302
+
303
+
304
+ if __name__ == "__main__":
305
+ main()