jennamk14 commited on
Commit
418b0ac
·
verified ·
1 Parent(s): cbb1c3c

Upload prepare_yolo_dataset.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. prepare_yolo_dataset.py +324 -0
prepare_yolo_dataset.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from pathlib import Path
4
+ import shutil
5
+
6
+ #######################################################
7
+ # CONFIGURATION SECTION - MODIFY THESE VALUES
8
+ #######################################################
9
+
10
+ # Define source directories for each location
11
+ SOURCE_DIRS = {
12
+ 'location_1': 'mpala', # REPLACE WITH YOUR ACTUAL PATH
13
+ 'location_2': 'opc', # REPLACE WITH YOUR ACTUAL PATH
14
+ 'location_3': 'wilds' # REPLACE WITH YOUR ACTUAL PATH
15
+ }
16
+
17
+ # Destination directory
18
+ DEST_DIR = "/data" # REPLACE WITH YOUR ACTUAL PATH
19
+
20
+ # Define your class labels
21
+ CLASS_LABELS = {
22
+ 0: "Zebra",
23
+ 1: "Giraffe",
24
+ 2: "Onager",
25
+ 3: "Dog",
26
+ }
27
+
28
+ # Sampling rate (adjust as needed - higher values mean fewer frames)
29
+ SAMPLING_RATE = 10
30
+
31
+ # Define the splits (train/test) for the 70/30 strategy
32
+ splits = {
33
+ 'train': {
34
+ 'location_3': {
35
+ 'session_1': ['DJI_0034', 'DJI_0035_part1'], # African Painted Dog (70%)
36
+ 'session_2': ['P0140018'], # Giraffe (70%)
37
+ 'session_3': ['P0100010', 'P0110011', 'P0080008', 'P0090009'], # Persian Onanger (70%)
38
+
39
+ },
40
+ 'location_1': {
41
+ 'session_1': ['DJI_0001', 'DJI_0002'], # Giraffe
42
+ 'session_2': ['DJI_0005', 'DJI_0006'], # Plains zebra
43
+ 'session_3': ['DJI_0068', 'DJI_0069'], # Grevy's zebra
44
+ 'session_4': ['DJI_0142', 'DJI_0143', 'DJI_0144'], # Grevy's zebra
45
+ 'session_5': ['DJI_0206', 'DJI_0208'], # Mixed species
46
+ },
47
+ 'location_2': {
48
+ 'session_1': ['P0800081', 'P0830086', 'P0840087', 'P0870091'], # Plains zebra
49
+ 'session_2': ['P0910095'], # Plains zebra
50
+ }
51
+ },
52
+ 'test': {
53
+ 'location_3': {
54
+ 'session_1': ['DJI_0035_part2'], # African Painted Dog (30%)
55
+ 'session_3': ['P0070007', 'P0160016', 'P0120012'], # Persian Onanger (30%)
56
+ 'session_2': ['P0150019'], # Giraffe (30%)
57
+ 'session_4': ['P0070010'], # Grevy's Zebra (100%)
58
+ },
59
+ 'location_1': {
60
+ 'session_3': ['DJI_0070', 'DJI_0071'], # Grevy's zebra
61
+ 'session_4': ['DJI_0145', 'DJI_0146', 'DJI_0147'], # Grevy's zebra
62
+ 'session_5': ['DJI_0210', 'DJI_0211'], # Mixed species
63
+ },
64
+ 'location_2': {
65
+ 'session_1': ['P0860090'], # Plains zebra
66
+ 'session_2': ['P0940098'], # Plains zebra
67
+ }
68
+ }
69
+ }
70
+
71
+ #######################################################
72
+ # SCRIPT CODE - DO NOT MODIFY UNLESS NECESSARY
73
+ #######################################################
74
+
75
+ # Create destination directories
76
+ for split in ['train', 'test']:
77
+ os.makedirs(f"{DEST_DIR}/images/{split}", exist_ok=True)
78
+ os.makedirs(f"{DEST_DIR}/labels/{split}", exist_ok=True)
79
+
80
+ def find_images_in_directory(dir_path):
81
+ """Find all image files in a directory"""
82
+ try:
83
+ return [f for f in os.listdir(dir_path)
84
+ if f.endswith(('.jpg', '.png', '.jpeg')) and os.path.isfile(dir_path / f)]
85
+ except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
86
+ print(f"Error accessing {dir_path}: {e}")
87
+ return []
88
+
89
+ def find_partitions(session_path):
90
+ """Find partition directories in a session"""
91
+ try:
92
+ return [d for d in os.listdir(session_path)
93
+ if os.path.isdir(session_path / d) and d.startswith('partition_')]
94
+ except (FileNotFoundError, NotADirectoryError, PermissionError) as e:
95
+ print(f"Error accessing {session_path}: {e}")
96
+ return []
97
+
98
+ def find_video_images(session_path, video_name):
99
+ """
100
+ Find all images for a specific video in all partitions or video directory
101
+ Returns a list of tuples: (image_path, image_name, partition_name)
102
+ """
103
+ all_images = []
104
+
105
+ # First, check if the video is directly a directory
106
+ video_path = session_path / video_name
107
+ if os.path.isdir(video_path):
108
+ # Check for partitions within video directory
109
+ partitions = find_partitions(video_path)
110
+
111
+ if partitions:
112
+ # If partitions exist in video directory
113
+ for partition in partitions:
114
+ partition_path = video_path / partition
115
+ images = find_images_in_directory(partition_path)
116
+ all_images.extend([(partition_path, img, partition) for img in images])
117
+ else:
118
+ # Check for direct images in video directory (no partitions)
119
+ images = find_images_in_directory(video_path)
120
+ all_images.extend([(video_path, img, "") for img in images])
121
+
122
+ # Also check for partitions directly in session directory
123
+ partitions = find_partitions(session_path)
124
+ for partition in partitions:
125
+ partition_path = session_path / partition
126
+
127
+ # Look for images matching this video name pattern
128
+ for img in find_images_in_directory(partition_path):
129
+ # Check if image filename contains this video name
130
+ if video_name in img:
131
+ all_images.append((partition_path, img, partition))
132
+
133
+ return all_images
134
+
135
+ # Process each location and session
136
+ for split_name, locations in splits.items():
137
+ for location_name, sessions in locations.items():
138
+ # Get the source directory for this location
139
+ if location_name not in SOURCE_DIRS:
140
+ print(f"Warning: No source directory defined for {location_name}. Skipping.")
141
+ continue
142
+
143
+ location_source_dir = Path(SOURCE_DIRS[location_name])
144
+
145
+ for session_name, video_info in sessions.items():
146
+ session_path = location_source_dir / session_name
147
+
148
+ if not os.path.exists(session_path):
149
+ print(f"Warning: Session path {session_path} does not exist. Skipping.")
150
+ continue
151
+
152
+ # Get all videos in this session
153
+ if isinstance(video_info, bool) and video_info:
154
+ # Use all videos in the session - detect them from directories or video files
155
+ try:
156
+ # First check for video directories
157
+ videos = [v for v in os.listdir(session_path)
158
+ if os.path.isdir(session_path / v) and not v.startswith('partition_')]
159
+
160
+ # If no video directories, try to infer from partition files
161
+ if not videos:
162
+ partitions = find_partitions(session_path)
163
+ if partitions:
164
+ # Get all images in first partition to extract video names
165
+ first_partition = session_path / partitions[0]
166
+ all_imgs = find_images_in_directory(first_partition)
167
+ # Extract potential video names from image filenames
168
+ videos = list(set([img.split('_')[0] for img in all_imgs if '_' in img]))
169
+
170
+ except (FileNotFoundError, NotADirectoryError) as e:
171
+ print(f"Warning: Could not list directory {session_path}: {e}")
172
+ continue
173
+ else:
174
+ # Use specific videos
175
+ videos = video_info
176
+
177
+ # Process each video
178
+ for video in videos:
179
+ print(f"Processing {location_name}/{session_name}/{video}...")
180
+
181
+ # Find all images for this video (in all partitions)
182
+ frame_info = find_video_images(session_path, video)
183
+
184
+ if not frame_info:
185
+ print(f"Warning: No frames found for {video} in {session_name}")
186
+ continue
187
+
188
+ # Sort frames by name to ensure temporal order
189
+ frame_info.sort(key=lambda x: x[1])
190
+
191
+ # Sample frames at regular intervals
192
+ sampled_frame_info = frame_info[::SAMPLING_RATE]
193
+
194
+ # Copy sampled frames and labels to destination
195
+ for frame_dir, frame_name, partition in sampled_frame_info:
196
+ # Create a path component for the partition if it exists
197
+ partition_str = "" if partition == "" else f"_{partition}"
198
+
199
+ # Copy image
200
+ src_img = frame_dir / frame_name
201
+ dest_img_name = f"{location_name}_{session_name}_{video}{partition_str}_{frame_name}"
202
+ dest_img = Path(DEST_DIR) / "images" / split_name / dest_img_name
203
+
204
+ try:
205
+ shutil.copy(src_img, dest_img)
206
+ except (FileNotFoundError, IOError) as e:
207
+ print(f"Error copying image {src_img}: {e}")
208
+ continue
209
+
210
+ # Handle different possible label locations
211
+ label_name = frame_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt')
212
+
213
+ # Possible label locations (in order of priority)
214
+ possible_label_paths = [
215
+ # 1. Same directory as image
216
+ frame_dir / label_name,
217
+
218
+ # 2. Labels subdirectory in partition
219
+ frame_dir / "labels" / label_name,
220
+
221
+ # 3. Labels directory parallel to partition with same structure
222
+ session_path / "labels" / partition / label_name,
223
+
224
+ # 4. Flat labels directory for session
225
+ session_path / "labels" / label_name,
226
+
227
+ # 5. In video directory (if it exists)
228
+ session_path / video / "labels" / label_name,
229
+ ]
230
+
231
+ src_label = None
232
+ for label_path in possible_label_paths:
233
+ if os.path.exists(label_path):
234
+ src_label = label_path
235
+ break
236
+
237
+ if src_label:
238
+ dest_label_name = dest_img_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt')
239
+ dest_label = Path(DEST_DIR) / "labels" / split_name / dest_label_name
240
+ try:
241
+ shutil.copy(src_label, dest_label)
242
+ except (FileNotFoundError, IOError) as e:
243
+ print(f"Error copying label {src_label}: {e}")
244
+ else:
245
+ print(f"Warning: No label found for {src_img}")
246
+
247
+ print("Dataset split completed successfully!")
248
+
249
+ # Create dataset.yaml file
250
+ def create_dataset_yaml():
251
+ with open(f"{DEST_DIR}/dataset.yaml", "w") as f:
252
+ f.write(f"# YOLOv11 dataset config\n")
253
+ f.write(f"path: {os.path.abspath(DEST_DIR)} # dataset root dir\n")
254
+ f.write(f"train: images/train # train images\n")
255
+ f.write(f"val: images/train # validation uses train images\n")
256
+ f.write(f"test: images/test # test images\n\n")
257
+
258
+ f.write(f"# Classes\n")
259
+ f.write(f"names:\n")
260
+ for class_id, class_name in CLASS_LABELS.items():
261
+ f.write(f" {class_id}: {class_name}\n")
262
+
263
+ create_dataset_yaml()
264
+
265
+ # Analyze the distribution
266
+ stats = {"train": {}, "test": {}}
267
+
268
+ for split in ['train', 'test']:
269
+ # Count images by location
270
+ locations = {}
271
+ species_count = {}
272
+
273
+ # Get all images in this split
274
+ img_dir = Path(DEST_DIR) / "images" / split
275
+ if not os.path.exists(img_dir):
276
+ print(f"Warning: Directory {img_dir} does not exist.")
277
+ continue
278
+
279
+ total_count = 0
280
+
281
+ for img in os.listdir(img_dir):
282
+ parts = img.split('_')
283
+ if len(parts) < 2:
284
+ continue
285
+
286
+ location = parts[0]
287
+ session = parts[1]
288
+
289
+ # Count by location
290
+ if location not in locations:
291
+ locations[location] = 0
292
+ locations[location] += 1
293
+
294
+ # Extract species information if possible
295
+ species_key = f"{location}_{session}"
296
+ if species_key not in species_count:
297
+ species_count[species_key] = 0
298
+ species_count[species_key] += 1
299
+
300
+ # Increment total
301
+ total_count += 1
302
+
303
+ stats[split]["total"] = total_count
304
+ stats[split]["locations"] = locations
305
+ stats[split]["species"] = species_count
306
+
307
+ # Print stats
308
+ for split, data in stats.items():
309
+ print(f"\n{split.upper()} set:")
310
+ print(f"Total images: {data['total']}")
311
+
312
+ print("Distribution by location:")
313
+ for loc, count in data["locations"].items():
314
+ percentage = (count/data['total']*100) if data['total'] > 0 else 0
315
+ print(f" - {loc}: {count} ({percentage:.1f}%)")
316
+
317
+ print("\nDistribution by location_session:")
318
+ for species_key, count in data["species"].items():
319
+ percentage = (count/data['total']*100) if data['total'] > 0 else 0
320
+ print(f" - {species_key}: {count} ({percentage:.1f}%)")
321
+
322
+ print("\nOverall train/test ratio:",
323
+ f"{stats['train']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}",
324
+ f"/ {stats['test']['total'] / (stats['train']['total'] + stats['test']['total']):.1%}")