File size: 5,026 Bytes
4df127e
55dac6c
83e4ceb
55dac6c
 
ca317b4
55dac6c
 
 
 
 
 
 
 
 
 
 
 
 
32af9ec
ef93443
55dac6c
 
 
 
 
 
 
 
 
e4a22f3
 
 
 
 
 
 
 
 
55dac6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4a22f3
 
 
2f42c1a
 
 
 
55dac6c
 
 
 
 
 
 
 
4250582
55dac6c
 
e4a22f3
 
55dac6c
 
 
 
 
4250582
 
 
 
55dac6c
 
 
 
 
 
 
 
 
 
 
 
e4a22f3
55dac6c
 
 
 
 
 
 
 
 
 
 
4250582
55dac6c
 
 
9a71320
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from collections import deque    
import streamlit as st
import torch 
from streamlit_player import st_player
from transformers import AutoModelForCTC, Wav2Vec2Processor
from streaming import ffmpeg_stream 

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
player_options = {
    "events": ["onProgress"],
    "progress_interval": 200,
    "volume": 1.0,
    "playing": True,
    "loop": False,
    "controls": False,
    "muted": False,
    "config": {"youtube": {"playerVars": {"start": 1}}},
}

st.title("YouTube Video Spanish ASR")
st.write("Acknowledgement: This demo is based on Anton Lozhkov's cool Space : https://huggingface.co/spaces/anton-l/youtube-subs-wav2vec")
# disable rapid fading in and out on `st.code` updates
st.markdown("<style>.element-container{opacity:1 !important}</style>", unsafe_allow_html=True)

@st.cache(hash_funcs={torch.nn.parameter.Parameter: lambda _: None})
def load_model(model_path="facebook/wav2vec2-large-robust-ft-swbd-300h"):
    processor = Wav2Vec2Processor.from_pretrained(model_path)
    model = AutoModelForCTC.from_pretrained(model_path).to(device)
    return processor, model

model_path = st.radio(
    "Select a model", (
        "jonatasgrosman/wav2vec2-xls-r-1b-spanish", 
        "jonatasgrosman/wav2vec2-large-xlsr-53-spanish", 
        "patrickvonplaten/wav2vec2-large-xlsr-53-spanish-with-lm", 
        "facebook/wav2vec2-large-xlsr-53-spanish", 
        "glob-asr/xls-r-es-test-lm"
    )
)
   
processor, model = load_model(model_path)

def stream_text(url, chunk_duration_ms, pad_duration_ms):
    sampling_rate = processor.feature_extractor.sampling_rate

    # calculate the length of logits to cut from the sides of the output to account for input padding
    output_pad_len = model._get_feat_extract_output_lengths(int(sampling_rate * pad_duration_ms / 1000))

    # define the audio chunk generator
    stream = ffmpeg_stream(url, sampling_rate, chunk_duration_ms=chunk_duration_ms, pad_duration_ms=pad_duration_ms)

    leftover_text = ""
    for i, chunk in enumerate(stream):
        input_values = processor(chunk, sampling_rate=sampling_rate, return_tensors="pt").input_values

        with torch.no_grad():
            logits = model(input_values.to(device)).logits[0]
            if i > 0:
                logits = logits[output_pad_len : len(logits) - output_pad_len]
            else:  # don't count padding at the start of the clip
                logits = logits[: len(logits) - output_pad_len]

            predicted_ids = torch.argmax(logits, dim=-1).cpu().tolist()
            if processor.decode(predicted_ids).strip():
                leftover_ids = processor.tokenizer.encode(leftover_text)
                # concat the last word (or its part) from the last frame with the current text
                text = processor.decode(leftover_ids + predicted_ids)
                # don't return the last word in case it's just partially recognized
                if " " in text:
                    text, leftover_text = text.rsplit(" ", 1)
                else:
                    leftover_text = text
                    text = ""
                if text:
                    yield text
            else:
                yield leftover_text
                leftover_text = ""
    yield leftover_text

def main():
    state = st.session_state
    st.header("Video ASR Streamlit from Youtube Link")

    with st.form(key="inputs_form"):
    
        initial_url = "https://youtu.be/ghOqTkGzX7I?t=60"
        state.youtube_url = st.text_input("YouTube URL", initial_url)
        
        state.chunk_duration_ms = st.slider("Audio chunk duration (ms)", 2000, 10000, 3000, 100)
        state.pad_duration_ms = st.slider("Padding duration (ms)", 100, 5000, 1000, 100)
        submit_button = st.form_submit_button(label="Submit")

    if "lines" in state:
        # print the lines of subs
        st.code("\n".join(state.lines))

    if submit_button or "asr_stream" not in state:
        # a hack to update the video player on value changes
        state.youtube_url = (
            state.youtube_url.split("&hash=")[0]
            + f"&hash={state.chunk_duration_ms}-{state.pad_duration_ms}"
        )
        state.asr_stream = stream_text(
            state.youtube_url, state.chunk_duration_ms, state.pad_duration_ms
        )
        state.chunks_taken = 0
        
        
        state.lines = deque([], maxlen=5)  # limit to the last n lines of subs
        

    player = st_player(state.youtube_url, **player_options, key="youtube_player")

    if "asr_stream" in state and player.data and player.data["played"] < 1.0:
        # check how many seconds were played, and if more than processed - write the next text chunk
        processed_seconds = state.chunks_taken * (state.chunk_duration_ms / 1000)
        if processed_seconds < player.data["playedSeconds"]:
            text = next(state.asr_stream)
            state.lines.append(text)
            state.chunks_taken += 1



if __name__ == "__main__":
    main()