3koozy's picture
Update app.py
f1427b4
raw
history blame contribute delete
No virus
1.52 kB
import gradio as gr
import segmentation_models_pytorch as smp
import torch
import PIL as Image
#load our pytorch model:
model = smp.Unet(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=10, # model output channels (number of classes in your dataset)
)
model.load_state_dict(torch.load('Floodnet_model_e5.pt', map_location=torch.device('cpu')))
model.eval()
#handle input:
# output = lbm(sample.unsqueeze(dim=0).float()).detach().type(torch.int64)
# show(output.argmax(dim=1).squeeze())
def predict_segmentation(image: Image.Image):
image = image.resize((256, 256))
input_data = np.asarray(image)
# Assuming the model expects a 4D input array
input_data = input_data[np.newaxis, ...]
# Get the prediction from the model
output_data = model.predict(torch.from_numpy(input_data).float())
# Assuming the output is a 3D array
output_mask = output.argmax(dim=1).squeeze()
# Convert the output_mask to an Image object
output_image = output_mask#Image.fromarray(np.uint8(output_mask.numpy()))
return output_image
image_input = gr.components.Image(shape=(256, 256), source="upload")
image_output = gr.components.Image(type="pil")
iface = gr.Interface(predict_segmentation, 'image', 'image')
iface.launch()