|
import streamlit as st |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
import time |
|
import models |
|
import torch |
|
|
|
from torchvision import transforms |
|
from torchvision import transforms |
|
|
|
def load_model(path, model): |
|
model.load_state_dict(torch.load(path, map_location=torch.device('cpu'))) |
|
return model |
|
|
|
def predict(img): |
|
model = models.unet(3, 1) |
|
model = load_model('model.pth',model) |
|
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) |
|
img = cv2.resize(img, (512, 512)) |
|
convert_tensor = transforms.ToTensor() |
|
img = convert_tensor(img).float() |
|
img = normalize(img) |
|
img = torch.unsqueeze(img, dim=0) |
|
|
|
output = model(img) |
|
result = torch.sigmoid(output) |
|
|
|
threshold = 0.5 |
|
result = (result >= threshold).float() |
|
prediction = result[0].cpu() |
|
|
|
prediction_array = prediction.numpy() |
|
|
|
prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0) |
|
cv2.imwrite("test.png",prediction_array) |
|
return prediction_array |
|
|
|
def predicjt(img): |
|
model1 = models.SAunet(3, 1) |
|
model1 = load_model('saunet.pth',model1) |
|
|
|
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) |
|
img = cv2.resize(img, (512, 512)) |
|
convert_tensor = transforms.ToTensor() |
|
img = convert_tensor(img).float() |
|
img = normalize(img) |
|
img = torch.unsqueeze(img, dim=0) |
|
|
|
output = model1(img) |
|
result = torch.sigmoid(output) |
|
|
|
threshold = 0.5 |
|
result = (result >= threshold).float() |
|
prediction = result[0].cpu() |
|
|
|
prediction_array = prediction.numpy() |
|
|
|
prediction_array = (prediction_array * 255).astype('uint8').transpose(1, 2, 0) |
|
cv2.imwrite("test1.png",prediction_array) |
|
return prediction_array |
|
def main(): |
|
st.title("Image Segmentation Demo") |
|
|
|
|
|
image_names = ["01_test.tif", "02_test.tif", "03_test.tif"] |
|
|
|
|
|
selected_image_name = st.selectbox("Select an Image", image_names) |
|
|
|
|
|
selected_image = cv2.imread(selected_image_name) |
|
|
|
|
|
st.image(selected_image, channels="RGB") |
|
|
|
|
|
if st.button("Segment"): |
|
|
|
segmented_image = predict(selected_image) |
|
segmented_image1 = predicjt(selected_image) |
|
|
|
|
|
|
|
st.image(segmented_image, channels="RGB",caption='U-Net segmentation') |
|
st.image(segmented_image1, channels="RGB",caption='Spatial Attention U-Net segmentation ') |
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|