hp733's picture
Update app.py
46c9dec verified
# import os
# import numpy as np
# from PIL import Image
# import cv2
# from flask import Flask, request, render_template
# from werkzeug.utils import secure_filename
# from tensorflow.keras.models import load_model
# from gradcam_utils import generate_and_merge_heatmaps
# app = Flask(__name__)
# UPLOAD_FOLDER = 'static/uploads'
# HEATMAP_PATH = 'static/heatmap.jpg'
# os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# # Load your trained ensemble model
# model = load_model('ensemble_model_best(92.3).h5')
# # Load the three base models (if required for gradcam)
# from models import create_vgg19_model, create_efficientnet_model, create_densenet_model
# vgg_model = create_vgg19_model()
# efficientnet_model = create_efficientnet_model()
# densenet_model = create_densenet_model()
# print('Model loaded. Visit http://127.0.0.1:5000/')
# def get_className(classNo):
# return "Normal" if classNo == 0 else "Pneumonia"
# def getResult(img_path):
# image = cv2.imread(img_path)
# image = Image.fromarray(image, 'RGB')
# image = image.resize((224, 224))
# image = np.array(image)
# input_img = np.expand_dims(image, axis=0) / 255.0
# result = model.predict(input_img)
# result01 = np.argmax(result, axis=1)
# return result01
# @app.route('/', methods=['GET'])
# def index():
# return render_template('index.html')
# @app.route('/predict', methods=['POST'])
# def upload():
# if request.method == 'POST':
# f = request.files['file']
# filename = secure_filename(f.filename)
# file_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
# f.save(file_path)
# # Get prediction
# value = getResult(file_path)
# result = get_className(value[0])
# # Generate Grad-CAM heatmap
# heatmap_img = generate_and_merge_heatmaps(
# file_path, vgg_model, efficientnet_model, densenet_model
# )
# # Save heatmap image
# cv2.imwrite(HEATMAP_PATH, cv2.cvtColor(heatmap_img, cv2.COLOR_RGB2BGR))
# return render_template(
# 'result.html',
# prediction=result,
# original_image=file_path,
# heatmap_image=HEATMAP_PATH
# )
# return None
# if __name__ == '__main__':
# app.run(host='0.0.0.0', port=5000, debug=True)
# import gradio as gr
# import numpy as np
# import cv2
# from PIL import Image
# from tensorflow.keras.models import load_model
# from models import create_vgg19_model, create_efficientnet_model, create_densenet_model
# from gradcam_utils import generate_and_merge_heatmaps
# # Load models
# ensemble_model = load_model("ensemble_model_best(92.3).h5")
# vgg_model = create_vgg19_model()
# efficientnet_model = create_efficientnet_model()
# densenet_model = create_densenet_model()
# def get_class_name(class_id):
# return "Normal" if class_id == 0 else "Pneumonia"
# def predict_and_heatmap(image):
# # Preprocess input image
# img = image.resize((224, 224))
# img_array = np.array(img) / 255.0
# img_array = np.expand_dims(img_array, axis=0)
# # Predict using ensemble model
# prediction = ensemble_model.predict(img_array)
# class_id = np.argmax(prediction[0])
# result = get_class_name(class_id)
# # Save uploaded image temporarily
# temp_img_path = "temp_input.jpg"
# image.save(temp_img_path)
# # Generate Grad-CAM heatmap
# heatmap_img = generate_and_merge_heatmaps(
# temp_img_path, vgg_model, efficientnet_model, densenet_model
# )
# return result, Image.fromarray(heatmap_img)
# # Gradio Interface
# interface = gr.Interface(
# fn=predict_and_heatmap,
# inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
# outputs=[
# gr.Label(label="Prediction"),
# gr.Image(label="Grad-CAM Heatmap")
# ],
# title="Pneumonia Detection Using Deep Learning",
# description="Upload a chest X-ray to detect Pneumonia and see the heatmap visualization (Grad-CAM)."
# )
# if __name__ == "__main__":
# interface.launch()
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from PIL import Image
import os
from models import create_vgg19_model
from gradcam_utils import generate_heatmap_tf_explain
# Load your trained model
ensemble_model = load_model("ensemble_model_best(92.3).h5")
vgg_model = create_vgg19_model() # Only used for Grad-CAM (tf-explain)
# Label names
def get_class_name(class_id):
return "Normal" if class_id == 0 else "Pneumonia"
# Prediction + Heatmap generation
def predict_and_heatmap(image):
img = image.resize((224, 224))
img_array = np.array(img) / 255.0
img_array = np.expand_dims(img_array, axis=0)
prediction = ensemble_model.predict(img_array)
class_id = int(np.argmax(prediction[0]))
label = get_class_name(class_id)
result_html = f"""
<div style='
text-align: center;
font-size: 1.5rem;
font-weight: bold;
color: {"green" if class_id == 0 else "red"};
background-color: #f0f8ff;
border: 2px solid {"green" if class_id == 0 else "red"};
padding: 10px;
border-radius: 10px;
width: fit-content;
margin: 0 auto;
'>
Result: {label}
</div>
"""
# Generate Grad-CAM heatmap using tf-explain (on VGG19)
heatmap_img = generate_heatmap_tf_explain(image, vgg_model, class_index=class_id)
return result_html, heatmap_img
# Function to load sample image
def load_sample():
return Image.open("sample_pneumonia.jpeg")
# Gradio interface
with gr.Blocks(theme="soft") as demo:
gr.Markdown("""
<div style="text-align: center; font-size: 2.5rem; font-weight: bold; color: #0b5394; margin-bottom: 1rem;">
🩺 Pneumonia Detection from Chest X-rays
</div>
<div style="text-align: center; font-size: 1.1rem; margin-bottom: 2rem;">
Upload a chest X-ray image to predict if the lungs are Normal or show signs of Pneumonia.
</div>
""")
with gr.Row():
with gr.Column(scale=1, min_width=600):
image_input = gr.Image(type="pil", label="Upload Chest X-Ray", interactive=True, width=600, height=600)
prediction_output = gr.HTML(label="Prediction")
heatmap_output = gr.Image(label="Grad-CAM Heatmap", width=600, height=600)
with gr.Row():
submit_button = gr.Button("Predict")
clear_button = gr.Button("Clear")
sample_button = gr.Button("Load Sample X-ray")
submit_button.click(fn=predict_and_heatmap, inputs=image_input, outputs=[prediction_output, heatmap_output])
clear_button.click(fn=lambda: (None, "", None), inputs=[], outputs=[image_input, prediction_output, heatmap_output])
sample_button.click(fn=load_sample, inputs=[], outputs=[image_input])
gr.Markdown("""
<div style="text-align: center; font-size: 0.95rem; color: #888; margin-top: 30px;">
Made with ❤️ by <a href="https://github.com/hruthik733" target="_blank" style="color: #0b5394; text-decoration: none; font-weight: bold;">
Hruthik Pavarala</a>
</div>
""")
if __name__ == "__main__":
demo.launch()