{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Start to finish - DINOv2 feature extraction" ] }, { "cell_type": "markdown", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "3AdjGBwjnr-5" }, "outputs": [], "source": [ "from transformers import AutoImageProcessor, AutoModel\n", "from PIL import Image\n", "\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import requests\n", "import torch\n", "import cv2\n", "import os" ] }, { "cell_type": "markdown", "metadata": { "id": "qvTYvSVOkLLL" }, "source": [ "## Initialize pre-trained image processor and model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aRlCk-Tlj8Iv", "outputId": "fb51843c-598f-48ad-a1c0-cf8d9bab53f4", "scrolled": true }, "outputs": [], "source": [ "# Adjust for cuda - takes up 2193 MiB on device\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "\n", "processor = AutoImageProcessor.from_pretrained('facebook/dinov2-large')\n", "model = AutoModel.from_pretrained('facebook/dinov2-large').to(device)" ] }, { "cell_type": "markdown", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## DINOv2 Feature Extraction" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm import tqdm\n", "import gc\n", "\n", "torch.cuda.empty_cache() \n", "gc.collect()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Crq7KD84qz5d" }, "outputs": [], "source": [ "# Path to your videos\n", "path_to_videos = './dataset-tacdec/videos'\n", "\n", "# Directory paths\n", "processed_features_dir = './processed_features'\n", "last_hidden_states_dir = os.path.join(processed_features_dir, 'last_hidden_states/')\n", "pooler_outputs_dir = os.path.join(processed_features_dir, 'pooler_outputs/')\n", "\n", "# Create directories if they don't exist\n", "os.makedirs(last_hidden_states_dir, exist_ok=True)\n", "os.makedirs(pooler_outputs_dir, exist_ok=True)\n", "\n", "# Dictonary with filename as key, all feature extracted frames as values\n", "feature_extracted_videos = {}\n", "\n", "# Define batch size\n", "batch_size = 32\n", "\n", "# Process each video\n", "for video_file in tqdm(os.listdir(path_to_videos)):\n", " full_path = os.path.join(path_to_videos, video_file)\n", "\n", " if not os.path.isfile(full_path):\n", " continue\n", "\n", " cap = cv2.VideoCapture(full_path)\n", "\n", " # List to hold all batch outputs, clear for each video\n", " batch_last_hidden_states = []\n", " batch_pooler_outputs = []\n", " \n", " batch_frames = []\n", "\n", " while True:\n", " ret, frame = cap.read()\n", " if not ret:\n", " \n", " # Process the last batch\n", " if len(batch_frames) > 0:\n", " inputs = processor(images=batch_frames, return_tensors=\"pt\").to(device)\n", " \n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " \n", " for key, value in outputs.items():\n", " if key == 'last_hidden_state':\n", " # batch_last_hidden_states.append(value.cpu().numpy())\n", " batch_last_hidden_states.append(value)\n", " elif key == 'pooler_output':\n", " # batch_pooler_outputs.append(value.cpu().numpy())\n", " batch_pooler_outputs.append(value)\n", " else:\n", " print('Error in key, expected last_hidden_state or pooler_output, got: ', key)\n", " break\n", "\n", " # cv2 comes in BGR, but transformer takes RGB\n", " frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)\n", " batch_frames.append(frame_rgb)\n", "\n", " # Check if batch is full\n", " if len(batch_frames) == batch_size:\n", " inputs = processor(images=batch_frames, return_tensors=\"pt\").to(device)\n", " # outputs = model(**inputs)\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " for key, value in outputs.items():\n", " if key == 'last_hidden_state':\n", " batch_last_hidden_states.append(value)\n", " elif key == 'pooler_output':\n", " batch_pooler_outputs.append(value)\n", " else:\n", " print('Error in key, expected last_hidden_state or pooler_output, got: ', key)\n", "\n", " # Clear batch\n", " batch_frames = []\n", "\n", " \n", " all_last_hidden_states = torch.cat(batch_last_hidden_states, dim=0)\n", " all_pooler_outputs = torch.cat(batch_pooler_outputs, dim=0)\n", "\n", " # Save the tensors with the video name as filename\n", " pt_filename = video_file.replace('.mp4', '.pt')\n", " torch.save(all_last_hidden_states, os.path.join(last_hidden_states_dir, f'{pt_filename}'))\n", " torch.save(all_pooler_outputs, os.path.join(pooler_outputs_dir, f'{pt_filename}'))\n", " \n", "print('Features extracted')" ] }, { "cell_type": "markdown", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## Reload features to verify " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lhs_torch = torch.load('./processed_features/last_hidden_states/1738_avxeiaxxw6ocr.pt')\n", "po_torch = torch.load('./processed_features/pooler_outputs/1738_avxeiaxxw6ocr.pt')\n", "\n", "print('LHS Torch size: ', lhs_torch.size())\n", "print('PO Torch size: ', po_torch.size())\n", "\n", "for i in range(all_last_hidden_states.size(0)):\n", " print(f\"Frame {i}:\")\n", " print(all_last_hidden_states[i])\n", " print() \n", " break\n", "\n", "for i in range(lhs_torch.size(0)):\n", " print(f\"Frame {i}:\")\n", " print(all_last_hidden_states[i])\n", " print() \n", " break\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Different sorts of plots" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Histogram of video length in seconds" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import cv2\n", "import numpy as np\n", "\n", "path_to_videos = './dataset-tacdec/videos'\n", "video_lengths = []\n", "frame_counts = []\n", "\n", "# Iterate through each file in the directory\n", "for video_file in os.listdir(path_to_videos):\n", " full_path = os.path.join(path_to_videos, video_file)\n", "\n", " if not os.path.isfile(full_path):\n", " continue\n", "\n", " cap = cv2.VideoCapture(full_path)\n", "\n", " # Calculate the length of the video\n", " # Note: Assuming the frame rate information is accurate\n", " if cap.isOpened():\n", " fps = cap.get(cv2.CAP_PROP_FPS) # Frame rate\n", " frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", " duration = frame_count / fps if fps > 0 else 0\n", " video_lengths.append(duration)\n", " frame_counts.append(frame_count)\n", "\n", " cap.release()\n", "\n", "np.save('./video_durations', video_lengths)\n", "np.save('./frame_counts', frame_counts)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import seaborn as sns\n", "\n", "# Set the aesthetic style of the plots\n", "sns.set(style=\"darkgrid\")\n", "\n", "# Plotting the histogram for video lengths\n", "plt.figure(figsize=(12, 6))\n", "sns.histplot(video_lengths, kde=True, color=\"blue\")\n", "plt.title('Histogram - Video Lengths')\n", "plt.xlabel('Length of Videos (seconds)')\n", "plt.ylabel('Number of Videos')\n", "\n", "# Plotting the histogram for frame counts\n", "plt.figure(figsize=(12, 6))\n", "sns.histplot(frame_counts, kde=True, color=\"green\")\n", "plt.title('Histogram - Number of Frames')\n", "plt.xlabel('Frame Count')\n", "plt.ylabel('Number of Videos')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "jp-MarkdownHeadingCollapsed": true }, "source": [ "## Frame count and vid lengths" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.boxplot(x=video_lengths)\n", "plt.title('Box Plot of Video Lengths')\n", "plt.xlabel('Video Length (seconds)')\n", "plt.show()\n", "\n", "sns.boxplot(x=frame_counts, color=\"r\")\n", "plt.title('Box Plot of Frame Counts')\n", "plt.xlabel('Frame Count')\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Class distributions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [ "import os\n", "import json\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "path_to_labels = './dataset-tacdec/full_labels'\n", "class_counts = {'background': 0, 'tackle-live': 0, 'tackle-replay': 0, 'tackle-live-incomplete': 0, 'tackle-replay-incomplete': 0, 'dummy_class': 0}\n", "\n", "# Iterate through each JSON file in the labels directory\n", "for label_file in os.listdir(path_to_labels):\n", " full_path = os.path.join(path_to_labels, label_file)\n", "\n", " if not os.path.isfile(full_path):\n", " continue\n", "\n", " with open(full_path, 'r') as file:\n", " data = json.load(file)\n", " frame_sections = data['frames_sections']\n", "\n", " # Extract annotations\n", " for section in frame_sections:\n", " for frame_number, frame_data in section.items():\n", " class_label = frame_data['radio_answer']\n", " if class_label in class_counts:\n", " class_counts[class_label] += 1\n", "\n", "# Convert the dictionary to a DataFrame for Seaborn\n", "df_class_counts = pd.DataFrame(list(class_counts.items()), columns=['Class', 'Occurrences'])\n", "\n", "# Save the DataFrame to a CSV file\n", "df_class_counts.to_csv('class_distribution.csv', sep=',', index=False, encoding='utf-8')\n", "\n", "# Plotting the distribution using Seaborn\n", "plt.figure(figsize=(10, 6))\n", "sns.barplot(x='Class', y='Occurrences', data=df_class_counts, palette='viridis', alpha=0.75)\n", "plt.title('Distribution of Frame Classes')\n", "plt.xlabel('Class')\n", "plt.ylabel('Number of Occurrences')\n", "plt.xticks(rotation=45) # Rotate class names for better readability\n", "plt.tight_layout() # Adjust layout to make room for the rotated x-axis labels\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "# Ensure df_class_counts is already created as in the previous script\n", "\n", "# Create a pie chart\n", "plt.figure(figsize=(8, 8))\n", "plt.pie(df_class_counts['Occurrences'], labels=df_class_counts['Class'], \n", " autopct=lambda p: '{:.1f}%'.format(p), startangle=140, \n", " colors=sns.color_palette('bright', len(df_class_counts)))\n", "plt.title('Distribution of Frame Classes', fontweight='bold')\n", "plt.show()" ] } ], "metadata": { "colab": { "collapsed_sections": [ "uzdIsbuEpF2w" ], "provenance": [] }, "kernelspec": { "display_name": "Python (evan31818)", "language": "python", "name": "evan31818" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 0 }