0Shot1Shot-v0.1 / sorter.py
CyborgPaloma's picture
Upload 5 files
c0d8e31 verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import shutil
from pathlib import Path
import numpy as np
import librosa
import torchaudio
from torchvision import transforms
from tqdm import tqdm
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class AudioResNet(nn.Module):
def __init__(self, num_classes=6, dropout_rate=0.5):
super(AudioResNet, self).__init__()
self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)
self.dropout = nn.Dropout(dropout_rate)
self.gap = nn.AdaptiveAvgPool2d((1, 1)) # Global Average Pooling
self.fc1 = nn.Linear(512, 1024)
self.fc2 = nn.Linear(1024, num_classes)
def _make_layer(self, in_channels, out_channels, num_blocks, stride):
layers = []
for i in range(num_blocks):
layers.append(ResidualBlock(in_channels if i == 0 else out_channels, out_channels, stride if i == 0 else 1))
return nn.Sequential(*layers)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.gap(x) # Apply Global Average Pooling
x = x.view(x.size(0), -1)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def load_model(model_path='checkpoint_epoch_50.pth', num_classes=6, dropout_rate=0.5):
model = AudioResNet(num_classes=num_classes, dropout_rate=dropout_rate)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
def validate_audio(y, sr, target_sr=44100, min_duration=0.1):
if sr != target_sr:
y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
if len(y) < min_duration * target_sr:
pad_length = int(min_duration * target_sr - len(y))
y = np.pad(y, (0, pad_length), mode='constant')
return y, target_sr
def strip_silence(y, sr, top_db=20, pad_duration=0.1):
y_trimmed, _ = librosa.effects.trim(y, top_db=top_db)
pad_length = int(pad_duration * sr)
y_padded = np.pad(y_trimmed, pad_length, mode='constant')
return y_padded
def audio_to_spectrogram(file_path, n_fft=2048, hop_length=256, n_mels=128, target_sr=44100, min_duration=0.1):
try:
y, sr = librosa.load(file_path, sr=None)
y, sr = validate_audio(y, sr, target_sr, min_duration)
y = strip_silence(y, sr)
except Exception as e:
print(f"Error reading {file_path}: {e}")
return None
y = librosa.util.normalize(y)
S = librosa.feature.melspectrogram(y=y, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels)
S_dB = librosa.power_to_db(S, ref=np.max)
return S_dB
def classify_file(model, file_path, spectrogram_save_path):
spectrogram = audio_to_spectrogram(file_path)
if spectrogram is None:
return None, None
os.makedirs(os.path.dirname(spectrogram_save_path), exist_ok=True)
np.save(spectrogram_save_path, spectrogram)
spectrogram = torch.tensor(spectrogram, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
with torch.no_grad():
output = model(spectrogram)
probabilities = torch.exp(output)
confidence, predicted = torch.max(probabilities, 1)
return confidence.item(), predicted.item()
def sort_files(model, input_folder, output_folder, confidence_threshold=0.9, progress_callback=None):
spectrogram_folder = os.path.join(output_folder, "Spectrograms")
if not os.path.exists(output_folder):
os.makedirs(output_folder)
files = list(Path(input_folder).rglob('*.wav'))
total_files = len(files)
for idx, file in enumerate(files):
spectrogram_save_path = os.path.join(spectrogram_folder, os.path.relpath(file, input_folder)) + '.npy'
confidence, label = classify_file(model, file, spectrogram_save_path)
if confidence is not None and confidence >= confidence_threshold:
label_folder = os.path.join(output_folder, str(label))
if not os.path.exists(label_folder):
os.makedirs(label_folder)
shutil.copy(file, label_folder)
if progress_callback:
progress_callback(idx + 1, total_files)
class Application(tk.Frame):
def __init__(self, master=None):
super().__init__(master)
self.master = master
self.pack()
self.create_widgets()
def create_widgets(self):
self.label = tk.Label(self, text="Select Folder:")
self.label.pack()
self.entry = tk.Entry(self, width=50)
self.entry.pack()
self.browse_button = tk.Button(self, text="Browse", command=self.browse_folder)
self.browse_button.pack()
self.progress = tk.IntVar()
self.progress_bar = ttk.Progressbar(self, orient="horizontal", length=400, mode="determinate", variable=self.progress)
self.progress_bar.pack()
self.sort_button = tk.Button(self, text="Sort Files", command=self.sort_files)
self.sort_button.pack()
self.quit = tk.Button(self, text="Quit", fg="red", command=self.master.destroy)
self.quit.pack()
def browse_folder(self):
folder_selected = filedialog.askdirectory()
self.entry.delete(0, tk.END)
self.entry.insert(0, folder_selected)
def update_progress(self, current, total):
self.progress.set(int(current / total * 100))
self.progress_bar.update()
def sort_files(self):
input_folder = self.entry.get()
output_folder = os.path.join(input_folder, "Sorted")
model_path = "0Shot1Shot2ShotV0.1.pth"
model = load_model(model_path)
try:
sort_files(model, input_folder, output_folder, progress_callback=self.update_progress)
messagebox.showinfo("Success", "Files sorted successfully!")
except Exception as e:
messagebox.showerror("Error", str(e))
root = tk.Tk()
app = Application(master=root)
app.master.title("Sample Sorter")
app.mainloop()