WildlifeDatasets commited on
Commit
bb14d6a
·
unverified ·
1 Parent(s): 7fddf9c

Added training scripts

Browse files
training/segmentation_prepare.ipynb ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ebe0faa7",
6
+ "metadata": {},
7
+ "source": [
8
+ "This notebook prepares the datasets for training of the turtle detection model. First, it goes through the SeaTurtleID2022 dataset and converts the existing masks into the YOLO format needed by Ultralytics. Then it goes through the TurtlesOfSMSRC dataset, loads the masks created in the smsrc_prepare notebook and again, converts the masks to the YOLO format. Finally, the metadata are merged together and are ready to use the segmentation_train script, which first trains on SeaTurtleID2022 (photos below water) and then finetunes on the combined dataset (photos above water were added)."
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "a2e66c17",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import os\n",
19
+ "import json\n",
20
+ "import shutil\n",
21
+ "import numpy as np\n",
22
+ "import pandas as pd\n",
23
+ "from tqdm import tqdm\n",
24
+ "from wildlife_datasets.datasets import SeaTurtleID2022, TurtlesOfSMSRC\n",
25
+ "from wildlife_datasets.datasets.utils import find_images, parse_bbox_mask\n",
26
+ "from wildlife_datasets.splits import ClosedSetSplit\n",
27
+ "from turtle_detector import get_index, rle_to_yolo, uncompressed_rle_to_yolo"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "id": "93be7212",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "root_out = f'/data/wildlife_datasets/turtle-detector'\n",
38
+ "\n",
39
+ "for addition in ['images/train', 'images/val', 'labels/train', 'labels/val']:\n",
40
+ " for dataset_name in ['SeaTurtleID2022', 'TurtlesOfSMSRC']:\n",
41
+ " os.makedirs(os.path.join(root_out, addition, dataset_name), exist_ok=True)"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "markdown",
46
+ "id": "14b3e193",
47
+ "metadata": {},
48
+ "source": [
49
+ "# SeaTurtleID2022"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "4c664fa1",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "dataset_name = 'SeaTurtleID2022'\n",
60
+ "root = '/data/wildlife_datasets/data/SeaTurtleID2022'\n",
61
+ "\n",
62
+ "dataset = SeaTurtleID2022(root)\n",
63
+ "if dataset.df['path'].nunique() != len(dataset):\n",
64
+ " raise ValueError('path is not unique')"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": null,
70
+ "id": "64391cd5",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "splitter = ClosedSetSplit(0.8)\n",
75
+ "idx_train, idx_test = splitter.split(dataset.df)[0]\n",
76
+ "idx_train += 1\n",
77
+ "idx_test += 1"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": null,
83
+ "id": "057b036e",
84
+ "metadata": {},
85
+ "outputs": [],
86
+ "source": [
87
+ "flipper_categories = {\n",
88
+ " '': 0,\n",
89
+ " 'front_left': 2,\n",
90
+ " 'front_right': 3,\n",
91
+ " 'rear_left': 4,\n",
92
+ " 'rear_right': 5,\n",
93
+ "}\n",
94
+ "\n",
95
+ "root_ann = f'{root}/turtles-data/data'\n",
96
+ "with open(os.path.join(root_ann, 'annotations.json')) as file:\n",
97
+ " annotations = json.load(file)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "id": "7ff91b8a",
104
+ "metadata": {},
105
+ "outputs": [],
106
+ "source": [
107
+ "for ann_img in tqdm(annotations['images']):\n",
108
+ " file_name = os.path.join(root_ann, ann_img['file_name'])\n",
109
+ " if ann_img['id'] in idx_train:\n",
110
+ " shutil.copy(file_name, f'{root_out}/images/train/{dataset_name}')\n",
111
+ " elif ann_img['id'] in idx_test:\n",
112
+ " shutil.copy(file_name, f'{root_out}/images/val/{dataset_name}')\n",
113
+ " else:\n",
114
+ " raise ValueError('Split wrong')"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": null,
120
+ "id": "a5a19846",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "for ann_ann in tqdm(annotations['annotations']):\n",
125
+ " if ann_ann['category_id'] == 1:\n",
126
+ " category_id = 0\n",
127
+ " elif ann_ann['category_id'] == 3:\n",
128
+ " category_id = 1\n",
129
+ " else:\n",
130
+ " location = ann_ann['attributes'].get('location', '')\n",
131
+ " category_id = flipper_categories[location]\n",
132
+ "\n",
133
+ " image_id = ann_ann['image_id']\n",
134
+ " rle = ann_ann['segmentation'] \n",
135
+ " yolo_segments = uncompressed_rle_to_yolo(rle, class_id=category_id)\n",
136
+ " ann_img = annotations['images'][image_id - 1]\n",
137
+ " base_name = os.path.basename(ann_img['file_name'])\n",
138
+ " base_name = os.path.splitext(base_name)[0] + '.txt'\n",
139
+ "\n",
140
+ " if image_id != ann_img['id']:\n",
141
+ " raise ValueError('Image ids are not ordered')\n",
142
+ " if ann_img['id'] in idx_train:\n",
143
+ " file_name = f'{root_out}/labels/train/{dataset_name}/{base_name}'\n",
144
+ " elif ann_img['id'] in idx_test:\n",
145
+ " file_name = f'{root_out}/labels/val/{dataset_name}/{base_name}'\n",
146
+ " else:\n",
147
+ " raise ValueError('Split wrong')\n",
148
+ "\n",
149
+ " with open(file_name, 'a') as myfile:\n",
150
+ " for yolo_segment in yolo_segments:\n",
151
+ " myfile.write(yolo_segment + '\\n')"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "id": "65a8b7ce",
157
+ "metadata": {},
158
+ "source": [
159
+ "# TurtlesOfSMSRC"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "33d88ac1",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "dataset_name = 'TurtlesOfSMSRC'\n",
170
+ "root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n",
171
+ "\n",
172
+ "dataset = TurtlesOfSMSRC(root)\n",
173
+ "masks = pd.read_csv(f'{root}/masks.csv')\n",
174
+ "masks['mask'] = masks['mask'].apply(parse_bbox_mask)"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "e632972b",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "splitter = ClosedSetSplit(0.8)\n",
185
+ "idx_train, idx_test = splitter.split(dataset.df)[0]"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "id": "e6acc2bd",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "annotation_categories = {\n",
196
+ " 'turtle': 0,\n",
197
+ " 'head': 1,\n",
198
+ " 'flipper_fl': 2,\n",
199
+ " 'flipper_fr': 3,\n",
200
+ " 'flipper_rl': 4,\n",
201
+ " 'flipper_rr': 5,\n",
202
+ "}"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": null,
208
+ "id": "b3377df4",
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "for image_id in tqdm(masks['image_id'].unique()):\n",
213
+ " i = get_index(dataset, image_id)\n",
214
+ " file_name = os.path.join(root, dataset.metadata.loc[i, 'path'])\n",
215
+ " if i in idx_train:\n",
216
+ " shutil.copy(file_name, f'{root_out}/images/train/{dataset_name}')\n",
217
+ " elif i in idx_test:\n",
218
+ " shutil.copy(file_name, f'{root_out}/images/val/{dataset_name}')\n",
219
+ " else:\n",
220
+ " raise ValueError('Split wrong')"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "id": "d1eba0f2",
227
+ "metadata": {},
228
+ "outputs": [],
229
+ "source": [
230
+ "for _, mask in tqdm(masks.iterrows(), total=len(masks)):\n",
231
+ " category_id = annotation_categories[mask['label_side']]\n",
232
+ " image_id = mask['image_id']\n",
233
+ " rle = mask['mask'] \n",
234
+ " yolo_segments = rle_to_yolo(rle, class_id=category_id)\n",
235
+ " i = get_index(dataset, image_id)\n",
236
+ "\n",
237
+ " base_name = os.path.basename(dataset.metadata.loc[i, 'path'])\n",
238
+ " base_name = os.path.splitext(base_name)[0] + '.txt'\n",
239
+ "\n",
240
+ " if i in idx_train:\n",
241
+ " file_name = f'{root_out}/labels/train/{dataset_name}/{base_name}'\n",
242
+ " elif i in idx_test:\n",
243
+ " file_name = f'{root_out}/labels/val/{dataset_name}/{base_name}'\n",
244
+ " else:\n",
245
+ " raise ValueError('Split wrong')\n",
246
+ "\n",
247
+ " with open(file_name, 'a') as myfile:\n",
248
+ " for yolo_segment in yolo_segments:\n",
249
+ " myfile.write(yolo_segment + '\\n')"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "id": "0b6f6683",
255
+ "metadata": {},
256
+ "source": [
257
+ "# Create metadata"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "d5908b52",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "n_repeat = {\n",
268
+ " 'SeaTurtleID2022': 1,\n",
269
+ " 'TurtlesOfSMSRC': 30,\n",
270
+ "}\n",
271
+ "\n",
272
+ "# First split and only then oversample to prevent train-test leak\n",
273
+ "images = find_images(root_out)\n",
274
+ "images = root_out + '/' + images['path'] + '/' + images['file']\n",
275
+ "images_train = images[images.str.contains('/train/')]\n",
276
+ "images_test = images[images.str.contains('/val/')]\n",
277
+ "if len(images_train) + len(images_test) != len(images):\n",
278
+ " raise ValueError('The split into train and test images failed.')\n",
279
+ "\n",
280
+ "# Oversample (even the test set)\n",
281
+ "idx_train = []\n",
282
+ "idx_test = []\n",
283
+ "for dataset_name in ['SeaTurtleID2022', 'TurtlesOfSMSRC']:\n",
284
+ " idx_part = list(images_train[images_train.str.contains(dataset_name)].index)\n",
285
+ " idx_train += n_repeat[dataset_name] * idx_part\n",
286
+ " idx_part = list(images_test[images_test.str.contains(dataset_name)].index)\n",
287
+ " idx_test += n_repeat[dataset_name] * idx_part\n",
288
+ "images_train = images_train.loc[idx_train]\n",
289
+ "images_test = images_test.loc[idx_test]\n",
290
+ "\n",
291
+ "# Save the oversampled splits\n",
292
+ "images_train.to_csv(f'{root_out}/train.txt', header=False, index=False)\n",
293
+ "images_test.to_csv(f'{root_out}/val.txt', header=False, index=False)\n",
294
+ "for dataset_name in ['SeaTurtleID2022', 'TurtlesOfSMSRC']:\n",
295
+ " subset_train = images_train[images_train.str.contains(dataset_name)]\n",
296
+ " subset_train.to_csv(f'{root_out}/train_{dataset_name}.txt', header=False, index=False)\n",
297
+ " subset_test = images_test[images_test.str.contains(dataset_name)]\n",
298
+ " subset_test.to_csv(f'{root_out}/val_{dataset_name}.txt', header=False, index=False)"
299
+ ]
300
+ }
301
+ ],
302
+ "metadata": {
303
+ "kernelspec": {
304
+ "display_name": "sam3",
305
+ "language": "python",
306
+ "name": "python3"
307
+ },
308
+ "language_info": {
309
+ "codemirror_mode": {
310
+ "name": "ipython",
311
+ "version": 3
312
+ },
313
+ "file_extension": ".py",
314
+ "mimetype": "text/x-python",
315
+ "name": "python",
316
+ "nbconvert_exporter": "python",
317
+ "pygments_lexer": "ipython3",
318
+ "version": "3.12.12"
319
+ }
320
+ },
321
+ "nbformat": 4,
322
+ "nbformat_minor": 5
323
+ }
training/segmentation_stage1.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path: /data/wildlife_datasets/turtle-detector
2
+ train: train_SeaTurtleID2022.txt
3
+ val: val_SeaTurtleID2022.txt
4
+
5
+ nc: 6
6
+ names:
7
+ 0: turtle
8
+ 1: head
9
+ 2: flipper_fl
10
+ 3: flipper_fr
11
+ 4: flipper_rl
12
+ 5: flipper_rr
training/segmentation_stage2.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path: /data/wildlife_datasets/turtle-detector
2
+ train: train.txt
3
+ val: val.txt
4
+
5
+ nc: 6
6
+ names:
7
+ 0: turtle
8
+ 1: head
9
+ 2: flipper_fl
10
+ 3: flipper_fr
11
+ 4: flipper_rl
12
+ 5: flipper_rr
training/segmentation_train.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from ultralytics import YOLO
3
+
4
+ project = f"{os.getcwd()}/runs"
5
+ device = "cuda:2"
6
+ imgsz = 640
7
+ epochs = 20
8
+
9
+ # Stage 1: Pretrain on SeaTurtleID2022 (large dataset)
10
+ model = YOLO("yolo11s-seg.pt")
11
+ model.train(
12
+ data="segmentation_stage1.yaml",
13
+ project=project,
14
+ name="stage1",
15
+ epochs=epochs,
16
+ imgsz=imgsz,
17
+ device=device,
18
+ fliplr=0,
19
+ flipud=0,
20
+ )
21
+
22
+ # Stage 2: Fine-tune on combined dataset (balanced)
23
+ model = YOLO(f"{project}/stage1/weights/last.pt")
24
+ model.train(
25
+ data="segmentation_stage2.yaml",
26
+ project=project,
27
+ name="stage2",
28
+ epochs=epochs,
29
+ imgsz=imgsz,
30
+ device=device,
31
+ fliplr=0,
32
+ flipud=0,
33
+
34
+ freeze=5,
35
+ )
training/smsrc_prepare.ipynb ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "926f340c",
6
+ "metadata": {},
7
+ "source": [
8
+ "The notebook prepares the SMSRC data for training of the turtle detector. It uses SAM3 to detect the turtle, its head and flippers. Then it uses a heuristic to assing the left/right and front/rear orientation of the flipper. These assignments were manually checked and fixed when not correct.\n",
9
+ "\n",
10
+ "The output is the notebook is the masks.csv file which is then used in the segmentation_prepare notebook to create the training dataset for detection."
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "6774dc0c",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "import numpy as np\n",
21
+ "import pandas as pd\n",
22
+ "from wildlife_datasets.datasets import TurtlesOfSMSRC\n",
23
+ "from turtle_detector import assign_flippers, initialize_sam3, mask_to_rle, rle_to_mask, compute_iou, mask_to_bbox"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": null,
29
+ "id": "4c8a0449",
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n",
34
+ "dataset = TurtlesOfSMSRC(root)\n",
35
+ "\n",
36
+ "idx_ranges = [\n",
37
+ " (333582414, 333582440),\n",
38
+ " (327367311, 327367335),\n",
39
+ "]\n",
40
+ "idx = np.zeros(len(dataset), dtype=bool)\n",
41
+ "for idx_min, idx_max in idx_ranges:\n",
42
+ " encounter_id = dataset.metadata['encounter_id'].to_numpy()\n",
43
+ " idx += (encounter_id >= idx_min) * (encounter_id <= idx_max)\n",
44
+ "\n",
45
+ "dataset = dataset.get_subset(idx)"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "id": "7a35f952",
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "model, processor = initialize_sam3()"
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "execution_count": null,
61
+ "id": "2f521710",
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "prompt_map = {\n",
66
+ " \"head\": \"turtle head\",\n",
67
+ " \"flipper\": \"turtle flipper\",\n",
68
+ " \"turtle\": \"turtle\",\n",
69
+ "}"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "id": "aef61bcb",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "min_area = 500\n",
80
+ "iou_threshold = 0.1\n",
81
+ "\n",
82
+ "masks = []\n",
83
+ "for i in range(len(dataset)):\n",
84
+ " image_path = f\"{dataset.root}/{dataset.metadata['path'].iloc[i]}\"\n",
85
+ " image = dataset[i]\n",
86
+ " inference_state = processor.set_image(image)\n",
87
+ "\n",
88
+ " for label, prompt in prompt_map.items():\n",
89
+ " processor.reset_all_prompts(inference_state)\n",
90
+ " inference_state = processor.set_text_prompt(state=inference_state, prompt=prompt)\n",
91
+ "\n",
92
+ " for m in inference_state[\"masks\"]:\n",
93
+ " m = m.cpu().numpy().astype(bool)\n",
94
+ " if m.ndim == 3 and m.shape[0] == 1:\n",
95
+ " m = m[0]\n",
96
+ " if m.sum() > min_area:\n",
97
+ " masks.append({\n",
98
+ " 'image_id': dataset.metadata['image_id'].loc[i],\n",
99
+ " 'mask': mask_to_rle(m),\n",
100
+ " 'label': label,\n",
101
+ " })\n",
102
+ "masks = pd.DataFrame(masks)"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "id": "d4a052cd",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "masks['keep'] = True\n",
113
+ "for _, masks_image in masks.groupby('image_id'):\n",
114
+ " keep = masks_image['keep'].copy()\n",
115
+ " for i, (j, mask_j) in enumerate(masks_image.iterrows()):\n",
116
+ " for k, mask_k in masks_image.iloc[i+1:].iterrows(): \n",
117
+ " if not keep.loc[j] or not keep.loc[k]:\n",
118
+ " continue\n",
119
+ " \n",
120
+ " mj = rle_to_mask(masks.loc[j, 'mask'])\n",
121
+ " mk = rle_to_mask(masks.loc[k, 'mask'])\n",
122
+ "\n",
123
+ " iou = compute_iou(mj, mk)\n",
124
+ " if iou < iou_threshold:\n",
125
+ " continue\n",
126
+ "\n",
127
+ " if mask_j['label'] == mask_k['label']:\n",
128
+ " masks.at[j, 'mask'] = mask_to_rle(mj | mk)\n",
129
+ " keep.loc[k] = False\n",
130
+ "\n",
131
+ " elif {\"head\", \"flipper\"} == {mask_j['label'], mask_k['label']}:\n",
132
+ " if (keep * (masks_image['label'] == 'head')).sum() == 1:\n",
133
+ " if mask_j['label'] == \"flipper\":\n",
134
+ " keep.loc[j] = False\n",
135
+ " else:\n",
136
+ " keep.loc[k] = False\n",
137
+ " else:\n",
138
+ " if mask_j['label'] == \"head\":\n",
139
+ " keep.loc[j] = False\n",
140
+ " else:\n",
141
+ " keep.loc[k] = False\n",
142
+ " masks.loc[masks_image.index, 'keep'] = keep\n",
143
+ "masks = masks[masks['keep']]\n",
144
+ "masks = masks.drop('keep', axis=1)\n",
145
+ "\n",
146
+ "for i, m in masks.iterrows():\n",
147
+ " bbox = mask_to_bbox(rle_to_mask(m['mask']))\n",
148
+ " x0, y0, x1, y1 = bbox\n",
149
+ " masks.loc[i, 'bbox_x'] = x0\n",
150
+ " masks.loc[i, 'bbox_y'] = y0\n",
151
+ " masks.loc[i, 'bbox_w'] = x1 - x0\n",
152
+ " masks.loc[i, 'bbox_h'] = y1 - y0\n",
153
+ "\n",
154
+ "for _, masks_image in masks.groupby('image_id'):\n",
155
+ " masks.loc[masks_image.index, 'label_side'] = assign_flippers(masks_image)['label']\n",
156
+ "\n",
157
+ "masks.to_csv('masks.csv', index=False)"
158
+ ]
159
+ }
160
+ ],
161
+ "metadata": {
162
+ "kernelspec": {
163
+ "display_name": "sam3",
164
+ "language": "python",
165
+ "name": "python3"
166
+ },
167
+ "language_info": {
168
+ "codemirror_mode": {
169
+ "name": "ipython",
170
+ "version": 3
171
+ },
172
+ "file_extension": ".py",
173
+ "mimetype": "text/x-python",
174
+ "name": "python",
175
+ "nbconvert_exporter": "python",
176
+ "pygments_lexer": "ipython3",
177
+ "version": "3.12.12"
178
+ }
179
+ },
180
+ "nbformat": 4,
181
+ "nbformat_minor": 5
182
+ }
training/smsrc_visualize.ipynb ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "56e96915",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import numpy as np\n",
12
+ "import pandas as pd\n",
13
+ "import matplotlib.pyplot as plt\n",
14
+ "import matplotlib.patches as patches\n",
15
+ "\n",
16
+ "from wildlife_datasets.datasets import TurtlesOfSMSRC\n",
17
+ "from wildlife_datasets.datasets.utils import parse_bbox_mask\n",
18
+ "from turtle_detector import assign_flippers, initialize_sam3, mask_to_rle, rle_to_mask, compute_iou, mask_to_bbox"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "4c8a0449",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n",
29
+ "root_figures = 'figures'\n",
30
+ "dataset = TurtlesOfSMSRC(root)\n",
31
+ "masks = pd.read_csv('masks.csv')\n",
32
+ "masks['mask'] = masks['mask'].apply(parse_bbox_mask)\n",
33
+ "\n",
34
+ "os.makedirs(root_figures, exist_ok=True)"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "id": "2f521710",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "colors_map = {\n",
45
+ " \"head\": 0,\n",
46
+ " \"flipper\": 1,\n",
47
+ " \"turtle\": 2,\n",
48
+ "}"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "id": "e82f9db7",
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "for image_id, masks_image in masks.groupby('image_id'):\n",
59
+ " i = np.where(dataset.metadata.image_id == image_id)[0][0]\n",
60
+ " image = dataset[i]\n",
61
+ " width, height = image.size\n",
62
+ "\n",
63
+ " overlay = np.zeros((height, width, 3), dtype=np.float32)\n",
64
+ " for _, m in masks_image.iterrows():\n",
65
+ " mask_bool = rle_to_mask(m['mask']).astype(bool)\n",
66
+ " overlay[mask_bool, colors_map[m['label']]] = 1.0\n",
67
+ "\n",
68
+ " fig, ax = plt.subplots(figsize=(8, 8))\n",
69
+ " plt.imshow(image)\n",
70
+ " plt.imshow(overlay, alpha=0.5)\n",
71
+ "\n",
72
+ " for _, m in masks_image.iterrows():\n",
73
+ " rect = patches.Rectangle(\n",
74
+ " (m['bbox_x'], m['bbox_y']),\n",
75
+ " m['bbox_w'],\n",
76
+ " m['bbox_h'],\n",
77
+ " linewidth=2,\n",
78
+ " edgecolor=\"white\",\n",
79
+ " facecolor=\"none\"\n",
80
+ " )\n",
81
+ " ax.add_patch(rect)\n",
82
+ " ax.text(\n",
83
+ " m['bbox_x'],\n",
84
+ " m['bbox_y'] - 3,\n",
85
+ " m['label_side'],\n",
86
+ " color=\"white\",\n",
87
+ " fontsize=10,\n",
88
+ " weight=\"bold\",\n",
89
+ " bbox=dict(facecolor=\"black\", alpha=0.5, pad=2)\n",
90
+ " )\n",
91
+ " \n",
92
+ " n_head = (masks_image['label'] == 'head').sum()\n",
93
+ " n_flipper = (masks_image['label'] == 'flipper').sum()\n",
94
+ " n_turtle = (masks_image['label'] == 'head').sum()\n",
95
+ "\n",
96
+ " plt.axis(\"off\")\n",
97
+ " plt.title(f'{n_head}, {n_flipper}, {n_turtle}')\n",
98
+ " plt.savefig(f'{root_figures}/{image_id}.png', bbox_inches='tight', dpi=600)\n",
99
+ " plt.close()"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "id": "54035a2d",
106
+ "metadata": {},
107
+ "outputs": [],
108
+ "source": [
109
+ "for image_id, masks_image in masks.groupby('image_id'):\n",
110
+ " if masks_image['label_side'].value_counts().max() > 1:\n",
111
+ " print(f'Image id {image_id} has multiple annotations.')\n",
112
+ " display(masks_image)\n",
113
+ "display(masks['label'].value_counts())\n",
114
+ "display(masks['label_side'].value_counts())"
115
+ ]
116
+ }
117
+ ],
118
+ "metadata": {
119
+ "kernelspec": {
120
+ "display_name": "sam3",
121
+ "language": "python",
122
+ "name": "python3"
123
+ },
124
+ "language_info": {
125
+ "codemirror_mode": {
126
+ "name": "ipython",
127
+ "version": 3
128
+ },
129
+ "file_extension": ".py",
130
+ "mimetype": "text/x-python",
131
+ "name": "python",
132
+ "nbconvert_exporter": "python",
133
+ "pygments_lexer": "ipython3",
134
+ "version": "3.12.12"
135
+ }
136
+ },
137
+ "nbformat": 4,
138
+ "nbformat_minor": 5
139
+ }
training/turtle_detector/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .masks import *
2
+ from .utils import assign_flippers, get_index, initialize_sam3
training/turtle_detector/masks.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import pycocotools.mask as mask_utils
4
+ from PIL import ImageDraw
5
+
6
+ def compute_iou(mask_a, mask_b):
7
+ intersection = np.logical_and(mask_a, mask_b).sum()
8
+ union = np.logical_or(mask_a, mask_b).sum()
9
+ return 0.0 if union == 0 else intersection / union
10
+
11
+ def mask_to_bbox(mask):
12
+ ys, xs = np.where(mask)
13
+ if len(xs) == 0:
14
+ return None
15
+ return xs.min(), ys.min(), xs.max(), ys.max()
16
+
17
+ def mask_to_rle(mask, json_safe=True):
18
+ rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
19
+ if json_safe:
20
+ rle["counts"] = rle["counts"].decode("ascii")
21
+ return rle
22
+
23
+ def rle_to_mask(rle):
24
+ rle = rle.copy()
25
+ if isinstance(rle["counts"], str):
26
+ rle["counts"] = rle["counts"].encode("ascii")
27
+ return mask_utils.decode(rle)
28
+
29
+ def uncompressed_rle_to_mask(rle):
30
+ """Decode COCO-style uncompressed RLE into a binary mask (0/1)."""
31
+ h, w = rle["size"]
32
+ counts = rle["counts"]
33
+
34
+ mask = np.zeros(h * w, dtype=np.uint8)
35
+ val = 0
36
+ idx = 0
37
+ for c in counts:
38
+ mask[idx:idx + c] = val
39
+ idx += c
40
+ val = 1 - val
41
+ mask = mask.reshape((h, w), order='F')
42
+ return mask
43
+
44
+ def mask_to_yolo(mask, class_id=0):
45
+ """Convert a binary mask (0/1) into YOLO polygon segmentation format."""
46
+ h, w = mask.shape
47
+
48
+ # ensure 8-bit binary mask
49
+ mask8 = (mask * 255).astype(np.uint8)
50
+
51
+ # find outer contours only
52
+ contours, _ = cv2.findContours(mask8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
53
+
54
+ yolo_segments = []
55
+ for contour in contours:
56
+ if cv2.contourArea(contour) < 100: # ignore tiny noise
57
+ continue
58
+
59
+ contour = contour.squeeze().astype(float)
60
+ if contour.ndim != 2:
61
+ continue
62
+
63
+ # normalize to [0,1]
64
+ contour[:, 0] = contour[:, 0] / float(w)
65
+ contour[:, 1] = contour[:, 1] / float(h)
66
+
67
+ coords = contour.flatten().tolist()
68
+ yolo_segments.append(f"{class_id} " + " ".join(f"{x:.6f}" for x in coords))
69
+
70
+ return yolo_segments
71
+
72
+ def rle_to_yolo(rle, class_id=0):
73
+ mask = rle_to_mask(rle)
74
+ return mask_to_yolo(mask, class_id)
75
+
76
+ def uncompressed_rle_to_yolo(rle, class_id=0):
77
+ mask = uncompressed_rle_to_mask(rle)
78
+ return mask_to_yolo(mask, class_id)
79
+
80
+ def draw_yolo_on_pil(image, yolo_segments, color=(0,255,0)):
81
+ img = image.convert("RGB")
82
+ draw = ImageDraw.Draw(img)
83
+ w, h = img.size
84
+
85
+ for seg in yolo_segments:
86
+ parts = seg.strip().split()
87
+ class_id = int(parts[0])
88
+ coords = np.array([float(x) for x in parts[1:]]).reshape(-1, 2)
89
+ points = [(x * w, y * h) for x, y in coords]
90
+ draw.line(points + [points[0]], fill=color, width=2)
91
+
92
+ return img
training/turtle_detector/utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import sam3
4
+ from sam3 import build_sam3_image_model
5
+ from sam3.model.sam3_image_processor import Sam3Processor
6
+ from .masks import rle_to_mask
7
+
8
+ def get_index(dataset, image_id):
9
+ idx = dataset.metadata['image_id'] == image_id
10
+ if idx.sum() != 1:
11
+ raise ValueError('image_id not found or found multiple times.')
12
+ return dataset.metadata[idx].index[0]
13
+
14
+ def mask_centroid(mask):
15
+ ys, xs = np.nonzero(mask)
16
+ return np.array([xs.mean(), ys.mean()])
17
+
18
+ def rle_centroid(rle):
19
+ return mask_centroid(rle_to_mask(rle))
20
+
21
+ def assign_flippers(df):
22
+ df = df.copy()
23
+
24
+ # Check that there is only one head
25
+ head_rows = df[df['label'] == 'head']
26
+ if len(head_rows) != 1:
27
+ return df
28
+
29
+ # Compute the head centroid
30
+ head_center = rle_centroid(head_rows.iloc[0]['mask'])
31
+
32
+ # Extract the flippers
33
+ flippers = df[df['label'] == 'flipper']
34
+ n_flippers = len(flippers)
35
+ if n_flippers == 0:
36
+ return df
37
+
38
+ # Compute the flipper centroids
39
+ flipper_centers = np.vstack([
40
+ rle_centroid(rle) for rle in flippers['mask']
41
+ ])
42
+
43
+ # Vector from turtle center to head defines "forward"
44
+ turtle_center = flipper_centers.mean(axis=0)
45
+ forward_vec = head_center - turtle_center
46
+ forward_vec /= np.linalg.norm(forward_vec)
47
+
48
+ # Perpendicular defines left/right
49
+ left_vec = np.array([-forward_vec[1], forward_vec[0]])
50
+
51
+ # Project flippers
52
+ forward_proj = flipper_centers @ forward_vec
53
+ lateral_proj = flipper_centers @ left_vec
54
+
55
+ if n_flippers <= 2:
56
+ # Always front flippers
57
+ order = np.argsort(lateral_proj)
58
+ left_idx, right_idx = order[0], order[-1]
59
+
60
+ df.loc[flippers.index[left_idx], 'label'] = 'flipper_fl'
61
+ df.loc[flippers.index[right_idx], 'label'] = 'flipper_fr'
62
+ return df
63
+ elif n_flippers <= 4:
64
+ # Sort by forward distance
65
+ order_fwd = np.argsort(forward_proj)
66
+ rear_idxs = order_fwd[:2]
67
+ front_idxs = order_fwd[-2:]
68
+
69
+ # Front flippers
70
+ front_l = front_idxs[np.argmin(lateral_proj[front_idxs])]
71
+ front_r = front_idxs[np.argmax(lateral_proj[front_idxs])]
72
+
73
+ df.loc[flippers.index[front_l], 'label'] = 'flipper_fl'
74
+ df.loc[flippers.index[front_r], 'label'] = 'flipper_fr'
75
+
76
+ # Rear flippers (if present)
77
+ if len(rear_idxs) == 2:
78
+ rear_l = rear_idxs[np.argmin(lateral_proj[rear_idxs])]
79
+ rear_r = rear_idxs[np.argmax(lateral_proj[rear_idxs])]
80
+
81
+ df.loc[flippers.index[rear_l], 'label'] = 'flipper_rl'
82
+ df.loc[flippers.index[rear_r], 'label'] = 'flipper_rr'
83
+ else:
84
+ # 3 flippers: assign only the most rear one
85
+ idx = rear_idxs[0]
86
+ side = 'l' if lateral_proj[idx] < 0 else 'r'
87
+ df.loc[flippers.index[idx], 'label'] = f'flipper_r{side}'
88
+
89
+ return df
90
+
91
+ def initialize_sam3():
92
+ sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
93
+ bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
94
+ model = build_sam3_image_model(bpe_path=bpe_path)
95
+ processor = Sam3Processor(model, confidence_threshold=0.5)
96
+ return model, processor