|
import torch |
|
from torch import nn |
|
import torchvision |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from PIL import Image |
|
import pandas as pd |
|
import segmentation_models_pytorch as smp |
|
import gradio as gr |
|
|
|
num_classes = 2 |
|
model_unet_path = "unet_model.pth" |
|
model_fpn_path = "fpn_model.pth" |
|
model_deeplab_path = "deeplabv3_model.pth" |
|
image_path = "leaf11.jpg" |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
print(f"Using {device} device") |
|
|
|
model_unet = smp.Unet( |
|
encoder_name="resnet18", |
|
encoder_weights=None, |
|
in_channels=3, |
|
classes=num_classes, |
|
) |
|
|
|
model_fpn = smp.FPN( |
|
encoder_name="resnet18", |
|
encoder_weights=None, |
|
in_channels=3, |
|
classes=num_classes, |
|
) |
|
|
|
model_deeplab = smp.DeepLabV3( |
|
encoder_name="resnet34", |
|
encoder_weights=None, |
|
in_channels=3, |
|
classes=num_classes, |
|
) |
|
|
|
def pred_one_image(inp,option): |
|
one_image = np.array(inp.resize((256, 256)).convert("RGB")) |
|
|
|
one_image = np.moveaxis(one_image, -1, 0) |
|
|
|
one_image = torch.tensor(one_image).float() |
|
one_image = one_image.unsqueeze(0) |
|
one_image = one_image.to(device) |
|
if option == "unet": |
|
model_load = model_unet |
|
elif option == "fpn": |
|
model_load = model_fpn |
|
elif option == "deeplab": |
|
model_load = model_deeplab |
|
model_load.eval() |
|
with torch.no_grad(): |
|
output = model_load(one_image) |
|
|
|
predictions = torch.argmax(output, dim=1) |
|
pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8) |
|
|
|
pred_img = Image.fromarray(pred_array) |
|
|
|
|
|
return pred_img |
|
|
|
|
|
|
|
model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu'))) |
|
model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu'))) |
|
model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu'))) |
|
|
|
dropdown = gr.Dropdown(["unet", "fpn","deeplab"]) |
|
interface = gr.Interface(fn=pred_one_image, |
|
inputs=[gr.Image(type="pil"),dropdown], |
|
outputs=gr.Image(type="pil"), |
|
examples=[["leaf11.jpg",'unet']],) |
|
interface.launch(debug=False) |
|
|