rajsecrets0's picture
Create app.py
1f2b3fa verified
raw
history blame
2.45 kB
import streamlit as st
from PIL import Image
import torch
from torchvision import transforms
import numpy as np
import os
from osgeo import gdal
# Load the pretrained model
@st.cache(allow_output_mutation=True)
def load_model():
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet',
pretrained=True, progress=True)
model.eval()
return model
# Function to load large TIFF images
def load_tiff_image(tiff_path):
try:
dataset = gdal.Open(tiff_path)
if dataset is None:
st.error("Failed to load the TIFF image. Please check the file format.")
return None
band = dataset.GetRasterBand(1) # Assuming grayscale or single band
image = band.ReadAsArray()
return image
except Exception as e:
st.error(f"Error loading image: {e}")
return None
# Preprocess image
def preprocess_image(image):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((256, 256)), # Resize image for model input
transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize
])
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
return image_tensor
# Post-process prediction to display
def postprocess_prediction(pred):
pred = torch.sigmoid(pred)
pred = pred.squeeze().detach().numpy() # Remove batch dimension
pred = (pred > 0.5).astype(np.uint8) # Binary mask thresholding
return pred
# Streamlit app
st.title("TIFF Image Upload and Model Prediction")
# Upload image
uploaded_file = st.file_uploader("Upload a large TIFF image (up to 5GB)", type=["tiff"])
if uploaded_file is not None:
with open("temp_image.tiff", "wb") as f:
f.write(uploaded_file.getbuffer())
tiff_image = load_tiff_image("temp_image.tiff")
if tiff_image is not None:
st.write("Original Image")
st.image(tiff_image, caption="Uploaded Image", use_column_width=True)
model = load_model()
image = Image.fromarray(tiff_image)
image_tensor = preprocess_image(image)
with torch.no_grad():
prediction = model(image_tensor)
pred_image = postprocess_prediction(prediction)
st.write("Model Prediction")
st.image(pred_image, caption="Predicted Image", use_column_width=True)
os.remove("temp_image.tiff")