sergeipetrov HF staff commited on
Commit
36cfeae
1 Parent(s): c7d2a58

Create diarization_utils.py

Browse files
Files changed (1) hide show
  1. diarization_utils.py +141 -0
diarization_utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torchaudio import functional as F
4
+ from transformers.pipelines.audio_utils import ffmpeg_read
5
+ from fastapi import HTTPException
6
+ import sys
7
+
8
+ # Code from insanely-fast-whisper:
9
+ # https://github.com/Vaibhavs10/insanely-fast-whisper
10
+
11
+ import logging
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def preprocess_inputs(inputs, sampling_rate):
15
+ inputs = ffmpeg_read(inputs, sampling_rate)
16
+
17
+ if sampling_rate != 16000:
18
+ inputs = F.resample(
19
+ torch.from_numpy(inputs), sampling_rate, 16000
20
+ ).numpy()
21
+
22
+ if len(inputs.shape) != 1:
23
+ logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}")
24
+ raise HTTPException(
25
+ status_code=400,
26
+ detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}"
27
+ )
28
+
29
+ # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
30
+ diarizer_inputs = torch.from_numpy(inputs).float()
31
+ diarizer_inputs = diarizer_inputs.unsqueeze(0)
32
+
33
+ return inputs, diarizer_inputs
34
+
35
+
36
+ def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
37
+ diarization = diarization_pipeline(
38
+ {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate},
39
+ num_speakers=parameters.num_speakers,
40
+ min_speakers=parameters.min_speakers,
41
+ max_speakers=parameters.max_speakers,
42
+ )
43
+
44
+ segments = []
45
+ for segment, track, label in diarization.itertracks(yield_label=True):
46
+ segments.append(
47
+ {
48
+ "segment": {"start": segment.start, "end": segment.end},
49
+ "track": track,
50
+ "label": label,
51
+ }
52
+ )
53
+
54
+ # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
55
+ # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
56
+ new_segments = []
57
+ prev_segment = cur_segment = segments[0]
58
+
59
+ for i in range(1, len(segments)):
60
+ cur_segment = segments[i]
61
+
62
+ # check if we have changed speaker ("label")
63
+ if cur_segment["label"] != prev_segment["label"] and i < len(segments):
64
+ # add the start/end times for the super-segment to the new list
65
+ new_segments.append(
66
+ {
67
+ "segment": {
68
+ "start": prev_segment["segment"]["start"],
69
+ "end": cur_segment["segment"]["start"],
70
+ },
71
+ "speaker": prev_segment["label"],
72
+ }
73
+ )
74
+ prev_segment = segments[i]
75
+
76
+ # add the last segment(s) if there was no speaker change
77
+ new_segments.append(
78
+ {
79
+ "segment": {
80
+ "start": prev_segment["segment"]["start"],
81
+ "end": cur_segment["segment"]["end"],
82
+ },
83
+ "speaker": prev_segment["label"],
84
+ }
85
+ )
86
+
87
+ return new_segments
88
+
89
+
90
+ def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
91
+ # get the end timestamps for each chunk from the ASR output
92
+ end_timestamps = np.array(
93
+ [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
94
+ segmented_preds = []
95
+
96
+ # align the diarizer timestamps and the ASR timestamps
97
+ for segment in new_segments:
98
+ # get the diarizer end timestamp
99
+ end_time = segment["segment"]["end"]
100
+ # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
101
+ upto_idx = np.argmin(np.abs(end_timestamps - end_time))
102
+
103
+ if group_by_speaker:
104
+ segmented_preds.append(
105
+ {
106
+ "speaker": segment["speaker"],
107
+ "text": "".join(
108
+ [chunk["text"] for chunk in transcript[: upto_idx + 1]]
109
+ ),
110
+ "timestamp": (
111
+ transcript[0]["timestamp"][0],
112
+ transcript[upto_idx]["timestamp"][1],
113
+ ),
114
+ }
115
+ )
116
+ else:
117
+ for i in range(upto_idx + 1):
118
+ segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
119
+
120
+ # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
121
+ transcript = transcript[upto_idx + 1:]
122
+ end_timestamps = end_timestamps[upto_idx + 1:]
123
+
124
+ if len(end_timestamps) == 0:
125
+ break
126
+
127
+ return segmented_preds
128
+
129
+
130
+ def diarize(diarization_pipeline, file, parameters, asr_outputs):
131
+ _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
132
+
133
+ segments = diarize_audio(
134
+ diarizer_inputs,
135
+ diarization_pipeline,
136
+ parameters
137
+ )
138
+
139
+ return post_process_segments_and_transcripts(
140
+ segments, asr_outputs["chunks"], group_by_speaker=False
141
+ )