Spaces:
Runtime error
Runtime error
ahmedghani
commited on
Commit
•
8d39dd5
1
Parent(s):
6b562d4
whisper demo added
Browse files- app.py +177 -0
- packages.txt +1 -0
- requirements.txt +7 -0
app.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import whisper
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
LANGUAGES = {
|
7 |
+
"english":"en",
|
8 |
+
"chinese":"zh",
|
9 |
+
"german":"de",
|
10 |
+
"spanish":"es",
|
11 |
+
"russian":"ru",
|
12 |
+
"korean":"ko",
|
13 |
+
"french":"fr",
|
14 |
+
"japanese":"ja",
|
15 |
+
"portuguese":"pt",
|
16 |
+
"turkish":"tr",
|
17 |
+
"polish":"pl",
|
18 |
+
"catalan":"ca",
|
19 |
+
"dutch":"nl",
|
20 |
+
"arabic":"ar",
|
21 |
+
"swedish":"sv",
|
22 |
+
"italian":"it",
|
23 |
+
"indonesian":"id",
|
24 |
+
"hindi":"hi",
|
25 |
+
"finnish":"fi",
|
26 |
+
"vietnamese":"vi",
|
27 |
+
"hebrew":"iw",
|
28 |
+
"ukrainian":"uk",
|
29 |
+
"greek":"el",
|
30 |
+
"malay":"ms",
|
31 |
+
"czech":"cs",
|
32 |
+
"romanian":"ro",
|
33 |
+
"danish":"da",
|
34 |
+
"hungarian":"hu",
|
35 |
+
"tamil":"ta",
|
36 |
+
"norwegian":"no",
|
37 |
+
"thai":"th",
|
38 |
+
"urdu":"ur",
|
39 |
+
"croatian":"hr",
|
40 |
+
"bulgarian":"bg",
|
41 |
+
"lithuanian":"lt",
|
42 |
+
"latin":"la",
|
43 |
+
"maori":"mi",
|
44 |
+
"malayalam":"ml",
|
45 |
+
"welsh":"cy",
|
46 |
+
"slovak":"sk",
|
47 |
+
"telugu":"te",
|
48 |
+
"persian":"fa",
|
49 |
+
"latvian":"lv",
|
50 |
+
"bengali":"bn",
|
51 |
+
"serbian":"sr",
|
52 |
+
"azerbaijani":"az",
|
53 |
+
"slovenian":"sl",
|
54 |
+
"kannada":"kn",
|
55 |
+
"estonian":"et",
|
56 |
+
"macedonian":"mk",
|
57 |
+
"breton":"br",
|
58 |
+
"basque":"eu",
|
59 |
+
"icelandic":"is",
|
60 |
+
"armenian":"hy",
|
61 |
+
"nepali":"ne",
|
62 |
+
"mongolian":"mn",
|
63 |
+
"bosnian":"bs",
|
64 |
+
"kazakh":"kk",
|
65 |
+
"albanian":"sq",
|
66 |
+
"swahili":"sw",
|
67 |
+
"galician":"gl",
|
68 |
+
"marathi":"mr",
|
69 |
+
"punjabi":"pa",
|
70 |
+
"sinhala":"si",
|
71 |
+
"khmer":"km",
|
72 |
+
"shona":"sn",
|
73 |
+
"yoruba":"yo",
|
74 |
+
"somali":"so",
|
75 |
+
"afrikaans":"af",
|
76 |
+
"occitan":"oc",
|
77 |
+
"georgian":"ka",
|
78 |
+
"belarusian":"be",
|
79 |
+
"tajik":"tg",
|
80 |
+
"sindhi":"sd",
|
81 |
+
"gujarati":"gu",
|
82 |
+
"amharic":"am",
|
83 |
+
"yiddish":"yi",
|
84 |
+
"lao":"lo",
|
85 |
+
"uzbek":"uz",
|
86 |
+
"faroese":"fo",
|
87 |
+
"haitian creole":"ht",
|
88 |
+
"pashto":"ps",
|
89 |
+
"turkmen":"tk",
|
90 |
+
"nynorsk":"nn",
|
91 |
+
"maltese":"mt",
|
92 |
+
"sanskrit":"sa",
|
93 |
+
"luxembourgish":"lb",
|
94 |
+
"myanmar":"my",
|
95 |
+
"tibetan":"bo",
|
96 |
+
"tagalog":"tl",
|
97 |
+
"malagasy":"mg",
|
98 |
+
"assamese":"as",
|
99 |
+
"tatar":"tt",
|
100 |
+
"hawaiian":"haw",
|
101 |
+
"lingala":"ln",
|
102 |
+
"hausa":"ha",
|
103 |
+
"bashkir":"ba",
|
104 |
+
"javanese":"jw",
|
105 |
+
"sundanese":"su",
|
106 |
+
}
|
107 |
+
|
108 |
+
def decode(model, mel, options):
|
109 |
+
result = whisper.decode(model, mel, options)
|
110 |
+
return result.text
|
111 |
+
|
112 |
+
def load_audio(path):
|
113 |
+
waveform, sample_rate = torchaudio.load(path)
|
114 |
+
if sample_rate != 16000:
|
115 |
+
waveform = torchaudio.transforms.Resample(sample_rate, 16000)(waveform)
|
116 |
+
return waveform.squeeze(0)
|
117 |
+
|
118 |
+
def detect_language(model, mel):
|
119 |
+
_, probs = model.detect_language(mel)
|
120 |
+
return max(probs, key=probs.get)
|
121 |
+
|
122 |
+
def main():
|
123 |
+
|
124 |
+
st.title("Whisper ASR Demo")
|
125 |
+
st.markdown(
|
126 |
+
"""
|
127 |
+
This is a demo of OpenAI's Whisper ASR model. The model is trained on 680,000 hours of dataset.
|
128 |
+
"""
|
129 |
+
)
|
130 |
+
|
131 |
+
model_selection = st.sidebar.selectbox("Select model", ["tiny", "base", "small", "medium", "large"])
|
132 |
+
en_model_selection = st.sidebar.checkbox("English only model", value=False)
|
133 |
+
|
134 |
+
if en_model_selection:
|
135 |
+
model_selection += ".en"
|
136 |
+
st.sidebar.write(f"Model: {model_selection+' (Multilingual)' if not en_model_selection else model_selection + ' (English only)'}")
|
137 |
+
|
138 |
+
if st.sidebar.checkbox("Show supported languages", value=False):
|
139 |
+
st.sidebar.info(list(LANGUAGES.keys()))
|
140 |
+
st.sidebar.title("Options")
|
141 |
+
|
142 |
+
beam_size = st.sidebar.slider("Beam Size", min_value=1, max_value=10, value=5)
|
143 |
+
fp16 = st.sidebar.checkbox("Enable FP16 for faster transcription (It may affect performance)", value=False)
|
144 |
+
|
145 |
+
if not en_model_selection:
|
146 |
+
task = st.sidebar.selectbox("Select task", ["transcribe", "translate (To English)"], index=0)
|
147 |
+
else:
|
148 |
+
task = st.sidebar.selectbox("Select task", ["transcribe"], index=0)
|
149 |
+
|
150 |
+
st.title("Audio")
|
151 |
+
audio_file = st.file_uploader("Upload Audio", type=["wav", "mp3", "flac"])
|
152 |
+
|
153 |
+
if audio_file is not None:
|
154 |
+
st.audio(audio_file, format='audio/ogg')
|
155 |
+
with st.spinner("Loading model..."):
|
156 |
+
model = whisper.load_model(model_selection)
|
157 |
+
model = model.to("cpu") if not torch.cuda.is_available() else model.to("cuda")
|
158 |
+
|
159 |
+
|
160 |
+
audio = load_audio(audio_file)
|
161 |
+
with st.spinner("Extracting features..."):
|
162 |
+
audio = whisper.pad_or_trim(audio)
|
163 |
+
mel = whisper.log_mel_spectrogram(audio).to(model.device)
|
164 |
+
if not en_model_selection:
|
165 |
+
with st.spinner("Detecting language..."):
|
166 |
+
language = detect_language(model, mel)
|
167 |
+
st.markdown(f"Detected Language: {language}")
|
168 |
+
else:
|
169 |
+
language = "en"
|
170 |
+
configuration = {"beam_size": beam_size, "fp16": fp16, "task": task, "language": language}
|
171 |
+
with st.spinner("Transcribing..."):
|
172 |
+
options = whisper.DecodingOptions(**configuration)
|
173 |
+
text = decode(model, mel, options)
|
174 |
+
st.markdown(f"**Recognized Text:** {text}")
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
main()
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
ffmpeg
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
torchaudio
|
4 |
+
tqdm
|
5 |
+
more-itertools
|
6 |
+
transformers>=4.19.0
|
7 |
+
ffmpeg-python==0.2.0
|