File size: 34,789 Bytes
5a169ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
import os
import torch
import nibabel as nib
from flask import Flask, request, render_template, redirect, url_for, flash, jsonify
import tempfile
import yaml
import traceback # For detailed error printing
import zipfile
import dicom2nifti
import shutil
import subprocess # To run unzip command
import SimpleITK as sitk
import itk
import numpy as np
from scipy.signal import medfilt
import skimage.filters
import cv2 # For Gaussian Blur
import io # For saving plots to memory
import base64 # For encoding plots
import uuid # For unique IDs

# Configure Matplotlib for non-GUI backend *before* importing pyplot
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

# --- Preprocessing Imports ---
try:
    # Adjust import path based on Docker structure
    # Assumes HD_BET is now at /app/BrainIAC/HD_BET
    from HD_BET.run import run_hd_bet
    # Import MONAI saliency visualizer
    from monai.visualize.gradient_based import GuidedBackpropSmoothGrad
except ImportError as e:
    print(f"Could not import HD_BET or MONAI visualize: {e}. Advanced features might fail.")
    run_hd_bet = None
    GuidedBackpropSmoothGrad = None

# Import necessary components from your existing modules
from model import Backbone, SingleScanModel, Classifier
# Removed: from dataset2 import NormalSynchronizedTransform3D
# Import specific MONAI transforms needed
from monai.transforms import Resized, ScaleIntensityd # Removed ToTensord, will handle manually

app = Flask(__name__)
app.secret_key = 'supersecretkey' # Needed for flashing messages

# --- Constants for Preprocessing ---
APP_DIR = os.path.dirname(__file__)
TEMPLATE_DIR = os.path.join(APP_DIR, "golden_image", "mni_templates")
PARAMS_RIGID_PATH = os.path.join(APP_DIR, "golden_image", "mni_templates", "Parameters_Rigid.txt")
DEFAULT_TEMPLATE_PATH = os.path.join(TEMPLATE_DIR, "nihpd_asym_13.0-18.5_t1w.nii") # Using adult template as default
HD_BET_CONFIG_PATH = os.path.join(APP_DIR, "HD_BET", "config.py")
HD_BET_MODEL_DIR = os.path.join(APP_DIR, "hdbet_model") # Path to copied models

# --- Configuration Loading ---
def load_config():
    # Assuming config.yml is in the same directory as app.py
    config_path = os.path.join(APP_DIR, 'config.yml')
    try:
        with open(config_path, 'r') as file:
            config = yaml.safe_load(file)
        # Add default image_size if not present in config
        if 'data' not in config: config['data'] = {}
        if 'image_size' not in config['data']: config['data']['image_size'] = [128, 128, 128]

    except FileNotFoundError:
        print(f"Error: Configuration file not found at {config_path}")
        # Provide default config or handle error appropriately
        config = {
            'gpu': {'device': 'cpu'},
            'infer': {'checkpoints': 'checkpoints/brainage_model_latest.pt'},
            'data': {'image_size': [128, 128, 128]} # Default image size
        }
    return config

config = load_config()
# Ensure image_size is available, e.g., from config or a default
DEFAULT_IMAGE_SIZE = (128, 128, 128)
image_size_cfg = config.get('data', {}).get('image_size', DEFAULT_IMAGE_SIZE)
# Validate image_size format
if not isinstance(image_size_cfg, (list, tuple)) or len(image_size_cfg) != 3:
    print(f"Warning: Invalid image_size in config ({image_size_cfg}). Using default {DEFAULT_IMAGE_SIZE}.")
    image_size = DEFAULT_IMAGE_SIZE
else:
    image_size = tuple(image_size_cfg) # Ensure it's a tuple for transforms

# --- Model Loading ---
def load_model(device, checkpoint_path):
    backbone = Backbone()
    classifier = Classifier(d_model=2048) # Make sure d_model matches your trained model
    model = SingleScanModel(backbone, classifier)

    try:
        # Construct absolute path if checkpoint_path is relative
        relative_path = config.get('infer', {}).get('checkpoints', 'checkpoints/brainage_model_latest.pt')
        # Use path relative to app.py location
        checkpoint_path_abs = os.path.join(APP_DIR, relative_path)

        checkpoint = torch.load(checkpoint_path_abs, map_location=device)
        # Adjust key if necessary based on how model was saved
        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
             model.load_state_dict(checkpoint)
        model.to(device)
        model.eval()
        print(f"Model loaded successfully from {checkpoint_path_abs} onto {device}.")
        return model
    except FileNotFoundError:
        print(f"Error: Checkpoint file not found at {checkpoint_path_abs}")
        return None
    except Exception as e:
        print(f"Error loading model checkpoint: {e}")
        traceback.print_exc()
        return None

