Spaces:
Runtime error
Runtime error
Add GPU device support to dataset
Browse files- __pycache__/dataset.cpython-39.pyc +0 -0
- dataset.py +11 -2
- notebooks/playground.ipynb +41 -38
__pycache__/dataset.cpython-39.pyc
CHANGED
Binary files a/__pycache__/dataset.cpython-39.pyc and b/__pycache__/dataset.cpython-39.pyc differ
|
|
dataset.py
CHANGED
@@ -7,13 +7,21 @@ import torchaudio
|
|
7 |
|
8 |
class VoiceDataset(Dataset):
|
9 |
|
10 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
# file processing
|
12 |
self._data_path = os.path.join(data_directory)
|
13 |
self._labels = os.listdir(self._data_path)
|
14 |
-
|
15 |
self.audio_files_labels = self._join_audio_files()
|
16 |
|
|
|
|
|
17 |
# audio processing
|
18 |
self.transformation = transformation
|
19 |
self.target_sample_rate = target_sample_rate
|
@@ -35,6 +43,7 @@ class VoiceDataset(Dataset):
|
|
35 |
wav, sr = torchaudio.load(filepath, normalize=True)
|
36 |
|
37 |
# modify wav file, if necessary
|
|
|
38 |
wav = self._resample(wav, sr)
|
39 |
wav = self._mix_down(wav)
|
40 |
wav = self._cut_or_pad(wav)
|
|
|
7 |
|
8 |
class VoiceDataset(Dataset):
|
9 |
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
data_directory,
|
13 |
+
transformation,
|
14 |
+
target_sample_rate,
|
15 |
+
device,
|
16 |
+
time_limit_in_secs=5,
|
17 |
+
):
|
18 |
# file processing
|
19 |
self._data_path = os.path.join(data_directory)
|
20 |
self._labels = os.listdir(self._data_path)
|
|
|
21 |
self.audio_files_labels = self._join_audio_files()
|
22 |
|
23 |
+
self.device = device
|
24 |
+
|
25 |
# audio processing
|
26 |
self.transformation = transformation
|
27 |
self.target_sample_rate = target_sample_rate
|
|
|
43 |
wav, sr = torchaudio.load(filepath, normalize=True)
|
44 |
|
45 |
# modify wav file, if necessary
|
46 |
+
wav = wav.to(self.device)
|
47 |
wav = self._resample(wav, sr)
|
48 |
wav = self._mix_down(wav)
|
49 |
wav = self._cut_or_pad(wav)
|
notebooks/playground.ipynb
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
-
"id": "
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
@@ -14,7 +14,7 @@
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
-
"id": "
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
@@ -24,18 +24,20 @@
|
|
24 |
},
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
-
"execution_count":
|
28 |
-
"id": "
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
32 |
-
"import os"
|
|
|
|
|
33 |
]
|
34 |
},
|
35 |
{
|
36 |
"cell_type": "code",
|
37 |
-
"execution_count":
|
38 |
-
"id": "
|
39 |
"metadata": {},
|
40 |
"outputs": [],
|
41 |
"source": [
|
@@ -44,29 +46,30 @@
|
|
44 |
},
|
45 |
{
|
46 |
"cell_type": "code",
|
47 |
-
"execution_count":
|
48 |
-
"id": "
|
49 |
"metadata": {},
|
50 |
"outputs": [
|
51 |
{
|
52 |
-
"
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
"execution_count": 15,
|
58 |
-
"metadata": {},
|
59 |
-
"output_type": "execute_result"
|
60 |
}
|
61 |
],
|
62 |
"source": [
|
63 |
-
"
|
|
|
|
|
|
|
|
|
64 |
]
|
65 |
},
|
66 |
{
|
67 |
"cell_type": "code",
|
68 |
-
"execution_count":
|
69 |
-
"id": "
|
70 |
"metadata": {},
|
71 |
"outputs": [],
|
72 |
"source": [
|
@@ -76,13 +79,13 @@
|
|
76 |
" hop_length=512,\n",
|
77 |
" n_mels=64\n",
|
78 |
" )\n",
|
79 |
-
"dataset = VoiceDataset('../data', mel_spectrogram, 16000,)"
|
80 |
]
|
81 |
},
|
82 |
{
|
83 |
"cell_type": "code",
|
84 |
-
"execution_count":
|
85 |
-
"id": "
|
86 |
"metadata": {},
|
87 |
"outputs": [
|
88 |
{
|
@@ -91,7 +94,7 @@
|
|
91 |
"5718"
|
92 |
]
|
93 |
},
|
94 |
-
"execution_count":
|
95 |
"metadata": {},
|
96 |
"output_type": "execute_result"
|
97 |
}
|
@@ -102,24 +105,24 @@
|
|
102 |
},
|
103 |
{
|
104 |
"cell_type": "code",
|
105 |
-
"execution_count":
|
106 |
-
"id": "
|
107 |
"metadata": {},
|
108 |
"outputs": [
|
109 |
{
|
110 |
"data": {
|
111 |
"text/plain": [
|
112 |
-
"(tensor([[[0.2647, 0.0247, 0.0324, ..., 0.
|
113 |
-
" [0.0812, 0.0178, 0.0890, ..., 0.
|
114 |
-
" [0.0052, 0.0212, 0.1341, ..., 0.
|
115 |
" ...,\n",
|
116 |
-
" [0.5154, 0.3950, 0.4497, ..., 0.
|
117 |
-
" [0.1919, 0.4804, 0.5144, ..., 0.
|
118 |
-
" [0.1208, 0.4357, 0.4016, ..., 0.
|
119 |
" 'aman')"
|
120 |
]
|
121 |
},
|
122 |
-
"execution_count":
|
123 |
"metadata": {},
|
124 |
"output_type": "execute_result"
|
125 |
}
|
@@ -130,17 +133,17 @@
|
|
130 |
},
|
131 |
{
|
132 |
"cell_type": "code",
|
133 |
-
"execution_count":
|
134 |
-
"id": "
|
135 |
"metadata": {},
|
136 |
"outputs": [
|
137 |
{
|
138 |
"data": {
|
139 |
"text/plain": [
|
140 |
-
"torch.Size([1, 64,
|
141 |
]
|
142 |
},
|
143 |
-
"execution_count":
|
144 |
"metadata": {},
|
145 |
"output_type": "execute_result"
|
146 |
}
|
@@ -152,7 +155,7 @@
|
|
152 |
{
|
153 |
"cell_type": "code",
|
154 |
"execution_count": null,
|
155 |
-
"id": "
|
156 |
"metadata": {},
|
157 |
"outputs": [],
|
158 |
"source": []
|
|
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
"execution_count": 8,
|
6 |
+
"id": "7f11e761",
|
7 |
"metadata": {},
|
8 |
"outputs": [],
|
9 |
"source": [
|
|
|
14 |
{
|
15 |
"cell_type": "code",
|
16 |
"execution_count": 10,
|
17 |
+
"id": "f3deb79d",
|
18 |
"metadata": {},
|
19 |
"outputs": [],
|
20 |
"source": [
|
|
|
24 |
},
|
25 |
{
|
26 |
"cell_type": "code",
|
27 |
+
"execution_count": 76,
|
28 |
+
"id": "eb9888a5",
|
29 |
"metadata": {},
|
30 |
"outputs": [],
|
31 |
"source": [
|
32 |
+
"import os\n",
|
33 |
+
"\n",
|
34 |
+
"import torch"
|
35 |
]
|
36 |
},
|
37 |
{
|
38 |
"cell_type": "code",
|
39 |
+
"execution_count": 77,
|
40 |
+
"id": "75440e63",
|
41 |
"metadata": {},
|
42 |
"outputs": [],
|
43 |
"source": [
|
|
|
46 |
},
|
47 |
{
|
48 |
"cell_type": "code",
|
49 |
+
"execution_count": 78,
|
50 |
+
"id": "5b51f712",
|
51 |
"metadata": {},
|
52 |
"outputs": [
|
53 |
{
|
54 |
+
"name": "stdout",
|
55 |
+
"output_type": "stream",
|
56 |
+
"text": [
|
57 |
+
"Using device cpu\n"
|
58 |
+
]
|
|
|
|
|
|
|
59 |
}
|
60 |
],
|
61 |
"source": [
|
62 |
+
"if torch.cuda.is_available():\n",
|
63 |
+
" device = \"cuda\"\n",
|
64 |
+
"else:\n",
|
65 |
+
" device = \"cpu\"\n",
|
66 |
+
"print(f\"Using device {device}\")"
|
67 |
]
|
68 |
},
|
69 |
{
|
70 |
"cell_type": "code",
|
71 |
+
"execution_count": 80,
|
72 |
+
"id": "253f87d6",
|
73 |
"metadata": {},
|
74 |
"outputs": [],
|
75 |
"source": [
|
|
|
79 |
" hop_length=512,\n",
|
80 |
" n_mels=64\n",
|
81 |
" )\n",
|
82 |
+
"dataset = VoiceDataset('../data', mel_spectrogram, 16000, device)"
|
83 |
]
|
84 |
},
|
85 |
{
|
86 |
"cell_type": "code",
|
87 |
+
"execution_count": 81,
|
88 |
+
"id": "3d5c127a",
|
89 |
"metadata": {},
|
90 |
"outputs": [
|
91 |
{
|
|
|
94 |
"5718"
|
95 |
]
|
96 |
},
|
97 |
+
"execution_count": 81,
|
98 |
"metadata": {},
|
99 |
"output_type": "execute_result"
|
100 |
}
|
|
|
105 |
},
|
106 |
{
|
107 |
"cell_type": "code",
|
108 |
+
"execution_count": 82,
|
109 |
+
"id": "cbac184f",
|
110 |
"metadata": {},
|
111 |
"outputs": [
|
112 |
{
|
113 |
"data": {
|
114 |
"text/plain": [
|
115 |
+
"(tensor([[[0.2647, 0.0247, 0.0324, ..., 0.0230, 0.1026, 0.5454],\n",
|
116 |
+
" [0.0812, 0.0178, 0.0890, ..., 0.2376, 0.5061, 0.5292],\n",
|
117 |
+
" [0.0052, 0.0212, 0.1341, ..., 0.9336, 0.2778, 0.1372],\n",
|
118 |
" ...,\n",
|
119 |
+
" [0.5154, 0.3950, 0.4497, ..., 0.4916, 0.4505, 0.7709],\n",
|
120 |
+
" [0.1919, 0.4804, 0.5144, ..., 0.5931, 0.4466, 0.4706],\n",
|
121 |
+
" [0.1208, 0.4357, 0.4016, ..., 0.5168, 0.7007, 0.3696]]]),\n",
|
122 |
" 'aman')"
|
123 |
]
|
124 |
},
|
125 |
+
"execution_count": 82,
|
126 |
"metadata": {},
|
127 |
"output_type": "execute_result"
|
128 |
}
|
|
|
133 |
},
|
134 |
{
|
135 |
"cell_type": "code",
|
136 |
+
"execution_count": 83,
|
137 |
+
"id": "2bd8c582",
|
138 |
"metadata": {},
|
139 |
"outputs": [
|
140 |
{
|
141 |
"data": {
|
142 |
"text/plain": [
|
143 |
+
"torch.Size([1, 64, 157])"
|
144 |
]
|
145 |
},
|
146 |
+
"execution_count": 83,
|
147 |
"metadata": {},
|
148 |
"output_type": "execute_result"
|
149 |
}
|
|
|
155 |
{
|
156 |
"cell_type": "code",
|
157 |
"execution_count": null,
|
158 |
+
"id": "c3c7b1d4",
|
159 |
"metadata": {},
|
160 |
"outputs": [],
|
161 |
"source": []
|