Lguyogiro's picture
try new aproach
faee479
import time
import datetime
import logging
import soundfile
import streamlit as st
from streamlit_webrtc import webrtc_streamer, AudioProcessorBase, WebRtcMode
import numpy as np
import pydub
from pathlib import Path
from asr import load_model, inference
LOG_DIR = "./logs"
DATA_DIR = "./data"
logger = logging.getLogger(__name__)
# Define a custom audio processor to handle microphone input
class AudioProcessor(AudioProcessorBase):
def __init__(self):
self.audio_data = []
def recv_audio(self, frame):
# Convert the audio frame to a NumPy array
audio_array = np.frombuffer(frame.to_ndarray(), dtype=np.int16)
self.audio_data.append(audio_array)
return frame
def get_audio_data(self):
# Combine all captured audio data
if self.audio_data:
combined = np.concatenate(self.audio_data, axis=0)
return combined
return None
def upload_audio() -> Path:
# Upload audio file
uploaded_file = st.file_uploader("Choose a audio file(wav, mp3, flac)", type=['wav','mp3','flac'])
if uploaded_file is not None:
# Save audio file
audio_data, samplerate = soundfile.read(uploaded_file)
# Make save directory
now = datetime.datetime.now()
now_time = now.strftime('%Y-%m-%d-%H:%M:%S')
audio_dir = Path(DATA_DIR) / f"{now_time}"
audio_dir.mkdir(parents=True, exist_ok=True)
audio_path = audio_dir / uploaded_file.name
soundfile.write(audio_path, audio_data, samplerate)
# Show audio file
with open(audio_path, 'rb') as audio_file:
audio_bytes = audio_file.read()
st.audio(audio_bytes, format=uploaded_file.type)
return audio_path
@st.cache_resource(show_spinner=False)
def call_load_model():
generator = load_model()
return generator
def main():
st.header("Speech-to-Text app with streamlit")
st.markdown(
"""
This STT app is using a fine-tuned MMS ASR model.
"""
)
audio_path = upload_audio()
logger.info(f"Uploaded audio file: {audio_path}")
with st.spinner(text="Wait for loading ASR Model..."):
generator = call_load_model()
if audio_path is not None:
start_time = time.time()
with st.spinner(text='Wait for inference...'):
output = inference(generator, audio_path)
end_time = time.time()
process_time = time.gmtime(end_time - start_time)
process_time = time.strftime("%H hour %M min %S secs", process_time)
st.success(f"Inference finished in {process_time}.")
st.write(f"output: {output['text']}")
st.title("Microphone Input for ASR")
# Initialize the audio processor
audio_processor = AudioProcessor()
webrtc_streamer(
key="audio",
mode=WebRtcMode.SENDONLY,
audio_processor_factory=lambda: audio_processor,
media_stream_constraints={"audio": True, "video": False},
)
if st.button("Process Audio"):
audio_data = audio_processor.get_audio_data()
if audio_data is not None:
# Convert the NumPy array to a WAV-like audio segment
audio_segment = pydub.AudioSegment(
audio_data.tobytes(),
frame_rate=16000, # Default WebRTC audio frame rate
sample_width=2, # 16-bit audio
channels=1 # Mono
)
# Save or process audio_segment as needed
st.success("Audio captured successfully!")
# st.audio(audio_segment.export(format="wav"), format="audio/wav")
else:
st.warning("No audio data captured!")
if st.button("Transcribe Audio"):
if audio_data is not None:
# Perform ASR on the audio segment
transcription = inference(generator, audio_segment.raw_data)
st.text_area("Transcription", transcription["text"])
else:
st.warning("No audio data to transcribe!")
if __name__ == "__main__":
# Setting logger
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(levelname)8s %(asctime)s %(name)s %(message)s")
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
now = datetime.datetime.now()
now_time = now.strftime('%Y-%m-%d-%H:%M:%S')
log_dir = Path(LOG_DIR)
log_dir.mkdir(parents=True, exist_ok=True)
log_file = log_dir / f"{now_time}.log"
file_handler = logging.FileHandler(str(log_file), encoding='utf-8')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.info('Start App')
main()