device = torch.device(config.get('gpu', {}).get('device', 'cpu')) # Default to CPU
model = load_model(device, config) # Pass full config for path finding

# --- Preprocessing Functions from preprocess_utils.py ---
def bias_field_correction(img_array):
    """Performs N4 bias field correction using SimpleITK."""
    image = sitk.GetImageFromArray(img_array)
    # Ensure image is float32 for N4
    if image.GetPixelID() != sitk.sitkFloat32:
        image = sitk.Cast(image, sitk.sitkFloat32)
    maskImage = sitk.OtsuThreshold(image, 0, 1, 200)
    corrector = sitk.N4BiasFieldCorrectionImageFilter()
    numberFittingLevels = 4
    # Define iterations per level more robustly
    max_iters = [min(50 * (2**i), 200) for i in range(numberFittingLevels)]
    corrector.SetMaximumNumberOfIterations(max_iters)
    # Set convergence threshold (optional, can speed up)
    # corrector.SetConvergenceThreshold(1e-6)
    print("  Running N4 Bias Field Correction...")
    corrected_image = corrector.Execute(image, maskImage)
    print("  N4 Correction finished.")
    return sitk.GetArrayFromImage(corrected_image)

def denoise(volume, kernel_size=3):
    """Applies median filter for denoising."""
    print(f"  Applying median filter denoising (kernel={kernel_size})...")
    return medfilt(volume, kernel_size)

def rescale_intensity(volume, percentils=[0.5, 99.5], bins_num=256):
    """Rescales intensity after removing background via Otsu."""
    print("  Rescaling intensity...")
    # Ensure input is float for Otsu and calculations
    volume_float = volume.astype(np.float32)
    try:
        t = skimage.filters.threshold_otsu(volume_float, nbins=256)
        print(f"    Otsu threshold found: {t}")
        volume_masked = np.copy(volume_float)
        volume_masked[volume_masked < t] = 0 # Apply mask based on original values
        obj_volume = volume_masked[np.where(volume_masked > 0)]
    except ValueError: # Handle cases with near-uniform intensity
        print("    Otsu failed (likely uniform image), skipping background mask.")
        obj_volume = volume_float.flatten()

    if obj_volume.size == 0:
        print("    Warning: No foreground voxels found after Otsu. Scaling full volume.")
        obj_volume = volume_float.flatten() # Fallback to full volume
        min_value = np.min(obj_volume)
        max_value = np.max(obj_volume)
    else:
        min_value = np.percentile(obj_volume, percentils[0])
        max_value = np.percentile(obj_volume, percentils[1])

    print(f"    Intensity range used for scaling: [{min_value:.2f}, {max_value:.2f}]")
    # Avoid division by zero if max == min
    denominator = max_value - min_value
    if denominator < 1e-6: denominator = 1e-6

    # Create a copy to modify for output
    output_volume = np.copy(volume_float)
    # Apply scaling only to the object volume identified (or full volume as fallback)
    if bins_num == 0:
        # Scale to 0-1 (float)
        output_volume = (volume_float - min_value) / denominator
        output_volume = np.clip(output_volume, 0.0, 1.0) # Clip results to [0, 1]
    else:
        # Scale and bin
        output_volume = np.round((volume_float - min_value) / denominator * (bins_num - 1))
        output_volume = np.clip(output_volume, 0, bins_num - 1) # Ensure within bin range

    # Ensure output is float32 for consistency
    return output_volume.astype(np.float32)

