celery22 commited on
Commit
a50af6a
1 Parent(s): cca1850

upload file

Browse files
Files changed (2) hide show
  1. app.py +72 -0
  2. cucumber_resnet18_last_model.pth +3 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from torchvision.models import resnet18
7
+ from torchvision.transforms import functional as F
8
+
9
+ def main():
10
+ # モデルの準備
11
+ model_ft = resnet18(num_classes = 11,pretrained=False)
12
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+ model_ft = model_ft.to(device)
14
+ if torch.cuda.is_available():
15
+ model_ft.load_state_dict(torch.load('cucumber_resnet18_last_model.pth'))
16
+ else:
17
+ model_ft.load_state_dict(
18
+ torch.load('cucumber_resnet18_last_model.pth', map_location=torch.device("cpu"))
19
+ )
20
+ model_ft.eval()
21
+
22
+ # 学種済みモデルのラベルの取得
23
+ # response = requests.get("https://git.io/JJkYN")
24
+ # labels = response.text.split("\n")
25
+ labels = [
26
+ "健全",
27
+ "うどんこ病",
28
+ "灰色かび病",
29
+ "炭疽病",
30
+ "べと病",
31
+ "褐斑病",
32
+ "つる枯病",
33
+ "斑点細菌病",
34
+ "CCYV",
35
+ "モザイク病",
36
+ "MYSV",
37
+ ]
38
+
39
+ # 画像分類を行う関数を定義
40
+ @torch.no_grad()
41
+ def inference(gr_input):
42
+ img = Image.fromarray(gr_input.astype("uint8"), "RGB")
43
+
44
+ # 前処理
45
+ img = F.resize(img, (224, 224))
46
+ img = F.to_tensor(img)
47
+ img = img.unsqueeze(0)
48
+ # img = F.normalize(
49
+ # img,
50
+ # mean=[0.485, 0.456, 0.406],
51
+ # std=[0.229, 0.224, 0.225],
52
+ # )
53
+
54
+ # 推論
55
+ output = model_ft(img).squeeze(0)
56
+ probs = nn.functional.softmax(output, dim=0).numpy()
57
+
58
+ # ラベルごとの確率をdictとして返す
59
+ return {labels[i]: float(probs[i]) for i in range(11)}
60
+
61
+ # 入力の形式を画像とする
62
+ inputs = gr.inputs.Image()
63
+ # 出力はラベル形式で,top5まで表示する
64
+ outputs = gr.outputs.Label(num_top_classes=5)
65
+
66
+ # サーバーの立ち上げ
67
+ interface = gr.Interface(fn=inference, inputs=inputs, outputs=outputs)
68
+ interface.launch()
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
cucumber_resnet18_last_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46ea5b3332a072a019bc4f78fb86e172ad39f3e38b463cb194170501f903eba1
3
+ size 44806605