Evanjaa commited on
Commit
2fdb677
1 Parent(s): 5c58464

Upload 2 files

Browse files
Files changed (2) hide show
  1. feature_extraction.ipynb +433 -0
  2. train_classifier.ipynb +745 -0
feature_extraction.ipynb ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Start to finish - DINOv2 feature extraction"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "markdown",
12
+ "metadata": {
13
+ "jp-MarkdownHeadingCollapsed": true
14
+ },
15
+ "source": [
16
+ "## Imports"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {
23
+ "id": "3AdjGBwjnr-5"
24
+ },
25
+ "outputs": [],
26
+ "source": [
27
+ "from transformers import AutoImageProcessor, AutoModel\n",
28
+ "from PIL import Image\n",
29
+ "\n",
30
+ "\n",
31
+ "import matplotlib.pyplot as plt\n",
32
+ "import numpy as np\n",
33
+ "import requests\n",
34
+ "import torch\n",
35
+ "import cv2\n",
36
+ "import os"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "metadata": {
42
+ "id": "qvTYvSVOkLLL"
43
+ },
44
+ "source": [
45
+ "## Initialize pre-trained image processor and model"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": null,
51
+ "metadata": {
52
+ "colab": {
53
+ "base_uri": "https://localhost:8080/"
54
+ },
55
+ "id": "aRlCk-Tlj8Iv",
56
+ "outputId": "fb51843c-598f-48ad-a1c0-cf8d9bab53f4",
57
+ "scrolled": true
58
+ },
59
+ "outputs": [],
60
+ "source": [
61
+ "# Adjust for cuda - takes up 2193 MiB on device\n",
62
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
63
+ "\n",
64
+ "processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')\n",
65
+ "model = AutoModel.from_pretrained('facebook/dinov2-large').to(device)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "metadata": {
71
+ "jp-MarkdownHeadingCollapsed": true
72
+ },
73
+ "source": [
74
+ "## DINOv2 Feature Extraction"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "from tqdm import tqdm\n",
84
+ "import gc\n",
85
+ "\n",
86
+ "torch.cuda.empty_cache() \n",
87
+ "gc.collect()"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "metadata": {
94
+ "id": "Crq7KD84qz5d"
95
+ },
96
+ "outputs": [],
97
+ "source": [
98
+ "# Path to your videos\n",
99
+ "path_to_videos = './dataset-tacdec/videos'\n",
100
+ "\n",
101
+ "# Directory paths\n",
102
+ "processed_features_dir = './processed_features'\n",
103
+ "last_hidden_states_dir = os.path.join(processed_features_dir, 'last_hidden_states/')\n",
104
+ "pooler_outputs_dir = os.path.join(processed_features_dir, 'pooler_outputs/')\n",
105
+ "\n",
106
+ "# Create directories if they don't exist\n",
107
+ "os.makedirs(last_hidden_states_dir, exist_ok=True)\n",
108
+ "os.makedirs(pooler_outputs_dir, exist_ok=True)\n",
109
+ "\n",
110
+ "# Dictonary with filename as key, all feature extracted frames as values\n",
111
+ "feature_extracted_videos = {}\n",
112
+ "\n",
113
+ "# Define batch size\n",
114
+ "batch_size = 32\n",
115
+ "\n",
116
+ "# Process each video\n",
117
+ "for video_file in tqdm(os.listdir(path_to_videos)):\n",
118
+ " full_path = os.path.join(path_to_videos, video_file)\n",
119
+ "\n",
120
+ " if not os.path.isfile(full_path):\n",
121
+ " continue\n",
122
+ "\n",
123
+ " cap = cv2.VideoCapture(full_path)\n",
124
+ "\n",
125
+ " # List to hold all batch outputs, clear for each video\n",
126
+ " batch_last_hidden_states = []\n",
127
+ " batch_pooler_outputs = []\n",
128
+ " \n",
129
+ " batch_frames = []\n",
130
+ "\n",
131
+ " while True:\n",
132
+ " ret, frame = cap.read()\n",
133
+ " if not ret:\n",
134
+ " \n",
135
+ " # Process the last batch\n",
136
+ " if len(batch_frames) > 0:\n",
137
+ " inputs = processor(images=batch_frames, return_tensors=\"pt\").to(device)\n",
138
+ " \n",
139
+ " with torch.no_grad():\n",
140
+ " outputs = model(**inputs)\n",
141
+ " \n",
142
+ " for key, value in outputs.items():\n",
143
+ " if key == 'last_hidden_state':\n",
144
+ " # batch_last_hidden_states.append(value.cpu().numpy())\n",
145
+ " batch_last_hidden_states.append(value)\n",
146
+ " elif key == 'pooler_output':\n",
147
+ " # batch_pooler_outputs.append(value.cpu().numpy())\n",
148
+ " batch_pooler_outputs.append(value)\n",
149
+ " else:\n",
150
+ " print('Error in key, expected last_hidden_state or pooler_output, got: ', key)\n",
151
+ " break\n",
152
+ "\n",
153
+ " # cv2 comes in BGR, but transformer takes RGB\n",
154
+ " frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n",
155
+ " batch_frames.append(frame_rgb)\n",
156
+ "\n",
157
+ " # Check if batch is full\n",
158
+ " if len(batch_frames) == batch_size:\n",
159
+ " inputs = processor(images=batch_frames, return_tensors=\"pt\").to(device)\n",
160
+ " # outputs = model(**inputs)\n",
161
+ " with torch.no_grad():\n",
162
+ " outputs = model(**inputs)\n",
163
+ " for key, value in outputs.items():\n",
164
+ " if key == 'last_hidden_state':\n",
165
+ " batch_last_hidden_states.append(value)\n",
166
+ " elif key == 'pooler_output':\n",
167
+ " batch_pooler_outputs.append(value)\n",
168
+ " else:\n",
169
+ " print('Error in key, expected last_hidden_state or pooler_output, got: ', key)\n",
170
+ "\n",
171
+ " # Clear batch\n",
172
+ " batch_frames = []\n",
173
+ "\n",
174
+ " \n",
175
+ " all_last_hidden_states = torch.cat(batch_last_hidden_states, dim=0)\n",
176
+ " all_pooler_outputs = torch.cat(batch_pooler_outputs, dim=0)\n",
177
+ "\n",
178
+ " # Save the tensors with the video name as filename\n",
179
+ " pt_filename = video_file.replace('.mp4', '.pt')\n",
180
+ " torch.save(all_last_hidden_states, os.path.join(last_hidden_states_dir, f'{pt_filename}'))\n",
181
+ " torch.save(all_pooler_outputs, os.path.join(pooler_outputs_dir, f'{pt_filename}'))\n",
182
+ " \n",
183
+ "print('Features extracted')"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "metadata": {
189
+ "jp-MarkdownHeadingCollapsed": true
190
+ },
191
+ "source": [
192
+ "## Reload features to verify "
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "execution_count": null,
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "lhs_torch = torch.load('./processed_features/last_hidden_states/1738_avxeiaxxw6ocr.pt')\n",
202
+ "po_torch = torch.load('./processed_features/pooler_outputs/1738_avxeiaxxw6ocr.pt')\n",
203
+ "\n",
204
+ "print('LHS Torch size: ', lhs_torch.size())\n",
205
+ "print('PO Torch size: ', po_torch.size())\n",
206
+ "\n",
207
+ "for i in range(all_last_hidden_states.size(0)):\n",
208
+ " print(f\"Frame {i}:\")\n",
209
+ " print(all_last_hidden_states[i])\n",
210
+ " print() \n",
211
+ " break\n",
212
+ "\n",
213
+ "for i in range(lhs_torch.size(0)):\n",
214
+ " print(f\"Frame {i}:\")\n",
215
+ " print(all_last_hidden_states[i])\n",
216
+ " print() \n",
217
+ " break\n"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "metadata": {},
223
+ "source": [
224
+ "# Different sorts of plots"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "markdown",
229
+ "metadata": {},
230
+ "source": [
231
+ "## Histogram of video length in seconds"
232
+ ]
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "execution_count": null,
237
+ "metadata": {},
238
+ "outputs": [],
239
+ "source": [
240
+ "import os\n",
241
+ "import cv2\n",
242
+ "import numpy as np\n",
243
+ "\n",
244
+ "path_to_videos = './dataset-tacdec/videos'\n",
245
+ "video_lengths = []\n",
246
+ "frame_counts = []\n",
247
+ "\n",
248
+ "# Iterate through each file in the directory\n",
249
+ "for video_file in os.listdir(path_to_videos):\n",
250
+ " full_path = os.path.join(path_to_videos, video_file)\n",
251
+ "\n",
252
+ " if not os.path.isfile(full_path):\n",
253
+ " continue\n",
254
+ "\n",
255
+ " cap = cv2.VideoCapture(full_path)\n",
256
+ "\n",
257
+ " # Calculate the length of the video\n",
258
+ " # Note: Assuming the frame rate information is accurate\n",
259
+ " if cap.isOpened():\n",
260
+ " fps = cap.get(cv2.CAP_PROP_FPS) # Frame rate\n",
261
+ " frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n",
262
+ " duration = frame_count / fps if fps > 0 else 0\n",
263
+ " video_lengths.append(duration)\n",
264
+ " frame_counts.append(frame_count)\n",
265
+ "\n",
266
+ " cap.release()\n",
267
+ "\n",
268
+ "np.save('./video_durations', video_lengths)\n",
269
+ "np.save('./frame_counts', frame_counts)\n"
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "import seaborn as sns\n",
279
+ "\n",
280
+ "# Set the aesthetic style of the plots\n",
281
+ "sns.set(style=\"darkgrid\")\n",
282
+ "\n",
283
+ "# Plotting the histogram for video lengths\n",
284
+ "plt.figure(figsize=(12, 6))\n",
285
+ "sns.histplot(video_lengths, kde=True, color=\"blue\")\n",
286
+ "plt.title('Histogram - Video Lengths')\n",
287
+ "plt.xlabel('Length of Videos (seconds)')\n",
288
+ "plt.ylabel('Number of Videos')\n",
289
+ "\n",
290
+ "# Plotting the histogram for frame counts\n",
291
+ "plt.figure(figsize=(12, 6))\n",
292
+ "sns.histplot(frame_counts, kde=True, color=\"green\")\n",
293
+ "plt.title('Histogram - Number of Frames')\n",
294
+ "plt.xlabel('Frame Count')\n",
295
+ "plt.ylabel('Number of Videos')\n",
296
+ "\n",
297
+ "plt.show()"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "markdown",
302
+ "metadata": {
303
+ "jp-MarkdownHeadingCollapsed": true
304
+ },
305
+ "source": [
306
+ "## Frame count and vid lengths"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "sns.boxplot(x=video_lengths)\n",
316
+ "plt.title('Box Plot of Video Lengths')\n",
317
+ "plt.xlabel('Video Length (seconds)')\n",
318
+ "plt.show()\n",
319
+ "\n",
320
+ "sns.boxplot(x=frame_counts, color=\"r\")\n",
321
+ "plt.title('Box Plot of Frame Counts')\n",
322
+ "plt.xlabel('Frame Count')\n",
323
+ "plt.show()\n"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "metadata": {},
329
+ "source": [
330
+ "## Class distributions"
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "metadata": {
337
+ "scrolled": true
338
+ },
339
+ "outputs": [],
340
+ "source": [
341
+ "import os\n",
342
+ "import json\n",
343
+ "import pandas as pd\n",
344
+ "import matplotlib.pyplot as plt\n",
345
+ "import seaborn as sns\n",
346
+ "\n",
347
+ "path_to_labels = './dataset-tacdec/full_labels'\n",
348
+ "class_counts = {'background': 0, 'tackle-live': 0, 'tackle-replay': 0, 'tackle-live-incomplete': 0, 'tackle-replay-incomplete': 0, 'dummy_class': 0}\n",
349
+ "\n",
350
+ "# Iterate through each JSON file in the labels directory\n",
351
+ "for label_file in os.listdir(path_to_labels):\n",
352
+ " full_path = os.path.join(path_to_labels, label_file)\n",
353
+ "\n",
354
+ " if not os.path.isfile(full_path):\n",
355
+ " continue\n",
356
+ "\n",
357
+ " with open(full_path, 'r') as file:\n",
358
+ " data = json.load(file)\n",
359
+ " frame_sections = data['frames_sections']\n",
360
+ "\n",
361
+ " # Extract annotations\n",
362
+ " for section in frame_sections:\n",
363
+ " for frame_number, frame_data in section.items():\n",
364
+ " class_label = frame_data['radio_answer']\n",
365
+ " if class_label in class_counts:\n",
366
+ " class_counts[class_label] += 1\n",
367
+ "\n",
368
+ "# Convert the dictionary to a DataFrame for Seaborn\n",
369
+ "df_class_counts = pd.DataFrame(list(class_counts.items()), columns=['Class', 'Occurrences'])\n",
370
+ "\n",
371
+ "# Save the DataFrame to a CSV file\n",
372
+ "df_class_counts.to_csv('class_distribution.csv', sep=',', index=False, encoding='utf-8')\n",
373
+ "\n",
374
+ "# Plotting the distribution using Seaborn\n",
375
+ "plt.figure(figsize=(10, 6))\n",
376
+ "sns.barplot(x='Class', y='Occurrences', data=df_class_counts, palette='viridis', alpha=0.75)\n",
377
+ "plt.title('Distribution of Frame Classes')\n",
378
+ "plt.xlabel('Class')\n",
379
+ "plt.ylabel('Number of Occurrences')\n",
380
+ "plt.xticks(rotation=45) # Rotate class names for better readability\n",
381
+ "plt.tight_layout() # Adjust layout to make room for the rotated x-axis labels\n",
382
+ "plt.show()\n"
383
+ ]
384
+ },
385
+ {
386
+ "cell_type": "code",
387
+ "execution_count": null,
388
+ "metadata": {},
389
+ "outputs": [],
390
+ "source": [
391
+ "import pandas as pd\n",
392
+ "import matplotlib.pyplot as plt\n",
393
+ "\n",
394
+ "# Ensure df_class_counts is already created as in the previous script\n",
395
+ "\n",
396
+ "# Create a pie chart\n",
397
+ "plt.figure(figsize=(8, 8))\n",
398
+ "plt.pie(df_class_counts['Occurrences'], labels=df_class_counts['Class'], \n",
399
+ " autopct=lambda p: '{:.1f}%'.format(p), startangle=140, \n",
400
+ " colors=sns.color_palette('bright', len(df_class_counts)))\n",
401
+ "plt.title('Distribution of Frame Classes', fontweight='bold')\n",
402
+ "plt.show()"
403
+ ]
404
+ }
405
+ ],
406
+ "metadata": {
407
+ "colab": {
408
+ "collapsed_sections": [
409
+ "uzdIsbuEpF2w"
410
+ ],
411
+ "provenance": []
412
+ },
413
+ "kernelspec": {
414
+ "display_name": "Python (evan31818)",
415
+ "language": "python",
416
+ "name": "evan31818"
417
+ },
418
+ "language_info": {
419
+ "codemirror_mode": {
420
+ "name": "ipython",
421
+ "version": 3
422
+ },
423
+ "file_extension": ".py",
424
+ "mimetype": "text/x-python",
425
+ "name": "python",
426
+ "nbconvert_exporter": "python",
427
+ "pygments_lexer": "ipython3",
428
+ "version": "3.8.18"
429
+ }
430
+ },
431
+ "nbformat": 4,
432
+ "nbformat_minor": 0
433
+ }
train_classifier.ipynb ADDED
@@ -0,0 +1,745 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "27933625-f946-4fce-a622-e92ea518fad1",
6
+ "metadata": {
7
+ "jp-MarkdownHeadingCollapsed": true
8
+ },
9
+ "source": [
10
+ "## 1. Mandatory"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "id": "8674dce1-4885-4bc9-8b90-1d847c38e6f1",
17
+ "metadata": {},
18
+ "outputs": [],
19
+ "source": [
20
+ "from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, accuracy_score\n",
21
+ "from torch.utils.data import TensorDataset, DataLoader\n",
22
+ "from sklearn.model_selection import train_test_split\n",
23
+ "\n",
24
+ "import matplotlib.pyplot as plt\n",
25
+ "import torch.optim as optim\n",
26
+ "import torch.nn as nn\n",
27
+ "import seaborn as sns\n",
28
+ "import numpy as np\n",
29
+ "import torch\n",
30
+ "import json\n",
31
+ "import os"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "id": "46a4597f",
37
+ "metadata": {},
38
+ "source": [
39
+ "# 2. Complete below - if you did not download DINOv2 cls-tokens together with the labels - Skip to step 3 if done."
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "id": "1f1bd72b-ed98-4669-908c-2b103bcacda5",
45
+ "metadata": {},
46
+ "source": [
47
+ "## Load labels"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "id": "98e09803-9862-4e29-aaff-3bdcd4e0fe53",
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "# Paths to labels\n",
58
+ "path_to_labels = '/home/evan/D1/project/code/start_end_labels'"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "id": "b41d5fd2-ee4a-4f02-98b9-887e48115c47",
65
+ "metadata": {},
66
+ "outputs": [],
67
+ "source": [
68
+ "# Should be 425 files, code just to verify\n",
69
+ "num_of_labels = 0\n",
70
+ "for ind, label in enumerate(os.listdir(path_to_labels)):\n",
71
+ " num_of_labels = ind+1\n",
72
+ "\n",
73
+ "num_of_labels"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "1ef791d8-a268-4436-ad18-150d645bef73",
80
+ "metadata": {},
81
+ "outputs": [],
82
+ "source": [
83
+ "list_of_labels = []\n",
84
+ "\n",
85
+ "categorical_mapping = {'background': 0, 'tackle-live': 1, 'tackle-replay': 2, 'tackle-live-incomplete': 3, 'tackle-replay-incomplete': 4}\n",
86
+ "\n",
87
+ "# Sort to make sure order is maintained\n",
88
+ "for ind, label in enumerate(sorted(os.listdir(path_to_labels))):\n",
89
+ " full_path = os.path.join(path_to_labels, label)\n",
90
+ "\n",
91
+ " with open(full_path, 'r') as file:\n",
92
+ " data = json.load(file)\n",
93
+ " \n",
94
+ " # Extract frame count\n",
95
+ " frame_count = data['media_attributes']['frame_count']\n",
96
+ "\n",
97
+ " # Extract tackles\n",
98
+ " tackles = data['events']\n",
99
+ " \n",
100
+ " labels_of_current_file = np.zeros(frame_count)\n",
101
+ " \n",
102
+ " for tackle in tackles:\n",
103
+ " # Extract variables\n",
104
+ " tackle_class = tackle['type']\n",
105
+ " start_frame = tackle['frame_start']\n",
106
+ " end_frame = tackle['frame_end']\n",
107
+ "\n",
108
+ " # Need to shift start_frame with -1 as array-indexing starts at 0, while \n",
109
+ " # frame count starts at 1\n",
110
+ " for i in range(start_frame-1, end_frame, 1):\n",
111
+ " labels_of_current_file[i] = categorical_mapping[tackle_class]\n",
112
+ "\n",
113
+ " list_of_labels.append(labels_of_current_file)\n"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "id": "b302d94a-d18c-4e41-929b-3c8f4d547afa",
119
+ "metadata": {},
120
+ "source": [
121
+ "## Verify that change is correct"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "286b27a8-1c9a-4ba9-9996-deeef7927195",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "test = list_of_labels[0]\n",
132
+ "\n",
133
+ "for i in range(len(test)):\n",
134
+ " # Should give [0,1,1,0] as 181-107 is the actual sequence, but its moved to 180-206 with array indexing\n",
135
+ " # starting from 0 instead of 1 like the frame counting.\n",
136
+ " if i == 179 or i == 180 or i == 206 or i == 207:\n",
137
+ " print(test[i])"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "id": "88650952-a098-4ae3-ba3b-d67f5d17c41b",
143
+ "metadata": {},
144
+ "source": [
145
+ "## Map incomplete class-labels to instances of their respective 'full-class'"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "id": "2c48db00-b367-4f38-aa59-de5164d11fe9",
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n",
156
+ "prev_list_of_labels = list_of_labels\n",
157
+ "\n",
158
+ "for i, label in enumerate(list_of_labels):\n",
159
+ " list_of_labels[i] = np.array([class_mapping[frame_class] for frame_class in label])"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "markdown",
164
+ "id": "ee69c1f0-db9d-4848-9b3c-2556e09d1991",
165
+ "metadata": {},
166
+ "source": [
167
+ "## Load DINOv2-features and extract CLS-tokens"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "20b2ee27-5d94-4301-9229-aa9486360a73",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "# Define path to DINOv2-features\n",
178
+ "path_to_tensors = '/home/evan/D1/project/code/processed_features/last_hidden_states'\n",
179
+ "path_to_first_tensor = '/home/evan/D1/project/code/processed_features/last_hidden_states/1738_avxeiaxxw6ocr.pt'\n",
180
+ "\n",
181
+ "all_cls_tokens = torch.load(path_to_first_tensor)[:,0,:]\n",
182
+ "\n",
183
+ "for index, tensor_file in enumerate(sorted(os.listdir(path_to_tensors))[1:]): # Start from the second item\n",
184
+ " full_path = os.path.join(path_to_tensors, tensor_file)\n",
185
+ " cls_token = torch.load(full_path)[:,0,:]\n",
186
+ " all_cls_tokens = torch.cat((all_cls_tokens, cls_token), dim=0)\n",
187
+ "\n",
188
+ "\n",
189
+ "# Should have shape: total_frames, feature_vector (1024)\n",
190
+ "print('CLS tokens shape: ', all_cls_tokens.shape)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "03c8f5ed-5b04-456d-a9fd-8d493878ea18",
196
+ "metadata": {},
197
+ "source": [
198
+ "### Reshape labels list"
199
+ ]
200
+ },
201
+ {
202
+ "cell_type": "code",
203
+ "execution_count": null,
204
+ "id": "c9bc68a4-5c33-43b6-a9e1-febb035ea2fb",
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "all_labels_concatenated = np.concatenate(list_of_labels, axis=0)\n",
209
+ "\n",
210
+ "# Length should be total number of frames\n",
211
+ "print('Length of all labels concatenated: ', len(all_labels_concatenated))\n",
212
+ "\n",
213
+ "\n",
214
+ "\n",
215
+ "# Map imcomplete instances to complete ones. As this approach only looks at 'background', 'tackle-live' and 'tackle-replay',\n",
216
+ "# the incomplete classes can be mapped to their respective others due to a single frame being part of the tackle whatsoever.\n",
217
+ "class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n",
218
+ "\n",
219
+ "for i, label in enumerate(all_labels_concatenated):\n",
220
+ " all_labels_concatenated[i] = class_mapping[label]"
221
+ ]
222
+ },
223
+ {
224
+ "cell_type": "markdown",
225
+ "id": "f644964d",
226
+ "metadata": {},
227
+ "source": [
228
+ "# 3. If you downloaded the DINOv2 cls-tokens together with the labels, follow below:"
229
+ ]
230
+ },
231
+ {
232
+ "cell_type": "markdown",
233
+ "id": "ab5f971c",
234
+ "metadata": {},
235
+ "source": [
236
+ "The next cell can be skipped if you completed step 1."
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": null,
242
+ "id": "5e2600aa",
243
+ "metadata": {},
244
+ "outputs": [],
245
+ "source": [
246
+ "\n",
247
+ "# Place the path to your cls tokens and labels downloaded below:\n",
248
+ "cls_path = '/home/evan/D1/project/code/full_concat_dino_features.pt'\n",
249
+ "labels_path = '/home/evan/D1/project/code/all_labels_concatenated.npy'\n",
250
+ "\n",
251
+ "all_cls_tokens = torch.load(cls_path)\n",
252
+ "all_labels_concatenated = np.load(labels_path)\n",
253
+ "\n",
254
+ "# Map imcomplete instances to complete ones. As this approach only looks at 'background', 'tackle-live' and 'tackle-replay',\n",
255
+ "# the incomplete classes can be mapped to their respective others due to a single frame being part of the tackle whatsoever.\n",
256
+ "class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n",
257
+ "\n",
258
+ "for i, label in enumerate(all_labels_concatenated):\n",
259
+ " all_labels_concatenated[i] = class_mapping[label]"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "id": "01b360a4",
265
+ "metadata": {},
266
+ "source": [
267
+ "# 4. Follow below "
268
+ ]
269
+ },
270
+ {
271
+ "cell_type": "markdown",
272
+ "id": "e4561d68-a149-4a00-9a7d-e0e69bbcfa53",
273
+ "metadata": {},
274
+ "source": [
275
+ "## Balance classes"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "markdown",
280
+ "id": "68e2e245-36d3-464e-85ae-6d5f30ebe164",
281
+ "metadata": {},
282
+ "source": [
283
+ "### Move cls-tokens to CPU"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "id": "61b8a9fe-d3ac-4d6c-b0a9-5c32a2593495",
290
+ "metadata": {},
291
+ "outputs": [],
292
+ "source": [
293
+ "all_cls_tokens = np.array([e.cpu().numpy() for e in all_cls_tokens])\n",
294
+ "print('Tensor shape after reshaping: ', all_cls_tokens.shape)"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "markdown",
299
+ "id": "b6074527-9ddc-4b9e-b933-a6c5af9cd134",
300
+ "metadata": {},
301
+ "source": [
302
+ "### Verify that order is correct"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": null,
308
+ "id": "ea1425ae-6588-4c71-8a08-7f9c0adc7422",
309
+ "metadata": {},
310
+ "outputs": [],
311
+ "source": [
312
+ "for i in range(len(all_labels_concatenated)):\n",
313
+ " # Should give [0,1,1,0] as 181-107 is the actual sequence, but its moved to 180-206 with array indexing\n",
314
+ " # starting from 0 instead of 1 like the frame counting.\n",
315
+ " if i == 179 or i == 180 or i == 206 or i == 207:\n",
316
+ " print(all_labels_concatenated[i])\n",
317
+ "\n",
318
+ " if i > 210:\n",
319
+ " break"
320
+ ]
321
+ },
322
+ {
323
+ "cell_type": "markdown",
324
+ "id": "6e851954-e2d7-41fd-956f-92df09a79e8b",
325
+ "metadata": {},
326
+ "source": [
327
+ "### Class for balancing distribution of classes"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": null,
333
+ "id": "479daf78-11c0-4ded-9bb3-8fa34d12c6d7",
334
+ "metadata": {},
335
+ "outputs": [],
336
+ "source": [
337
+ "def balance_classes(X, y):\n",
338
+ " unique, counts = np.unique(y, return_counts=True)\n",
339
+ " min_samples = counts.min()\n",
340
+ " # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n",
341
+ " # target_samples = int(2.0 * min_samples)\n",
342
+ " target_samples = 5000\n",
343
+ " \n",
344
+ " indices_to_keep = np.hstack([\n",
345
+ " np.random.choice(\n",
346
+ " np.where(y == label)[0], \n",
347
+ " min(target_samples, counts[unique.tolist().index(label)]), # Ensure not to exceed the actual count\n",
348
+ " replace=False\n",
349
+ " ) for label in unique\n",
350
+ " ])\n",
351
+ " \n",
352
+ " return X[indices_to_keep], y[indices_to_keep]"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "markdown",
357
+ "id": "6cf24d79-27d7-499e-b856-e58938cef5e7",
358
+ "metadata": {},
359
+ "source": [
360
+ "### Split into train and test, without shuffle to remain order"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": null,
366
+ "id": "9c9fbaec-2849-48d0-867d-e0ad39682135",
367
+ "metadata": {},
368
+ "outputs": [],
369
+ "source": [
370
+ "X_train, X_test, y_train, y_test = train_test_split(all_cls_tokens, all_labels_concatenated, test_size=0.2, shuffle=False, stratify=None)"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "id": "35fa46bb-258a-4b6e-a8c0-56c47c791d55",
377
+ "metadata": {},
378
+ "outputs": [],
379
+ "source": [
380
+ "X_train_balanced, y_train_balanced = balance_classes(X_train, y_train)\n",
381
+ "X_test_balanced, y_test_balanced = balance_classes(X_test, y_test)\n",
382
+ "print(\"Total number of samples:\", len(all_labels_concatenated))\n",
383
+ "print(\"\")\n",
384
+ "\n",
385
+ "print('Total distribution of labels: \\n', np.unique(all_labels_concatenated, return_counts=True))\n",
386
+ "print(\"\")\n",
387
+ "\n",
388
+ "\n",
389
+ "print('Distribution within training set: \\n', np.unique(y_train_balanced, return_counts=True))\n",
390
+ "print(\"\")\n",
391
+ "\n",
392
+ "print('Distribution within test set: \\n', np.unique(y_test_balanced, return_counts=True))\n",
393
+ "print(\"\")\n",
394
+ "\n",
395
+ "\n",
396
+ "print('Training shape: ', X_train_balanced.shape, y_train_balanced.shape)\n",
397
+ "print(\"\")\n",
398
+ "\n",
399
+ "print('Test shape: ', X_test_balanced.shape, y_test_balanced.shape)\n",
400
+ "print(\"\")"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "execution_count": null,
406
+ "id": "5b6bf3b4-5d67-41b4-9c6b-8d02d3923366",
407
+ "metadata": {},
408
+ "outputs": [],
409
+ "source": [
410
+ "# Convert data to torch tensors\n",
411
+ "X_train = torch.tensor(X_train_balanced, dtype=torch.float32)\n",
412
+ "y_train = torch.tensor(y_train_balanced, dtype=torch.long)\n",
413
+ "X_test = torch.tensor(X_test_balanced, dtype=torch.float32)\n",
414
+ "y_test = torch.tensor(y_test_balanced, dtype=torch.long)"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "markdown",
419
+ "id": "7d7250f4-c820-4c00-9bde-77bdc3cdd2e2",
420
+ "metadata": {},
421
+ "source": [
422
+ "## Create dataset and Dataloaders"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "id": "532583ed-65e9-4339-b94d-6cdb704c0ed7",
429
+ "metadata": {},
430
+ "outputs": [],
431
+ "source": [
432
+ "# Create data loaders\n",
433
+ "batch_size = 64\n",
434
+ "train_dataset = TensorDataset(X_train, y_train)\n",
435
+ "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
436
+ "\n",
437
+ "test_dataset = TensorDataset(X_test, y_test)\n",
438
+ "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "id": "5ef7b5d4-04e1-4c2e-9476-2537a6785893",
444
+ "metadata": {},
445
+ "source": [
446
+ "## Model class"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": null,
452
+ "id": "d7120ab9-c016-4eba-9588-77afde98a639",
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": [
456
+ "import torch.nn as nn\n",
457
+ "import torch.nn.functional as F\n",
458
+ "\n",
459
+ "class MultiLayerClassifier(nn.Module):\n",
460
+ " def __init__(self, input_size, num_classes):\n",
461
+ " super(MultiLayerClassifier, self).__init__()\n",
462
+ " \n",
463
+ " self.fc1 = nn.Linear(input_size, 128, bias=True)\n",
464
+ " self.dropout1 = nn.Dropout(0.5) \n",
465
+ " \n",
466
+ " # self.fc2 = nn.Linear(512, 128)\n",
467
+ " # self.dropout2 = nn.Dropout(0.5)\n",
468
+ " \n",
469
+ " self.fc3 = nn.Linear(128, num_classes, bias=True)\n",
470
+ " \n",
471
+ " def forward(self, x):\n",
472
+ " x = F.relu(self.fc1(x))\n",
473
+ " x = self.dropout1(x)\n",
474
+ " # x = F.relu(self.fc2(x))\n",
475
+ " # x = self.dropout2(x)\n",
476
+ " x = self.fc3(x)\n",
477
+ " \n",
478
+ " return x\n",
479
+ "\n",
480
+ "model = MultiLayerClassifier(1024, 3)\n",
481
+ "model"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "markdown",
486
+ "id": "5b0ba056-0a73-466f-b65e-a3261e1a69f1",
487
+ "metadata": {},
488
+ "source": [
489
+ "## L1-regularization class"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "id": "ebd6211c-fc94-4557-947b-5a3fac89c1ba",
496
+ "metadata": {},
497
+ "outputs": [],
498
+ "source": [
499
+ "def l1_regularization(model, lambda_l1):\n",
500
+ " l1_penalty = torch.tensor(0.) # Ensure the penalty is on the same device as model parameters\n",
501
+ " for param in model.parameters():\n",
502
+ " l1_penalty += torch.norm(param, 1)\n",
503
+ " return lambda_l1 * l1_penalty"
504
+ ]
505
+ },
506
+ {
507
+ "cell_type": "markdown",
508
+ "id": "00735f1f-2bf9-4aae-90c2-61e44973f699",
509
+ "metadata": {},
510
+ "source": [
511
+ "## Loss, optimizer and L1-strength initialization"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": null,
517
+ "id": "c4efe9d8-fc72-4701-a1a9-d463c6b33dfa",
518
+ "metadata": {},
519
+ "outputs": [],
520
+ "source": [
521
+ "# Loss and optimizer\n",
522
+ "criterion = nn.CrossEntropyLoss()\n",
523
+ "optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n",
524
+ "lambda_l1 = 1e-3 # L1 regularization strength"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "markdown",
529
+ "id": "e87f7513-47d0-491e-9073-9289eda1b484",
530
+ "metadata": {},
531
+ "source": [
532
+ "## Training loop"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "execution_count": null,
538
+ "id": "4260c3bc-25c2-48f0-b79c-b6d7cc0c14eb",
539
+ "metadata": {},
540
+ "outputs": [],
541
+ "source": [
542
+ "epochs = 50\n",
543
+ "train_losses, test_losses = [], []\n",
544
+ "\n",
545
+ "for epoch in range(epochs):\n",
546
+ " model.train()\n",
547
+ " train_loss = 0\n",
548
+ " for X_batch, y_batch in train_loader:\n",
549
+ " optimizer.zero_grad()\n",
550
+ " outputs = model(X_batch)\n",
551
+ " loss = criterion(outputs, y_batch)\n",
552
+ "\n",
553
+ " # Calculate L1 regularization penalty\n",
554
+ " l1_penalty = l1_regularization(model, lambda_l1)\n",
555
+ " \n",
556
+ " # Add L1 penalty to the loss\n",
557
+ " loss += l1_penalty\n",
558
+ " \n",
559
+ " loss.backward()\n",
560
+ " optimizer.step()\n",
561
+ " train_loss += loss.item()\n",
562
+ " train_losses.append(train_loss / len(train_loader))\n",
563
+ "\n",
564
+ " model.eval()\n",
565
+ " test_loss = 0\n",
566
+ " all_preds, all_targets, all_outputs = [], [], []\n",
567
+ " with torch.no_grad():\n",
568
+ " for X_batch, y_batch in test_loader:\n",
569
+ " outputs = model(X_batch)\n",
570
+ " loss = criterion(outputs, y_batch)\n",
571
+ " test_loss += loss.item()\n",
572
+ " _, predicted = torch.max(outputs.data, 1)\n",
573
+ " all_preds.extend(predicted.numpy())\n",
574
+ " all_targets.extend(y_batch.numpy())\n",
575
+ " all_outputs.extend(outputs.numpy())\n",
576
+ " test_losses.append(test_loss / len(test_loader))\n",
577
+ " \n",
578
+ " precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n",
579
+ " accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n",
580
+ " if epoch % 10==0:\n",
581
+ " print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')"
582
+ ]
583
+ },
584
+ {
585
+ "cell_type": "markdown",
586
+ "id": "615f685e-fb19-46f8-afba-b76fb730ed49",
587
+ "metadata": {},
588
+ "source": [
589
+ "## Train- vs Test-loss graph"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "code",
594
+ "execution_count": null,
595
+ "id": "597b4570-1579-470e-8f11-f72b7b04b816",
596
+ "metadata": {},
597
+ "outputs": [],
598
+ "source": [
599
+ "plt.plot(train_losses, label='Train Loss')\n",
600
+ "plt.plot(test_losses, label='Test Loss')\n",
601
+ "plt.legend()\n",
602
+ "plt.title('Train vs Test Loss')\n",
603
+ "plt.xlabel('Epoch')\n",
604
+ "plt.ylabel('Loss')\n",
605
+ "plt.show()"
606
+ ]
607
+ },
608
+ {
609
+ "cell_type": "markdown",
610
+ "id": "1babe3bd-da5b-4f0d-9d83-9ca4d73922c5",
611
+ "metadata": {},
612
+ "source": [
613
+ "## Confusion matrix"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": null,
619
+ "id": "2c0b0fa3-814e-474c-bbe1-31152305e17b",
620
+ "metadata": {},
621
+ "outputs": [],
622
+ "source": [
623
+ "conf_matrix = confusion_matrix(all_targets, all_preds)\n",
624
+ "labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n",
625
+ " # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n",
626
+ "sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)\n",
627
+ "# plt.title('Confusion Matrix')\n",
628
+ "plt.xlabel('Predicted Label')\n",
629
+ "plt.ylabel('True Label')\n",
630
+ "plt.show()"
631
+ ]
632
+ },
633
+ {
634
+ "cell_type": "markdown",
635
+ "id": "480ddfd5-6ac4-46ed-92db-b556c8bfbd7d",
636
+ "metadata": {},
637
+ "source": [
638
+ "## ROC Curve"
639
+ ]
640
+ },
641
+ {
642
+ "cell_type": "code",
643
+ "execution_count": null,
644
+ "id": "ddc52d39-7612-43ad-ae44-345119122112",
645
+ "metadata": {},
646
+ "outputs": [],
647
+ "source": [
648
+ "from sklearn.metrics import roc_curve, auc\n",
649
+ "import matplotlib.pyplot as plt\n",
650
+ "\n",
651
+ "y_score= np.array(all_outputs)\n",
652
+ "fpr = dict()\n",
653
+ "tpr = dict()\n",
654
+ "roc_auc = dict()\n",
655
+ "n_classes = len(labels) \n",
656
+ "\n",
657
+ "y_test_one_hot = np.eye(n_classes)[y_test]\n",
658
+ "\n",
659
+ "for i in range(n_classes):\n",
660
+ " fpr[i], tpr[i], _ = roc_curve(y_test_one_hot[:, i], y_score[:, i])\n",
661
+ " roc_auc[i] = auc(fpr[i], tpr[i])\n",
662
+ "\n",
663
+ "# Plot all ROC curves\n",
664
+ "plt.figure()\n",
665
+ "colors = ['blue', 'red', 'green', 'darkorange', 'purple']\n",
666
+ "for i, color in zip(range(n_classes), colors):\n",
667
+ " plt.plot(fpr[i], tpr[i], color=color, lw=2,\n",
668
+ " label='ROC curve of class {0} (area = {1:0.2f})'\n",
669
+ " ''.format(labels[i], roc_auc[i]))\n",
670
+ "\n",
671
+ "plt.plot([0, 1], [0, 1], 'k--', lw=2)\n",
672
+ "plt.xlim([0.0, 1.0])\n",
673
+ "plt.ylim([0.0, 1.05])\n",
674
+ "plt.xlabel('False Positive Rate')\n",
675
+ "plt.ylabel('True Positive Rate')\n",
676
+ "print('Receiver operating characteristic for multi-class')\n",
677
+ "plt.legend(loc=\"lower right\")\n",
678
+ "plt.show()\n"
679
+ ]
680
+ },
681
+ {
682
+ "cell_type": "markdown",
683
+ "id": "45c05c14-99d8-49e6-ad64-7e6ad565c0ca",
684
+ "metadata": {},
685
+ "source": [
686
+ "## Multi-Class Precision-Recall Cruve"
687
+ ]
688
+ },
689
+ {
690
+ "cell_type": "code",
691
+ "execution_count": null,
692
+ "id": "3c779274-252f-4248-bf57-a07c665c618c",
693
+ "metadata": {},
694
+ "outputs": [],
695
+ "source": [
696
+ "from sklearn.metrics import precision_recall_curve\n",
697
+ "from sklearn.preprocessing import label_binarize\n",
698
+ "from itertools import cycle\n",
699
+ "\n",
700
+ "y_test_bin = label_binarize(y_test, classes=range(n_classes))\n",
701
+ "\n",
702
+ "precision_recall = {}\n",
703
+ "\n",
704
+ "for i in range(n_classes):\n",
705
+ " precision, recall, _ = precision_recall_curve(y_test_bin[:, i], y_score[:, i])\n",
706
+ " precision_recall[i] = (precision, recall)\n",
707
+ "\n",
708
+ "colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])\n",
709
+ "\n",
710
+ "plt.figure(figsize=(6, 4))\n",
711
+ "\n",
712
+ "for i, color in zip(range(n_classes), colors):\n",
713
+ " precision, recall = precision_recall[i]\n",
714
+ " plt.plot(recall, precision, color=color, lw=2, label=f'{labels[i]}')\n",
715
+ "\n",
716
+ "plt.xlabel('Recall')\n",
717
+ "plt.ylabel('Precision')\n",
718
+ "print('Multi-Class Precision-Recall Curve')\n",
719
+ "plt.legend(loc='best')\n",
720
+ "plt.show()"
721
+ ]
722
+ }
723
+ ],
724
+ "metadata": {
725
+ "kernelspec": {
726
+ "display_name": "Python (evan31818)",
727
+ "language": "python",
728
+ "name": "evan31818"
729
+ },
730
+ "language_info": {
731
+ "codemirror_mode": {
732
+ "name": "ipython",
733
+ "version": 3
734
+ },
735
+ "file_extension": ".py",
736
+ "mimetype": "text/x-python",
737
+ "name": "python",
738
+ "nbconvert_exporter": "python",
739
+ "pygments_lexer": "ipython3",
740
+ "version": "3.8.18"
741
+ }
742
+ },
743
+ "nbformat": 4,
744
+ "nbformat_minor": 5
745
+ }