def equalize_hist(volume, bins_num=256):
    """Performs histogram equalization on non-zero voxels."""
    print("  Performing histogram equalization...")
    # Create a mask of non-zero voxels
    mask = volume > 1e-6 # Use a small epsilon for float comparison
    obj_volume = volume[mask]

    if obj_volume.size == 0:
        print("    Warning: No non-zero voxels found for histogram equalization. Skipping.")
        return volume # Return original volume if no foreground

    # Compute histogram and CDF on the non-zero voxels
    hist, bins = np.histogram(obj_volume, bins_num, range=(obj_volume.min(), obj_volume.max()))
    cdf = hist.cumsum()

    # Normalize CDF
    cdf_normalized = (bins_num - 1) * cdf / float(cdf[-1])

    # Interpolate new values for the object volume
    equalized_obj_volume = np.interp(obj_volume, bins[:-1], cdf_normalized)

    # Create a copy of the original volume to put the results back
    equalized_volume = np.copy(volume)
    equalized_volume[mask] = equalized_obj_volume

    # Ensure output is float32
    return equalized_volume.astype(np.float32)

def enhance(img_array, run_bias_correction=True, kernel_size=3, percentils=[0.5, 99.5], bins_num=256, run_equalize_hist=True):
    """Full enhancement pipeline from preprocess_utils."""
    print("Starting enhancement pipeline...")
    volume = img_array.astype(np.float32) # Ensure float input
    try:
        if run_bias_correction:
            volume = bias_field_correction(volume)
        volume = denoise(volume, kernel_size)
        volume = rescale_intensity(volume, percentils, bins_num)
        if run_equalize_hist:
            volume = equalize_hist(volume, bins_num)
        print("Enhancement pipeline finished.")
        return volume
    except Exception as e:
        print(f"Error during enhancement: {e}")
        traceback.print_exc()
        raise RuntimeError(f"Failed enhancing image: {e}") # Re-raise to stop processing

# --- Registration Function (modified enhance call) ---
def register_image(input_nifti_path, output_nifti_path):
    """Registers input NIfTI to the default template using Elastix."""
    print(f"Registering {input_nifti_path} to {DEFAULT_TEMPLATE_PATH}")
    if not os.path.exists(PARAMS_RIGID_PATH):
        raise FileNotFoundError(f"Elastix parameter file not found at {PARAMS_RIGID_PATH}")
    if not os.path.exists(DEFAULT_TEMPLATE_PATH):
        raise FileNotFoundError(f"Default template file not found at {DEFAULT_TEMPLATE_PATH}")

    fixed_image = itk.imread(DEFAULT_TEMPLATE_PATH, itk.F)
    moving_image = itk.imread(input_nifti_path, itk.F)

    parameter_object = itk.ParameterObject.New()
    parameter_object.AddParameterFile(PARAMS_RIGID_PATH)

    result_image, _ = itk.elastix_registration_method(
        fixed_image, moving_image,
        parameter_object=parameter_object,
        log_to_console=False # Keep console clean
    )
    itk.imwrite(result_image, output_nifti_path)
    print(f"Registration output saved to {output_nifti_path}")

# --- Enhanced Image Function (calls actual enhance) ---
def run_enhance_on_file(input_nifti_path, output_nifti_path):
    """Reads NIfTI, runs enhance pipeline, saves NIfTI."""
    print(f"Running full enhancement on {input_nifti_path}")
    img_sitk = sitk.ReadImage(input_nifti_path)
    img_array = sitk.GetArrayFromImage(img_sitk)

    # Run the actual enhancement pipeline
    enhanced_array = enhance(img_array, run_bias_correction=True) # Assuming N4 is desired

    enhanced_img_sitk = sitk.GetImageFromArray(enhanced_array)
    enhanced_img_sitk.CopyInformation(img_sitk) # Preserve metadata
    sitk.WriteImage(enhanced_img_sitk, output_nifti_path)
    print(f"Enhanced image saved to {output_nifti_path}")

