pragyaa commited on
Commit
0874a62
1 Parent(s): b1c3d0e

Upload 10 files

Browse files
Files changed (10) hide show
  1. __init__.py +0 -0
  2. download.py +79 -0
  3. modelCache.py +17 -0
  4. segments.py +55 -0
  5. source.py +70 -0
  6. utils-original.py +115 -0
  7. utils.py +129 -0
  8. vad.py +537 -0
  9. vadParallel.py +255 -0
  10. whisperContainer.py +127 -0
__init__.py ADDED
File without changes
download.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'outtmpl':destinationDirectory+'1.wav',
34
+ 'paths': {
35
+ 'home': destinationDirectory
36
+ }
37
+ }
38
+ if (playlistItems):
39
+ ydl_opts['playlist_items'] = playlistItems
40
+
41
+ # Add output template if specified
42
+ if outputTemplate:
43
+ ydl_opts['outtmpl'] = outputTemplate
44
+
45
+ filename_collector = FilenameCollectorPP()
46
+
47
+ with YoutubeDL(ydl_opts) as ydl:
48
+ if maxDuration and maxDuration > 0:
49
+ info = ydl.extract_info(url, download=False)
50
+ entries = "entries" in info and info["entries"] or [info]
51
+
52
+ total_duration = 0
53
+
54
+ # Compute total duration
55
+ for entry in entries:
56
+ total_duration += float(entry["duration"])
57
+
58
+ if total_duration >= maxDuration:
59
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=maxDuration, message="Video is too long")
60
+
61
+ ydl.add_post_processor(filename_collector)
62
+ ydl.download([url])
63
+
64
+ if len(filename_collector.filenames) <= 0:
65
+ raise Exception("Cannot download " + url)
66
+
67
+ result = []
68
+
69
+ for filename in filename_collector.filenames:
70
+ result.append(filename)
71
+ print("Downloaded " + filename)
72
+
73
+ return result
74
+
75
+ class ExceededMaximumDuration(Exception):
76
+ def __init__(self, videoDuration, maxDuration, message):
77
+ self.videoDuration = videoDuration
78
+ self.maxDuration = maxDuration
79
+ super().__init__(message)
modelCache.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ModelCache:
2
+ def __init__(self):
3
+ self._cache = dict()
4
+
5
+ def get(self, model_key: str, model_factory):
6
+ result = self._cache.get(model_key)
7
+
8
+ if result is None:
9
+ result = model_factory()
10
+ self._cache[model_key] = result
11
+ return result
12
+
13
+ def clear(self):
14
+ self._cache.clear()
15
+
16
+ # A global cache of models. This is mainly used by the daemon processes to avoid loading the same model multiple times.
17
+ GLOBAL_MODEL_CACHE = ModelCache()
segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
source.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
2
+ import os
3
+ import pathlib
4
+ from typing import List
5
+ import zipfile
6
+
7
+ import ffmpeg
8
+ from more_itertools import unzip
9
+
10
+ from src.download import ExceededMaximumDuration, download_url
11
+
12
+ MAX_FILE_PREFIX_LENGTH = 17
13
+
14
+ class AudioSource:
15
+ def __init__(self, source_path, source_name = None):
16
+ self.source_path = source_path
17
+ self.source_name = source_name
18
+
19
+ # Load source name if not provided
20
+ if (self.source_name is None):
21
+ file_path = pathlib.Path(self.source_path)
22
+ self.source_name = file_path.name
23
+
24
+ def get_full_name(self):
25
+ return self.source_name
26
+
27
+ def get_short_name(self, max_length: int = MAX_FILE_PREFIX_LENGTH):
28
+ file_path = pathlib.Path(self.source_name)
29
+ short_name = file_path.stem[:max_length] + file_path.suffix
30
+
31
+ return short_name
32
+
33
+ def __str__(self) -> str:
34
+ return self.source_path
35
+
36
+ class AudioSourceCollection:
37
+ def __init__(self, sources: List[AudioSource]):
38
+ self.sources = sources
39
+
40
+ def __iter__(self):
41
+ return iter(self.sources)
42
+
43
+ def get_audio_source_collection(urlData: str, multipleFiles: List, microphoneData: str, input_audio_max_duration: float = -1) -> List[AudioSource]:
44
+ output: List[AudioSource] = []
45
+
46
+ if urlData:
47
+ # Download from YouTube. This could also be a playlist or a channel.
48
+ output.extend([ AudioSource(x) for x in download_url(urlData, input_audio_max_duration, playlistItems=None) ])
49
+ else:
50
+ # Add input files
51
+ if (multipleFiles is not None):
52
+ output.extend([ AudioSource(x.name) for x in multipleFiles ])
53
+ if (microphoneData is not None):
54
+ output.append(AudioSource(microphoneData))
55
+
56
+ total_duration = 0
57
+
58
+ # Calculate total audio length. We do this even if input_audio_max_duration
59
+ # is disabled to ensure that all the audio files are valid.
60
+ for source in output:
61
+ audioDuration = ffmpeg.probe(source.source_path)["format"]["duration"]
62
+ total_duration += float(audioDuration)
63
+
64
+ # Ensure the total duration of the audio is not too long
65
+ if input_audio_max_duration > 0:
66
+ if float(total_duration) > input_audio_max_duration:
67
+ raise ExceededMaximumDuration(videoDuration=total_duration, maxDuration=input_audio_max_duration, message="Video(s) is too long")
68
+
69
+ # Return a list of audio sources
70
+ return output
utils-original.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+
8
+
9
+ def exact_div(x, y):
10
+ assert x % y == 0
11
+ return x // y
12
+
13
+
14
+ def str2bool(string):
15
+ str2val = {"True": True, "False": False}
16
+ if string in str2val:
17
+ return str2val[string]
18
+ else:
19
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
20
+
21
+
22
+ def optional_int(string):
23
+ return None if string == "None" else int(string)
24
+
25
+
26
+ def optional_float(string):
27
+ return None if string == "None" else float(string)
28
+
29
+
30
+ def compression_ratio(text) -> float:
31
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
32
+
33
+
34
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
35
+ assert seconds >= 0, "non-negative timestamp expected"
36
+ milliseconds = round(seconds * 1000.0)
37
+
38
+ hours = milliseconds // 3_600_000
39
+ milliseconds -= hours * 3_600_000
40
+
41
+ minutes = milliseconds // 60_000
42
+ milliseconds -= minutes * 60_000
43
+
44
+ seconds = milliseconds // 1_000
45
+ milliseconds -= seconds * 1_000
46
+
47
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
48
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
49
+
50
+
51
+ def write_txt(transcript: Iterator[dict], file: TextIO):
52
+ for segment in transcript:
53
+ print(segment['text'].strip(), file=file, flush=True)
54
+
55
+
56
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
+ print("WEBVTT\n", file=file)
58
+ for segment in transcript:
59
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
60
+
61
+ print(
62
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
63
+ f"{text}\n",
64
+ file=file,
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
70
+ """
71
+ Write a transcript to a file in SRT format.
72
+ Example usage:
73
+ from pathlib import Path
74
+ from whisper.utils import write_srt
75
+ result = transcribe(model, audio_path, temperature=temperature, **args)
76
+ # save SRT
77
+ audio_basename = Path(audio_path).stem
78
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
79
+ write_srt(result["segments"], file=srt)
80
+ """
81
+ for i, segment in enumerate(transcript, start=1):
82
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
+
84
+ # write srt lines
85
+ print(
86
+ f"{i}\n"
87
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
88
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
89
+ f"{text}\n",
90
+ file=file,
91
+ flush=True,
92
+ )
93
+
94
+ def process_text(text: str, maxLineWidth=None):
95
+ if (maxLineWidth is None or maxLineWidth < 0):
96
+ return text
97
+
98
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
99
+ return '\n'.join(lines)
100
+
101
+ def slugify(value, allow_unicode=False):
102
+ """
103
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
104
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
105
+ dashes to single dashes. Remove characters that aren't alphanumerics,
106
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
107
+ trailing whitespace, dashes, and underscores.
108
+ """
109
+ value = str(value)
110
+ if allow_unicode:
111
+ value = unicodedata.normalize('NFKC', value)
112
+ else:
113
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
114
+ value = re.sub(r'[^\w\s-]', '', value.lower())
115
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+ import audioread
8
+
9
+
10
+ def exact_div(x, y):
11
+ assert x % y == 0
12
+ return x // y
13
+
14
+ def duration_detector(path):
15
+ length = 0
16
+ with audioread.audio_open(path) as f:
17
+ length = int(f.duration)
18
+
19
+ hours = length // 3600 # calculate in hours
20
+ length %= 3600
21
+ mins = length // 60 # calculate in minutes
22
+ length %= 60
23
+ seconds = length # calculate in seconds
24
+ print('Total Duration: {}:{}:{}:{}'.format(path,hours, mins, seconds))
25
+ #return "{}:{}:{}".format(hours, mins, seconds)
26
+ return hours,mins,seconds
27
+
28
+ def str2bool(string):
29
+ str2val = {"True": True, "False": False}
30
+ if string in str2val:
31
+ return str2val[string]
32
+ else:
33
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
34
+
35
+
36
+ def optional_int(string):
37
+ return None if string == "None" else int(string)
38
+
39
+
40
+ def optional_float(string):
41
+ return None if string == "None" else float(string)
42
+
43
+
44
+ def compression_ratio(text) -> float:
45
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
46
+
47
+
48
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
49
+ assert seconds >= 0, "non-negative timestamp expected"
50
+ milliseconds = round(seconds * 1000.0)
51
+
52
+ hours = milliseconds // 3_600_000
53
+ milliseconds -= hours * 3_600_000
54
+
55
+ minutes = milliseconds // 60_000
56
+ milliseconds -= minutes * 60_000
57
+
58
+ seconds = milliseconds // 1_000
59
+ milliseconds -= seconds * 1_000
60
+
61
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
62
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
63
+
64
+
65
+ def write_txt(transcript: Iterator[dict], file: TextIO):
66
+ for segment in transcript:
67
+ print(segment['text'].strip(), file=file, flush=True)
68
+
69
+
70
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
71
+ print("WEBVTT\n", file=file)
72
+ for segment in transcript:
73
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
74
+
75
+ print(
76
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
77
+ f"{text}\n",
78
+ file=file,
79
+ flush=True,
80
+ )
81
+
82
+
83
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
84
+ """
85
+ Write a transcript to a file in SRT format.
86
+ Example usage:
87
+ from pathlib import Path
88
+ from whisper.utils import write_srt
89
+ result = transcribe(model, audio_path, temperature=temperature, **args)
90
+ # save SRT
91
+ audio_basename = Path(audio_path).stem
92
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
93
+ write_srt(result["segments"], file=srt)
94
+ """
95
+ for i, segment in enumerate(transcript, start=1):
96
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
97
+
98
+ # write srt lines
99
+ print(
100
+ f"{i}\n"
101
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
102
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
103
+ f"{text}\n",
104
+ file=file,
105
+ flush=True,
106
+ )
107
+
108
+ def process_text(text: str, maxLineWidth=None):
109
+ if (maxLineWidth is None or maxLineWidth < 0):
110
+ return text
111
+
112
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
113
+ return '\n'.join(lines)
114
+
115
+ def slugify(value, allow_unicode=False):
116
+ """
117
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
118
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
119
+ dashes to single dashes. Remove characters that aren't alphanumerics,
120
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
121
+ trailing whitespace, dashes, and underscores.
122
+ """
123
+ value = str(value)
124
+ if allow_unicode:
125
+ value = unicodedata.normalize('NFKC', value)
126
+ else:
127
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
128
+ value = re.sub(r'[^\w\s-]', '', value.lower())
129
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
vad.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+ import time
4
+
5
+ from typing import Any, Deque, Iterator, List, Dict
6
+
7
+ from pprint import pprint
8
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
9
+
10
+ from src.segments import merge_timestamps
11
+ from src.whisperContainer import WhisperCallback
12
+
13
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
14
+ try:
15
+ import tensorflow as tf
16
+ except ModuleNotFoundError:
17
+ # Error handling
18
+ pass
19
+
20
+ import torch
21
+
22
+ import ffmpeg
23
+ import numpy as np
24
+
25
+ from src.utils import format_timestamp
26
+ from enum import Enum
27
+
28
+ class NonSpeechStrategy(Enum):
29
+ """
30
+ Ignore non-speech frames segments.
31
+ """
32
+ SKIP = 1
33
+ """
34
+ Just treat non-speech segments as speech.
35
+ """
36
+ CREATE_SEGMENT = 2
37
+ """
38
+ Expand speech segments into subsequent non-speech segments.
39
+ """
40
+ EXPAND_SEGMENT = 3
41
+
42
+ # Defaults for Silero
43
+ SPEECH_TRESHOLD = 0.3
44
+
45
+ # Minimum size of segments to process
46
+ MIN_SEGMENT_DURATION = 1
47
+
48
+ # The maximum time for texts from old segments to be used in the next segment
49
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
50
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
51
+
52
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
53
+
54
+ class TranscriptionConfig(ABC):
55
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
56
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
57
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
58
+ self.non_speech_strategy = non_speech_strategy
59
+ self.segment_padding_left = segment_padding_left
60
+ self.segment_padding_right = segment_padding_right
61
+ self.max_silent_period = max_silent_period
62
+ self.max_merge_size = max_merge_size
63
+ self.max_prompt_window = max_prompt_window
64
+ self.initial_segment_index = initial_segment_index
65
+
66
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
67
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
68
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
69
+ max_merge_size: float = None, max_prompt_window: float = None, initial_segment_index = -1):
70
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window, initial_segment_index)
71
+ self.periodic_duration = periodic_duration
72
+
73
+ class AbstractTranscription(ABC):
74
+ def __init__(self, sampling_rate: int = 16000):
75
+ self.sampling_rate = sampling_rate
76
+
77
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
78
+ return load_audio(str, self.sampling_rate, start_time, duration)
79
+
80
+ def is_transcribe_timestamps_fast(self):
81
+ """
82
+ Determine if get_transcribe_timestamps is fast enough to not need parallelization.
83
+ """
84
+ return False
85
+
86
+ @abstractmethod
87
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
88
+ """
89
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
90
+
91
+ Parameters
92
+ ----------
93
+ audio: str
94
+ The audio file.
95
+ config: TranscriptionConfig
96
+ The transcription configuration.
97
+
98
+ Returns
99
+ -------
100
+ A list of start and end timestamps, in fractional seconds.
101
+ """
102
+ return
103
+
104
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: TranscriptionConfig, total_duration: float):
105
+ """
106
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method,
107
+ after merging the given segments using the specified configuration.
108
+
109
+ Parameters
110
+ ----------
111
+ audio: str
112
+ The audio file.
113
+ config: TranscriptionConfig
114
+ The transcription configuration.
115
+
116
+ Returns
117
+ -------
118
+ A list of start and end timestamps, in fractional seconds.
119
+ """
120
+ merged = merge_timestamps(timestamps, config.max_silent_period, config.max_merge_size,
121
+ config.segment_padding_left, config.segment_padding_right)
122
+
123
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
124
+ # Expand segments to include the gaps between them
125
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
126
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
127
+ merged = self.fill_gaps(merged, total_duration=total_duration, max_expand_size=config.max_merge_size)
128
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
129
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
130
+ merged = self.expand_gaps(merged, total_duration=total_duration)
131
+ else:
132
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
133
+
134
+ print("Transcribing non-speech:")
135
+ pprint(merged)
136
+ return merged
137
+
138
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig):
139
+ """
140
+ Transcribe the given audo file.
141
+
142
+ Parameters
143
+ ----------
144
+ audio: str
145
+ The audio file.
146
+ whisperCallable: WhisperCallback
147
+ A callback object to call to transcribe each segment.
148
+
149
+ Returns
150
+ -------
151
+ A list of start and end timestamps, in fractional seconds.
152
+ """
153
+
154
+ max_audio_duration = get_audio_duration(audio)
155
+ timestamp_segments = self.get_transcribe_timestamps(audio, config, 0, max_audio_duration)
156
+
157
+ # Get speech timestamps from full audio file
158
+ merged = self.get_merged_timestamps(timestamp_segments, config, max_audio_duration)
159
+
160
+ # A deque of transcribed segments that is passed to the next segment as a prompt
161
+ prompt_window = deque()
162
+
163
+ print("Processing timestamps:")
164
+ pprint(merged)
165
+
166
+ result = {
167
+ 'text': "",
168
+ 'segments': [],
169
+ 'language': ""
170
+ }
171
+ languageCounter = Counter()
172
+ detected_language = None
173
+
174
+ segment_index = config.initial_segment_index
175
+
176
+ # For each time segment, run whisper
177
+ for segment in merged:
178
+ segment_index += 1
179
+ segment_start = segment['start']
180
+ segment_end = segment['end']
181
+ segment_expand_amount = segment.get('expand_amount', 0)
182
+ segment_gap = segment.get('gap', False)
183
+
184
+ segment_duration = segment_end - segment_start
185
+
186
+ if segment_duration < MIN_SEGMENT_DURATION:
187
+ continue;
188
+
189
+ # Audio to run on Whisper
190
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
191
+ # Previous segments to use as a prompt
192
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
193
+
194
+ # Detected language
195
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
196
+
197
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
198
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
199
+ segment_result = whisperCallable.invoke(segment_audio, segment_index, segment_prompt, detected_language)
200
+
201
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
202
+
203
+ # Propagate expand amount to the segments
204
+ if (segment_expand_amount > 0):
205
+ segment_without_expansion = segment_duration - segment_expand_amount
206
+
207
+ for adjusted_segment in adjusted_segments:
208
+ adjusted_segment_end = adjusted_segment['end']
209
+
210
+ # Add expand amount if the segment got expanded
211
+ if (adjusted_segment_end > segment_without_expansion):
212
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
213
+
214
+ # Append to output
215
+ result['text'] += segment_result['text']
216
+ result['segments'].extend(adjusted_segments)
217
+
218
+ # Increment detected language
219
+ if not segment_gap:
220
+ languageCounter[segment_result['language']] += 1
221
+
222
+ # Update prompt window
223
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
224
+
225
+ if detected_language is not None:
226
+ result['language'] = detected_language
227
+
228
+ return result
229
+
230
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
231
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
232
+ # Add segments to the current prompt window (unless it is a speech gap)
233
+ if not segment_gap:
234
+ for segment in adjusted_segments:
235
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
236
+ prompt_window.append(segment)
237
+
238
+ while (len(prompt_window) > 0):
239
+ first_end_time = prompt_window[0].get('end', 0)
240
+ # Time expanded in the segments should be discounted from the prompt window
241
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
242
+
243
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
244
+ prompt_window.popleft()
245
+ else:
246
+ break
247
+
248
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
249
+ result = []
250
+ last_end_time = 0
251
+
252
+ for segment in segments:
253
+ segment_start = float(segment['start'])
254
+ segment_end = float(segment['end'])
255
+
256
+ if (last_end_time != segment_start):
257
+ delta = segment_start - last_end_time
258
+
259
+ if (min_gap_length is None or delta >= min_gap_length):
260
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
261
+
262
+ last_end_time = segment_end
263
+ result.append(segment)
264
+
265
+ # Also include total duration if specified
266
+ if (total_duration is not None and last_end_time < total_duration):
267
+ delta = total_duration - segment_start
268
+
269
+ if (min_gap_length is None or delta >= min_gap_length):
270
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
271
+
272
+ return result
273
+
274
+ # Expand the end time of each segment to the start of the next segment
275
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
276
+ result = []
277
+
278
+ if len(segments) == 0:
279
+ return result
280
+
281
+ # Add gap at the beginning if needed
282
+ if (segments[0]['start'] > 0):
283
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
284
+
285
+ for i in range(len(segments) - 1):
286
+ current_segment = segments[i]
287
+ next_segment = segments[i + 1]
288
+
289
+ delta = next_segment['start'] - current_segment['end']
290
+
291
+ # Expand if the gap actually exists
292
+ if (delta >= 0):
293
+ current_segment = current_segment.copy()
294
+ current_segment['expand_amount'] = delta
295
+ current_segment['end'] = next_segment['start']
296
+
297
+ result.append(current_segment)
298
+
299
+ # Add last segment
300
+ last_segment = segments[-1]
301
+ result.append(last_segment)
302
+
303
+ # Also include total duration if specified
304
+ if (total_duration is not None):
305
+ last_segment = result[-1]
306
+
307
+ if (last_segment['end'] < total_duration):
308
+ last_segment = last_segment.copy()
309
+ last_segment['end'] = total_duration
310
+ result[-1] = last_segment
311
+
312
+ return result
313
+
314
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
315
+ result = []
316
+
317
+ if len(segments) == 0:
318
+ return result
319
+
320
+ # Add gap at the beginning if needed
321
+ if (segments[0]['start'] > 0):
322
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
323
+
324
+ for i in range(len(segments) - 1):
325
+ expanded = False
326
+ current_segment = segments[i]
327
+ next_segment = segments[i + 1]
328
+
329
+ delta = next_segment['start'] - current_segment['end']
330
+
331
+ if (max_expand_size is not None and delta <= max_expand_size):
332
+ # Just expand the current segment
333
+ current_segment = current_segment.copy()
334
+ current_segment['expand_amount'] = delta
335
+ current_segment['end'] = next_segment['start']
336
+ expanded = True
337
+
338
+ result.append(current_segment)
339
+
340
+ # Add a gap to the next segment if needed
341
+ if (delta >= 0 and not expanded):
342
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
343
+
344
+ # Add last segment
345
+ last_segment = segments[-1]
346
+ result.append(last_segment)
347
+
348
+ # Also include total duration if specified
349
+ if (total_duration is not None):
350
+ last_segment = result[-1]
351
+
352
+ delta = total_duration - last_segment['end']
353
+
354
+ if (delta > 0):
355
+ if (max_expand_size is not None and delta <= max_expand_size):
356
+ # Expand the last segment
357
+ last_segment = last_segment.copy()
358
+ last_segment['expand_amount'] = delta
359
+ last_segment['end'] = total_duration
360
+ result[-1] = last_segment
361
+ else:
362
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
363
+
364
+ return result
365
+
366
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
367
+ result = []
368
+
369
+ for segment in segments:
370
+ segment_start = float(segment['start'])
371
+ segment_end = float(segment['end'])
372
+
373
+ # Filter segments?
374
+ if (max_source_time is not None):
375
+ if (segment_start > max_source_time):
376
+ continue
377
+ segment_end = min(max_source_time, segment_end)
378
+
379
+ new_segment = segment.copy()
380
+
381
+ # Add to start and end
382
+ new_segment['start'] = segment_start + adjust_seconds
383
+ new_segment['end'] = segment_end + adjust_seconds
384
+ result.append(new_segment)
385
+ return result
386
+
387
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
388
+ result = []
389
+
390
+ for entry in timestamps:
391
+ start = entry['start']
392
+ end = entry['end']
393
+
394
+ result.append({
395
+ 'start': start * factor,
396
+ 'end': end * factor
397
+ })
398
+ return result
399
+
400
+
401
+ class VadSileroTranscription(AbstractTranscription):
402
+ def __init__(self, sampling_rate: int = 16000, cache: ModelCache = None):
403
+ super().__init__(sampling_rate=sampling_rate)
404
+ self.model = None
405
+ self.cache = cache
406
+ self._initialize_model()
407
+
408
+ def _initialize_model(self):
409
+ if (self.cache is not None):
410
+ model_key = "VadSileroTranscription"
411
+ self.model, self.get_speech_timestamps = self.cache.get(model_key, self._create_model)
412
+ print("Loaded Silerio model from cache.")
413
+ else:
414
+ self.model, self.get_speech_timestamps = self._create_model()
415
+ print("Created Silerio model")
416
+
417
+ def _create_model(self):
418
+ model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
419
+
420
+ # Silero does not benefit from multi-threading
421
+ torch.set_num_threads(1) # JIT
422
+ (get_speech_timestamps, _, _, _, _) = utils
423
+
424
+ return model, get_speech_timestamps
425
+
426
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig, start_time: float, end_time: float):
427
+ result = []
428
+
429
+ print("Getting timestamps from audio file: {}, start: {}, duration: {}".format(audio, start_time, end_time))
430
+ perf_start_time = time.perf_counter()
431
+
432
+ # Divide procesisng of audio into chunks
433
+ chunk_start = start_time
434
+
435
+ while (chunk_start < end_time):
436
+ chunk_duration = min(end_time - chunk_start, VAD_MAX_PROCESSING_CHUNK)
437
+
438
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
439
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
440
+
441
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
442
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
443
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
444
+
445
+ #pprint(adjusted)
446
+
447
+ result.extend(adjusted)
448
+ chunk_start += chunk_duration
449
+
450
+ perf_end_time = time.perf_counter()
451
+ print("VAD processing took {} seconds".format(perf_end_time - perf_start_time))
452
+
453
+ return result
454
+
455
+ def __getstate__(self):
456
+ # We only need the sampling rate
457
+ return { 'sampling_rate': self.sampling_rate }
458
+
459
+ def __setstate__(self, state):
460
+ self.sampling_rate = state['sampling_rate']
461
+ self.model = None
462
+ # Use the global cache
463
+ self.cache = GLOBAL_MODEL_CACHE
464
+ self._initialize_model()
465
+
466
+ # A very simple VAD that just marks every N seconds as speech
467
+ class VadPeriodicTranscription(AbstractTranscription):
468
+ def __init__(self, sampling_rate: int = 16000):
469
+ super().__init__(sampling_rate=sampling_rate)
470
+
471
+ def is_transcribe_timestamps_fast(self):
472
+ # This is a very fast VAD - no need to parallelize it
473
+ return True
474
+
475
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig, start_time: float, end_time: float):
476
+ result = []
477
+
478
+ # Generate a timestamp every N seconds
479
+ start_timestamp = start_time
480
+
481
+ while (start_timestamp < end_time):
482
+ end_timestamp = min(start_timestamp + config.periodic_duration, end_time)
483
+ segment_duration = end_timestamp - start_timestamp
484
+
485
+ # Minimum duration is 1 second
486
+ if (segment_duration >= 1):
487
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
488
+
489
+ start_timestamp = end_timestamp
490
+
491
+ return result
492
+
493
+ def get_audio_duration(file: str):
494
+ return float(ffmpeg.probe(file)["format"]["duration"])
495
+
496
+ def load_audio(file: str, sample_rate: int = 16000,
497
+ start_time: str = None, duration: str = None):
498
+ """
499
+ Open an audio file and read as mono waveform, resampling as necessary
500
+
501
+ Parameters
502
+ ----------
503
+ file: str
504
+ The audio file to open
505
+
506
+ sr: int
507
+ The sample rate to resample the audio if necessary
508
+
509
+ start_time: str
510
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
511
+
512
+ duration: str
513
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
514
+
515
+ Returns
516
+ -------
517
+ A NumPy array containing the audio waveform, in float32 dtype.
518
+ """
519
+ try:
520
+ inputArgs = {'threads': 0}
521
+
522
+ if (start_time is not None):
523
+ inputArgs['ss'] = start_time
524
+ if (duration is not None):
525
+ inputArgs['t'] = duration
526
+
527
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
528
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
529
+ out, _ = (
530
+ ffmpeg.input(file, **inputArgs)
531
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
532
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
533
+ )
534
+ except ffmpeg.Error as e:
535
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
536
+
537
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
vadParallel.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ import threading
3
+ import time
4
+ from src.vad import AbstractTranscription, TranscriptionConfig, get_audio_duration
5
+ from src.whisperContainer import WhisperCallback
6
+
7
+ from multiprocessing import Pool
8
+
9
+ from typing import Any, Dict, List
10
+ import os
11
+
12
+
13
+ class ParallelContext:
14
+ def __init__(self, num_processes: int = None, auto_cleanup_timeout_seconds: float = None):
15
+ self.num_processes = num_processes
16
+ self.auto_cleanup_timeout_seconds = auto_cleanup_timeout_seconds
17
+ self.lock = threading.Lock()
18
+
19
+ self.ref_count = 0
20
+ self.pool = None
21
+ self.cleanup_timer = None
22
+
23
+ def get_pool(self):
24
+ # Initialize pool lazily
25
+ if (self.pool is None):
26
+ context = multiprocessing.get_context('spawn')
27
+ self.pool = context.Pool(self.num_processes)
28
+
29
+ self.ref_count = self.ref_count + 1
30
+
31
+ if (self.auto_cleanup_timeout_seconds is not None):
32
+ self._stop_auto_cleanup()
33
+
34
+ return self.pool
35
+
36
+ def return_pool(self, pool):
37
+ if (self.pool == pool and self.ref_count > 0):
38
+ self.ref_count = self.ref_count - 1
39
+
40
+ if (self.ref_count == 0):
41
+ if (self.auto_cleanup_timeout_seconds is not None):
42
+ self._start_auto_cleanup()
43
+
44
+ def _start_auto_cleanup(self):
45
+ if (self.cleanup_timer is not None):
46
+ self.cleanup_timer.cancel()
47
+ self.cleanup_timer = threading.Timer(self.auto_cleanup_timeout_seconds, self._execute_cleanup)
48
+ self.cleanup_timer.start()
49
+
50
+ print("Started auto cleanup of pool in " + str(self.auto_cleanup_timeout_seconds) + " seconds")
51
+
52
+ def _stop_auto_cleanup(self):
53
+ if (self.cleanup_timer is not None):
54
+ self.cleanup_timer.cancel()
55
+ self.cleanup_timer = None
56
+
57
+ print("Stopped auto cleanup of pool")
58
+
59
+ def _execute_cleanup(self):
60
+ print("Executing cleanup of pool")
61
+
62
+ if (self.ref_count == 0):
63
+ self.close()
64
+
65
+ def close(self):
66
+ self._stop_auto_cleanup()
67
+
68
+ if (self.pool is not None):
69
+ print("Closing pool of " + str(self.num_processes) + " processes")
70
+ self.pool.close()
71
+ self.pool.join()
72
+ self.pool = None
73
+
74
+ class ParallelTranscriptionConfig(TranscriptionConfig):
75
+ def __init__(self, device_id: str, override_timestamps, initial_segment_index, copy: TranscriptionConfig = None):
76
+ super().__init__(copy.non_speech_strategy, copy.segment_padding_left, copy.segment_padding_right, copy.max_silent_period, copy.max_merge_size, copy.max_prompt_window, initial_segment_index)
77
+ self.device_id = device_id
78
+ self.override_timestamps = override_timestamps
79
+
80
+ class ParallelTranscription(AbstractTranscription):
81
+ # Silero VAD typically takes about 3 seconds per minute, so there's no need to split the chunks
82
+ # into smaller segments than 2 minute (min 6 seconds per CPU core)
83
+ MIN_CPU_CHUNK_SIZE_SECONDS = 2 * 60
84
+
85
+ def __init__(self, sampling_rate: int = 16000):
86
+ super().__init__(sampling_rate=sampling_rate)
87
+
88
+ def transcribe_parallel(self, transcription: AbstractTranscription, audio: str, whisperCallable: WhisperCallback, config: TranscriptionConfig,
89
+ cpu_device_count: int, gpu_devices: List[str], cpu_parallel_context: ParallelContext = None, gpu_parallel_context: ParallelContext = None):
90
+ total_duration = get_audio_duration(audio)
91
+
92
+ # First, get the timestamps for the original audio
93
+ if (cpu_device_count > 1 and not transcription.is_transcribe_timestamps_fast()):
94
+ merged = self._get_merged_timestamps_parallel(transcription, audio, config, total_duration, cpu_device_count, cpu_parallel_context)
95
+ else:
96
+ timestamp_segments = transcription.get_transcribe_timestamps(audio, config, 0, total_duration)
97
+ merged = transcription.get_merged_timestamps(timestamp_segments, config, total_duration)
98
+
99
+ # We must make sure the whisper model is downloaded
100
+ if (len(gpu_devices) > 1):
101
+ whisperCallable.model_container.ensure_downloaded()
102
+
103
+ # Split into a list for each device
104
+ # TODO: Split by time instead of by number of chunks
105
+ merged_split = list(self._split(merged, len(gpu_devices)))
106
+
107
+ # Parameters that will be passed to the transcribe function
108
+ parameters = []
109
+ segment_index = config.initial_segment_index
110
+
111
+ for i in range(len(gpu_devices)):
112
+ # Note that device_segment_list can be empty. But we will still create a process for it,
113
+ # as otherwise we run the risk of assigning the same device to multiple processes.
114
+ device_segment_list = list(merged_split[i]) if i < len(merged_split) else []
115
+ device_id = gpu_devices[i]
116
+
117
+ print("Device " + str(device_id) + " (index " + str(i) + ") has " + str(len(device_segment_list)) + " segments")
118
+
119
+ # Create a new config with the given device ID
120
+ device_config = ParallelTranscriptionConfig(device_id, device_segment_list, segment_index, config)
121
+ segment_index += len(device_segment_list)
122
+
123
+ parameters.append([audio, whisperCallable, device_config]);
124
+
125
+ merged = {
126
+ 'text': '',
127
+ 'segments': [],
128
+ 'language': None
129
+ }
130
+
131
+ created_context = False
132
+
133
+ perf_start_gpu = time.perf_counter()
134
+
135
+ # Spawn a separate process for each device
136
+ try:
137
+ if (gpu_parallel_context is None):
138
+ gpu_parallel_context = ParallelContext(len(gpu_devices))
139
+ created_context = True
140
+
141
+ # Get a pool of processes
142
+ pool = gpu_parallel_context.get_pool()
143
+
144
+ # Run the transcription in parallel
145
+ results = pool.starmap(self.transcribe, parameters)
146
+
147
+ for result in results:
148
+ # Merge the results
149
+ if (result['text'] is not None):
150
+ merged['text'] += result['text']
151
+ if (result['segments'] is not None):
152
+ merged['segments'].extend(result['segments'])
153
+ if (result['language'] is not None):
154
+ merged['language'] = result['language']
155
+
156
+ finally:
157
+ # Return the pool to the context
158
+ if (gpu_parallel_context is not None):
159
+ gpu_parallel_context.return_pool(pool)
160
+ # Always close the context if we created it
161
+ if (created_context):
162
+ gpu_parallel_context.close()
163
+
164
+ perf_end_gpu = time.perf_counter()
165
+ print("Parallel transcription took " + str(perf_end_gpu - perf_start_gpu) + " seconds")
166
+
167
+ return merged
168
+
169
+ def _get_merged_timestamps_parallel(self, transcription: AbstractTranscription, audio: str, config: TranscriptionConfig, total_duration: float,
170
+ cpu_device_count: int, cpu_parallel_context: ParallelContext = None):
171
+ parameters = []
172
+
173
+ chunk_size = max(total_duration / cpu_device_count, self.MIN_CPU_CHUNK_SIZE_SECONDS)
174
+ chunk_start = 0
175
+ cpu_device_id = 0
176
+
177
+ perf_start_time = time.perf_counter()
178
+
179
+ # Create chunks that will be processed on the CPU
180
+ while (chunk_start < total_duration):
181
+ chunk_end = min(chunk_start + chunk_size, total_duration)
182
+
183
+ if (chunk_end - chunk_start < 1):
184
+ # No need to process chunks that are less than 1 second
185
+ break
186
+
187
+ print("Parallel VAD: Executing chunk from " + str(chunk_start) + " to " +
188
+ str(chunk_end) + " on CPU device " + str(cpu_device_id))
189
+ parameters.append([audio, config, chunk_start, chunk_end]);
190
+
191
+ cpu_device_id += 1
192
+ chunk_start = chunk_end
193
+
194
+ created_context = False
195
+
196
+ # Spawn a separate process for each device
197
+ try:
198
+ if (cpu_parallel_context is None):
199
+ cpu_parallel_context = ParallelContext(cpu_device_count)
200
+ created_context = True
201
+
202
+ # Get a pool of processes
203
+ pool = cpu_parallel_context.get_pool()
204
+
205
+ # Run the transcription in parallel. Note that transcription must be picklable.
206
+ results = pool.starmap(transcription.get_transcribe_timestamps, parameters)
207
+
208
+ timestamps = []
209
+
210
+ # Flatten the results
211
+ for result in results:
212
+ timestamps.extend(result)
213
+
214
+ merged = transcription.get_merged_timestamps(timestamps, config, total_duration)
215
+
216
+ perf_end_time = time.perf_counter()
217
+ print("Parallel VAD processing took {} seconds".format(perf_end_time - perf_start_time))
218
+ return merged
219
+
220
+ finally:
221
+ # Return the pool to the context
222
+ if (cpu_parallel_context is not None):
223
+ cpu_parallel_context.return_pool(pool)
224
+ # Always close the context if we created it
225
+ if (created_context):
226
+ cpu_parallel_context.close()
227
+
228
+ def get_transcribe_timestamps(self, audio: str, config: ParallelTranscriptionConfig, start_time: float, duration: float):
229
+ return []
230
+
231
+ def get_merged_timestamps(self, timestamps: List[Dict[str, Any]], config: ParallelTranscriptionConfig, total_duration: float):
232
+ # Override timestamps that will be processed
233
+ if (config.override_timestamps is not None):
234
+ print("Using override timestamps of size " + str(len(config.override_timestamps)))
235
+ return config.override_timestamps
236
+ return super().get_merged_timestamps(timestamps, config, total_duration)
237
+
238
+ def transcribe(self, audio: str, whisperCallable: WhisperCallback, config: ParallelTranscriptionConfig):
239
+ # Override device ID the first time
240
+ if (os.environ.get("INITIALIZED", None) is None):
241
+ os.environ["INITIALIZED"] = "1"
242
+
243
+ # Note that this may be None if the user didn't specify a device. In that case, Whisper will
244
+ # just use the default GPU device.
245
+ if (config.device_id is not None):
246
+ print("Using device " + config.device_id)
247
+ os.environ["CUDA_VISIBLE_DEVICES"] = config.device_id
248
+
249
+ return super().transcribe(audio, whisperCallable, config)
250
+
251
+ def _split(self, a, n):
252
+ """Split a list into n approximately equal parts."""
253
+ k, m = divmod(len(a), n)
254
+ return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))
255
+
whisperContainer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # External programs
2
+ import os
3
+ import whisper
4
+
5
+ from src.modelCache import GLOBAL_MODEL_CACHE, ModelCache
6
+
7
+ class WhisperContainer:
8
+ def __init__(self, model_name: str, device: str = None, download_root: str = None, cache: ModelCache = None):
9
+ self.model_name = model_name
10
+ self.device = device
11
+ self.download_root = download_root
12
+ self.cache = cache
13
+
14
+ # Will be created on demand
15
+ self.model = None
16
+
17
+ def get_model(self):
18
+ if self.model is None:
19
+
20
+ if (self.cache is None):
21
+ self.model = self._create_model()
22
+ else:
23
+ model_key = "WhisperContainer." + self.model_name + ":" + (self.device if self.device else '')
24
+ self.model = self.cache.get(model_key, self._create_model)
25
+ return self.model
26
+
27
+ def ensure_downloaded(self):
28
+ """
29
+ Ensure that the model is downloaded. This is useful if you want to ensure that the model is downloaded before
30
+ passing the container to a subprocess.
31
+ """
32
+ # Warning: Using private API here
33
+ try:
34
+ root_dir = self.download_root
35
+
36
+ if root_dir is None:
37
+ root_dir = os.path.join(os.path.expanduser("~"), ".cache", "whisper")
38
+
39
+ if self.model_name in whisper._MODELS:
40
+ whisper._download(whisper._MODELS[self.model_name], root_dir, False)
41
+ return True
42
+ except Exception as e:
43
+ # Given that the API is private, it could change at any time. We don't want to crash the program
44
+ print("Error pre-downloading model: " + str(e))
45
+ return False
46
+
47
+ def _create_model(self):
48
+ print("Loading whisper model " + self.model_name)
49
+ return whisper.load_model(self.model_name, device=self.device, download_root=self.download_root)
50
+
51
+ def create_callback(self, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
52
+ """
53
+ Create a WhisperCallback object that can be used to transcript audio files.
54
+
55
+ Parameters
56
+ ----------
57
+ language: str
58
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
59
+ task: str
60
+ The task - either translate or transcribe.
61
+ initial_prompt: str
62
+ The initial prompt to use for the transcription.
63
+ decodeOptions: dict
64
+ Additional options to pass to the decoder. Must be pickleable.
65
+
66
+ Returns
67
+ -------
68
+ A WhisperCallback object.
69
+ """
70
+ return WhisperCallback(self, language=language, task=task, initial_prompt=initial_prompt, **decodeOptions)
71
+
72
+ # This is required for multiprocessing
73
+ def __getstate__(self):
74
+ return { "model_name": self.model_name, "device": self.device, "download_root": self.download_root }
75
+
76
+ def __setstate__(self, state):
77
+ self.model_name = state["model_name"]
78
+ self.device = state["device"]
79
+ self.download_root = state["download_root"]
80
+ self.model = None
81
+ # Depickled objects must use the global cache
82
+ self.cache = GLOBAL_MODEL_CACHE
83
+
84
+
85
+ class WhisperCallback:
86
+ def __init__(self, model_container: WhisperContainer, language: str = None, task: str = None, initial_prompt: str = None, **decodeOptions: dict):
87
+ self.model_container = model_container
88
+ self.language = language
89
+ self.task = task
90
+ self.initial_prompt = initial_prompt
91
+ self.decodeOptions = decodeOptions
92
+
93
+ def invoke(self, audio, segment_index: int, prompt: str, detected_language: str):
94
+ """
95
+ Peform the transcription of the given audio file or data.
96
+
97
+ Parameters
98
+ ----------
99
+ audio: Union[str, np.ndarray, torch.Tensor]
100
+ The audio file to transcribe, or the audio data as a numpy array or torch tensor.
101
+ segment_index: int
102
+ The target language of the transcription. If not specified, the language will be inferred from the audio content.
103
+ task: str
104
+ The task - either translate or transcribe.
105
+ prompt: str
106
+ The prompt to use for the transcription.
107
+ detected_language: str
108
+ The detected language of the audio file.
109
+
110
+ Returns
111
+ -------
112
+ The result of the Whisper call.
113
+ """
114
+ model = self.model_container.get_model()
115
+
116
+ return model.transcribe(audio, \
117
+ language=self.language if self.language else detected_language, task=self.task, \
118
+ initial_prompt=self._concat_prompt(self.initial_prompt, prompt) if segment_index == 0 else prompt, \
119
+ **self.decodeOptions)
120
+
121
+ def _concat_prompt(self, prompt1, prompt2):
122
+ if (prompt1 is None):
123
+ return prompt2
124
+ elif (prompt2 is None):
125
+ return prompt1
126
+ else:
127
+ return prompt1 + " " + prompt2