Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_webrtc import webrtc_streamer, WebRtcMode | |
from transformers import Wav2Vec2Processor, Wav2Vec2Model | |
import torch | |
import numpy as np | |
import wave | |
import io | |
import asyncio | |
# Load Wav2Vec2 processor and model | |
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base") | |
model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base") | |
model.eval() | |
# Function to generate embeddings | |
def generate_embedding(audio, samplerate): | |
input_values = processor(audio, sampling_rate=samplerate, return_tensors="pt", padding=True).input_values | |
with torch.no_grad(): | |
embeddings = model(input_values).last_hidden_state | |
return embeddings.mean(dim=1) # Mean across time | |
# Streamlit interface | |
st.title("Live Audio Recording and Embedding with Wav2Vec 2.0") | |
st.write("Record audio in your browser and generate embeddings.") | |
# WebRTC audio recording | |
webrtc_ctx = webrtc_streamer( | |
key="audio", | |
mode=WebRtcMode.SENDONLY, | |
media_stream_constraints={"audio": True, "video": False}, | |
async_processing=False, | |
) | |
if webrtc_ctx.audio_receiver: | |
try: | |
audio_frames = webrtc_ctx.audio_receiver.get_frames() | |
audio_data = b"".join([frame.to_ndarray().tobytes() for frame in audio_frames]) | |
# Convert raw audio bytes to a NumPy array | |
audio_array = np.frombuffer(audio_data, dtype=np.float32) | |
# Save the recorded audio to a file | |
samplerate = 16000 # Default Wav2Vec2 sample rate | |
audio_file = io.BytesIO() | |
with wave.open(audio_file, "wb") as wf: | |
wf.setnchannels(1) | |
wf.setsampwidth(2) | |
wf.setframerate(samplerate) | |
wf.writeframes(audio_array.tobytes()) | |
audio_file.seek(0) | |
# Display audio playback | |
st.audio(audio_file, format="audio/wav") | |
# Generate embedding | |
embedding = generate_embedding(audio_array, samplerate) | |
st.success("Audio embedding generated!") | |
st.write("Embedding Shape:", embedding.shape) | |
# Save embedding to a file | |
embedding_file = io.BytesIO() | |
np.save(embedding_file, embedding.numpy()) | |
embedding_file.seek(0) | |
# Provide download links | |
st.download_button("Download Recorded Audio", audio_file, file_name="recorded_audio.wav") | |
st.download_button("Download Embedding", embedding_file, file_name="embedding.npy") | |
except Exception as e: | |
st.error(f"An error occurred: {e}") | |