from collections import deque import streamlit as st import torch from streamlit_player import st_player from transformers import AutoModelForCTC, Wav2Vec2Processor from streaming import ffmpeg_stream device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') player_options = { "events": ["onProgress"], "progress_interval": 200, "volume": 1.0, "playing": True, "loop": False, "controls": False, "muted": False, "config": {"youtube": {"playerVars": {"start": 1}}}, } st.title("YouTube Video Spanish ASR") st.write("Acknowledgement: This demo is based on Anton Lozhkov's cool Space : https://huggingface.co/spaces/anton-l/youtube-subs-wav2vec") # disable rapid fading in and out on `st.code` updates st.markdown("", unsafe_allow_html=True) @st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None}) def load_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"): processor = Wav2Vec2Processor.from_pretrained(model_path) model = AutoModelForCTC.from_pretrained(model_path).to(device) return processor, model model_path = st.radio( "Select a model", ( "jonatasgrosman/wav2vec2-xls-r-1b-spanish", "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm", "facebook/wav2vec2-large-xlsr-53-spanish", "glob-asr/xls-r-es-test-lm" ) ) processor, model = load_model(model_path) def stream_text(url, chunk_duration_ms, pad_duration_ms): sampling_rate = processor.feature_extractor.sampling_rate # calculate the length of logits to cut from the sides of the output to account for input padding output_pad_len = model._get_feat_extract_output_lengths(int(sampling_rate * pad_duration_ms / 1000)) # define the audio chunk generator stream = ffmpeg_stream(url, sampling_rate, chunk_duration_ms=chunk_duration_ms, pad_duration_ms=pad_duration_ms) leftover_text = "" for i, chunk in enumerate(stream): input_values = processor(chunk, sampling_rate=sampling_rate, return_tensors="pt").input_values with torch.no_grad(): logits = model(input_values.to(device)).logits[0] if i > 0: logits = logits[output_pad_len : len(logits) - output_pad_len] else: # don't count padding at the start of the clip logits = logits[: len(logits) - output_pad_len] predicted_ids = torch.argmax(logits, dim=-1).cpu().tolist() if processor.decode(predicted_ids).strip(): leftover_ids = processor.tokenizer.encode(leftover_text) # concat the last word (or its part) from the last frame with the current text text = processor.decode(leftover_ids + predicted_ids) # don't return the last word in case it's just partially recognized if " " in text: text, leftover_text = text.rsplit(" ", 1) else: leftover_text = text text = "" if text: yield text else: yield leftover_text leftover_text = "" yield leftover_text def main(): state = st.session_state st.header("Video ASR Streamlit from Youtube Link") with st.form(key="inputs_form"): initial_url = "https://youtu.be/ghOqTkGzX7I?t=60" state.youtube_url = st.text_input("YouTube URL", initial_url) state.chunk_duration_ms = st.slider("Audio chunk duration (ms)", 2000, 10000, 3000, 100) state.pad_duration_ms = st.slider("Padding duration (ms)", 100, 5000, 1000, 100) submit_button = st.form_submit_button(label="Submit") if submit_button or "asr_stream" not in state: # a hack to update the video player on value changes state.youtube_url = ( state.youtube_url.split("&hash=")[0] + f"&hash={state.chunk_duration_ms}-{state.pad_duration_ms}" ) state.asr_stream = stream_text( state.youtube_url, state.chunk_duration_ms, state.pad_duration_ms ) state.chunks_taken = 0 state.lines = deque([], maxlen=5) # limit to the last n lines of subs player = st_player(state.youtube_url, **player_options, key="youtube_player") if "asr_stream" in state and player.data and player.data["played"] < 1.0: # check how many seconds were played, and if more than processed - write the next text chunk processed_seconds = state.chunks_taken * (state.chunk_duration_ms / 1000) if processed_seconds < player.data["playedSeconds"]: text = next(state.asr_stream) state.lines.append(text) state.chunks_taken += 1 if "lines" in state: # print the lines of subs st.code("\n".join(state.lines)) if __name__ == "__main__": main()