| |
|
|
| import os |
| import numpy as np |
| import cv2 |
| import tensorflow as tf |
| import onnxruntime as ort |
| from tensorflow.keras.preprocessing.image import load_img |
|
|
| from src.utils import get_logger, get_gradcam_heatmap, get_last_conv_layer |
|
|
| logger = get_logger("predict") |
|
|
|
|
| class BrainTumorPredictor: |
| """ |
| Unified predictor supporting TF model, ONNX FP32, |
| ONNX Dynamic INT8, and ONNX Static INT8. |
| """ |
|
|
| BACKENDS = ["tensorflow", "onnx_fp32", "onnx_dynamic", "onnx_static"] |
|
|
| def __init__(self, cfg: dict, backend: str = "tensorflow"): |
| if backend not in self.BACKENDS: |
| raise ValueError(f"backend must be one of {self.BACKENDS}") |
|
|
| self.backend = backend |
| self.image_size = tuple(cfg["data"]["image_size"]) |
| self.class_names = cfg["data"]["classes"] |
| self.save_dir = cfg["models"]["save_dir"] |
| self.onnx_dir = cfg["models"]["onnx_dir"] |
|
|
| self.tf_model = None |
| self.ort_session = None |
| self._load(backend) |
|
|
| def _load(self, backend: str): |
| if backend == "tensorflow": |
| path = os.path.join(self.save_dir, "ft_best.h5") |
| logger.info(f"Loading TF model from {path}") |
| self.tf_model = tf.keras.models.load_model(path, compile=False) |
|
|
| elif backend == "onnx_fp32": |
| path = os.path.join(self.onnx_dir, "model_fp32.onnx") |
| logger.info(f"Loading ONNX FP32 from {path}") |
| self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) |
|
|
| elif backend == "onnx_dynamic": |
| path = os.path.join(self.onnx_dir, "model_dynamic_int8.onnx") |
| logger.info(f"Loading ONNX Dynamic INT8 from {path}") |
| try: |
| self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) |
| except Exception as e: |
| raise RuntimeError( |
| f"ONNX Dynamic INT8 model is not supported in this ONNX Runtime build: {e}" |
| ) |
|
|
| elif backend == "onnx_static": |
| path = os.path.join(self.onnx_dir, "model_static_int8.onnx") |
| logger.info(f"Loading ONNX Static INT8 from {path}") |
| try: |
| self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"]) |
| except Exception as e: |
| raise RuntimeError( |
| f"ONNX Static INT8 model is not supported in this ONNX Runtime build: {e}" |
| ) |
|
|
| def preprocess(self, image_path: str) -> tuple: |
| img = load_img(image_path, target_size=self.image_size) |
| arr = np.array(img) / 255.0 |
| img_input = np.expand_dims(arr, axis=0).astype(np.float32) |
| return img, arr, img_input |
|
|
| def predict(self, image_path: str) -> dict: |
| _, _, img_input = self.preprocess(image_path) |
|
|
| if self.backend == "tensorflow": |
| probs = self.tf_model.predict(img_input, verbose=0)[0] |
| else: |
| inp_name = self.ort_session.get_inputs()[0].name |
| out_name = self.ort_session.get_outputs()[0].name |
| probs = self.ort_session.run([out_name], {inp_name: img_input})[0][0] |
|
|
| pred_idx = int(np.argmax(probs)) |
| pred_class = self.class_names[pred_idx] |
| confidence = float(probs[pred_idx]) * 100 |
|
|
| all_probs = {cls: float(p) * 100 for cls, p in zip(self.class_names, probs)} |
|
|
| return { |
| "predicted_class": pred_class, |
| "confidence": round(confidence, 2), |
| "all_probabilities": all_probs, |
| "backend": self.backend, |
| } |
|
|
| def predict_with_gradcam(self, image_path: str) -> dict: |
| if self.backend != "tensorflow": |
| raise RuntimeError("Grad-CAM is only supported with tensorflow backend.") |
|
|
| result = self.predict(image_path) |
| _, arr, img_input = self.preprocess(image_path) |
|
|
| last_conv = get_last_conv_layer(self.tf_model) |
| heatmap, _ = get_gradcam_heatmap(self.tf_model, img_input, last_conv) |
|
|
| heatmap_resized = cv2.resize(heatmap, self.image_size) |
| heatmap_colored = cv2.cvtColor( |
| cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET), |
| cv2.COLOR_BGR2RGB |
| ) |
| overlay = cv2.addWeighted(np.uint8(255 * arr), 0.6, heatmap_colored, 0.4, 0) |
|
|
| result["gradcam_overlay"] = overlay |
| result["heatmap"] = heatmap_resized |
| return result |
|
|
|
|
| if __name__ == "__main__": |
| import argparse |
| import matplotlib.pyplot as plt |
| from src.utils import load_config |
|
|
| parser = argparse.ArgumentParser(description="Brain Tumor MRI Predictor") |
| parser.add_argument("--image", required=True) |
| parser.add_argument("--backend", default="tensorflow", choices=BrainTumorPredictor.BACKENDS) |
| parser.add_argument("--gradcam", action="store_true") |
| args = parser.parse_args() |
|
|
| cfg = load_config("config.yaml") |
| predictor = BrainTumorPredictor(cfg, backend=args.backend) |
|
|
| if args.gradcam and args.backend == "tensorflow": |
| result = predictor.predict_with_gradcam(args.image) |
|
|
| fig, axes = plt.subplots(1, 3, figsize=(13, 4)) |
| img = load_img(args.image, target_size=tuple(cfg["data"]["image_size"])) |
| axes[0].imshow(img) |
| axes[0].set_title("Input MRI") |
| axes[0].axis("off") |
|
|
| axes[1].imshow(result["heatmap"], cmap="jet") |
| axes[1].set_title("Grad-CAM") |
| axes[1].axis("off") |
|
|
| axes[2].imshow(result["gradcam_overlay"]) |
| axes[2].set_title(f"Pred: {result['predicted_class']} ({result['confidence']:.1f}%)") |
| axes[2].axis("off") |
|
|
| plt.tight_layout() |
| plt.show() |
| else: |
| result = predictor.predict(args.image) |
|
|
| print("\n" + "=" * 42) |
| print(f" PREDICTION : {result['predicted_class'].upper()}") |
| print(f" CONFIDENCE : {result['confidence']:.2f}%") |
| print(f" BACKEND : {result['backend']}") |
| print("=" * 42) |
| print(" All probabilities:") |
| for cls, prob in sorted(result["all_probabilities"].items(), key=lambda x: -x[1]): |
| bar = "█" * int(prob / 4) |
| marker = " ← predicted" if cls == result["predicted_class"] else "" |
| print(f" {cls:<15} {prob:5.1f}% {bar}{marker}") |