{ "cells": [ { "cell_type": "code", "execution_count": 15, "id": "6bf499e8-54b0-498b-84b6-aba956cc573b", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "\n", "\n", "from CLAPWrapper import CLAPWrapper\n", "from esc50_dataset import ESC50\n", "import torch.nn.functional as F\n", "import numpy as np\n", "from tqdm import tqdm\n", "from sklearn.metrics import accuracy_score\n", "\n", "import torch\n" ] }, { "cell_type": "code", "execution_count": 16, "id": "082e82b9-56b4-41ce-a8f8-390bb5bc0193", "metadata": {}, "outputs": [], "source": [ "df = pd.read_csv(\"../landscape/landscape_final.csv\")\n", "\n", "classes = list(set(df[\"label\"]))\n", "\n", "prompt = 'this is a sound of '\n", "y = [prompt + x for x in classes]\n", "\n", "class_count = len(classes)" ] }, { "cell_type": "code", "execution_count": 17, "id": "68e72bf4-6c94-438d-b3f3-c46aaa0b88cc", "metadata": {}, "outputs": [], "source": [ "class_dict = {k: v for v, k in enumerate(classes)}" ] }, { "cell_type": "code", "execution_count": 37, "id": "80c437e3-b7e3-41bc-bb9c-fab936648caf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/kuacc/users/bbiner21/.conda/envs/clap/lib/python3.8/site-packages/torchlibrosa/stft.py:193: FutureWarning: Pass size=1024 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", " fft_window = librosa.util.pad_center(fft_window, n_fft)\n", "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']\n", "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.\n", "/kuacc/users/bbiner21/.conda/envs/clap/lib/python3.8/site-packages/transformers/tokenization_utils_base.py:2339: FutureWarning: The `pad_to_max_length` argument is deprecated and will be removed in a future version, use `padding=True` or `padding='longest'` to pad to the longest sequence in the batch, or use `padding='max_length'` to pad to a max length. In this case, you can give a specific length with `max_length` (e.g. `max_length=45`) or leave max_length to None to pad to the maximal input size of the model (e.g. 512 for Bert).\n", " warnings.warn(\n" ] } ], "source": [ "# Load and initialize CLAP\n", "weights_path = \"../clap_weight/CLAP_weights_2022.pth\"\n", "\n", "# Setting use_cuda = True will load the model on a GPU using CUDA\n", "clap_model = CLAPWrapper(weights_path, use_cuda=False)\n", "\n", "# Computing text embeddings\n", "text_embeddings = clap_model.get_text_embeddings(y)\n", "\n", "# Computing audio embeddings\n", "y_preds, y_labels = [], []\n" ] }, { "cell_type": "code", "execution_count": 38, "id": "3093fa76-5c25-4cae-a43c-8368fdfd96fc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 1061/1061 [02:33<00:00, 6.92it/s]\n" ] } ], "source": [ "\n", "gt = []\n", "pred = []\n", "\n", "for i in tqdm(range(len(df.index))):\n", " x = \"/datasets/audio-image/audios/audio_10s/\" + df.iloc[i,1] + \".wav\"\n", " \n", " cur_class = class_dict[df.iloc[i,0]]\n", " one_hot = torch.zeros((1,class_count))\n", " one_hot[0,cur_class] = 1.0 \n", " \n", " gt.append(cur_class)\n", " \n", " \n", "# x, _, one_hot_target = dataset.__getitem__(i)\n", " audio_embeddings = clap_model.get_audio_embeddings([x], resample=True)\n", " \n", " similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)\n", " y_pred = F.softmax(similarity.detach().cpu(), dim=1).numpy()\n", " \n", " pred.append(np.argmax(y_pred, axis=1)[0])\n", " y_preds.append(y_pred)\n", " y_labels.append(one_hot.detach().cpu().numpy())\n" ] }, { "cell_type": "code", "execution_count": 23, "id": "e2247ab8-844d-4eba-b691-4d38051a51a3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ESC50 Accuracy 0.4458058435438266\n" ] }, { "data": { "text/plain": [ "'\\nThe output:\\n\\nESC50 Accuracy: 82.6%\\n\\n'" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\n", "\n", "\n", "# for i in tqdm(range(len(dataset))):\n", "# x, _, one_hot_target = dataset.__getitem__(i)\n", "# audio_embeddings = clap_model.get_audio_embeddings([x], resample=True)\n", "# similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)\n", "# y_pred = F.softmax(similarity.detach().cpu(), dim=1).numpy()\n", "# y_preds.append(y_pred)\n", "# y_labels.append(one_hot_target.detach().cpu().numpy())\n", "\n", "y_labels, y_preds = np.concatenate(y_labels, axis=0), np.concatenate(y_preds, axis=0)\n", "acc = accuracy_score(np.argmax(y_labels, axis=1), np.argmax(y_preds, axis=1))\n", "print('ESC50 Accuracy {}'.format(acc))\n", "\n", "\"\"\"\n", "The output:\n", "\n", "ESC50 Accuracy: 82.6%\n", "\n", "\"\"\"\n" ] }, { "cell_type": "code", "execution_count": 25, "id": "41254964-43ec-4fcb-b1d0-2c9ae76d56f6", "metadata": {}, "outputs": [], "source": [ "gt = []\n", "x = \"/datasets/audio-image/audios/audio_10s/\" + df.iloc[0,1] + \".wav\"\n", "\n", "cur_class = class_dict[df.iloc[0,0]]\n", "one_hot = torch.zeros((1,class_count))\n", "one_hot[0,cur_class] = 1.0 \n", "\n", "gt.append(cur_class)\n", "\n", "\n", "# x, _, one_hot_target = dataset.__getitem__(i)\n", "audio_embeddings = clap_model.get_audio_embeddings([x], resample=True)\n", "\n", "similarity = clap_model.compute_similarity(audio_embeddings, text_embeddings)" ] }, { "cell_type": "code", "execution_count": 31, "id": "7e73d889-05b6-46ab-820a-9728b1623d5a", "metadata": {}, "outputs": [], "source": [ "y_pred = F.softmax(similarity.detach().cpu(), dim=1).numpy()\n" ] }, { "cell_type": "code", "execution_count": 35, "id": "99574178-aba0-467b-a370-679ae927b13b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "3" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.argmax(y_pred, axis=1)[0]" ] }, { "cell_type": "code", "execution_count": 41, "id": "21b42bef-9500-46be-8f3e-2c53b91462d0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([0.28571429, 0.35164835, 0.7877095 , 0.59615385, 0.01639344,\n", " 0.93243243, 0.93292683, 0.03092784, 0.4 ])" ] }, "execution_count": 41, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.metrics import confusion_matrix\n", "\n", "matrix = confusion_matrix(gt, pred)\n", "matrix.diagonal()/matrix.sum(axis=1)" ] }, { "cell_type": "code", "execution_count": 42, "id": "2e96c02a-d789-417e-aaec-a420976bef17", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 8, 2, 0, 0, 0, 0, 18, 0, 0],\n", " [ 5, 64, 1, 11, 0, 0, 100, 1, 0],\n", " [ 1, 1, 141, 5, 2, 3, 23, 1, 2],\n", " [ 2, 1, 0, 31, 0, 1, 15, 0, 2],\n", " [ 70, 51, 0, 0, 3, 2, 40, 17, 0],\n", " [ 1, 1, 0, 3, 0, 69, 0, 0, 0],\n", " [ 2, 1, 7, 0, 0, 0, 153, 0, 1],\n", " [ 30, 85, 0, 1, 0, 0, 72, 6, 0],\n", " [ 1, 0, 0, 1, 0, 1, 0, 0, 2]])" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "matrix" ] }, { "cell_type": "code", "execution_count": 43, "id": "24911c5c-06ed-492f-927d-1555df15b1c5", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['this is a sound of waterfall burbling',\n", " 'this is a sound of wind noise',\n", " 'this is a sound of fire crackling',\n", " 'this is a sound of thunder',\n", " 'this is a sound of squishing water',\n", " 'this is a sound of underwater bubbling',\n", " 'this is a sound of raining',\n", " 'this is a sound of splashing water',\n", " 'this is a sound of explosion']" ] }, "execution_count": 43, "metadata": {}, "output_type": "execute_result" } ], "source": [ "y" ] }, { "cell_type": "code", "execution_count": 44, "id": "e90b1d22-ddcd-421b-a011-ef1054cdf412", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['waterfall burbling',\n", " 'wind noise',\n", " 'fire crackling',\n", " 'thunder',\n", " 'squishing water',\n", " 'underwater bubbling',\n", " 'raining',\n", " 'splashing water',\n", " 'explosion']" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "classes" ] } ], "metadata": { "kernelspec": { "display_name": "clap", "language": "python", "name": "clap" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }