Vineedhar's picture
Update app.py
e234d10 verified
import gradio as gr
import numpy as np
from PIL import Image
import tensorflow as tf
def iou(y_true, y_pred):
def f(y_true, y_pred):
intersection = (y_true * y_pred).sum()
union = y_true.sum() + y_pred.sum() - intersection
x = (intersection + 1e-15) / (union + 1e-15)
x = x.astype(np.float32)
return x
return tf.numpy_function(f, [y_true, y_pred], tf.float32)
def dice_coef(y_true, y_pred):
y_true = tf.keras.layers.Flatten()(y_true)
y_pred = tf.keras.layers.Flatten()(y_pred)
intersection = tf.reduce_sum(y_true * y_pred)
return (2. * intersection + 1e-15) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred))
def dice_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred)
def read_image(file_path, target_size=(256, 256)):
img = Image.open(file_path)
img = img.resize(target_size)
x = np.array(img, dtype=np.float32)
x = x / 255.0
return x
def preprocess_image(img):
if img.shape[-1] == 4:
img = img[..., :3]
img_expanded = np.expand_dims(img, axis=0)
return img_expanded
def predict_image(model, img):
pred = model.predict(img)
return pred[0, ...] # Taking the first item in the batch
# Load the model with specific custom objects
loaded_model = tf.keras.models.load_model(
"oryx_road_segmentation_model.h5",
custom_objects={'dice_coef': dice_coef, 'iou': iou})
def process_image(image):
img = read_image(image)
img_preprocessed = preprocess_image(img)
pred = predict_image(loaded_model, img_preprocessed)
# Convert single-channel image to RGB by duplicating the channel across RGB
pred_img = np.squeeze(pred) # Remove the singleton dimension
pred_img = np.clip(pred_img, 0, 1) # Ensure all values are between 0 and 1
pred_img_rgb = np.stack((pred_img,)*3, axis=-1) # Stack grayscale across three channels to mimic RGB
pred_img_rgb = (pred_img_rgb * 255).astype(np.uint8) # Scale to 0-255 and convert to uint8
return Image.fromarray(pred_img_rgb) # Now converting a proper 2D RGB array
# Sample images directory or paths
sample_images = ["234989_sat.jpg", "751359_sat.jpg","168243_sat.jpg","877873_sat.jpg","836987_sat.jpg"]
# Gradio Interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type="filepath"),
outputs=gr.Image(type="pil", label="Predicted Image"),
title="orYx Models' - Road Segmentation Predictor",
description="Upload an image or choose a sample and view the model's segmentation for roads on different terrains.",
examples= sample_images
)
iface.launch(debug=True)