flozi00 commited on
Commit
36e1b68
·
1 Parent(s): 7f06476

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -43
app.py CHANGED
@@ -1,49 +1,89 @@
1
  from transformers import pipeline
 
2
  import gradio as gr
3
- from pyctcdecode import BeamSearchDecoderCTC
 
 
4
 
5
- #lmID = "aware-ai/german-lowercase-wiki-5gram"
6
- #decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID)
7
- p = pipeline("automatic-speech-recognition", model="aware-ai/wav2vec2-xls-r-1b-5gram-german")
8
- ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
- def transcribe(audio):
12
- transcribed = p(audio, chunk_length_s=16, stride_length_s=(4, 0))["text"]
13
- return transcribed
14
-
15
- def punctuate(transcribed):
16
- punctuated = ttp(transcribed, max_length = 512)[0]["generated_text"]
17
- return punctuated
18
-
19
- def get_asr_interface():
20
- return gr.Interface(
21
- fn=transcribe,
22
- inputs=[
23
- gr.inputs.Audio(source="microphone", type="filepath")
24
- ],
25
- outputs=[
26
- "textbox",
27
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- def get_punctuation_interface():
30
- return gr.Interface(
31
- fn=punctuate,
32
- inputs=[
33
- "textbox",
34
- ],
35
- outputs=[
36
- "textbox",
37
- ])
38
-
39
- interfaces = [
40
- get_asr_interface(),
41
- get_punctuation_interface(),
42
- ]
43
-
44
- names = [
45
- "ASR",
46
- "GRAMMAR",
47
- ]
48
-
49
- gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0", enable_queue=False)
 
1
  from transformers import pipeline
2
+ import torch
3
  import gradio as gr
4
+ import subprocess
5
+ import numpy as np
6
+ import time
7
 
8
+ p = pipeline("automatic-speech-recognition", model="aware-ai/wav2vec2-base-german")
 
 
 
9
 
10
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
11
+ model='silero_vad', force_reload=False, onnx=True)
12
+
13
+ def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array:
14
+ """
15
+ Helper function to read an audio file through ffmpeg.
16
+ """
17
+ ar = f"{sampling_rate}"
18
+ ac = "1"
19
+ format_for_conversion = "f32le"
20
+ ffmpeg_command = [
21
+ "ffmpeg",
22
+ "-i",
23
+ "pipe:0",
24
+ "-ac",
25
+ ac,
26
+ "-ar",
27
+ ar,
28
+ "-f",
29
+ format_for_conversion,
30
+ "-hide_banner",
31
+ "-loglevel",
32
+ "quiet",
33
+ "pipe:1",
34
+ ]
35
 
36
+ try:
37
+ with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process:
38
+ output_stream = ffmpeg_process.communicate(bpayload)
39
+ except FileNotFoundError as error:
40
+ raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error
41
+ out_bytes = output_stream[0]
42
+ audio = np.frombuffer(out_bytes, np.float32)
43
+ if audio.shape[0] == 0:
44
+ raise ValueError("Malformed soundfile")
45
+ return audio
46
+
47
+ (get_speech_timestamps,
48
+ _, read_audio,
49
+ *_) = utils
50
+
51
+ def is_speech(wav, sr):
52
+ speech_timestamps = get_speech_timestamps(wav, model,
53
+ sampling_rate=sr)
54
+
55
+ return len(speech_timestamps) > 0
56
+
57
+ def transcribe(audio, state={"text": "", "temp_text": "", "audio": ""}):
58
+ if state is None:
59
+ state={"text": "", "temp_text": "", "audio": ""}
60
+ with open(audio, "rb") as f:
61
+ payload = f.read()
62
+ audio = ffmpeg_read(payload, sampling_rate=16000)
63
+ _sr = 16000
64
+
65
+ speech = is_speech(wav_data, _sr)
66
+ if(speech):
67
+ if(state["audio"] is ""):
68
+ state["audio"] = wav_data
69
+ else:
70
+ state["audio"] = np.concatenate((state["audio"], wav_data))
71
+ else:
72
+ if(state["audio"] is not ""):
73
+ text = p(state["audio"])["text"] + "\n"
74
+ state["temp_text"] = text
75
 
76
+ state["text"] += state["temp_text"]
77
+ state["temp_text"] = ""
78
+ state["audio"] = ""
79
+
80
+ time.sleep(0.5)
81
+ return f'{state["text"]} ( {state["temp_text"]} )', state
82
+
83
+ gr.Interface(
84
+ transcribe,
85
+ [gr.Audio(source="microphone", type="filepath", streaming=True), "state"],
86
+
87
+ [gr.Textbox(),"state"],
88
+ live=True
89
+ ).launch(server_name = "0.0.0.0")