yasinelh's picture
Upload 2 files
8a33342
raw
history blame
2.97 kB
import streamlit as st
from PIL import Image
import cv2
import numpy as np
import time
import models
import torch
from torchvision import transforms
from torchvision import transforms
def load_model(path, model):
model.load_state_dict(torch.load(path, map_location=torch.device('cpu')))
return model
def predict(img):
model = models.unet(3, 1)
model = load_model('model.pth',model)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
img = cv2.resize(img, (512, 512))
convert_tensor = transforms.ToTensor()
img = convert_tensor(img).float()
img = normalize(img)
img = torch.unsqueeze(img, dim=0)
output = model(img)
result = torch.sigmoid(output)
threshold = 0.5
result = (result >= threshold).float()
prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
# Convert tensor to a numpy array
prediction_array = prediction.numpy()
# Rescale values to the range [0, 255]
prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
cv2.imwrite("test.png",prediction_array)
return prediction_array
def predicjt(img):
model1 = models.SAunet(3, 1)
model1 = load_model('saunet.pth',model1)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
img = cv2.resize(img, (512, 512))
convert_tensor = transforms.ToTensor()
img = convert_tensor(img).float()
img = normalize(img)
img = torch.unsqueeze(img, dim=0)
output = model1(img)
result = torch.sigmoid(output)
threshold = 0.5
result = (result >= threshold).float()
prediction = result[0].cpu() # Move tensor to CPU if it's on GPU
# Convert tensor to a numpy array
prediction_array = prediction.numpy()
# Rescale values to the range [0, 255]
prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0)
cv2.imwrite("test1.png",prediction_array)
return prediction_array
def main():
st.title("Image Segmentation Demo")
# Predefined list of image names
image_names = ["01_test.tif", "02_test.tif", "03_test.tif"]
# Create a selection box for the images
selected_image_name = st.selectbox("Select an Image", image_names)
# Load the selected image
selected_image = cv2.imread(selected_image_name)
# Display the selected image
st.image(selected_image, channels="RGB")
# Create a button for segmentation
if st.button("Segment"):
# Perform segmentation on the selected image
segmented_image = predict(selected_image)
segmented_image1 = predicjt(selected_image)
# Display the segmented image
st.image(segmented_image, channels="RGB",caption='U-Net segmentation')
st.image(segmented_image1, channels="RGB",caption='Spatial Attention U-Net segmentation ')
# Function to perform segmentation on the selected image
if __name__ == "__main__":
main()