{ "cells": [ { "cell_type": "markdown", "id": "27933625-f946-4fce-a622-e92ea518fad1", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## 1. Mandatory" ] }, { "cell_type": "code", "execution_count": null, "id": "8674dce1-4885-4bc9-8b90-1d847c38e6f1", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, accuracy_score\n", "from torch.utils.data import TensorDataset, DataLoader\n", "from sklearn.model_selection import train_test_split\n", "\n", "import matplotlib.pyplot as plt\n", "import torch.optim as optim\n", "import torch.nn as nn\n", "import seaborn as sns\n", "import numpy as np\n", "import torch\n", "import json\n", "import os" ] }, { "cell_type": "markdown", "id": "46a4597f", "metadata": {}, "source": [ "# 2. Complete below - if you did not download DINOv2 cls-tokens together with the labels - Skip to step 3 if done." ] }, { "cell_type": "markdown", "id": "1f1bd72b-ed98-4669-908c-2b103bcacda5", "metadata": {}, "source": [ "## Load labels" ] }, { "cell_type": "code", "execution_count": null, "id": "98e09803-9862-4e29-aaff-3bdcd4e0fe53", "metadata": {}, "outputs": [], "source": [ "# Paths to labels\n", "path_to_labels = '/home/evan/D1/project/code/start_end_labels'" ] }, { "cell_type": "code", "execution_count": null, "id": "b41d5fd2-ee4a-4f02-98b9-887e48115c47", "metadata": {}, "outputs": [], "source": [ "# Should be 425 files, code just to verify\n", "num_of_labels = 0\n", "for ind, label in enumerate(os.listdir(path_to_labels)):\n", " num_of_labels = ind+1\n", "\n", "num_of_labels" ] }, { "cell_type": "code", "execution_count": null, "id": "1ef791d8-a268-4436-ad18-150d645bef73", "metadata": {}, "outputs": [], "source": [ "list_of_labels = []\n", "\n", "categorical_mapping = {'background': 0, 'tackle-live': 1, 'tackle-replay': 2, 'tackle-live-incomplete': 3, 'tackle-replay-incomplete': 4}\n", "\n", "# Sort to make sure order is maintained\n", "for ind, label in enumerate(sorted(os.listdir(path_to_labels))):\n", " full_path = os.path.join(path_to_labels, label)\n", "\n", " with open(full_path, 'r') as file:\n", " data = json.load(file)\n", " \n", " # Extract frame count\n", " frame_count = data['media_attributes']['frame_count']\n", "\n", " # Extract tackles\n", " tackles = data['events']\n", " \n", " labels_of_current_file = np.zeros(frame_count)\n", " \n", " for tackle in tackles:\n", " # Extract variables\n", " tackle_class = tackle['type']\n", " start_frame = tackle['frame_start']\n", " end_frame = tackle['frame_end']\n", "\n", " # Need to shift start_frame with -1 as array-indexing starts at 0, while \n", " # frame count starts at 1\n", " for i in range(start_frame-1, end_frame, 1):\n", " labels_of_current_file[i] = categorical_mapping[tackle_class]\n", "\n", " list_of_labels.append(labels_of_current_file)\n" ] }, { "cell_type": "markdown", "id": "b302d94a-d18c-4e41-929b-3c8f4d547afa", "metadata": {}, "source": [ "## Verify that change is correct" ] }, { "cell_type": "code", "execution_count": null, "id": "286b27a8-1c9a-4ba9-9996-deeef7927195", "metadata": {}, "outputs": [], "source": [ "test = list_of_labels[0]\n", "\n", "for i in range(len(test)):\n", " # Should give [0,1,1,0] as 181-107 is the actual sequence, but its moved to 180-206 with array indexing\n", " # starting from 0 instead of 1 like the frame counting.\n", " if i == 179 or i == 180 or i == 206 or i == 207:\n", " print(test[i])" ] }, { "cell_type": "markdown", "id": "88650952-a098-4ae3-ba3b-d67f5d17c41b", "metadata": {}, "source": [ "## Map incomplete class-labels to instances of their respective 'full-class'" ] }, { "cell_type": "code", "execution_count": null, "id": "2c48db00-b367-4f38-aa59-de5164d11fe9", "metadata": {}, "outputs": [], "source": [ "class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n", "prev_list_of_labels = list_of_labels\n", "\n", "for i, label in enumerate(list_of_labels):\n", " list_of_labels[i] = np.array([class_mapping[frame_class] for frame_class in label])" ] }, { "cell_type": "markdown", "id": "ee69c1f0-db9d-4848-9b3c-2556e09d1991", "metadata": {}, "source": [ "## Load DINOv2-features and extract CLS-tokens" ] }, { "cell_type": "code", "execution_count": null, "id": "20b2ee27-5d94-4301-9229-aa9486360a73", "metadata": {}, "outputs": [], "source": [ "# Define path to DINOv2-features\n", "path_to_tensors = '/home/evan/D1/project/code/processed_features/last_hidden_states'\n", "path_to_first_tensor = '/home/evan/D1/project/code/processed_features/last_hidden_states/1738_avxeiaxxw6ocr.pt'\n", "\n", "all_cls_tokens = torch.load(path_to_first_tensor)[:,0,:]\n", "\n", "for index, tensor_file in enumerate(sorted(os.listdir(path_to_tensors))[1:]): # Start from the second item\n", " full_path = os.path.join(path_to_tensors, tensor_file)\n", " cls_token = torch.load(full_path)[:,0,:]\n", " all_cls_tokens = torch.cat((all_cls_tokens, cls_token), dim=0)\n", "\n", "\n", "# Should have shape: total_frames, feature_vector (1024)\n", "print('CLS tokens shape: ', all_cls_tokens.shape)" ] }, { "cell_type": "markdown", "id": "03c8f5ed-5b04-456d-a9fd-8d493878ea18", "metadata": {}, "source": [ "### Reshape labels list" ] }, { "cell_type": "code", "execution_count": null, "id": "c9bc68a4-5c33-43b6-a9e1-febb035ea2fb", "metadata": {}, "outputs": [], "source": [ "all_labels_concatenated = np.concatenate(list_of_labels, axis=0)\n", "\n", "# Length should be total number of frames\n", "print('Length of all labels concatenated: ', len(all_labels_concatenated))\n", "\n", "\n", "\n", "# Map imcomplete instances to complete ones. As this approach only looks at 'background', 'tackle-live' and 'tackle-replay',\n", "# the incomplete classes can be mapped to their respective others due to a single frame being part of the tackle whatsoever.\n", "class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n", "\n", "for i, label in enumerate(all_labels_concatenated):\n", " all_labels_concatenated[i] = class_mapping[label]" ] }, { "cell_type": "markdown", "id": "f644964d", "metadata": {}, "source": [ "# 3. If you downloaded the DINOv2 cls-tokens together with the labels, follow below:" ] }, { "cell_type": "markdown", "id": "ab5f971c", "metadata": {}, "source": [ "The next cell can be skipped if you completed step 1." ] }, { "cell_type": "code", "execution_count": null, "id": "5e2600aa", "metadata": {}, "outputs": [], "source": [ "\n", "# Place the path to your cls tokens and labels downloaded below:\n", "cls_path = '/home/evan/D1/project/code/full_concat_dino_features.pt'\n", "labels_path = '/home/evan/D1/project/code/all_labels_concatenated.npy'\n", "\n", "all_cls_tokens = torch.load(cls_path)\n", "all_labels_concatenated = np.load(labels_path)\n", "\n", "# Map imcomplete instances to complete ones. As this approach only looks at 'background', 'tackle-live' and 'tackle-replay',\n", "# the incomplete classes can be mapped to their respective others due to a single frame being part of the tackle whatsoever.\n", "class_mapping = {0:0, 1: 1, 2: 2, 3: 1, 4: 2}\n", "\n", "for i, label in enumerate(all_labels_concatenated):\n", " all_labels_concatenated[i] = class_mapping[label]" ] }, { "cell_type": "markdown", "id": "01b360a4", "metadata": {}, "source": [ "# 4. Follow below " ] }, { "cell_type": "markdown", "id": "e4561d68-a149-4a00-9a7d-e0e69bbcfa53", "metadata": {}, "source": [ "## Balance classes" ] }, { "cell_type": "markdown", "id": "68e2e245-36d3-464e-85ae-6d5f30ebe164", "metadata": {}, "source": [ "### Move cls-tokens to CPU" ] }, { "cell_type": "code", "execution_count": null, "id": "61b8a9fe-d3ac-4d6c-b0a9-5c32a2593495", "metadata": {}, "outputs": [], "source": [ "all_cls_tokens = np.array([e.cpu().numpy() for e in all_cls_tokens])\n", "print('Tensor shape after reshaping: ', all_cls_tokens.shape)" ] }, { "cell_type": "markdown", "id": "b6074527-9ddc-4b9e-b933-a6c5af9cd134", "metadata": {}, "source": [ "### Verify that order is correct" ] }, { "cell_type": "code", "execution_count": null, "id": "ea1425ae-6588-4c71-8a08-7f9c0adc7422", "metadata": {}, "outputs": [], "source": [ "for i in range(len(all_labels_concatenated)):\n", " # Should give [0,1,1,0] as 181-107 is the actual sequence, but its moved to 180-206 with array indexing\n", " # starting from 0 instead of 1 like the frame counting.\n", " if i == 179 or i == 180 or i == 206 or i == 207:\n", " print(all_labels_concatenated[i])\n", "\n", " if i > 210:\n", " break" ] }, { "cell_type": "markdown", "id": "6e851954-e2d7-41fd-956f-92df09a79e8b", "metadata": {}, "source": [ "### Class for balancing distribution of classes" ] }, { "cell_type": "code", "execution_count": null, "id": "479daf78-11c0-4ded-9bb3-8fa34d12c6d7", "metadata": {}, "outputs": [], "source": [ "def balance_classes(X, y):\n", " unique, counts = np.unique(y, return_counts=True)\n", " min_samples = counts.min()\n", " # Calculate 2.0 times the minimum sample size, rounded down to the nearest integer\n", " # target_samples = int(2.0 * min_samples)\n", " target_samples = 7500\n", " \n", " indices_to_keep = np.hstack([\n", " np.random.choice(\n", " np.where(y == label)[0], \n", " min(target_samples, counts[unique.tolist().index(label)]), # Ensure not to exceed the actual count\n", " replace=False\n", " ) for label in unique\n", " ])\n", " \n", " return X[indices_to_keep], y[indices_to_keep]" ] }, { "cell_type": "markdown", "id": "6cf24d79-27d7-499e-b856-e58938cef5e7", "metadata": {}, "source": [ "### Split into train and test, without shuffle to remain order" ] }, { "cell_type": "code", "execution_count": null, "id": "9c9fbaec-2849-48d0-867d-e0ad39682135", "metadata": {}, "outputs": [], "source": [ "X_train, X_test, y_train, y_test = train_test_split(all_cls_tokens, all_labels_concatenated, test_size=0.2, shuffle=False, stratify=None)" ] }, { "cell_type": "code", "execution_count": null, "id": "35fa46bb-258a-4b6e-a8c0-56c47c791d55", "metadata": {}, "outputs": [], "source": [ "X_train_balanced, y_train_balanced = balance_classes(X_train, y_train)\n", "X_test_balanced, y_test_balanced = balance_classes(X_test, y_test)\n", "print(\"Total number of samples:\", len(all_labels_concatenated))\n", "print(\"\")\n", "\n", "print('Total distribution of labels: \\n', np.unique(all_labels_concatenated, return_counts=True))\n", "print(\"\")\n", "\n", "\n", "print('Distribution within training set: \\n', np.unique(y_train_balanced, return_counts=True))\n", "print(\"\")\n", "\n", "print('Distribution within test set: \\n', np.unique(y_test_balanced, return_counts=True))\n", "print(\"\")\n", "\n", "\n", "print('Training shape: ', X_train_balanced.shape, y_train_balanced.shape)\n", "print(\"\")\n", "\n", "print('Test shape: ', X_test_balanced.shape, y_test_balanced.shape)\n", "print(\"\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5b6bf3b4-5d67-41b4-9c6b-8d02d3923366", "metadata": {}, "outputs": [], "source": [ "# Convert data to torch tensors\n", "X_train = torch.tensor(X_train_balanced, dtype=torch.float32)\n", "y_train = torch.tensor(y_train_balanced, dtype=torch.long)\n", "X_test = torch.tensor(X_test_balanced, dtype=torch.float32)\n", "y_test = torch.tensor(y_test_balanced, dtype=torch.long)" ] }, { "cell_type": "markdown", "id": "7d7250f4-c820-4c00-9bde-77bdc3cdd2e2", "metadata": {}, "source": [ "## Create dataset and Dataloaders" ] }, { "cell_type": "code", "execution_count": null, "id": "532583ed-65e9-4339-b94d-6cdb704c0ed7", "metadata": {}, "outputs": [], "source": [ "# Create data loaders\n", "batch_size = 64\n", "train_dataset = TensorDataset(X_train, y_train)\n", "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", "\n", "test_dataset = TensorDataset(X_test, y_test)\n", "test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)\n" ] }, { "cell_type": "markdown", "id": "5ef7b5d4-04e1-4c2e-9476-2537a6785893", "metadata": {}, "source": [ "## Model class" ] }, { "cell_type": "code", "execution_count": null, "id": "d7120ab9-c016-4eba-9588-77afde98a639", "metadata": {}, "outputs": [], "source": [ "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "class MultiLayerClassifier(nn.Module):\n", " def __init__(self, input_size, num_classes):\n", " super(MultiLayerClassifier, self).__init__()\n", " \n", " self.fc1 = nn.Linear(input_size, 128, bias=True)\n", " self.dropout1 = nn.Dropout(0.5) \n", " \n", " # self.fc2 = nn.Linear(512, 128)\n", " # self.dropout2 = nn.Dropout(0.5)\n", " \n", " self.fc3 = nn.Linear(128, num_classes, bias=True)\n", " \n", " def forward(self, x):\n", " x = F.relu(self.fc1(x))\n", " x = self.dropout1(x)\n", " # x = F.relu(self.fc2(x))\n", " # x = self.dropout2(x)\n", " x = self.fc3(x)\n", " \n", " return x\n", "\n", "model = MultiLayerClassifier(1024, 3)\n", "model" ] }, { "cell_type": "markdown", "id": "5b0ba056-0a73-466f-b65e-a3261e1a69f1", "metadata": {}, "source": [ "## L1-regularization class" ] }, { "cell_type": "code", "execution_count": null, "id": "ebd6211c-fc94-4557-947b-5a3fac89c1ba", "metadata": {}, "outputs": [], "source": [ "def l1_regularization(model, lambda_l1):\n", " l1_penalty = torch.tensor(0.) # Ensure the penalty is on the same device as model parameters\n", " for param in model.parameters():\n", " l1_penalty += torch.norm(param, 1)\n", " return lambda_l1 * l1_penalty" ] }, { "cell_type": "markdown", "id": "00735f1f-2bf9-4aae-90c2-61e44973f699", "metadata": {}, "source": [ "## Loss, optimizer and L1-strength initialization" ] }, { "cell_type": "code", "execution_count": null, "id": "c4efe9d8-fc72-4701-a1a9-d463c6b33dfa", "metadata": {}, "outputs": [], "source": [ "# Loss and optimizer\n", "criterion = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5) \n", "lambda_l1 = 1e-5 # L1 regularization strength" ] }, { "cell_type": "markdown", "id": "e87f7513-47d0-491e-9073-9289eda1b484", "metadata": {}, "source": [ "## Training loop" ] }, { "cell_type": "code", "execution_count": null, "id": "4260c3bc-25c2-48f0-b79c-b6d7cc0c14eb", "metadata": {}, "outputs": [], "source": [ "epochs = 10\n", "train_losses, test_losses = [], []\n", "\n", "for epoch in range(epochs):\n", " model.train()\n", " train_loss = 0\n", " for X_batch, y_batch in train_loader:\n", " optimizer.zero_grad()\n", " outputs = model(X_batch)\n", " loss = criterion(outputs, y_batch)\n", "\n", " # Calculate L1 regularization penalty\n", " l1_penalty = l1_regularization(model, lambda_l1)\n", " \n", " # Add L1 penalty to the loss\n", " loss += l1_penalty\n", " \n", " loss.backward()\n", " optimizer.step()\n", " train_loss += loss.item()\n", " train_losses.append(train_loss / len(train_loader))\n", "\n", " model.eval()\n", " test_loss = 0\n", " all_preds, all_targets, all_outputs = [], [], []\n", " with torch.no_grad():\n", " for X_batch, y_batch in test_loader:\n", " outputs = model(X_batch)\n", " loss = criterion(outputs, y_batch)\n", " test_loss += loss.item()\n", " _, predicted = torch.max(outputs.data, 1)\n", " all_preds.extend(predicted.numpy())\n", " all_targets.extend(y_batch.numpy())\n", " all_outputs.extend(outputs.numpy())\n", " test_losses.append(test_loss / len(test_loader))\n", " \n", " precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average='weighted', zero_division=0)\n", " accuracy = accuracy_score(all_targets, all_preds) # Compute accuracy\n", " if epoch % 2==0:\n", " print(f'Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}, Accuracy: {accuracy:.4f}')" ] }, { "cell_type": "markdown", "id": "615f685e-fb19-46f8-afba-b76fb730ed49", "metadata": {}, "source": [ "## Train- vs Test-loss graph" ] }, { "cell_type": "code", "execution_count": null, "id": "597b4570-1579-470e-8f11-f72b7b04b816", "metadata": {}, "outputs": [], "source": [ "plt.plot(train_losses, label='Train Loss')\n", "plt.plot(test_losses, label='Test Loss')\n", "plt.legend()\n", "plt.title('Train vs Test Loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "1babe3bd-da5b-4f0d-9d83-9ca4d73922c5", "metadata": {}, "source": [ "## Confusion matrix" ] }, { "cell_type": "code", "execution_count": null, "id": "2c0b0fa3-814e-474c-bbe1-31152305e17b", "metadata": {}, "outputs": [], "source": [ "print(np.unique(all_targets, return_counts=True))\n", "print(np.unique(all_preds, return_counts=True))\n", "\n", "conf_matrix = confusion_matrix(all_targets, all_preds)\n", "labels = [\"background\", \"tackle-live\", \"tackle-replay\",]\n", " # \"tackle-live-incomplete\", \"tackle-replay-incomplete\"]\n", "sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)\n", "# plt.title('Confusion Matrix')\n", "plt.xlabel('Predicted Label')\n", "plt.ylabel('True Label')\n", "plt.show()\n", "\n", "def showClassWiseAcc(conf_matrix):\n", " # Calculate accuracy per class\n", " class_accuracies = conf_matrix.diagonal() / conf_matrix.sum(axis=1)\n", "\n", " # Prepare accuracy data for writing to file\n", " accuracy_data = \"\\n\".join([f\"Accuracy for class {i}: {class_accuracies[i]:.4f}\" for i in range(len(class_accuracies))])\n", "\n", " # Print accuracy per class and write to a file\n", " print(accuracy_data) # Print to console\n", "\n", "showClassWiseAcc(conf_matrix)" ] }, { "cell_type": "markdown", "id": "480ddfd5-6ac4-46ed-92db-b556c8bfbd7d", "metadata": {}, "source": [ "## ROC Curve" ] }, { "cell_type": "code", "execution_count": null, "id": "ddc52d39-7612-43ad-ae44-345119122112", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import roc_curve, auc\n", "import matplotlib.pyplot as plt\n", "\n", "y_score= np.array(all_outputs)\n", "fpr = dict()\n", "tpr = dict()\n", "roc_auc = dict()\n", "n_classes = len(labels) \n", "\n", "y_test_one_hot = np.eye(n_classes)[y_test]\n", "\n", "for i in range(n_classes):\n", " fpr[i], tpr[i], _ = roc_curve(y_test_one_hot[:, i], y_score[:, i])\n", " roc_auc[i] = auc(fpr[i], tpr[i])\n", "\n", "# Plot all ROC curves\n", "plt.figure()\n", "colors = ['blue', 'red', 'green', 'darkorange', 'purple']\n", "for i, color in zip(range(n_classes), colors):\n", " plt.plot(fpr[i], tpr[i], color=color, lw=2,\n", " label='ROC curve of class {0} (area = {1:0.2f})'\n", " ''.format(labels[i], roc_auc[i]))\n", "\n", "plt.plot([0, 1], [0, 1], 'k--', lw=2)\n", "plt.xlim([0.0, 1.0])\n", "plt.ylim([0.0, 1.05])\n", "plt.xlabel('False Positive Rate')\n", "plt.ylabel('True Positive Rate')\n", "print('Receiver operating characteristic for multi-class')\n", "plt.legend(loc=\"lower right\")\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "45c05c14-99d8-49e6-ad64-7e6ad565c0ca", "metadata": {}, "source": [ "## Multi-Class Precision-Recall Cruve" ] }, { "cell_type": "code", "execution_count": null, "id": "3c779274-252f-4248-bf57-a07c665c618c", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import precision_recall_curve\n", "from sklearn.preprocessing import label_binarize\n", "from itertools import cycle\n", "\n", "y_test_bin = label_binarize(y_test, classes=range(n_classes))\n", "\n", "precision_recall = {}\n", "\n", "for i in range(n_classes):\n", " precision, recall, _ = precision_recall_curve(y_test_bin[:, i], y_score[:, i])\n", " precision_recall[i] = (precision, recall)\n", "\n", "colors = cycle(['navy', 'turquoise', 'darkorange', 'cornflowerblue', 'teal'])\n", "\n", "plt.figure(figsize=(6, 4))\n", "\n", "for i, color in zip(range(n_classes), colors):\n", " precision, recall = precision_recall[i]\n", " plt.plot(recall, precision, color=color, lw=2, label=f'{labels[i]}')\n", "\n", "plt.xlabel('Recall')\n", "plt.ylabel('Precision')\n", "print('Multi-Class Precision-Recall Curve')\n", "plt.legend(loc='best')\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python (evan31818)", "language": "python", "name": "evan31818" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }