Spaces:
Sleeping
Sleeping
import time | |
import datetime | |
import logging | |
import soundfile | |
import streamlit as st | |
from streamlit_webrtc import webrtc_streamer, AudioProcessorBase, WebRtcMode | |
import numpy as np | |
import pydub | |
from pathlib import Path | |
from asr import load_model, inference | |
LOG_DIR = "./logs" | |
DATA_DIR = "./data" | |
logger = logging.getLogger(__name__) | |
# Define a custom audio processor to handle microphone input | |
class AudioProcessor(AudioProcessorBase): | |
def __init__(self): | |
self.audio_data = [] | |
def recv_audio(self, frame): | |
# Convert the audio frame to a NumPy array | |
audio_array = np.frombuffer(frame.to_ndarray(), dtype=np.int16) | |
self.audio_data.append(audio_array) | |
return frame | |
def get_audio_data(self): | |
# Combine all captured audio data | |
if self.audio_data: | |
combined = np.concatenate(self.audio_data, axis=0) | |
return combined | |
return None | |
def upload_audio() -> Path: | |
# Upload audio file | |
uploaded_file = st.file_uploader("Choose a audio file(wav, mp3, flac)", type=['wav','mp3','flac']) | |
if uploaded_file is not None: | |
# Save audio file | |
audio_data, samplerate = soundfile.read(uploaded_file) | |
# Make save directory | |
now = datetime.datetime.now() | |
now_time = now.strftime('%Y-%m-%d-%H:%M:%S') | |
audio_dir = Path(DATA_DIR) / f"{now_time}" | |
audio_dir.mkdir(parents=True, exist_ok=True) | |
audio_path = audio_dir / uploaded_file.name | |
soundfile.write(audio_path, audio_data, samplerate) | |
# Show audio file | |
with open(audio_path, 'rb') as audio_file: | |
audio_bytes = audio_file.read() | |
st.audio(audio_bytes, format=uploaded_file.type) | |
return audio_path | |
def call_load_model(): | |
generator = load_model() | |
return generator | |
def main(): | |
st.header("Speech-to-Text app with streamlit") | |
st.markdown( | |
""" | |
This STT app is using a fine-tuned MMS ASR model. | |
""" | |
) | |
audio_path = upload_audio() | |
logger.info(f"Uploaded audio file: {audio_path}") | |
with st.spinner(text="Wait for loading ASR Model..."): | |
generator = call_load_model() | |
if audio_path is not None: | |
start_time = time.time() | |
with st.spinner(text='Wait for inference...'): | |
output = inference(generator, audio_path) | |
end_time = time.time() | |
process_time = time.gmtime(end_time - start_time) | |
process_time = time.strftime("%H hour %M min %S secs", process_time) | |
st.success(f"Inference finished in {process_time}.") | |
st.write(f"output: {output['text']}") | |
st.title("Microphone Input for ASR") | |
# Initialize the audio processor | |
audio_processor = AudioProcessor() | |
webrtc_streamer( | |
key="audio", | |
mode=WebRtcMode.SENDONLY, | |
audio_processor_factory=lambda: audio_processor, | |
media_stream_constraints={"audio": True, "video": False}, | |
) | |
if st.button("Process Audio"): | |
audio_data = audio_processor.get_audio_data() | |
if audio_data is not None: | |
# Convert the NumPy array to a WAV-like audio segment | |
audio_segment = pydub.AudioSegment( | |
audio_data.tobytes(), | |
frame_rate=16000, # Default WebRTC audio frame rate | |
sample_width=2, # 16-bit audio | |
channels=1 # Mono | |
) | |
# Save or process audio_segment as needed | |
st.success("Audio captured successfully!") | |
# st.audio(audio_segment.export(format="wav"), format="audio/wav") | |
else: | |
st.warning("No audio data captured!") | |
if st.button("Transcribe Audio"): | |
if audio_data is not None: | |
# Perform ASR on the audio segment | |
transcription = inference(generator, audio_segment.raw_data) | |
st.text_area("Transcription", transcription["text"]) | |
else: | |
st.warning("No audio data to transcribe!") | |
if __name__ == "__main__": | |
# Setting logger | |
logger.setLevel(logging.INFO) | |
formatter = logging.Formatter("%(levelname)8s %(asctime)s %(name)s %(message)s") | |
stream_handler = logging.StreamHandler() | |
stream_handler.setFormatter(formatter) | |
logger.addHandler(stream_handler) | |
now = datetime.datetime.now() | |
now_time = now.strftime('%Y-%m-%d-%H:%M:%S') | |
log_dir = Path(LOG_DIR) | |
log_dir.mkdir(parents=True, exist_ok=True) | |
log_file = log_dir / f"{now_time}.log" | |
file_handler = logging.FileHandler(str(log_file), encoding='utf-8') | |
file_handler.setFormatter(formatter) | |
logger.addHandler(file_handler) | |
logger.info('Start App') | |
main() | |