Kaori1707 commited on
Commit
91df7cc
·
1 Parent(s): 8aa08c5
Files changed (2) hide show
  1. app.py +35 -63
  2. config.yaml +0 -11
app.py CHANGED
@@ -1,26 +1,29 @@
1
  from typing import Any
2
  import pytorch_lightning as pl
3
- from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
4
  import torch
5
  from torch import nn
6
  from torchvision import transforms
7
- from torch.nn import functional as F
8
  import yaml
9
  from yaml.loader import SafeLoader
10
- from PIL import Image
11
  import gradio as gr
12
  import os
13
 
 
14
  class WeedModel(pl.LightningModule):
15
  def __init__(self, params):
16
  super().__init__()
17
  self.params = params
18
-
19
  model = self.params["model"]
20
 
21
- if(model.lower() == "efficientnet"):
22
- if(self.params["pretrained"]): self.base_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
23
- else: self.base_model = efficientnet_v2_s(weights=None)
 
 
 
 
24
  num_ftrs = self.base_model.classifier[-1].in_features
25
  self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"])
26
 
@@ -31,41 +34,15 @@ class WeedModel(pl.LightningModule):
31
  embedding = self.base_model(x)
32
  return embedding
33
 
34
- def configure_optimizers(self):
35
- if(self.params["optimizer"] == "Adam"):
36
- optimizer = torch.optim.Adam(self.parameters(), lr=self.params["Lr"])
37
- elif(self.params["optimizer"] == "SGD"):
38
- optimizer = torch.optim.SGD(self.parameters(), lr=self.params["Lr"])
39
- return optimizer
40
-
41
- def training_step(self, train_batch, batch_idx):
42
- x = train_batch["image"]
43
- y = train_batch["label"]
44
-
45
- y_hat = self(x)
46
- loss = F.cross_entropy(y_hat, y)
47
- self.log('metrics/batch/train_loss', loss, prog_bar=False)
48
-
49
- preds = F.softmax(y_hat, dim=-1)
50
-
51
- return loss
52
-
53
- def validation_step(self, val_batch, batch_idx):
54
-
55
- x = val_batch["image"]
56
- y = val_batch["label"]
57
-
58
- y_hat = self(x)
59
- loss = F.cross_entropy(y_hat, y)
60
- self.log('metrics/batch/val_loss', loss)
61
-
62
- def predict_step(self, batch: Any, batch_idx: int=0, dataloader_idx: int = 0) -> Any:
63
  y_hat = self(batch)
64
  preds = torch.softmax(y_hat, dim=-1).tolist()
65
-
66
  # preds = torch.argmax(preds, dim=-1)
67
  return preds
68
-
69
 
70
  def predict(image):
71
 
@@ -80,45 +57,40 @@ title = " AISeed AI Application Demo "
80
  description = "# A Demo of Deep Learning for Weed Classification"
81
  example_list = [["examples/" + example] for example in os.listdir("examples")]
82
 
83
- with open("class_names.txt", "r", encoding='utf-8') as f:
84
  class_names = f.read().splitlines()
85
-
86
  with gr.Blocks() as demo:
87
  demo.title = title
88
  gr.Markdown(description)
89
  with gr.Tabs():
90
- with gr.TabItem("for Images"):
91
  with gr.Row():
92
  with gr.Column():
93
- im = gr.Image(type="pil", label="input image")
94
  with gr.Column():
95
  label_conv = gr.Label(label="Predictions", num_top_classes=4)
96
  btn = gr.Button(value="predict")
97
  btn.click(predict, inputs=im, outputs=[label_conv])
98
  gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv])
99
- with gr.TabItem("for Webcam"):
100
- with gr.Row():
101
- with gr.Column():
102
- webcam = gr.Image(type="pil", label="input image", source="webcam")
103
- # capture = gr.Image(type="pil", label="output image")
104
- with gr.Column():
105
- label = gr.Label(label="Predictions", num_top_classes=4)
106
-
107
- webcam.change(predict, inputs=webcam, outputs=[label])
108
-
109
-
110
- if __name__ == '__main__':
111
- with open('config.yaml') as f:
112
  PARAMS = yaml.load(f, Loader=SafeLoader)
113
  print(PARAMS)
114
- model = WeedModel.load_from_checkpoint("model/epoch=08.ckpt", params=PARAMS, map_location=torch.device('cpu'))
 
 
115
  model.eval()
