imageseg / app.py
dfghj1345's picture
Upload 5 files
ebcff8f verified
raw history blame
No virus
3.25 kB
import torch
from torch import nn
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pandas as pd
import segmentation_models_pytorch as smp
import gradio as gr
num_classes = 2
model_unet_path = "unet_model.pth"
model_fpn_path = "fpn_model.pth"
model_deeplab_path = "deeplabv3_model.pth"
image_path = "leaf11.jpg"
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
model_unet = smp.Unet(
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=num_classes, # model output channels (number of classes in your dataset)
)
model_fpn = smp.FPN(
encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=num_classes, # model output channels (number of classes in your dataset)
)
model_deeplab = smp.DeepLabV3(
encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
classes=num_classes, # model output channels (number of classes in your dataset)
)
def pred_one_image(inp,option):
one_image = np.array(inp.resize((256, 256)).convert("RGB"))
# convert to other format HWC -> CHW
one_image = np.moveaxis(one_image, -1, 0)
# mask = np.expand_dims(mask, 0)
one_image = torch.tensor(one_image).float()
one_image = one_image.unsqueeze(0)
one_image = one_image.to(device)
if option == "unet":
model_load = model_unet
elif option == "fpn":
model_load = model_fpn
elif option == "deeplab":
model_load = model_deeplab
model_load.eval()
with torch.no_grad():
output = model_load(one_image)
# print(output.shape)
predictions = torch.argmax(output, dim=1) # 获取预测的类别标签图像
pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8)
# print(pred_array.shape)
pred_img = Image.fromarray(pred_array)
# pred_img.save("pred.png")
# print(predictions.shape)
return pred_img
model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu')))
model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu')))
model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu')))
dropdown = gr.Dropdown(["unet", "fpn","deeplab"])
interface = gr.Interface(fn=pred_one_image,
inputs=[gr.Image(type="pil"),dropdown],
outputs=gr.Image(type="pil"),
examples=[["leaf11.jpg",'unet']],)
interface.launch(debug=False)