amanmibra commited on
Commit
3806d0c
1 Parent(s): 0908871

Add GPU device support to dataset

Browse files
__pycache__/dataset.cpython-39.pyc CHANGED
Binary files a/__pycache__/dataset.cpython-39.pyc and b/__pycache__/dataset.cpython-39.pyc differ
 
dataset.py CHANGED
@@ -7,13 +7,21 @@ import torchaudio
7
 
8
  class VoiceDataset(Dataset):
9
 
10
- def __init__(self, data_directory, transformation, target_sample_rate, time_limit_in_secs=5):
 
 
 
 
 
 
 
11
  # file processing
12
  self._data_path = os.path.join(data_directory)
13
  self._labels = os.listdir(self._data_path)
14
-
15
  self.audio_files_labels = self._join_audio_files()
16
 
 
 
17
  # audio processing
18
  self.transformation = transformation
19
  self.target_sample_rate = target_sample_rate
@@ -35,6 +43,7 @@ class VoiceDataset(Dataset):
35
  wav, sr = torchaudio.load(filepath, normalize=True)
36
 
37
  # modify wav file, if necessary
 
38
  wav = self._resample(wav, sr)
39
  wav = self._mix_down(wav)
40
  wav = self._cut_or_pad(wav)
 
7
 
8
  class VoiceDataset(Dataset):
9
 
10
+ def __init__(
11
+ self,
12
+ data_directory,
13
+ transformation,
14
+ target_sample_rate,
15
+ device,
16
+ time_limit_in_secs=5,
17
+ ):
18
  # file processing
19
  self._data_path = os.path.join(data_directory)
20
  self._labels = os.listdir(self._data_path)
 
21
  self.audio_files_labels = self._join_audio_files()
22
 
23
+ self.device = device
24
+
25
  # audio processing
26
  self.transformation = transformation
27
  self.target_sample_rate = target_sample_rate
 
43
  wav, sr = torchaudio.load(filepath, normalize=True)
44
 
45
  # modify wav file, if necessary
46
+ wav = wav.to(self.device)
47
  wav = self._resample(wav, sr)
48
  wav = self._mix_down(wav)
49
  wav = self._cut_or_pad(wav)
notebooks/playground.ipynb CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
- "id": "26db4cdb",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
@@ -14,7 +14,7 @@
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
- "id": "c8244b70",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
@@ -24,18 +24,20 @@
24
  },
25
  {
26
  "cell_type": "code",
27
- "execution_count": 18,
28
- "id": "f3fd2d28",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
32
- "import os"
 
 
33
  ]
34
  },
35
  {
36
  "cell_type": "code",
37
- "execution_count": 14,
38
- "id": "da9fe647",
39
  "metadata": {},
40
  "outputs": [],
41
  "source": [
@@ -44,29 +46,30 @@
44
  },
45
  {
46
  "cell_type": "code",
47
- "execution_count": 15,
48
- "id": "70905d2d",
49
  "metadata": {},
50
  "outputs": [
51
  {
52
- "data": {
53
- "text/plain": [
54
- "dataset.VoiceDataset"
55
- ]
56
- },
57
- "execution_count": 15,
58
- "metadata": {},
59
- "output_type": "execute_result"
60
  }
61
  ],
62
  "source": [
63
- "VoiceDataset"
 
 
 
 
64
  ]
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": 64,
69
- "id": "523d28f9",
70
  "metadata": {},
71
  "outputs": [],
72
  "source": [
@@ -76,13 +79,13 @@
76
  " hop_length=512,\n",
77
  " n_mels=64\n",
78
  " )\n",
79
- "dataset = VoiceDataset('../data', mel_spectrogram, 16000,)"
80
  ]
81
  },
82
  {
83
  "cell_type": "code",
84
- "execution_count": 65,
85
- "id": "0044724d",
86
  "metadata": {},
87
  "outputs": [
88
  {
@@ -91,7 +94,7 @@
91
  "5718"
92
  ]
93
  },
94
- "execution_count": 65,
95
  "metadata": {},
96
  "output_type": "execute_result"
97
  }
@@ -102,24 +105,24 @@
102
  },
