Anustup commited on
Commit
a929b20
1 Parent(s): c7060f6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os, time, librosa, torch
4
+ from pyannote.audio import Pipeline
5
+ from transformers import pipeline
6
+ from utils import second_to_timecode, download_from_youtube
7
+
8
+ MODEL_NAME = 'openai/whisper-medium'
9
+ lang = 'en'
10
+
11
+ chunk_length_s = 9
12
+ vad_activation_min_duration = 9 # sec
13
+ device = 0 if torch.cuda.is_available() else "cpu"
14
+ SAMPLE_RATE = 16_000
15
+
16
+ ######## LOAD MODELS FROM HUB ########
17
+ dia_model = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=os.environ['TOKEN'])
18
+ vad_model = Pipeline.from_pretrained("pyannote/voice-activity-detection", use_auth_token=os.environ['TOKEN'])
19
+ pipe = pipeline(task="automatic-speech-recognition", model=MODEL_NAME, chunk_length_s=chunk_length_s, device=device)
20
+ pipe.model.config.forced_decoder_ids = pipe.tokenizer.get_decoder_prompt_ids(language=lang, task="transcribe")
21
+
22
+ print("----------> Loaded models <-----------")
23
+
24
+ def generator(youtube_link, microphone, file_upload, num_speakers, max_duration, history):
25
+
26
+ if int(youtube_link != '') + int(microphone is not None) + int(file_upload is not None) != 1:
27
+ raise Exception(f"Only one of the source should be given youtube_link={youtube_link}, microphone={microphone}, file_upload={file_upload}")
28
+
29
+ history = history or ""
30
+
31
+ if microphone:
32
+ path = microphone
33
+ elif file_upload:
34
+ path = file_upload
35
+ elif youtube_link:
36
+ path = download_from_youtube(youtube_link)
37
+
38
+ waveform, sampling_rate = librosa.load(path, sr=SAMPLE_RATE, mono=True, duration=max_duration)
39
+
40
+ print(waveform.shape, sampling_rate)
41
+ waveform_tensor = torch.unsqueeze(torch.tensor(waveform), 0).to(device)
42
+
43
+ dia_result = dia_model({
44
+ "waveform": waveform_tensor,
45
+ "sample_rate": sampling_rate,
46
+ }, num_speakers=num_speakers)
47
+
48
+ for speech_turn, track, speaker in dia_result.itertracks(yield_label=True):
49
+ print(f"{speech_turn.start:4.1f} {speech_turn.end:4.1f} {speaker}")
50
+ _start = int(sampling_rate * speech_turn.start)
51
+ _end = int(sampling_rate * speech_turn.end)
52
+ data = waveform[_start: _end]
53
+
54
+ if speech_turn.end - speech_turn.start > vad_activation_min_duration:
55
+ print(f'audio duration {speech_turn.end - speech_turn.start} sec ----> activating VAD')
56
+ vad_output = vad_model({
57
+ 'waveform': waveform_tensor[:, _start:_end],
58
+ 'sample_rate': sampling_rate})
59
+ for vad_turn in vad_output.get_timeline().support():
60
+ vad_start = _start + int(sampling_rate * vad_turn.start)
61
+ vad_end = _start + int(sampling_rate * vad_turn.end)
62
+ prediction = pipe(waveform[vad_start: vad_end])['text']
63
+ history += f"{second_to_timecode(speech_turn.start + vad_turn.start)},{second_to_timecode(speech_turn.start + vad_turn.end)}\n" + \
64
+ f"{prediction}\n\n"
65
+ # f">> {speaker}: {prediction}\n\n"
66
+ yield history, history, None
67
+
68
+ else:
69
+ prediction = pipe(data)['text']
70
+ history += f"{second_to_timecode(speech_turn.start)},{second_to_timecode(speech_turn.end)}\n" + \
71
+ f"{prediction}\n\n"
72
+ # f">> {speaker}: {prediction}\n\n"
73
+
74
+ yield history, history, None
75
+
76
+ # https://support.google.com/youtube/answer/2734698?hl=en#zippy=%2Cbasic-file-formats%2Csubrip-srt-example%2Csubviewer-sbv-example
77
+ file_name = 'transcript.sbv'
78
+ with open(file_name, 'w') as fp:
79
+ fp.write(history)
80
+
81
+ yield history, history, file_name
82
+
83
+ demo = gr.Interface(
84
+ generator,
85
+ inputs=[
86
+ gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL", optional=True),
87
+ gr.inputs.Audio(source="microphone", type="filepath", optional=True),
88
+ gr.inputs.Audio(source="upload", type="filepath", optional=True),
89
+ gr.Number(value=1, label="Number of Speakers"),
90
+ gr.Number(value=120, label="Maximum Duration (Seconds)"),
91
+ 'state',
92
+ ],
93
+ outputs=['text', 'state', 'file'],
94
+ layout="horizontal",
95
+ theme="huggingface",
96
+ allow_flagging="never",
97
+ )
98
+
99
+ # define queue - required for generators
100
+ demo.queue()
101
+
102
+ demo.launch()