import os import torch import torch.nn as nn import torch.nn.functional as F import tkinter as tk from tkinter import ttk, filedialog, messagebox import shutil from pathlib import Path import numpy as np import librosa import torchaudio from torchvision import transforms from tqdm import tqdm class ResidualBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=1): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(out_channels) if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) else: self.shortcut = nn.Identity() def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.bn2(self.conv2(out)) out += self.shortcut(x) out = F.relu(out) return out class AudioResNet(nn.Module): def __init__(self, num_classes=6, dropout_rate=0.5): super(AudioResNet, self).__init__() self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1) self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2) self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2) self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2) self.dropout = nn.Dropout(dropout_rate) self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling self.fc1 = nn.Linear(512, 1024) self.fc2 = nn.Linear(1024, num_classes) def _make_layer(self, in_channels, out_channels, num_blocks, stride): layers = [] for i in range(num_blocks): layers.append(ResidualBlock(in_channels if i == 0 else out_channels, out_channels, stride if i == 0 else 1)) return nn.Sequential(*layers) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.gap(x) # Apply Global Average Pooling x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return F.log_softmax(x, dim=1) def load_model(model_path='checkpoint_epoch_50.pth', num_classes=6, dropout_rate=0.5): model = AudioResNet(num_classes=num_classes, dropout_rate=dropout_rate) model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) model.eval() return model def validate_audio(y, sr, target_sr=44100, min_duration=0.1): if sr != target_sr: y = librosa.resample(y, orig_sr=sr, target_sr=target_sr) if len(y) < min_duration * target_sr: pad_length = int(min_duration * target_sr - len(y)) y = np.pad(y, (0, pad_length), mode='constant') return y, target_sr def strip_silence(y, sr, top_db=20, pad_duration=0.1): y_trimmed, _ = librosa.effects.trim(y, top_db=top_db) pad_length = int(pad_duration * sr) y_padded = np.pad(y_trimmed, pad_length, mode='constant') return y_padded def audio_to_spectrogram(file_path, n_fft=2048, hop_length=256, n_mels=128, target_sr=44100, min_duration=0.1): try: y, sr = librosa.load(file_path, sr=None) y, sr = validate_audio(y, sr, target_sr, min_duration) y = strip_silence(y, sr) except Exception as e: print(f"Error reading {file_path}: {e}") return None y = librosa.util.normalize(y) S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels) S_dB = librosa.power_to_db(S, ref=np.max) return S_dB def classify_file(model, file_path, spectrogram_save_path): spectrogram = audio_to_spectrogram(file_path) if spectrogram is None: return None, None os.makedirs(os.path.dirname(spectrogram_save_path), exist_ok=True) np.save(spectrogram_save_path, spectrogram) spectrogram = torch.tensor(spectrogram, dtype=torch.float32).unsqueeze(0).unsqueeze(0) with torch.no_grad(): output = model(spectrogram) probabilities = torch.exp(output) confidence, predicted = torch.max(probabilities, 1) return confidence.item(), predicted.item() def sort_files(model, input_folder, output_folder, confidence_threshold=0.9, progress_callback=None): spectrogram_folder = os.path.join(output_folder, "Spectrograms") if not os.path.exists(output_folder): os.makedirs(output_folder) files = list(Path(input_folder).rglob('*.wav')) total_files = len(files) for idx, file in enumerate(files): spectrogram_save_path = os.path.join(spectrogram_folder, os.path.relpath(file, input_folder)) + '.npy' confidence, label = classify_file(model, file, spectrogram_save_path) if confidence is not None and confidence >= confidence_threshold: label_folder = os.path.join(output_folder, str(label)) if not os.path.exists(label_folder): os.makedirs(label_folder) shutil.copy(file, label_folder) if progress_callback: progress_callback(idx + 1, total_files) class Application(tk.Frame): def __init__(self, master=None): super().__init__(master) self.master = master self.pack() self.create_widgets() def create_widgets(self): self.label = tk.Label(self, text="Select Folder:") self.label.pack() self.entry = tk.Entry(self, width=50) self.entry.pack() self.browse_button = tk.Button(self, text="Browse", command=self.browse_folder) self.browse_button.pack() self.progress = tk.IntVar() self.progress_bar = ttk.Progressbar(self, orient="horizontal", length=400, mode="determinate", variable=self.progress) self.progress_bar.pack() self.sort_button = tk.Button(self, text="Sort Files", command=self.sort_files) self.sort_button.pack() self.quit = tk.Button(self, text="Quit", fg="red", command=self.master.destroy) self.quit.pack() def browse_folder(self): folder_selected = filedialog.askdirectory() self.entry.delete(0, tk.END) self.entry.insert(0, folder_selected) def update_progress(self, current, total): self.progress.set(int(current / total * 100)) self.progress_bar.update() def sort_files(self): input_folder = self.entry.get() output_folder = os.path.join(input_folder, "Sorted") model_path = "0Shot1Shot2ShotV0.1.pth" model = load_model(model_path) try: sort_files(model, input_folder, output_folder, progress_callback=self.update_progress) messagebox.showinfo("Success", "Files sorted successfully!") except Exception as e: messagebox.showerror("Error", str(e)) root = tk.Tk() app = Application(master=root) app.master.title("Sample Sorter") app.mainloop()