BenjiELCA commited on
Commit
2da5c78
·
1 Parent(s): 42199f1

commit with training and evaluation code

Browse files
.gitignore CHANGED
@@ -14,8 +14,6 @@ backup/
14
 
15
  temp.jpg
16
 
17
- Evaluation.ipynb
18
-
19
  study/
20
 
21
  result_bpmn.bpmn
@@ -24,8 +22,8 @@ BPMN_creation.ipynb
24
 
25
  *.png
26
 
27
- *.ipynb
28
-
29
  *.pmw
30
 
31
  best_models.txt
 
 
 
14
 
15
  temp.jpg
16
 
 
 
17
  study/
18
 
19
  result_bpmn.bpmn
 
22
 
23
  *.png
24
 
 
 
25
  *.pmw
26
 
27
  best_models.txt
28
+
29
+ Wizard_creation.ipynb
Evaluation.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Evaluation_colab.ipynb ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "id": "N7fMlFb-n3dJ",
11
+ "outputId": "ed9bb8ea-42a4-4e07-fdfa-d5a9eca0253f"
12
+ },
13
+ "outputs": [
14
+ {
15
+ "name": "stdout",
16
+ "output_type": "stream",
17
+ "text": [
18
+ "Requirement already satisfied: yamlu in /usr/local/lib/python3.10/dist-packages (0.0.17)\n",
19
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from yamlu) (3.7.1)\n",
20
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from yamlu) (1.26.4)\n",
21
+ "Requirement already satisfied: Pillow in /usr/local/lib/python3.10/dist-packages (from yamlu) (9.4.0)\n",
22
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->yamlu) (1.2.1)\n",
23
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->yamlu) (0.12.1)\n",
24
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->yamlu) (4.53.1)\n",
25
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->yamlu) (1.4.5)\n",
26
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->yamlu) (24.1)\n",
27
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->yamlu) (3.1.2)\n",
28
+ "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->yamlu) (2.8.2)\n",
29
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->yamlu) (1.16.0)\n",
30
+ "Requirement already satisfied: optuna in /usr/local/lib/python3.10/dist-packages (3.6.1)\n",
31
+ "Requirement already satisfied: alembic>=1.5.0 in /usr/local/lib/python3.10/dist-packages (from optuna) (1.13.2)\n",
32
+ "Requirement already satisfied: colorlog in /usr/local/lib/python3.10/dist-packages (from optuna) (6.8.2)\n",
33
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from optuna) (1.26.4)\n",
34
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from optuna) (24.1)\n",
35
+ "Requirement already satisfied: sqlalchemy>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from optuna) (2.0.32)\n",
36
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from optuna) (4.66.5)\n",
37
+ "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from optuna) (6.0.2)\n",
38
+ "Requirement already satisfied: Mako in /usr/local/lib/python3.10/dist-packages (from alembic>=1.5.0->optuna) (1.3.5)\n",
39
+ "Requirement already satisfied: typing-extensions>=4 in /usr/local/lib/python3.10/dist-packages (from alembic>=1.5.0->optuna) (4.12.2)\n",
40
+ "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from sqlalchemy>=1.3.0->optuna) (3.0.3)\n",
41
+ "Requirement already satisfied: MarkupSafe>=0.9.2 in /usr/local/lib/python3.10/dist-packages (from Mako->alembic>=1.5.0->optuna) (2.1.5)\n",
42
+ "Requirement already satisfied: streamlit in /usr/local/lib/python3.10/dist-packages (1.37.1)\n",
43
+ "Requirement already satisfied: altair<6,>=4.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (4.2.2)\n",
44
+ "Requirement already satisfied: blinker<2,>=1.0.0 in /usr/lib/python3/dist-packages (from streamlit) (1.4)\n",
45
+ "Requirement already satisfied: cachetools<6,>=4.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (5.5.0)\n",
46
+ "Requirement already satisfied: click<9,>=7.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (8.1.7)\n",
47
+ "Requirement already satisfied: numpy<3,>=1.20 in /usr/local/lib/python3.10/dist-packages (from streamlit) (1.26.4)\n",
48
+ "Requirement already satisfied: packaging<25,>=20 in /usr/local/lib/python3.10/dist-packages (from streamlit) (24.1)\n",
49
+ "Requirement already satisfied: pandas<3,>=1.3.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (2.1.4)\n",
50
+ "Requirement already satisfied: pillow<11,>=7.1.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (9.4.0)\n",
51
+ "Requirement already satisfied: protobuf<6,>=3.20 in /usr/local/lib/python3.10/dist-packages (from streamlit) (3.20.3)\n",
52
+ "Requirement already satisfied: pyarrow>=7.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (14.0.2)\n",
53
+ "Requirement already satisfied: requests<3,>=2.27 in /usr/local/lib/python3.10/dist-packages (from streamlit) (2.32.3)\n",
54
+ "Requirement already satisfied: rich<14,>=10.14.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (13.7.1)\n",
55
+ "Requirement already satisfied: tenacity<9,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (8.5.0)\n",
56
+ "Requirement already satisfied: toml<2,>=0.10.1 in /usr/local/lib/python3.10/dist-packages (from streamlit) (0.10.2)\n",
57
+ "Requirement already satisfied: typing-extensions<5,>=4.3.0 in /usr/local/lib/python3.10/dist-packages (from streamlit) (4.12.2)\n",
58
+ "Requirement already satisfied: gitpython!=3.1.19,<4,>=3.0.7 in /usr/local/lib/python3.10/dist-packages (from streamlit) (3.1.43)\n",
59
+ "Requirement already satisfied: pydeck<1,>=0.8.0b4 in /usr/local/lib/python3.10/dist-packages (from streamlit) (0.9.1)\n",
60
+ "Requirement already satisfied: tornado<7,>=6.0.3 in /usr/local/lib/python3.10/dist-packages (from streamlit) (6.3.3)\n",
61
+ "Requirement already satisfied: watchdog<5,>=2.1.5 in /usr/local/lib/python3.10/dist-packages (from streamlit) (4.0.2)\n",
62
+ "Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit) (0.4)\n",
63
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit) (3.1.4)\n",
64
+ "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit) (4.23.0)\n",
65
+ "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair<6,>=4.0->streamlit) (0.12.1)\n",
66
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from gitpython!=3.1.19,<4,>=3.0.7->streamlit) (4.0.11)\n",
67
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas<3,>=1.3.0->streamlit) (2.8.2)\n",
68
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3,>=1.3.0->streamlit) (2024.1)\n",
69
+ "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas<3,>=1.3.0->streamlit) (2024.1)\n",
70
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.27->streamlit) (3.3.2)\n",
71
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.27->streamlit) (3.7)\n",
72
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.27->streamlit) (2.0.7)\n",
73
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.27->streamlit) (2024.7.4)\n",
74
+ "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich<14,>=10.14.0->streamlit) (3.0.0)\n",
75
+ "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich<14,>=10.14.0->streamlit) (2.16.1)\n",
76
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->gitpython!=3.1.19,<4,>=3.0.7->streamlit) (5.0.1)\n",
77
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->altair<6,>=4.0->streamlit) (2.1.5)\n",
78
+ "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (24.2.0)\n",
79
+ "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (2023.12.1)\n",
80
+ "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (0.35.1)\n",
81
+ "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair<6,>=4.0->streamlit) (0.20.0)\n",
82
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich<14,>=10.14.0->streamlit) (0.1.2)\n",
83
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas<3,>=1.3.0->streamlit) (1.16.0)\n",
84
+ "Mounted at /content/drive\n"
85
+ ]
86
+ }
87
+ ],
88
+ "source": [
89
+ "%pip install yamlu\n",
90
+ "%pip install optuna\n",
91
+ "%pip install streamlit\n",
92
+ "\n",
93
+ "from google.colab import drive\n",
94
+ "import os\n",
95
+ "\n",
96
+ "drive.mount('/content/drive')\n",
97
+ "path = 'drive/MyDrive/ELCA/BPMN project/'\n",
98
+ "\n",
99
+ "os.chdir(path)\n"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {
106
+ "colab": {
107
+ "base_uri": "https://localhost:8080/"
108
+ },
109
+ "id": "YkZcbI53n3Dm",
110
+ "outputId": "18adb94b-567a-46eb-ee84-0b4a77dc00a2"
111
+ },
112
+ "outputs": [
113
+ {
114
+ "name": "stderr",
115
+ "output_type": "stream",
116
+ "text": [
117
+ "100%|██████████| 92/92 [00:30<00:00, 3.04it/s]\n"
118
+ ]
119
+ }
120
+ ],
121
+ "source": [
122
+ "from yamlu import ls\n",
123
+ "from yamlu.coco_read import CocoReader\n",
124
+ "from pathlib import Path\n",
125
+ "import cv2\n",
126
+ "from modules.utils import *\n",
127
+ "from modules.eval import *\n",
128
+ "from modules.train import *\n",
129
+ "from modules.dataset_loader import *\n",
130
+ "\n",
131
+ "dataset_path = Path(\"../data/hdBPMN-COCO\")\n",
132
+ "ls(dataset_path)\n",
133
+ "\n",
134
+ "\n",
135
+ "bpmn_reader = CocoReader(\n",
136
+ " dataset_root=dataset_path,\n",
137
+ " arrow_categories=[\"sequenceFlow\", \"messageFlow\", \"dataAssociation\"],\n",
138
+ ")\n",
139
+ "\n",
140
+ "\n",
141
+ "test_anot = bpmn_reader.parse_split(\"test\")"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "metadata": {
148
+ "colab": {
149
+ "base_uri": "https://localhost:8080/"
150
+ },
151
+ "id": "Ert1SxZbn3Dn",
152
+ "outputId": "6384181c-9129-489a-e694-f6b0cd1b57a7"
153
+ },
154
+ "outputs": [
155
+ {
156
+ "name": "stdout",
157
+ "output_type": "stream",
158
+ "text": [
159
+ "Loaded 92 annotations.\n"
160
+ ]
161
+ }
162
+ ],
163
+ "source": [
164
+ "from torchvision import transforms\n",
165
+ "from modules.utils import object_dict, arrow_dict, class_dict\n",
166
+ "from modules.dataset_loader import create_loader\n",
167
+ "\n",
168
+ "new_size = (1333,1333)\n",
169
+ "\n",
170
+ "model_type = 'object'\n",
171
+ "\n",
172
+ "if model_type == 'object':\n",
173
+ " model_dict = object_dict\n",
174
+ "else:\n",
175
+ " model_dict = arrow_dict\n",
176
+ "\n",
177
+ "transformation_test = transforms.Compose([\n",
178
+ " transforms.ToTensor(),\n",
179
+ "\n",
180
+ "])\n",
181
+ "\n",
182
+ "test_loader = create_loader(new_size, transformation_test, test_anot, batch_size=1, model_type = model_type, seed=42)\n"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "code",
187
+ "execution_count": null,
188
+ "metadata": {
189
+ "id": "hp8jehlrXOay"
190
+ },
191
+ "outputs": [],
192
+ "source": [
193
+ "from modules.train import get_faster_rcnn_model, get_arrow_model\n",
194
+ "import torch\n",
195
+ "\n",
196
+ "# Function to load the models only once and use session state to keep track of it\n",
197
+ "def load_object_models(model_to_load, model_dict):\n",
198
+ " # Adjusted to pass the class_dict directly\n",
199
+ " model = get_faster_rcnn_model(len(model_dict))\n",
200
+ "\n",
201
+ " device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
202
+ " # Load the model weights\n",
203
+ " model.load_state_dict(torch.load('./models/'+ model_to_load, map_location=device))\n",
204
+ "\n",
205
+ "\n",
206
+ " model.to(device)\n",
207
+ "\n",
208
+ " return model\n",
209
+ "\n",
210
+ "def load_arrow_models(model_to_load, arrow_dict):\n",
211
+ " model = get_arrow_model(len(arrow_dict),2)\n",
212
+ "\n",
213
+ " device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
214
+ " # Load the model weights\n",
215
+ " model.load_state_dict(torch.load('./models/'+ model_to_load, map_location=device))\n",
216
+ "\n",
217
+ "\n",
218
+ " model.to(device)\n",
219
+ "\n",
220
+ " return model"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "execution_count": null,
226
+ "metadata": {
227
+ "colab": {
228
+ "base_uri": "https://localhost:8080/"
229
+ },
230
+ "id": "6hdMAQ7RX8K8",
231
+ "outputId": "1315f88b-06c9-42c8-f209-c45abe852e06"
232
+ },
233
+ "outputs": [
234
+ {
235
+ "name": "stdout",
236
+ "output_type": "stream",
237
+ "text": [
238
+ "['model_AdamW_1ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject2.pth', 'model_AdamW_1ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject3.pth', 'model_AdamW_2ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject3.pth', 'model_AdamW_3ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject3.pth', 'model_AdamW_1ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth', 'model_AdamW_2ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth', 'model_AdamW_3ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth', 'model_AdamW_4ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth', 'model_AdamW_5ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth', 'model_AdamW_1ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_arrow4.pth', 'model_AdamW_2ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_arrow4.pth', 'model_AdamW_3ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_arrow4.pth', 'model_AdamW_4ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_arrow4.pth', 'model_AdamW_5ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_arrow4.pth']\n",
239
+ "There is 14 models to test\n"
240
+ ]
241
+ }
242
+ ],
243
+ "source": [
244
+ "import os\n",
245
+ "model_folder = \"models\"\n",
246
+ "elements = os.listdir(model_folder)\n",
247
+ "elements = [element for element in elements if \"Adam\" in element]\n",
248
+ "#elements = [element for element in elements if \"recall\" in element]\n",
249
+ "print(elements)\n",
250
+ "print(f\"There is {len(elements)} models to test\")"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "metadata": {
257
+ "id": "fwz0QKdxgBJz"
258
+ },
259
+ "outputs": [],
260
+ "source": [
261
+ "from modules.eval import main_evaluation"
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": null,
267
+ "metadata": {
268
+ "colab": {
269
+ "base_uri": "https://localhost:8080/"
270
+ },
271
+ "id": "bHkTL_5Jq_t0",
272
+ "outputId": "026ada96-c865-4a88-c212-8b455d659859"
273
+ },
274
+ "outputs": [
275
+ {
276
+ "name": "stdout",
277
+ "output_type": "stream",
278
+ "text": [
279
+ "There is 8 models to test\n"
280
+ ]
281
+ },
282
+ {
283
+ "name": "stderr",
284
+ "output_type": "stream",
285
+ "text": [
286
+ "Testing... : 100%|██████████| 92/92 [00:14<00:00, 6.30it/s]\n"
287
+ ]
288
+ },
289
+ {
290
+ "name": "stdout",
291
+ "output_type": "stream",
292
+ "text": [
293
+ "1: model_AdamW_1ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject2.pth\n",
294
+ "Labels_precision: 0.9683, Precision: 0.9742, Recall: 0.9438, F1 Score: 0.9588 \n"
295
+ ]
296
+ },
297
+ {
298
+ "name": "stderr",
299
+ "output_type": "stream",
300
+ "text": [
301
+ "Testing... : 100%|██████████| 92/92 [00:14<00:00, 6.38it/s]\n"
302
+ ]
303
+ },
304
+ {
305
+ "name": "stdout",
306
+ "output_type": "stream",
307
+ "text": [
308
+ "2: model_AdamW_1ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject3.pth\n",
309
+ "Labels_precision: 0.9701, Precision: 0.9541, Recall: 0.9600, F1 Score: 0.9571 \n"
310
+ ]
311
+ },
312
+ {
313
+ "name": "stderr",
314
+ "output_type": "stream",
315
+ "text": [
316
+ "Testing... : 100%|██████████| 92/92 [00:14<00:00, 6.27it/s]\n"
317
+ ]
318
+ },
319
+ {
320
+ "name": "stdout",
321
+ "output_type": "stream",
322
+ "text": [
323
+ "3: model_AdamW_2ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject3.pth\n",
324
+ "Labels_precision: 0.9701, Precision: 0.9541, Recall: 0.9600, F1 Score: 0.9571 \n"
325
+ ]
326
+ },
327
+ {
328
+ "name": "stderr",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "Testing... : 100%|██████████| 92/92 [00:14<00:00, 6.43it/s]\n"
332
+ ]
333
+ },
334
+ {
335
+ "name": "stdout",
336
+ "output_type": "stream",
337
+ "text": [
338
+ "4: model_AdamW_3ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject3.pth\n",
339
+ "Labels_precision: 0.9699, Precision: 0.9658, Recall: 0.9532, F1 Score: 0.9595 \n"
340
+ ]
341
+ },
342
+ {
343
+ "name": "stderr",
344
+ "output_type": "stream",
345
+ "text": [
346
+ "Testing... : 100%|██████████| 92/92 [00:14<00:00, 6.38it/s]\n"
347
+ ]
348
+ },
349
+ {
350
+ "name": "stdout",
351
+ "output_type": "stream",
352
+ "text": [
353
+ "5: model_AdamW_1ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth\n",
354
+ "Labels_precision: 0.9649, Precision: 0.9565, Recall: 0.9607, F1 Score: 0.9586 \n"
355
+ ]
356
+ },
357
+ {
358
+ "name": "stderr",
359
+ "output_type": "stream",
360
+ "text": [
361
+ "Testing... : 100%|██████████| 92/92 [00:14<00:00, 6.41it/s]\n"
362
+ ]
363
+ },
364
+ {
365
+ "name": "stdout",
366
+ "output_type": "stream",
367
+ "text": [
368
+ "6: model_AdamW_2ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth\n",
369
+ "Labels_precision: 0.9704, Precision: 0.9700, Recall: 0.9482, F1 Score: 0.9590 \n"
370
+ ]
371
+ },
372
+ {
373
+ "name": "stderr",
374
+ "output_type": "stream",
375
+ "text": [
376
+ "Testing... : 100%|██████���███| 92/92 [00:14<00:00, 6.47it/s]\n"
377
+ ]
378
+ },
379
+ {
380
+ "name": "stdout",
381
+ "output_type": "stream",
382
+ "text": [
383
+ "7: model_AdamW_3ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth\n",
384
+ "Labels_precision: 0.9708, Precision: 0.9631, Recall: 0.9619, F1 Score: 0.9625 \n"
385
+ ]
386
+ },
387
+ {
388
+ "name": "stderr",
389
+ "output_type": "stream",
390
+ "text": [
391
+ "Testing... : 100%|██████████| 92/92 [00:14<00:00, 6.41it/s]"
392
+ ]
393
+ },
394
+ {
395
+ "name": "stdout",
396
+ "output_type": "stream",
397
+ "text": [
398
+ "8: model_AdamW_4ep_4batch_trainval_blur00_crop02_flip02_rotate02_finetune_bestobject4.pth\n",
399
+ "Labels_precision: 0.9708, Precision: 0.9631, Recall: 0.9619, F1 Score: 0.9625 \n"
400
+ ]
401
+ },
402
+ {
403
+ "name": "stderr",
404
+ "output_type": "stream",
405
+ "text": [
406
+ "\n"
407
+ ]
408
+ }
409
+ ],
410
+ "source": [
411
+ "results = {}\n",
412
+ "print(f\"There is {len(elements)} models to test\")\n",
413
+ "for idx, model_name in enumerate(elements):\n",
414
+ " if model_type == 'object':\n",
415
+ " model = load_object_models(model_name, model_dict)\n",
416
+ " else:\n",
417
+ " model = load_arrow_models(model_name, model_dict)\n",
418
+ "\n",
419
+ " labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, test_loader,score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)\n",
420
+ " print(f\"{idx+1}: {model_name}\")\n",
421
+ " print(f\"Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} \")\n",
422
+ " results[model_name] = [labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy]"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "metadata": {
429
+ "colab": {
430
+ "base_uri": "https://localhost:8080/",
431
+ "height": 88
432
+ },
433
+ "id": "v0pe9A7DnbUV",
434
+ "outputId": "80386876-8f54-4166-cb7a-bb7c63cf9414"
435
+ },
436
+ "outputs": [
437
+ {
438
+ "data": {
439
+ "application/vnd.google.colaboratory.intrinsic+json": {
440
+ "type": "string"
441
+ },
442
+ "text/plain": [
443
+ "'for i, metric in enumerate([\\'labels_precision\\', \\'precision\\', \\'recall\\', \\'f1_score\\',\\'key_accuracy\\']):\\n best_model = max(results, key=lambda x: results[x][i])\\n print(f\"Best model for {metric}: {best_model}\")\\n #print all score for this one\\n print(f\\'Labels Precision: {results[best_model][0]:.3f}, Precision: {results[best_model][1]:.3f}, Recall: {results[best_model][2]:.3f}, F1 Score: {results[best_model][3]:.3f}, Key Accuracy: {results[best_model][4]:.3f}\\')'"
444
+ ]
445
+ },
446
+ "execution_count": 9,
447
+ "metadata": {},
448
+ "output_type": "execute_result"
449
+ }
450
+ ],
451
+ "source": [
452
+ "\"\"\"for i, metric in enumerate(['labels_precision', 'precision', 'recall', 'f1_score','key_accuracy']):\n",
453
+ " best_model = max(results, key=lambda x: results[x][i])\n",
454
+ " print(f\"Best model for {metric}: {best_model}\")\n",
455
+ " #print all score for this one\n",
456
+ " print(f'Labels Precision: {results[best_model][0]:.3f}, Precision: {results[best_model][1]:.3f}, Recall: {results[best_model][2]:.3f}, F1 Score: {results[best_model][3]:.3f}, Key Accuracy: {results[best_model][4]:.3f}')\"\"\""
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "metadata": {
463
+ "colab": {
464
+ "base_uri": "https://localhost:8080/"
465
+ },
466
+ "id": "HMyYdPjLiGMH",
467
+ "outputId": "b5e28040-9703-4d3c-9aeb-8ab945a78c21"
468
+ },
469
+ "outputs": [
470
+ {
471
+ "name": "stderr",
472
+ "output_type": "stream",
473
+ "text": [
474
+ "Downloading: \"https://download.pytorch.org/models/resnet50-0676ba61.pth\" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth\n",
475
+ "100%|██████████| 97.8M/97.8M [00:03<00:00, 31.2MB/s]\n",
476
+ "Testing... : 100%|██████████| 92/92 [00:20<00:00, 4.44it/s]"
477
+ ]
478
+ },
479
+ {
480
+ "name": "stdout",
481
+ "output_type": "stream",
482
+ "text": [
483
+ "best_model_object.pth\n",
484
+ "Labels_precision: 0.9671, Precision: 0.9429, Recall: 0.9682, F1 Score: 0.9553\n"
485
+ ]
486
+ },
487
+ {
488
+ "name": "stderr",
489
+ "output_type": "stream",
490
+ "text": [
491
+ "\n"
492
+ ]
493
+ }
494
+ ],
495
+ "source": [
496
+ "from modules.eval import main_evaluation\n",
497
+ "\n",
498
+ "\n",
499
+ "results = {}\n",
500
+ "model_name = 'best_model_object.pth'\n",
501
+ "model_dict = object_dict\n",
502
+ "model = load_object_models(model_name, model_dict)\n",
503
+ "\n",
504
+ "labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, test_loader,score_threshold=0.5, iou_threshold=0.5, model_type=model_type)\n",
505
+ "print(model_name)\n",
506
+ "print(f\"Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}\")\n",
507
+ "#results[model_name] = [labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy]"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": null,
513
+ "metadata": {
514
+ "colab": {
515
+ "base_uri": "https://localhost:8080/"
516
+ },
517
+ "id": "r6yDD7CljRXA",
518
+ "outputId": "53eb3edc-7dfd-47fc-e9f8-72b208aefd6e"
519
+ },
520
+ "outputs": [
521
+ {
522
+ "name": "stderr",
523
+ "output_type": "stream",
524
+ "text": [
525
+ "Testing... : 100%|██████████| 92/92 [00:15<00:00, 5.83it/s]"
526
+ ]
527
+ },
528
+ {
529
+ "name": "stdout",
530
+ "output_type": "stream",
531
+ "text": [
532
+ "\n",
533
+ "Class Precision: {'background': 0, 'task': 0.967741935483871, 'exclusiveGateway': 0.9433962264150944, 'event': 0.9461077844311377, 'parallelGateway': 0.926829268292683, 'messageEvent': 0.9230769230769231, 'pool': 0.7453416149068323, 'lane': 0.8554216867469879, 'dataObject': 0.8651685393258427, 'dataStore': 1.0, 'subProcess': 0.0, 'eventBasedGateway': 0.7272727272727273, 'timerEvent': 0.7916666666666666}\n",
534
+ "Class Recall: {'background': 0, 'task': 0.9810671256454389, 'exclusiveGateway': 0.9554140127388535, 'event': 0.9294117647058824, 'parallelGateway': 0.9344262295081968, 'messageEvent': 0.9523809523809523, 'pool': 0.96, 'lane': 0.71, 'dataObject': 0.9565217391304348, 'dataStore': 0.64, 'subProcess': 0, 'eventBasedGateway': 0.7272727272727273, 'timerEvent': 0.7916666666666666}\n",
535
+ "Class F1 Score: {'background': 0, 'task': 0.9743589743589743, 'exclusiveGateway': 0.949367088607595, 'event': 0.9376854599406529, 'parallelGateway': 0.9306122448979592, 'messageEvent': 0.9375, 'pool': 0.8391608391608391, 'lane': 0.7759562841530054, 'dataObject': 0.9085545722713865, 'dataStore': 0.7804878048780487, 'subProcess': 0, 'eventBasedGateway': 0.7272727272727273, 'timerEvent': 0.7916666666666666}\n"
536
+ ]
537
+ },
538
+ {
539
+ "name": "stderr",
540
+ "output_type": "stream",
541
+ "text": [
542
+ "\n"
543
+ ]
544
+ }
545
+ ],
546
+ "source": [
547
+ "class_precision, class_recall, class_f1_score = evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.5, iou_threshold=0.5)\n",
548
+ "print(f\"\\nClass Precision: {class_precision}\")\n",
549
+ "print(f\"Class Recall: {class_recall}\")\n",
550
+ "print(f\"Class F1 Score: {class_f1_score}\")"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": null,
556
+ "metadata": {
557
+ "colab": {
558
+ "base_uri": "https://localhost:8080/"
559
+ },
560
+ "id": "1wtvRs4zqoDN",
561
+ "outputId": "08b8f742-2ef3-4414-d84f-e9d089d32b16"
562
+ },
563
+ "outputs": [
564
+ {
565
+ "name": "stdout",
566
+ "output_type": "stream",
567
+ "text": [
568
+ "Average Precision: 0.9429\n",
569
+ "Average Recall: 0.9682\n",
570
+ "Average F1 Score: 0.9553\n"
571
+ ]
572
+ }
573
+ ],
574
+ "source": [
575
+ "import numpy as np\n",
576
+ "\n",
577
+ "#average each\n",
578
+ "average_precision = np.mean(precision)\n",
579
+ "average_recall = np.mean(recall)\n",
580
+ "average_f1_score = np.mean(f1_score)\n",
581
+ "\n",
582
+ "print(f\"Average Precision: {average_precision:.4f}\")\n",
583
+ "print(f\"Average Recall: {average_recall:.4f}\")\n",
584
+ "print(f\"Average F1 Score: {average_f1_score:.4f}\")"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": null,
590
+ "metadata": {
591
+ "colab": {
592
+ "base_uri": "https://localhost:8080/"
593
+ },
594
+ "id": "aHVvDOEvKdL4",
595
+ "outputId": "f6e636aa-d281-4e67-de43-f1783c06194b"
596
+ },
597
+ "outputs": [
598
+ {
599
+ "name": "stdout",
600
+ "output_type": "stream",
601
+ "text": [
602
+ "Loaded 92 annotations.\n"
603
+ ]
604
+ }
605
+ ],
606
+ "source": [
607
+ "from torchvision import transforms\n",
608
+ "#from modules.utils import object_dict, arrow_dict, class_dict\n",
609
+ "\n",
610
+ "#new_size = (640, 384)\n",
611
+ "new_size = (1333,1333)\n",
612
+ "\n",
613
+ "model_type = 'arrow'\n",
614
+ "\n",
615
+ "if model_type == 'object':\n",
616
+ " model_dict = object_dict\n",
617
+ "else:\n",
618
+ " model_dict = arrow_dict\n",
619
+ "\n",
620
+ "transformation_test = transforms.Compose([\n",
621
+ " transforms.ToTensor(),\n",
622
+ "\n",
623
+ "])\n",
624
+ "\n",
625
+ "test_loader = create_loader(new_size, transformation_test, test_anot, batch_size=1, model_type = model_type)\n"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": null,
631
+ "metadata": {
632
+ "colab": {
633
+ "base_uri": "https://localhost:8080/"
634
+ },
635
+ "id": "gIyJdC3shmGU",
636
+ "outputId": "eccaf29a-b01a-460c-fada-34fbd3f626bf"
637
+ },
638
+ "outputs": [
639
+ {
640
+ "name": "stderr",
641
+ "output_type": "stream",
642
+ "text": [
643
+ "Testing... : 100%|██████████| 92/92 [00:19<00:00, 4.69it/s]"
644
+ ]
645
+ },
646
+ {
647
+ "name": "stdout",
648
+ "output_type": "stream",
649
+ "text": [
650
+ "\n",
651
+ " best_model_arrow.pth\n",
652
+ "Labels_precision: 0.9873, Precision: 0.9203, Recall: 0.9256, F1 Score: 0.9229, Key Accuracy: 0.7065, Reverted Accuracy: 0.0196\n"
653
+ ]
654
+ },
655
+ {
656
+ "name": "stderr",
657
+ "output_type": "stream",
658
+ "text": [
659
+ "\n"
660
+ ]
661
+ }
662
+ ],
663
+ "source": [
664
+ "from modules.eval import main_evaluation\n",
665
+ "\n",
666
+ "results = {}\n",
667
+ "model_name = 'best_model_arrow.pth'\n",
668
+ "model = load_arrow_models(model_name, model_dict)\n",
669
+ "\n",
670
+ "for i in range(5):\n",
671
+ " test_loader = create_loader(new_size, transformation_test, test_anot, batch_size=1, model_type = model_type, seed=42+i)\n",
672
+ " labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, test_loader,score_threshold=0.7, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)\n",
673
+ " print(\"\\n\",model_name)\n",
674
+ " print(f\"Seed: {42+i} ,Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f}, Key Accuracy: {key_accuracy:.4f}, Reverted Accuracy: {reverted_accuracy:.4f}\")\n",
675
+ " #results[model_name] = [labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy]"
676
+ ]
677
+ },
678
+ {
679
+ "cell_type": "code",
680
+ "execution_count": null,
681
+ "metadata": {
682
+ "colab": {
683
+ "base_uri": "https://localhost:8080/"
684
+ },
685
+ "id": "KIUAasG5hzw1",
686
+ "outputId": "6a5617b1-ed1b-4237-dee2-3fa40f14f99a"
687
+ },
688
+ "outputs": [
689
+ {
690
+ "name": "stderr",
691
+ "output_type": "stream",
692
+ "text": [
693
+ "Testing... : 100%|██████████| 92/92 [00:19<00:00, 4.78it/s]"
694
+ ]
695
+ },
696
+ {
697
+ "name": "stdout",
698
+ "output_type": "stream",
699
+ "text": [
700
+ "Class Precision: {'background': 0, 'sequenceFlow': 0.9075697211155378, 'dataAssociation': 0.7788778877887789, 'messageFlow': 0.7914110429447853}\n",
701
+ "Class Recall: {'background': 0, 'sequenceFlow': 0.9366776315789473, 'dataAssociation': 0.7492063492063492, 'messageFlow': 0.7288135593220338}\n",
702
+ "Class F1 Score: {'background': 0, 'sequenceFlow': 0.9218939700526103, 'dataAssociation': 0.7637540453074433, 'messageFlow': 0.7588235294117648}\n"
703
+ ]
704
+ },
705
+ {
706
+ "name": "stderr",
707
+ "output_type": "stream",
708
+ "text": [
709
+ "\n"
710
+ ]
711
+ }
712
+ ],
713
+ "source": [
714
+ "from modules.eval import evaluate_model_by_class\n",
715
+ "\n",
716
+ "class_precision, class_recall, class_f1_score = evaluate_model_by_class(model, test_loader, model_dict, score_threshold=0.7, iou_threshold=0.6)\n",
717
+ "print(f\"Class Precision: {class_precision}\")\n",
718
+ "print(f\"Class Recall: {class_recall}\")\n",
719
+ "print(f\"Class F1 Score: {class_f1_score}\")"
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "code",
724
+ "execution_count": null,
725
+ "metadata": {
726
+ "id": "fwkbOQ8Yq019"
727
+ },
728
+ "outputs": [],
729
+ "source": []
730
+ }
731
+ ],
732
+ "metadata": {
733
+ "accelerator": "GPU",
734
+ "colab": {
735
+ "gpuType": "T4",
736
+ "machine_shape": "hm",
737
+ "provenance": []
738
+ },
739
+ "kernelspec": {
740
+ "display_name": "Python 3",
741
+ "name": "python3"
742
+ },
743
+ "language_info": {
744
+ "codemirror_mode": {
745
+ "name": "ipython",
746
+ "version": 3
747
+ },
748
+ "file_extension": ".py",
749
+ "mimetype": "text/x-python",
750
+ "name": "python",
751
+ "nbconvert_exporter": "python",
752
+ "pygments_lexer": "ipython3",
753
+ "version": "3.12.2"
754
+ }
755
+ },
756
+ "nbformat": 4,
757
+ "nbformat_minor": 0
758
+ }
Training_model colab.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Training_model.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
modules/train.py CHANGED
@@ -87,7 +87,7 @@ def prepare_model(dict, opti, learning_rate=0.0003, model_to_load=None, model_ty
87
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
88
  # Load the model weights
89
  if model_to_load:
90
- model.load_state_dict(torch.load('./models/' + model_to_load + '.pth', map_location=device))
91
  print(f"Model '{model_to_load}' loaded")
92
 
93
  model.to(device)
@@ -191,228 +191,187 @@ def evaluate_loss(model, data_loader, device, loss_config=None, print_losses=Fal
191
 
192
 
193
  def training_model(num_epochs, model, data_loader, subset_test_loader,
194
- optimizer, model_to_load=None, change_learning_rate=100, start_key=100,
195
  parameters=None, blur_prob=0.02,
196
  score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
197
  information_training='training', start_epoch=0, loss_config=None, model_type='object',
198
  eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
199
- """
200
- Train the model over a specified number of epochs.
201
-
202
- Parameters:
203
- - num_epochs (int): Number of epochs to train for.
204
- - model (torch.nn.Module): Model to train.
205
- - data_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
206
- - subset_test_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
207
- - optimizer (torch.optim.Optimizer): Optimizer to use for training.
208
- - model_to_load (str, optional): Name of the model to load.
209
- - change_learning_rate (int): Epoch interval to change the learning rate.
210
- - start_key (int): Epoch to start training keypoints.
211
- - parameters (dict, optional): Additional training parameters.
212
- - blur_prob (float): Probability of applying blur augmentation.
213
- - score_threshold (float): Score threshold for evaluation.
214
- - iou_threshold (float): IoU threshold for evaluation.
215
- - early_stop_f1_score (float): F1 score threshold for early stopping.
216
- - information_training (str): Information about the training.
217
- - start_epoch (int): Starting epoch number.
218
- - loss_config (dict, optional): Configuration specifying which losses to use.
219
- - model_type (str): Type of model ('object' or 'arrow').
220
- - eval_metric (str): Evaluation metric ('f1_score', 'precision', 'recall', or 'loss').
221
- - device (torch.device): Device to perform training on.
222
-
223
- Returns:
224
- - model (torch.nn.Module): Trained model.
225
- """
226
- model.train()
227
-
228
- if loss_config is None:
229
- print('No loss config found, all losses will be used.')
230
- else:
231
- # Print the list of the losses that will be used
232
- print('The following losses will be used: ', end='')
233
- for key, value in loss_config.items():
234
- if value:
235
- print(key, end=", ")
236
- print()
237
-
238
- # Initialize lists to store epoch-wise average losses
239
- epoch_avg_losses = []
240
- epoch_avg_loss_classifier = []
241
- epoch_avg_loss_box_reg = []
242
- epoch_avg_loss_objectness = []
243
- epoch_avg_loss_rpn_box_reg = []
244
- epoch_avg_loss_keypoints = []
245
- epoch_precision = []
246
- epoch_recall = []
247
- epoch_f1_score = []
248
- epoch_test_loss = []
249
-
250
- start_tot = time.time()
251
- best_metrics = -1000
252
- best_epoch = 0
253
- best_model_state = None
254
- same = 0
255
- learning_rate = optimizer.param_groups[0]['lr']
256
- bad_test_loss = 0
257
- previous_test_loss = 1000
258
-
259
- if parameters is not None:
260
- batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
261
-
262
- print(f"Let's go training {model_type} model with {num_epochs} epochs!")
263
- if parameters is not None:
264
- print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
265
-
266
- for epoch in range(num_epochs):
267
- if (epoch > 0 and (epoch) % change_learning_rate == 0) or bad_test_loss >= 3:
268
- learning_rate = 0.7 * learning_rate
269
- optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
270
- if best_model_state is not None:
271
- model.load_state_dict(best_model_state)
272
- print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
273
- bad_test_loss = 0
274
- if epoch > 0 and (epoch) == start_key:
275
- print("Now it's training Keypoints also")
276
- loss_config['loss_keypoint'] = True
277
- for name, param in model.named_parameters():
278
- if 'keypoint' in name:
279
- param.requires_grad = True
280
-
281
- model.train()
282
- start = time.time()
283
- total_loss = 0
284
-
285
- # Initialize lists to keep track of individual losses
286
- loss_classifier_list = []
287
- loss_box_reg_list = []
288
- loss_objectness_list = []
289
- loss_rpn_box_reg_list = []
290
- loss_keypoints_list = []
291
-
292
- # Create a tqdm progress bar
293
- progress_bar = tqdm(data_loader, desc=f'Epoch {epoch + 1 + start_epoch}')
294
-
295
- for images, targets_im in progress_bar:
296
- images = [image.to(device) for image in images]
297
- targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
298
-
299
- optimizer.zero_grad()
300
-
301
- loss_dict = model(images, targets)
302
- # Inside the training loop where losses are calculated:
303
- losses = 0
304
- if loss_config is not None:
305
- for key, loss in loss_dict.items():
306
- if loss_config.get(key, False):
307
- if key == 'loss_classifier':
308
- loss *= 3
309
- losses += loss
310
- else:
311
- losses = sum(loss for key, loss in loss_dict.items())
312
-
313
- # Collect individual losses
314
- if loss_dict['loss_classifier']:
315
- loss_classifier_list.append(loss_dict['loss_classifier'].item())
316
- else:
317
- loss_classifier_list.append(0)
318
-
319
- if loss_dict['loss_box_reg']:
320
- loss_box_reg_list.append(loss_dict['loss_box_reg'].item())
321
- else:
322
- loss_box_reg_list.append(0)
323
-
324
- if loss_dict['loss_objectness']:
325
- loss_objectness_list.append(loss_dict['loss_objectness'].item())
326
- else:
327
- loss_objectness_list.append(0)
328
-
329
- if loss_dict['loss_rpn_box_reg']:
330
- loss_rpn_box_reg_list.append(loss_dict['loss_rpn_box_reg'].item())
331
- else:
332
- loss_rpn_box_reg_list.append(0)
333
 
334
- if 'loss_keypoint' in loss_dict:
335
- loss_keypoints_list.append(loss_dict['loss_keypoint'].item())
336
- else:
337
- loss_keypoints_list.append(0)
338
-
339
- losses.backward()
340
- optimizer.step()
341
-
342
- total_loss += losses.item()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
- # Update the description with the current loss
345
- progress_bar.set_description(f'Epoch {epoch + 1 + start_epoch}, Loss: {losses.item():.4f}')
346
-
347
- # Calculate average loss
348
- avg_loss = total_loss / len(data_loader)
349
-
350
- epoch_avg_losses.append(avg_loss)
351
- epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
352
- epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
353
- epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
354
- epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
355
- epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
356
-
357
- # Evaluate the model on the test set
358
- if eval_metric == 'loss':
359
- labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0, 0, 0, 0, 0, 0
360
- avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
361
- print(f"Epoch {epoch + 1 + start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
362
- else:
363
- avg_test_loss = 0
364
- labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader, score_threshold=0.5, iou_threshold=0.5, distance_threshold=10, key_correction=False, model_type=model_type)
365
- print(f"Epoch {epoch + 1 + start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
366
- avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
367
- print(f"Epoch {epoch + 1 + start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
368
-
369
- print(f"Time: {time.time() - start:.2f} [s]")
370
-
371
- if eval_metric == 'f1_score':
372
- metric_used = f1_score
373
- elif eval_metric == 'precision':
374
- metric_used = precision
375
- elif eval_metric == 'recall':
376
- metric_used = recall
377
- else:
378
- metric_used = -avg_test_loss
379
-
380
- # Check if this epoch's model has the lowest average loss
381
- if metric_used > best_metrics:
382
- best_metrics = metric_used
383
- best_epoch = epoch + 1 + start_epoch
384
- best_model_state = copy.deepcopy(model.state_dict())
385
-
386
- if epoch > 0 and f1_score > early_stop_f1_score:
387
- same += 1
388
-
389
- epoch_precision.append(precision)
390
- epoch_recall.append(recall)
391
- epoch_f1_score.append(f1_score)
392
- epoch_test_loss.append(avg_test_loss)
393
-
394
- name_model = f"model_{type(optimizer).__name__}_{epoch + 1 + start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob * 10)}_crop0{int(crop_prob * 10)}_flip0{int(h_flip_prob * 10)}_rotate0{int(rotate_proba * 10)}_{information_training}"
395
- metrics_list = [epoch_avg_losses, epoch_avg_loss_classifier, epoch_avg_loss_box_reg, epoch_avg_loss_objectness, epoch_avg_loss_rpn_box_reg, epoch_avg_loss_keypoints, epoch_precision, epoch_recall, epoch_f1_score, epoch_test_loss]
396
-
397
- if same >= 1:
398
- torch.save(best_model_state, './models/' + name_model + '.pth')
399
- write_results(name_model, metrics_list, start_epoch)
400
- break
401
-
402
- if (epoch + 1 + start_epoch) % 5 == 0:
403
- torch.save(best_model_state, './models/' + name_model + '.pth')
404
- model.load_state_dict(best_model_state)
405
- write_results(name_model, metrics_list, start_epoch)
406
-
407
- if avg_test_loss > previous_test_loss:
408
- bad_test_loss += 1
409
- previous_test_loss = avg_test_loss
410
-
411
- print(f"\n Total time: {(time.time() - start_tot) / 60} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metrics:.4f}")
412
- if best_model_state:
413
  torch.save(best_model_state, './models/' + name_model + '.pth')
414
  model.load_state_dict(best_model_state)
415
  write_results(name_model, metrics_list, start_epoch)
416
- print(f"Name of the best model: model_{type(optimizer).__name__}_{epoch + 1 + start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob * 10)}_crop0{int(crop_prob * 10)}_flip0{int(h_flip_prob * 10)}_rotate0{int(rotate_proba * 10)}_{information_training}")
417
 
418
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
88
  # Load the model weights
89
  if model_to_load:
90
+ model.load_state_dict(torch.load(model_to_load + '.pth', map_location=device))
91
  print(f"Model '{model_to_load}' loaded")
92
 
93
  model.to(device)
 
191
 
192
 
193
  def training_model(num_epochs, model, data_loader, subset_test_loader,
194
+ optimizer, model_to_load=None, change_learning_rate=100, start_key=100, save_every=5,
195
  parameters=None, blur_prob=0.02,
196
  score_threshold=0.7, iou_threshold=0.5, early_stop_f1_score=0.97,
197
  information_training='training', start_epoch=0, loss_config=None, model_type='object',
198
  eval_metric='f1_score', device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
 
200
+ # Set the model to training mode
201
+ model.train()
202
+
203
+ if loss_config is None:
204
+ print('No loss config found, all losses will be used.')
205
+ else:
206
+ # Print the list of the losses that will be used
207
+ print('The following losses will be used: ', end='')
208
+ for key, value in loss_config.items():
209
+ if value:
210
+ print(key, end=", ")
211
+ print()
212
+
213
+ # Initialize lists to store epoch-wise average losses and other metrics
214
+ epoch_avg_losses = []
215
+ epoch_avg_loss_classifier = []
216
+ epoch_avg_loss_box_reg = []
217
+ epoch_avg_loss_objectness = []
218
+ epoch_avg_loss_rpn_box_reg = []
219
+ epoch_avg_loss_keypoints = []
220
+ epoch_precision = []
221
+ epoch_recall = []
222
+ epoch_f1_score = []
223
+ epoch_test_loss = []
224
+
225
+ start_tot = time.time()
226
+ best_metric_value = -1000
227
+ best_epoch = 0
228
+ best_model_state = None
229
+ epochs_with_high_f1 = 0
230
+ learning_rate = optimizer.param_groups[0]['lr']
231
+ bad_test_loss_epochs = 0
232
+ previous_test_loss = 1000
233
+
234
+ if parameters is not None:
235
+ batch_size, crop_prob, rotate_90_proba, h_flip_prob, v_flip_prob, max_rotate_deg, rotate_proba, keep_ratio = parameters.values()
236
+
237
+ print(f"Let's go training {model_type} model with {num_epochs} epochs!")
238
+ if parameters is not None:
239
+ print(f"Learning rate: {learning_rate}, Batch size: {batch_size}, Crop prob: {crop_prob}, H flip prob: {h_flip_prob}, V flip prob: {v_flip_prob}, Max rotate deg: {max_rotate_deg}, Rotate proba: {rotate_proba}, Rotate 90 proba: {rotate_90_proba}, Keep ratio: {keep_ratio}")
240
+
241
+ for epoch in range(num_epochs):
242
+
243
+ if (epoch > 0 and epoch % change_learning_rate == 0) or bad_test_loss_epochs >= 2:
244
+ learning_rate *= 0.7
245
+ optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=learning_rate, eps=1e-08, betas=(0.9, 0.999))
246
+ if best_model_state is not None:
247
+ model.load_state_dict(best_model_state)
248
+ print(f'Learning rate changed to {learning_rate:.4} and the best epoch for now is {best_epoch}')
249
+ bad_test_loss_epochs = 0
250
+
251
+ if epoch > 0 and epoch == start_key:
252
+ print("Now it's training Keypoints also")
253
+ loss_config['loss_keypoint'] = True
254
+ for name, param in model.named_parameters():
255
+ if 'keypoint' in name:
256
+ param.requires_grad = True
257
+
258
+ model.train()
259
+ start = time.time()
260
+ total_loss = 0
261
+
262
+ # Initialize lists to keep track of individual losses
263
+ loss_classifier_list = []
264
+ loss_box_reg_list = []
265
+ loss_objectness_list = []
266
+ loss_rpn_box_reg_list = []
267
+ loss_keypoints_list = []
268
+
269
+ # Create a tqdm progress bar
270
+ progress_bar = tqdm(data_loader, desc=f'Epoch {epoch+1+start_epoch}')
271
+
272
+ for images, targets_im in progress_bar:
273
+ images = [image.to(device) for image in images]
274
+ targets = [{k: v.clone().detach().to(device) for k, v in t.items()} for t in targets_im]
275
+
276
+ optimizer.zero_grad()
277
+
278
+ loss_dict = model(images, targets)
279
+ # Inside the training loop where losses are calculated:
280
+ losses = 0
281
+ if loss_config is not None:
282
+ for key, loss in loss_dict.items():
283
+ if loss_config.get(key, False):
284
+ if key == 'loss_classifier':
285
+ loss *= 3
286
+ losses += loss
287
+ else:
288
+ losses = sum(loss for key, loss in loss_dict.items())
289
+
290
+ # Collect individual losses
291
+ loss_classifier_list.append(loss_dict.get('loss_classifier', torch.tensor(0)).item())
292
+ loss_box_reg_list.append(loss_dict.get('loss_box_reg', torch.tensor(0)).item())
293
+ loss_objectness_list.append(loss_dict.get('loss_objectness', torch.tensor(0)).item())
294
+ loss_rpn_box_reg_list.append(loss_dict.get('loss_rpn_box_reg', torch.tensor(0)).item())
295
+ loss_keypoints_list.append(loss_dict.get('loss_keypoint', torch.tensor(0)).item())
296
+
297
+ losses.backward()
298
+ optimizer.step()
299
+
300
+ total_loss += losses.item()
301
+
302
+ # Update the description with the current loss
303
+ progress_bar.set_description(f'Epoch {epoch+1+start_epoch}, Loss: {losses.item():.4f}')
304
+
305
+ # Calculate average loss
306
+ avg_loss = total_loss / len(data_loader)
307
+
308
+ epoch_avg_losses.append(avg_loss)
309
+ epoch_avg_loss_classifier.append(np.mean(loss_classifier_list))
310
+ epoch_avg_loss_box_reg.append(np.mean(loss_box_reg_list))
311
+ epoch_avg_loss_objectness.append(np.mean(loss_objectness_list))
312
+ epoch_avg_loss_rpn_box_reg.append(np.mean(loss_rpn_box_reg_list))
313
+ epoch_avg_loss_keypoints.append(np.mean(loss_keypoints_list))
314
+
315
+ # Evaluate the model on the test set
316
+ if eval_metric == 'loss':
317
+ labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = 0, 0, 0, 0, 0, 0
318
+ avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
319
+ print(f"Epoch {epoch+1+start_epoch}, Average Training Loss: {avg_loss:.4f}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
320
+ else:
321
+ avg_test_loss = 0
322
+ labels_precision, precision, recall, f1_score, key_accuracy, reverted_accuracy = main_evaluation(model, subset_test_loader, score_threshold=score_threshold, iou_threshold=iou_threshold, distance_threshold=10, key_correction=False, model_type=model_type)
323
+ print(f"Epoch {epoch+1+start_epoch}, Average Loss: {avg_loss:.4f}, Labels_precision: {labels_precision:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1_score:.4f} ", end=", ")
324
+ avg_test_loss = evaluate_loss(model, subset_test_loader, device, loss_config)
325
+ print(f"Epoch {epoch+1+start_epoch}, Average Test Loss: {avg_test_loss:.4f}", end=", ")
326
+
327
+ print(f"Time: {time.time() - start:.2f} [s]")
328
+
329
+ if eval_metric == 'f1_score':
330
+ metric_used = f1_score
331
+ elif eval_metric == 'precision':
332
+ metric_used = precision
333
+ elif eval_metric == 'recall':
334
+ metric_used = recall
335
+ else:
336
+ metric_used = -avg_test_loss
337
+
338
+ # Check if this epoch's model has the best evaluation metric
339
+ if metric_used > best_metric_value:
340
+ best_metric_value = metric_used
341
+ best_epoch = epoch + 1 + start_epoch
342
+ best_model_state = copy.deepcopy(model.state_dict())
343
+
344
+ if epoch > 0 and f1_score > early_stop_f1_score:
345
+ epochs_with_high_f1 += 1
346
+
347
+ epoch_precision.append(precision)
348
+ epoch_recall.append(recall)
349
+ epoch_f1_score.append(f1_score)
350
+ epoch_test_loss.append(avg_test_loss)
351
+
352
+ name_model = f"model_{type(optimizer).__name__}_{epoch+1+start_epoch}ep_{batch_size}batch_trainval_blur0{int(blur_prob*10)}_crop0{int(crop_prob*10)}_flip0{int(h_flip_prob*10)}_rotate0{int(rotate_proba*10)}_{information_training}"
353
+ metrics_list = [epoch_avg_losses, epoch_avg_loss_classifier, epoch_avg_loss_box_reg, epoch_avg_loss_objectness, epoch_avg_loss_rpn_box_reg, epoch_avg_loss_keypoints, epoch_precision, epoch_recall, epoch_f1_score, epoch_test_loss]
354
+
355
+ if epochs_with_high_f1 >= 1:
356
+ torch.save(best_model_state, './models/' + name_model + '.pth')
357
+ write_results(name_model, metrics_list, start_epoch)
358
+ break
359
 
360
+ if (epoch + 1 + start_epoch) % save_every == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  torch.save(best_model_state, './models/' + name_model + '.pth')
362
  model.load_state_dict(best_model_state)
363
  write_results(name_model, metrics_list, start_epoch)
 
364
 
365
+ if avg_test_loss > previous_test_loss:
366
+ bad_test_loss_epochs += 1
367
+ previous_test_loss = avg_test_loss
368
+
369
+ print(f"\nTotal time: {(time.time() - start_tot) / 60:.2f} minutes, Best Epoch is {best_epoch} with an {eval_metric} of {best_metric_value:.4f}")
370
+
371
+ if best_model_state:
372
+ torch.save(best_model_state, './models/' + name_model + '.pth')
373
+ model.load_state_dict(best_model_state)
374
+ write_results(name_model, metrics_list, start_epoch)
375
+ print(f"Name of the best model: {name_model}")
376
+
377
+ return model