import torch import torchaudio from sklearn.metrics import classification_report, confusion_matrix import seaborn as sns import torchaudio from models import * from models import MobileNetV2RawAudio, YAMNet, ElephantCallerNet import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import argparse import io import numpy as np import base64 from io import BytesIO import torchaudio.transforms as T import wave import tempfile import os from flask import jsonify from scipy.io import wavfile device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(model_path): model = torch.load(model_path, map_location=device) model.eval() return model.to(device) def create_temp_wav_file(audio_file): # Create a temporary file temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') temp_file_path = temp_file.name temp_file.close() # Save the uploaded file as a WAV file with open(temp_file_path, 'wb') as temp_wav_file: temp_wav_file.write(audio_file.read()) # Validate the WAV file try: # Try to read the WAV file with scipy.io.wavfile rate, data = wavfile.read(temp_file_path) return temp_file_path except Exception as e: os.remove(temp_file_path) return jsonify({'error': str(e)}), 400 # Load the appropriate pre-trained model based on some condition def load_model_based_on_condition(condition): if condition == "mobilenet": model_path = "MobileNetV2RawAudio_optim.pt" # Path to your model 1 elif condition == "yamnet": model_path = "YAMNETRawAudio_100.pt" # Path to your model 2 elif condition == "elephantnet": model_path = "adcnet_ep100.pt" # Path to your model 3 else: raise ValueError("Invalid condition") return load_model(model_path) def pad_audio(audio, target_duration=6, sr=44100): current_duration = len(audio) / sr if current_duration < target_duration: samples_to_pad = int((target_duration - current_duration) * sr) audio = np.pad(audio, (0, samples_to_pad), mode='constant') return audio def preprocess_audio(waveform, sample_rate, expected_sample_rate=44100, expected_duration=6): try: if sample_rate != expected_sample_rate: resampler = T.Resample(orig_freq=sample_rate, new_freq=expected_sample_rate) waveform = resampler(waveform) waveform = waveform.mean(dim=0, keepdim=True) if waveform.size(0) > 1 else waveform if waveform.size(1) < expected_sample_rate * expected_duration: waveform = pad_audio(waveform.squeeze().numpy(), target_duration=expected_duration, sr=expected_sample_rate) waveform = torch.tensor(waveform).unsqueeze(0) elif waveform.size(1) > expected_sample_rate * expected_duration: waveform = waveform[:, :expected_sample_rate * expected_duration] return waveform except Exception as e: return jsonify({'error': str(e)}), 500 def inference_audio_file(model, temp_file_path, classes): try: # Load the temporary WAV file using torchaudio.load waveform, sample_rate = torchaudio.load(temp_file_path) # Preprocess the audio waveform = preprocess_audio(waveform, sample_rate) waveform = waveform.to(device) with torch.no_grad(): output = model(waveform) probabilities = torch.nn.functional.softmax(output[0], dim=0) predicted_class_index = torch.argmax(probabilities).item() predicted_class = classes[predicted_class_index] return predicted_class, probabilities except Exception as e: return jsonify({'error': str(e)}), 500 def generate_spectrogram(temp_file_path): try: # Read the WAV file using scipy.io.wavfile sample_rate, samples = wavfile.read(temp_file_path) except Exception as e: return jsonify({'error': str(e)}), 400 # If the data is not in int16 format, convert it if samples.dtype == np.float32 or samples.dtype == np.float64: samples = np.int16(samples * 32767) # Generate the spectrogram plt.figure(figsize=(12, 6)) plt.specgram(samples, Fs=sample_rate) plt.xlabel('Time') plt.ylabel('Frequency') plt.title('Spectrogram of the Audio Wave') plt.colorbar(format='%+2.0f dB') # Save the spectrogram to a BytesIO object buf = BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) buf.seek(0) plt.close() # Convert the image to a base64 string image_base64 = base64.b64encode(buf.getvalue()).decode('utf-8') return image_base64 if __name__ == "__main__": # Parse command-line arguments parser = argparse.ArgumentParser(description="Audio classification inference") parser.add_argument("condition", type=str, help="Condition (e.g., mobilenet, yamnet, elephantnet)") parser.add_argument("audio_file", nargs='?', default=None, type=str, help="Path to the audio file") args = parser.parse_args() # Define classes based on imported model names classes = ["Roar", "Rumble", "Trumpet"] # Load model based on condition model = load_model_based_on_condition(args.condition).to(device) # If args.audio_file is not provided, set audio_file_path to a default value if not args.audio_file: audio_file_path = "Trumpet1.wav" # Default audio file path else: audio_file_path = args.audio_file # Open the audio file and create a file-like object with open(audio_file_path, 'rb') as f: audio_data = f.read() audio_buffer = io.BytesIO(audio_data) # Perform inference predicted_class, probabilities = inference_audio_file(model, audio_buffer, classes) print("Predicted class:", predicted_class) print("Class probabilities:", probabilities)