afurkank commited on
Commit
8fdb96f
1 Parent(s): d473317

Delete diarization_utils.py

Browse files
Files changed (1) hide show
  1. diarization_utils.py +0 -141
diarization_utils.py DELETED
@@ -1,141 +0,0 @@
1
- import torch
2
- import numpy as np
3
- from torchaudio import functional as F
4
- from transformers.pipelines.audio_utils import ffmpeg_read
5
- from starlette.exceptions import HTTPException
6
- import sys
7
-
8
- # Code from insanely-fast-whisper:
9
- # https://github.com/Vaibhavs10/insanely-fast-whisper
10
-
11
- import logging
12
- logger = logging.getLogger(__name__)
13
-
14
- def preprocess_inputs(inputs, sampling_rate):
15
- inputs = ffmpeg_read(inputs, sampling_rate)
16
-
17
- if sampling_rate != 16000:
18
- inputs = F.resample(
19
- torch.from_numpy(inputs), sampling_rate, 16000
20
- ).numpy()
21
-
22
- if len(inputs.shape) != 1:
23
- logger.error(f"Diarization pipeline expecs single channel audio, received {inputs.shape}")
24
- raise HTTPException(
25
- status_code=400,
26
- detail=f"Diarization pipeline expecs single channel audio, received {inputs.shape}"
27
- )
28
-
29
- # diarization model expects float32 torch tensor of shape `(channels, seq_len)`
30
- diarizer_inputs = torch.from_numpy(inputs).float()
31
- diarizer_inputs = diarizer_inputs.unsqueeze(0)
32
-
33
- return inputs, diarizer_inputs
34
-
35
-
36
- def diarize_audio(diarizer_inputs, diarization_pipeline, parameters):
37
- diarization = diarization_pipeline(
38
- {"waveform": diarizer_inputs, "sample_rate": parameters.sampling_rate},
39
- num_speakers=parameters.num_speakers,
40
- min_speakers=parameters.min_speakers,
41
- max_speakers=parameters.max_speakers,
42
- )
43
-
44
- segments = []
45
- for segment, track, label in diarization.itertracks(yield_label=True):
46
- segments.append(
47
- {
48
- "segment": {"start": segment.start, "end": segment.end},
49
- "track": track,
50
- "label": label,
51
- }
52
- )
53
-
54
- # diarizer output may contain consecutive segments from the same speaker (e.g. {(0 -> 1, speaker_1), (1 -> 1.5, speaker_1), ...})
55
- # we combine these segments to give overall timestamps for each speaker's turn (e.g. {(0 -> 1.5, speaker_1), ...})
56
- new_segments = []
57
- prev_segment = cur_segment = segments[0]
58
-
59
- for i in range(1, len(segments)):
60
- cur_segment = segments[i]
61
-
62
- # check if we have changed speaker ("label")
63
- if cur_segment["label"] != prev_segment["label"] and i < len(segments):
64
- # add the start/end times for the super-segment to the new list
65
- new_segments.append(
66
- {
67
- "segment": {
68
- "start": prev_segment["segment"]["start"],
69
- "end": cur_segment["segment"]["start"],
70
- },
71
- "speaker": prev_segment["label"],
72
- }
73
- )
74
- prev_segment = segments[i]
75
-
76
- # add the last segment(s) if there was no speaker change
77
- new_segments.append(
78
- {
79
- "segment": {
80
- "start": prev_segment["segment"]["start"],
81
- "end": cur_segment["segment"]["end"],
82
- },
83
- "speaker": prev_segment["label"],
84
- }
85
- )
86
-
87
- return new_segments
88
-
89
-
90
- def post_process_segments_and_transcripts(new_segments, transcript, group_by_speaker) -> list:
91
- # get the end timestamps for each chunk from the ASR output
92
- end_timestamps = np.array(
93
- [chunk["timestamp"][-1] if chunk["timestamp"][-1] is not None else sys.float_info.max for chunk in transcript])
94
- segmented_preds = []
95
-
96
- # align the diarizer timestamps and the ASR timestamps
97
- for segment in new_segments:
98
- # get the diarizer end timestamp
99
- end_time = segment["segment"]["end"]
100
- # find the ASR end timestamp that is closest to the diarizer's end timestamp and cut the transcript to here
101
- upto_idx = np.argmin(np.abs(end_timestamps - end_time))
102
-
103
- if group_by_speaker:
104
- segmented_preds.append(
105
- {
106
- "speaker": segment["speaker"],
107
- "text": "".join(
108
- [chunk["text"] for chunk in transcript[: upto_idx + 1]]
109
- ),
110
- "timestamp": (
111
- transcript[0]["timestamp"][0],
112
- transcript[upto_idx]["timestamp"][1],
113
- ),
114
- }
115
- )
116
- else:
117
- for i in range(upto_idx + 1):
118
- segmented_preds.append({"speaker": segment["speaker"], **transcript[i]})
119
-
120
- # crop the transcripts and timestamp lists according to the latest timestamp (for faster argmin)
121
- transcript = transcript[upto_idx + 1:]
122
- end_timestamps = end_timestamps[upto_idx + 1:]
123
-
124
- if len(end_timestamps) == 0:
125
- break
126
-
127
- return segmented_preds
128
-
129
-
130
- def diarize(diarization_pipeline, file, parameters, asr_outputs):
131
- _, diarizer_inputs = preprocess_inputs(file, parameters.sampling_rate)
132
-
133
- segments = diarize_audio(
134
- diarizer_inputs,
135
- diarization_pipeline,
136
- parameters
137
- )
138
-
139
- return post_process_segments_and_transcripts(
140
- segments, asr_outputs["chunks"], group_by_speaker=False
141
- )