{ "cells": [ { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "import mirdata\n", "from torch.utils.data import Dataset, DataLoader\n", "import torchaudio\n", "import torchaudio.transforms as T\n", "import torch.nn.functional as F\n", "from pathlib import Path\n", "from typing import List\n", "import torch\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "effect_type = [\"Phaser\"]\n", "root=Path(\"./data/egfx\")\n", "wet_files = []\n", "dry_files = []\n", "labels = []" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "for i, effect in enumerate(effect_type):\n", " for pickup in Path(root / effect).iterdir():\n", " wet_files += list(pickup.glob(\"*.wav\"))\n", " dry_files += list(root.glob(f\"Clean/{pickup.name}/**/*.wav\"))\n", " \n", " labels += [i] * len(wet_files)\n" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "LENGTH = 2**18 # 12 seconds\n", "ORIG_SR = 48000" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "class GuitarFXDataset(Dataset):\n", " def __init__(\n", " self,\n", " root: str,\n", " sample_rate: int,\n", " length: int = LENGTH,\n", " effect_type: List[str] = None,\n", " ):\n", " self.length = length\n", " self.wet_files = []\n", " self.dry_files = []\n", " self.labels = []\n", " self.root = Path(root)\n", " if effect_type is None:\n", " effect_type = [\n", " d.name for d in self.root.iterdir() if d.is_dir() and d != \"Clean\"\n", " ]\n", " for i, effect in enumerate(effect_type):\n", " for pickup in Path(self.root / effect).iterdir():\n", " self.wet_files += sorted(list(pickup.glob(\"*.wav\")))\n", " self.dry_files += sorted(\n", " list(self.root.glob(f\"Clean/{pickup.name}/**/*.wav\"))\n", " )\n", " self.labels += [i] * len(self.wet_files)\n", " print(\n", " f\"Found {len(self.wet_files)} wet files and {len(self.dry_files)} dry files\"\n", " )\n", " self.resampler = T.Resample(ORIG_SR, sample_rate)\n", "\n", " def __len__(self):\n", " return len(self.dry_files)\n", "\n", " def __getitem__(self, idx):\n", " print(idx, self.wet_files[idx], self.dry_files[idx])\n", " x, sr = torchaudio.load(self.wet_files[idx])\n", " y, sr = torchaudio.load(self.dry_files[idx])\n", " effect_label = self.labels[idx]\n", "\n", " resampled_x = self.resampler(x)\n", " resampled_y = self.resampler(y)\n", " # Pad or crop to length\n", " if resampled_x.shape[-1] < self.length:\n", " resampled_x = F.pad(resampled_x, (0, self.length - resampled_x.shape[1]))\n", " elif resampled_x.shape[-1] > self.length:\n", " resampled_x = resampled_x[:, : self.length]\n", " if resampled_y.shape[-1] < self.length:\n", " resampled_y = F.pad(resampled_y, (0, self.length - resampled_y.shape[1]))\n", " elif resampled_y.shape[-1] > self.length:\n", " resampled_y = resampled_y[:, : self.length]\n", " return (resampled_x, resampled_y, effect_label)\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "\n", "SAMPLE_RATE = 22050\n", "TRAIN_SPLIT = 0.8" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 690 wet files and 690 dry files\n" ] } ], "source": [ "guitfx = GuitarFXDataset(\n", " root=\"./data/egfx\",\n", " sample_rate=SAMPLE_RATE,\n", " effect_type=[\"Phaser\"],\n", ")" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "train_size = int(TRAIN_SPLIT * len(guitfx))\n", "val_size = len(guitfx) - train_size\n", "train_dataset, val_dataset = torch.utils.data.random_split(\n", " guitfx, [train_size, val_size]\n", ")\n", "val = DataLoader(val_dataset, batch_size=2)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[560, 150, 218, 404, 292, 509, 10, 315, 554, 6, 169, 116, 601, 309, 280, 510, 559, 197, 613, 424, 500, 460, 273, 467, 190, 534, 642, 112, 635, 283, 217, 7, 679, 526, 73, 102, 134, 263, 449, 142, 215, 154, 181, 378, 425, 278, 208, 58, 323, 210, 388, 363, 249, 57, 479, 79, 508, 429, 237, 390, 435, 62, 254, 528, 614, 311, 680, 61, 374, 668, 373, 594, 9, 677, 188, 2, 91, 633, 549, 257, 170, 183, 465, 502, 244, 664, 632, 356, 581, 145, 81, 85, 232, 250, 571, 118, 319, 308, 536, 592, 607, 566, 609, 302, 576, 354, 35, 493, 593, 437, 636, 495, 506, 153, 638, 164, 229, 456, 34, 518, 381, 322, 304, 565, 52, 499, 66, 39, 220, 38, 111, 454, 267, 98, 563, 585, 121, 391]\n" ] } ], "source": [ "print(val_dataset.indices)" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[tensor([[[0.0482, 0.0772, 0.0682, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0092, 0.0138, 0.0139, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0007, -0.0009, 0.0074, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0007, 0.0036, 0.0064, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0027, 0.0050, 0.0063, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0043, 0.0077, 0.0084, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0023, 0.0026, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0005, 0.0023, 0.0034, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0008, 0.0017, 0.0056, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0008, 0.0016, 0.0016, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0022, 0.0028, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0003, 0.0009, 0.0011, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0071, 0.0107, 0.0078, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0043, 0.0011, -0.0055, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0004, -0.0014, -0.0022, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0013, 0.0045, 0.0072, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0022, 0.0036, 0.0059, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0431, 0.0687, 0.0638, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0009, 0.0030, 0.0047, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0001, -0.0012, -0.0022, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0053, 0.0082, 0.0058, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0035, -0.0036, -0.0021, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0003, 0.0019, 0.0038, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0003, -0.0029, -0.0058, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0069, 0.0106, 0.0078, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0035, -0.0040, -0.0034, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0003, 0.0010, 0.0034, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0020, -0.0076, -0.0117, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0044, 0.0086, 0.0079, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0016, -0.0022, -0.0014, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0006, -0.0020, -0.0020, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0002, -0.0009, -0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0027, 0.0027, 0.0002, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0033, 0.0048, 0.0028, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0008, 0.0035, 0.0059, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0006, 0.0011, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0056, 0.0341, 0.0562, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0009, -0.0013, -0.0002, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0021, 0.0092, 0.0056, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0007, 0.0021, 0.0023, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0014, 0.0024, 0.0030, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.1450, -0.1390, -0.0209, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 2.6921e-04, 1.6453e-03, 2.7682e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[-3.2024e-02, -1.9613e-01, -4.0412e-01, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[-0.0059, -0.0064, -0.0022, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0039, -0.0046, -0.0021, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0002, -0.0021, -0.0037, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0004, -0.0017, -0.0031, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0086, 0.0122, 0.0111, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0114, -0.0113, -0.0039, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0004, 0.0043, 0.0085, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0010, -0.0059, -0.0108, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0069, 0.0105, 0.0097, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0069, 0.0100, 0.0067, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[1.4215e-04, 2.1199e-03, 5.7695e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[9.1938e-05, 1.1531e-03, 2.8006e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[0.0012, 0.0038, 0.0057, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0035, 0.0058, 0.0088, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0004, 0.0013, 0.0014, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0003, 0.0022, 0.0044, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0011, 0.0023, 0.0030, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0033, -0.0038, -0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0003, 0.0023, 0.0041, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0005, 0.0038, 0.0079, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0020, -0.0012, -0.0042, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0035, 0.0056, 0.0063, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0013, 0.0048, 0.0063, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0005, 0.0026, 0.0049, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0033, -0.0054, -0.0052, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0031, 0.0057, 0.0069, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0021, 0.0064, 0.0081, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0004, 0.0015, 0.0021, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0023, -0.0059, -0.0074, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0087, 0.0125, 0.0101, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0004, -0.0011, -0.0011, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0019, 0.0071, 0.0113, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0046, 0.0039, -0.0007, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0038, 0.0021, 0.0117, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0003, 0.0024, 0.0042, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0048, 0.0240, 0.0323, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0064, 0.0104, 0.0116, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0028, -0.0033, -0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0010, 0.0047, 0.0074, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0005, 0.0013, 0.0009, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0042, 0.0073, 0.0064, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0015, 0.0029, 0.0041, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0026, 0.0033, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0004, 0.0016, 0.0026, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0016, -0.0054, -0.0048, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0113, 0.0209, 0.0223, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0014, 0.0034, 0.0026, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0004, -0.0024, -0.0011, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0064, 0.0108, 0.0107, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0007, -0.0040, -0.0083, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0005, 0.0018, 0.0035, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0008, 0.0012, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0071, 0.0064, -0.0008, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0024, 0.0039, 0.0042, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0018, 0.0064, 0.0094, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0001, 0.0008, 0.0019, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0075, 0.0099, 0.0065, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0008, 0.0010, 0.0005, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0032, 0.0058, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0006, 0.0023, 0.0029, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0024, -0.0027, -0.0021, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0096, 0.0151, 0.0139, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0003, -0.0012, -0.0023, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0001, -0.0010, -0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0012, 0.0027, 0.0036, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0030, 0.0039, 0.0039, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0005, 0.0015, 0.0021, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0003, -0.0010, -0.0016, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0006, 0.0012, 0.0014, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0033, -0.0041, -0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0001, 0.0003, 0.0002, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0005, -0.0021, -0.0035, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0041, 0.0051, 0.0031, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0005, 0.0005, 0.0005, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0006, 0.0025, 0.0042, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0006, 0.0015, 0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0111, 0.0142, 0.0113, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0015, 0.0030, 0.0035, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 4.1647e-04, 3.4729e-03, 9.2547e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[-6.6249e-05, -7.6026e-04, -1.4447e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[-0.0004, 0.0005, 0.0029, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0025, 0.0042, 0.0035, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0004, 0.0015, 0.0015, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0005, 0.0011, 0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0009, 0.0009, 0.0006, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0147, 0.0051, -0.0118, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0008, 0.0028, 0.0037, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0010, 0.0134, 0.0304, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0015, 0.0026, 0.0049, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0030, 0.0060, 0.0081, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0003, 0.0015, 0.0035, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0010, 0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0019, 0.0052, 0.0092, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0031, -0.0043, -0.0041, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0002, 0.0005, 0.0007, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0010, -0.0047, -0.0077, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0014, 0.0019, 0.0014, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0005, -0.0015, -0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0002, 0.0011, 0.0020, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0008, 0.0032, 0.0038, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0097, 0.0098, 0.0010, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0106, 0.0109, 0.0012, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0003, -0.0013, -0.0023, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0002, 0.0021, 0.0052, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0030, 0.0057, 0.0084, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0037, -0.0044, -0.0016, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0013, 0.0051, 0.0077, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0006, 0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0050, 0.0078, 0.0072, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0022, -0.0033, -0.0034, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0003, 0.0014, 0.0022, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0006, -0.0028, -0.0046, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0014, -0.0020, -0.0014, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0069, 0.0159, 0.0219, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0002, 0.0004, -0.0001, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0008, 0.0033, 0.0045, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0027, -0.0035, -0.0092, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0004, 0.0005, 0.0015, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0029, -0.0107, -0.0177, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0006, 0.0013, 0.0007, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0286, 0.0299, 0.0081, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0002, 0.0009, 0.0020, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-2.5037e-03, -1.1195e-03, 1.1923e-02, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[ 1.5308e-04, -3.6808e-05, -5.5343e-04, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[0.0018, 0.0027, 0.0019, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0055, 0.0096, 0.0124, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0002, 0.0022, 0.0032, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0017, 0.0032, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 3.7041e-03, 5.7321e-03, 5.1036e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[ 2.1164e-03, 1.5563e-03, -7.2010e-05, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([[[0.0013, 0.0040, 0.0052, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0009, 0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0020, -0.0022, -0.0008, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0088, 0.0130, 0.0157, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-4.7201e-05, -2.4349e-04, -5.1608e-04, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[ 1.2934e-03, 4.1513e-03, 3.6547e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[0.0018, 0.0069, 0.0096, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0026, 0.0024, 0.0006, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 3.0634e-04, 1.1319e-03, 1.7446e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[-8.3813e-05, -7.7285e-04, -1.7113e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[0.0009, 0.0013, 0.0017, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0070, 0.0122, 0.0151, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0001, 0.0013, 0.0024, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0008, 0.0034, 0.0060, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0075, 0.0141, 0.0187, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0023, 0.0025, 0.0019, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0002, 0.0012, 0.0024, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0007, 0.0028, 0.0046, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0115, -0.0051, -0.0278, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0212, 0.0099, -0.0170, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0042, 0.0199, 0.0318, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0018, 0.0141, 0.0247, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0010, -0.0022, -0.0045, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0068, 0.0116, 0.0121, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0005, 0.0037, 0.0072, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0015, 0.0030, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0081, 0.0120, 0.0105, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0209, -0.0408, -0.0275, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0027, 0.0050, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0037, 0.0249, 0.0292, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0067, 0.0102, 0.0098, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0144, 0.0246, 0.0242, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[2.0850e-04, 2.0092e-03, 4.8804e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[5.1832e-05, 1.2148e-03, 4.0634e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[-0.0024, -0.0024, -0.0013, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0046, 0.0079, 0.0074, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0005, -0.0025, -0.0036, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0003, 0.0010, 0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0109, -0.0121, -0.0033, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0011, 0.0015, 0.0020, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0001, 0.0005, 0.0006, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0008, 0.0010, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0015, -0.0018, -0.0014, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0083, -0.0828, -0.1668, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0006, -0.0017, -0.0014, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0118, -0.0768, -0.1046, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0106, -0.0040, -0.0246, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0028, 0.0043, 0.0036, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0040, 0.0172, 0.0272, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0002, 0.0016, 0.0027, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0020, 0.0005, -0.0163, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0065, 0.0113, 0.0133, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0019, -0.0107, -0.0205, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0007, 0.0026, 0.0040, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0010, 0.0020, 0.0023, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0112, 0.0147, 0.0106, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0005, 0.0010, 0.0009, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0017, 0.0062, 0.0095, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-0.0007, -0.0016, -0.0019, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0028, 0.0051, 0.0083, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0003, -0.0010, -0.0008, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0011, 0.0045, 0.0069, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0078, 0.0125, 0.0115, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0046, 0.0071, 0.0058, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-7.1681e-05, -8.8706e-04, -1.7330e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[ 1.1175e-03, 2.9858e-03, 4.5334e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[-0.0027, -0.0049, -0.0051, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0013, -0.0019, -0.0017, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0005, -0.0015, -0.0020, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0003, -0.0022, -0.0038, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0008, 0.0011, 0.0007, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0004, 0.0012, 0.0014, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0007, 0.0024, 0.0034, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0006, 0.0020, 0.0029, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[-1.2539e-03, -6.1979e-04, 1.0325e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[ 6.1576e-05, 2.2814e-04, 9.5116e-04, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([[[8.0004e-05, 1.0549e-03, 2.6432e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[7.4739e-05, 1.3412e-04, 1.7083e-04, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0151, 0.0170, 0.0124, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0022, -0.0067, -0.0094, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[ 0.0005, 0.0016, 0.0019, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0001, -0.0019, -0.0042, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[ 0.0006, 0.0013, 0.0020, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[-0.0153, -0.0197, -0.0135, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[7.0598e-04, 3.5944e-03, 4.8469e-03, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]],\n", "\n", " [[5.7171e-05, 3.5541e-04, 3.9973e-04, ..., 0.0000e+00,\n", " 0.0000e+00, 0.0000e+00]]]), tensor([0, 0])]\n", "[tensor([[[0.0013, 0.0020, 0.0025, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0120, 0.0202, 0.0220, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0006, 0.0025, 0.0037, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0004, 0.0027, 0.0058, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0029, 0.0039, 0.0064, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0015, 0.0025, 0.0030, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0004, 0.0048, 0.0087, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0003, 0.0012, 0.0022, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0034, 0.0093, 0.0100, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0007, 0.0014, 0.0018, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[-0.0008, -0.0035, -0.0042, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[ 0.0003, 0.0016, 0.0028, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n", "[tensor([[[0.0022, 0.0047, 0.0062, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0082, 0.0121, 0.0115, ..., 0.0000, 0.0000, 0.0000]]]), tensor([[[0.0013, 0.0045, 0.0062, ..., 0.0000, 0.0000, 0.0000]],\n", "\n", " [[0.0015, 0.0091, 0.0154, ..., 0.0000, 0.0000, 0.0000]]]), tensor([0, 0])]\n" ] } ], "source": [ "for v in val:\n", " print(v)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "94173bdbcc3a07290a92586f1f41e17e9573695669854c49e68cc83ee6746035" } } }, "nbformat": 4, "nbformat_minor": 2 }