innat's picture
Update app.py
cef90e2
raw
history blame contribute delete
No virus
3.03 kB
import os
import gdown
import gradio as gr
import tensorflow as tf
from config import Parameters
from models.hybrid_model import GradientAccumulation
from utils.model_utils import *
from utils.viz_utils import make_gradcam_heatmap
from utils.viz_utils import save_and_display_gradcam
image_size = Parameters().image_size
str_labels = [
"daisy",
"dandelion",
"roses",
"sunflowers",
"tulips",
]
def get_model():
"""Get the model."""
model = GradientAccumulation(
n_gradients=params.num_grad_accumulation, model_name="HybridModel"
)
_ = model(tf.ones((1, params.image_size, params.image_size, 3)))[0].shape
return model
def get_model_weight(model_id):
"""Get the trained weights."""
if not os.path.exists("model.h5"):
model_weight = gdown.download(id=model_id, quiet=False)
else:
model_weight = "model.h5"
return model_weight
def load_model(model_id):
"""Load trained model."""
weight = get_model_weight(model_id)
model = get_model()
model.load_weights(weight)
return model
def image_process(image):
"""Image preprocess for model input."""
image = tf.cast(image, dtype=tf.float32)
original_shape = image.shape
image = tf.image.resize(image, [image_size, image_size])
image = image[tf.newaxis, ...]
return image, original_shape
def predict_fn(image):
"""A predict function that will be invoked by gradio."""
loaded_model = load_model(model_id="1y6tseN0194T6d-4iIh5wo7RL9ttQERe0")
loaded_image, original_shape = image_process(image)
heatmap_a, heatmap_b, preds = make_gradcam_heatmap(loaded_image, loaded_model)
int_label = tf.argmax(preds, axis=-1).numpy()[0]
str_label = str_labels[int_label]
overaly_a = save_and_display_gradcam(
loaded_image[0], heatmap_a, image_shape=original_shape[:2]
)
overlay_b = save_and_display_gradcam(
loaded_image[0], heatmap_b, image_shape=original_shape[:2]
)
return [f"Predicted: {str_label}", overaly_a, overlay_b]
iface = gr.Interface(
fn=predict_fn,
inputs=gr.inputs.Image(label="Input Image"),
outputs=[
gr.outputs.Label(label="Prediction"),
gr.inputs.Image(label="CNN GradCAM"),
gr.inputs.Image(label="Transformer GradCAM"),
],
title="Hybrid EfficientNet Swin Transformer Demo",
description="The model is trained on tf_flowers dataset <a href='https://www.kaggle.com/datasets/alxmamaev/flowers-recognition'>Flowers Recognition Dataset</a>. It provides 5 categories, namely: `daisy`, `rose`, `sunflower`, `tulip`, `dandelion`. One example from each class is provided in the Example section.",
article = "<div><center><img src='https://visitor-badge.glitch.me/badge?page_id=hybrid-gradcam' alt='visitor badge'></center></div>",
examples=[
["examples/dandelion.jpg"],
["examples/sunflower.jpg"],
["examples/tulip.jpg"],
["examples/daisy.jpg"],
["examples/rose.jpg"],
],
)
iface.launch(share=True)