NeuroScan / predict.py
Shoaib-33's picture
Deployment Added
f16e7e0
# predict.py — inference engine: TF | ONNX FP32 | Dynamic INT8 | Static INT8
import os
import numpy as np
import cv2
import tensorflow as tf
import onnxruntime as ort
from tensorflow.keras.preprocessing.image import load_img
from src.utils import get_logger, get_gradcam_heatmap, get_last_conv_layer
logger = get_logger("predict")
class BrainTumorPredictor:
"""
Unified predictor supporting TF model, ONNX FP32,
ONNX Dynamic INT8, and ONNX Static INT8.
"""
BACKENDS = ["tensorflow", "onnx_fp32", "onnx_dynamic", "onnx_static"]
def __init__(self, cfg: dict, backend: str = "tensorflow"):
if backend not in self.BACKENDS:
raise ValueError(f"backend must be one of {self.BACKENDS}")
self.backend = backend
self.image_size = tuple(cfg["data"]["image_size"])
self.class_names = cfg["data"]["classes"]
self.save_dir = cfg["models"]["save_dir"]
self.onnx_dir = cfg["models"]["onnx_dir"]
self.tf_model = None
self.ort_session = None
self._load(backend)
def _load(self, backend: str):
if backend == "tensorflow":
path = os.path.join(self.save_dir, "ft_best.h5")
logger.info(f"Loading TF model from {path}")
self.tf_model = tf.keras.models.load_model(path, compile=False)
elif backend == "onnx_fp32":
path = os.path.join(self.onnx_dir, "model_fp32.onnx")
logger.info(f"Loading ONNX FP32 from {path}")
self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
elif backend == "onnx_dynamic":
path = os.path.join(self.onnx_dir, "model_dynamic_int8.onnx")
logger.info(f"Loading ONNX Dynamic INT8 from {path}")
try:
self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
except Exception as e:
raise RuntimeError(
f"ONNX Dynamic INT8 model is not supported in this ONNX Runtime build: {e}"
)
elif backend == "onnx_static":
path = os.path.join(self.onnx_dir, "model_static_int8.onnx")
logger.info(f"Loading ONNX Static INT8 from {path}")
try:
self.ort_session = ort.InferenceSession(path, providers=["CPUExecutionProvider"])
except Exception as e:
raise RuntimeError(
f"ONNX Static INT8 model is not supported in this ONNX Runtime build: {e}"
)
def preprocess(self, image_path: str) -> tuple:
img = load_img(image_path, target_size=self.image_size)
arr = np.array(img) / 255.0
img_input = np.expand_dims(arr, axis=0).astype(np.float32)
return img, arr, img_input
def predict(self, image_path: str) -> dict:
_, _, img_input = self.preprocess(image_path)
if self.backend == "tensorflow":
probs = self.tf_model.predict(img_input, verbose=0)[0]
else:
inp_name = self.ort_session.get_inputs()[0].name
out_name = self.ort_session.get_outputs()[0].name
probs = self.ort_session.run([out_name], {inp_name: img_input})[0][0]
pred_idx = int(np.argmax(probs))
pred_class = self.class_names[pred_idx]
confidence = float(probs[pred_idx]) * 100
all_probs = {cls: float(p) * 100 for cls, p in zip(self.class_names, probs)}
return {
"predicted_class": pred_class,
"confidence": round(confidence, 2),
"all_probabilities": all_probs,
"backend": self.backend,
}
def predict_with_gradcam(self, image_path: str) -> dict:
if self.backend != "tensorflow":
raise RuntimeError("Grad-CAM is only supported with tensorflow backend.")
result = self.predict(image_path)
_, arr, img_input = self.preprocess(image_path)
last_conv = get_last_conv_layer(self.tf_model)
heatmap, _ = get_gradcam_heatmap(self.tf_model, img_input, last_conv)
heatmap_resized = cv2.resize(heatmap, self.image_size)
heatmap_colored = cv2.cvtColor(
cv2.applyColorMap(np.uint8(255 * heatmap_resized), cv2.COLORMAP_JET),
cv2.COLOR_BGR2RGB
)
overlay = cv2.addWeighted(np.uint8(255 * arr), 0.6, heatmap_colored, 0.4, 0)
result["gradcam_overlay"] = overlay
result["heatmap"] = heatmap_resized
return result
if __name__ == "__main__":
import argparse
import matplotlib.pyplot as plt
from src.utils import load_config
parser = argparse.ArgumentParser(description="Brain Tumor MRI Predictor")
parser.add_argument("--image", required=True)
parser.add_argument("--backend", default="tensorflow", choices=BrainTumorPredictor.BACKENDS)
parser.add_argument("--gradcam", action="store_true")
args = parser.parse_args()
cfg = load_config("config.yaml")
predictor = BrainTumorPredictor(cfg, backend=args.backend)
if args.gradcam and args.backend == "tensorflow":
result = predictor.predict_with_gradcam(args.image)
fig, axes = plt.subplots(1, 3, figsize=(13, 4))
img = load_img(args.image, target_size=tuple(cfg["data"]["image_size"]))
axes[0].imshow(img)
axes[0].set_title("Input MRI")
axes[0].axis("off")
axes[1].imshow(result["heatmap"], cmap="jet")
axes[1].set_title("Grad-CAM")
axes[1].axis("off")
axes[2].imshow(result["gradcam_overlay"])
axes[2].set_title(f"Pred: {result['predicted_class']} ({result['confidence']:.1f}%)")
axes[2].axis("off")
plt.tight_layout()
plt.show()
else:
result = predictor.predict(args.image)
print("\n" + "=" * 42)
print(f" PREDICTION : {result['predicted_class'].upper()}")
print(f" CONFIDENCE : {result['confidence']:.2f}%")
print(f" BACKEND : {result['backend']}")
print("=" * 42)
print(" All probabilities:")
for cls, prob in sorted(result["all_probabilities"].items(), key=lambda x: -x[1]):
bar = "█" * int(prob / 4)
marker = " ← predicted" if cls == result["predicted_class"] else ""
print(f" {cls:<15} {prob:5.1f}% {bar}{marker}")