Spearoad commited on
Commit
71051a0
·
1 Parent(s): 7ebc959

add app.py + model files + requirements.txt

Browse files
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "f6a71737-070f-4eff-854b-b2432a032fdd",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn as nn\n",
12
+ "import timm\n",
13
+ "import gradio as gr\n",
14
+ "from PIL import Image\n",
15
+ "\n",
16
+ "import torchvision.transforms as T\n",
17
+ "from ultralytics import YOLO\n",
18
+ "\n",
19
+ "device = torch.device(\"cpu\")\n",
20
+ "\n",
21
+ "num_classes = 3\n",
22
+ "classify_model = timm.create_model(\n",
23
+ " \"efficientnet_b3\", pretrained=False, num_classes=num_classes\n",
24
+ ")\n",
25
+ "# 학습된 가중치 불러오기\n",
26
+ "state_dict = torch.load(\"classify.pth\", map_location=\"cpu\")\n",
27
+ "classify_model.load_state_dict(state_dict)\n",
28
+ "classify_model.to(device)\n",
29
+ "classify_model.eval()\n",
30
+ "\n",
31
+ "# 전처리 정의 (학습 때 사용한 방식과 동일해야 함)\n",
32
+ "transform_cls = T.Compose([\n",
33
+ " T.Resize((224, 224)),\n",
34
+ " T.ToTensor(),\n",
35
+ " T.Normalize([0.485, 0.456, 0.406],\n",
36
+ " [0.229, 0.224, 0.225]),\n",
37
+ "])\n",
38
+ "\n",
39
+ "yolo_model = torch.hub.load(\n",
40
+ " \"ultralytics/yolov5\", \"yolov5n\", pretrained=False, source=\"github\"\n",
41
+ ")\n",
42
+ "yolo_model = torch.hub.load(\n",
43
+ " \"ultralytics/yolov5\", \"custom\", path=\"detect.pt\", source=\"local\"\n",
44
+ ")\n",
45
+ "yolo_model.to(device)\n",
46
+ "yolo_model.eval()\n",
47
+ "\n",
48
+ "# 예측 함수\n",
49
+ "def pipeline_predict(img: Image.Image):\n",
50
+ " # --- 분류 ---\n",
51
+ " img_cls = transform_cls(img).unsqueeze(0).to(device)\n",
52
+ " with torch.no_grad():\n",
53
+ " out_cls = classify_model(img_cls)\n",
54
+ " pred_cls = torch.argmax(out_cls, dim=1).item()\n",
55
+ " cls_label = [\"정상\", \"경미 손상\", \"심각 손상\"][pred_cls]\n",
56
+ "\n",
57
+ " # --- 객체 검출 ---\n",
58
+ " results = yolo_model(img, size=640)\n",
59
+ " det_img = results.render()[0] # 시각화된 이미지\n",
60
+ " det_img = Image.fromarray(det_img)\n",
61
+ "\n",
62
+ " return cls_label, det_img"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "id": "94211111-fadd-41dd-b812-7cd056ce5331",
68
+ "metadata": {},
69
+ "source": [
70
+ "# Gradio UI\n",
71
+ "demo = gr.Interface(\n",
72
+ " fn=pipeline_predict,\n",
73
+ " inputs=gr.Image(type=\"pil\"),\n",
74
+ " outputs=[gr.Label(num_top_classes=1), gr.Image(type=\"pil\")],\n",
75
+ " title=\"도로 파손 AI (경량화 버전)\",\n",
76
+ " description=\"무료 HuggingFace Spaces CPU 환경에서도 동작하도록 최적화된 버전입니다.\",\n",
77
+ " allow_flagging=\"never\"\n",
78
+ ")\n",
79
+ "\n",
80
+ "if __name__ == \"__main__\":\n",
81
+ " demo.launch()"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": null,
87
+ "id": "265c26c0-4d58-4506-8a80-c059cfaed3d4",
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": []
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "id": "b13880fe-cae9-4215-b39d-c883466079d6",
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": []
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "ff17d6f2-7e3d-4c5c-8457-7ad197202c4a",
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": []
107
+ }
108
+ ],
109
+ "metadata": {
110
+ "kernelspec": {
111
+ "display_name": "Python 3 (ipykernel)",
112
+ "language": "python",
113
+ "name": "python3"
114
+ },
115
+ "language_info": {
116
+ "codemirror_mode": {
117
+ "name": "ipython",
118
+ "version": 3
119
+ },
120
+ "file_extension": ".py",
121
+ "mimetype": "text/x-python",
122
+ "name": "python",
123
+ "nbconvert_exporter": "python",
124
+ "pygments_lexer": "ipython3",
125
+ "version": "3.12.3"
126
+ }
127
+ },
128
+ "nbformat": 4,
129
+ "nbformat_minor": 5
130
+ }
.ipynb_checkpoints/Untitled1-checkpoint.ipynb ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 42,
6
+ "id": "2617d8a8-b0f1-43a2-90c4-c0f0fc753b88",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn.functional as F\n",
12
+ "import timm\n",
13
+ "import gradio as gr\n",
14
+ "from PIL import Image\n",
15
+ "import torchvision.transforms as T\n",
16
+ "from ultralytics import YOLO\n",
17
+ "import numpy as np\n",
18
+ "\n",
19
+ "device = torch.device(\"cpu\")\n",
20
+ "\n",
21
+ "# 분류 모델\n",
22
+ "num_classes = 3\n",
23
+ "classify_model = timm.create_model(\n",
24
+ " \"efficientnet_b3\", pretrained=False, num_classes=num_classes)\n",
25
+ "\n",
26
+ "# 학습된 가중치 불러오기\n",
27
+ "state_dict = torch.load(\"classify.pth\", map_location=\"cpu\")\n",
28
+ "classify_model.load_state_dict(state_dict)\n",
29
+ "classify_model.to(device)\n",
30
+ "classify_model.eval()\n",
31
+ "\n",
32
+ "# 전처리\n",
33
+ "transform_cls = T.Compose([\n",
34
+ " T.Resize((300, 300)),\n",
35
+ " T.ToTensor(),\n",
36
+ " T.Normalize([0.485, 0.456, 0.406],\n",
37
+ " [0.229, 0.224, 0.225]),\n",
38
+ "])\n",
39
+ "\n",
40
+ "def fine_to_coarse(idx: int) -> str:\n",
41
+ " return \"손상됨\" if idx in (1, 2) else \"정상\"\n",
42
+ " \n",
43
+ "# 탐지 모델\n",
44
+ "yolo_model = YOLO(\"detect.pt\") # 가중치 불러오기\n",
45
+ "yolo_model.to(device)\n",
46
+ "yolo_model.eval()\n",
47
+ "\n",
48
+ "# 예측 함수\n",
49
+ "def pipeline_predict(img: Image.Image):\n",
50
+ " # 분류\n",
51
+ " img_cls = transform_cls(img).unsqueeze(0).to(device)\n",
52
+ " with torch.no_grad():\n",
53
+ " out_cls = classify_model(img_cls)\n",
54
+ " pred_cls = torch.argmax(out_cls, dim=1).item()\n",
55
+ "\n",
56
+ " # 손상됨 상태면 탐지\n",
57
+ " if pred_cls == 0:\n",
58
+ " return \"정상\", img\n",
59
+ " else:\n",
60
+ " cls_label = \"손상됨\"\n",
61
+ " # --- 객체 검출 (YOLOv8) ---\n",
62
+ " results = yolo_model.predict(img, imgsz=640)\n",
63
+ " det_img = results[0].plot() # numpy (BGR)\n",
64
+ " det_img = Image.fromarray(det_img[..., ::-1])\n",
65
+ " return cls_label, det_img"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 43,
71
+ "id": "592edec4-f7ca-4b47-bcfd-0f203b397353",
72
+ "metadata": {},
73
+ "outputs": [
74
+ {
75
+ "name": "stdout",
76
+ "output_type": "stream",
77
+ "text": [
78
+ "* Running on local URL: http://127.0.0.1:7880\n",
79
+ "* To create a public link, set `share=True` in `launch()`.\n"
80
+ ]
81
+ },
82
+ {
83
+ "data": {
84
+ "text/html": [
85
+ "<div><iframe src=\"http://127.0.0.1:7880/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
86
+ ],
87
+ "text/plain": [
88
+ "<IPython.core.display.HTML object>"
89
+ ]
90
+ },
91
+ "metadata": {},
92
+ "output_type": "display_data"
93
+ },
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "\n",
99
+ "0: 640x640 1 십자파손, 189.5ms\n",
100
+ "Speed: 1.7ms preprocess, 189.5ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 640)\n",
101
+ "\n",
102
+ "0: 640x640 1 횡방향균열, 191.7ms\n",
103
+ "Speed: 1.8ms preprocess, 191.7ms inference, 0.7ms postprocess per image at shape (1, 3, 640, 640)\n"
104
+ ]
105
+ }
106
+ ],
107
+ "source": [
108
+ "# ======================\n",
109
+ "# 5. Gradio UI\n",
110
+ "# ======================\n",
111
+ "demo = gr.Interface(\n",
112
+ " fn=pipeline_predict,\n",
113
+ " inputs=gr.Image(type=\"pil\"),\n",
114
+ " outputs=[gr.Textbox(label=\"도로 상태\"), gr.Image(type=\"pil\", label=\"탐지 결과\")],\n",
115
+ " title=\"도로 상태 분석\",\n",
116
+ " description=\"분류 모델에서의 손상 확률이 기준 이상이면 탐지 모델로 파손 부분 탐지\",\n",
117
+ " allow_flagging=\"never\"\n",
118
+ ")\n",
119
+ "\n",
120
+ "if __name__ == \"__main__\":\n",
121
+ " demo.launch()"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "692ac8db-ebb3-4a87-adec-151ac4459392",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": []
131
+ }
132
+ ],
133
+ "metadata": {
134
+ "kernelspec": {
135
+ "display_name": "Python 3 (ipykernel)",
136
+ "language": "python",
137
+ "name": "python3"
138
+ },
139
+ "language_info": {
140
+ "codemirror_mode": {
141
+ "name": "ipython",
142
+ "version": 3
143
+ },
144
+ "file_extension": ".py",
145
+ "mimetype": "text/x-python",
146
+ "name": "python",
147
+ "nbconvert_exporter": "python",
148
+ "pygments_lexer": "ipython3",
149
+ "version": "3.12.3"
150
+ }
151
+ },
152
+ "nbformat": 4,
153
+ "nbformat_minor": 5
154
+ }
.ipynb_checkpoints/app-checkpoint.ipynb ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import yolov5 # ultralytics/yolov5 불러올 때 필요 (detect.pt 호환)
6
+
7
+ # 1. 분류 모델 로드
8
+ classify_model = torch.load("classify.pth", map_location="cpu")
9
+ classify_model.eval()
10
+
11
+ # 전처리 정의 (학습 때 사용한 방식과 동일해야 함)
12
+ transform = transforms.Compose([
13
+ transforms.Resize((224,224)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
16
+ ])
17
+
18
+ # 2. 탐지 모델 로드
19
+ detect_model = torch.hub.load("ultralytics/yolov5", "custom", path="detect.pt")
20
+
21
+ # 예측 함수
22
+ def pipeline(image):
23
+ # ---- 1단계 분류 ----
24
+ img = image.convert("RGB")
25
+ x = transform(img).unsqueeze(0)
26
+ with torch.no_grad():
27
+ out = classify_model(x)
28
+ pred = torch.argmax(out, dim=1).item()
29
+
30
+ # 분류 결과 해석 (예: 0=정상, 1=파손)
31
+ if pred == 0:
32
+ return f"도로 상태: 정상 (탐지 불필요)", image
33
+
34
+ # ---- 2단계 탐지 ----
35
+ results = detect_model(img)
36
+ detected_img = results.render()[0] # numpy array
37
+ detected_img = Image.fromarray(detected_img)
38
+
39
+ return f"도로 상태: 파손 감지됨", detected_img
40
+
41
+ # Gradio UI
42
+ iface = gr.Interface(
43
+ fn=pipeline,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs=[gr.Textbox(), gr.Image(type="pil")],
46
+ title="도로 상태 분석 파이프라인",
47
+ description="먼저 분류기로 도로 상태를 판별하고, 필요 시 탐지 모델로 파손 위치를 표시합니다."
48
+ )
49
+
50
+ if __name__ == "__main__":
51
+ iface.launch()
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from PIL import Image
4
+ import torchvision.transforms as transforms
5
+ import yolov5 # ultralytics/yolov5 불러올 때 필요 (detect.pt 호환)
6
+
7
+ # 1. 분류 모델 로드
8
+ classify_model = torch.load("classify.pth", map_location="cpu")
9
+ classify_model.eval()
10
+
11
+ # 전처리 정의 (학습 때 사용한 방식과 동일해야 함)
12
+ transform = transforms.Compose([
13
+ transforms.Resize((224,224)),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
16
+ ])
17
+
18
+ # 2. 탐지 모델 로드
19
+ detect_model = torch.hub.load("ultralytics/yolov5", "custom", path="detect.pt")
20
+
21
+ # 예측 함수
22
+ def pipeline(image):
23
+ # ---- 1단계 분류 ----
24
+ img = image.convert("RGB")
25
+ x = transform(img).unsqueeze(0)
26
+ with torch.no_grad():
27
+ out = classify_model(x)
28
+ pred = torch.argmax(out, dim=1).item()
29
+
30
+ # 분류 결과 해석 (예: 0=정상, 1=파손)
31
+ if pred == 0:
32
+ return f"도로 상태: 정상 (탐지 불필요)", image
33
+
34
+ # ---- 2단계 탐지 ----
35
+ results = detect_model(img)
36
+ detected_img = results.render()[0] # numpy array
37
+ detected_img = Image.fromarray(detected_img)
38
+
39
+ return f"도로 상태: 파손 감지됨", detected_img
40
+
41
+ # Gradio UI
42
+ iface = gr.Interface(
43
+ fn=pipeline,
44
+ inputs=gr.Image(type="pil"),
45
+ outputs=[gr.Textbox(), gr.Image(type="pil")],
46
+ title="도로 상태 분석 파이프라인",
47
+ description="먼저 분류기로 도로 상태를 판별하고, 필요 시 탐지 모델로 파손 위치를 표시합니다."
48
+ )
49
+
50
+ if __name__ == "__main__":
51
+ iface.launch()
Untitled.ipynb ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "f6a71737-070f-4eff-854b-b2432a032fdd",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Downloading: \"https://github.com/ultralytics/yolov5/zipball/master\" to /home/user22313548/.cache/torch/hub/master.zip\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "import torch\n",
19
+ "import torch.nn as nn\n",
20
+ "import timm\n",
21
+ "import gradio as gr\n",
22
+ "from PIL import Image\n",
23
+ "\n",
24
+ "import torchvision.transforms as T\n",
25
+ "from ultralytics import YOLO\n",
26
+ "\n",
27
+ "device = torch.device(\"cpu\")\n",
28
+ "\n",
29
+ "num_classes = 3\n",
30
+ "classify_model = timm.create_model(\n",
31
+ " \"efficientnet_b3\", pretrained=False, num_classes=num_classes\n",
32
+ ")\n",
33
+ "# 학습된 가중치 불러오기\n",
34
+ "state_dict = torch.load(\"classify.pth\", map_location=\"cpu\")\n",
35
+ "classify_model.load_state_dict(state_dict)\n",
36
+ "classify_model.to(device)\n",
37
+ "classify_model.eval()\n",
38
+ "\n",
39
+ "# 전처리 정의 (학습 때 사용한 방식과 동일해야 함)\n",
40
+ "transform_cls = T.Compose([\n",
41
+ " T.Resize((224, 224)),\n",
42
+ " T.ToTensor(),\n",
43
+ " T.Normalize([0.485, 0.456, 0.406],\n",
44
+ " [0.229, 0.224, 0.225]),\n",
45
+ "])\n",
46
+ "\n",
47
+ "yolo_model = torch.hub.load(\n",
48
+ " \"ultralytics/yolov5\", \"yolov5n\", pretrained=False, source=\"github\"\n",
49
+ ")\n",
50
+ "yolo_model = torch.hub.load(\n",
51
+ " \"ultralytics/yolov5\", \"custom\", path=\"detect.pt\", source=\"local\"\n",
52
+ ")\n",
53
+ "yolo_model.to(device)\n",
54
+ "yolo_model.eval()\n",
55
+ "\n",
56
+ "# 예측 함수\n",
57
+ "def pipeline_predict(img: Image.Image):\n",
58
+ " # --- 분류 ---\n",
59
+ " img_cls = transform_cls(img).unsqueeze(0).to(device)\n",
60
+ " with torch.no_grad():\n",
61
+ " out_cls = classify_model(img_cls)\n",
62
+ " pred_cls = torch.argmax(out_cls, dim=1).item()\n",
63
+ " cls_label = [\"정상\", \"경미 손상\", \"심각 손상\"][pred_cls]\n",
64
+ "\n",
65
+ " # --- 객체 검출 ---\n",
66
+ " results = yolo_model(img, size=640)\n",
67
+ " det_img = results.render()[0] # 시각화된 이미지\n",
68
+ " det_img = Image.fromarray(det_img)\n",
69
+ "\n",
70
+ " return cls_label, det_img"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "id": "94211111-fadd-41dd-b812-7cd056ce5331",
76
+ "metadata": {},
77
+ "source": [
78
+ "# Gradio UI\n",
79
+ "demo = gr.Interface(\n",
80
+ " fn=pipeline_predict,\n",
81
+ " inputs=gr.Image(type=\"pil\"),\n",
82
+ " outputs=[gr.Label(num_top_classes=1), gr.Image(type=\"pil\")],\n",
83
+ " title=\"도로 파손 AI (경량화 버전)\",\n",
84
+ " description=\"무료 HuggingFace Spaces CPU 환경에서도 동작하도록 최적화된 버전입니다.\",\n",
85
+ " allow_flagging=\"never\"\n",
86
+ ")\n",
87
+ "\n",
88
+ "if __name__ == \"__main__\":\n",
89
+ " demo.launch()"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": null,
95
+ "id": "239606a5-63df-49f6-a1dd-69ba979a3e21",
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": []
99
+ }
100
+ ],
101
+ "metadata": {
102
+ "kernelspec": {
103
+ "display_name": "Python 3 (ipykernel)",
104
+ "language": "python",
105
+ "name": "python3"
106
+ },
107
+ "language_info": {
108
+ "codemirror_mode": {
109
+ "name": "ipython",
110
+ "version": 3
111
+ },
112
+ "file_extension": ".py",
113
+ "mimetype": "text/x-python",
114
+ "name": "python",
115
+ "nbconvert_exporter": "python",
116
+ "pygments_lexer": "ipython3",
117
+ "version": "3.12.3"
118
+ }
119
+ },
120
+ "nbformat": 4,
121
+ "nbformat_minor": 5
122
+ }
Untitled1.ipynb ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 44,
6
+ "id": "2617d8a8-b0f1-43a2-90c4-c0f0fc753b88",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "import torch.nn.functional as F\n",
12
+ "import timm\n",
13
+ "import gradio as gr\n",
14
+ "from PIL import Image\n",
15
+ "import torchvision.transforms as T\n",
16
+ "from ultralytics import YOLO\n",
17
+ "import numpy as np\n",
18
+ "\n",
19
+ "device = torch.device(\"cpu\")\n",
20
+ "\n",
21
+ "# 분류 모델\n",
22
+ "num_classes = 3\n",
23
+ "classify_model = timm.create_model(\n",
24
+ " \"efficientnet_b3\", pretrained=False, num_classes=num_classes)\n",
25
+ "\n",
26
+ "# 학습된 가중치 불러오기\n",
27
+ "state_dict = torch.load(\"classify.pth\", map_location=\"cpu\")\n",
28
+ "classify_model.load_state_dict(state_dict)\n",
29
+ "classify_model.to(device)\n",
30
+ "classify_model.eval()\n",
31
+ "\n",
32
+ "# 전처리\n",
33
+ "transform_cls = T.Compose([\n",
34
+ " T.Resize((300, 300)),\n",
35
+ " T.ToTensor(),\n",
36
+ " T.Normalize([0.485, 0.456, 0.406],\n",
37
+ " [0.229, 0.224, 0.225]),\n",
38
+ "])\n",
39
+ "\n",
40
+ "def fine_to_coarse(idx: int) -> str:\n",
41
+ " return \"손상됨\" if idx in (1, 2) else \"정상\"\n",
42
+ " \n",
43
+ "# 탐지 모델\n",
44
+ "yolo_model = YOLO(\"detect.pt\") # 가중치 불러오기\n",
45
+ "yolo_model.to(device)\n",
46
+ "yolo_model.eval()\n",
47
+ "\n",
48
+ "# 예측 함수\n",
49
+ "def pipeline_predict(img: Image.Image):\n",
50
+ " # 분류\n",
51
+ " img_cls = transform_cls(img).unsqueeze(0).to(device)\n",
52
+ " with torch.no_grad():\n",
53
+ " out_cls = classify_model(img_cls)\n",
54
+ " pred_cls = torch.argmax(out_cls, dim=1).item()\n",
55
+ "\n",
56
+ " # 손상됨 상태면 탐지\n",
57
+ " if pred_cls == 0:\n",
58
+ " return \"정상\", img\n",
59
+ " else:\n",
60
+ " cls_label = \"손상됨\"\n",
61
+ " # --- 객체 검출 (YOLOv8) ---\n",
62
+ " results = yolo_model.predict(img, imgsz=640)\n",
63
+ " det_img = results[0].plot() # numpy (BGR)\n",
64
+ " det_img = Image.fromarray(det_img[..., ::-1])\n",
65
+ " return cls_label, det_img"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 45,
71
+ "id": "592edec4-f7ca-4b47-bcfd-0f203b397353",
72
+ "metadata": {},
73
+ "outputs": [
74
+ {
75
+ "name": "stdout",
76
+ "output_type": "stream",
77
+ "text": [
78
+ "* Running on local URL: http://127.0.0.1:7881\n",
79
+ "* To create a public link, set `share=True` in `launch()`.\n"
80
+ ]
81
+ },
82
+ {
83
+ "data": {
84
+ "text/html": [
85
+ "<div><iframe src=\"http://127.0.0.1:7881/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
86
+ ],
87
+ "text/plain": [
88
+ "<IPython.core.display.HTML object>"
89
+ ]
90
+ },
91
+ "metadata": {},
92
+ "output_type": "display_data"
93
+ },
94
+ {
95
+ "name": "stdout",
96
+ "output_type": "stream",
97
+ "text": [
98
+ "\n",
99
+ "0: 640x640 1 십자파손, 193.6ms\n",
100
+ "Speed: 1.7ms preprocess, 193.6ms inference, 0.8ms postprocess per image at shape (1, 3, 640, 640)\n"
101
+ ]
102
+ }
103
+ ],
104
+ "source": [
105
+ "# ======================\n",
106
+ "# 5. Gradio UI\n",
107
+ "# ======================\n",
108
+ "demo = gr.Interface(\n",
109
+ " fn=pipeline_predict,\n",
110
+ " inputs=gr.Image(type=\"pil\"),\n",
111
+ " outputs=[gr.Textbox(label=\"도로 상태\"), gr.Image(type=\"pil\", label=\"탐지 결과\")],\n",
112
+ " title=\"도로 상태 분석\",\n",
113
+ " description=\"분류 모델에서의 손상 확률이 기준 이상이면 탐지 모델로 파손 부분 탐지\",\n",
114
+ " allow_flagging=\"never\"\n",
115
+ ")\n",
116
+ "\n",
117
+ "if __name__ == \"__main__\":\n",
118
+ " demo.launch()"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "code",
123
+ "execution_count": null,
124
+ "id": "692ac8db-ebb3-4a87-adec-151ac4459392",
125
+ "metadata": {},
126
+ "outputs": [],
127
+ "source": []
128
+ }
129
+ ],
130
+ "metadata": {
131
+ "kernelspec": {
132
+ "display_name": "Python 3 (ipykernel)",
133
+ "language": "python",
134
+ "name": "python3"
135
+ },
136
+ "language_info": {
137
+ "codemirror_mode": {
138
+ "name": "ipython",
139
+ "version": 3
140
+ },
141
+ "file_extension": ".py",
142
+ "mimetype": "text/x-python",
143
+ "name": "python",
144
+ "nbconvert_exporter": "python",
145
+ "pygments_lexer": "ipython3",
146
+ "version": "3.12.3"
147
+ }
148
+ },
149
+ "nbformat": 4,
150
+ "nbformat_minor": 5
151
+ }
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import timm
4
+ import gradio as gr
5
+ from PIL import Image
6
+ import torchvision.transforms as T
7
+ from ultralytics import YOLO
8
+ import numpy as np
9
+
10
+ device = torch.device("cpu")
11
+
12
+ # 분류 모델
13
+ num_classes = 3
14
+ classify_model = timm.create_model(
15
+ "efficientnet_b3", pretrained=False, num_classes=num_classes)
16
+
17
+ # 학습된 가중치 불러오기
18
+ state_dict = torch.load("classify.pth", map_location="cpu")
19
+ classify_model.load_state_dict(state_dict)
20
+ classify_model.to(device)
21
+ classify_model.eval()
22
+
23
+ # 전처리
24
+ transform_cls = T.Compose([
25
+ T.Resize((300, 300)),
26
+ T.ToTensor(),
27
+ T.Normalize([0.485, 0.456, 0.406],
28
+ [0.229, 0.224, 0.225]),
29
+ ])
30
+
31
+ def fine_to_coarse(idx: int) -> str:
32
+ return "손상됨" if idx in (1, 2) else "정상"
33
+
34
+ # 탐지 모델
35
+ yolo_model = YOLO("detect.pt") # 가중치 불러오기
36
+ yolo_model.to(device)
37
+ yolo_model.eval()
38
+
39
+ # 예측 함수
40
+ def pipeline_predict(img: Image.Image):
41
+ # 분류
42
+ img_cls = transform_cls(img).unsqueeze(0).to(device)
43
+ with torch.no_grad():
44
+ out_cls = classify_model(img_cls)
45
+ pred_cls = torch.argmax(out_cls, dim=1).item()
46
+
47
+ # 손상됨 상태면 탐지
48
+ if pred_cls == 0:
49
+ return "정상", img
50
+ else:
51
+ cls_label = "손상됨"
52
+ # --- 객체 검출 (YOLOv8) ---
53
+ results = yolo_model.predict(img, imgsz=640)
54
+ det_img = results[0].plot() # numpy (BGR)
55
+ det_img = Image.fromarray(det_img[..., ::-1])
56
+ return cls_label, det_img
57
+
58
+ # Gradio UI
59
+ demo = gr.Interface(
60
+ fn=pipeline_predict,
61
+ inputs=gr.Image(type="pil"),
62
+ outputs=[gr.Textbox(label="도로 상태"), gr.Image(type="pil", label="탐지 결과")],
63
+ title="도로 상태 분석",
64
+ description="분류 모델에서의 손상 확률이 기준 이상이면 탐지 모델로 파손 부분 탐지",
65
+ allow_flagging="never"
66
+ )
67
+
68
+ if __name__ == "__main__":
69
+ demo.launch()
classification_model(85.67).pth:Zone.Identifier ADDED
File without changes
classify.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:74ff4565abff0fd34ad6daf409a073774b55601554837131e50b9281dfd1366b
3
+ size 43360699
detect.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d9d8607b52af8891335957573247fa8305462f9c8f4c10e033729c66906c837
3
+ size 52049362
last.pt:Zone.Identifier ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ gradio
5
+ Pillow
6
+ ultralytics
7
+ numpy