103
  {
104
  "cell_type": "code",
105
- "execution_count": 66,
106
- "id": "df7a9e58",
107
  "metadata": {},
108
  "outputs": [
109
  {
110
  "data": {
111
  "text/plain": [
112
- "(tensor([[[0.2647, 0.0247, 0.0324, ..., 0.0000, 0.0000, 0.0000],\n",
113
- " [0.0812, 0.0178, 0.0890, ..., 0.0000, 0.0000, 0.0000],\n",
114
- " [0.0052, 0.0212, 0.1341, ..., 0.0000, 0.0000, 0.0000],\n",
115
  " ...,\n",
116
- " [0.5154, 0.3950, 0.4497, ..., 0.0000, 0.0000, 0.0000],\n",
117
- " [0.1919, 0.4804, 0.5144, ..., 0.0000, 0.0000, 0.0000],\n",
118
- " [0.1208, 0.4357, 0.4016, ..., 0.0000, 0.0000, 0.0000]]]),\n",
119
  " 'aman')"
120
  ]
121
  },
122
- "execution_count": 66,
123
  "metadata": {},
124
  "output_type": "execute_result"
125
  }
@@ -130,17 +133,17 @@
130
  },
131
  {
132
  "cell_type": "code",
133
- "execution_count": 67,
134
- "id": "df064dbc",
135
  "metadata": {},
136
  "outputs": [
137
  {
138
  "data": {
139
  "text/plain": [
140
- "torch.Size([1, 64, 313])"
141
  ]
142
  },
143
- "execution_count": 67,
144
  "metadata": {},
145
  "output_type": "execute_result"
146
  }
@@ -152,7 +155,7 @@
152
  {
153
  "cell_type": "code",
154
  "execution_count": null,
155
- "id": "ed4899bf",
156
  "metadata": {},
157
  "outputs": [],
158
  "source": []
 
3
  {
4
  "cell_type": "code",
5
  "execution_count": 8,
6
+ "id": "7f11e761",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
 
14
  {
15
  "cell_type": "code",
16
  "execution_count": 10,
17
+ "id": "f3deb79d",
18
  "metadata": {},
19
  "outputs": [],
20
  "source": [
 
24
  },
25
  {
26
  "cell_type": "code",
27
+ "execution_count": 76,
28
+ "id": "eb9888a5",
29
  "metadata": {},
30
  "outputs": [],
31
  "source": [
32
+ "import os\n",
33
+ "\n",
34
+ "import torch"
35
  ]
36
  },
37
  {
38
  "cell_type": "code",
39
+ "execution_count": 77,
40
+ "id": "75440e63",
41
  "metadata": {},
42
  "outputs": [],
43
  "source": [
 
46
  },
47
  {
48
  "cell_type": "code",
49
+ "execution_count": 78,
50
+ "id": "5b51f712",
51
  "metadata": {},
52
  "outputs": [
53
  {
54
+ "name": "stdout",
55
+ "output_type": "stream",
56
+ "text": [
57
+ "Using device cpu\n"
58
+ ]
 
 
 
59
  }
60
  ],
61
  "source": [
62
+ "if torch.cuda.is_available():\n",
63
+ " device = \"cuda\"\n",
64
+ "else:\n",
65
+ " device = \"cpu\"\n",
66
+ "print(f\"Using device {device}\")"
67
  ]
68
  },
69
  {
70
  "cell_type": "code",
71
+ "execution_count": 80,
72
+ "id": "253f87d6",
73
  "metadata": {},
74
  "outputs": [],
75
  "source": [
 
79
  " hop_length=512,\n",
80
  " n_mels=64\n",
81
  " )\n",
82
+ "dataset = VoiceDataset('../data', mel_spectrogram, 16000, device)"
83
  ]
84
  },
85
  {
86
  "cell_type": "code",
87
+ "execution_count": 81,
88
+ "id": "3d5c127a",
89
  "metadata": {},
90
  "outputs": [
91
  {
 
94
  "5718"
95
  ]
96
  },
97
+ "execution_count": 81,
98
  "metadata": {},
99
  "output_type": "execute_result"
100
  }
 
105
  },
106
  {
107
  "cell_type": "code",
108
+ "execution_count": 82,
109
+ "id": "cbac184f",
110
  "metadata": {},
111
  "outputs": [
112
  {
113
  "data": {
114
  "text/plain": [
115
+ "(tensor([[[0.2647, 0.0247, 0.0324, ..., 0.0230, 0.1026, 0.5454],\n",
116
+ " [0.0812, 0.0178, 0.0890, ..., 0.2376, 0.5061, 0.5292],\n",
117
+ " [0.0052, 0.0212, 0.1341, ..., 0.9336, 0.2778, 0.1372],\n",
118
  " ...,\n",
119
+ " [0.5154, 0.3950, 0.4497, ..., 0.4916, 0.4505, 0.7709],\n",
120
+ " [0.1919, 0.4804, 0.5144, ..., 0.5931, 0.4466, 0.4706],\n",
121
+ " [0.1208, 0.4357, 0.4016, ..., 0.5168, 0.7007, 0.3696]]]),\n",
122
  " 'aman')"
123
  ]
124
  },
125
+ "execution_count": 82,
126
  "metadata": {},
127
  "output_type": "execute_result"
128
  }
 
133
  },
134
  {
135
  "cell_type": "code",
136
+ "execution_count": 83,
137
+ "id": "2bd8c582",
138
  "metadata": {},
139
  "outputs": [
140
  {
141
  "data": {
142
  "text/plain": [
143
+ "torch.Size([1, 64, 157])"
144
  ]
145
  },
146
+ "execution_count": 83,
147
  "metadata": {},
148
  "output_type": "execute_result"
149
  }
 
155
  {
156
  "cell_type": "code",
157
  "execution_count": null,
158
+ "id": "c3c7b1d4",
159
  "metadata": {},
160
  "outputs": [],
161
  "source": []