aadnk's picture
Fix pad_timestamps
02253c6
from abc import ABC, abstractmethod
from collections import Counter
from typing import Any, Iterator, List, Dict
from pprint import pprint
# Workaround for https://github.com/tensorflow/tensorflow/issues/48797
try:
import tensorflow as tf
except ModuleNotFoundError:
# Error handling
pass
import torch
import ffmpeg
import numpy as np
from src.utils import format_timestamp
# Defaults for Silero
# TODO: Make these configurable?
SPEECH_TRESHOLD = 0.3
MAX_SILENT_PERIOD = 10 # seconds
MAX_MERGE_SIZE = 150 # Do not create segments larger than 2.5 minutes
SEGMENT_PADDING_LEFT = 1 # Start detected text segment early
SEGMENT_PADDING_RIGHT = 1 # End detected segments late
# Whether to attempt to transcribe non-speech
TRANSCRIBE_NON_SPEECH = False
# Minimum size of segments to process
MIN_SEGMENT_DURATION = 1
VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
class AbstractTranscription(ABC):
def __init__(self, segment_padding_left: int = None, segment_padding_right = None, max_silent_period: int = None, max_merge_size: int = None, transcribe_non_speech: bool = False):
self.sampling_rate = 16000
self.segment_padding_left = segment_padding_left
self.segment_padding_right = segment_padding_right
self.max_silent_period = max_silent_period
self.max_merge_size = max_merge_size
self.transcribe_non_speech = transcribe_non_speech
def get_audio_segment(self, str, start_time: str = None, duration: str = None):
return load_audio(str, self.sampling_rate, start_time, duration)
@abstractmethod
def get_transcribe_timestamps(self, audio: str):
"""
Get the start and end timestamps of the sections that should be transcribed by this VAD method.
Parameters
----------
audio: str
The audio file.
Returns
-------
A list of start and end timestamps, in fractional seconds.
"""
return
def transcribe(self, audio: str, whisperCallable):
"""
Transcribe the given audo file.
Parameters
----------
audio: str
The audio file.
whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor]], dict[str, Union[dict, Any]]]
The callback that is used to invoke Whisper on an audio file/buffer.
Returns
-------
A list of start and end timestamps, in fractional seconds.
"""
# get speech timestamps from full audio file
seconds_timestamps = self.get_transcribe_timestamps(audio)
padded = self.pad_timestamps(seconds_timestamps, self.segment_padding_left, self.segment_padding_right)
merged = self.merge_timestamps(padded, self.max_silent_period, self.max_merge_size)
print("Timestamps:")
pprint(merged)
if self.transcribe_non_speech:
max_audio_duration = get_audio_duration(audio)
# Expand segments to include the gaps between them
merged = self.expand_gaps(merged, total_duration=max_audio_duration)
print("Transcribing non-speech:")
pprint(merged)
result = {
'text': "",
'segments': [],
'language': ""
}
languageCounter = Counter()
# For each time segment, run whisper
for segment in merged:
segment_start = segment['start']
segment_end = segment['end']
segment_expand_amount = segment.get('expand_amount', 0)
segment_duration = segment_end - segment_start
if segment_duration < MIN_SEGMENT_DURATION:
continue;
segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ", segment_duration, "expanded: ", segment_expand_amount)
segment_result = whisperCallable(segment_audio)
adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
# Append to output
result['text'] += segment_result['text']
result['segments'].extend(adjusted_segments)
# Increment detected language
languageCounter[segment_result['language']] += 1
if len(languageCounter) > 0:
result['language'] = languageCounter.most_common(1)[0][0]
return result
def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
result = []
last_end_time = 0
for segment in segments:
segment_start = float(segment['start'])
segment_end = float(segment['end'])
if (last_end_time != segment_start):
delta = segment_start - last_end_time
if (min_gap_length is None or delta >= min_gap_length):
result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
last_end_time = segment_end
result.append(segment)
# Also include total duration if specified
if (total_duration is not None and last_end_time < total_duration):
delta = total_duration - segment_start
if (min_gap_length is None or delta >= min_gap_length):
result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
return result
# Expand the end time of each segment to the start of the next segment
def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
result = []
if len(segments) == 0:
return result
# Add gap at the beginning if needed
if (segments[0]['start'] > 0):
result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
for i in range(len(segments) - 1):
current_segment = segments[i]
next_segment = segments[i + 1]
delta = next_segment['start'] - current_segment['end']
# Expand if the gap actually exists
if (delta >= 0):
current_segment = current_segment.copy()
current_segment['expand_amount'] = delta
current_segment['end'] = next_segment['start']
result.append(current_segment)
last_segment = result[-1]
# Also include total duration if specified
if (total_duration is not None):
last_segment = result[-1]
if (last_segment['end'] < total_duration):
last_segment = last_segment.copy()
last_segment['end'] = total_duration
result[-1] = last_segment
return result
def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
result = []
for segment in segments:
segment_start = float(segment['start'])
segment_end = float(segment['end'])
# Filter segments?
if (max_source_time is not None):
if (segment_start > max_source_time):
continue
segment_end = min(max_source_time, segment_end)
new_segment = segment.copy()
# Add to start and end
new_segment['start'] = segment_start + adjust_seconds
new_segment['end'] = segment_end + adjust_seconds
result.append(new_segment)
return result
def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float):
if (padding_left == 0 and padding_right == 0):
return timestamps
result = []
prev_entry = None
for i in range(len(timestamps)):
curr_entry = timestamps[i]
next_entry = timestamps[i + 1] if i < len(timestamps) - 1 else None
segment_start = curr_entry['start']
segment_end = curr_entry['end']
if padding_left is not None:
segment_start = max(prev_entry['end'] if prev_entry else 0, segment_start - padding_left)
if padding_right is not None:
segment_end = segment_end + padding_right
# Do not pad past the next segment
if (next_entry is not None):
segment_end = min(next_entry['start'], segment_end)
new_entry = { 'start': segment_start, 'end': segment_end }
prev_entry = new_entry
result.append(new_entry)
return result
def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_merge_gap: float, max_merge_size: float):
if max_merge_gap is None:
return timestamps
result = []
current_entry = None
for entry in timestamps:
if current_entry is None:
current_entry = entry
continue
# Get distance to the previous entry
distance = entry['start'] - current_entry['end']
current_entry_size = current_entry['end'] - current_entry['start']
if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size):
# Merge
current_entry['end'] = entry['end']
else:
# Output current entry
result.append(current_entry)
current_entry = entry
# Add final entry
if current_entry is not None:
result.append(current_entry)
return result
def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
result = []
for entry in timestamps:
start = entry['start']
end = entry['end']
result.append({
'start': start * factor,
'end': end * factor
})
return result
class VadSileroTranscription(AbstractTranscription):
def __init__(self, segment_padding_left=SEGMENT_PADDING_LEFT, segment_padding_right=SEGMENT_PADDING_RIGHT,
max_silent_period=MAX_SILENT_PERIOD, max_merge_size=MAX_MERGE_SIZE, transcribe_non_speech: bool = False,
copy = None):
super().__init__(segment_padding_left=segment_padding_left, segment_padding_right=segment_padding_right,
max_silent_period=max_silent_period, max_merge_size=max_merge_size, transcribe_non_speech=transcribe_non_speech)
if copy:
self.model = copy.model
self.get_speech_timestamps = copy.get_speech_timestamps
else:
self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
(self.get_speech_timestamps, _, _, _, _) = utils
def get_transcribe_timestamps(self, audio: str):
audio_duration = get_audio_duration(audio)
result = []
# Divide procesisng of audio into chunks
chunk_start = 0.0
while (chunk_start < audio_duration):
chunk_duration = min(audio_duration - chunk_start, VAD_MAX_PROCESSING_CHUNK)
print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
#pprint(adjusted)
result.extend(adjusted)
chunk_start += chunk_duration
return result
# A very simple VAD that just marks every N seconds as speech
class VadPeriodicTranscription(AbstractTranscription):
def __init__(self, periodic_duration: int):
super().__init__()
self.periodic_duration = periodic_duration
def get_transcribe_timestamps(self, audio: str):
# Get duration in seconds
audio_duration = get_audio_duration(audio)
result = []
# Generate a timestamp every N seconds
start_timestamp = 0
while (start_timestamp < audio_duration):
end_timestamp = min(start_timestamp + self.periodic_duration, audio_duration)
segment_duration = end_timestamp - start_timestamp
# Minimum duration is 1 second
if (segment_duration >= 1):
result.append( { 'start': start_timestamp, 'end': end_timestamp } )
start_timestamp = end_timestamp
return result
def get_audio_duration(file: str):
return float(ffmpeg.probe(file)["format"]["duration"])
def load_audio(file: str, sample_rate: int = 16000,
start_time: str = None, duration: str = None):
"""
Open an audio file and read as mono waveform, resampling as necessary
Parameters
----------
file: str
The audio file to open
sr: int
The sample rate to resample the audio if necessary
start_time: str
The start time, using the standard FFMPEG time duration syntax, or None to disable.
duration: str
The duration, using the standard FFMPEG time duration syntax, or None to disable.
Returns
-------
A NumPy array containing the audio waveform, in float32 dtype.
"""
try:
inputArgs = {'threads': 0}
if (start_time is not None):
inputArgs['ss'] = start_time
if (duration is not None):
inputArgs['t'] = duration
# This launches a subprocess to decode audio while down-mixing and resampling as necessary.
# Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
out, _ = (
ffmpeg.input(file, **inputArgs)
.output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
)
except ffmpeg.Error as e:
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0