Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import torch | |
from torchvision import transforms | |
import numpy as np | |
import os | |
from osgeo import gdal | |
# Load the pretrained model | |
def load_model(): | |
model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', | |
pretrained=True, progress=True) | |
model.eval() | |
return model | |
# Function to load large TIFF images | |
def load_tiff_image(tiff_path): | |
try: | |
dataset = gdal.Open(tiff_path) | |
if dataset is None: | |
st.error("Failed to load the TIFF image. Please check the file format.") | |
return None | |
band = dataset.GetRasterBand(1) # Assuming grayscale or single band | |
image = band.ReadAsArray() | |
return image | |
except Exception as e: | |
st.error(f"Error loading image: {e}") | |
return None | |
# Preprocess image | |
def preprocess_image(image): | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((256, 256)), # Resize image for model input | |
transforms.Normalize(mean=[0.485], std=[0.229]) # Normalize | |
]) | |
image_tensor = transform(image).unsqueeze(0) # Add batch dimension | |
return image_tensor | |
# Post-process prediction to display | |
def postprocess_prediction(pred): | |
pred = torch.sigmoid(pred) | |
pred = pred.squeeze().detach().numpy() # Remove batch dimension | |
pred = (pred > 0.5).astype(np.uint8) # Binary mask thresholding | |
return pred | |
# Streamlit app | |
st.title("TIFF Image Upload and Model Prediction") | |
# Upload image | |
uploaded_file = st.file_uploader("Upload a large TIFF image (up to 5GB)", type=["tiff"]) | |
if uploaded_file is not None: | |
with open("temp_image.tiff", "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
tiff_image = load_tiff_image("temp_image.tiff") | |
if tiff_image is not None: | |
st.write("Original Image") | |
st.image(tiff_image, caption="Uploaded Image", use_column_width=True) | |
model = load_model() | |
image = Image.fromarray(tiff_image) | |
image_tensor = preprocess_image(image) | |
with torch.no_grad(): | |
prediction = model(image_tensor) | |
pred_image = postprocess_prediction(prediction) | |
st.write("Model Prediction") | |
st.image(pred_image, caption="Predicted Image", use_column_width=True) | |
os.remove("temp_image.tiff") | |