Spaces:
Running
Running
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2Model | |
| import torchaudio | |
| from torchaudio.transforms import Resample | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.models import load_model | |
| # Load models | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/mms-1b") | |
| wav2vec_model = Wav2Vec2Model.from_pretrained("facebook/mms-1b").to(device) | |
| # h5_model_path = "model.h5" | |
| # try: | |
| # loaded_h5_model = load_model(h5_model_path) | |
| # print("Model loaded successfully from .h5 format.") | |
| # except Exception as e: | |
| # print(f"Error: {e}") | |
| saved_model_path = "model.h5" | |
| try: | |
| cnn_model = load_model(saved_model_path) | |
| except Exception as e: | |
| st.error(f"Error loading TensorFlow model: {e}") | |
| st.stop() | |
| # Preprocessing Function | |
| def preprocess_audio(audio_path): | |
| try: | |
| waveform, sampling_rate = torchaudio.load(audio_path) | |
| desired_sampling_rate = 16000 | |
| if sampling_rate != desired_sampling_rate: | |
| resampler = Resample(sampling_rate, desired_sampling_rate) | |
| waveform = resampler(waveform) | |
| if waveform.shape[0] > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| return waveform, desired_sampling_rate | |
| except Exception as e: | |
| st.error(f"Error processing audio file: {e}") | |
| return None, None | |
| # Feature Extraction | |
| def extract_features(audio_path, feature_extractor, wav2vec_model, device): | |
| waveform, fs = preprocess_audio(audio_path) | |
| if waveform is None: | |
| return None | |
| inputs = feature_extractor(waveform.squeeze().numpy(), sampling_rate=fs, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = wav2vec_model(**inputs) | |
| embeddings = outputs.last_hidden_state.cpu().numpy() | |
| avg_embeddings = np.mean(embeddings.squeeze(), axis=0) | |
| return avg_embeddings | |
| # Prediction | |
| def predict_with_cnn(audio_path, cnn_model, feature_extractor, wav2vec_model, device): | |
| features = extract_features(audio_path, feature_extractor, wav2vec_model, device) | |
| if features is None: | |
| return None, None, None | |
| features = np.expand_dims(features, axis=0) | |
| features = np.expand_dims(features, axis=2) | |
| predictions = cnn_model.predict(features) | |
| predicted_class = np.argmax(predictions, axis=1) | |
| class_names = ["bonafide", "spoof"] | |
| confidence = predictions[0][predicted_class[0]] # Extract confidence for predicted class | |
| return class_names[predicted_class[0]], predictions[0], confidence | |
| # Streamlit Application | |
| st.set_page_config(page_title="π΅ Audio Spoof Detection", layout="wide") | |
| st.title("π΅ Audio Spoof Detection") | |
| st.markdown( | |
| """ | |
| This application uses advanced machine learning models to detect whether an audio file is **bonafide** (real) or **spoofed** (fake). | |
| Upload a `.wav` file to get started! | |
| """ | |
| ) | |
| # File Upload | |
| uploaded_file = st.file_uploader( | |
| "Upload your audio file (WAV format only):", type=["wav"] | |
| ) | |
| if uploaded_file: | |
| # Save uploaded file to a temporary path | |
| temp_file_path = "temp_audio.wav" | |
| with open(temp_file_path, "wb") as f: | |
| f.write(uploaded_file.getbuffer()) | |
| # Display Audio Player | |
| st.audio(temp_file_path, format="audio/wav") | |
| # Processing Audio | |
| st.write("π§ **Processing the audio...**") | |
| predicted_class, probabilities, confidence = predict_with_cnn( | |
| temp_file_path, cnn_model, feature_extractor, wav2vec_model, device | |
| ) | |
| # Display Results | |
| if predicted_class: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.markdown( | |
| f""" | |
| ## π **Prediction: `{predicted_class.upper()}`** | |
| """ | |
| ) | |
| st.markdown(f"### **Confidence**: `{confidence:.2f}`") | |
| with col2: | |
| st.write("### Class Probabilities") | |
| st.bar_chart(probabilities) | |
| # Display Detailed Probabilities | |
| st.markdown("### Class Details") | |
| st.write(f"**Bonafide Probability**: `{probabilities[0]:.2f}`") | |
| st.write(f"**Spoof Probability**: `{probabilities[1]:.2f}`") | |
| else: | |
| st.error("Failed to process the audio file. Please try again.") | |
| else: | |
| st.info("Please upload a `.wav` audio file to analyze.") | |