Kaori1707 commited on
Commit
79146ea
ยท
1 Parent(s): ab90335

first update

Browse files
Files changed (6) hide show
  1. .gitignore +138 -0
  2. app.py +124 -0
  3. class_names.txt +39 -0
  4. config.yaml +19 -0
  5. model/epoch=08.ckpt +3 -0
  6. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+
128
+ # Pyre type checker
129
+ .pyre/
130
+ .neptune
131
+
132
+ #dataset
133
+ data
134
+ crop_data
135
+ examples
136
+ #model
137
+ lightning_logs
138
+ ckpts
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+ import pytorch_lightning as pl
3
+ from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights
4
+ import torch
5
+ from torch import nn
6
+ from torchvision import transforms
7
+ from torch.nn import functional as F
8
+ import yaml
9
+ from yaml.loader import SafeLoader
10
+ from PIL import Image
11
+ import gradio as gr
12
+ import os
13
+
14
+ class WeedModel(pl.LightningModule):
15
+ def __init__(self, params):
16
+ super().__init__()
17
+ self.params = params
18
+
19
+ model = self.params["model"]
20
+
21
+ if(model.lower() == "efficientnet"):
22
+ if(self.params["pretrained"]): self.base_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1)
23
+ else: self.base_model = efficientnet_v2_s(weights=None)
24
+ num_ftrs = self.base_model.classifier[-1].in_features
25
+ self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"])
26
+
27
+ else:
28
+ print("not prepared model yet!!")
29
+
30
+ def forward(self, x):
31
+ embedding = self.base_model(x)
32
+ return embedding
33
+
34
+ def configure_optimizers(self):
35
+ if(self.params["optimizer"] == "Adam"):
36
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.params["Lr"])
37
+ elif(self.params["optimizer"] == "SGD"):
38
+ optimizer = torch.optim.SGD(self.parameters(), lr=self.params["Lr"])
39
+ return optimizer
40
+
41
+ def training_step(self, train_batch, batch_idx):
42
+ x = train_batch["image"]
43
+ y = train_batch["label"]
44
+
45
+ y_hat = self(x)
46
+ loss = F.cross_entropy(y_hat, y)
47
+ self.log('metrics/batch/train_loss', loss, prog_bar=False)
48
+
49
+ preds = F.softmax(y_hat, dim=-1)
50
+
51
+ return loss
52
+
53
+ def validation_step(self, val_batch, batch_idx):
54
+
55
+ x = val_batch["image"]
56
+ y = val_batch["label"]
57
+
58
+ y_hat = self(x)
59
+ loss = F.cross_entropy(y_hat, y)
60
+ self.log('metrics/batch/val_loss', loss)
61
+
62
+ def predict_step(self, batch: Any, batch_idx: int=0, dataloader_idx: int = 0) -> Any:
63
+ y_hat = self(batch)
64
+ preds = torch.softmax(y_hat, dim=-1).tolist()
65
+
66
+ # preds = torch.argmax(preds, dim=-1)
67
+ return preds
68
+
69
+
70
+ def predict(image):
71
+
72
+ tensor_image = transform(image)
73
+ outs = model.predict_step(tensor_image.unsqueeze(0))
74
+ labels = {class_names[k]: float(v) for k, v in enumerate(outs[0][:-1])}
75
+
76
+ return labels
77
+
78
+
79
+ title = " AISeed AI Application Demo "
80
+ description = "# A Demo of Deep Learning for Weed Classification"
81
+ example_list = [["examples/" + example] for example in os.listdir("examples")]
82
+
83
+ with open("class_names.txt", "r", encoding='utf-8') as f:
84
+ class_names = f.read().splitlines()
85
+
86
+ with gr.Blocks() as demo:
87
+ demo.title = title
88
+ gr.Markdown(description)
89
+ with gr.Tabs():
90
+ with gr.TabItem("for Images"):
91
+ with gr.Row():
92
+ with gr.Column():
93
+ im = gr.Image(type="pil", label="input image")
94
+ with gr.Column():
95
+ label_conv = gr.Label(label="Predictions", num_top_classes=4)
96
+ btn = gr.Button(value="predict")
97
+ btn.click(predict, inputs=im, outputs=[label_conv])
98
+ gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv])
99
+ with gr.TabItem("for Webcam"):
100
+ with gr.Row():
101
+ with gr.Column():
102
+ webcam = gr.Image(type="pil", label="input image", source="webcam")
103
+ # capture = gr.Image(type="pil", label="output image")
104
+ with gr.Column():
105
+ label = gr.Label(label="Predictions", num_top_classes=4)
106
+
107
+ webcam.change(predict, inputs=webcam, outputs=[label])
108
+
109
+
110
+ if __name__ == '__main__':
111
+ with open('config.yaml') as f:
112
+ PARAMS = yaml.load(f, Loader=SafeLoader)
113
+ print(PARAMS)
114
+ model = WeedModel.load_from_checkpoint("model\epoch=08.ckpt", params=PARAMS).cpu()
115
+ model.eval()
116
+
117
+ transform = transforms.Compose([
118
+ transforms.Resize(256),
119
+ transforms.CenterCrop(224),
120
+ transforms.ToTensor(),
121
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
122
+ ])
123
+
124
+ demo.launch()
class_names.txt ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ๋ชฉํ‘œ_๊ฐœ๋ง์ดˆ
2
+ ๋ชฉํ‘œ_๊ฐœ์‘ฅ๊ฐ“
3
+ ๋ชฉํ‘œ_๋‹จํ’์žŽ๋ผ์ง€ํ’€
4
+ ๋ชฉํ‘œ_๋ผ์ง€ํ’€
5
+ ๋ชฉํ‘œ_๋ง์ดˆ
6
+ ๋ชฉํ‘œ_๋ฏธ๊ตญ๊ฐ€๋ง‰์‚ฌ๋ฆฌ
7
+ ๋ชฉํ‘œ_๋ฐฉ๊ฐ€์ง€๋˜ฅ
8
+ ๋ชฉํ‘œ_๋ณ„๊ฝƒ์•„์žฌ๋น„
9
+ ๋ชฉํ‘œ_๋ถ‰์€์„œ๋‚˜๋ฌผ
10
+ ๋ชฉํ‘œ_์šธ์‚ฐ๋„๊นจ๋น„๋ฐ”๋Š˜
11
+ ๋ชฉํ‘œ_์ฃผํ™์„œ๋‚˜๋ฌผ
12
+ ๋ชฉํ‘œ_ํฐ๋„๊ผฌ๋งˆ๋ฆฌ
13
+ ๋ชฉํ‘œ_ํฐ๋ง์ดˆ
14
+ ๋ชฉํ‘œ_ํฐ๋ฐฉ๊ฐ€์ง€๋˜ฅ
15
+ ๋ชฉํ‘œ_ํ„ธ๋ณ„๊ฝƒ์•„์žฌ๋น„
16
+ ๋ชฉํ‘œ_๋‘ฅ๊ทผ์žŽ๋‚˜ํŒ”๊ฝƒ
17
+ ๋ชฉํ‘œ_๋‘ฅ๊ทผ์žŽ๋ฏธ๊ตญ๋‚˜ํŒ”๊ฝƒ
18
+ ๋ชฉํ‘œ_๋‘ฅ๊ทผ์žŽ์œ ํ™์ดˆ
19
+ ๋ชฉํ‘œ_๋ฏธ๊ตญ๋‚˜ํŒ”๊ฝƒ
20
+ ๋ชฉํ‘œ_๋ณ„๋‚˜ํŒ”๊ฝƒ
21
+ ๋ชฉํ‘œ_์• ๊ธฐ๋‚˜ํŒ”๊ฝƒ
22
+ ๋ชฉํ‘œ_๋ƒ„์ƒˆ๋ช…์•„์ฃผ
23
+ ๋ชฉํ‘œ_์–‘๋ช…์•„์ฃผ
24
+ ๋ชฉํ‘œ_์ข€๋ช…์•„์ฃผ
25
+ ๋ชฉํ‘œ_์ทจ๋ช…์•„์ฃผ
26
+ ๋ชฉํ‘œ_ํฐ๋ช…์•„์ฃผ
27
+ ๋ชฉํ‘œ_๊ฐ€๋Š”ํ„ธ๋น„๋ฆ„
28
+ ๋ชฉํ‘œ_๊ฐ€์‹œ๋น„๋ฆ„
29
+ ๋ชฉํ‘œ_๊ฐœ๋น„๋ฆ„
30
+ ๋ชฉํ‘œ_๊ธดํ„ธ๋น„๋ฆ„
31
+ ๋ชฉํ‘œ_๋ฏผํ„ธ๋น„๋ฆ„
32
+ ๋ชฉํ‘œ_์ฒญ๋น„๋ฆ„
33
+ ๋ชฉํ‘œ_ํ„ธ๋น„๋ฆ„
34
+ ๋ชฉํ‘œ_๊ฐ€๋Š”๋ฏธ๊ตญ์™ธํ’€
35
+ ๋ชฉํ‘œ_๋ˆˆ๊ฐœ๋ถˆ์•Œํ’€
36
+ ๋ชฉํ‘œ_๋ฏธ๊ตญ์™ธํ’€
37
+ ๋ชฉํ‘œ_์„ ๊ฐœ๋ถˆ์•Œํ’€
38
+ ๋ชฉํ‘œ_์ข€๊ฐœ๋ถˆ์•Œํ’€
39
+ ๋ชฉํ‘œ_ํฐ๊ฐœ๋ถˆ์•Œํ’€
config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ #Dataset
3
+ "train_path": "./train.txt",
4
+ "test_path": "./test.txt",
5
+ "val_path": "./val.txt",
6
+ "n_data": 1,
7
+
8
+
9
+ #Model
10
+ "model": "EfficientNet", # [Alexnet, VGG, GoogleNet, ResNet, DenseNet, MobileNet, SqueezeNet, ShuffleNet, EfficientNet, SE-ResNet (not available)]
11
+ "pretrained": True,
12
+ "n_class": 40,
13
+
14
+ #Training
15
+ "B_sz": 4,
16
+ "Lr": 0.001,
17
+ "Epoch": 5,
18
+ "optimizer": "Adam" #[Adam, SGD]
19
+ }
model/epoch=08.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c21a880dfa8c41e3c1533f5e6ab5a12a2bc562d99db7c3504b042384a54de1d
3
+ size 244046718
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pytorch-lightning
2
+ torch==1.8.1
3
+ torchvision==0.9.1