{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "09225787-6a4b-4484-b00b-d0f731915a81", "metadata": {}, "outputs": [], "source": [ "from models.baseline import Network\n", "from models.mel import AugmentMelSTFT\n", "import soundfile as sf\n", "import torch\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": null, "id": "c377b699-2c2e-468e-88b0-6767338988c8", "metadata": {}, "outputs": [], "source": [ "audio_path = \"/path/to/audio.wav\"" ] }, { "cell_type": "code", "execution_count": null, "id": "fa950347-df0d-4135-801a-d54525c57e58", "metadata": {}, "outputs": [], "source": [ "from IPython.display import display, Audio\n", "\n", "display(Audio(audio_path))" ] }, { "cell_type": "code", "execution_count": null, "id": "79faad26-0f20-439d-b152-10f4666db41d", "metadata": {}, "outputs": [], "source": [ "mel = AugmentMelSTFT().eval()\n", "model = Network.from_pretrained(\"split5\").eval()\n", "\n", "audio, sr = sf.read(audio_path, dtype=np.float32)\n", "assert sr == 32_000\n", "\n", "audio = torch.as_tensor(audio)\n", "\n", "# audio.shape: (1,samples)\n", "\n", "audio = audio.unsqueeze(0)\n", "\n", "# audio.shape: (1,1,samples)\n", "\n", "with torch.no_grad():\n", " mel_spec = mel(audio)\n", "\n", "# mel_spec.shape: (1, mel_bins, frames)\n", "\n", "mel_spec = mel_spec.unsqueeze(0)\n", "\n", "with torch.no_grad():\n", " logits = model(mel_spec)\n", "\n", "# logits.shape: (1,classes)\n", "\n", "logits = logits.squeeze(0)\n", "\n", "tau2022_classes = [\n", " \"airport\",\n", " \"bus\",\n", " \"metro\",\n", " \"metro_station\",\n", " \"park\",\n", " \"public_square\",\n", " \"shopping_mall\",\n", " \"street_pedestrian\",\n", " \"street_traffic\",\n", " \"tram\"\n", "]\n", "\n", "best_prediction_idx = torch.argmax(logits)\n", "\n", "scores = torch.softmax(logits, dim=0)\n", "\n", "print(f\"Prediction: {tau2022_classes[best_prediction_idx]} (score: {scores[best_prediction_idx]:0.2f})\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" } }, "nbformat": 4, "nbformat_minor": 5 }