celery22 commited on
Commit
f232657
1 Parent(s): 6a26dcb

Upload 7 files

Browse files
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import torch
4
+ import torch.nn as nn
5
+ from PIL import Image
6
+ from torchvision.models import resnet18
7
+ from torchvision.transforms import functional as F
8
+
9
+ def main():
10
+ # title
11
+ title = 'Cucumber Diseases Diagnosis'
12
+ # description
13
+ description = """ このデモは、きゅうりの葉の病害を分類することを目的としたプロジェクトの進行中のアプリケーションである。
14
+ 現在、健全含む病害10種を分類することができます。
15
+ このモデルはこちらでご覧になれます:https://huggingface.co/spaces/celery22/cucumber_diseases_diagnosis
16
+ お楽しみください。
17
+ """
18
+ # model_path
19
+ model_path = 'cucumber_resnet18_last_model.pth'
20
+ # class_name
21
+ class_name = ["健全","うどんこ病","灰色かび病","炭疽病","べと病","褐斑病","つる枯病","斑点細菌病","CCYV","モザイク病","MYSV"]
22
+
23
+ # example
24
+ example=[
25
+ ['image/231305_20200302150233_01.JPG'],
26
+ ['image/0004_20181120084837_01.jpg'],
27
+ ['image/160001_20170830173740_01.JPG'],
28
+ ['image/152300_20190119175054_01.JPG'],
29
+
30
+ ]
31
+
32
+ # model定義
33
+ model_ft = resnet18(num_classes = len(class_name),pretrained=False)
34
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
35
+ model_ft = model_ft.to(device)
36
+ if torch.cuda.is_available():
37
+ model_ft.load_state_dict(torch.load(model_path))
38
+ else:
39
+ model_ft.load_state_dict(
40
+ torch.load(model_path, map_location=torch.device("cpu"))
41
+ )
42
+ model_ft.eval()
43
+
44
+ # 画像分類を行う関数を定義
45
+ @torch.no_grad()
46
+ def inference(gr_input):
47
+ img = Image.fromarray(gr_input.astype("uint8"), "RGB")
48
+
49
+ # 前処理
50
+ img = F.resize(img, (224, 224))
51
+ img = F.to_tensor(img)
52
+ img = img.unsqueeze(0)
53
+
54
+ # 推論
55
+ output = model_ft(img).squeeze(0)
56
+ probs = nn.functional.softmax(output, dim=0).numpy()
57
+ labels_lenght =len(class_name)
58
+ # ラベルごとの確率をdictとして返す
59
+ return {class_name[i]: float(probs[i]) for i in range(labels_lenght)}
60
+
61
+
62
+ # 入力の形式を画像とする
63
+ inputs = gr.inputs.Image()
64
+
65
+ # 出力はラベル形式で,top5まで表示する
66
+ outputs = gr.outputs.Label(num_top_classes=5)
67
+
68
+ # サーバーの立ち上げ
69
+ interface = gr.Interface(fn=inference,
70
+ inputs=[inputs],
71
+ outputs=outputs,
72
+ examples=example,
73
+ title=title,
74
+ description=description)
75
+
76
+ interface.launch()
77
+
78
+
79
+ if __name__ == "__main__":
80
+ main()
cucumber_resnet18_last_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:46ea5b3332a072a019bc4f78fb86e172ad39f3e38b463cb194170501f903eba1
3
+ size 44806605
image/0004_20181120084837_01.jpg ADDED
image/152300_20190119175054_01.JPG ADDED
image/160001_20170830173740_01.JPG ADDED
image/231305_20200302150233_01.JPG ADDED
requirements.txt ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==22.1.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.3.1
4
+ altair==4.2.0
5
+ anyio==3.6.2
6
+ asgiref==3.6.0
7
+ async-timeout==4.0.2
8
+ attrs==22.2.0
9
+ backports.zoneinfo==0.2.1
10
+ certifi==2022.12.7
11
+ charset-normalizer==2.1.1
12
+ click==8.1.3
13
+ contourpy==1.0.7
14
+ cycler==0.11.0
15
+ dj-database-url==1.2.0
16
+ Django==4.1.4
17
+ django-environ==0.9.0
18
+ entrypoints==0.4
19
+ fastapi==0.89.1
20
+ ffmpy==0.3.0
21
+ fonttools==4.38.0
22
+ frozenlist==1.3.3
23
+ fsspec==2022.11.0
24
+ gradio==3.16.2
25
+ h11==0.14.0
26
+ httpcore==0.16.3
27
+ httpx==0.23.3
28
+ idna==3.4
29
+ importlib-resources==5.10.2
30
+ Jinja2==3.1.2
31
+ jsonschema==4.17.3
32
+ kiwisolver==1.4.4
33
+ linkify-it-py==1.0.3
34
+ markdown-it-py==2.1.0
35
+ MarkupSafe==2.1.1
36
+ matplotlib==3.6.3
37
+ mdit-py-plugins==0.3.3
38
+ mdurl==0.1.2
39
+ multidict==6.0.4
40
+ numpy==1.24.0
41
+ opencv-python==4.6.0.66
42
+ orjson==3.8.5
43
+ packaging==23.0
44
+ pandas==1.5.2
45
+ Pillow==9.3.0
46
+ pkgutil_resolve_name==1.3.10
47
+ psycopg2==2.9.5
48
+ pycryptodome==3.16.0
49
+ pydantic==1.10.4
50
+ pydub==0.25.1
51
+ pyparsing==3.0.9
52
+ pyrsistent==0.19.3
53
+ python-dateutil==2.8.2
54
+ python-decouple==3.7
55
+ python-multipart==0.0.5
56
+ pytz==2022.7.1
57
+ PyYAML==6.0
58
+ requests==2.28.1
59
+ rfc3986==1.5.0
60
+ six==1.16.0
61
+ sniffio==1.3.0
62
+ sqlparse==0.4.3
63
+ starlette==0.22.0
64
+ toolz==0.12.0
65
+ torch==1.13.1
66
+ torchvision==0.14.1
67
+ typing_extensions==4.4.0
68
+ uc-micro-py==1.0.1
69
+ urllib3==1.26.13
70
+ uvicorn==0.20.0
71
+ websockets==10.4
72
+ whitenoise==6.3.0
73
+ yarl==1.8.2
74
+ zipp==3.11.0