File size: 7,696 Bytes
c0d8e31 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 |
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import shutil
from pathlib import Path
import numpy as np
import librosa
import torchaudio
from torchvision import transforms
from tqdm import tqdm
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class AudioResNet(nn.Module):
def __init__(self, num_classes=6, dropout_rate=0.5):
super(AudioResNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)
self.dropout = nn.Dropout(dropout_rate)
self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling
self.fc1 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
layers = []
for i in range(num_blocks):
layers.append(ResidualBlock(in_channels if i == 0 else out_channels, out_channels, stride if i == 0 else 1))
return nn.Sequential(*layers)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.gap(x) # Apply Global Average Pooling
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def load_model(model_path='checkpoint_epoch_50.pth', num_classes=6, dropout_rate=0.5):
model = AudioResNet(num_classes=num_classes, dropout_rate=dropout_rate)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
def validate_audio(y, sr, target_sr=44100, min_duration=0.1):
if sr != target_sr:
y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
if len(y) < min_duration * target_sr:
pad_length = int(min_duration * target_sr - len(y))
y = np.pad(y, (0, pad_length), mode='constant')
return y, target_sr
def strip_silence(y, sr, top_db=20, pad_duration=0.1):
y_trimmed, _ = librosa.effects.trim(y, top_db=top_db)
pad_length = int(pad_duration * sr)
y_padded = np.pad(y_trimmed, pad_length, mode='constant')
return y_padded
def audio_to_spectrogram(file_path, n_fft=2048, hop_length=256, n_mels=128, target_sr=44100, min_duration=0.1):
try:
y, sr = librosa.load(file_path, sr=None)
y, sr = validate_audio(y, sr, target_sr, min_duration)
y = strip_silence(y, sr)
except Exception as e:
print(f"Error reading {file_path}: {e}")
return None
y = librosa.util.normalize(y)
S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
S_dB = librosa.power_to_db(S, ref=np.max)
return S_dB
def classify_file(model, file_path, spectrogram_save_path):
spectrogram = audio_to_spectrogram(file_path)
if spectrogram is None:
return None, None
os.makedirs(os.path.dirname(spectrogram_save_path), exist_ok=True)
np.save(spectrogram_save_path, spectrogram)
spectrogram = torch.tensor(spectrogram, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = model(spectrogram)
probabilities = torch.exp(output)
confidence, predicted = torch.max(probabilities, 1)
return confidence.item(), predicted.item()
def sort_files(model, input_folder, output_folder, confidence_threshold=0.9, progress_callback=None):
spectrogram_folder = os.path.join(output_folder, "Spectrograms")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
files = list(Path(input_folder).rglob('*.wav'))
total_files = len(files)
for idx, file in enumerate(files):
spectrogram_save_path = os.path.join(spectrogram_folder, os.path.relpath(file, input_folder)) + '.npy'
confidence, label = classify_file(model, file, spectrogram_save_path)
if confidence is not None and confidence >= confidence_threshold:
label_folder = os.path.join(output_folder, str(label))
if not os.path.exists(label_folder):
os.makedirs(label_folder)
shutil.copy(file, label_folder)
if progress_callback:
progress_callback(idx + 1, total_files)
class Application(tk.Frame):
def __init__(self, master=None):
super().__init__(master)
self.master = master
self.pack()
self.create_widgets()
def create_widgets(self):
self.label = tk.Label(self, text="Select Folder:")
self.label.pack()
self.entry = tk.Entry(self, width=50)
self.entry.pack()
self.browse_button = tk.Button(self, text="Browse", command=self.browse_folder)
self.browse_button.pack()
self.progress = tk.IntVar()
self.progress_bar = ttk.Progressbar(self, orient="horizontal", length=400, mode="determinate", variable=self.progress)
self.progress_bar.pack()
self.sort_button = tk.Button(self, text="Sort Files", command=self.sort_files)
self.sort_button.pack()
self.quit = tk.Button(self, text="Quit", fg="red", command=self.master.destroy)
self.quit.pack()
def browse_folder(self):
folder_selected = filedialog.askdirectory()
self.entry.delete(0, tk.END)
self.entry.insert(0, folder_selected)
def update_progress(self, current, total):
self.progress.set(int(current / total * 100))
self.progress_bar.update()
def sort_files(self):
input_folder = self.entry.get()
output_folder = os.path.join(input_folder, "Sorted")
model_path = "0Shot1Shot2ShotV0.1.pth"
model = load_model(model_path)
try:
sort_files(model, input_folder, output_folder, progress_callback=self.update_progress)
messagebox.showinfo("Success", "Files sorted successfully!")
except Exception as e:
messagebox.showerror("Error", str(e))
root = tk.Tk()
app = Application(master=root)
app.master.title("Sample Sorter")
app.mainloop()
|