audio_recorder / app.py
mohammedriza-rahman's picture
Update app.py
05a8a19 verified
import streamlit as st
from streamlit_webrtc import webrtc_streamer, WebRtcMode
from transformers import Wav2Vec2Processor, Wav2Vec2Model
import torch
import numpy as np
import wave
import io
import asyncio
# Load Wav2Vec2 processor and model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
model.eval()
# Function to generate embeddings
def generate_embedding(audio, samplerate):
input_values = processor(audio, sampling_rate=samplerate, return_tensors="pt", padding=True).input_values
with torch.no_grad():
embeddings = model(input_values).last_hidden_state
return embeddings.mean(dim=1) # Mean across time
# Streamlit interface
st.title("Live Audio Recording and Embedding with Wav2Vec 2.0")
st.write("Record audio in your browser and generate embeddings.")
# WebRTC audio recording
webrtc_ctx = webrtc_streamer(
key="audio",
mode=WebRtcMode.SENDONLY,
media_stream_constraints={"audio": True, "video": False},
async_processing=False,
)
if webrtc_ctx.audio_receiver:
try:
audio_frames = webrtc_ctx.audio_receiver.get_frames()
audio_data = b"".join([frame.to_ndarray().tobytes() for frame in audio_frames])
# Convert raw audio bytes to a NumPy array
audio_array = np.frombuffer(audio_data, dtype=np.float32)
# Save the recorded audio to a file
samplerate = 16000 # Default Wav2Vec2 sample rate
audio_file = io.BytesIO()
with wave.open(audio_file, "wb") as wf:
wf.setnchannels(1)
wf.setsampwidth(2)
wf.setframerate(samplerate)
wf.writeframes(audio_array.tobytes())
audio_file.seek(0)
# Display audio playback
st.audio(audio_file, format="audio/wav")
# Generate embedding
embedding = generate_embedding(audio_array, samplerate)
st.success("Audio embedding generated!")
st.write("Embedding Shape:", embedding.shape)
# Save embedding to a file
embedding_file = io.BytesIO()
np.save(embedding_file, embedding.numpy())
embedding_file.seek(0)
# Provide download links
st.download_button("Download Recorded Audio", audio_file, file_name="recorded_audio.wav")
st.download_button("Download Embedding", embedding_file, file_name="embedding.npy")
except Exception as e:
st.error(f"An error occurred: {e}")