116
 
117
- transform = transforms.Compose([
118
- transforms.Resize(256),
119
- transforms.CenterCrop(224),
120
- transforms.ToTensor(),
121
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
122
- ])
123
-
 
 
124
  demo.launch()
 
1
  from typing import Any
2
  import pytorch_lightning as pl
3
+ from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
4
  import torch
5
  from torch import nn
6
  from torchvision import transforms
 
7
  import yaml
8
  from yaml.loader import SafeLoader
 
9
  import gradio as gr
10
  import os
11
 
12
+
13
  class WeedModel(pl.LightningModule):
14
  def __init__(self, params):
15
  super().__init__()
16
  self.params = params
17
+
18
  model = self.params["model"]
19
 
20
+ if model.lower() == "efficientnet":
21
+ if self.params["pretrained"]:
22
+ self.base_model = efficientnet_v2_s(
23
+ weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1
24
+ )
25
+ else:
26
+ self.base_model = efficientnet_v2_s(weights=None)
27
  num_ftrs = self.base_model.classifier[-1].in_features
28
  self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"])
29
 
 
34
  embedding = self.base_model(x)
35
  return embedding
36
 
37
+ def predict_step(
38
+ self, batch: Any, batch_idx: int = 0, dataloader_idx: int = 0
39
+ ) -> Any:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  y_hat = self(batch)
41
  preds = torch.softmax(y_hat, dim=-1).tolist()
42
+
43
  # preds = torch.argmax(preds, dim=-1)
44
  return preds
45
+
46
 
47
  def predict(image):
48
 
 
57
  description = "# A Demo of Deep Learning for Weed Classification"
58
  example_list = [["examples/" + example] for example in os.listdir("examples")]
59
 
60
+ with open("class_names.txt", "r", encoding="utf-8") as f:
61
  class_names = f.read().splitlines()
62
+
63
  with gr.Blocks() as demo:
64
  demo.title = title
65
  gr.Markdown(description)
66
  with gr.Tabs():
67
+ with gr.TabItem("Images"):
68
  with gr.Row():
69
  with gr.Column():
70
+ im = gr.Image(type="pil", label="input image", sources=["upload", "webcam"])
71
  with gr.Column():
72
  label_conv = gr.Label(label="Predictions", num_top_classes=4)
73
  btn = gr.Button(value="predict")
74
  btn.click(predict, inputs=im, outputs=[label_conv])
75
  gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv])
76
+
77
+
78
+ if __name__ == "__main__":
79
+ with open("config.yaml") as f:
 
 
 
 
 
 
 
 
 
80
  PARAMS = yaml.load(f, Loader=SafeLoader)
81
  print(PARAMS)
82
+ model = WeedModel.load_from_checkpoint(
83
+ "model/epoch=08.ckpt", params=PARAMS, map_location=torch.device("cpu")
84
+ )
85
  model.eval()
86
 
87
+ transform = transforms.Compose(
88
+ [
89
+ transforms.Resize(256),
90
+ transforms.CenterCrop(224),
91
+ transforms.ToTensor(),
92
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
93
+ ]
94
+ )
95
+
96
  demo.launch()
config.yaml CHANGED
@@ -1,19 +1,8 @@
1
  {
2
- #Dataset
3
- "train_path": "./train.txt",
4
- "test_path": "./test.txt",
5
- "val_path": "./val.txt",
6
- "n_data": 1,
7
-
8
 
9
  #Model
10
  "model": "EfficientNet", # [Alexnet, VGG, GoogleNet, ResNet, DenseNet, MobileNet, SqueezeNet, ShuffleNet, EfficientNet, SE-ResNet (not available)]
11
  "pretrained": True,
12
  "n_class": 40,
13
 
14
- #Training
15
- "B_sz": 4,
16
- "Lr": 0.001,
17
- "Epoch": 5,
18
- "optimizer": "Adam" #[Adam, SGD]
19
  }
 
1
  {
 
 
 
 
 
 
2
 
3
  #Model
4
  "model": "EfficientNet", # [Alexnet, VGG, GoogleNet, ResNet, DenseNet, MobileNet, SqueezeNet, ShuffleNet, EfficientNet, SE-ResNet (not available)]
5
  "pretrained": True,
6
  "n_class": 40,
7
 
 
 
 
 
 
8
  }