lakiet commited on
Commit
069f468
1 Parent(s): b39e213

Upload 14 files

Browse files
cnn_model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class custom_model(nn.Module):
4
+ def __init__(self, input_param: int, output_param: int):
5
+ super().__init__()
6
+ self.layer_block1 = nn.Sequential(
7
+ nn.Conv2d(in_channels=input_param, out_channels=32, kernel_size=3),
8
+ nn.ReLU(),
9
+ nn.BatchNorm2d(num_features=32),
10
+
11
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
12
+ nn.ReLU(),
13
+ nn.BatchNorm2d(num_features=64),
14
+ nn.MaxPool2d(kernel_size=2)
15
+ )
16
+
17
+ self.layer_block2 = nn.Sequential(
18
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3),
19
+ nn.ReLU(),
20
+ nn.BatchNorm2d(num_features=128),
21
+
22
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3),
23
+ nn.ReLU(),
24
+ nn.BatchNorm2d(num_features=128),
25
+ nn.MaxPool2d(kernel_size=2)
26
+ )
27
+
28
+ self.layer_block3 = nn.Sequential(
29
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3),
30
+ nn.ReLU(),
31
+ nn.BatchNorm2d(num_features=256),
32
+
33
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3),
34
+ nn.ReLU(),
35
+ nn.BatchNorm2d(num_features=256),
36
+ nn.MaxPool2d(kernel_size=2)
37
+ )
38
+
39
+ self.layer_block4 = nn.Sequential(
40
+ nn.Flatten(),
41
+ nn.Linear(in_features=256*12*12, out_features=256),
42
+ nn.Dropout(0.5),
43
+ nn.Linear(in_features=256, out_features=output_param)
44
+ )
45
+
46
+ def forward(self, x):
47
+ x = self.layer_block1(x)
48
+ x = self.layer_block2(x)
49
+ x = self.layer_block3(x)
50
+ x = self.layer_block4(x)
51
+ return x
52
+
53
+
examples/Early_blight.JPG ADDED
examples/Leaf_mold.JPG ADDED
examples/Sep_leaf_spot.JPG ADDED
examples/bacterial_spot.JPG ADDED
examples/healthy.JPG ADDED
examples/late_blight.JPG ADDED
examples/mosaic_virus.JPG ADDED
examples/target_spot.JPG ADDED
examples/two_spotted_spider.JPG ADDED
examples/yellow_leaf_curl_virus.JPG ADDED
main.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import torch
4
+ from cnn_model import custom_model
5
+ from timeit import default_timer as timer
6
+ from typing import Tuple, Dict
7
+ from torchvision import transforms
8
+
9
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
+
11
+ class_name = ['Tomato___Bacterial_spot',
12
+ 'Tomato___Early_blight',
13
+ 'Tomato___Late_blight',
14
+ 'Tomato___Leaf_Mold',
15
+ 'Tomato___Septoria_leaf_spot',
16
+ 'Tomato___Spider_mites Two-spotted_spider_mite',
17
+ 'Tomato___Target_Spot',
18
+ 'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
19
+ 'Tomato___Tomato_mosaic_virus',
20
+ 'Tomato___healthy']
21
+
22
+ #Function for gradio
23
+ def predict_gradio(img):
24
+ start_time = timer()
25
+
26
+ image_transform = transforms.Compose([
27
+ transforms.Resize(128),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
30
+ std=[0.229, 0.224, 0.225]),
31
+ ])
32
+
33
+
34
+ #Load model
35
+ model_path = r'Deployment\tomato_plants\model_2.pth'
36
+ loaded_model_2 = custom_model(input_param=3, output_param=10)
37
+ loaded_model_2.load_state_dict(torch.load(f=model_path))
38
+ loaded_model_2 = loaded_model_2.to(device)
39
+ loaded_model_2.eval()
40
+
41
+ with torch.inference_mode():
42
+ transformed_image = image_transform(img).unsqueeze(dim=0)
43
+
44
+ target_image_pred = loaded_model_2(transformed_image.to(device))
45
+
46
+ pred_probs = torch.softmax(target_image_pred, dim=1)
47
+ pred_labels_and_probs = {class_name[i]: float(pred_probs[0][i]) for i in range(len(class_name))}
48
+ pred_time = round(timer() - start_time, 4)
49
+
50
+ return pred_labels_and_probs, pred_time
51
+
52
+ a = 'Deployment\tomato_plants\examples\bacterial_spot.JPG'
53
+
54
+ #Create title
55
+ title = 'Tomato Plants Disease Detector'
56
+ description = 'A custom CNN image classification model to detect 9 diseases on tomato plants'
57
+ articile = 'Created at [Deploy the tomato plant diseases image classification by using Gradio](https://github.com/lakiet1609/Deploy-the-tomato-plant-diseases-image-classification-by-using-Gradio)'
58
+ example_list = [['Deployment/tomato_plants/examples/' + example] for example in os.listdir(r'Deployment\tomato_plants\examples')]
59
+
60
+ # Create the Gradio demo
61
+ demo = gr.Interface(fn=predict_gradio,
62
+ inputs=gr.Image(type='pil'),
63
+ outputs=[gr.Label(num_top_classes=3, label='prediction'),
64
+ gr.Number(label='Prediction time (second)')],
65
+ examples=example_list,
66
+ title=title,
67
+ description=description,
68
+ article=articile)
69
+
70
+ demo.launch(debug=False,
71
+ share=True)
model_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8cc4a346644e5cbbd873289452f3c7456727a330c169f20cd372801728ee0da5
3
+ size 42292827
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==1.12.0
2
+ torchvision==0.13.0
3
+ gradio==3.1.4