# --- Skull Stripping Function (Set Environment Variable) ---
def run_skull_stripping(input_nifti_path, output_dir):
    """Runs HD-BET skull stripping."""
    print(f"Running HD-BET skull stripping on {input_nifti_path}")
    if run_hd_bet is None:
        raise RuntimeError("HD-BET module could not be imported. Cannot perform skull stripping.")
    
    # Removed environment variable setting as path is fixed in HD_BET/paths.py
    # # Set environment variable *before* calling run_hd_bet
    # # Ensure the target directory exists
    # if not os.path.isdir(HD_BET_MODEL_DIR):
    #     raise FileNotFoundError(f"HD-BET model directory not found at specified path: {HD_BET_MODEL_DIR}")
    # print(f"Setting HD_BET_MODELS environment variable to: {HD_BET_MODEL_DIR}")
    # os.environ['HD_BET_MODELS'] = HD_BET_MODEL_DIR

    # Check config path
    if not os.path.exists(HD_BET_CONFIG_PATH):
        alt_config_path = os.path.join(APP_DIR, "HD_BET", "HD_BET", "config.py")
        if os.path.exists(alt_config_path):
             print(f"Warning: Using alternative HD-BET config path: {alt_config_path}")
             config_to_use = alt_config_path
        else:
            raise FileNotFoundError(f"HD-BET config file not found at {HD_BET_CONFIG_PATH} or {alt_config_path}")
    else:
        config_to_use = HD_BET_CONFIG_PATH

    # Define output paths
    base_name = os.path.basename(input_nifti_path).replace(".nii.gz", "").replace(".nii", "")
    output_file_path = os.path.join(output_dir, f"{base_name}_bet.nii.gz")
    output_mask_path = os.path.join(output_dir, f"{base_name}_bet_mask.nii.gz")

    # Make sure output directory exists
    os.makedirs(output_dir, exist_ok=True)

    # Run HD-BET
    run_hd_bet(input_nifti_path, output_file_path,
               mode="fast",
               device='cpu',
               config_file=config_to_use,
               postprocess=False,
               do_tta=False,
               keep_mask=True,
               overwrite=True)
    
    # Unset environment variable after use (optional, good practice)
    # del os.environ['HD_BET_MODELS'] 

    if not os.path.exists(output_file_path):
        raise RuntimeError(f"HD-BET did not produce the expected output file: {output_file_path}")

    print(f"Skull stripping output saved to {output_file_path}")
    return output_file_path, output_mask_path

# --- Image Preprocessing ---
# Define necessary MONAI transforms directly
# Keys must match the dictionary keys we create later ('image')
resize_transform = Resized(keys=["image"], spatial_size=image_size)
scale_transform = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0)

def preprocess_nifti(nifti_path):
    """Loads and preprocesses a NIfTI file, returning a 5D tensor."""
    print(f"Preprocessing NIfTI: {nifti_path}")
    scan_data = nib.load(nifti_path).get_fdata()
    print(f"  Loaded scan data shape: {scan_data.shape}")
    scan_tensor = torch.tensor(scan_data, dtype=torch.float32).unsqueeze(0) # Add C dim
    print(f"  Shape after tensor+channel: {scan_tensor.shape}")
    sample = {"image": scan_tensor}
    sample_resized = resize_transform(sample)
    print(f"  Shape after resize: {sample_resized['image'].shape}")
    sample_scaled = scale_transform(sample_resized)
    print(f"  Shape after scaling: {sample_scaled['image'].shape}")
    input_tensor = sample_scaled["image"].unsqueeze(0).to(device) # Add B dim
    print(f"  Final shape for model: {input_tensor.shape}")
    if input_tensor.dim() != 5:
        raise ValueError(f"Preprocessing resulted in incorrect shape: {input_tensor.shape}. Expected 5D.")
    return input_tensor

# --- Final NIfTI Preprocessing for Model ---
def preprocess_nifti_for_model(nifti_path):
    """Loads final NIfTI and prepares 5D tensor for the model."""
    # ... (Same as previous preprocess_nifti function) ...
    print(f"Preprocessing NIfTI for model: {nifti_path}")
    scan_data = nib.load(nifti_path).get_fdata()
    print(f"  Loaded scan data shape: {scan_data.shape}")
    scan_tensor = torch.tensor(scan_data, dtype=torch.float32).unsqueeze(0) # Add C dim
    print(f"  Shape after tensor+channel: {scan_tensor.shape}")
    sample = {"image": scan_tensor}
    sample_resized = resize_transform(sample)
    print(f"  Shape after resize: {sample_resized['image'].shape}")
    sample_scaled = scale_transform(sample_resized)
    print(f"  Shape after scaling: {sample_scaled['image'].shape}")
    input_tensor = sample_scaled["image"].unsqueeze(0).to(device) # Add B dim
    print(f"  Final shape for model: {input_tensor.shape}")
    if input_tensor.dim() != 5:
        raise ValueError(f"Preprocessing resulted in incorrect shape: {input_tensor.shape}. Expected 5D.")
    return input_tensor

