Spaces:
Runtime error
Runtime error
romsyflux
commited on
Commit
•
5a1ff9c
1
Parent(s):
e4ba58a
Updated match logic
Browse files
app.py
CHANGED
@@ -6,6 +6,7 @@ from pyannote.audio import Pipeline
|
|
6 |
import os
|
7 |
from dotenv import load_dotenv
|
8 |
import plotly.graph_objects as go
|
|
|
9 |
|
10 |
load_dotenv()
|
11 |
|
@@ -88,7 +89,7 @@ def transcribe(sr, data):
|
|
88 |
processed_data = np.array(data).astype(np.float32) / 32767.0
|
89 |
|
90 |
# results from the pipeline
|
91 |
-
transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})
|
92 |
|
93 |
return transcription_res
|
94 |
|
@@ -104,17 +105,26 @@ def transcribe_diarize(audio):
|
|
104 |
diarization_res = diarization_pipeline(
|
105 |
{"waveform": waveform_tensor, "sample_rate": sr}
|
106 |
)
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
# Get diarization information
|
109 |
starts, ends, speakers = diarization_info(diarization_res)
|
110 |
|
111 |
# results from the transcription pipeline
|
112 |
-
diarized_transcription =
|
113 |
-
|
114 |
-
# Get transcription results for each speaker segment
|
115 |
-
for start_time, end_time, speaker_id in zip(starts, ends, speakers):
|
116 |
-
segment = data[int(start_time * sr) : int(end_time * sr)]
|
117 |
-
diarized_transcription += f"{speaker_id} {round(start_time, 2)}:{round(end_time, 2)} \t {transcribe(sr, segment)}\n"
|
118 |
|
119 |
# Plot diarization
|
120 |
diarization_plot = plot_diarization(starts, ends, speakers)
|
@@ -132,8 +142,8 @@ demo = gr.Interface(
|
|
132 |
gr.Plot(label="Visualization"),
|
133 |
],
|
134 |
examples=["sample1.wav"],
|
135 |
-
title="Automatic Speech Recognition with Diarization 🗣️
|
136 |
-
description="
|
137 |
)
|
138 |
|
139 |
if __name__ == "__main__":
|
|
|
6 |
import os
|
7 |
from dotenv import load_dotenv
|
8 |
import plotly.graph_objects as go
|
9 |
+
from .utils.diarize_utils import match_segments
|
10 |
|
11 |
load_dotenv()
|
12 |
|
|
|
89 |
processed_data = np.array(data).astype(np.float32) / 32767.0
|
90 |
|
91 |
# results from the pipeline
|
92 |
+
transcription_res = pipe({"sampling_rate": sr, "raw": processed_data},return_timestamps=True)
|
93 |
|
94 |
return transcription_res
|
95 |
|
|
|
105 |
diarization_res = diarization_pipeline(
|
106 |
{"waveform": waveform_tensor, "sample_rate": sr}
|
107 |
)
|
108 |
+
dia_seg, dia_label = [], []
|
109 |
+
for segment, _, label in diarization_res.itertracks(yield_label=True):
|
110 |
+
dia_seg.append([segment.start, segment.end])
|
111 |
+
dia_label.append(label)
|
112 |
+
assert (
|
113 |
+
dia_seg
|
114 |
+
), "The result from the diarization pipeline: `diarization_segments` is empty. No segments found from the diarization process."
|
115 |
+
segmented_preds = transcription_res["chunks"]
|
116 |
+
dia_seg = np.array(dia_seg)
|
117 |
+
asr_seg = np.array([[*chunk["timestamp"]] for chunk in segmented_preds])
|
118 |
+
|
119 |
+
asr_labels = match_segments(dia_seg, dia_label, asr_seg, threshold=0.0, no_match_label="NO_SPEAKER")
|
120 |
+
|
121 |
+
for i, label in enumerate(asr_labels):
|
122 |
+
segmented_preds[i]["speaker"] = label
|
123 |
# Get diarization information
|
124 |
starts, ends, speakers = diarization_info(diarization_res)
|
125 |
|
126 |
# results from the transcription pipeline
|
127 |
+
diarized_transcription = segmented_preds
|
|
|
|
|
|
|
|
|
|
|
128 |
|
129 |
# Plot diarization
|
130 |
diarization_plot = plot_diarization(starts, ends, speakers)
|
|
|
142 |
gr.Plot(label="Visualization"),
|
143 |
],
|
144 |
examples=["sample1.wav"],
|
145 |
+
title="Automatic Speech Recognition with Diarization 🗣️",
|
146 |
+
description="Whisper V3 Large & Pyannote Speaker Diarization V3.1 \nTranscribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️",
|
147 |
)
|
148 |
|
149 |
if __name__ == "__main__":
|