File size: 13,750 Bytes
8840b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7801528
8840b4b
 
7801528
 
8840b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7801528
 
 
 
 
 
 
 
8840b4b
 
 
 
 
 
 
7801528
8840b4b
 
 
 
 
 
 
 
 
 
 
7801528
8840b4b
 
7801528
 
8840b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7801528
 
 
 
 
 
 
 
8840b4b
 
 
 
 
 
7801528
8840b4b
 
6a73d74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8840b4b
 
 
 
7801528
8840b4b
 
 
 
 
 
 
 
 
7801528
8840b4b
 
7801528
 
8840b4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7801528
 
 
 
 
 
 
8840b4b
 
 
 
 
 
 
7801528
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import torch
from speechbrain.inference.interfaces import Pretrained
import librosa
import numpy as np


class ASR(Pretrained):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def encode_batch_w2v2(self, device, wavs, wav_lens=None, normalize=False):
        wavs = wavs.to(device)
        wav_lens = wav_lens.to(device)

        # Forward pass
        encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
        # append
        tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(device)
        embedded_tokens = self.mods.embedding(tokens_bos)
        decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, wav_lens)

        # Output layer for seq2seq log-probabilities
        predictions = self.hparams.test_search(encoded_outputs, wav_lens)[0]
        # predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
        predicted_words = []
        for prediction in predictions:
            prediction = [token for token in prediction if token != 0]
            predicted_words.append(self.hparams.tokenizer.decode_ids(prediction).split(" "))
        prediction = []
        for sent in predicted_words:
            sent = self.filter_repetitions(sent, 3)
            prediction.append(sent)
        predicted_words = prediction
        return predicted_words


    def encode_batch_whisper(self, device, wavs, wav_lens=None, normalize=False):
        wavs = wavs.to(device)
        wav_lens = wav_lens.to(device)

        # Forward encoder + decoder
        tokens = torch.tensor([[1, 1]]) * self.mods.whisper.config.decoder_start_token_id
        tokens = tokens.to(device)
        enc_out, logits, _ = self.mods.whisper(wavs.detach(), tokens.detach())
        log_probs = self.hparams.log_softmax(logits)

        hyps, _, _, _ = self.hparams.test_search(enc_out.detach(), wav_lens)
        predicted_words = [self.mods.whisper.tokenizer.decode(token, skip_special_tokens=True).strip() for token in hyps]
        return predicted_words


    def filter_repetitions(self, seq, max_repetition_length):
        seq = list(seq)
        output = []
        max_n = len(seq) // 2
        for n in range(max_n, 0, -1):
            max_repetitions = max(max_repetition_length // n, 1)
            # Don't need to iterate over impossible n values:
            # len(seq) can change a lot during iteration
            if (len(seq) <= n*2) or (len(seq) <= max_repetition_length):
                continue
            iterator = enumerate(seq)
            # Fill first buffers:
            buffers = [[next(iterator)[1]] for _ in range(n)]
            for seq_index, token in iterator:
                current_buffer = seq_index % n
                if token != buffers[current_buffer][-1]:
                    # No repeat, we can flush some tokens
                    buf_len = sum(map(len, buffers))
                    flush_start = (current_buffer-buf_len) % n
                    # Keep n-1 tokens, but possibly mark some for removal
                    for flush_index in range(buf_len - buf_len%n):
                        if (buf_len - flush_index) > n-1:
                            to_flush = buffers[(flush_index + flush_start) % n].pop(0)
                        else:
                            to_flush = None
                        # Here, repetitions get removed:
                        if (flush_index // n < max_repetitions) and to_flush is not None:
                            output.append(to_flush)
                        elif (flush_index // n >= max_repetitions) and to_flush is None:
                            output.append(to_flush)
                buffers[current_buffer].append(token)
            # At the end, final flush
            current_buffer += 1
            buf_len = sum(map(len, buffers))
            flush_start = (current_buffer-buf_len) % n
            for flush_index in range(buf_len):
                to_flush = buffers[(flush_index + flush_start) % n].pop(0)
                # Here, repetitions just get removed:
                if flush_index // n < max_repetitions:
                    output.append(to_flush)
            seq = []
            to_delete = 0
            for token in output:
                if token is None:
                    to_delete += 1
                elif to_delete > 0:
                    to_delete -= 1
                else:
                    seq.append(token)
            output = []
        return seq
    

    def increase_volume(self, waveform, threshold_db=-25):
        # Measure loudness using RMS
        loudness_vector = librosa.feature.rms(y=waveform)
        average_loudness = np.mean(loudness_vector)
        average_loudness_db = librosa.amplitude_to_db(average_loudness)

        print(f"Average Loudness: {average_loudness_db} dB")

        # Check if loudness is below threshold and apply gain if needed
        if average_loudness_db < threshold_db:
            # Calculate gain needed
            gain_db = threshold_db - average_loudness_db
            gain = librosa.db_to_amplitude(gain_db)  # Convert dB to amplitude factor

            # Apply gain to the audio signal
            waveform = waveform * gain
            loudness_vector = librosa.feature.rms(y=waveform)
            average_loudness = np.mean(loudness_vector)
            average_loudness_db = librosa.amplitude_to_db(average_loudness)

            print(f"Average Loudness: {average_loudness_db} dB")
        return waveform


    def classify_file_w2v2(self, waveform, device):
        # Load the audio file
        # waveform, sr = librosa.load(path, sr=16000)

        # Get audio length in seconds
        audio_length = len(waveform) / 16000
        
        if audio_length >= 30:
            # split audio every 20 seconds
            segments = []
            all_segments = []
            max_duration = 30 * 16000  # Maximum segment duration in samples (20 seconds)
            num_segments = int(np.ceil(len(waveform) / max_duration))
            start = 0
            for i in range(num_segments):
                end = start + max_duration
                if end > len(waveform):
                    end = len(waveform)
                segment_part = waveform[start:end]
                segment_len = len(segment_part) / 16000
                if segment_len < 1:
                    continue
                segments.append(segment_part)
                start = end

            for segment in segments:
                segment_tensor = torch.tensor(segment).to(device)

                # Fake a batch for the segment
                batch = segment_tensor.unsqueeze(0).to(device)
                rel_length = torch.tensor([1.0]).to(device)  # Adjust if necessary

                # Pass the segment through the ASR model
                segment_output = self.encode_batch_w2v2(device, batch, rel_length)
                segment_output = [" ".join(segment) for segment in segment_output]
                all_segments.append(segment_output)
            
            segments = ""
            for segment in all_segments:
                segment = segment[0]
                segments += segment + " "
            return [segments]
        else:
            waveform = torch.tensor(waveform).to(device)
            waveform = waveform.to(device)
            # Fake a batch:
            batch = waveform.unsqueeze(0)
            rel_length = torch.tensor([1.0]).to(device)
            outputs = self.encode_batch_w2v2(device, batch, rel_length)
            return [" ".join(out) for out in outputs]

        


    def classify_file_whisper_mkd(self, waveform, device):
        # Load the audio file
        # waveform, sr = librosa.load(path, sr=16000)

        # Get audio length in seconds
        audio_length = len(waveform) / 16000
        
        if audio_length >= 30:
            # split audio every 20 seconds
            segments = []
            all_segments = []
            max_duration = 30 * 16000  # Maximum segment duration in samples (20 seconds)
            num_segments = int(np.ceil(len(waveform) / max_duration))
            start = 0
            for i in range(num_segments):
                end = start + max_duration
                if end > len(waveform):
                    end = len(waveform)
                segment_part = waveform[start:end]
                segment_len = len(segment_part) / 16000
                if segment_len < 1:
                    continue
                segments.append(segment_part)
                start = end

            for segment in segments:
                segment_tensor = torch.tensor(segment).to(device)

                # Fake a batch for the segment
                batch = segment_tensor.unsqueeze(0).to(device)
                rel_length = torch.tensor([1.0]).to(device)

                # Pass the segment through the ASR model
                segment_output = self.encode_batch_whisper(device, batch, rel_length)
                # segment_output = [" ".join(segment) for segment in segment_output]
                all_segments.append(segment_output)
            
            segments = ""
            for segment in all_segments:
                segment = segment[0]
                segments += segment + " "
            return [segments]
        else:
            waveform = torch.tensor(waveform).to(device)
            waveform = waveform.to(device)
            batch = waveform.unsqueeze(0)
            rel_length = torch.tensor([1.0]).to(device)
            outputs = self.encode_batch_whisper(device, batch, rel_length)
            return outputs


    def classify_file_whisper_mkd_streaming(self, waveform, device):
        # Load the audio file
        # waveform, sr = librosa.load(path, sr=16000)

        # Get audio length in seconds
        audio_length = len(waveform) / 16000
        
        if audio_length >= 20:
            # split audio every 20 seconds
            segments = []
            max_duration = 20 * 16000  # Maximum segment duration in samples (20 seconds)
            num_segments = int(np.ceil(len(waveform) / max_duration))
            start = 0
            for i in range(num_segments):
                end = start + max_duration
                if end > len(waveform):
                    end = len(waveform)
                segment_part = waveform[start:end]
                segment_len = len(segment_part) / 16000
                if segment_len < 1:
                    continue
                segments.append(segment_part)
                start = end

            for segment in segments:
                segment_tensor = torch.tensor(segment).to(device)

                # Fake a batch for the segment
                batch = segment_tensor.unsqueeze(0).to(device)
                rel_length = torch.tensor([1.0]).to(device)

                # Pass the segment through the ASR model
                segment_output = self.encode_batch_whisper(device, batch, rel_length)
                yield segment_output
        else:
            waveform = torch.tensor(waveform).to(device)
            waveform = waveform.to(device)
            batch = waveform.unsqueeze(0)
            rel_length = torch.tensor([1.0]).to(device)
            outputs = self.encode_batch_whisper(device, batch, rel_length)
            yield outputs


    def classify_file_whisper(self, waveform, pipe, device):
        # waveform, sr = librosa.load(path, sr=16000)
        transcription = pipe(waveform, generate_kwargs={"language": "macedonian"})["text"]
        return [transcription]
       

    def classify_file_mms(self, waveform, processor, model, device):
        # Load the audio file
        # waveform, sr = librosa.load(path, sr=16000)

        # Get audio length in seconds
        audio_length = len(waveform) / 16000
        
        if audio_length >= 30:
            # split audio every 20 seconds
            segments = []
            all_segments = []
            max_duration = 30 * 16000  # Maximum segment duration in samples (20 seconds)
            num_segments = int(np.ceil(len(waveform) / max_duration))
            start = 0
            for i in range(num_segments):
                end = start + max_duration
                if end > len(waveform):
                    end = len(waveform)
                segment_part = waveform[start:end]
                segment_len = len(segment_part) / 16000
                if segment_len < 1:
                    continue
                segments.append(segment_part)
                start = end

            for segment in segments:
                segment_tensor = torch.tensor(segment).to(device)

                # Pass the segment through the ASR model
                inputs = processor(segment_tensor, sampling_rate=16_000, return_tensors="pt").to(device)
                inputs['input_values'] = inputs['input_values']
                outputs = model(**inputs).logits
                ids = torch.argmax(outputs, dim=-1)[0]
                segment_output = processor.decode(ids)
                # segment_output = [" ".join(segment) for segment in segment_output]
                all_segments.append(segment_output)
            
            segments = ""
            for segment in all_segments:
                segments += segment + " "
            return [segments]
        else:
            waveform = torch.tensor(waveform).to(device)
            inputs = processor(waveform, sampling_rate=16_000, return_tensors="pt").to(device)
            inputs['input_values'] = inputs['input_values']
            outputs = model(**inputs).logits
            ids = torch.argmax(outputs, dim=-1)[0]
            transcription = processor.decode(ids)
            return [transcription]