# --- Saliency Map Generation ---
def generate_saliency(model, input_tensor_5d):
    """Generates saliency map using GuidedBackpropSmoothGrad."""
    if GuidedBackpropSmoothGrad is None:
        raise ImportError("MONAI visualize components not imported. Cannot generate saliency map.")
    if model is None:
         raise ValueError("Model not loaded. Cannot generate saliency map.")
    
    print("Generating saliency map...")
    input_tensor_5d.requires_grad_(True)
    # Use the backbone for saliency as in the original script
    # Ensure model and backbone are on the correct device (CPU in this case)
    visualizer = GuidedBackpropSmoothGrad(model=model.backbone.to(device), 
                                            stdev_spread=0.15, 
                                            n_samples=10, 
                                            magnitude=True)
    
    try:
        with torch.enable_grad():
            saliency_map_5d = visualizer(input_tensor_5d.to(device))
        print("Saliency map generated.")
        
        # Detach, move to CPU, remove Batch and Channel dims for processing/plotting -> (D, H, W)
        input_3d = input_tensor_5d.squeeze().cpu().detach().numpy()
        saliency_3d = saliency_map_5d.squeeze().cpu().detach().numpy()
        
        return input_3d, saliency_3d
    
    except Exception as e:
        print(f"Error during saliency map generation: {e}")
        traceback.print_exc()
        # Return None or empty arrays if generation fails
        return None, None
    finally:
        # Ensure requires_grad is turned off if it was modified
        input_tensor_5d.requires_grad_(False)

# --- Plotting Function for Single Slice ---
def create_plot_images_for_slice(mri_data_3d, saliency_data_3d, slice_index):
    """Creates base64 encoded PNGs for a specific axial slice index."""
    print(f"  Generating plots for slice index: {slice_index}")
    if mri_data_3d is None or saliency_data_3d is None:
        print("    Input or Saliency data is None, cannot generate plot.")
        return None
    if slice_index < 0 or slice_index >= mri_data_3d.shape[2]:
        print(f"    Error: Slice index {slice_index} out of bounds (0-{mri_data_3d.shape[2]-1}).")
        return None

    # Function to save plot to base64 string (copied from previous version)
    def save_plot_to_base64(fig):
        buf = io.BytesIO()
        fig.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=75)
        plt.close(fig) # Close the figure immediately
        buf.seek(0)
        img_str = base64.b64encode(buf.read()).decode('utf-8')
        buf.close()
        return img_str

    try:
        mri_slice = mri_data_3d[:, :, slice_index]
        saliency_slice_orig = saliency_data_3d[:, :, slice_index]

        # --- Normalize MRI Slice (using volume stats if available, otherwise slice stats) ---
        # For consistency, ideally pass volume stats, but recalculating per slice is fallback
        p1_vol, p99_vol = np.percentile(mri_data_3d, (1, 99))
        mri_norm_denom = p99_vol - p1_vol
        if mri_norm_denom < 1e-6: mri_norm_denom = 1e-6
        mri_slice_norm = np.clip(mri_slice, p1_vol, p99_vol)
        mri_slice_norm = (mri_slice_norm - p1_vol) / mri_norm_denom

        # --- Process Saliency Slice ---
        saliency_slice = np.copy(saliency_slice_orig)
        saliency_slice[saliency_slice < 0] = 0 # Ensure non-negative
        saliency_slice_blurred = cv2.GaussianBlur(saliency_slice, (15, 15), 0)
        # Use volume max for normalization if possible, fallback to slice max
        s_max_vol = np.max(saliency_data_3d[saliency_data_3d >= 0]) # Max of non-negative values in volume
        if s_max_vol < 1e-6: s_max_vol = 1e-6
        # --- Add logging for the calculated global max ---
        print(f"    Calculated Global Max Saliency (s_max_vol) for normalization: {s_max_vol:.4f}")
        # --------------------------------------------------
        saliency_slice_norm = saliency_slice_blurred / s_max_vol
        threshold_value = 0.0
        saliency_slice_thresholded = np.where(saliency_slice_norm > threshold_value, saliency_slice_norm, 0)

        # --- Generate Plots ---
        slice_plots = {}

        # Plot 1: Input Slice
        fig1, ax1 = plt.subplots(figsize=(3, 3))
        ax1.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower')
        ax1.axis('off')
        slice_plots['input_slice'] = save_plot_to_base64(fig1)

        # Plot 2: Saliency Heatmap
        fig2, ax2 = plt.subplots(figsize=(3, 3))
        ax2.imshow(saliency_slice_thresholded, cmap='magma', interpolation='none', origin='lower')
        ax2.axis('off')
        slice_plots['heatmap_slice'] = save_plot_to_base64(fig2)

        # Plot 3: Overlay
        fig3, ax3 = plt.subplots(figsize=(3, 3))
        ax3.imshow(mri_slice_norm, cmap='gray', interpolation='none', origin='lower')
        if np.max(saliency_slice_thresholded) > 0:
             # Remove fixed levels to let contour auto-determine levels based on slice data
             ax3.contour(saliency_slice_thresholded, cmap='magma', origin='lower', linewidths=1.0)
        ax3.axis('off')
        slice_plots['overlay_slice'] = save_plot_to_base64(fig3)

        print(f"    Generated plots successfully for slice {slice_index}.")
        return slice_plots

    except Exception as e:
        print(f"Error generating plots for slice {slice_index}: {e}")
        traceback.print_exc()
        return None

