viktor-enzell commited on
Commit
091b848
β€’
1 Parent(s): cca4571

Refactoring.

Browse files
Files changed (2) hide show
  1. README.md +2 -2
  2. app.py +75 -58
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Wav2vec2 Large Voxrex Swedish 4gram
3
  emoji: πŸŽ™οΈ
4
- colorFrom: orange
5
- colorTo: black
6
  sdk: streamlit
7
  sdk_version: 1.9.0
8
  app_file: app.py
 
1
  ---
2
  title: Wav2vec2 Large Voxrex Swedish 4gram
3
  emoji: πŸŽ™οΈ
4
+ colorFrom: blue
5
+ colorTo: yellow
6
  sdk: streamlit
7
  sdk_version: 1.9.0
8
  app_file: app.py
app.py CHANGED
@@ -5,61 +5,78 @@ import torchaudio
5
  import torchaudio.functional as F
6
 
7
 
8
- st.set_page_config(
9
- page_title="Swedish Speech-to-Text",
10
- page_icon="πŸŽ™οΈ"
11
- )
12
- st.image(
13
- "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/320/apple/325/studio-microphone_1f399-fe0f.png",
14
- width=100,
15
- )
16
- st.markdown("""
17
- # Swedish high-quality transcription
18
-
19
- Generate Swedish transcripts for download from an audio file with this high-quality speech-to-text model. The model is KBLab's wav2vec 2.0 large VoxRex Swedish (C) with a 4-gram language model, which you can access [here](https://huggingface.co/viktor-enzell/wav2vec2-large-voxrex-swedish-4gram).
20
- """)
21
-
22
- model_name = "viktor-enzell/wav2vec2-large-voxrex-swedish-4gram"
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
25
- processor = Wav2Vec2ProcessorWithLM.from_pretrained(model_name)
26
-
27
-
28
- def run_inference(file):
29
- waveform, sample_rate = torchaudio.load(file)
30
-
31
- if sample_rate == 16_000:
32
- waveform = waveform[0]
33
- else:
34
- waveform = F.resample(waveform, sample_rate, 16_000)[0]
35
-
36
- inputs = processor(
37
- waveform,
38
- sampling_rate=16_000,
39
- return_tensors="pt",
40
- padding=True
41
- ).to(device)
42
-
43
- with torch.no_grad():
44
- logits = model(**inputs).logits
45
-
46
- return processor.batch_decode(logits.cpu().numpy()).text[0].lower()
47
-
48
-
49
- uploaded_file = st.file_uploader("Choose a file", type=[".wav"])
50
- if uploaded_file is not None:
51
- if uploaded_file.type != "audio/wav":
52
- pass
53
- # TODO: convert to wav
54
- # bytes = uploaded_file.getvalue()
55
- # audio_input = ffmpeg.input(bytes).audio
56
- # audio_output = ffmpeg.output(audio_input, "tmp.wav", format="wav")
57
- # ffmpeg.run(audio_output)
58
-
59
- transcript = run_inference(uploaded_file)
60
-
61
- st.download_button("Download transcript", transcript,
62
- f"{uploaded_file.name}-swedish-transcript.txt")
63
-
64
- with st.expander("Transcript", expanded=True):
65
- st.write(transcript)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torchaudio.functional as F
6
 
7
 
8
+ class ASR:
9
+ def __init__(self):
10
+ self.model_name = "viktor-enzell/wav2vec2-large-voxrex-swedish-4gram"
11
+ self.device = torch.device(
12
+ "cuda" if torch.cuda.is_available() else "cpu")
13
+ self.model = None
14
+ self.processor = None
15
+
16
+ def load_model(self):
17
+ self.model = Wav2Vec2ForCTC.from_pretrained(
18
+ self.model_name).to(self.device)
19
+ self.processor = Wav2Vec2ProcessorWithLM.from_pretrained(
20
+ self.model_name)
21
+
22
+ def run_inference(self, file):
23
+ waveform, sample_rate = torchaudio.load(file)
24
+
25
+ if sample_rate == 16_000:
26
+ waveform = waveform[0]
27
+ else:
28
+ waveform = F.resample(waveform, sample_rate, 16_000)[0]
29
+
30
+ inputs = self.processor(
31
+ waveform,
32
+ sampling_rate=16_000,
33
+ return_tensors="pt",
34
+ padding=True
35
+ ).to(self.device)
36
+
37
+ with torch.no_grad():
38
+ logits = self.model(**inputs).logits
39
+
40
+ return self.processor.batch_decode(logits.cpu().numpy()).text[0].lower()
41
+
42
+
43
+ @st.cache(allow_output_mutation=True, show_spinner=True)
44
+ def load_model():
45
+ asr = ASR()
46
+ asr.load_model()
47
+ return asr
48
+
49
+
50
+ if __name__ == "__main__":
51
+ st.set_page_config(
52
+ page_title="Swedish Speech-to-Text",
53
+ page_icon="πŸŽ™οΈ"
54
+ )
55
+ st.image(
56
+ "https://emojipedia-us.s3.dualstack.us-west-1.amazonaws.com/thumbs/320/apple/325/studio-microphone_1f399-fe0f.png",
57
+ width=100,
58
+ )
59
+ st.markdown("""
60
+ # Swedish high-quality transcription
61
+
62
+ Generate Swedish transcripts for download from an audio file with this high-quality speech-to-text model. The model is KBLab's wav2vec 2.0 large VoxRex Swedish (C) with a 4-gram language model, which you can access [here](https://huggingface.co/viktor-enzell/wav2vec2-large-voxrex-swedish-4gram).
63
+ """)
64
+
65
+ asr = load_model()
66
+
67
+ uploaded_file = st.file_uploader("Choose a file", type=[".wav"])
68
+ if uploaded_file is not None:
69
+ if uploaded_file.type != "audio/wav":
70
+ pass
71
+ # TODO: convert to wav
72
+ # bytes = uploaded_file.getvalue()
73
+ # audio_input = ffmpeg.input(bytes).audio
74
+ # audio_output = ffmpeg.output(audio_input, "tmp.wav", format="wav")
75
+ # ffmpeg.run(audio_output)
76
+
77
+ transcript = asr.run_inference(uploaded_file)
78
+
79
+ st.download_button("Download transcript", transcript, "transcript.txt")
80
+
81
+ with st.expander("Transcript", expanded=True):
82
+ st.write(transcript)