CyborgPaloma commited on
Commit
45f95fb
1 Parent(s): b0b0fcb

Upload 2 files

Browse files

My bad scripts I used to make the model/dataset

Files changed (2) hide show
  1. prepare_data.py +299 -0
  2. train_model.py +315 -0
prepare_data.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from torch.utils.data import Dataset, DataLoader, Sampler
5
+ from tqdm import tqdm
6
+ import librosa
7
+ import logging
8
+ import argparse
9
+ import json
10
+ import time
11
+ import torchaudio
12
+ from torchvision import transforms
13
+ import pickle
14
+ import random
15
+
16
+ def configure_logging():
17
+ logging.basicConfig(level=logging.DEBUG,
18
+ format='%(asctime)s - %(levelname)s - %(message)s',
19
+ handlers=[
20
+ logging.StreamHandler()
21
+ ])
22
+ logging.info("Logging is set up.")
23
+ print("Logging is set up.")
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description='Spectrogram Dataset Preparation')
27
+ parser.add_argument('--config', type=str, required=True, help='Path to the config file')
28
+ return parser.parse_args()
29
+
30
+ def load_config(config_path):
31
+ logging.info(f"Loading configuration from {config_path}")
32
+ print(f"Loading configuration from {config_path}")
33
+ try:
34
+ with open(config_path, 'r') as f:
35
+ config = json.load(f)
36
+ logging.info("Configuration loaded successfully")
37
+ print("Configuration loaded successfully")
38
+ return config
39
+ except Exception as e:
40
+ logging.error(f"Failed to load config file: {e}", exc_info=True)
41
+ print(f"Failed to load config file: {e}")
42
+ raise
43
+
44
+ def validate_audio(y, sr, target_sr=44100, min_duration=0.1):
45
+ logging.debug(f"Validating audio with sr={sr}, target_sr={target_sr}, min_duration={min_duration}")
46
+ print(f"Validating audio with sr={sr}, target_sr={target_sr}, min_duration={min_duration}")
47
+ if sr != target_sr:
48
+ logging.warning(f"Resampling from {sr} to {target_sr}")
49
+ print(f"Resampling from {sr} to {target_sr}")
50
+ y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
51
+ if len(y) < min_duration * target_sr:
52
+ pad_length = int(min_duration * target_sr - len(y))
53
+ y = np.pad(y, (0, pad_length), mode='constant')
54
+ logging.info(f"Audio file padded with {pad_length} samples")
55
+ print(f"Audio file padded with {pad_length} samples")
56
+ return y, target_sr
57
+
58
+ def strip_silence(y, sr, top_db=20, pad_duration=0.1):
59
+ logging.debug(f"Stripping silence with sr={sr}, top_db={top_db}, pad_duration={pad_duration}")
60
+ print(f"Stripping silence with sr={sr}, top_db={top_db}, pad_duration={pad_duration}")
61
+ y_trimmed, _ = librosa.effects.trim(y, top_db=top_db)
62
+ pad_length = int(pad_duration * sr)
63
+ y_padded = np.pad(y_trimmed, pad_length, mode='constant')
64
+ return y_padded
65
+
66
+ def audio_to_spectrogram(file_path, n_fft=2048, hop_length=256, n_mels=128, target_sr=44100, min_duration=0.1):
67
+ try:
68
+ logging.info(f"Loading file: {file_path}")
69
+ print(f"Loading file: {file_path}")
70
+ y, sr = librosa.load(file_path, sr=None)
71
+ logging.debug(f"Loaded file: {file_path} with sr={sr}")
72
+ print(f"Loaded file: {file_path} with sr={sr}")
73
+ y, sr = validate_audio(y, sr, target_sr, min_duration)
74
+ y = strip_silence(y, sr)
75
+ except Exception as e:
76
+ logging.error(f"Error reading {file_path}: {e}", exc_info=True)
77
+ print(f"Error reading {file_path}: {e}")
78
+ return None
79
+
80
+ y = librosa.util.normalize(y)
81
+ S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
82
+ S_dB = librosa.power_to_db(S, ref=np.max)
83
+ logging.debug(f"Generated spectrogram for file: {file_path}")
84
+ print(f"Generated spectrogram for file: {file_path}")
85
+
86
+ return S_dB
87
+
88
+ def validate_spectrogram(spectrogram, n_mels=128):
89
+ logging.debug(f"Validating spectrogram with n_mels={n_mels}")
90
+ print(f"Validating spectrogram with n_mels={n_mels}")
91
+ if spectrogram.shape[0] != n_mels:
92
+ raise ValueError(f"Spectrogram has incorrect number of mel bands: {spectrogram.shape[0]}")
93
+ if spectrogram.shape[1] == 0:
94
+ raise ValueError("Spectrogram has zero frames")
95
+ return True
96
+
97
+ def save_spectrogram(spectrogram, save_path):
98
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
99
+ np.save(save_path, spectrogram)
100
+ logging.debug(f"Spectrogram saved at {save_path}")
101
+ print(f"Spectrogram saved at {save_path}")
102
+
103
+ class AddNoise(torch.nn.Module):
104
+ def __init__(self, noise_type='white', snr=10):
105
+ super(AddNoise, self).__init__()
106
+ self.noise_type = noise_type
107
+ self.snr = snr
108
+
109
+ def forward(self, waveform):
110
+ noise = torch.randn_like(waveform)
111
+ signal_power = waveform.norm(p=2)
112
+ noise_power = noise.norm(p=2)
113
+ noise = noise * (signal_power / noise_power) / (10 ** (self.snr / 20))
114
+ return waveform + noise
115
+
116
+ class SpectrogramDataset(Dataset):
117
+ def __init__(self, config, directory, process_new=True):
118
+ logging.info("Initializing SpectrogramDataset...")
119
+ print("Initializing SpectrogramDataset...")
120
+ self.directory = directory
121
+ self.output_directory = config['output_directory']
122
+ self.spectrograms = []
123
+ self.labels = []
124
+ self.label_to_index = {}
125
+ self.process_new = process_new
126
+ self.config = config
127
+
128
+ # Paths for saving and loading cache
129
+ self.cache_path = os.path.join(self.output_directory, 'cache_data.npy')
130
+ self.dataset_path = os.path.join(self.output_directory, 'spectrogram_dataset.pkl')
131
+
132
+ if os.path.exists(self.dataset_path):
133
+ self.load_dataset()
134
+ else:
135
+ if os.path.exists(self.cache_path):
136
+ os.remove(self.cache_path)
137
+ logging.info(f"Cache cleared at {self.cache_path}")
138
+ print(f"Cache cleared at {self.cache_path}")
139
+
140
+ self.load_data()
141
+ self.save_dataset()
142
+
143
+ self.transforms = transforms.Compose([
144
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
145
+ torchaudio.transforms.TimeMasking(time_mask_param=30)
146
+ ]) if self.config['augment'] else None
147
+
148
+ self.audio_transforms = torch.nn.Sequential(
149
+ AddNoise(snr=self.config['noise_snr']),
150
+ torchaudio.transforms.PitchShift(self.config['sample_rate'], n_steps=self.config['pitch_steps'])
151
+ ) if self.config['augment'] else None
152
+ logging.info("SpectrogramDataset initialized successfully")
153
+ print("SpectrogramDataset initialized successfully")
154
+
155
+ def save_dataset(self):
156
+ with open(self.dataset_path, 'wb') as f:
157
+ pickle.dump(self, f)
158
+ logging.info(f"Dataset object saved at {self.dataset_path}")
159
+ print(f"Dataset object saved at {self.dataset_path}")
160
+
161
+ def load_dataset(self):
162
+ with open(self.dataset_path, 'rb') as f:
163
+ obj = pickle.load(f)
164
+ self.__dict__.update(obj.__dict__)
165
+ logging.info(f"Dataset object loaded from {self.dataset_path}")
166
+ print(f"Dataset object loaded from {self.dataset_path}")
167
+
168
+ def process_file(self, file_path):
169
+ logging.debug(f"Processing file: {file_path}")
170
+ print(f"Processing file: {file_path}")
171
+ try:
172
+ label = os.path.basename(os.path.dirname(file_path))
173
+ if label not in self.label_to_index:
174
+ self.label_to_index[label] = len(self.label_to_index)
175
+ relative_path = os.path.relpath(file_path, self.directory)
176
+ spectrogram_path = os.path.join(self.output_directory, os.path.splitext(relative_path)[0] + '_spectrogram.npy')
177
+ if not os.path.exists(spectrogram_path) and self.process_new:
178
+ spectrogram = audio_to_spectrogram(file_path, n_fft=self.config['n_fft'], hop_length=self.config['hop_length'], n_mels=self.config['n_mels'], target_sr=self.config['sample_rate'], min_duration=self.config['min_duration'])
179
+ if spectrogram is not None:
180
+ if spectrogram.shape[1] > self.config['max_frames']:
181
+ spectrogram = spectrogram[:, :self.config['max_frames']]
182
+ try:
183
+ validate_spectrogram(spectrogram, n_mels=self.config['n_mels'])
184
+ save_spectrogram(spectrogram, spectrogram_path)
185
+ logging.debug(f"Spectrogram saved: {spectrogram_path}")
186
+ print(f"Spectrogram saved: {spectrogram_path}")
187
+ except Exception as e:
188
+ logging.error(f"Error validating/saving spectrogram: {e}", exc_info=True)
189
+ print(f"Error validating/saving spectrogram: {e}")
190
+ if os.path.exists(spectrogram_path):
191
+ try:
192
+ spectrogram = np.load(spectrogram_path)
193
+ validate_spectrogram(spectrogram, n_mels=self.config['n_mels'])
194
+ spectrogram_tensor = torch.tensor(spectrogram, dtype=torch.float32)
195
+ self.spectrograms.append(spectrogram_tensor)
196
+ self.labels.append(self.label_to_index[label])
197
+ logging.debug(f"Spectrogram loaded and appended for file: {file_path}")
198
+ print(f"Spectrogram loaded and appended for file: {file_path}")
199
+ except Exception as e:
200
+ logging.error(f"Error loading spectrogram {spectrogram_path}: {e}", exc_info=True)
201
+ print(f"Error loading spectrogram {spectrogram_path}: {e}")
202
+ except Exception as e:
203
+ logging.error(f"Exception in process_file: {e}", exc_info=True)
204
+ print(f"Exception in process_file: {e}")
205
+
206
+ def load_data(self):
207
+ start_time = time.time()
208
+ logging.info("Starting to load and process files...")
209
+ print("Starting to load and process files...")
210
+ files_to_process = [os.path.join(root, file) for root, _, files in os.walk(self.directory) for file in files if file.lower().endswith('.wav')]
211
+ total_files = len(files_to_process)
212
+ logging.info(f"Total files to process: {total_files}")
213
+ print(f"Total files to process: {total_files}")
214
+
215
+ for file_path in tqdm(files_to_process, desc="Processing files"):
216
+ self.process_file(file_path)
217
+
218
+ end_time = time.time()
219
+ logging.info(f"Data loading and processing took {end_time - start_time:.2f} seconds")
220
+ print(f"Data loading and processing took {end_time - start_time:.2f} seconds")
221
+
222
+ self.save_cached_data(self.cache_path)
223
+
224
+ def save_cached_data(self, cache_path):
225
+ os.makedirs(os.path.dirname(cache_path), exist_ok=True)
226
+ np.save(cache_path, {'spectrograms': self.spectrograms, 'labels': self.labels})
227
+ logging.debug(f"Cached data saved at {cache_path}")
228
+ print(f"Cached data saved at {cache_path}")
229
+
230
+ def __len__(self):
231
+ return len(self.spectrograms)
232
+
233
+ def __getitem__(self, idx):
234
+ spectrogram, label = self.spectrograms[idx], self.labels[idx]
235
+ if self.config['augment']:
236
+ if spectrogram.shape[1] >= 256: # Ensure sufficient width for PitchShift
237
+ spectrogram = self.audio_transforms(spectrogram.unsqueeze(0)).squeeze(0)
238
+ spectrogram = self.transforms(spectrogram.unsqueeze(0)).squeeze(0)
239
+ return spectrogram, label
240
+
241
+ def collate_fn(batch):
242
+ spectrograms, labels = zip(*batch)
243
+ labels = torch.tensor(labels, dtype=torch.long)
244
+ max_length = max(s.size(1) for s in spectrograms)
245
+ max_freq = max(s.size(0) for s in spectrograms)
246
+ spectrograms_padded = torch.zeros(len(spectrograms), max_freq, max_length)
247
+ for i, s in enumerate(spectrograms):
248
+ if s.dim() == 3 and s.size(2) == 1:
249
+ s = s.squeeze(2)
250
+ spectrograms_padded[i, :s.size(0), :s.size(1)] = s
251
+ return spectrograms_padded, labels
252
+
253
+ class SmartBatchingSampler(Sampler):
254
+ def __init__(self, data_source, batch_size):
255
+ self.data_source = data_source
256
+ self.batch_size = batch_size
257
+
258
+ def __iter__(self):
259
+ sorted_indices = sorted(range(len(self.data_source)), key=lambda i: self.data_source[i][0].shape[1], reverse=True)
260
+ pooled_indices = [sorted_indices[i:i + self.batch_size] for i in range(0, len(sorted_indices), self.batch_size)]
261
+ random.shuffle(pooled_indices)
262
+ for p in pooled_indices:
263
+ yield from p
264
+ if len(sorted_indices) % self.batch_size != 0:
265
+ yield from sorted_indices[-(len(sorted_indices) % self.batch_size):]
266
+
267
+ def __len__(self):
268
+ return len(self.data_source) // self.batch_size
269
+
270
+ if __name__ == '__main__':
271
+ print("Starting script")
272
+ try:
273
+ args = parse_args()
274
+ print(f"Arguments parsed: {args}")
275
+ config = load_config(args.config)
276
+ print(f"Config loaded: {config}")
277
+
278
+ configure_logging()
279
+ print("Logging configured")
280
+
281
+ logging.info("Script started.")
282
+ dataset = SpectrogramDataset(config, config['directory'], process_new=True)
283
+ dataloader = DataLoader(dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SmartBatchingSampler(dataset, config['batch_size']))
284
+ for batch in dataloader:
285
+ spectrograms, labels = batch
286
+ logging.info(f"Spectrograms batch shape: {spectrograms.shape}")
287
+ logging.info(f"Labels batch shape: {labels.shape}")
288
+ print(f"Spectrograms batch shape: {spectrograms.shape}")
289
+ print(f"Labels batch shape: {labels.shape}")
290
+ break
291
+
292
+ logging.info(f"Total files processed: {len(dataset)}")
293
+ print(f"Total files processed: {len(dataset)}")
294
+ except Exception as e:
295
+ logging.error(f"Exception occurred: {e}", exc_info=True)
296
+ print(f"Exception occurred: {e}")
297
+ finally:
298
+ logging.info("Script ended.")
299
+ print("Script ended")
train_model.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader, WeightedRandomSampler, random_split, RandomSampler, SequentialSampler
6
+ import logging
7
+ import argparse
8
+ import json
9
+ from datetime import datetime
10
+ import optuna
11
+ from prepare_data import SpectrogramDataset, collate_fn
12
+ from sklearn.metrics import classification_report, confusion_matrix
13
+ import matplotlib.pyplot as plt
14
+ import seaborn as sns
15
+ import os
16
+ import numpy as np
17
+
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ print(f'Using device: {device}')
20
+
21
+ class ResidualBlock(nn.Module):
22
+ def __init__(self, in_channels, out_channels, stride=1):
23
+ super(ResidualBlock, self).__init__()
24
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
25
+ self.bn1 = nn.BatchNorm2d(out_channels)
26
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
27
+ self.bn2 = nn.BatchNorm2d(out_channels)
28
+ if stride != 1 or in_channels != out_channels:
29
+ self.shortcut = nn.Sequential(
30
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
31
+ nn.BatchNorm2d(out_channels)
32
+ )
33
+ else:
34
+ self.shortcut = nn.Identity()
35
+
36
+ def forward(self, x):
37
+ out = F.relu(self.bn1(self.conv1(x)))
38
+ out = self.bn2(self.conv2(out))
39
+ out += self.shortcut(x)
40
+ out = F.relu(out)
41
+ return out
42
+
43
+ class AudioResNet(nn.Module):
44
+ def __init__(self, num_classes=6, dropout_rate=0.5):
45
+ super(AudioResNet, self).__init__()
46
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
47
+ self.bn1 = nn.BatchNorm2d(64)
48
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
49
+ self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
50
+ self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
51
+ self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
52
+ self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)
53
+
54
+ self.dropout = nn.Dropout(dropout_rate)
55
+ self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling
56
+ self.fc1 = nn.Linear(512, 1024)
57
+ self.fc2 = nn.Linear(1024, num_classes)
58
+
59
+ def _make_layer(self, in_channels, out_channels, num_blocks, stride):
60
+ layers = []
61
+ for i in range(num_blocks):
62
+ layers.append(ResidualBlock(in_channels if i == 0 else out_channels, out_channels, stride if i == 0 else 1))
63
+ return nn.Sequential(*layers)
64
+
65
+ def forward(self, x):
66
+ x = F.relu(self.bn1(self.conv1(x)))
67
+ x = self.maxpool(x)
68
+ x = self.layer1(x)
69
+ x = self.layer2(x)
70
+ x = self.layer3(x)
71
+ x = self.layer4(x)
72
+
73
+ x = self.gap(x) # Apply Global Average Pooling
74
+ x = x.view(x.size(0), -1)
75
+
76
+ x = F.relu(self.fc1(x))
77
+ x = self.dropout(x)
78
+ x = self.fc2(x)
79
+ return F.log_softmax(x, dim=1)
80
+
81
+ # Example device configuration
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ print(f'Using device: {device}')
84
+
85
+ # Configure logging
86
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
87
+ logger = logging.getLogger()
88
+ fh = logging.FileHandler('training.log')
89
+ fh.setLevel(logging.INFO)
90
+ ch = logging.StreamHandler()
91
+ ch.setLevel(logging.INFO)
92
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
93
+ fh.setFormatter(formatter)
94
+ ch.setFormatter(formatter)
95
+ logger.addHandler(fh)
96
+ logger.addHandler(ch)
97
+
98
+ def parse_args():
99
+ parser = argparse.ArgumentParser(description='Train Sample Classifier Model')
100
+ parser.add_argument('--config', type=str, required=True, help='Path to the config file')
101
+ return parser.parse_args()
102
+
103
+ def load_config(config_path):
104
+ if not os.path.exists(config_path):
105
+ raise FileNotFoundError(f"Config file not found: {config_path}")
106
+ with open(config_path, 'r') as f:
107
+ config = json.load(f)
108
+ return config
109
+
110
+ def train_one_epoch(model, train_loader, criterion, optimizer, device):
111
+ model.train()
112
+ running_loss = 0.0
113
+ total_correct = 0
114
+
115
+ for batch_idx, (inputs, labels) in enumerate(train_loader):
116
+ inputs, labels = inputs.to(device), labels.to(device)
117
+
118
+ optimizer.zero_grad()
119
+ outputs = model(inputs.unsqueeze(1))
120
+ loss = criterion(outputs, labels)
121
+ loss.backward()
122
+ optimizer.step()
123
+ running_loss += loss.item() * inputs.size(0)
124
+ _, predicted = torch.max(outputs, 1)
125
+ total_correct += (predicted == labels).sum().item()
126
+
127
+ train_loss = running_loss / len(train_loader.dataset)
128
+ train_accuracy = total_correct / len(train_loader.dataset)
129
+ return train_loss, train_accuracy
130
+
131
+ def validate_one_epoch(model, val_loader, criterion, device):
132
+ model.eval()
133
+ val_loss = 0.0
134
+ val_correct = 0
135
+ with torch.no_grad():
136
+ for batch_idx, (inputs, labels) in enumerate(val_loader):
137
+ inputs, labels = inputs.to(device), labels.to(device)
138
+ outputs = model(inputs.unsqueeze(1))
139
+ loss = criterion(outputs, labels)
140
+ val_loss += loss.item() * inputs.size(0)
141
+ _, predicted = torch.max(outputs, 1)
142
+ val_correct += (predicted == labels).sum().item()
143
+
144
+ val_loss /= len(val_loader.dataset)
145
+ val_accuracy = val_correct / len(val_loader.dataset)
146
+ return val_loss, val_accuracy
147
+
148
+ def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=10, max_epochs=50):
149
+ best_loss = float('inf')
150
+ patience_counter = 0
151
+
152
+ for epoch in range(max_epochs):
153
+ train_loss, train_accuracy = train_one_epoch(model, train_loader, criterion, optimizer, device)
154
+ val_loss, val_accuracy = validate_one_epoch(model, val_loader, criterion, device)
155
+
156
+ log_message = (f'Epoch {epoch + 1}:\n'
157
+ f'Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}, '
158
+ f'Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}\n')
159
+ logging.info(log_message)
160
+
161
+ scheduler.step(val_loss)
162
+ current_lr = optimizer.param_groups[0]['lr']
163
+ logging.info(f'Current learning rate: {current_lr}')
164
+
165
+ if val_loss < best_loss:
166
+ best_loss = val_loss
167
+ patience_counter = 0
168
+ torch.save(model.state_dict(), 'best_model.pth')
169
+ else:
170
+ patience_counter += 1
171
+
172
+ if patience_counter >= patience:
173
+ logging.info('Early stopping triggered')
174
+ break
175
+
176
+ if (epoch + 1) % 10 == 0:
177
+ checkpoint_path = f'checkpoint_epoch_{epoch + 1}.pth'
178
+ torch.save(model.state_dict(), checkpoint_path)
179
+ logging.info(f'Model saved to {checkpoint_path}')
180
+
181
+ def evaluate_model(model, test_loader, device, class_names):
182
+ model.eval()
183
+ all_labels = []
184
+ all_preds = []
185
+ with torch.no_grad():
186
+ for inputs, labels in test_loader:
187
+ inputs, labels = inputs.to(device), labels.to(device)
188
+ outputs = model(inputs.unsqueeze(1))
189
+ _, preds = torch.max(outputs, 1)
190
+ all_labels.extend(labels.cpu().numpy())
191
+ all_preds.extend(preds.cpu().numpy())
192
+ logging.info(classification_report(all_labels, all_preds, target_names=class_names))
193
+ plot_confusion_matrix(all_labels, all_preds, class_names)
194
+
195
+ def plot_confusion_matrix(labels, preds, class_names, save_path=None):
196
+ cm = confusion_matrix(labels, preds)
197
+ plt.figure(figsize=(10, 8))
198
+ sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_names, yticklabels=class_names)
199
+ plt.ylabel('Actual')
200
+ plt.xlabel('Predicted')
201
+ plt.title('Confusion Matrix')
202
+ if save_path:
203
+ plt.savefig(save_path)
204
+ plt.show()
205
+
206
+ def objective(trial, train_loader, val_loader, num_classes):
207
+ learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
208
+ weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-3, log=True)
209
+ dropout_rate = trial.suggest_float('dropout_rate', 0.2, 0.5)
210
+
211
+ model = AudioResNet(num_classes=num_classes, dropout_rate=dropout_rate).to(device)
212
+ criterion = nn.NLLLoss()
213
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
214
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
215
+
216
+ best_loss = float('inf')
217
+ patience_counter = 0
218
+
219
+ for epoch in range(10):
220
+ train_loss, _ = train_one_epoch(model, train_loader, criterion, optimizer, device)
221
+ val_loss, _ = validate_one_epoch(model, val_loader, criterion, device)
222
+ scheduler.step(val_loss)
223
+
224
+ if val_loss < best_loss:
225
+ best_loss = val_loss
226
+ patience_counter = 0
227
+ else:
228
+ patience_counter += 1
229
+
230
+ if patience_counter >= 3:
231
+ break
232
+
233
+ return val_loss
234
+
235
+ def verify_dataset_and_loader(dataset, train_loader, val_loader, test_loader):
236
+ try:
237
+ logger.info(f"Dataset length: {len(dataset)}")
238
+ logger.info(f"Train dataset length: {len(train_loader.dataset)}")
239
+ logger.info(f"Validation dataset length: {len(val_loader.dataset)}")
240
+ logger.info(f"Test dataset length: {len(test_loader.dataset)}")
241
+
242
+ for idx in range(len(train_loader.dataset)):
243
+ _ = train_loader.dataset[idx]
244
+ logger.info("Train dataset verification passed")
245
+
246
+ for idx in range(len(val_loader.dataset)):
247
+ _ = val_loader.dataset[idx]
248
+ logger.info("Validation dataset verification passed")
249
+
250
+ for idx in range(len(test_loader.dataset)):
251
+ _ = test_loader.dataset[idx]
252
+ logger.info("Test dataset verification passed")
253
+ except IndexError as e:
254
+ logger.error(f"Dataset index error: {e}")
255
+
256
+ def verify_sampler_indices(loader, name):
257
+ indices = list(loader.sampler)
258
+ logger.info(f"{name} sampler indices: {indices[:10]}... (total: {len(indices)})")
259
+ max_index = max(indices)
260
+ if max_index >= len(loader.dataset):
261
+ logger.error(f"{name} sampler index out of range: {max_index} >= {len(loader.dataset)}")
262
+ else:
263
+ logger.info(f"{name} sampler indices within range.")
264
+
265
+ def main():
266
+ try:
267
+ args = parse_args()
268
+ config = load_config(args.config)
269
+
270
+ dataset = SpectrogramDataset(config, config['directory'], process_new=True)
271
+ if len(dataset) == 0:
272
+ raise ValueError("The dataset is empty. Please check the data loading process.")
273
+ num_classes = len(dataset.label_to_index)
274
+ class_names = list(dataset.label_to_index.keys())
275
+
276
+ train_size = int(0.7 * len(dataset))
277
+ val_size = int(0.15 * len(dataset))
278
+ test_size = len(dataset) - train_size - val_size
279
+ train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
280
+
281
+ train_labels = [dataset.labels[i] for i in train_dataset.indices]
282
+ class_counts = np.bincount(train_labels)
283
+ class_weights = 1. / class_counts
284
+ sample_weights = class_weights[train_labels]
285
+ sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
286
+
287
+ train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=sampler)
288
+ val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=RandomSampler(val_dataset))
289
+ test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], collate_fn=collate_fn, sampler=SequentialSampler(test_dataset))
290
+
291
+ verify_dataset_and_loader(dataset, train_loader, val_loader, test_loader)
292
+ verify_sampler_indices(train_loader, "Train")
293
+ verify_sampler_indices(val_loader, "Validation")
294
+ verify_sampler_indices(test_loader, "Test")
295
+
296
+ study = optuna.create_study(direction='minimize')
297
+ study.optimize(lambda trial: objective(trial, train_loader, val_loader, num_classes), n_trials=50)
298
+
299
+ print('Best hyperparameters: ', study.best_params)
300
+
301
+ best_params = study.best_params
302
+ model = AudioResNet(num_classes=num_classes, dropout_rate=best_params['dropout_rate']).to(device)
303
+ criterion = nn.NLLLoss()
304
+ optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'], weight_decay=best_params['weight_decay'])
305
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
306
+
307
+ train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, patience=config['patience'])
308
+
309
+ model.load_state_dict(torch.load('best_model.pth'))
310
+ evaluate_model(model, test_loader, device, class_names)
311
+ except Exception as e:
312
+ logging.error(f"An error occurred: {e}")
313
+
314
+ if __name__ == '__main__':
315
+ main()