# --- Flask Routes ---
@app.route('/', methods=['GET'])
def index():
    return render_template('index.html')

@app.route('/predict', methods=['POST'])
def predict():
    if model is None:
        flash('Model not loaded. Cannot perform prediction.', 'error')
        return redirect(url_for('index'))

    # Get form data
    file_type = request.form.get('file_type')
    run_preprocess_flag = request.form.get('preprocess') == 'yes'
    generate_saliency_flag = request.form.get('generate_saliency') == 'yes' # Get saliency flag
    file = request.files.get('scan_file')

    # --- Basic Input Validation ---
    if not file_type:
        flash('Please select a file type (NIfTI or DICOM).', 'error')
        return redirect(url_for('index'))
    if not file or file.filename == '':
        flash('No scan file selected', 'error')
        return redirect(url_for('index'))

    print(f"Received upload: type='{file_type}', filename='{file.filename}', preprocess={run_preprocess_flag}, saliency={generate_saliency_flag}")

    # --- Setup Temporary Directory --- 
    # temp_dir_obj = tempfile.TemporaryDirectory() # <--- PROBLEM: Cleans up automatically
    # Use mkdtemp to create a persistent temporary directory
    # NOTE: Requires a manual cleanup strategy later!
    try:
        temp_dir = tempfile.mkdtemp() 
    except Exception as e:
        print(f"Error creating temporary directory: {e}")
        flash("Server error: Could not create temporary directory.", "error")
        return redirect(url_for('index'))
        
    # Generate a unique ID based on the temp directory name
    unique_id = os.path.basename(temp_dir) 
    print(f"Created persistent temp directory: {temp_dir} (ID: {unique_id})")
    nifti_for_preprocessing_path = None # Path to the NIfTI before optional preprocessing

    try:
        # --- Handle Upload and DICOM Conversion ---
        # --- Handle NIfTI Upload ---
        if file_type == 'nifti':
            if not file.filename.endswith('.nii.gz'):
                flash('Invalid file type for NIfTI selection. Please upload .nii.gz', 'error')
                # temp_dir_obj.cleanup() # No object to cleanup, need manual rmtree
                shutil.rmtree(temp_dir, ignore_errors=True) 
                return redirect(url_for('index'))
            uploaded_file_path = os.path.join(temp_dir, "uploaded_scan.nii.gz")
            file.save(uploaded_file_path)
            print(f"Saved uploaded NIfTI file to: {uploaded_file_path}")
            nifti_for_preprocessing_path = uploaded_file_path

        # --- Handle DICOM Upload ---
        elif file_type == 'dicom':
            if not file.filename.endswith('.zip'):
                flash('Invalid file type for DICOM selection. Please upload a .zip file.', 'error')
                # temp_dir_obj.cleanup()
                shutil.rmtree(temp_dir, ignore_errors=True) 
                return redirect(url_for('index'))
            uploaded_zip_path = os.path.join(temp_dir, "dicom_files.zip")
            file.save(uploaded_zip_path)
            print(f"Saved uploaded DICOM zip to: {uploaded_zip_path}")
            dicom_input_dir = os.path.join(temp_dir, "dicom_input")
            nifti_output_dir = os.path.join(temp_dir, "nifti_output")
            os.makedirs(dicom_input_dir, exist_ok=True)
            os.makedirs(nifti_output_dir, exist_ok=True)
            try:
                # Use shutil.unpack_archive for better cross-platform compatibility potentially
                shutil.unpack_archive(uploaded_zip_path, dicom_input_dir)
                print(f"Unzip successful.")
            except Exception as e:
                print(f"Unzip failed: {e}")
                flash(f'Error unzipping DICOM file: {e}', 'error')
                # temp_dir_obj.cleanup()
                shutil.rmtree(temp_dir, ignore_errors=True) 
                return redirect(url_for('index'))
            try:
                dicom2nifti.convert_directory(dicom_input_dir, nifti_output_dir, compression=True, reorient=True)
                nifti_files = [f for f in os.listdir(nifti_output_dir) if f.endswith('.nii.gz')]
                if not nifti_files:
                    raise RuntimeError("dicom2nifti did not produce a .nii.gz file.")
                nifti_for_preprocessing_path = os.path.join(nifti_output_dir, nifti_files[0])
                print(f"DICOM conversion successful. NIfTI file: {nifti_for_preprocessing_path}")
            except Exception as e:
                print(f"DICOM to NIfTI conversion failed: {e}")
                flash(f'Error converting DICOM to NIfTI: {e}', 'error')
                # temp_dir_obj.cleanup()
                shutil.rmtree(temp_dir, ignore_errors=True) 
                return redirect(url_for('index'))
        else:
            flash('Invalid file type selected.', 'error')
            # temp_dir_obj.cleanup()
            shutil.rmtree(temp_dir, ignore_errors=True) 
            return redirect(url_for('index'))
        
        if not nifti_for_preprocessing_path or not os.path.exists(nifti_for_preprocessing_path):
             flash('Error: Could not find the NIfTI file for processing.', 'error')
             # temp_dir_obj.cleanup()
             shutil.rmtree(temp_dir, ignore_errors=True) 
             return redirect(url_for('index'))

        # --- Optional Preprocessing Steps ---
        nifti_to_predict_path = nifti_for_preprocessing_path
        if run_preprocess_flag:
            print("--- Running Optional Preprocessing Pipeline ---")
            try:
                registered_path = os.path.join(temp_dir, "registered.nii.gz")
                register_image(nifti_for_preprocessing_path, registered_path)
                enhanced_path = os.path.join(temp_dir, "enhanced.nii.gz")
                run_enhance_on_file(registered_path, enhanced_path)
                skullstrip_output_dir = os.path.join(temp_dir, "skullstripped")
                skullstripped_path, _ = run_skull_stripping(enhanced_path, skullstrip_output_dir)
                nifti_to_predict_path = skullstripped_path
                print("--- Optional Preprocessing Pipeline Complete ---")
            except Exception as e:
                print(f"Error during optional preprocessing pipeline: {e}")
                traceback.print_exc()
                flash(f'Error during preprocessing: {e}', 'error')
                # temp_dir_obj.cleanup()
                shutil.rmtree(temp_dir, ignore_errors=True) 
                return redirect(url_for('index'))
        else:
             print("--- Skipping Optional Preprocessing Pipeline ---")

        # --- Final Preprocessing for Model & Prediction ---
        input_tensor_5d = preprocess_nifti_for_model(nifti_to_predict_path)
        print("Performing prediction...")
        with torch.no_grad():
            output = model(input_tensor_5d)
            predicted_age = output.item()
            predicted_age_years = predicted_age / 12 # Adjust if needed
        print(f"Prediction successful: {predicted_age_years:.2f} years")

        # --- Saliency Data Handling (Generate, Save, Get Initial Plot) ---
        saliency_output_for_template = None # Initialize
        if generate_saliency_flag:
            print("--- Generating & Saving Saliency Data ---")
            try:
                input_3d_for_plot, saliency_3d = generate_saliency(model, input_tensor_5d)
                
                if input_3d_for_plot is not None and saliency_3d is not None:
                    num_slices = input_3d_for_plot.shape[2]
                    center_slice_index = num_slices // 2

                    # Save the numpy arrays for the dynamic route
                    input_array_path = os.path.join(temp_dir, f"{unique_id}_input.npy")
                    saliency_array_path = os.path.join(temp_dir, f"{unique_id}_saliency.npy")
                    np.save(input_array_path, input_3d_for_plot)
                    np.save(saliency_array_path, saliency_3d)
                    print(f"Saved input array to {input_array_path}")
                    print(f"Saved saliency array to {saliency_array_path}")

                    # Generate ONLY the center slice plots for the initial view
                    center_slice_plots = create_plot_images_for_slice(input_3d_for_plot, saliency_3d, center_slice_index)

                    if center_slice_plots:
                        # Prepare data structure for the template
                        saliency_output_for_template = {
                            'center_slice_plots': center_slice_plots,
                            'num_slices': num_slices,
                            'center_slice_index': center_slice_index,
                            'unique_id': unique_id, # Pass the ID for filenames
                            'temp_dir_path': temp_dir # Pass the full path for lookup
                        }
                        print("--- Saliency Data Saved & Initial Plot Generated ---")
                    else:
                         print("--- Center Slice Plotting Failed ---")
                         flash('Failed to generate initial saliency plot.', 'warning')
                else:
                     print("--- Saliency Generation Failed --- ")
                     flash('Saliency map generation failed.', 'warning')

            except Exception as e:
                print(f"Error during saliency processing/saving: {e}")
                traceback.print_exc()
                flash('Could not generate or save saliency maps due to an error.', 'error')

        # Render result, passing prediction and potentially the NEW saliency structure
        return render_template('index.html', 
                               prediction=f"{predicted_age_years:.2f} years", 
                               saliency_info=saliency_output_for_template) # Pass the new dict

    except Exception as e:
        flash(f'Error processing file: {e}', 'error')
        print(f"Caught Exception during prediction process: {e}")
        traceback.print_exc()
        # Ensure cleanup happens even if exception occurs mid-process
        # temp_dir_obj.cleanup()
        if temp_dir and os.path.exists(temp_dir):
            shutil.rmtree(temp_dir, ignore_errors=True) # Manual cleanup on general error
        return redirect(url_for('index'))

    # NOTE: Temporary directory created with mkdtemp is NOT automatically cleaned.
    # Need a separate mechanism (e.g., cron job, background task) to remove old directories
    # from the system's temporary location (e.g., /tmp) based on age.
    # Leaving the directory here so /get_slice can access the files.

