Upload prepare_yolo_dataset.py with huggingface_hub
Browse files- prepare_yolo_dataset.py +324 -0
prepare_yolo_dataset.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
|
6 |
+
#######################################################
|
7 |
+
# CONFIGURATION SECTION - MODIFY THESE VALUES
|
8 |
+
#######################################################
|
9 |
+
|
10 |
+
# Define source directories for each location
|
11 |
+
SOURCE_DIRS = {
|
12 |
+
'location_1': 'mpala', # REPLACE WITH YOUR ACTUAL PATH
|
13 |
+
'location_2': 'opc', # REPLACE WITH YOUR ACTUAL PATH
|
14 |
+
'location_3': 'wilds' # REPLACE WITH YOUR ACTUAL PATH
|
15 |
+
}
|
16 |
+
|
17 |
+
# Destination directory
|
18 |
+
DEST_DIR = "/data" # REPLACE WITH YOUR ACTUAL PATH
|
19 |
+
|
20 |
+
# Define your class labels
|
21 |
+
CLASS_LABELS = {
|
22 |
+
0: "Zebra",
|
23 |
+
1: "Giraffe",
|
24 |
+
2: "Onager",
|
25 |
+
3: "Dog",
|
26 |
+
}
|
27 |
+
|
28 |
+
# Sampling rate (adjust as needed - higher values mean fewer frames)
|
29 |
+
SAMPLING_RATE = 10
|
30 |
+
|
31 |
+
# Define the splits (train/test) for the 70/30 strategy
|
32 |
+
splits = {
|
33 |
+
'train': {
|
34 |
+
'location_3': {
|
35 |
+
'session_1': ['DJI_0034', 'DJI_0035_part1'], # African Painted Dog (70%)
|
36 |
+
'session_2': ['P0140018'], # Giraffe (70%)
|
37 |
+
'session_3': ['P0100010', 'P0110011', 'P0080008', 'P0090009'], # Persian Onanger (70%)
|
38 |
+
|
39 |
+
},
|
40 |
+
'location_1': {
|
41 |
+
'session_1': ['DJI_0001', 'DJI_0002'], # Giraffe
|
42 |
+
'session_2': ['DJI_0005', 'DJI_0006'], # Plains zebra
|
43 |
+
'session_3': ['DJI_0068', 'DJI_0069'], # Grevy's zebra
|
44 |
+
'session_4': ['DJI_0142', 'DJI_0143', 'DJI_0144'], # Grevy's zebra
|
45 |
+
'session_5': ['DJI_0206', 'DJI_0208'], # Mixed species
|
46 |
+
},
|
47 |
+
'location_2': {
|
48 |
+
'session_1': ['P0800081', 'P0830086', 'P0840087', 'P0870091'], # Plains zebra
|
49 |
+
'session_2': ['P0910095'], # Plains zebra
|
50 |
+
}
|
51 |
+
},
|
52 |
+
'test': {
|
53 |
+
'location_3': {
|
54 |
+
'session_1': ['DJI_0035_part2'], # African Painted Dog (30%)
|
55 |
+
'session_3': ['P0070007', 'P0160016', 'P0120012'], # Persian Onanger (30%)
|
56 |
+
'session_2': ['P0150019'], # Giraffe (30%)
|
57 |
+
'session_4': ['P0070010'], # Grevy's Zebra (100%)
|
58 |
+
},
|
59 |
+
'location_1': {
|
60 |
+
'session_3': ['DJI_0070', 'DJI_0071'], # Grevy's zebra
|
61 |
+
'session_4': ['DJI_0145', 'DJI_0146', 'DJI_0147'], # Grevy's zebra
|
62 |
+
'session_5': ['DJI_0210', 'DJI_0211'], # Mixed species
|
63 |
+
},
|
64 |
+
'location_2': {
|
65 |
+
'session_1': ['P0860090'], # Plains zebra
|
66 |
+
'session_2': ['P0940098'], # Plains zebra
|
67 |
+
}
|
68 |
+
}
|
69 |
+
}
|
70 |
+
|
71 |
+
#######################################################
|
72 |
+
# SCRIPT CODE - DO NOT MODIFY UNLESS NECESSARY
|
73 |
+
#######################################################
|
74 |
+
|
75 |
+
# Create destination directories
|
76 |
+
for split in ['train', 'test']:
|
77 |
+
os.makedirs(f"{DEST_DIR}/images/{split}", exist_ok=True)
|
78 |
+
os.makedirs(f"{DEST_DIR}/labels/{split}", exist_ok=True)
|
79 |
+
|
80 |
+
def find_images_in_directory(dir_path):
|
81 |
+
"""Find all image files in a directory"""
|
82 |
+
try:
|
83 |
+
return [f for f in os.listdir(dir_path)
|
84 |
+
if f.endswith(('.jpg', '.png', '.jpeg')) and os.path.isfile(dir_path / f)]
|
85 |
+
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
|
86 |
+
print(f"Error accessing {dir_path}: {e}")
|
87 |
+
return []
|
88 |
+
|
89 |
+
def find_partitions(session_path):
|
90 |
+
"""Find partition directories in a session"""
|
91 |
+
try:
|
92 |
+
return [d for d in os.listdir(session_path)
|
93 |
+
if os.path.isdir(session_path / d) and d.startswith('partition_')]
|
94 |
+
except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
|
95 |
+
print(f"Error accessing {session_path}: {e}")
|
96 |
+
return []
|
97 |
+
|
98 |
+
def find_video_images(session_path, video_name):
|
99 |
+
"""
|
100 |
+
Find all images for a specific video in all partitions or video directory
|
101 |
+
Returns a list of tuples: (image_path, image_name, partition_name)
|
102 |
+
"""
|
103 |
+
all_images = []
|
104 |
+
|
105 |
+
# First, check if the video is directly a directory
|
106 |
+
video_path = session_path / video_name
|
107 |
+
if os.path.isdir(video_path):
|
108 |
+
# Check for partitions within video directory
|
109 |
+
partitions = find_partitions(video_path)
|
110 |
+
|
111 |
+
if partitions:
|
112 |
+
# If partitions exist in video directory
|
113 |
+
for partition in partitions:
|
114 |
+
partition_path = video_path / partition
|
115 |
+
images = find_images_in_directory(partition_path)
|
116 |
+
all_images.extend([(partition_path, img, partition) for img in images])
|
117 |
+
else:
|
118 |
+
# Check for direct images in video directory (no partitions)
|
119 |
+
images = find_images_in_directory(video_path)
|
120 |
+
all_images.extend([(video_path, img, "") for img in images])
|
121 |
+
|
122 |
+
# Also check for partitions directly in session directory
|
123 |
+
partitions = find_partitions(session_path)
|
124 |
+
for partition in partitions:
|
125 |
+
partition_path = session_path / partition
|
126 |
+
|
127 |
+
# Look for images matching this video name pattern
|
128 |
+
for img in find_images_in_directory(partition_path):
|
129 |
+
# Check if image filename contains this video name
|
130 |
+
if video_name in img:
|
131 |
+
all_images.append((partition_path, img, partition))
|
132 |
+
|
133 |
+
return all_images
|
134 |
+
|
135 |
+
# Process each location and session
|
136 |
+
for split_name, locations in splits.items():
|
137 |
+
for location_name, sessions in locations.items():
|
138 |
+
# Get the source directory for this location
|
139 |
+
if location_name not in SOURCE_DIRS:
|
140 |
+
print(f"Warning: No source directory defined for {location_name}. Skipping.")
|
141 |
+
continue
|
142 |
+
|
143 |
+
location_source_dir = Path(SOURCE_DIRS[location_name])
|
144 |
+
|
145 |
+
for session_name, video_info in sessions.items():
|
146 |
+
session_path = location_source_dir / session_name
|
147 |
+
|
148 |
+
if not os.path.exists(session_path):
|
149 |
+
print(f"Warning: Session path {session_path} does not exist. Skipping.")
|
150 |
+
continue
|
151 |
+
|
152 |
+
# Get all videos in this session
|
153 |
+
if isinstance(video_info, bool) and video_info:
|
154 |
+
# Use all videos in the session - detect them from directories or video files
|
155 |
+
try:
|
156 |
+
# First check for video directories
|
157 |
+
videos = [v for v in os.listdir(session_path)
|
158 |
+
if os.path.isdir(session_path / v) and not v.startswith('partition_')]
|
159 |
+
|
160 |
+
# If no video directories, try to infer from partition files
|
161 |
+
if not videos:
|
162 |
+
partitions = find_partitions(session_path)
|
163 |
+
if partitions:
|
164 |
+
# Get all images in first partition to extract video names
|
165 |
+
first_partition = session_path / partitions[0]
|
166 |
+
all_imgs = find_images_in_directory(first_partition)
|
167 |
+
# Extract potential video names from image filenames
|
168 |
+
videos = list(set([img.split('_')[0] for img in all_imgs if '_' in img]))
|
169 |
+
|
170 |
+
except (FileNotFoundError, NotADirectoryError) as e:
|
171 |
+
print(f"Warning: Could not list directory {session_path}: {e}")
|
172 |
+
continue
|
173 |
+
else:
|
174 |
+
# Use specific videos
|
175 |
+
videos = video_info
|
176 |
+
|
177 |
+
# Process each video
|
178 |
+
for video in videos:
|
179 |
+
print(f"Processing {location_name}/{session_name}/{video}...")
|
180 |
+
|
181 |
+
# Find all images for this video (in all partitions)
|
182 |
+
frame_info = find_video_images(session_path, video)
|
183 |
+
|
184 |
+
if not frame_info:
|
185 |
+
print(f"Warning: No frames found for {video} in {session_name}")
|
186 |
+
continue
|
187 |
+
|
188 |
+
# Sort frames by name to ensure temporal order
|
189 |
+
frame_info.sort(key=lambda x: x[1])
|
190 |
+
|
191 |
+
# Sample frames at regular intervals
|
192 |
+
sampled_frame_info = frame_info[::SAMPLING_RATE]
|
193 |
+
|
194 |
+
# Copy sampled frames and labels to destination
|
195 |
+
for frame_dir, frame_name, partition in sampled_frame_info:
|
196 |
+
# Create a path component for the partition if it exists
|
197 |
+
partition_str = "" if partition == "" else f"_{partition}"
|
198 |
+
|
199 |
+
# Copy image
|
200 |
+
src_img = frame_dir / frame_name
|
201 |
+
dest_img_name = f"{location_name}_{session_name}_{video}{partition_str}_{frame_name}"
|
202 |
+
dest_img = Path(DEST_DIR) / "images" / split_name / dest_img_name
|
203 |
+
|
204 |
+
try:
|
205 |
+
shutil.copy(src_img, dest_img)
|
206 |
+
except (FileNotFoundError, IOError) as e:
|
207 |
+
print(f"Error copying image {src_img}: {e}")
|
208 |
+
continue
|
209 |
+
|
210 |
+
# Handle different possible label locations
|
211 |
+
label_name = frame_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt')
|
212 |
+
|
213 |
+
# Possible label locations (in order of priority)
|
214 |
+
possible_label_paths = [
|
215 |
+
# 1. Same directory as image
|
216 |
+
frame_dir / label_name,
|
217 |
+
|
218 |
+
# 2. Labels subdirectory in partition
|
219 |
+
frame_dir / "labels" / label_name,
|
220 |
+
|
221 |
+
# 3. Labels directory parallel to partition with same structure
|
222 |
+
session_path / "labels" / partition / label_name,
|
223 |
+
|
224 |
+
# 4. Flat labels directory for session
|
225 |
+
session_path / "labels" / label_name,
|
226 |
+
|
227 |
+
# 5. In video directory (if it exists)
|
228 |
+
session_path / video / "labels" / label_name,
|
229 |
+
]
|
230 |
+
|
231 |
+
src_label = None
|
232 |
+
for label_path in possible_label_paths:
|
233 |
+
if os.path.exists(label_path):
|
234 |
+
src_label = label_path
|
235 |
+
break
|
236 |
+
|
237 |
+
if src_label:
|
238 |
+
dest_label_name = dest_img_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt')
|
239 |
+
dest_label = Path(DEST_DIR) / "labels" / split_name / dest_label_name
|
240 |
+
try:
|
241 |
+
shutil.copy(src_label, dest_label)
|
242 |
+
except (FileNotFoundError, IOError) as e:
|
243 |
+
print(f"Error copying label {src_label}: {e}")
|
244 |
+
else:
|
245 |
+
print(f"Warning: No label found for {src_img}")
|
246 |
+
|
247 |
+
print("Dataset split completed successfully!")
|
248 |
+
|
249 |
+
# Create dataset.yaml file
|
250 |
+
def create_dataset_yaml():
|
251 |
+
with open(f"{DEST_DIR}/dataset.yaml", "w") as f:
|
252 |
+
f.write(f"# YOLOv11 dataset config\n")
|
253 |
+
f.write(f"path: {os.path.abspath(DEST_DIR)} # dataset root dir\n")
|
254 |
+
f.write(f"train: images/train # train images\n")
|
255 |
+
f.write(f"val: images/train # validation uses train images\n")
|
256 |
+
f.write(f"test: images/test # test images\n\n")
|
257 |
+
|
258 |
+
f.write(f"# Classes\n")
|
259 |
+
f.write(f"names:\n")
|
260 |
+
for class_id, class_name in CLASS_LABELS.items():
|
261 |
+
f.write(f" {class_id}: {class_name}\n")
|
262 |
+
|
263 |
+
create_dataset_yaml()
|
264 |
+
|
265 |
+
# Analyze the distribution
|
266 |
+
stats = {"train": {}, "test": {}}
|
267 |
+
|
268 |
+
for split in ['train', 'test']:
|
269 |
+
# Count images by location
|
270 |
+
locations = {}
|
271 |
+
species_count = {}
|
272 |
+
|
273 |
+
# Get all images in this split
|
274 |
+
img_dir = Path(DEST_DIR) / "images" / split
|
275 |
+
if not os.path.exists(img_dir):
|
276 |
+
print(f"Warning: Directory {img_dir} does not exist.")
|
277 |
+
continue
|
278 |
+
|
279 |
+
total_count = 0
|
280 |
+
|
281 |
+
for img in os.listdir(img_dir):
|
282 |
+
parts = img.split('_')
|
283 |
+
if len(parts) < 2:
|
284 |
+
continue
|
285 |
+
|
286 |
+
location = parts[0]
|
287 |
+
session = parts[1]
|
288 |
+
|
289 |
+
# Count by location
|
290 |
+
if location not in locations:
|
291 |
+
locations[location] = 0
|
292 |
+
locations[location] += 1
|
293 |
+
|
294 |
+
# Extract species information if possible
|
295 |
+
species_key = f"{location}_{session}"
|
296 |
+
if species_key not in species_count:
|
297 |
+
species_count[species_key] = 0
|
298 |
+
species_count[species_key] += 1
|
299 |
+
|
300 |
+
# Increment total
|
301 |
+
total_count += 1
|
302 |
+
|
303 |
+
stats[split]["total"] = total_count
|
304 |
+
stats[split]["locations"] = locations
|
305 |
+
stats[split]["species"] = species_count
|
306 |
+
|
307 |
+
# Print stats
|
308 |
+
for split, data in stats.items():
|
309 |
+
print(f"\n{split.upper()} set:")
|
310 |
+
print(f"Total images: {data['total']}")
|
311 |
+
|
312 |
+
print("Distribution by location:")
|
313 |
+
for loc, count in data["locations"].items():
|
314 |
+
percentage = (count/data['total']*100) if data['total'] > 0 else 0
|
315 |
+
print(f" - {loc}: {count} ({percentage:.1f}%)")
|
316 |
+
|
317 |
+
print("\nDistribution by location_session:")
|
318 |
+
for species_key, count in data["species"].items():
|
319 |
+
percentage = (count/data['total']*100) if data['total'] > 0 else 0
|
320 |
+
print(f" - {species_key}: {count} ({percentage:.1f}%)")
|
321 |
+
|
322 |
+
print("\nOverall train/test ratio:",
|
323 |
+
f"{stats['train']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}",
|
324 |
+
f"/ {stats['test']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}")
|