dfghj1345 commited on
Commit
ebcff8f
1 Parent(s): 7bdedf2

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +79 -0
  2. fpn_model.pth +3 -0
  3. leaf11.jpg +0 -0
  4. requirements.txt +3 -0
  5. unet_model.pth +3 -0
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torchvision
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+ import pandas as pd
8
+ import segmentation_models_pytorch as smp
9
+ import gradio as gr
10
+
11
+ num_classes = 2
12
+ model_unet_path = "unet_model.pth"
13
+ model_fpn_path = "fpn_model.pth"
14
+ model_deeplab_path = "deeplabv3_model.pth"
15
+ image_path = "leaf11.jpg"
16
+
17
+ # Get cpu or gpu device for training.
18
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
19
+ print(f"Using {device} device")
20
+
21
+ model_unet = smp.Unet(
22
+ encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
23
+ encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
24
+ in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
25
+ classes=num_classes, # model output channels (number of classes in your dataset)
26
+ )
27
+
28
+ model_fpn = smp.FPN(
29
+ encoder_name="resnet18", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
30
+ encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
31
+ in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
32
+ classes=num_classes, # model output channels (number of classes in your dataset)
33
+ )
34
+
35
+ model_deeplab = smp.DeepLabV3(
36
+ encoder_name="resnet34", # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
37
+ encoder_weights=None, # use `imagenet` pre-trained weights for encoder initialization
38
+ in_channels=3, # model input channels (1 for gray-scale images, 3 for RGB, etc.)
39
+ classes=num_classes, # model output channels (number of classes in your dataset)
40
+ )
41
+
42
+ def pred_one_image(inp,option):
43
+ one_image = np.array(inp.resize((256, 256)).convert("RGB"))
44
+ # convert to other format HWC -> CHW
45
+ one_image = np.moveaxis(one_image, -1, 0)
46
+ # mask = np.expand_dims(mask, 0)
47
+ one_image = torch.tensor(one_image).float()
48
+ one_image = one_image.unsqueeze(0)
49
+ one_image = one_image.to(device)
50
+ if option == "unet":
51
+ model_load = model_unet
52
+ elif option == "fpn":
53
+ model_load = model_fpn
54
+ elif option == "deeplab":
55
+ model_load = model_deeplab
56
+ model_load.eval()
57
+ with torch.no_grad():
58
+ output = model_load(one_image)
59
+ # print(output.shape)
60
+ predictions = torch.argmax(output, dim=1) # 获取预测的类别标签图像
61
+ pred_array = (predictions[0].cpu().numpy()/2*255).astype(np.uint8)
62
+ # print(pred_array.shape)
63
+ pred_img = Image.fromarray(pred_array)
64
+ # pred_img.save("pred.png")
65
+ # print(predictions.shape)
66
+ return pred_img
67
+
68
+
69
+
70
+ model_unet.load_state_dict(torch.load(model_unet_path,map_location=torch.device('cpu')))
71
+ model_fpn.load_state_dict(torch.load(model_fpn_path,map_location=torch.device('cpu')))
72
+ model_deeplab.load_state_dict(torch.load(model_deeplab_path,map_location=torch.device('cpu')))
73
+
74
+ dropdown = gr.Dropdown(["unet", "fpn","deeplab"])
75
+ interface = gr.Interface(fn=pred_one_image,
76
+ inputs=[gr.Image(type="pil"),dropdown],
77
+ outputs=gr.Image(type="pil"),
78
+ examples=[["leaf11.jpg",'unet']],)
79
+ interface.launch(debug=False)
fpn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b5bba08e48b133fecd0aa289955977555f279629860d98fa02811265eeb893c
3
+ size 52282682
leaf11.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ segmentation-models-pytorch
unet_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a66d41ff565cab4119957301a8ce147339725b05937fe3a9c68a81374f43c6f
3
+ size 57425076