sergey21000 commited on
Commit
e8cd255
·
verified ·
1 Parent(s): f0eed67

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +183 -0
  2. requirements.txt +4 -0
  3. utils.py +109 -0
  4. yolo_classes.json +82 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ from pathlib import Path
3
+ from typing import List, Dict, Union, Tuple, Literal, Optional
4
+
5
+ import numpy as np
6
+ import gradio as gr
7
+ from gradio.components.base import Component
8
+ from ultralytics import YOLO
9
+
10
+ from utils import download_model, detect_image, detect_video, get_csv_annotate
11
+
12
+
13
+ # ======================= МОДЕЛЬ ===================================
14
+
15
+ MODELS_DIR = Path('models')
16
+ MODELS_DIR.mkdir(exist_ok=True)
17
+
18
+ MODELS = {
19
+ 'yolov8n.pt': 'https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8n.pt',
20
+ 'yolov8s.pt': 'https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8s.pt',
21
+ 'yolov8m.pt': 'https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8m.pt',
22
+ 'yolov8l.pt': 'https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8l.pt',
23
+ 'yolov8x.pt': 'https://github.com/ultralytics/assets/releases/download/v8.1.0/yolov8x.pt',
24
+ }
25
+ MODEL_NAMES = list(MODELS.keys())
26
+
27
+ model_path = download_model(MODEL_NAMES[0], MODELS_DIR, MODELS)
28
+ default_model = YOLO(model_path)
29
+
30
+ IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png']
31
+ VIDEO_EXTENSIONS = ['.mp4', '.avi']
32
+
33
+
34
+ # =================== ДОП ФУНКЦИИ ИНТРЕФЕЙСА ==============================
35
+
36
+ def change_model(model_state: Dict[str, YOLO], model_name: str):
37
+ progress = gr.Progress()
38
+ progress(0.3, desc='Загрузка модели')
39
+ model_path = download_model(model_name)
40
+ progress(0.7, desc='Инициализация модели')
41
+ model_state['model'] = YOLO(model_path)
42
+ return f"Модель {model_name} инициализирована"
43
+
44
+
45
+ def detect(file_path: str, file_link: str, model_state: Dict[str, YOLO], conf: float, iou: float):
46
+ model = model_state['model']
47
+ if file_link:
48
+ file_path = file_link
49
+
50
+ file_ext = f'.{file_path.rsplit(".")[-1]}'
51
+ if file_ext in IMAGE_EXTENSIONS:
52
+ np_image = detect_image(file_path, model, conf, iou)
53
+ return np_image, "Детекция завершена, открытие изображения..."
54
+ elif file_ext in VIDEO_EXTENSIONS or 'youtube.com' in file_link:
55
+ video_path = detect_video(file_path, model, conf, iou)
56
+ return video_path, "Детекция завершена, конвертация и открытие видео..."
57
+ else:
58
+ gr.Info('Неверный формат изображения или видео...')
59
+ return None, None
60
+
61
+ # =================== КОМПОНЕНТЫ ИНТРЕФЕЙСА ==============================
62
+
63
+ def get_output_media_components(detect_result: Optional[Union[np.ndarray, str, Path]] = None):
64
+ visible = isinstance(detect_result, np.ndarray)
65
+ image_output = gr.Image(
66
+ value=detect_result if visible else None,
67
+ type="numpy",
68
+ width=640,
69
+ height=480,
70
+ visible=visible,
71
+ label='Output',
72
+ )
73
+ visible = isinstance(detect_result, (str, Path))
74
+ video_output = gr.Video(
75
+ value=detect_result if visible else None,
76
+ width=640,
77
+ height=480,
78
+ visible=visible,
79
+ label='Output',
80
+ )
81
+ clear_btn = gr.Button(
82
+ value='Clear',
83
+ scale=0,
84
+ visible=detect_result is not None,
85
+ )
86
+ return image_output, video_output, clear_btn
87
+
88
+
89
+ def get_download_csv_btn(csv_annotations_path: Optional[Path] = None):
90
+ download_csv_btn = gr.DownloadButton(
91
+ label='Скачать csv аннотации к видео',
92
+ value=csv_annotations_path,
93
+ scale=0,
94
+ visible=csv_annotations_path is not None,
95
+ )
96
+ return download_csv_btn
97
+
98
+ # =================== ИНТЕРФЕЙС ПРИЛОЖЕНИЯ ==========================
99
+
100
+ css = '''
101
+ .gradio-container { width: 70% !important }
102
+ '''
103
+ with gr.Blocks(css=css) as demo:
104
+ gr.HTML("""<h3 style='text-align: center'>YOLOv8 Detector</h3>""")
105
+
106
+ model_state = gr.State({'model': default_model})
107
+ detect_result = gr.State(None)
108
+ csv_annotations_path = gr.State(None)
109
+
110
+ with gr.Row():
111
+ with gr.Column():
112
+ file_path = gr.File(file_types=['image', 'video'], file_count='single', label='Выберите изображение или видео')
113
+ file_link = gr.Textbox(label='Прямая ссылка на изображение или ссылка на YouTube')
114
+ model_name = gr.Radio(choices=MODEL_NAMES, value=MODEL_NAMES[0], label='Модель YOLO')
115
+ conf = gr.Slider(0, 1, value=0.5, step=0.05, label='Порог уверенности')
116
+ iou = gr.Slider(0, 1, value=0.7, step=0.1, label='Порог IOU')
117
+ status_message = gr.Textbox(value='Готово к работе', label='Статус')
118
+ detect_btn = gr.Button('Detect', interactive=True)
119
+
120
+ with gr.Column():
121
+ image_output, video_output, clear_btn = get_output_media_components()
122
+ download_csv_btn = get_download_csv_btn()
123
+
124
+ model_name.change(
125
+ fn=lambda: gr.update(interactive=False),
126
+ inputs=None,
127
+ outputs=[detect_btn],
128
+ ).then(
129
+ fn=change_model,
130
+ inputs=[model_state, model_name],
131
+ outputs=[status_message],
132
+ ).success(
133
+ fn=lambda: gr.update(interactive=True),
134
+ inputs=None,
135
+ outputs=[detect_btn],
136
+ )
137
+
138
+ detect_btn.click(
139
+ fn=detect,
140
+ inputs=[file_path, file_link, model_state, conf, iou],
141
+ outputs=[detect_result, status_message],
142
+ ).success(
143
+ fn=get_output_media_components,
144
+ inputs=[detect_result],
145
+ outputs=[image_output, video_output, clear_btn],
146
+ ).then(
147
+ fn=lambda: 'Готово к работе',
148
+ inputs=None,
149
+ outputs=[status_message],
150
+ ).then(
151
+ fn=get_csv_annotate,
152
+ inputs=[detect_result],
153
+ outputs=[csv_annotations_path],
154
+ ).success(
155
+ fn=get_download_csv_btn,
156
+ inputs=[csv_annotations_path],
157
+ outputs=[download_csv_btn],
158
+ )
159
+
160
+ def clear_results_dir(detect_result):
161
+ if isinstance(detect_result, Path):
162
+ shutil.rmtree(detect_result.parent, ignore_errors=True)
163
+
164
+ clear_components = [image_output, video_output, clear_btn, download_csv_btn]
165
+ clear_btn.click(
166
+ fn=lambda: [gr.update(visible=False) for _ in range(len(clear_components))],
167
+ inputs=None,
168
+ outputs=clear_components,
169
+ ).then(
170
+ fn=clear_results_dir,
171
+ inputs=[detect_result],
172
+ outputs=None,
173
+ ).then(
174
+ fn=lambda: (None, None),
175
+ inputs=None,
176
+ outputs=[detect_result, csv_annotations_path]
177
+ )
178
+
179
+ gr.HTML("""<h3 style='text-align: center'>
180
+ <a href="https://github.com/sergey21000/yolo_gradio_detector" target='_blank'>GitHub Page</a></h3>
181
+ """)
182
+
183
+ demo.launch(server_name='0.0.0.0') # debug=True
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ ultralytics>=8,<9
3
+ gradio>4
4
+ yt_dlp
utils.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import json
4
+ import urllib.request
5
+ from pathlib import Path
6
+ from typing import List, Dict, Union, Tuple, Optional
7
+
8
+ import torch
9
+ import pandas as pd
10
+ import numpy as np
11
+ import cv2
12
+ import yt_dlp
13
+ import gradio as gr
14
+ from ultralytics import YOLO
15
+
16
+
17
+ YOLO_CLASS_NAMES = json.loads(Path('yolo_classes.json').read_text())
18
+
19
+
20
+ def download_model(model_name: str, models_dir: Path, models: dict) -> str:
21
+ model_path = models_dir / model_name
22
+ if not model_path.exists():
23
+ urllib.request.urlretrieve(models[model_name], model_path)
24
+ return str(model_path)
25
+
26
+
27
+ def detect_image(image_path: str, model: YOLO, conf: float, iou: float) -> np.ndarray:
28
+ gr.Progress()(0.5, desc='Детекция изображения...')
29
+ detections = model.predict(source=image_path, conf=conf, iou=iou)
30
+ np_image = detections[0].plot()
31
+ np_image = cv2.cvtColor(np_image, cv2.COLOR_BGR2RGB)
32
+ return np_image
33
+
34
+
35
+ def detect_video(video_path_or_url: str, model: YOLO, conf: float, iou: float) -> Tuple[Path, Path]:
36
+ progress = gr.Progress()
37
+ video_path = video_path_or_url
38
+ if 'youtube.com' in video_path_or_url or 'youtu.be' in video_path_or_url:
39
+ progress(0.001, desc='Загрузка видео с YouTube...')
40
+ ydl_opts = {'format': 'bestvideo[height<=720]'}
41
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
42
+ video_info_dict = ydl.extract_info(video_path_or_url, download=True)
43
+ video_path = ydl.prepare_filename(video_info_dict)
44
+
45
+ cap = cv2.VideoCapture(video_path)
46
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
47
+ cap.release()
48
+
49
+ generator = model.predict(
50
+ source=video_path,
51
+ conf=0.5,
52
+ iou=0.5,
53
+ save=True,
54
+ save_txt=True,
55
+ save_conf=True,
56
+ stream=True,
57
+ verbose=False,
58
+ )
59
+
60
+ frames_count = 0
61
+ for result in generator:
62
+ frames_count += 1
63
+ progress((frames_count, num_frames), desc=f'Детекция видео, шаг {frames_count}/{num_frames}')
64
+
65
+ file_name = Path(result.path).with_suffix('.avi').name
66
+ result_video_path = Path(result.save_dir) / file_name
67
+ Path(video_path).unlink(missing_ok=True)
68
+ return result_video_path
69
+
70
+
71
+ def get_csv_annotate(result_video_path: Path) -> str:
72
+ if not isinstance(result_video_path, Path):
73
+ return None
74
+
75
+ txts_path = result_video_path.parent / 'labels'
76
+ escaped_pattern = glob.escape(result_video_path.stem)
77
+ matching_txts_path = sorted(txts_path.glob(f'{escaped_pattern}_*.txt'), key=os.path.getmtime)
78
+
79
+ df_list = []
80
+ for txt_path in matching_txts_path:
81
+ frame_number = int(txt_path.stem.rsplit('_')[-1])
82
+ with open(txt_path) as file:
83
+ df_rows = file.readlines()
84
+ for df_row in df_rows:
85
+ df_row = map(float, df_row.split())
86
+ df_list.append((frame_number, *df_row))
87
+
88
+ column_names = ['frame_number', 'class_label', 'x', 'y', 'w', 'h', 'conf']
89
+ df = pd.DataFrame(df_list, columns=column_names)
90
+
91
+ df.class_label = df.class_label.astype(int)
92
+ class_name_series = df.class_label.map(YOLO_CLASS_NAMES)
93
+ df.insert(loc=1, column='class_name', value=class_name_series)
94
+
95
+ cap = cv2.VideoCapture(str(result_video_path))
96
+ frames_fps = int(cap.get(cv2.CAP_PROP_FPS))
97
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
98
+ cap.release()
99
+
100
+ frame_sec_series = df.frame_number / frames_fps
101
+ df.insert(loc=1, column='frame_sec', value=frame_sec_series)
102
+
103
+ full_frames = pd.DataFrame({'frame_number': range(total_frames)})
104
+ df = pd.merge(full_frames, df, on='frame_number', how='outer')
105
+ df.frame_sec = df.frame_number / frames_fps
106
+
107
+ result_csv_path = f'{result_video_path.parent / result_video_path.stem}_annotations.csv'
108
+ df.to_csv(result_csv_path, index=False)
109
+ return result_csv_path
yolo_classes.json ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0": "person",
3
+ "1": "bicycle",
4
+ "2": "car",
5
+ "3": "motorcycle",
6
+ "4": "airplane",
7
+ "5": "bus",
8
+ "6": "train",
9
+ "7": "truck",
10
+ "8": "boat",
11
+ "9": "traffic light",
12
+ "10": "fire hydrant",
13
+ "11": "stop sign",
14
+ "12": "parking meter",
15
+ "13": "bench",
16
+ "14": "bird",
17
+ "15": "cat",
18
+ "16": "dog",
19
+ "17": "horse",
20
+ "18": "sheep",
21
+ "19": "cow",
22
+ "20": "elephant",
23
+ "21": "bear",
24
+ "22": "zebra",
25
+ "23": "giraffe",
26
+ "24": "backpack",
27
+ "25": "umbrella",
28
+ "26": "handbag",
29
+ "27": "tie",
30
+ "28": "suitcase",
31
+ "29": "frisbee",
32
+ "30": "skis",
33
+ "31": "snowboard",
34
+ "32": "sports ball",
35
+ "33": "kite",
36
+ "34": "baseball bat",
37
+ "35": "baseball glove",
38
+ "36": "skateboard",
39
+ "37": "surfboard",
40
+ "38": "tennis racket",
41
+ "39": "bottle",
42
+ "40": "wine glass",
43
+ "41": "cup",
44
+ "42": "fork",
45
+ "43": "knife",
46
+ "44": "spoon",
47
+ "45": "bowl",
48
+ "46": "banana",
49
+ "47": "apple",
50
+ "48": "sandwich",
51
+ "49": "orange",
52
+ "50": "broccoli",
53
+ "51": "carrot",
54
+ "52": "hot dog",
55
+ "53": "pizza",
56
+ "54": "donut",
57
+ "55": "cake",
58
+ "56": "chair",
59
+ "57": "couch",
60
+ "58": "potted plant",
61
+ "59": "bed",
62
+ "60": "dining table",
63
+ "61": "toilet",
64
+ "62": "tv",
65
+ "63": "laptop",
66
+ "64": "mouse",
67
+ "65": "remote",
68
+ "66": "keyboard",
69
+ "67": "cell phone",
70
+ "68": "microwave",
71
+ "69": "oven",
72
+ "70": "toaster",
73
+ "71": "sink",
74
+ "72": "refrigerator",
75
+ "73": "book",
76
+ "74": "clock",
77
+ "75": "vase",
78
+ "76": "scissors",
79
+ "77": "teddy bear",
80
+ "78": "hair drier",
81
+ "79": "toothbrush"
82
+ }