Romain COARD commited on
Commit
3a3a6e5
β€’
1 Parent(s): d7c84cf

Add application file

Browse files
Files changed (5) hide show
  1. .gitignore +3 -0
  2. README.md +4 -4
  3. app.py +141 -0
  4. requirements.txt +10 -0
  5. sample1.wav +0 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ flagged/
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Whisper Diarization
3
- emoji: πŸ“š
4
- colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.29.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Whisper Speech To Text
3
+ emoji: 🏒
4
+ colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
4
+ import numpy as np
5
+ from pyannote.audio import Pipeline
6
+ import os
7
+ from dotenv import load_dotenv
8
+ import plotly.graph_objects as go
9
+
10
+ load_dotenv()
11
+
12
+ # Check and set device
13
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {device}")
15
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
+
17
+ # Model and pipeline setup
18
+ model_id = "distil-whisper/distil-small.en"
19
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(
20
+ model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
21
+ )
22
+ model.to(device)
23
+
24
+ processor = AutoProcessor.from_pretrained(model_id)
25
+
26
+ pipe = pipeline(
27
+ "automatic-speech-recognition",
28
+ model=model,
29
+ tokenizer=processor.tokenizer,
30
+ feature_extractor=processor.feature_extractor,
31
+ max_new_tokens=128,
32
+ torch_dtype=torch_dtype,
33
+ device=device,
34
+ )
35
+
36
+ diarization_pipeline = Pipeline.from_pretrained(
37
+ "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY")
38
+ )
39
+
40
+
41
+ # returns diarization info such as segment start and end times, and speaker id
42
+ def diarization_info(res):
43
+ starts = []
44
+ ends = []
45
+ speakers = []
46
+
47
+ for segment, _, speaker in res.itertracks(yield_label=True):
48
+ starts.append(segment.start)
49
+ ends.append(segment.end)
50
+ speakers.append(speaker)
51
+
52
+ return starts, ends, speakers
53
+
54
+
55
+ # plot diarization results on a graph
56
+ def plot_diarization(starts, ends, speakers):
57
+ fig = go.Figure()
58
+
59
+ # Define a color map for different speakers
60
+ num_speakers = len(set(speakers))
61
+ colors = [f"hsl({h},80%,60%)" for h in np.linspace(0, 360, num_speakers)]
62
+
63
+ # Plot each segment with its speaker's color
64
+ for start, end, speaker in zip(starts, ends, speakers):
65
+ speaker_id = list(set(speakers)).index(speaker)
66
+ fig.add_trace(
67
+ go.Scatter(
68
+ x=[start, end],
69
+ y=[speaker_id, speaker_id],
70
+ mode="lines",
71
+ line=dict(color=colors[speaker_id], width=15),
72
+ showlegend=False,
73
+ )
74
+ )
75
+
76
+ fig.update_layout(
77
+ title="Speaker Diarization",
78
+ xaxis=dict(title="Time"),
79
+ yaxis=dict(title="Speaker"),
80
+ height=600,
81
+ width=800,
82
+ )
83
+
84
+ return fig
85
+
86
+
87
+ def transcribe(sr, data):
88
+ processed_data = np.array(data).astype(np.float32) / 32767.0
89
+
90
+ # results from the pipeline
91
+ transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"]
92
+
93
+ return transcription_res
94
+
95
+
96
+ def transcribe_diarize(audio):
97
+ sr, data = audio
98
+ processed_data = np.array(data).astype(np.float32) / 32767.0
99
+ waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
100
+
101
+ transcription_res = transcribe(sr, data)
102
+
103
+ # results from the diarization pipeline
104
+ diarization_res = diarization_pipeline(
105
+ {"waveform": waveform_tensor, "sample_rate": sr}
106
+ )
107
+
108
+ # Get diarization information
109
+ starts, ends, speakers = diarization_info(diarization_res)
110
+
111
+ # results from the transcription pipeline
112
+ diarized_transcription = ""
113
+
114
+ # Get transcription results for each speaker segment
115
+ for start_time, end_time, speaker_id in zip(starts, ends, speakers):
116
+ segment = data[int(start_time * sr) : int(end_time * sr)]
117
+ diarized_transcription += f"{speaker_id} {round(start_time, 2)}:{round(end_time, 2)} \t {transcribe(sr, segment)}\n"
118
+
119
+ # Plot diarization
120
+ diarization_plot = plot_diarization(starts, ends, speakers)
121
+
122
+ return transcription_res, diarized_transcription, diarization_plot
123
+
124
+
125
+ # creating the gradio interface
126
+ demo = gr.Interface(
127
+ fn=transcribe_diarize,
128
+ inputs=gr.Audio(sources=["upload", "microphone"]),
129
+ outputs=[
130
+ gr.Textbox(lines=3, label="Text Transcription"),
131
+ gr.Textbox(label="Diarized Transcription"),
132
+ gr.Plot(label="Visualization"),
133
+ ],
134
+ examples=["sample1.wav"],
135
+ title="Automatic Speech Recognition with Diarization πŸ—£οΈ",
136
+ description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) πŸŽ™οΈ",
137
+ )
138
+
139
+
140
+ if __name__ == "__main__":
141
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.24.1
2
+ transformers==4.33.2
3
+ gradio==4.12.0
4
+ pyannote.audio==3.1.1
5
+ pyannote.core==5.0.0
6
+ pyannote.database==5.0.1
7
+ pyannote.metrics==3.2.1
8
+ pyannote.pipeline==3.0.1
9
+ python-dotenv==1.0.0
10
+ plotly==5.18.0
sample1.wav ADDED
Binary file (438 kB). View file