|
import os
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import tkinter as tk
|
|
from tkinter import ttk, filedialog, messagebox
|
|
import shutil
|
|
from pathlib import Path
|
|
import numpy as np
|
|
import librosa
|
|
import torchaudio
|
|
from torchvision import transforms
|
|
from tqdm import tqdm
|
|
|
|
class ResidualBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels, stride=1):
|
|
super(ResidualBlock, self).__init__()
|
|
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(out_channels)
|
|
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
|
self.bn2 = nn.BatchNorm2d(out_channels)
|
|
if stride != 1 or in_channels != out_channels:
|
|
self.shortcut = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
|
|
nn.BatchNorm2d(out_channels)
|
|
)
|
|
else:
|
|
self.shortcut = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
out = F.relu(self.bn1(self.conv1(x)))
|
|
out = self.bn2(self.conv2(out))
|
|
out += self.shortcut(x)
|
|
out = F.relu(out)
|
|
return out
|
|
|
|
class AudioResNet(nn.Module):
|
|
def __init__(self, num_classes=6, dropout_rate=0.5):
|
|
super(AudioResNet, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
|
|
self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
|
|
self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
|
|
self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)
|
|
|
|
self.dropout = nn.Dropout(dropout_rate)
|
|
self.gap = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc1 = nn.Linear(512, 1024)
|
|
self.fc2 = nn.Linear(1024, num_classes)
|
|
|
|
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
|
|
layers = []
|
|
for i in range(num_blocks):
|
|
layers.append(ResidualBlock(in_channels if i == 0 else out_channels, out_channels, stride if i == 0 else 1))
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.bn1(self.conv1(x)))
|
|
x = self.maxpool(x)
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
|
|
x = self.gap(x)
|
|
x = x.view(x.size(0), -1)
|
|
|
|
x = F.relu(self.fc1(x))
|
|
x = self.dropout(x)
|
|
x = self.fc2(x)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
def load_model(model_path='checkpoint_epoch_50.pth', num_classes=6, dropout_rate=0.5):
|
|
model = AudioResNet(num_classes=num_classes, dropout_rate=dropout_rate)
|
|
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
|
model.eval()
|
|
return model
|
|
|
|
def validate_audio(y, sr, target_sr=44100, min_duration=0.1):
|
|
if sr != target_sr:
|
|
y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
|
|
if len(y) < min_duration * target_sr:
|
|
pad_length = int(min_duration * target_sr - len(y))
|
|
y = np.pad(y, (0, pad_length), mode='constant')
|
|
return y, target_sr
|
|
|
|
def strip_silence(y, sr, top_db=20, pad_duration=0.1):
|
|
y_trimmed, _ = librosa.effects.trim(y, top_db=top_db)
|
|
pad_length = int(pad_duration * sr)
|
|
y_padded = np.pad(y_trimmed, pad_length, mode='constant')
|
|
return y_padded
|
|
|
|
def audio_to_spectrogram(file_path, n_fft=2048, hop_length=256, n_mels=128, target_sr=44100, min_duration=0.1):
|
|
try:
|
|
y, sr = librosa.load(file_path, sr=None)
|
|
y, sr = validate_audio(y, sr, target_sr, min_duration)
|
|
y = strip_silence(y, sr)
|
|
except Exception as e:
|
|
print(f"Error reading {file_path}: {e}")
|
|
return None
|
|
|
|
y = librosa.util.normalize(y)
|
|
S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
|
|
S_dB = librosa.power_to_db(S, ref=np.max)
|
|
return S_dB
|
|
|
|
def classify_file(model, file_path, spectrogram_save_path):
|
|
spectrogram = audio_to_spectrogram(file_path)
|
|
if spectrogram is None:
|
|
return None, None
|
|
os.makedirs(os.path.dirname(spectrogram_save_path), exist_ok=True)
|
|
np.save(spectrogram_save_path, spectrogram)
|
|
spectrogram = torch.tensor(spectrogram, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
|
with torch.no_grad():
|
|
output = model(spectrogram)
|
|
probabilities = torch.exp(output)
|
|
confidence, predicted = torch.max(probabilities, 1)
|
|
return confidence.item(), predicted.item()
|
|
|
|
def sort_files(model, input_folder, output_folder, confidence_threshold=0.9, progress_callback=None):
|
|
spectrogram_folder = os.path.join(output_folder, "Spectrograms")
|
|
if not os.path.exists(output_folder):
|
|
os.makedirs(output_folder)
|
|
|
|
files = list(Path(input_folder).rglob('*.wav'))
|
|
total_files = len(files)
|
|
|
|
for idx, file in enumerate(files):
|
|
spectrogram_save_path = os.path.join(spectrogram_folder, os.path.relpath(file, input_folder)) + '.npy'
|
|
confidence, label = classify_file(model, file, spectrogram_save_path)
|
|
if confidence is not None and confidence >= confidence_threshold:
|
|
label_folder = os.path.join(output_folder, str(label))
|
|
if not os.path.exists(label_folder):
|
|
os.makedirs(label_folder)
|
|
shutil.copy(file, label_folder)
|
|
if progress_callback:
|
|
progress_callback(idx + 1, total_files)
|
|
|
|
class Application(tk.Frame):
|
|
def __init__(self, master=None):
|
|
super().__init__(master)
|
|
self.master = master
|
|
self.pack()
|
|
self.create_widgets()
|
|
|
|
def create_widgets(self):
|
|
self.label = tk.Label(self, text="Select Folder:")
|
|
self.label.pack()
|
|
|
|
self.entry = tk.Entry(self, width=50)
|
|
self.entry.pack()
|
|
|
|
self.browse_button = tk.Button(self, text="Browse", command=self.browse_folder)
|
|
self.browse_button.pack()
|
|
|
|
self.progress = tk.IntVar()
|
|
self.progress_bar = ttk.Progressbar(self, orient="horizontal", length=400, mode="determinate", variable=self.progress)
|
|
self.progress_bar.pack()
|
|
|
|
self.sort_button = tk.Button(self, text="Sort Files", command=self.sort_files)
|
|
self.sort_button.pack()
|
|
|
|
self.quit = tk.Button(self, text="Quit", fg="red", command=self.master.destroy)
|
|
self.quit.pack()
|
|
|
|
def browse_folder(self):
|
|
folder_selected = filedialog.askdirectory()
|
|
self.entry.delete(0, tk.END)
|
|
self.entry.insert(0, folder_selected)
|
|
|
|
def update_progress(self, current, total):
|
|
self.progress.set(int(current / total * 100))
|
|
self.progress_bar.update()
|
|
|
|
def sort_files(self):
|
|
input_folder = self.entry.get()
|
|
output_folder = os.path.join(input_folder, "Sorted")
|
|
model_path = "0Shot1Shot2ShotV0.1.pth"
|
|
model = load_model(model_path)
|
|
try:
|
|
sort_files(model, input_folder, output_folder, progress_callback=self.update_progress)
|
|
messagebox.showinfo("Success", "Files sorted successfully!")
|
|
except Exception as e:
|
|
messagebox.showerror("Error", str(e))
|
|
|
|
root = tk.Tk()
|
|
app = Application(master=root)
|
|
app.master.title("Sample Sorter")
|
|
app.mainloop()
|
|
|