File size: 3,454 Bytes
a50af6a
 
 
 
 
 
 
 
026895a
3d1c14a
36743cd
3d1c14a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03abd43
 
 
 
 
c43bd81
03abd43
f0ac27b
 
b34644c
36743cd
f0ac27b
 
 
36743cd
f0ac27b
 
36743cd
f0ac27b
 
 
 
a50af6a
 
 
f0ac27b
a50af6a
 
 
 
 
 
 
f0ac27b
 
a50af6a
 
 
36743cd
a50af6a
36743cd
f0ac27b
36743cd
99312fa
a50af6a
 
f0ac27b
3d1c14a
b34644c
f0ac27b
 
a50af6a
 
96563da
06de145
2f9e95a
f9f2226
b34644c
ea493f4
a50af6a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import requests
import torch
import torch.nn as nn
from PIL import Image
from torchvision.models import resnet18
from torchvision.transforms import functional as F

def main():
    # ใƒขใƒ‡ใƒซใฎ้ธๆŠž
    plant_models = {
        'cucumber':{
            'model_path':'cucumber_resnet18_last_model.pth',
            'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็ฐ่‰ฒใ‹ใณ็—…","็‚ญ็–ฝ็—…","ในใจ็—…","่คๆ–‘็—…","ใคใ‚‹ๆžฏ็—…","ๆ–‘็‚น็ดฐ่Œ็—…","CCYV","ใƒขใ‚ถใ‚คใ‚ฏ็—…","MYSV"]
        },
        'eggplant':{
            'model_path':'eggplant_resnet18_last_model.pth',
            'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็ฐ่‰ฒใ‹ใณ็—…","่ค่‰ฒๅ††ๆ˜Ÿ็—…","ใ™ใ™ใ‹ใณ็—…","ๅŠ่บซ่Žๅ‡‹็—…","้’ๆžฏ็—…"]
        },
        'strawberry':{
            'model_path':'strawberry_resnet18_last_model.pth',
            'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็‚ญ็–ฝ็—…","่Ž้ป„็—…"]
        },
        'tomato':{
            'model_path':'tomato_resnet18_last_model.pth',
            'labels_list' : ["ๅฅๅ…จ","ใ†ใฉใ‚“ใ“็—…","็ฐ่‰ฒใ‹ใณ็—…","ใ™ใ™ใ‹ใณ็—…","่‘‰ใ‹ใณ็—…","็–ซ็—…","่ค่‰ฒ่ผช็ด‹็—…","้’ๆžฏ็—…","ใ‹ใ„ใ‚ˆใ†็—…","้ป„ๅŒ–่‘‰ๅทป็—…"]
        },

    }
    # examples_images=[
    #         ['image/231305_20200302150233_01.JPG'],
    #         ['image/0004_20181120084837_01.jpg'],
    #         ['image/160001_20170830173740_01.JPG'],
    #         ['image/152300_20190119175054_01.JPG'],
            
    #     ]
    # ใƒขใƒ‡ใƒซใฎๆบ–ๅ‚™ใ™ใ‚‹้–ขๆ•ฐใ‚’ๅฎš็พฉ
    def select_model(plant_name):
 
        model_ft = resnet18(num_classes = len(plant_models[plant_name]['labels_list']),pretrained=False)
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model_ft = model_ft.to(device)
        if torch.cuda.is_available():
            model_ft.load_state_dict(torch.load(plant_models[plant_name]['model_path']))
        else:
            model_ft.load_state_dict(
                torch.load(plant_models[plant_name]['model_path'], map_location=torch.device("cpu"))
            )
        model_ft.eval()

        return model_ft

    # ็”ปๅƒๅˆ†้กžใ‚’่กŒใ†้–ขๆ•ฐใ‚’ๅฎš็พฉ
    @torch.no_grad()
    def inference(gr_input,gr_model_type):
        img = Image.fromarray(gr_input.astype("uint8"), "RGB")

        # ๅ‰ๅ‡ฆ็†
        img = F.resize(img, (224, 224))
        img = F.to_tensor(img)
        img = img.unsqueeze(0)

        # ใƒขใƒ‡ใƒซ้ธๆŠž
        model_ft = select_model(gr_model_type)
        # ๆŽจ่ซ–
        output = model_ft(img).squeeze(0)
        probs = nn.functional.softmax(output, dim=0).numpy()
        labels_lenght =len(plant_models[gr_model_type]['labels_list'])
        # ใƒฉใƒ™ใƒซใ”ใจใฎ็ขบ็އใ‚’dictใจใ—ใฆ่ฟ”ใ™
        return {plant_models[gr_model_type]['labels_list'][i]: float(probs[i]) for i in range(labels_lenght)}
    
    model_labels = list(plant_models.keys())
    
    # ๅ…ฅๅŠ›ใฎๅฝขๅผใ‚’็”ปๅƒใจใ™ใ‚‹
    inputs = gr.inputs.Image()
    # ใƒขใƒ‡ใƒซใฎ็จฎ้กžใ‚’้ธๆŠžใ™ใ‚‹
    model_type = gr.inputs.Radio(model_labels, type='value', label='BASE MODEL')

    # ๅ‡บๅŠ›ใฏใƒฉใƒ™ใƒซๅฝขๅผใง๏ผŒtop4ใพใง่กจ็คบใ™ใ‚‹
    outputs = gr.outputs.Label(num_top_classes=4)

    # ใ‚ตใƒผใƒใƒผใฎ็ซ‹ใกไธŠใ’
    interface = gr.Interface(fn=inference, 
        inputs=[inputs, model_type],
        outputs=outputs,
        title='Plant Diseases Diagnosis',
        )

    interface.launch()


if __name__ == "__main__":
    main()