car_crash / app.py
aizanlabs's picture
Update app.py
b14c819 verified
raw
history blame
No virus
2.43 kB
import torch
import librosa
import pickle
import numpy as np
import gradio as gr
class ML_model:
def __init__(self):
self.ml_model = torch.load("support_file/resnet_carcrash_94.pth", map_location=torch.device('cpu'))
self.ml_model.eval()
with open('support_file/indtocat.pkl', 'rb') as f:
self.i2c = pickle.load(f)
def spec_to_image(self, spec, eps=1e-6):
mean = spec.mean()
std = spec.std()
spec_norm = (spec - mean) / (std + eps)
spec_min, spec_max = spec_norm.min(), spec_norm.max()
spec_scaled = 255 * (spec_norm - spec_min) / (spec_max - spec_min)
spec_scaled = spec_scaled.astype(np.uint8)
return spec_scaled
def get_melspectrogram_db(self, file_path):
# Load audio file
wav, sr = librosa.load(file_path, sr=None)
sr= 44100
# Parameters for mel spectrogram
n_fft = 2048
hop_length = 512
n_mels = 128
fmin = 20
fmax = 8300
if wav.shape[0]<5*sr:
wav=np.pad(wav,int(np.ceil((5*sr-wav.shape[0])/2)),mode='reflect')
else:
wav=wav[:5*sr]
# Compute mel spectrogram
spec = librosa.feature.melspectrogram(y=wav, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, fmin=fmin, fmax=fmax)
# Convert to dB scale
spec_db = librosa.power_to_db(spec, ref=np.max)
return spec_db
def get_prediction(self, file_path):
spec_db = self.get_melspectrogram_db(file_path)
input_image = self.spec_to_image(spec_db)
input_tensor = torch.tensor(input_image[np.newaxis, np.newaxis, ...], dtype=torch.float32).to('cpu')
predictions = self.ml_model(input_tensor)
predicted_index = predictions.argmax(dim=1).item()
return self.i2c[predicted_index]
def predict(file_path):
ml_model = ML_model() # Initialize model
prediction = ml_model.get_prediction(file_path)
return prediction
interface = gr.Interface(
fn=predict,
inputs=gr.Audio(type="filepath", label="Upload your audio file"),
outputs="text",
title="Car Crash Sound Detection",
description="Upload a car crash sound clip and the model will identify the crash type.",
examples=["input_fileszQ1QmqrakIA_5-talking.wav","car_crash.wav","input_fileszQ1QmqrakIA_13-crash.wav"],
cache_examples=False
)
interface.launch(share=True)