IceAge / gradio_app.py
UsmanGhias's picture
Added Three files
0bf3fe4 verified
import gradio as gr
import tensorflow as tf
import numpy as np
from pyrsgis import raster, convert
from sklearn.preprocessing import StandardScaler
from PIL import Image
import io
# Load the model
model = tf.keras.models.load_model('SGDNet.h5')
def predict(image_path):
# Process the image file
ds, image_data = raster.read(image_path, bands='all')
image_data = convert.array_to_table(image_data)
scaler = StandardScaler()
image_data = scaler.fit_transform(image_data)
image_data = image_data.reshape((image_data.shape[0], 1, image_data.shape[1]))
# Make prediction
predicted = model.predict(image_data)
predicted_prob = predicted[:, 1]
predicted_prob = np.reshape(predicted_prob, (ds.RasterYSize, ds.RasterXSize))
# Convert prediction to image
im = Image.fromarray((predicted_prob * 255).astype(np.uint8))
bio = io.BytesIO()
im.save(bio, format='PNG')
return bio.getvalue()
def save_uploaded_file(uploaded_file):
with open(uploaded_file.name, "wb") as f:
f.write(uploaded_file.getbuffer())
return uploaded_file.name
with gr.Blocks() as app:
with gr.Row():
with gr.Column():
file_input = gr.File(label="Upload your satellite image")
submit_button = gr.Button("Predict")
with gr.Column():
image_output = gr.Image(label="Predicted Glacier Boundaries")
submit_button.click(
fn=lambda x: predict(save_uploaded_file(x)),
inputs=file_input,
outputs=image_output
)
if __name__ == "__main__":
app.launch()