# --- New Route for Dynamic Slice Loading ---
@app.route('/get_slice/<unique_id>/<int:slice_index>')
def get_slice(unique_id, slice_index):
    # Get the actual temporary directory path from query parameter
    temp_dir_path = request.args.get('path')
    if not temp_dir_path:
        print("Error: 'path' query parameter missing in /get_slice request")
        return jsonify({"error": "Required path information missing."}), 400
    
    # Construct paths using the provided directory path and unique ID
    input_array_path = os.path.join(temp_dir_path, f"{unique_id}_input.npy")
    saliency_array_path = os.path.join(temp_dir_path, f"{unique_id}_saliency.npy")
    print(f"Attempting to load slice {slice_index} for ID {unique_id} from actual path: {temp_dir_path}")

    try:
        # Check using the exact paths constructed above
        if not os.path.exists(input_array_path) or not os.path.exists(saliency_array_path):
             print(f"Error: .npy files not found for ID {unique_id} at {temp_dir_path}")
             return jsonify({"error": "Saliency data not found. It might have expired or failed to save."}), 404
        
        input_3d = np.load(input_array_path)
        saliency_3d = np.load(saliency_array_path)
        print(f"Loaded arrays for ID {unique_id}. Input shape: {input_3d.shape}, Saliency shape: {saliency_3d.shape}")

        # Generate plots for the requested slice using the helper function
        slice_plots = create_plot_images_for_slice(input_3d, saliency_3d, slice_index)

        if slice_plots:
            return jsonify(slice_plots) # Return plot data as JSON
        else:
            return jsonify({"error": f"Failed to generate plots for slice {slice_index}."}), 500

    except Exception as e:
        print(f"Error in /get_slice for ID {unique_id}, slice {slice_index}: {e}")
        traceback.print_exc()
        return jsonify({"error": "An internal error occurred while fetching the slice data."}), 500

if __name__ == '__main__':
    # Use '0.0.0.0' to make it accessible outside the container
    app.run(host='0.0.0.0', port=5000, debug=False) # Turn off debug for production/docker