import torch from torch import nn import torchvision import numpy as np import matplotlib.pyplot as plt from PIL import Image import pandas as pd import segmentation_models_pytorch as smp import gradio as gr num_classes = 2 model_unet_path = "unet_model.pth" model_fpn_path = "fpn_model.pth" model_deeplab_path = "deeplabv3_model.pth" image_path = "leaf11.jpg" # Get cpu or gpu device for training. device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") model_unet = smp.Unet( encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=num_classes, # model output channels (number of classes in your dataset) ) model_fpn = smp.FPN( encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=num_classes, # model output channels (number of classes in your dataset) ) model_deeplab = smp.DeepLabV3( encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.) classes=num_classes, # model output channels (number of classes in your dataset) ) def pred_one_image(inp,option): one_image = np.array(inp.resize((256, 256)).convert("RGB")) # convert to other format HWC -> CHW one_image = np.moveaxis(one_image, -1, 0) # mask = np.expand_dims(mask, 0) one_image = torch.tensor(one_image).float() one_image = one_image.unsqueeze(0) one_image = one_image.to(device) if option == "unet": model_load = model_unet elif option == "fpn": model_load = model_fpn elif option == "deeplab": model_load = model_deeplab model_load.eval() with torch.no_grad(): output = model_load(one_image) # print(output.shape) predictions = torch.argmax(output, dim=1) # 获取预测的类别标签图像 pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8) # print(pred_array.shape) pred_img = Image.fromarray(pred_array) # pred_img.save("pred.png") # print(predictions.shape) return pred_img model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu'))) model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu'))) model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu'))) dropdown = gr.Dropdown(["unet", "fpn","deeplab"]) interface = gr.Interface(fn=pred_one_image, inputs=[gr.Image(type="pil"),dropdown], outputs=gr.Image(type="pil"), examples=[["leaf11.jpg",'unet']],) interface.launch(debug=False)