Added training scripts
Browse files- training/segmentation_prepare.ipynb +323 -0
- training/segmentation_stage1.yaml +12 -0
- training/segmentation_stage2.yaml +12 -0
- training/segmentation_train.py +35 -0
- training/smsrc_prepare.ipynb +182 -0
- training/smsrc_visualize.ipynb +139 -0
- training/turtle_detector/__init__.py +2 -0
- training/turtle_detector/masks.py +92 -0
- training/turtle_detector/utils.py +96 -0
training/segmentation_prepare.ipynb
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "ebe0faa7",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"This notebook prepares the datasets for training of the turtle detection model. First, it goes through the SeaTurtleID2022 dataset and converts the existing masks into the YOLO format needed by Ultralytics. Then it goes through the TurtlesOfSMSRC dataset, loads the masks created in the smsrc_prepare notebook and again, converts the masks to the YOLO format. Finally, the metadata are merged together and are ready to use the segmentation_train script, which first trains on SeaTurtleID2022 (photos below water) and then finetunes on the combined dataset (photos above water were added)."
|
| 9 |
+
]
|
| 10 |
+
},
|
| 11 |
+
{
|
| 12 |
+
"cell_type": "code",
|
| 13 |
+
"execution_count": null,
|
| 14 |
+
"id": "a2e66c17",
|
| 15 |
+
"metadata": {},
|
| 16 |
+
"outputs": [],
|
| 17 |
+
"source": [
|
| 18 |
+
"import os\n",
|
| 19 |
+
"import json\n",
|
| 20 |
+
"import shutil\n",
|
| 21 |
+
"import numpy as np\n",
|
| 22 |
+
"import pandas as pd\n",
|
| 23 |
+
"from tqdm import tqdm\n",
|
| 24 |
+
"from wildlife_datasets.datasets import SeaTurtleID2022, TurtlesOfSMSRC\n",
|
| 25 |
+
"from wildlife_datasets.datasets.utils import find_images, parse_bbox_mask\n",
|
| 26 |
+
"from wildlife_datasets.splits import ClosedSetSplit\n",
|
| 27 |
+
"from turtle_detector import get_index, rle_to_yolo, uncompressed_rle_to_yolo"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"id": "93be7212",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"outputs": [],
|
| 36 |
+
"source": [
|
| 37 |
+
"root_out = f'/data/wildlife_datasets/turtle-detector'\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"for addition in ['images/train', 'images/val', 'labels/train', 'labels/val']:\n",
|
| 40 |
+
" for dataset_name in ['SeaTurtleID2022', 'TurtlesOfSMSRC']:\n",
|
| 41 |
+
" os.makedirs(os.path.join(root_out, addition, dataset_name), exist_ok=True)"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "markdown",
|
| 46 |
+
"id": "14b3e193",
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"source": [
|
| 49 |
+
"# SeaTurtleID2022"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": null,
|
| 55 |
+
"id": "4c664fa1",
|
| 56 |
+
"metadata": {},
|
| 57 |
+
"outputs": [],
|
| 58 |
+
"source": [
|
| 59 |
+
"dataset_name = 'SeaTurtleID2022'\n",
|
| 60 |
+
"root = '/data/wildlife_datasets/data/SeaTurtleID2022'\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"dataset = SeaTurtleID2022(root)\n",
|
| 63 |
+
"if dataset.df['path'].nunique() != len(dataset):\n",
|
| 64 |
+
" raise ValueError('path is not unique')"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": null,
|
| 70 |
+
"id": "64391cd5",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"outputs": [],
|
| 73 |
+
"source": [
|
| 74 |
+
"splitter = ClosedSetSplit(0.8)\n",
|
| 75 |
+
"idx_train, idx_test = splitter.split(dataset.df)[0]\n",
|
| 76 |
+
"idx_train += 1\n",
|
| 77 |
+
"idx_test += 1"
|
| 78 |
+
]
|
| 79 |
+
},
|
| 80 |
+
{
|
| 81 |
+
"cell_type": "code",
|
| 82 |
+
"execution_count": null,
|
| 83 |
+
"id": "057b036e",
|
| 84 |
+
"metadata": {},
|
| 85 |
+
"outputs": [],
|
| 86 |
+
"source": [
|
| 87 |
+
"flipper_categories = {\n",
|
| 88 |
+
" '': 0,\n",
|
| 89 |
+
" 'front_left': 2,\n",
|
| 90 |
+
" 'front_right': 3,\n",
|
| 91 |
+
" 'rear_left': 4,\n",
|
| 92 |
+
" 'rear_right': 5,\n",
|
| 93 |
+
"}\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"root_ann = f'{root}/turtles-data/data'\n",
|
| 96 |
+
"with open(os.path.join(root_ann, 'annotations.json')) as file:\n",
|
| 97 |
+
" annotations = json.load(file)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"id": "7ff91b8a",
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"for ann_img in tqdm(annotations['images']):\n",
|
| 108 |
+
" file_name = os.path.join(root_ann, ann_img['file_name'])\n",
|
| 109 |
+
" if ann_img['id'] in idx_train:\n",
|
| 110 |
+
" shutil.copy(file_name, f'{root_out}/images/train/{dataset_name}')\n",
|
| 111 |
+
" elif ann_img['id'] in idx_test:\n",
|
| 112 |
+
" shutil.copy(file_name, f'{root_out}/images/val/{dataset_name}')\n",
|
| 113 |
+
" else:\n",
|
| 114 |
+
" raise ValueError('Split wrong')"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"execution_count": null,
|
| 120 |
+
"id": "a5a19846",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"outputs": [],
|
| 123 |
+
"source": [
|
| 124 |
+
"for ann_ann in tqdm(annotations['annotations']):\n",
|
| 125 |
+
" if ann_ann['category_id'] == 1:\n",
|
| 126 |
+
" category_id = 0\n",
|
| 127 |
+
" elif ann_ann['category_id'] == 3:\n",
|
| 128 |
+
" category_id = 1\n",
|
| 129 |
+
" else:\n",
|
| 130 |
+
" location = ann_ann['attributes'].get('location', '')\n",
|
| 131 |
+
" category_id = flipper_categories[location]\n",
|
| 132 |
+
"\n",
|
| 133 |
+
" image_id = ann_ann['image_id']\n",
|
| 134 |
+
" rle = ann_ann['segmentation'] \n",
|
| 135 |
+
" yolo_segments = uncompressed_rle_to_yolo(rle, class_id=category_id)\n",
|
| 136 |
+
" ann_img = annotations['images'][image_id - 1]\n",
|
| 137 |
+
" base_name = os.path.basename(ann_img['file_name'])\n",
|
| 138 |
+
" base_name = os.path.splitext(base_name)[0] + '.txt'\n",
|
| 139 |
+
"\n",
|
| 140 |
+
" if image_id != ann_img['id']:\n",
|
| 141 |
+
" raise ValueError('Image ids are not ordered')\n",
|
| 142 |
+
" if ann_img['id'] in idx_train:\n",
|
| 143 |
+
" file_name = f'{root_out}/labels/train/{dataset_name}/{base_name}'\n",
|
| 144 |
+
" elif ann_img['id'] in idx_test:\n",
|
| 145 |
+
" file_name = f'{root_out}/labels/val/{dataset_name}/{base_name}'\n",
|
| 146 |
+
" else:\n",
|
| 147 |
+
" raise ValueError('Split wrong')\n",
|
| 148 |
+
"\n",
|
| 149 |
+
" with open(file_name, 'a') as myfile:\n",
|
| 150 |
+
" for yolo_segment in yolo_segments:\n",
|
| 151 |
+
" myfile.write(yolo_segment + '\\n')"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "markdown",
|
| 156 |
+
"id": "65a8b7ce",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"source": [
|
| 159 |
+
"# TurtlesOfSMSRC"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"cell_type": "code",
|
| 164 |
+
"execution_count": null,
|
| 165 |
+
"id": "33d88ac1",
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"outputs": [],
|
| 168 |
+
"source": [
|
| 169 |
+
"dataset_name = 'TurtlesOfSMSRC'\n",
|
| 170 |
+
"root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"dataset = TurtlesOfSMSRC(root)\n",
|
| 173 |
+
"masks = pd.read_csv(f'{root}/masks.csv')\n",
|
| 174 |
+
"masks['mask'] = masks['mask'].apply(parse_bbox_mask)"
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": null,
|
| 180 |
+
"id": "e632972b",
|
| 181 |
+
"metadata": {},
|
| 182 |
+
"outputs": [],
|
| 183 |
+
"source": [
|
| 184 |
+
"splitter = ClosedSetSplit(0.8)\n",
|
| 185 |
+
"idx_train, idx_test = splitter.split(dataset.df)[0]"
|
| 186 |
+
]
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"cell_type": "code",
|
| 190 |
+
"execution_count": null,
|
| 191 |
+
"id": "e6acc2bd",
|
| 192 |
+
"metadata": {},
|
| 193 |
+
"outputs": [],
|
| 194 |
+
"source": [
|
| 195 |
+
"annotation_categories = {\n",
|
| 196 |
+
" 'turtle': 0,\n",
|
| 197 |
+
" 'head': 1,\n",
|
| 198 |
+
" 'flipper_fl': 2,\n",
|
| 199 |
+
" 'flipper_fr': 3,\n",
|
| 200 |
+
" 'flipper_rl': 4,\n",
|
| 201 |
+
" 'flipper_rr': 5,\n",
|
| 202 |
+
"}"
|
| 203 |
+
]
|
| 204 |
+
},
|
| 205 |
+
{
|
| 206 |
+
"cell_type": "code",
|
| 207 |
+
"execution_count": null,
|
| 208 |
+
"id": "b3377df4",
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"outputs": [],
|
| 211 |
+
"source": [
|
| 212 |
+
"for image_id in tqdm(masks['image_id'].unique()):\n",
|
| 213 |
+
" i = get_index(dataset, image_id)\n",
|
| 214 |
+
" file_name = os.path.join(root, dataset.metadata.loc[i, 'path'])\n",
|
| 215 |
+
" if i in idx_train:\n",
|
| 216 |
+
" shutil.copy(file_name, f'{root_out}/images/train/{dataset_name}')\n",
|
| 217 |
+
" elif i in idx_test:\n",
|
| 218 |
+
" shutil.copy(file_name, f'{root_out}/images/val/{dataset_name}')\n",
|
| 219 |
+
" else:\n",
|
| 220 |
+
" raise ValueError('Split wrong')"
|
| 221 |
+
]
|
| 222 |
+
},
|
| 223 |
+
{
|
| 224 |
+
"cell_type": "code",
|
| 225 |
+
"execution_count": null,
|
| 226 |
+
"id": "d1eba0f2",
|
| 227 |
+
"metadata": {},
|
| 228 |
+
"outputs": [],
|
| 229 |
+
"source": [
|
| 230 |
+
"for _, mask in tqdm(masks.iterrows(), total=len(masks)):\n",
|
| 231 |
+
" category_id = annotation_categories[mask['label_side']]\n",
|
| 232 |
+
" image_id = mask['image_id']\n",
|
| 233 |
+
" rle = mask['mask'] \n",
|
| 234 |
+
" yolo_segments = rle_to_yolo(rle, class_id=category_id)\n",
|
| 235 |
+
" i = get_index(dataset, image_id)\n",
|
| 236 |
+
"\n",
|
| 237 |
+
" base_name = os.path.basename(dataset.metadata.loc[i, 'path'])\n",
|
| 238 |
+
" base_name = os.path.splitext(base_name)[0] + '.txt'\n",
|
| 239 |
+
"\n",
|
| 240 |
+
" if i in idx_train:\n",
|
| 241 |
+
" file_name = f'{root_out}/labels/train/{dataset_name}/{base_name}'\n",
|
| 242 |
+
" elif i in idx_test:\n",
|
| 243 |
+
" file_name = f'{root_out}/labels/val/{dataset_name}/{base_name}'\n",
|
| 244 |
+
" else:\n",
|
| 245 |
+
" raise ValueError('Split wrong')\n",
|
| 246 |
+
"\n",
|
| 247 |
+
" with open(file_name, 'a') as myfile:\n",
|
| 248 |
+
" for yolo_segment in yolo_segments:\n",
|
| 249 |
+
" myfile.write(yolo_segment + '\\n')"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "markdown",
|
| 254 |
+
"id": "0b6f6683",
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"source": [
|
| 257 |
+
"# Create metadata"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"id": "d5908b52",
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"n_repeat = {\n",
|
| 268 |
+
" 'SeaTurtleID2022': 1,\n",
|
| 269 |
+
" 'TurtlesOfSMSRC': 30,\n",
|
| 270 |
+
"}\n",
|
| 271 |
+
"\n",
|
| 272 |
+
"# First split and only then oversample to prevent train-test leak\n",
|
| 273 |
+
"images = find_images(root_out)\n",
|
| 274 |
+
"images = root_out + '/' + images['path'] + '/' + images['file']\n",
|
| 275 |
+
"images_train = images[images.str.contains('/train/')]\n",
|
| 276 |
+
"images_test = images[images.str.contains('/val/')]\n",
|
| 277 |
+
"if len(images_train) + len(images_test) != len(images):\n",
|
| 278 |
+
" raise ValueError('The split into train and test images failed.')\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"# Oversample (even the test set)\n",
|
| 281 |
+
"idx_train = []\n",
|
| 282 |
+
"idx_test = []\n",
|
| 283 |
+
"for dataset_name in ['SeaTurtleID2022', 'TurtlesOfSMSRC']:\n",
|
| 284 |
+
" idx_part = list(images_train[images_train.str.contains(dataset_name)].index)\n",
|
| 285 |
+
" idx_train += n_repeat[dataset_name] * idx_part\n",
|
| 286 |
+
" idx_part = list(images_test[images_test.str.contains(dataset_name)].index)\n",
|
| 287 |
+
" idx_test += n_repeat[dataset_name] * idx_part\n",
|
| 288 |
+
"images_train = images_train.loc[idx_train]\n",
|
| 289 |
+
"images_test = images_test.loc[idx_test]\n",
|
| 290 |
+
"\n",
|
| 291 |
+
"# Save the oversampled splits\n",
|
| 292 |
+
"images_train.to_csv(f'{root_out}/train.txt', header=False, index=False)\n",
|
| 293 |
+
"images_test.to_csv(f'{root_out}/val.txt', header=False, index=False)\n",
|
| 294 |
+
"for dataset_name in ['SeaTurtleID2022', 'TurtlesOfSMSRC']:\n",
|
| 295 |
+
" subset_train = images_train[images_train.str.contains(dataset_name)]\n",
|
| 296 |
+
" subset_train.to_csv(f'{root_out}/train_{dataset_name}.txt', header=False, index=False)\n",
|
| 297 |
+
" subset_test = images_test[images_test.str.contains(dataset_name)]\n",
|
| 298 |
+
" subset_test.to_csv(f'{root_out}/val_{dataset_name}.txt', header=False, index=False)"
|
| 299 |
+
]
|
| 300 |
+
}
|
| 301 |
+
],
|
| 302 |
+
"metadata": {
|
| 303 |
+
"kernelspec": {
|
| 304 |
+
"display_name": "sam3",
|
| 305 |
+
"language": "python",
|
| 306 |
+
"name": "python3"
|
| 307 |
+
},
|
| 308 |
+
"language_info": {
|
| 309 |
+
"codemirror_mode": {
|
| 310 |
+
"name": "ipython",
|
| 311 |
+
"version": 3
|
| 312 |
+
},
|
| 313 |
+
"file_extension": ".py",
|
| 314 |
+
"mimetype": "text/x-python",
|
| 315 |
+
"name": "python",
|
| 316 |
+
"nbconvert_exporter": "python",
|
| 317 |
+
"pygments_lexer": "ipython3",
|
| 318 |
+
"version": "3.12.12"
|
| 319 |
+
}
|
| 320 |
+
},
|
| 321 |
+
"nbformat": 4,
|
| 322 |
+
"nbformat_minor": 5
|
| 323 |
+
}
|
training/segmentation_stage1.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
path: /data/wildlife_datasets/turtle-detector
|
| 2 |
+
train: train_SeaTurtleID2022.txt
|
| 3 |
+
val: val_SeaTurtleID2022.txt
|
| 4 |
+
|
| 5 |
+
nc: 6
|
| 6 |
+
names:
|
| 7 |
+
0: turtle
|
| 8 |
+
1: head
|
| 9 |
+
2: flipper_fl
|
| 10 |
+
3: flipper_fr
|
| 11 |
+
4: flipper_rl
|
| 12 |
+
5: flipper_rr
|
training/segmentation_stage2.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
path: /data/wildlife_datasets/turtle-detector
|
| 2 |
+
train: train.txt
|
| 3 |
+
val: val.txt
|
| 4 |
+
|
| 5 |
+
nc: 6
|
| 6 |
+
names:
|
| 7 |
+
0: turtle
|
| 8 |
+
1: head
|
| 9 |
+
2: flipper_fl
|
| 10 |
+
3: flipper_fr
|
| 11 |
+
4: flipper_rl
|
| 12 |
+
5: flipper_rr
|
training/segmentation_train.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from ultralytics import YOLO
|
| 3 |
+
|
| 4 |
+
project = f"{os.getcwd()}/runs"
|
| 5 |
+
device = "cuda:2"
|
| 6 |
+
imgsz = 640
|
| 7 |
+
epochs = 20
|
| 8 |
+
|
| 9 |
+
# Stage 1: Pretrain on SeaTurtleID2022 (large dataset)
|
| 10 |
+
model = YOLO("yolo11s-seg.pt")
|
| 11 |
+
model.train(
|
| 12 |
+
data="segmentation_stage1.yaml",
|
| 13 |
+
project=project,
|
| 14 |
+
name="stage1",
|
| 15 |
+
epochs=epochs,
|
| 16 |
+
imgsz=imgsz,
|
| 17 |
+
device=device,
|
| 18 |
+
fliplr=0,
|
| 19 |
+
flipud=0,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# Stage 2: Fine-tune on combined dataset (balanced)
|
| 23 |
+
model = YOLO(f"{project}/stage1/weights/last.pt")
|
| 24 |
+
model.train(
|
| 25 |
+
data="segmentation_stage2.yaml",
|
| 26 |
+
project=project,
|
| 27 |
+
name="stage2",
|
| 28 |
+
epochs=epochs,
|
| 29 |
+
imgsz=imgsz,
|
| 30 |
+
device=device,
|
| 31 |
+
fliplr=0,
|
| 32 |
+
flipud=0,
|
| 33 |
+
|
| 34 |
+
freeze=5,
|
| 35 |
+
)
|
training/smsrc_prepare.ipynb
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "926f340c",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"The notebook prepares the SMSRC data for training of the turtle detector. It uses SAM3 to detect the turtle, its head and flippers. Then it uses a heuristic to assing the left/right and front/rear orientation of the flipper. These assignments were manually checked and fixed when not correct.\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"The output is the notebook is the masks.csv file which is then used in the segmentation_prepare notebook to create the training dataset for detection."
|
| 11 |
+
]
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"cell_type": "code",
|
| 15 |
+
"execution_count": null,
|
| 16 |
+
"id": "6774dc0c",
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"outputs": [],
|
| 19 |
+
"source": [
|
| 20 |
+
"import numpy as np\n",
|
| 21 |
+
"import pandas as pd\n",
|
| 22 |
+
"from wildlife_datasets.datasets import TurtlesOfSMSRC\n",
|
| 23 |
+
"from turtle_detector import assign_flippers, initialize_sam3, mask_to_rle, rle_to_mask, compute_iou, mask_to_bbox"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": null,
|
| 29 |
+
"id": "4c8a0449",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"outputs": [],
|
| 32 |
+
"source": [
|
| 33 |
+
"root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n",
|
| 34 |
+
"dataset = TurtlesOfSMSRC(root)\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"idx_ranges = [\n",
|
| 37 |
+
" (333582414, 333582440),\n",
|
| 38 |
+
" (327367311, 327367335),\n",
|
| 39 |
+
"]\n",
|
| 40 |
+
"idx = np.zeros(len(dataset), dtype=bool)\n",
|
| 41 |
+
"for idx_min, idx_max in idx_ranges:\n",
|
| 42 |
+
" encounter_id = dataset.metadata['encounter_id'].to_numpy()\n",
|
| 43 |
+
" idx += (encounter_id >= idx_min) * (encounter_id <= idx_max)\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"dataset = dataset.get_subset(idx)"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "code",
|
| 50 |
+
"execution_count": null,
|
| 51 |
+
"id": "7a35f952",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"model, processor = initialize_sam3()"
|
| 56 |
+
]
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"cell_type": "code",
|
| 60 |
+
"execution_count": null,
|
| 61 |
+
"id": "2f521710",
|
| 62 |
+
"metadata": {},
|
| 63 |
+
"outputs": [],
|
| 64 |
+
"source": [
|
| 65 |
+
"prompt_map = {\n",
|
| 66 |
+
" \"head\": \"turtle head\",\n",
|
| 67 |
+
" \"flipper\": \"turtle flipper\",\n",
|
| 68 |
+
" \"turtle\": \"turtle\",\n",
|
| 69 |
+
"}"
|
| 70 |
+
]
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"id": "aef61bcb",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"outputs": [],
|
| 78 |
+
"source": [
|
| 79 |
+
"min_area = 500\n",
|
| 80 |
+
"iou_threshold = 0.1\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"masks = []\n",
|
| 83 |
+
"for i in range(len(dataset)):\n",
|
| 84 |
+
" image_path = f\"{dataset.root}/{dataset.metadata['path'].iloc[i]}\"\n",
|
| 85 |
+
" image = dataset[i]\n",
|
| 86 |
+
" inference_state = processor.set_image(image)\n",
|
| 87 |
+
"\n",
|
| 88 |
+
" for label, prompt in prompt_map.items():\n",
|
| 89 |
+
" processor.reset_all_prompts(inference_state)\n",
|
| 90 |
+
" inference_state = processor.set_text_prompt(state=inference_state, prompt=prompt)\n",
|
| 91 |
+
"\n",
|
| 92 |
+
" for m in inference_state[\"masks\"]:\n",
|
| 93 |
+
" m = m.cpu().numpy().astype(bool)\n",
|
| 94 |
+
" if m.ndim == 3 and m.shape[0] == 1:\n",
|
| 95 |
+
" m = m[0]\n",
|
| 96 |
+
" if m.sum() > min_area:\n",
|
| 97 |
+
" masks.append({\n",
|
| 98 |
+
" 'image_id': dataset.metadata['image_id'].loc[i],\n",
|
| 99 |
+
" 'mask': mask_to_rle(m),\n",
|
| 100 |
+
" 'label': label,\n",
|
| 101 |
+
" })\n",
|
| 102 |
+
"masks = pd.DataFrame(masks)"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": null,
|
| 108 |
+
"id": "d4a052cd",
|
| 109 |
+
"metadata": {},
|
| 110 |
+
"outputs": [],
|
| 111 |
+
"source": [
|
| 112 |
+
"masks['keep'] = True\n",
|
| 113 |
+
"for _, masks_image in masks.groupby('image_id'):\n",
|
| 114 |
+
" keep = masks_image['keep'].copy()\n",
|
| 115 |
+
" for i, (j, mask_j) in enumerate(masks_image.iterrows()):\n",
|
| 116 |
+
" for k, mask_k in masks_image.iloc[i+1:].iterrows(): \n",
|
| 117 |
+
" if not keep.loc[j] or not keep.loc[k]:\n",
|
| 118 |
+
" continue\n",
|
| 119 |
+
" \n",
|
| 120 |
+
" mj = rle_to_mask(masks.loc[j, 'mask'])\n",
|
| 121 |
+
" mk = rle_to_mask(masks.loc[k, 'mask'])\n",
|
| 122 |
+
"\n",
|
| 123 |
+
" iou = compute_iou(mj, mk)\n",
|
| 124 |
+
" if iou < iou_threshold:\n",
|
| 125 |
+
" continue\n",
|
| 126 |
+
"\n",
|
| 127 |
+
" if mask_j['label'] == mask_k['label']:\n",
|
| 128 |
+
" masks.at[j, 'mask'] = mask_to_rle(mj | mk)\n",
|
| 129 |
+
" keep.loc[k] = False\n",
|
| 130 |
+
"\n",
|
| 131 |
+
" elif {\"head\", \"flipper\"} == {mask_j['label'], mask_k['label']}:\n",
|
| 132 |
+
" if (keep * (masks_image['label'] == 'head')).sum() == 1:\n",
|
| 133 |
+
" if mask_j['label'] == \"flipper\":\n",
|
| 134 |
+
" keep.loc[j] = False\n",
|
| 135 |
+
" else:\n",
|
| 136 |
+
" keep.loc[k] = False\n",
|
| 137 |
+
" else:\n",
|
| 138 |
+
" if mask_j['label'] == \"head\":\n",
|
| 139 |
+
" keep.loc[j] = False\n",
|
| 140 |
+
" else:\n",
|
| 141 |
+
" keep.loc[k] = False\n",
|
| 142 |
+
" masks.loc[masks_image.index, 'keep'] = keep\n",
|
| 143 |
+
"masks = masks[masks['keep']]\n",
|
| 144 |
+
"masks = masks.drop('keep', axis=1)\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"for i, m in masks.iterrows():\n",
|
| 147 |
+
" bbox = mask_to_bbox(rle_to_mask(m['mask']))\n",
|
| 148 |
+
" x0, y0, x1, y1 = bbox\n",
|
| 149 |
+
" masks.loc[i, 'bbox_x'] = x0\n",
|
| 150 |
+
" masks.loc[i, 'bbox_y'] = y0\n",
|
| 151 |
+
" masks.loc[i, 'bbox_w'] = x1 - x0\n",
|
| 152 |
+
" masks.loc[i, 'bbox_h'] = y1 - y0\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"for _, masks_image in masks.groupby('image_id'):\n",
|
| 155 |
+
" masks.loc[masks_image.index, 'label_side'] = assign_flippers(masks_image)['label']\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"masks.to_csv('masks.csv', index=False)"
|
| 158 |
+
]
|
| 159 |
+
}
|
| 160 |
+
],
|
| 161 |
+
"metadata": {
|
| 162 |
+
"kernelspec": {
|
| 163 |
+
"display_name": "sam3",
|
| 164 |
+
"language": "python",
|
| 165 |
+
"name": "python3"
|
| 166 |
+
},
|
| 167 |
+
"language_info": {
|
| 168 |
+
"codemirror_mode": {
|
| 169 |
+
"name": "ipython",
|
| 170 |
+
"version": 3
|
| 171 |
+
},
|
| 172 |
+
"file_extension": ".py",
|
| 173 |
+
"mimetype": "text/x-python",
|
| 174 |
+
"name": "python",
|
| 175 |
+
"nbconvert_exporter": "python",
|
| 176 |
+
"pygments_lexer": "ipython3",
|
| 177 |
+
"version": "3.12.12"
|
| 178 |
+
}
|
| 179 |
+
},
|
| 180 |
+
"nbformat": 4,
|
| 181 |
+
"nbformat_minor": 5
|
| 182 |
+
}
|
training/smsrc_visualize.ipynb
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"id": "56e96915",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [],
|
| 9 |
+
"source": [
|
| 10 |
+
"import os\n",
|
| 11 |
+
"import numpy as np\n",
|
| 12 |
+
"import pandas as pd\n",
|
| 13 |
+
"import matplotlib.pyplot as plt\n",
|
| 14 |
+
"import matplotlib.patches as patches\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"from wildlife_datasets.datasets import TurtlesOfSMSRC\n",
|
| 17 |
+
"from wildlife_datasets.datasets.utils import parse_bbox_mask\n",
|
| 18 |
+
"from turtle_detector import assign_flippers, initialize_sam3, mask_to_rle, rle_to_mask, compute_iou, mask_to_bbox"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
+
"id": "4c8a0449",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"outputs": [],
|
| 27 |
+
"source": [
|
| 28 |
+
"root = '/data/wildlife_datasets/TurtlesOfSMSRC'\n",
|
| 29 |
+
"root_figures = 'figures'\n",
|
| 30 |
+
"dataset = TurtlesOfSMSRC(root)\n",
|
| 31 |
+
"masks = pd.read_csv('masks.csv')\n",
|
| 32 |
+
"masks['mask'] = masks['mask'].apply(parse_bbox_mask)\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"os.makedirs(root_figures, exist_ok=True)"
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": null,
|
| 40 |
+
"id": "2f521710",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"colors_map = {\n",
|
| 45 |
+
" \"head\": 0,\n",
|
| 46 |
+
" \"flipper\": 1,\n",
|
| 47 |
+
" \"turtle\": 2,\n",
|
| 48 |
+
"}"
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
{
|
| 52 |
+
"cell_type": "code",
|
| 53 |
+
"execution_count": null,
|
| 54 |
+
"id": "e82f9db7",
|
| 55 |
+
"metadata": {},
|
| 56 |
+
"outputs": [],
|
| 57 |
+
"source": [
|
| 58 |
+
"for image_id, masks_image in masks.groupby('image_id'):\n",
|
| 59 |
+
" i = np.where(dataset.metadata.image_id == image_id)[0][0]\n",
|
| 60 |
+
" image = dataset[i]\n",
|
| 61 |
+
" width, height = image.size\n",
|
| 62 |
+
"\n",
|
| 63 |
+
" overlay = np.zeros((height, width, 3), dtype=np.float32)\n",
|
| 64 |
+
" for _, m in masks_image.iterrows():\n",
|
| 65 |
+
" mask_bool = rle_to_mask(m['mask']).astype(bool)\n",
|
| 66 |
+
" overlay[mask_bool, colors_map[m['label']]] = 1.0\n",
|
| 67 |
+
"\n",
|
| 68 |
+
" fig, ax = plt.subplots(figsize=(8, 8))\n",
|
| 69 |
+
" plt.imshow(image)\n",
|
| 70 |
+
" plt.imshow(overlay, alpha=0.5)\n",
|
| 71 |
+
"\n",
|
| 72 |
+
" for _, m in masks_image.iterrows():\n",
|
| 73 |
+
" rect = patches.Rectangle(\n",
|
| 74 |
+
" (m['bbox_x'], m['bbox_y']),\n",
|
| 75 |
+
" m['bbox_w'],\n",
|
| 76 |
+
" m['bbox_h'],\n",
|
| 77 |
+
" linewidth=2,\n",
|
| 78 |
+
" edgecolor=\"white\",\n",
|
| 79 |
+
" facecolor=\"none\"\n",
|
| 80 |
+
" )\n",
|
| 81 |
+
" ax.add_patch(rect)\n",
|
| 82 |
+
" ax.text(\n",
|
| 83 |
+
" m['bbox_x'],\n",
|
| 84 |
+
" m['bbox_y'] - 3,\n",
|
| 85 |
+
" m['label_side'],\n",
|
| 86 |
+
" color=\"white\",\n",
|
| 87 |
+
" fontsize=10,\n",
|
| 88 |
+
" weight=\"bold\",\n",
|
| 89 |
+
" bbox=dict(facecolor=\"black\", alpha=0.5, pad=2)\n",
|
| 90 |
+
" )\n",
|
| 91 |
+
" \n",
|
| 92 |
+
" n_head = (masks_image['label'] == 'head').sum()\n",
|
| 93 |
+
" n_flipper = (masks_image['label'] == 'flipper').sum()\n",
|
| 94 |
+
" n_turtle = (masks_image['label'] == 'head').sum()\n",
|
| 95 |
+
"\n",
|
| 96 |
+
" plt.axis(\"off\")\n",
|
| 97 |
+
" plt.title(f'{n_head}, {n_flipper}, {n_turtle}')\n",
|
| 98 |
+
" plt.savefig(f'{root_figures}/{image_id}.png', bbox_inches='tight', dpi=600)\n",
|
| 99 |
+
" plt.close()"
|
| 100 |
+
]
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"cell_type": "code",
|
| 104 |
+
"execution_count": null,
|
| 105 |
+
"id": "54035a2d",
|
| 106 |
+
"metadata": {},
|
| 107 |
+
"outputs": [],
|
| 108 |
+
"source": [
|
| 109 |
+
"for image_id, masks_image in masks.groupby('image_id'):\n",
|
| 110 |
+
" if masks_image['label_side'].value_counts().max() > 1:\n",
|
| 111 |
+
" print(f'Image id {image_id} has multiple annotations.')\n",
|
| 112 |
+
" display(masks_image)\n",
|
| 113 |
+
"display(masks['label'].value_counts())\n",
|
| 114 |
+
"display(masks['label_side'].value_counts())"
|
| 115 |
+
]
|
| 116 |
+
}
|
| 117 |
+
],
|
| 118 |
+
"metadata": {
|
| 119 |
+
"kernelspec": {
|
| 120 |
+
"display_name": "sam3",
|
| 121 |
+
"language": "python",
|
| 122 |
+
"name": "python3"
|
| 123 |
+
},
|
| 124 |
+
"language_info": {
|
| 125 |
+
"codemirror_mode": {
|
| 126 |
+
"name": "ipython",
|
| 127 |
+
"version": 3
|
| 128 |
+
},
|
| 129 |
+
"file_extension": ".py",
|
| 130 |
+
"mimetype": "text/x-python",
|
| 131 |
+
"name": "python",
|
| 132 |
+
"nbconvert_exporter": "python",
|
| 133 |
+
"pygments_lexer": "ipython3",
|
| 134 |
+
"version": "3.12.12"
|
| 135 |
+
}
|
| 136 |
+
},
|
| 137 |
+
"nbformat": 4,
|
| 138 |
+
"nbformat_minor": 5
|
| 139 |
+
}
|
training/turtle_detector/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .masks import *
|
| 2 |
+
from .utils import assign_flippers, get_index, initialize_sam3
|
training/turtle_detector/masks.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import pycocotools.mask as mask_utils
|
| 4 |
+
from PIL import ImageDraw
|
| 5 |
+
|
| 6 |
+
def compute_iou(mask_a, mask_b):
|
| 7 |
+
intersection = np.logical_and(mask_a, mask_b).sum()
|
| 8 |
+
union = np.logical_or(mask_a, mask_b).sum()
|
| 9 |
+
return 0.0 if union == 0 else intersection / union
|
| 10 |
+
|
| 11 |
+
def mask_to_bbox(mask):
|
| 12 |
+
ys, xs = np.where(mask)
|
| 13 |
+
if len(xs) == 0:
|
| 14 |
+
return None
|
| 15 |
+
return xs.min(), ys.min(), xs.max(), ys.max()
|
| 16 |
+
|
| 17 |
+
def mask_to_rle(mask, json_safe=True):
|
| 18 |
+
rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8)))
|
| 19 |
+
if json_safe:
|
| 20 |
+
rle["counts"] = rle["counts"].decode("ascii")
|
| 21 |
+
return rle
|
| 22 |
+
|
| 23 |
+
def rle_to_mask(rle):
|
| 24 |
+
rle = rle.copy()
|
| 25 |
+
if isinstance(rle["counts"], str):
|
| 26 |
+
rle["counts"] = rle["counts"].encode("ascii")
|
| 27 |
+
return mask_utils.decode(rle)
|
| 28 |
+
|
| 29 |
+
def uncompressed_rle_to_mask(rle):
|
| 30 |
+
"""Decode COCO-style uncompressed RLE into a binary mask (0/1)."""
|
| 31 |
+
h, w = rle["size"]
|
| 32 |
+
counts = rle["counts"]
|
| 33 |
+
|
| 34 |
+
mask = np.zeros(h * w, dtype=np.uint8)
|
| 35 |
+
val = 0
|
| 36 |
+
idx = 0
|
| 37 |
+
for c in counts:
|
| 38 |
+
mask[idx:idx + c] = val
|
| 39 |
+
idx += c
|
| 40 |
+
val = 1 - val
|
| 41 |
+
mask = mask.reshape((h, w), order='F')
|
| 42 |
+
return mask
|
| 43 |
+
|
| 44 |
+
def mask_to_yolo(mask, class_id=0):
|
| 45 |
+
"""Convert a binary mask (0/1) into YOLO polygon segmentation format."""
|
| 46 |
+
h, w = mask.shape
|
| 47 |
+
|
| 48 |
+
# ensure 8-bit binary mask
|
| 49 |
+
mask8 = (mask * 255).astype(np.uint8)
|
| 50 |
+
|
| 51 |
+
# find outer contours only
|
| 52 |
+
contours, _ = cv2.findContours(mask8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 53 |
+
|
| 54 |
+
yolo_segments = []
|
| 55 |
+
for contour in contours:
|
| 56 |
+
if cv2.contourArea(contour) < 100: # ignore tiny noise
|
| 57 |
+
continue
|
| 58 |
+
|
| 59 |
+
contour = contour.squeeze().astype(float)
|
| 60 |
+
if contour.ndim != 2:
|
| 61 |
+
continue
|
| 62 |
+
|
| 63 |
+
# normalize to [0,1]
|
| 64 |
+
contour[:, 0] = contour[:, 0] / float(w)
|
| 65 |
+
contour[:, 1] = contour[:, 1] / float(h)
|
| 66 |
+
|
| 67 |
+
coords = contour.flatten().tolist()
|
| 68 |
+
yolo_segments.append(f"{class_id} " + " ".join(f"{x:.6f}" for x in coords))
|
| 69 |
+
|
| 70 |
+
return yolo_segments
|
| 71 |
+
|
| 72 |
+
def rle_to_yolo(rle, class_id=0):
|
| 73 |
+
mask = rle_to_mask(rle)
|
| 74 |
+
return mask_to_yolo(mask, class_id)
|
| 75 |
+
|
| 76 |
+
def uncompressed_rle_to_yolo(rle, class_id=0):
|
| 77 |
+
mask = uncompressed_rle_to_mask(rle)
|
| 78 |
+
return mask_to_yolo(mask, class_id)
|
| 79 |
+
|
| 80 |
+
def draw_yolo_on_pil(image, yolo_segments, color=(0,255,0)):
|
| 81 |
+
img = image.convert("RGB")
|
| 82 |
+
draw = ImageDraw.Draw(img)
|
| 83 |
+
w, h = img.size
|
| 84 |
+
|
| 85 |
+
for seg in yolo_segments:
|
| 86 |
+
parts = seg.strip().split()
|
| 87 |
+
class_id = int(parts[0])
|
| 88 |
+
coords = np.array([float(x) for x in parts[1:]]).reshape(-1, 2)
|
| 89 |
+
points = [(x * w, y * h) for x, y in coords]
|
| 90 |
+
draw.line(points + [points[0]], fill=color, width=2)
|
| 91 |
+
|
| 92 |
+
return img
|
training/turtle_detector/utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import sam3
|
| 4 |
+
from sam3 import build_sam3_image_model
|
| 5 |
+
from sam3.model.sam3_image_processor import Sam3Processor
|
| 6 |
+
from .masks import rle_to_mask
|
| 7 |
+
|
| 8 |
+
def get_index(dataset, image_id):
|
| 9 |
+
idx = dataset.metadata['image_id'] == image_id
|
| 10 |
+
if idx.sum() != 1:
|
| 11 |
+
raise ValueError('image_id not found or found multiple times.')
|
| 12 |
+
return dataset.metadata[idx].index[0]
|
| 13 |
+
|
| 14 |
+
def mask_centroid(mask):
|
| 15 |
+
ys, xs = np.nonzero(mask)
|
| 16 |
+
return np.array([xs.mean(), ys.mean()])
|
| 17 |
+
|
| 18 |
+
def rle_centroid(rle):
|
| 19 |
+
return mask_centroid(rle_to_mask(rle))
|
| 20 |
+
|
| 21 |
+
def assign_flippers(df):
|
| 22 |
+
df = df.copy()
|
| 23 |
+
|
| 24 |
+
# Check that there is only one head
|
| 25 |
+
head_rows = df[df['label'] == 'head']
|
| 26 |
+
if len(head_rows) != 1:
|
| 27 |
+
return df
|
| 28 |
+
|
| 29 |
+
# Compute the head centroid
|
| 30 |
+
head_center = rle_centroid(head_rows.iloc[0]['mask'])
|
| 31 |
+
|
| 32 |
+
# Extract the flippers
|
| 33 |
+
flippers = df[df['label'] == 'flipper']
|
| 34 |
+
n_flippers = len(flippers)
|
| 35 |
+
if n_flippers == 0:
|
| 36 |
+
return df
|
| 37 |
+
|
| 38 |
+
# Compute the flipper centroids
|
| 39 |
+
flipper_centers = np.vstack([
|
| 40 |
+
rle_centroid(rle) for rle in flippers['mask']
|
| 41 |
+
])
|
| 42 |
+
|
| 43 |
+
# Vector from turtle center to head defines "forward"
|
| 44 |
+
turtle_center = flipper_centers.mean(axis=0)
|
| 45 |
+
forward_vec = head_center - turtle_center
|
| 46 |
+
forward_vec /= np.linalg.norm(forward_vec)
|
| 47 |
+
|
| 48 |
+
# Perpendicular defines left/right
|
| 49 |
+
left_vec = np.array([-forward_vec[1], forward_vec[0]])
|
| 50 |
+
|
| 51 |
+
# Project flippers
|
| 52 |
+
forward_proj = flipper_centers @ forward_vec
|
| 53 |
+
lateral_proj = flipper_centers @ left_vec
|
| 54 |
+
|
| 55 |
+
if n_flippers <= 2:
|
| 56 |
+
# Always front flippers
|
| 57 |
+
order = np.argsort(lateral_proj)
|
| 58 |
+
left_idx, right_idx = order[0], order[-1]
|
| 59 |
+
|
| 60 |
+
df.loc[flippers.index[left_idx], 'label'] = 'flipper_fl'
|
| 61 |
+
df.loc[flippers.index[right_idx], 'label'] = 'flipper_fr'
|
| 62 |
+
return df
|
| 63 |
+
elif n_flippers <= 4:
|
| 64 |
+
# Sort by forward distance
|
| 65 |
+
order_fwd = np.argsort(forward_proj)
|
| 66 |
+
rear_idxs = order_fwd[:2]
|
| 67 |
+
front_idxs = order_fwd[-2:]
|
| 68 |
+
|
| 69 |
+
# Front flippers
|
| 70 |
+
front_l = front_idxs[np.argmin(lateral_proj[front_idxs])]
|
| 71 |
+
front_r = front_idxs[np.argmax(lateral_proj[front_idxs])]
|
| 72 |
+
|
| 73 |
+
df.loc[flippers.index[front_l], 'label'] = 'flipper_fl'
|
| 74 |
+
df.loc[flippers.index[front_r], 'label'] = 'flipper_fr'
|
| 75 |
+
|
| 76 |
+
# Rear flippers (if present)
|
| 77 |
+
if len(rear_idxs) == 2:
|
| 78 |
+
rear_l = rear_idxs[np.argmin(lateral_proj[rear_idxs])]
|
| 79 |
+
rear_r = rear_idxs[np.argmax(lateral_proj[rear_idxs])]
|
| 80 |
+
|
| 81 |
+
df.loc[flippers.index[rear_l], 'label'] = 'flipper_rl'
|
| 82 |
+
df.loc[flippers.index[rear_r], 'label'] = 'flipper_rr'
|
| 83 |
+
else:
|
| 84 |
+
# 3 flippers: assign only the most rear one
|
| 85 |
+
idx = rear_idxs[0]
|
| 86 |
+
side = 'l' if lateral_proj[idx] < 0 else 'r'
|
| 87 |
+
df.loc[flippers.index[idx], 'label'] = f'flipper_r{side}'
|
| 88 |
+
|
| 89 |
+
return df
|
| 90 |
+
|
| 91 |
+
def initialize_sam3():
|
| 92 |
+
sam3_root = os.path.join(os.path.dirname(sam3.__file__), "..")
|
| 93 |
+
bpe_path = f"{sam3_root}/sam3/assets/bpe_simple_vocab_16e6.txt.gz"
|
| 94 |
+
model = build_sam3_image_model(bpe_path=bpe_path)
|
| 95 |
+
processor = Sam3Processor(model, confidence_threshold=0.5)
|
| 96 |
+
return model, processor
|