MutilModelSeg / app.py
azhongai666666's picture
Create app.py
bbf6b72 verified
raw
history blame contribute delete
No virus
3.25 kB
import torch
from torch import nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import segmentation_models_pytorch as smp
import gradio as gr
num_classes = 2
model_unet_path = "unet_model.pth"
model_fpn_path = "fpn_model.pth"
model_deeplab_path = "deeplabv3_model.pth"
image_path = "leaf11.jpg"
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model_unet = smp.Unet(
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=num_classes, # model output channels (number of classes in your dataset)
)
model_fpn = smp.FPN(
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=num_classes, # model output channels (number of classes in your dataset)
)
model_deeplab = smp.DeepLabV3(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=num_classes, # model output channels (number of classes in your dataset)
)
def pred_one_image(inp,option):
one_image = np.array(inp.resize((256, 256)).convert("RGB"))
# convert to other format HWC -> CHW
one_image = np.moveaxis(one_image, -1, 0)
# mask = np.expand_dims(mask, 0)
one_image = torch.tensor(one_image).float()
one_image = one_image.unsqueeze(0)
one_image = one_image.to(device)
if option == "unet":
model_load = model_unet
elif option == "fpn":
model_load = model_fpn
elif option == "deeplab":
model_load = model_deeplab
model_load.eval()
with torch.no_grad():
output = model_load(one_image)
# print(output.shape)
predictions = torch.argmax(output, dim=1) # čŽ·å–é¢„ęµ‹ēš„ē±»åˆ«ę ‡ē­¾å›¾åƒ
pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8)
# print(pred_array.shape)
pred_img = Image.fromarray(pred_array)
# pred_img.save("pred.png")
# print(predictions.shape)
return pred_img
model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu')))
model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu')))
model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu')))
dropdown = gr.Dropdown(["unet", "fpn","deeplab"])
interface = gr.Interface(fn=pred_one_image,
inputs=[gr.Image(type="pil"),dropdown],
outputs=gr.Image(type="pil"),
examples=[["leaf11.jpg",'unet']],)
interface.launch(debug=False)