File size: 6,629 Bytes
171bc27 641eeeb 163fe9c 171bc27 4e0c3e5 171bc27 722014e 747d7b4 4e0c3e5 722014e 4e0c3e5 171bc27 4e0c3e5 79892f2 4e3512b 2b0365b 195413c 722014e 171bc27 e298aea 171bc27 9a5cdc0 4e0c3e5 9a5cdc0 4e0c3e5 9a5cdc0 4e0c3e5 9a5cdc0 67b38a5 c5e38e6 9a5cdc0 e6750e3 9a5cdc0 080cf2e 9a5cdc0 0758244 9a5cdc0 c5e38e6 9a5cdc0 4e0c3e5 9a5cdc0 4e0c3e5 9a5cdc0 4e0c3e5 9a5cdc0 4e0c3e5 9a5cdc0 4e0c3e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
from speechbrain.inference.interfaces import Pretrained
import librosa
import numpy as np
class ASR(Pretrained):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch(self, device, wavs, wav_lens=None, normalize=False):
wavs = wavs.to(device)
wav_lens = wav_lens.to(device)
# Forward pass
encoded_outputs = self.mods.encoder_w2v2(wavs.detach())
# append
tokens_bos = torch.zeros((wavs.size(0), 1), dtype=torch.long).to(device)
embedded_tokens = self.mods.embedding(tokens_bos)
decoder_outputs, _ = self.mods.decoder(embedded_tokens, encoded_outputs, wav_lens)
# Output layer for seq2seq log-probabilities
predictions = self.hparams.test_search(encoded_outputs, wav_lens)[0]
# predicted_words = [self.hparams.tokenizer.decode_ids(prediction).split(" ") for prediction in predictions]
predicted_words = []
for prediction in predictions:
prediction = [token for token in prediction if token != 0]
predicted_words.append(self.hparams.tokenizer.decode_ids(prediction).split(" "))
prediction = []
for sent in predicted_words:
sent = self.filter_repetitions(sent, 3)
prediction.append(sent)
predicted_words = prediction
return predicted_words
def filter_repetitions(self, seq, max_repetition_length):
seq = list(seq)
output = []
max_n = len(seq) // 2
for n in range(max_n, 0, -1):
max_repetitions = max(max_repetition_length // n, 1)
# Don't need to iterate over impossible n values:
# len(seq) can change a lot during iteration
if (len(seq) <= n*2) or (len(seq) <= max_repetition_length):
continue
iterator = enumerate(seq)
# Fill first buffers:
buffers = [[next(iterator)[1]] for _ in range(n)]
for seq_index, token in iterator:
current_buffer = seq_index % n
if token != buffers[current_buffer][-1]:
# No repeat, we can flush some tokens
buf_len = sum(map(len, buffers))
flush_start = (current_buffer-buf_len) % n
# Keep n-1 tokens, but possibly mark some for removal
for flush_index in range(buf_len - buf_len%n):
if (buf_len - flush_index) > n-1:
to_flush = buffers[(flush_index + flush_start) % n].pop(0)
else:
to_flush = None
# Here, repetitions get removed:
if (flush_index // n < max_repetitions) and to_flush is not None:
output.append(to_flush)
elif (flush_index // n >= max_repetitions) and to_flush is None:
output.append(to_flush)
buffers[current_buffer].append(token)
# At the end, final flush
current_buffer += 1
buf_len = sum(map(len, buffers))
flush_start = (current_buffer-buf_len) % n
for flush_index in range(buf_len):
to_flush = buffers[(flush_index + flush_start) % n].pop(0)
# Here, repetitions just get removed:
if flush_index // n < max_repetitions:
output.append(to_flush)
seq = []
to_delete = 0
for token in output:
if token is None:
to_delete += 1
elif to_delete > 0:
to_delete -= 1
else:
seq.append(token)
output = []
return seq
# def classify_file(self, path):
# # waveform = self.load_audio(path)
# waveform, sr = librosa.load(path, sr=16000)
# waveform = torch.tensor(waveform)
# # Fake a batch:
# batch = waveform.unsqueeze(0)
# rel_length = torch.tensor([1.0])
# outputs = self.encode_batch(batch, rel_length)
# return outputs
def classify_file(self, path, device):
# Load the audio file
# path = "long_sample.wav"
waveform, sr = librosa.load(path, sr=16000)
# Get audio length in seconds
audio_length = len(waveform) / sr
if audio_length >= 20:
print(f"Audio is too long ({audio_length:.2f} seconds), splitting into segments")
# Detect non-silent segments
non_silent_intervals = librosa.effects.split(waveform, top_db=20) # Adjust top_db for sensitivity
segments = []
current_segment = []
current_length = 0
max_duration = 20 * sr # Maximum segment duration in samples (20 seconds)
for interval in non_silent_intervals:
start, end = interval
segment_part = waveform[start:end]
# If adding the next part exceeds max duration, store the segment and start a new one
if current_length + len(segment_part) > max_duration:
segments.append(np.concatenate(current_segment))
current_segment = []
current_length = 0
current_segment.append(segment_part)
current_length += len(segment_part)
# Append the last segment if it's not empty
if current_segment:
segments.append(np.concatenate(current_segment))
# Process each segment
outputs = []
for i, segment in enumerate(segments):
print(f"Processing segment {i + 1}/{len(segments)}, length: {len(segment) / sr:.2f} seconds")
segment_tensor = torch.tensor(segment).to(device)
# Fake a batch for the segment
batch = segment_tensor.unsqueeze(0).to(device)
rel_length = torch.tensor([1.0]).to(device) # Adjust if necessary
# Pass the segment through the ASR model
segment_output = self.encode_batch(device, batch, rel_length)
yield segment_output
else:
waveform = torch.tensor(waveform).to(device)
waveform = waveform.to(device)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0]).to(device)
outputs = self.encode_batch(device, batch, rel_length)
yield outputs
|