Spaces:
Sleeping
Sleeping
Commit
·
ae3884d
1
Parent(s):
9271e46
Add application file
Browse files- models/models will be saved here.txt +0 -0
- modules/__init__.py +0 -0
- modules/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/diarize/__init__.py +0 -0
- modules/diarize/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/diarize/__pycache__/diarize_pipeline.cpython-310.pyc +0 -0
- modules/diarize/__pycache__/diarizer.cpython-310.pyc +0 -0
- modules/diarize/audio_loader.py +179 -0
- modules/diarize/diarize_pipeline.py +94 -0
- modules/diarize/diarizer.py +132 -0
- modules/translation/__init__.py +0 -0
- modules/translation/deepl_api.py +201 -0
- modules/translation/nllb_inference.py +276 -0
- modules/translation/translation_base.py +151 -0
- modules/utils/__init__.py +0 -0
- modules/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/utils/__pycache__/files_manager.cpython-310.pyc +0 -0
- modules/utils/__pycache__/subtitle_manager.cpython-310.pyc +0 -0
- modules/utils/__pycache__/youtube_manager.cpython-310.pyc +0 -0
- modules/utils/files_manager.py +39 -0
- modules/utils/subtitle_manager.py +135 -0
- modules/utils/youtube_manager.py +15 -0
- modules/vad/__init__.py +0 -0
- modules/vad/silero_vad.py +264 -0
- modules/whisper/__init__.py +0 -0
- modules/whisper/__pycache__/__init__.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/faster_whisper_inference.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/whisper_base.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/whisper_factory.cpython-310.pyc +0 -0
- modules/whisper/__pycache__/whisper_parameter.cpython-310.pyc +0 -0
- modules/whisper/faster_whisper_inference.py +191 -0
- modules/whisper/insanely_fast_whisper_inference.py +185 -0
- modules/whisper/whisper_Inference.py +101 -0
- modules/whisper/whisper_base.py +436 -0
- modules/whisper/whisper_factory.py +81 -0
- modules/whisper/whisper_parameter.py +277 -0
- outputs/outputs are saved here.txt +0 -0
- outputs/translations/outputs for translation are saved here.txt +0 -0
models/models will be saved here.txt
ADDED
|
File without changes
|
modules/__init__.py
ADDED
|
File without changes
|
modules/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
modules/diarize/__init__.py
ADDED
|
File without changes
|
modules/diarize/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
modules/diarize/__pycache__/diarize_pipeline.cpython-310.pyc
ADDED
|
Binary file (3.06 kB). View file
|
|
|
modules/diarize/__pycache__/diarizer.cpython-310.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
modules/diarize/audio_loader.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
from functools import lru_cache
|
| 6 |
+
from typing import Optional, Union
|
| 7 |
+
from scipy.io.wavfile import write
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
def exact_div(x, y):
|
| 15 |
+
assert x % y == 0
|
| 16 |
+
return x // y
|
| 17 |
+
|
| 18 |
+
# hard-coded audio hyperparameters
|
| 19 |
+
SAMPLE_RATE = 16000
|
| 20 |
+
N_FFT = 400
|
| 21 |
+
HOP_LENGTH = 160
|
| 22 |
+
CHUNK_LENGTH = 30
|
| 23 |
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
| 24 |
+
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
|
| 25 |
+
|
| 26 |
+
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
|
| 27 |
+
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
| 28 |
+
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
|
| 32 |
+
"""
|
| 33 |
+
Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
file: Union[str, np.ndarray]
|
| 38 |
+
The audio file to open or a numpy array containing the audio data.
|
| 39 |
+
|
| 40 |
+
sr: int
|
| 41 |
+
The sample rate to resample the audio if necessary.
|
| 42 |
+
|
| 43 |
+
Returns
|
| 44 |
+
-------
|
| 45 |
+
A NumPy array containing the audio waveform, in float32 dtype.
|
| 46 |
+
"""
|
| 47 |
+
if isinstance(file, np.ndarray):
|
| 48 |
+
if file.dtype != np.float32:
|
| 49 |
+
file = file.astype(np.float32)
|
| 50 |
+
if file.ndim > 1:
|
| 51 |
+
file = np.mean(file, axis=1)
|
| 52 |
+
|
| 53 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
| 54 |
+
write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
|
| 55 |
+
temp_file_path = temp_file.name
|
| 56 |
+
temp_file.close()
|
| 57 |
+
else:
|
| 58 |
+
temp_file_path = file
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
cmd = [
|
| 62 |
+
"ffmpeg",
|
| 63 |
+
"-nostdin",
|
| 64 |
+
"-threads",
|
| 65 |
+
"0",
|
| 66 |
+
"-i",
|
| 67 |
+
temp_file_path,
|
| 68 |
+
"-f",
|
| 69 |
+
"s16le",
|
| 70 |
+
"-ac",
|
| 71 |
+
"1",
|
| 72 |
+
"-acodec",
|
| 73 |
+
"pcm_s16le",
|
| 74 |
+
"-ar",
|
| 75 |
+
str(sr),
|
| 76 |
+
"-",
|
| 77 |
+
]
|
| 78 |
+
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
| 79 |
+
except subprocess.CalledProcessError as e:
|
| 80 |
+
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
| 81 |
+
finally:
|
| 82 |
+
if isinstance(file, np.ndarray):
|
| 83 |
+
os.remove(temp_file_path)
|
| 84 |
+
|
| 85 |
+
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
|
| 89 |
+
"""
|
| 90 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
| 91 |
+
"""
|
| 92 |
+
if torch.is_tensor(array):
|
| 93 |
+
if array.shape[axis] > length:
|
| 94 |
+
array = array.index_select(
|
| 95 |
+
dim=axis, index=torch.arange(length, device=array.device)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if array.shape[axis] < length:
|
| 99 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 100 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 101 |
+
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
|
| 102 |
+
else:
|
| 103 |
+
if array.shape[axis] > length:
|
| 104 |
+
array = array.take(indices=range(length), axis=axis)
|
| 105 |
+
|
| 106 |
+
if array.shape[axis] < length:
|
| 107 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 108 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 109 |
+
array = np.pad(array, pad_widths)
|
| 110 |
+
|
| 111 |
+
return array
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@lru_cache(maxsize=None)
|
| 115 |
+
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
| 118 |
+
Allows decoupling librosa dependency; saved using:
|
| 119 |
+
|
| 120 |
+
np.savez_compressed(
|
| 121 |
+
"mel_filters.npz",
|
| 122 |
+
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
| 123 |
+
)
|
| 124 |
+
"""
|
| 125 |
+
assert n_mels in [80, 128], f"Unsupported n_mels: {n_mels}"
|
| 126 |
+
with np.load(
|
| 127 |
+
os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
| 128 |
+
) as f:
|
| 129 |
+
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def log_mel_spectrogram(
|
| 133 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 134 |
+
n_mels: int,
|
| 135 |
+
padding: int = 0,
|
| 136 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 137 |
+
):
|
| 138 |
+
"""
|
| 139 |
+
Compute the log-Mel spectrogram of
|
| 140 |
+
|
| 141 |
+
Parameters
|
| 142 |
+
----------
|
| 143 |
+
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
| 144 |
+
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
| 145 |
+
|
| 146 |
+
n_mels: int
|
| 147 |
+
The number of Mel-frequency filters, only 80 is supported
|
| 148 |
+
|
| 149 |
+
padding: int
|
| 150 |
+
Number of zero samples to pad to the right
|
| 151 |
+
|
| 152 |
+
device: Optional[Union[str, torch.device]]
|
| 153 |
+
If given, the audio tensor is moved to this device before STFT
|
| 154 |
+
|
| 155 |
+
Returns
|
| 156 |
+
-------
|
| 157 |
+
torch.Tensor, shape = (80, n_frames)
|
| 158 |
+
A Tensor that contains the Mel spectrogram
|
| 159 |
+
"""
|
| 160 |
+
if not torch.is_tensor(audio):
|
| 161 |
+
if isinstance(audio, str):
|
| 162 |
+
audio = load_audio(audio)
|
| 163 |
+
audio = torch.from_numpy(audio)
|
| 164 |
+
|
| 165 |
+
if device is not None:
|
| 166 |
+
audio = audio.to(device)
|
| 167 |
+
if padding > 0:
|
| 168 |
+
audio = F.pad(audio, (0, padding))
|
| 169 |
+
window = torch.hann_window(N_FFT).to(audio.device)
|
| 170 |
+
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
| 171 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
| 172 |
+
|
| 173 |
+
filters = mel_filters(audio.device, n_mels)
|
| 174 |
+
mel_spec = filters @ magnitudes
|
| 175 |
+
|
| 176 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
| 177 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
| 178 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 179 |
+
return log_spec
|
modules/diarize/diarize_pipeline.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import os
|
| 6 |
+
from pyannote.audio import Pipeline
|
| 7 |
+
from typing import Optional, Union
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DiarizationPipeline:
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
model_name="pyannote/speaker-diarization-3.1",
|
| 17 |
+
cache_dir: str = os.path.join("models", "Diarization"),
|
| 18 |
+
use_auth_token=None,
|
| 19 |
+
device: Optional[Union[str, torch.device]] = "cpu",
|
| 20 |
+
):
|
| 21 |
+
if isinstance(device, str):
|
| 22 |
+
device = torch.device(device)
|
| 23 |
+
self.model = Pipeline.from_pretrained(
|
| 24 |
+
model_name,
|
| 25 |
+
use_auth_token=use_auth_token,
|
| 26 |
+
cache_dir=cache_dir
|
| 27 |
+
).to(device)
|
| 28 |
+
|
| 29 |
+
def __call__(self, audio: Union[str, np.ndarray], min_speakers=None, max_speakers=None):
|
| 30 |
+
if isinstance(audio, str):
|
| 31 |
+
audio = load_audio(audio)
|
| 32 |
+
audio_data = {
|
| 33 |
+
'waveform': torch.from_numpy(audio[None, :]),
|
| 34 |
+
'sample_rate': SAMPLE_RATE
|
| 35 |
+
}
|
| 36 |
+
segments = self.model(audio_data, min_speakers=min_speakers, max_speakers=max_speakers)
|
| 37 |
+
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True), columns=['segment', 'label', 'speaker'])
|
| 38 |
+
diarize_df['start'] = diarize_df['segment'].apply(lambda x: x.start)
|
| 39 |
+
diarize_df['end'] = diarize_df['segment'].apply(lambda x: x.end)
|
| 40 |
+
return diarize_df
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
| 44 |
+
transcript_segments = transcript_result["segments"]
|
| 45 |
+
for seg in transcript_segments:
|
| 46 |
+
# assign speaker to segment (if any)
|
| 47 |
+
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
| 48 |
+
seg['start'])
|
| 49 |
+
diarize_df['union'] = np.maximum(diarize_df['end'], seg['end']) - np.minimum(diarize_df['start'], seg['start'])
|
| 50 |
+
|
| 51 |
+
intersected = diarize_df[diarize_df["intersection"] > 0]
|
| 52 |
+
|
| 53 |
+
speaker = None
|
| 54 |
+
if len(intersected) > 0:
|
| 55 |
+
# Choosing most strong intersection
|
| 56 |
+
speaker = intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
| 57 |
+
elif fill_nearest:
|
| 58 |
+
# Otherwise choosing closest
|
| 59 |
+
speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
|
| 60 |
+
|
| 61 |
+
if speaker is not None:
|
| 62 |
+
seg["speaker"] = speaker
|
| 63 |
+
|
| 64 |
+
# assign speaker to words
|
| 65 |
+
if 'words' in seg:
|
| 66 |
+
for word in seg['words']:
|
| 67 |
+
if 'start' in word:
|
| 68 |
+
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
|
| 69 |
+
diarize_df['start'], word['start'])
|
| 70 |
+
diarize_df['union'] = np.maximum(diarize_df['end'], word['end']) - np.minimum(diarize_df['start'],
|
| 71 |
+
word['start'])
|
| 72 |
+
|
| 73 |
+
intersected = diarize_df[diarize_df["intersection"] > 0]
|
| 74 |
+
|
| 75 |
+
word_speaker = None
|
| 76 |
+
if len(intersected) > 0:
|
| 77 |
+
# Choosing most strong intersection
|
| 78 |
+
word_speaker = \
|
| 79 |
+
intersected.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
|
| 80 |
+
elif fill_nearest:
|
| 81 |
+
# Otherwise choosing closest
|
| 82 |
+
word_speaker = diarize_df.sort_values(by=["intersection"], ascending=False)["speaker"].values[0]
|
| 83 |
+
|
| 84 |
+
if word_speaker is not None:
|
| 85 |
+
word["speaker"] = word_speaker
|
| 86 |
+
|
| 87 |
+
return transcript_result
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Segment:
|
| 91 |
+
def __init__(self, start, end, speaker=None):
|
| 92 |
+
self.start = start
|
| 93 |
+
self.end = end
|
| 94 |
+
self.speaker = speaker
|
modules/diarize/diarizer.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from typing import List, Union, BinaryIO, Optional
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
| 9 |
+
from modules.diarize.audio_loader import load_audio
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Diarizer:
|
| 13 |
+
def __init__(self,
|
| 14 |
+
model_dir: str = os.path.join("models", "Diarization")
|
| 15 |
+
):
|
| 16 |
+
self.device = self.get_device()
|
| 17 |
+
self.available_device = self.get_available_device()
|
| 18 |
+
self.compute_type = "float16"
|
| 19 |
+
self.model_dir = model_dir
|
| 20 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 21 |
+
self.pipe = None
|
| 22 |
+
|
| 23 |
+
def run(self,
|
| 24 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
| 25 |
+
transcribed_result: List[dict],
|
| 26 |
+
use_auth_token: str,
|
| 27 |
+
device: Optional[str] = None
|
| 28 |
+
):
|
| 29 |
+
"""
|
| 30 |
+
Diarize transcribed result as a post-processing
|
| 31 |
+
|
| 32 |
+
Parameters
|
| 33 |
+
----------
|
| 34 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 35 |
+
Audio input. This can be file path or binary type.
|
| 36 |
+
transcribed_result: List[dict]
|
| 37 |
+
transcribed result through whisper.
|
| 38 |
+
use_auth_token: str
|
| 39 |
+
Huggingface token with READ permission. This is only needed the first time you download the model.
|
| 40 |
+
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
| 41 |
+
device: Optional[str]
|
| 42 |
+
Device for diarization.
|
| 43 |
+
|
| 44 |
+
Returns
|
| 45 |
+
----------
|
| 46 |
+
segments_result: List[dict]
|
| 47 |
+
list of dicts that includes start, end timestamps and transcribed text
|
| 48 |
+
elapsed_time: float
|
| 49 |
+
elapsed time for running
|
| 50 |
+
"""
|
| 51 |
+
start_time = time.time()
|
| 52 |
+
|
| 53 |
+
if device is None:
|
| 54 |
+
device = self.device
|
| 55 |
+
|
| 56 |
+
if device != self.device or self.pipe is None:
|
| 57 |
+
self.update_pipe(
|
| 58 |
+
device=device,
|
| 59 |
+
use_auth_token=use_auth_token
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
audio = load_audio(audio)
|
| 63 |
+
|
| 64 |
+
diarization_segments = self.pipe(audio)
|
| 65 |
+
diarized_result = assign_word_speakers(
|
| 66 |
+
diarization_segments,
|
| 67 |
+
{"segments": transcribed_result}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
for segment in diarized_result["segments"]:
|
| 71 |
+
speaker = "None"
|
| 72 |
+
if "speaker" in segment:
|
| 73 |
+
speaker = segment["speaker"]
|
| 74 |
+
segment["text"] = speaker + "|" + segment["text"].strip()
|
| 75 |
+
|
| 76 |
+
elapsed_time = time.time() - start_time
|
| 77 |
+
return diarized_result["segments"], elapsed_time
|
| 78 |
+
|
| 79 |
+
def update_pipe(self,
|
| 80 |
+
use_auth_token: str,
|
| 81 |
+
device: str
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Set pipeline for diarization
|
| 85 |
+
|
| 86 |
+
Parameters
|
| 87 |
+
----------
|
| 88 |
+
use_auth_token: str
|
| 89 |
+
Huggingface token with READ permission. This is only needed the first time you download the model.
|
| 90 |
+
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
| 91 |
+
device: str
|
| 92 |
+
Device for diarization.
|
| 93 |
+
"""
|
| 94 |
+
self.device = device
|
| 95 |
+
|
| 96 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 97 |
+
|
| 98 |
+
if (not os.listdir(self.model_dir) and
|
| 99 |
+
not use_auth_token):
|
| 100 |
+
print(
|
| 101 |
+
"\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
|
| 102 |
+
"Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
|
| 103 |
+
)
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
logger = logging.getLogger("speechbrain.utils.train_logger")
|
| 107 |
+
# Disable redundant torchvision warning message
|
| 108 |
+
logger.disabled = True
|
| 109 |
+
self.pipe = DiarizationPipeline(
|
| 110 |
+
use_auth_token=use_auth_token,
|
| 111 |
+
device=device,
|
| 112 |
+
cache_dir=self.model_dir
|
| 113 |
+
)
|
| 114 |
+
logger.disabled = False
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def get_device():
|
| 118 |
+
if torch.cuda.is_available():
|
| 119 |
+
return "cuda"
|
| 120 |
+
elif torch.backends.mps.is_available():
|
| 121 |
+
return "mps"
|
| 122 |
+
else:
|
| 123 |
+
return "cpu"
|
| 124 |
+
|
| 125 |
+
@staticmethod
|
| 126 |
+
def get_available_device():
|
| 127 |
+
devices = ["cpu"]
|
| 128 |
+
if torch.cuda.is_available():
|
| 129 |
+
devices.append("cuda")
|
| 130 |
+
elif torch.backends.mps.is_available():
|
| 131 |
+
devices.append("mps")
|
| 132 |
+
return devices
|
modules/translation/__init__.py
ADDED
|
File without changes
|
modules/translation/deepl_api.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import requests
|
| 2 |
+
import time
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
from modules.utils.subtitle_manager import *
|
| 8 |
+
|
| 9 |
+
"""
|
| 10 |
+
This is written with reference to the DeepL API documentation.
|
| 11 |
+
If you want to know the information of the DeepL API, see here: https://www.deepl.com/docs-api/documents
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
DEEPL_AVAILABLE_TARGET_LANGS = {
|
| 15 |
+
'Bulgarian': 'BG',
|
| 16 |
+
'Czech': 'CS',
|
| 17 |
+
'Danish': 'DA',
|
| 18 |
+
'German': 'DE',
|
| 19 |
+
'Greek': 'EL',
|
| 20 |
+
'English': 'EN',
|
| 21 |
+
'English (British)': 'EN-GB',
|
| 22 |
+
'English (American)': 'EN-US',
|
| 23 |
+
'Spanish': 'ES',
|
| 24 |
+
'Estonian': 'ET',
|
| 25 |
+
'Finnish': 'FI',
|
| 26 |
+
'French': 'FR',
|
| 27 |
+
'Hungarian': 'HU',
|
| 28 |
+
'Indonesian': 'ID',
|
| 29 |
+
'Italian': 'IT',
|
| 30 |
+
'Japanese': 'JA',
|
| 31 |
+
'Korean': 'KO',
|
| 32 |
+
'Lithuanian': 'LT',
|
| 33 |
+
'Latvian': 'LV',
|
| 34 |
+
'Norwegian (Bokmål)': 'NB',
|
| 35 |
+
'Dutch': 'NL',
|
| 36 |
+
'Polish': 'PL',
|
| 37 |
+
'Portuguese': 'PT',
|
| 38 |
+
'Portuguese (Brazilian)': 'PT-BR',
|
| 39 |
+
'Portuguese (all Portuguese varieties excluding Brazilian Portuguese)': 'PT-PT',
|
| 40 |
+
'Romanian': 'RO',
|
| 41 |
+
'Russian': 'RU',
|
| 42 |
+
'Slovak': 'SK',
|
| 43 |
+
'Slovenian': 'SL',
|
| 44 |
+
'Swedish': 'SV',
|
| 45 |
+
'Turkish': 'TR',
|
| 46 |
+
'Ukrainian': 'UK',
|
| 47 |
+
'Chinese (simplified)': 'ZH'
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
DEEPL_AVAILABLE_SOURCE_LANGS = {
|
| 51 |
+
'Automatic Detection': None,
|
| 52 |
+
'Bulgarian': 'BG',
|
| 53 |
+
'Czech': 'CS',
|
| 54 |
+
'Danish': 'DA',
|
| 55 |
+
'German': 'DE',
|
| 56 |
+
'Greek': 'EL',
|
| 57 |
+
'English': 'EN',
|
| 58 |
+
'Spanish': 'ES',
|
| 59 |
+
'Estonian': 'ET',
|
| 60 |
+
'Finnish': 'FI',
|
| 61 |
+
'French': 'FR',
|
| 62 |
+
'Hungarian': 'HU',
|
| 63 |
+
'Indonesian': 'ID',
|
| 64 |
+
'Italian': 'IT',
|
| 65 |
+
'Japanese': 'JA',
|
| 66 |
+
'Korean': 'KO',
|
| 67 |
+
'Lithuanian': 'LT',
|
| 68 |
+
'Latvian': 'LV',
|
| 69 |
+
'Norwegian (Bokmål)': 'NB',
|
| 70 |
+
'Dutch': 'NL',
|
| 71 |
+
'Polish': 'PL',
|
| 72 |
+
'Portuguese (all Portuguese varieties mixed)': 'PT',
|
| 73 |
+
'Romanian': 'RO',
|
| 74 |
+
'Russian': 'RU',
|
| 75 |
+
'Slovak': 'SK',
|
| 76 |
+
'Slovenian': 'SL',
|
| 77 |
+
'Swedish': 'SV',
|
| 78 |
+
'Turkish': 'TR',
|
| 79 |
+
'Ukrainian': 'UK',
|
| 80 |
+
'Chinese': 'ZH'
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class DeepLAPI:
|
| 85 |
+
def __init__(self,
|
| 86 |
+
output_dir: str = os.path.join("outputs", "translations")
|
| 87 |
+
):
|
| 88 |
+
self.api_interval = 1
|
| 89 |
+
self.max_text_batch_size = 50
|
| 90 |
+
self.available_target_langs = DEEPL_AVAILABLE_TARGET_LANGS
|
| 91 |
+
self.available_source_langs = DEEPL_AVAILABLE_SOURCE_LANGS
|
| 92 |
+
self.output_dir = output_dir
|
| 93 |
+
|
| 94 |
+
def translate_deepl(self,
|
| 95 |
+
auth_key: str,
|
| 96 |
+
fileobjs: list,
|
| 97 |
+
source_lang: str,
|
| 98 |
+
target_lang: str,
|
| 99 |
+
is_pro: bool,
|
| 100 |
+
add_timestamp: bool,
|
| 101 |
+
progress=gr.Progress()) -> list:
|
| 102 |
+
"""
|
| 103 |
+
Translate subtitle files using DeepL API
|
| 104 |
+
Parameters
|
| 105 |
+
----------
|
| 106 |
+
auth_key: str
|
| 107 |
+
API Key for DeepL from gr.Textbox()
|
| 108 |
+
fileobjs: list
|
| 109 |
+
List of files to transcribe from gr.Files()
|
| 110 |
+
source_lang: str
|
| 111 |
+
Source language of the file to transcribe from gr.Dropdown()
|
| 112 |
+
target_lang: str
|
| 113 |
+
Target language of the file to transcribe from gr.Dropdown()
|
| 114 |
+
is_pro: str
|
| 115 |
+
Boolean value that is about pro user or not from gr.Checkbox().
|
| 116 |
+
add_timestamp: bool
|
| 117 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 118 |
+
progress: gr.Progress
|
| 119 |
+
Indicator to show progress directly in gradio.
|
| 120 |
+
|
| 121 |
+
Returns
|
| 122 |
+
----------
|
| 123 |
+
A List of
|
| 124 |
+
String to return to gr.Textbox()
|
| 125 |
+
Files to return to gr.Files()
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
files_info = {}
|
| 129 |
+
for fileobj in fileobjs:
|
| 130 |
+
file_path = fileobj.name
|
| 131 |
+
file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
|
| 132 |
+
|
| 133 |
+
if file_ext == ".srt":
|
| 134 |
+
parsed_dicts = parse_srt(file_path=file_path)
|
| 135 |
+
|
| 136 |
+
batch_size = self.max_text_batch_size
|
| 137 |
+
for batch_start in range(0, len(parsed_dicts), batch_size):
|
| 138 |
+
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
| 139 |
+
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
| 140 |
+
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
| 141 |
+
target_lang, is_pro)
|
| 142 |
+
for i, translated_text in enumerate(translated_texts):
|
| 143 |
+
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
| 144 |
+
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
| 145 |
+
|
| 146 |
+
subtitle = get_serialized_srt(parsed_dicts)
|
| 147 |
+
|
| 148 |
+
elif file_ext == ".vtt":
|
| 149 |
+
parsed_dicts = parse_vtt(file_path=file_path)
|
| 150 |
+
|
| 151 |
+
batch_size = self.max_text_batch_size
|
| 152 |
+
for batch_start in range(0, len(parsed_dicts), batch_size):
|
| 153 |
+
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
| 154 |
+
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
| 155 |
+
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
| 156 |
+
target_lang, is_pro)
|
| 157 |
+
for i, translated_text in enumerate(translated_texts):
|
| 158 |
+
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
| 159 |
+
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
| 160 |
+
|
| 161 |
+
subtitle = get_serialized_vtt(parsed_dicts)
|
| 162 |
+
|
| 163 |
+
if add_timestamp:
|
| 164 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 165 |
+
file_name += f"-{timestamp}"
|
| 166 |
+
|
| 167 |
+
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
| 168 |
+
write_file(subtitle, output_path)
|
| 169 |
+
|
| 170 |
+
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
| 171 |
+
|
| 172 |
+
total_result = ''
|
| 173 |
+
for file_name, info in files_info.items():
|
| 174 |
+
total_result += '------------------------------------\n'
|
| 175 |
+
total_result += f'{file_name}\n\n'
|
| 176 |
+
total_result += f'{info["subtitle"]}'
|
| 177 |
+
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
| 178 |
+
|
| 179 |
+
output_file_paths = [item["path"] for key, item in files_info.items()]
|
| 180 |
+
return [gr_str, output_file_paths]
|
| 181 |
+
|
| 182 |
+
def request_deepl_translate(self,
|
| 183 |
+
auth_key: str,
|
| 184 |
+
text: list,
|
| 185 |
+
source_lang: str,
|
| 186 |
+
target_lang: str,
|
| 187 |
+
is_pro: bool):
|
| 188 |
+
"""Request API response to DeepL server"""
|
| 189 |
+
|
| 190 |
+
url = 'https://api.deepl.com/v2/translate' if is_pro else 'https://api-free.deepl.com/v2/translate'
|
| 191 |
+
headers = {
|
| 192 |
+
'Authorization': f'DeepL-Auth-Key {auth_key}'
|
| 193 |
+
}
|
| 194 |
+
data = {
|
| 195 |
+
'text': text,
|
| 196 |
+
'source_lang': DEEPL_AVAILABLE_SOURCE_LANGS[source_lang],
|
| 197 |
+
'target_lang': DEEPL_AVAILABLE_TARGET_LANGS[target_lang]
|
| 198 |
+
}
|
| 199 |
+
response = requests.post(url, headers=headers, data=data).json()
|
| 200 |
+
time.sleep(self.api_interval)
|
| 201 |
+
return response["translations"]
|
modules/translation/nllb_inference.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
from modules.translation.translation_base import TranslationBase
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NLLBInference(TranslationBase):
|
| 9 |
+
def __init__(self,
|
| 10 |
+
model_dir: str = os.path.join("models", "NLLB"),
|
| 11 |
+
output_dir: str = os.path.join("outputs", "translations")
|
| 12 |
+
):
|
| 13 |
+
super().__init__(
|
| 14 |
+
model_dir=model_dir,
|
| 15 |
+
output_dir=output_dir
|
| 16 |
+
)
|
| 17 |
+
self.tokenizer = None
|
| 18 |
+
self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
|
| 19 |
+
self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
| 20 |
+
self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
| 21 |
+
self.pipeline = None
|
| 22 |
+
|
| 23 |
+
def translate(self,
|
| 24 |
+
text: str,
|
| 25 |
+
max_length: int
|
| 26 |
+
):
|
| 27 |
+
result = self.pipeline(
|
| 28 |
+
text,
|
| 29 |
+
max_length=max_length
|
| 30 |
+
)
|
| 31 |
+
return result[0]['translation_text']
|
| 32 |
+
|
| 33 |
+
def update_model(self,
|
| 34 |
+
model_size: str,
|
| 35 |
+
src_lang: str,
|
| 36 |
+
tgt_lang: str,
|
| 37 |
+
progress: gr.Progress
|
| 38 |
+
):
|
| 39 |
+
if model_size != self.current_model_size or self.model is None:
|
| 40 |
+
print("\nInitializing NLLB Model..\n")
|
| 41 |
+
progress(0, desc="Initializing NLLB Model..")
|
| 42 |
+
self.current_model_size = model_size
|
| 43 |
+
local_files_only = self.is_model_exists(self.current_model_size)
|
| 44 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 45 |
+
cache_dir=self.model_dir,
|
| 46 |
+
local_files_only=local_files_only)
|
| 47 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
| 48 |
+
cache_dir=os.path.join(self.model_dir, "tokenizers"),
|
| 49 |
+
local_files_only=local_files_only)
|
| 50 |
+
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
| 51 |
+
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
| 52 |
+
self.pipeline = pipeline("translation",
|
| 53 |
+
model=self.model,
|
| 54 |
+
tokenizer=self.tokenizer,
|
| 55 |
+
src_lang=src_lang,
|
| 56 |
+
tgt_lang=tgt_lang,
|
| 57 |
+
device=self.device)
|
| 58 |
+
|
| 59 |
+
def is_model_exists(self,
|
| 60 |
+
model_size: str):
|
| 61 |
+
"""Check if model exists or not (Only facebook model)"""
|
| 62 |
+
prefix = "models--facebook--"
|
| 63 |
+
_id, model_size_name = model_size.split("/")
|
| 64 |
+
model_dir_name = prefix + model_size_name
|
| 65 |
+
model_dir_path = os.path.join(self.model_dir, model_dir_name)
|
| 66 |
+
if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
|
| 67 |
+
return True
|
| 68 |
+
return False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
NLLB_AVAILABLE_LANGS = {
|
| 72 |
+
"Acehnese (Arabic script)": "ace_Arab",
|
| 73 |
+
"Acehnese (Latin script)": "ace_Latn",
|
| 74 |
+
"Mesopotamian Arabic": "acm_Arab",
|
| 75 |
+
"Ta’izzi-Adeni Arabic": "acq_Arab",
|
| 76 |
+
"Tunisian Arabic": "aeb_Arab",
|
| 77 |
+
"Afrikaans": "afr_Latn",
|
| 78 |
+
"South Levantine Arabic": "ajp_Arab",
|
| 79 |
+
"Akan": "aka_Latn",
|
| 80 |
+
"Amharic": "amh_Ethi",
|
| 81 |
+
"North Levantine Arabic": "apc_Arab",
|
| 82 |
+
"Modern Standard Arabic": "arb_Arab",
|
| 83 |
+
"Modern Standard Arabic (Romanized)": "arb_Latn",
|
| 84 |
+
"Najdi Arabic": "ars_Arab",
|
| 85 |
+
"Moroccan Arabic": "ary_Arab",
|
| 86 |
+
"Egyptian Arabic": "arz_Arab",
|
| 87 |
+
"Assamese": "asm_Beng",
|
| 88 |
+
"Asturian": "ast_Latn",
|
| 89 |
+
"Awadhi": "awa_Deva",
|
| 90 |
+
"Central Aymara": "ayr_Latn",
|
| 91 |
+
"South Azerbaijani": "azb_Arab",
|
| 92 |
+
"North Azerbaijani": "azj_Latn",
|
| 93 |
+
"Bashkir": "bak_Cyrl",
|
| 94 |
+
"Bambara": "bam_Latn",
|
| 95 |
+
"Balinese": "ban_Latn",
|
| 96 |
+
"Belarusian": "bel_Cyrl",
|
| 97 |
+
"Bemba": "bem_Latn",
|
| 98 |
+
"Bengali": "ben_Beng",
|
| 99 |
+
"Bhojpuri": "bho_Deva",
|
| 100 |
+
"Banjar (Arabic script)": "bjn_Arab",
|
| 101 |
+
"Banjar (Latin script)": "bjn_Latn",
|
| 102 |
+
"Standard Tibetan": "bod_Tibt",
|
| 103 |
+
"Bosnian": "bos_Latn",
|
| 104 |
+
"Buginese": "bug_Latn",
|
| 105 |
+
"Bulgarian": "bul_Cyrl",
|
| 106 |
+
"Catalan": "cat_Latn",
|
| 107 |
+
"Cebuano": "ceb_Latn",
|
| 108 |
+
"Czech": "ces_Latn",
|
| 109 |
+
"Chokwe": "cjk_Latn",
|
| 110 |
+
"Central Kurdish": "ckb_Arab",
|
| 111 |
+
"Crimean Tatar": "crh_Latn",
|
| 112 |
+
"Welsh": "cym_Latn",
|
| 113 |
+
"Danish": "dan_Latn",
|
| 114 |
+
"German": "deu_Latn",
|
| 115 |
+
"Southwestern Dinka": "dik_Latn",
|
| 116 |
+
"Dyula": "dyu_Latn",
|
| 117 |
+
"Dzongkha": "dzo_Tibt",
|
| 118 |
+
"Greek": "ell_Grek",
|
| 119 |
+
"English": "eng_Latn",
|
| 120 |
+
"Esperanto": "epo_Latn",
|
| 121 |
+
"Estonian": "est_Latn",
|
| 122 |
+
"Basque": "eus_Latn",
|
| 123 |
+
"Ewe": "ewe_Latn",
|
| 124 |
+
"Faroese": "fao_Latn",
|
| 125 |
+
"Fijian": "fij_Latn",
|
| 126 |
+
"Finnish": "fin_Latn",
|
| 127 |
+
"Fon": "fon_Latn",
|
| 128 |
+
"French": "fra_Latn",
|
| 129 |
+
"Friulian": "fur_Latn",
|
| 130 |
+
"Nigerian Fulfulde": "fuv_Latn",
|
| 131 |
+
"Scottish Gaelic": "gla_Latn",
|
| 132 |
+
"Irish": "gle_Latn",
|
| 133 |
+
"Galician": "glg_Latn",
|
| 134 |
+
"Guarani": "grn_Latn",
|
| 135 |
+
"Gujarati": "guj_Gujr",
|
| 136 |
+
"Haitian Creole": "hat_Latn",
|
| 137 |
+
"Hausa": "hau_Latn",
|
| 138 |
+
"Hebrew": "heb_Hebr",
|
| 139 |
+
"Hindi": "hin_Deva",
|
| 140 |
+
"Chhattisgarhi": "hne_Deva",
|
| 141 |
+
"Croatian": "hrv_Latn",
|
| 142 |
+
"Hungarian": "hun_Latn",
|
| 143 |
+
"Armenian": "hye_Armn",
|
| 144 |
+
"Igbo": "ibo_Latn",
|
| 145 |
+
"Ilocano": "ilo_Latn",
|
| 146 |
+
"Indonesian": "ind_Latn",
|
| 147 |
+
"Icelandic": "isl_Latn",
|
| 148 |
+
"Italian": "ita_Latn",
|
| 149 |
+
"Javanese": "jav_Latn",
|
| 150 |
+
"Japanese": "jpn_Jpan",
|
| 151 |
+
"Kabyle": "kab_Latn",
|
| 152 |
+
"Jingpho": "kac_Latn",
|
| 153 |
+
"Kamba": "kam_Latn",
|
| 154 |
+
"Kannada": "kan_Knda",
|
| 155 |
+
"Kashmiri (Arabic script)": "kas_Arab",
|
| 156 |
+
"Kashmiri (Devanagari script)": "kas_Deva",
|
| 157 |
+
"Georgian": "kat_Geor",
|
| 158 |
+
"Central Kanuri (Arabic script)": "knc_Arab",
|
| 159 |
+
"Central Kanuri (Latin script)": "knc_Latn",
|
| 160 |
+
"Kazakh": "kaz_Cyrl",
|
| 161 |
+
"Kabiyè": "kbp_Latn",
|
| 162 |
+
"Kabuverdianu": "kea_Latn",
|
| 163 |
+
"Khmer": "khm_Khmr",
|
| 164 |
+
"Kikuyu": "kik_Latn",
|
| 165 |
+
"Kinyarwanda": "kin_Latn",
|
| 166 |
+
"Kyrgyz": "kir_Cyrl",
|
| 167 |
+
"Kimbundu": "kmb_Latn",
|
| 168 |
+
"Northern Kurdish": "kmr_Latn",
|
| 169 |
+
"Kikongo": "kon_Latn",
|
| 170 |
+
"Korean": "kor_Hang",
|
| 171 |
+
"Lao": "lao_Laoo",
|
| 172 |
+
"Ligurian": "lij_Latn",
|
| 173 |
+
"Limburgish": "lim_Latn",
|
| 174 |
+
"Lingala": "lin_Latn",
|
| 175 |
+
"Lithuanian": "lit_Latn",
|
| 176 |
+
"Lombard": "lmo_Latn",
|
| 177 |
+
"Latgalian": "ltg_Latn",
|
| 178 |
+
"Luxembourgish": "ltz_Latn",
|
| 179 |
+
"Luba-Kasai": "lua_Latn",
|
| 180 |
+
"Ganda": "lug_Latn",
|
| 181 |
+
"Luo": "luo_Latn",
|
| 182 |
+
"Mizo": "lus_Latn",
|
| 183 |
+
"Standard Latvian": "lvs_Latn",
|
| 184 |
+
"Magahi": "mag_Deva",
|
| 185 |
+
"Maithili": "mai_Deva",
|
| 186 |
+
"Malayalam": "mal_Mlym",
|
| 187 |
+
"Marathi": "mar_Deva",
|
| 188 |
+
"Minangkabau (Arabic script)": "min_Arab",
|
| 189 |
+
"Minangkabau (Latin script)": "min_Latn",
|
| 190 |
+
"Macedonian": "mkd_Cyrl",
|
| 191 |
+
"Plateau Malagasy": "plt_Latn",
|
| 192 |
+
"Maltese": "mlt_Latn",
|
| 193 |
+
"Meitei (Bengali script)": "mni_Beng",
|
| 194 |
+
"Halh Mongolian": "khk_Cyrl",
|
| 195 |
+
"Mossi": "mos_Latn",
|
| 196 |
+
"Maori": "mri_Latn",
|
| 197 |
+
"Burmese": "mya_Mymr",
|
| 198 |
+
"Dutch": "nld_Latn",
|
| 199 |
+
"Norwegian Nynorsk": "nno_Latn",
|
| 200 |
+
"Norwegian Bokmål": "nob_Latn",
|
| 201 |
+
"Nepali": "npi_Deva",
|
| 202 |
+
"Northern Sotho": "nso_Latn",
|
| 203 |
+
"Nuer": "nus_Latn",
|
| 204 |
+
"Nyanja": "nya_Latn",
|
| 205 |
+
"Occitan": "oci_Latn",
|
| 206 |
+
"West Central Oromo": "gaz_Latn",
|
| 207 |
+
"Odia": "ory_Orya",
|
| 208 |
+
"Pangasinan": "pag_Latn",
|
| 209 |
+
"Eastern Panjabi": "pan_Guru",
|
| 210 |
+
"Papiamento": "pap_Latn",
|
| 211 |
+
"Western Persian": "pes_Arab",
|
| 212 |
+
"Polish": "pol_Latn",
|
| 213 |
+
"Portuguese": "por_Latn",
|
| 214 |
+
"Dari": "prs_Arab",
|
| 215 |
+
"Southern Pashto": "pbt_Arab",
|
| 216 |
+
"Ayacucho Quechua": "quy_Latn",
|
| 217 |
+
"Romanian": "ron_Latn",
|
| 218 |
+
"Rundi": "run_Latn",
|
| 219 |
+
"Russian": "rus_Cyrl",
|
| 220 |
+
"Sango": "sag_Latn",
|
| 221 |
+
"Sanskrit": "san_Deva",
|
| 222 |
+
"Santali": "sat_Olck",
|
| 223 |
+
"Sicilian": "scn_Latn",
|
| 224 |
+
"Shan": "shn_Mymr",
|
| 225 |
+
"Sinhala": "sin_Sinh",
|
| 226 |
+
"Slovak": "slk_Latn",
|
| 227 |
+
"Slovenian": "slv_Latn",
|
| 228 |
+
"Samoan": "smo_Latn",
|
| 229 |
+
"Shona": "sna_Latn",
|
| 230 |
+
"Sindhi": "snd_Arab",
|
| 231 |
+
"Somali": "som_Latn",
|
| 232 |
+
"Southern Sotho": "sot_Latn",
|
| 233 |
+
"Spanish": "spa_Latn",
|
| 234 |
+
"Tosk Albanian": "als_Latn",
|
| 235 |
+
"Sardinian": "srd_Latn",
|
| 236 |
+
"Serbian": "srp_Cyrl",
|
| 237 |
+
"Swati": "ssw_Latn",
|
| 238 |
+
"Sundanese": "sun_Latn",
|
| 239 |
+
"Swedish": "swe_Latn",
|
| 240 |
+
"Swahili": "swh_Latn",
|
| 241 |
+
"Silesian": "szl_Latn",
|
| 242 |
+
"Tamil": "tam_Taml",
|
| 243 |
+
"Tatar": "tat_Cyrl",
|
| 244 |
+
"Telugu": "tel_Telu",
|
| 245 |
+
"Tajik": "tgk_Cyrl",
|
| 246 |
+
"Tagalog": "tgl_Latn",
|
| 247 |
+
"Thai": "tha_Thai",
|
| 248 |
+
"Tigrinya": "tir_Ethi",
|
| 249 |
+
"Tamasheq (Latin script)": "taq_Latn",
|
| 250 |
+
"Tamasheq (Tifinagh script)": "taq_Tfng",
|
| 251 |
+
"Tok Pisin": "tpi_Latn",
|
| 252 |
+
"Tswana": "tsn_Latn",
|
| 253 |
+
"Tsonga": "tso_Latn",
|
| 254 |
+
"Turkmen": "tuk_Latn",
|
| 255 |
+
"Tumbuka": "tum_Latn",
|
| 256 |
+
"Turkish": "tur_Latn",
|
| 257 |
+
"Twi": "twi_Latn",
|
| 258 |
+
"Central Atlas Tamazight": "tzm_Tfng",
|
| 259 |
+
"Uyghur": "uig_Arab",
|
| 260 |
+
"Ukrainian": "ukr_Cyrl",
|
| 261 |
+
"Umbundu": "umb_Latn",
|
| 262 |
+
"Urdu": "urd_Arab",
|
| 263 |
+
"Northern Uzbek": "uzn_Latn",
|
| 264 |
+
"Venetian": "vec_Latn",
|
| 265 |
+
"Vietnamese": "vie_Latn",
|
| 266 |
+
"Waray": "war_Latn",
|
| 267 |
+
"Wolof": "wol_Latn",
|
| 268 |
+
"Xhosa": "xho_Latn",
|
| 269 |
+
"Eastern Yiddish": "ydd_Hebr",
|
| 270 |
+
"Yoruba": "yor_Latn",
|
| 271 |
+
"Yue Chinese": "yue_Hant",
|
| 272 |
+
"Chinese (Simplified)": "zho_Hans",
|
| 273 |
+
"Chinese (Traditional)": "zho_Hant",
|
| 274 |
+
"Standard Malay": "zsm_Latn",
|
| 275 |
+
"Zulu": "zul_Latn",
|
| 276 |
+
}
|
modules/translation/translation_base.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
from typing import List
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
from modules.whisper.whisper_parameter import *
|
| 9 |
+
from modules.utils.subtitle_manager import *
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TranslationBase(ABC):
|
| 13 |
+
def __init__(self,
|
| 14 |
+
model_dir: str = os.path.join("models", "NLLB"),
|
| 15 |
+
output_dir: str = os.path.join("outputs", "translations")
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.model = None
|
| 19 |
+
self.model_dir = model_dir
|
| 20 |
+
self.output_dir = output_dir
|
| 21 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 22 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 23 |
+
self.current_model_size = None
|
| 24 |
+
self.device = self.get_device()
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def translate(self,
|
| 28 |
+
text: str,
|
| 29 |
+
max_length: int
|
| 30 |
+
):
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def update_model(self,
|
| 35 |
+
model_size: str,
|
| 36 |
+
src_lang: str,
|
| 37 |
+
tgt_lang: str,
|
| 38 |
+
progress: gr.Progress
|
| 39 |
+
):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
def translate_file(self,
|
| 43 |
+
fileobjs: list,
|
| 44 |
+
model_size: str,
|
| 45 |
+
src_lang: str,
|
| 46 |
+
tgt_lang: str,
|
| 47 |
+
max_length: int,
|
| 48 |
+
add_timestamp: bool,
|
| 49 |
+
progress=gr.Progress()) -> list:
|
| 50 |
+
"""
|
| 51 |
+
Translate subtitle file from source language to target language
|
| 52 |
+
|
| 53 |
+
Parameters
|
| 54 |
+
----------
|
| 55 |
+
fileobjs: list
|
| 56 |
+
List of files to transcribe from gr.Files()
|
| 57 |
+
model_size: str
|
| 58 |
+
Whisper model size from gr.Dropdown()
|
| 59 |
+
src_lang: str
|
| 60 |
+
Source language of the file to translate from gr.Dropdown()
|
| 61 |
+
tgt_lang: str
|
| 62 |
+
Target language of the file to translate from gr.Dropdown()
|
| 63 |
+
max_length: int
|
| 64 |
+
Max length per line to translate
|
| 65 |
+
add_timestamp: bool
|
| 66 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 67 |
+
progress: gr.Progress
|
| 68 |
+
Indicator to show progress directly in gradio.
|
| 69 |
+
I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
|
| 70 |
+
|
| 71 |
+
Returns
|
| 72 |
+
----------
|
| 73 |
+
A List of
|
| 74 |
+
String to return to gr.Textbox()
|
| 75 |
+
Files to return to gr.Files()
|
| 76 |
+
"""
|
| 77 |
+
try:
|
| 78 |
+
self.update_model(model_size=model_size,
|
| 79 |
+
src_lang=src_lang,
|
| 80 |
+
tgt_lang=tgt_lang,
|
| 81 |
+
progress=progress)
|
| 82 |
+
|
| 83 |
+
files_info = {}
|
| 84 |
+
for fileobj in fileobjs:
|
| 85 |
+
file_path = fileobj.name
|
| 86 |
+
file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
|
| 87 |
+
if file_ext == ".srt":
|
| 88 |
+
parsed_dicts = parse_srt(file_path=file_path)
|
| 89 |
+
total_progress = len(parsed_dicts)
|
| 90 |
+
for index, dic in enumerate(parsed_dicts):
|
| 91 |
+
progress(index / total_progress, desc="Translating..")
|
| 92 |
+
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
| 93 |
+
dic["sentence"] = translated_text
|
| 94 |
+
subtitle = get_serialized_srt(parsed_dicts)
|
| 95 |
+
|
| 96 |
+
elif file_ext == ".vtt":
|
| 97 |
+
parsed_dicts = parse_vtt(file_path=file_path)
|
| 98 |
+
total_progress = len(parsed_dicts)
|
| 99 |
+
for index, dic in enumerate(parsed_dicts):
|
| 100 |
+
progress(index / total_progress, desc="Translating..")
|
| 101 |
+
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
| 102 |
+
dic["sentence"] = translated_text
|
| 103 |
+
subtitle = get_serialized_vtt(parsed_dicts)
|
| 104 |
+
|
| 105 |
+
if add_timestamp:
|
| 106 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 107 |
+
file_name += f"-{timestamp}"
|
| 108 |
+
|
| 109 |
+
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
| 110 |
+
write_file(subtitle, output_path)
|
| 111 |
+
|
| 112 |
+
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
| 113 |
+
|
| 114 |
+
total_result = ''
|
| 115 |
+
for file_name, info in files_info.items():
|
| 116 |
+
total_result += '------------------------------------\n'
|
| 117 |
+
total_result += f'{file_name}\n\n'
|
| 118 |
+
total_result += f'{info["subtitle"]}'
|
| 119 |
+
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
| 120 |
+
|
| 121 |
+
output_file_paths = [item["path"] for key, item in files_info.items()]
|
| 122 |
+
return [gr_str, output_file_paths]
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"Error: {str(e)}")
|
| 126 |
+
finally:
|
| 127 |
+
self.release_cuda_memory()
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def get_device():
|
| 131 |
+
if torch.cuda.is_available():
|
| 132 |
+
return "cuda"
|
| 133 |
+
elif torch.backends.mps.is_available():
|
| 134 |
+
return "mps"
|
| 135 |
+
else:
|
| 136 |
+
return "cpu"
|
| 137 |
+
|
| 138 |
+
@staticmethod
|
| 139 |
+
def release_cuda_memory():
|
| 140 |
+
if torch.cuda.is_available():
|
| 141 |
+
torch.cuda.empty_cache()
|
| 142 |
+
torch.cuda.reset_max_memory_allocated()
|
| 143 |
+
|
| 144 |
+
@staticmethod
|
| 145 |
+
def remove_input_files(file_paths: List[str]):
|
| 146 |
+
if not file_paths:
|
| 147 |
+
return
|
| 148 |
+
|
| 149 |
+
for file_path in file_paths:
|
| 150 |
+
if file_path and os.path.exists(file_path):
|
| 151 |
+
os.remove(file_path)
|
modules/utils/__init__.py
ADDED
|
File without changes
|
modules/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (173 Bytes). View file
|
|
|
modules/utils/__pycache__/files_manager.cpython-310.pyc
ADDED
|
Binary file (1.43 kB). View file
|
|
|
modules/utils/__pycache__/subtitle_manager.cpython-310.pyc
ADDED
|
Binary file (3.38 kB). View file
|
|
|
modules/utils/__pycache__/youtube_manager.cpython-310.pyc
ADDED
|
Binary file (748 Bytes). View file
|
|
|
modules/utils/files_manager.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import fnmatch
|
| 3 |
+
|
| 4 |
+
from gradio.utils import NamedString
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_media_files(folder_path, include_sub_directory=False):
|
| 8 |
+
video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
|
| 9 |
+
audio_extensions = ['*.mp3', '*.wav', '*.aac', '*.flac', '*.ogg', '*.m4a']
|
| 10 |
+
media_extensions = video_extensions + audio_extensions
|
| 11 |
+
|
| 12 |
+
media_files = []
|
| 13 |
+
|
| 14 |
+
if include_sub_directory:
|
| 15 |
+
for root, _, files in os.walk(folder_path):
|
| 16 |
+
for extension in media_extensions:
|
| 17 |
+
media_files.extend(
|
| 18 |
+
os.path.join(root, file) for file in fnmatch.filter(files, extension)
|
| 19 |
+
if os.path.exists(os.path.join(root, file))
|
| 20 |
+
)
|
| 21 |
+
else:
|
| 22 |
+
for extension in media_extensions:
|
| 23 |
+
media_files.extend(
|
| 24 |
+
os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension)
|
| 25 |
+
if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file))
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
return media_files
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def format_gradio_files(files: list):
|
| 32 |
+
if not files:
|
| 33 |
+
return files
|
| 34 |
+
|
| 35 |
+
gradio_files = []
|
| 36 |
+
for file in files:
|
| 37 |
+
gradio_files.append(NamedString(file))
|
| 38 |
+
return gradio_files
|
| 39 |
+
|
modules/utils/subtitle_manager.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def timeformat_srt(time):
|
| 5 |
+
hours = time // 3600
|
| 6 |
+
minutes = (time - hours * 3600) // 60
|
| 7 |
+
seconds = time - hours * 3600 - minutes * 60
|
| 8 |
+
milliseconds = (time - int(time)) * 1000
|
| 9 |
+
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def timeformat_vtt(time):
|
| 13 |
+
hours = time // 3600
|
| 14 |
+
minutes = (time - hours * 3600) // 60
|
| 15 |
+
seconds = time - hours * 3600 - minutes * 60
|
| 16 |
+
milliseconds = (time - int(time)) * 1000
|
| 17 |
+
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def write_file(subtitle, output_file):
|
| 21 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 22 |
+
f.write(subtitle)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_srt(segments):
|
| 26 |
+
output = ""
|
| 27 |
+
for i, segment in enumerate(segments):
|
| 28 |
+
output += f"{i + 1}\n"
|
| 29 |
+
output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
|
| 30 |
+
if segment['text'].startswith(' '):
|
| 31 |
+
segment['text'] = segment['text'][1:]
|
| 32 |
+
output += f"{segment['text']}\n\n"
|
| 33 |
+
return output
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_vtt(segments):
|
| 37 |
+
output = "WebVTT\n\n"
|
| 38 |
+
for i, segment in enumerate(segments):
|
| 39 |
+
output += f"{i + 1}\n"
|
| 40 |
+
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
|
| 41 |
+
if segment['text'].startswith(' '):
|
| 42 |
+
segment['text'] = segment['text'][1:]
|
| 43 |
+
output += f"{segment['text']}\n\n"
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_txt(segments):
|
| 48 |
+
output = ""
|
| 49 |
+
for i, segment in enumerate(segments):
|
| 50 |
+
if segment['text'].startswith(' '):
|
| 51 |
+
segment['text'] = segment['text'][1:]
|
| 52 |
+
output += f"{segment['text']}\n"
|
| 53 |
+
return output
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def parse_srt(file_path):
|
| 57 |
+
"""Reads SRT file and returns as dict"""
|
| 58 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 59 |
+
srt_data = file.read()
|
| 60 |
+
|
| 61 |
+
data = []
|
| 62 |
+
blocks = srt_data.split('\n\n')
|
| 63 |
+
|
| 64 |
+
for block in blocks:
|
| 65 |
+
if block.strip() != '':
|
| 66 |
+
lines = block.strip().split('\n')
|
| 67 |
+
index = lines[0]
|
| 68 |
+
timestamp = lines[1]
|
| 69 |
+
sentence = ' '.join(lines[2:])
|
| 70 |
+
|
| 71 |
+
data.append({
|
| 72 |
+
"index": index,
|
| 73 |
+
"timestamp": timestamp,
|
| 74 |
+
"sentence": sentence
|
| 75 |
+
})
|
| 76 |
+
return data
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def parse_vtt(file_path):
|
| 80 |
+
"""Reads WebVTT file and returns as dict"""
|
| 81 |
+
with open(file_path, 'r', encoding='utf-8') as file:
|
| 82 |
+
webvtt_data = file.read()
|
| 83 |
+
|
| 84 |
+
data = []
|
| 85 |
+
blocks = webvtt_data.split('\n\n')
|
| 86 |
+
|
| 87 |
+
for block in blocks:
|
| 88 |
+
if block.strip() != '' and not block.strip().startswith("WebVTT"):
|
| 89 |
+
lines = block.strip().split('\n')
|
| 90 |
+
index = lines[0]
|
| 91 |
+
timestamp = lines[1]
|
| 92 |
+
sentence = ' '.join(lines[2:])
|
| 93 |
+
|
| 94 |
+
data.append({
|
| 95 |
+
"index": index,
|
| 96 |
+
"timestamp": timestamp,
|
| 97 |
+
"sentence": sentence
|
| 98 |
+
})
|
| 99 |
+
|
| 100 |
+
return data
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_serialized_srt(dicts):
|
| 104 |
+
output = ""
|
| 105 |
+
for dic in dicts:
|
| 106 |
+
output += f'{dic["index"]}\n'
|
| 107 |
+
output += f'{dic["timestamp"]}\n'
|
| 108 |
+
output += f'{dic["sentence"]}\n\n'
|
| 109 |
+
return output
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_serialized_vtt(dicts):
|
| 113 |
+
output = "WebVTT\n\n"
|
| 114 |
+
for dic in dicts:
|
| 115 |
+
output += f'{dic["index"]}\n'
|
| 116 |
+
output += f'{dic["timestamp"]}\n'
|
| 117 |
+
output += f'{dic["sentence"]}\n\n'
|
| 118 |
+
return output
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def safe_filename(name):
|
| 122 |
+
from app import _args
|
| 123 |
+
INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
|
| 124 |
+
safe_name = re.sub(INVALID_FILENAME_CHARS, '_', name)
|
| 125 |
+
if not _args.colab:
|
| 126 |
+
return safe_name
|
| 127 |
+
# Truncate the filename if it exceeds the max_length (20)
|
| 128 |
+
if len(safe_name) > 20:
|
| 129 |
+
file_extension = safe_name.split('.')[-1]
|
| 130 |
+
if len(file_extension) + 1 < 20:
|
| 131 |
+
truncated_name = safe_name[:20 - len(file_extension) - 1]
|
| 132 |
+
safe_name = truncated_name + '.' + file_extension
|
| 133 |
+
else:
|
| 134 |
+
safe_name = safe_name[:20]
|
| 135 |
+
return safe_name
|
modules/utils/youtube_manager.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pytubefix import YouTube
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_ytdata(link):
|
| 6 |
+
return YouTube(link)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_ytmetas(link):
|
| 10 |
+
yt = YouTube(link)
|
| 11 |
+
return yt.thumbnail_url, yt.title, yt.description
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_ytaudio(ytdata: YouTube):
|
| 15 |
+
return ytdata.streams.get_audio_only().download(filename=os.path.join("modules", "yt_tmp.wav"))
|
modules/vad/__init__.py
ADDED
|
File without changes
|
modules/vad/silero_vad.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
|
| 2 |
+
|
| 3 |
+
from faster_whisper.vad import VadOptions, get_vad_model
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import BinaryIO, Union, List, Optional, Tuple
|
| 6 |
+
import warnings
|
| 7 |
+
import faster_whisper
|
| 8 |
+
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
|
| 9 |
+
import gradio as gr
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SileroVAD:
|
| 13 |
+
def __init__(self):
|
| 14 |
+
self.sampling_rate = 16000
|
| 15 |
+
self.window_size_samples = 512
|
| 16 |
+
self.model = None
|
| 17 |
+
|
| 18 |
+
def run(self,
|
| 19 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
| 20 |
+
vad_parameters: VadOptions,
|
| 21 |
+
progress: gr.Progress = gr.Progress()
|
| 22 |
+
) -> Tuple[np.ndarray, List[dict]]:
|
| 23 |
+
"""
|
| 24 |
+
Run VAD
|
| 25 |
+
|
| 26 |
+
Parameters
|
| 27 |
+
----------
|
| 28 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 29 |
+
Audio path or file binary or Audio numpy array
|
| 30 |
+
vad_parameters:
|
| 31 |
+
Options for VAD processing.
|
| 32 |
+
progress: gr.Progress
|
| 33 |
+
Indicator to show progress directly in gradio.
|
| 34 |
+
|
| 35 |
+
Returns
|
| 36 |
+
----------
|
| 37 |
+
np.ndarray
|
| 38 |
+
Pre-processed audio with VAD
|
| 39 |
+
List[dict]
|
| 40 |
+
Chunks of speeches to be used to restore the timestamps later
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
sampling_rate = self.sampling_rate
|
| 44 |
+
|
| 45 |
+
if not isinstance(audio, np.ndarray):
|
| 46 |
+
audio = faster_whisper.decode_audio(audio, sampling_rate=sampling_rate)
|
| 47 |
+
|
| 48 |
+
duration = audio.shape[0] / sampling_rate
|
| 49 |
+
duration_after_vad = duration
|
| 50 |
+
|
| 51 |
+
if vad_parameters is None:
|
| 52 |
+
vad_parameters = VadOptions()
|
| 53 |
+
elif isinstance(vad_parameters, dict):
|
| 54 |
+
vad_parameters = VadOptions(**vad_parameters)
|
| 55 |
+
speech_chunks = self.get_speech_timestamps(
|
| 56 |
+
audio=audio,
|
| 57 |
+
vad_options=vad_parameters,
|
| 58 |
+
progress=progress
|
| 59 |
+
)
|
| 60 |
+
audio = self.collect_chunks(audio, speech_chunks)
|
| 61 |
+
duration_after_vad = audio.shape[0] / sampling_rate
|
| 62 |
+
|
| 63 |
+
return audio, speech_chunks
|
| 64 |
+
|
| 65 |
+
def get_speech_timestamps(
|
| 66 |
+
self,
|
| 67 |
+
audio: np.ndarray,
|
| 68 |
+
vad_options: Optional[VadOptions] = None,
|
| 69 |
+
progress: gr.Progress = gr.Progress(),
|
| 70 |
+
**kwargs,
|
| 71 |
+
) -> List[dict]:
|
| 72 |
+
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
audio: One dimensional float array.
|
| 76 |
+
vad_options: Options for VAD processing.
|
| 77 |
+
kwargs: VAD options passed as keyword arguments for backward compatibility.
|
| 78 |
+
progress: Gradio progress to indicate progress.
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
List of dicts containing begin and end samples of each speech chunk.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
if self.model is None:
|
| 85 |
+
self.update_model()
|
| 86 |
+
|
| 87 |
+
if vad_options is None:
|
| 88 |
+
vad_options = VadOptions(**kwargs)
|
| 89 |
+
|
| 90 |
+
threshold = vad_options.threshold
|
| 91 |
+
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
| 92 |
+
max_speech_duration_s = vad_options.max_speech_duration_s
|
| 93 |
+
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
| 94 |
+
window_size_samples = self.window_size_samples
|
| 95 |
+
speech_pad_ms = vad_options.speech_pad_ms
|
| 96 |
+
sampling_rate = 16000
|
| 97 |
+
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
| 98 |
+
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 99 |
+
max_speech_samples = (
|
| 100 |
+
sampling_rate * max_speech_duration_s
|
| 101 |
+
- window_size_samples
|
| 102 |
+
- 2 * speech_pad_samples
|
| 103 |
+
)
|
| 104 |
+
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 105 |
+
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
| 106 |
+
|
| 107 |
+
audio_length_samples = len(audio)
|
| 108 |
+
|
| 109 |
+
state, context = self.model.get_initial_states(batch_size=1)
|
| 110 |
+
|
| 111 |
+
speech_probs = []
|
| 112 |
+
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
| 113 |
+
progress(current_start_sample/audio_length_samples, desc="Detecting speeches only using VAD...")
|
| 114 |
+
|
| 115 |
+
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
| 116 |
+
if len(chunk) < window_size_samples:
|
| 117 |
+
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
| 118 |
+
speech_prob, state, context = self.model(chunk, state, context, sampling_rate)
|
| 119 |
+
speech_probs.append(speech_prob)
|
| 120 |
+
|
| 121 |
+
triggered = False
|
| 122 |
+
speeches = []
|
| 123 |
+
current_speech = {}
|
| 124 |
+
neg_threshold = threshold - 0.15
|
| 125 |
+
|
| 126 |
+
# to save potential segment end (and tolerate some silence)
|
| 127 |
+
temp_end = 0
|
| 128 |
+
# to save potential segment limits in case of maximum segment size reached
|
| 129 |
+
prev_end = next_start = 0
|
| 130 |
+
|
| 131 |
+
for i, speech_prob in enumerate(speech_probs):
|
| 132 |
+
if (speech_prob >= threshold) and temp_end:
|
| 133 |
+
temp_end = 0
|
| 134 |
+
if next_start < prev_end:
|
| 135 |
+
next_start = window_size_samples * i
|
| 136 |
+
|
| 137 |
+
if (speech_prob >= threshold) and not triggered:
|
| 138 |
+
triggered = True
|
| 139 |
+
current_speech["start"] = window_size_samples * i
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
if (
|
| 143 |
+
triggered
|
| 144 |
+
and (window_size_samples * i) - current_speech["start"] > max_speech_samples
|
| 145 |
+
):
|
| 146 |
+
if prev_end:
|
| 147 |
+
current_speech["end"] = prev_end
|
| 148 |
+
speeches.append(current_speech)
|
| 149 |
+
current_speech = {}
|
| 150 |
+
# previously reached silence (< neg_thres) and is still not speech (< thres)
|
| 151 |
+
if next_start < prev_end:
|
| 152 |
+
triggered = False
|
| 153 |
+
else:
|
| 154 |
+
current_speech["start"] = next_start
|
| 155 |
+
prev_end = next_start = temp_end = 0
|
| 156 |
+
else:
|
| 157 |
+
current_speech["end"] = window_size_samples * i
|
| 158 |
+
speeches.append(current_speech)
|
| 159 |
+
current_speech = {}
|
| 160 |
+
prev_end = next_start = temp_end = 0
|
| 161 |
+
triggered = False
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
if (speech_prob < neg_threshold) and triggered:
|
| 165 |
+
if not temp_end:
|
| 166 |
+
temp_end = window_size_samples * i
|
| 167 |
+
# condition to avoid cutting in very short silence
|
| 168 |
+
if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
|
| 169 |
+
prev_end = temp_end
|
| 170 |
+
if (window_size_samples * i) - temp_end < min_silence_samples:
|
| 171 |
+
continue
|
| 172 |
+
else:
|
| 173 |
+
current_speech["end"] = temp_end
|
| 174 |
+
if (
|
| 175 |
+
current_speech["end"] - current_speech["start"]
|
| 176 |
+
) > min_speech_samples:
|
| 177 |
+
speeches.append(current_speech)
|
| 178 |
+
current_speech = {}
|
| 179 |
+
prev_end = next_start = temp_end = 0
|
| 180 |
+
triggered = False
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
if (
|
| 184 |
+
current_speech
|
| 185 |
+
and (audio_length_samples - current_speech["start"]) > min_speech_samples
|
| 186 |
+
):
|
| 187 |
+
current_speech["end"] = audio_length_samples
|
| 188 |
+
speeches.append(current_speech)
|
| 189 |
+
|
| 190 |
+
for i, speech in enumerate(speeches):
|
| 191 |
+
if i == 0:
|
| 192 |
+
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
|
| 193 |
+
if i != len(speeches) - 1:
|
| 194 |
+
silence_duration = speeches[i + 1]["start"] - speech["end"]
|
| 195 |
+
if silence_duration < 2 * speech_pad_samples:
|
| 196 |
+
speech["end"] += int(silence_duration // 2)
|
| 197 |
+
speeches[i + 1]["start"] = int(
|
| 198 |
+
max(0, speeches[i + 1]["start"] - silence_duration // 2)
|
| 199 |
+
)
|
| 200 |
+
else:
|
| 201 |
+
speech["end"] = int(
|
| 202 |
+
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 203 |
+
)
|
| 204 |
+
speeches[i + 1]["start"] = int(
|
| 205 |
+
max(0, speeches[i + 1]["start"] - speech_pad_samples)
|
| 206 |
+
)
|
| 207 |
+
else:
|
| 208 |
+
speech["end"] = int(
|
| 209 |
+
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
return speeches
|
| 213 |
+
|
| 214 |
+
def update_model(self):
|
| 215 |
+
self.model = get_vad_model()
|
| 216 |
+
|
| 217 |
+
@staticmethod
|
| 218 |
+
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
| 219 |
+
"""Collects and concatenates audio chunks."""
|
| 220 |
+
if not chunks:
|
| 221 |
+
return np.array([], dtype=np.float32)
|
| 222 |
+
|
| 223 |
+
return np.concatenate([audio[chunk["start"]: chunk["end"]] for chunk in chunks])
|
| 224 |
+
|
| 225 |
+
@staticmethod
|
| 226 |
+
def format_timestamp(
|
| 227 |
+
seconds: float,
|
| 228 |
+
always_include_hours: bool = False,
|
| 229 |
+
decimal_marker: str = ".",
|
| 230 |
+
) -> str:
|
| 231 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
| 232 |
+
milliseconds = round(seconds * 1000.0)
|
| 233 |
+
|
| 234 |
+
hours = milliseconds // 3_600_000
|
| 235 |
+
milliseconds -= hours * 3_600_000
|
| 236 |
+
|
| 237 |
+
minutes = milliseconds // 60_000
|
| 238 |
+
milliseconds -= minutes * 60_000
|
| 239 |
+
|
| 240 |
+
seconds = milliseconds // 1_000
|
| 241 |
+
milliseconds -= seconds * 1_000
|
| 242 |
+
|
| 243 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
| 244 |
+
return (
|
| 245 |
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def restore_speech_timestamps(
|
| 249 |
+
self,
|
| 250 |
+
segments: List[dict],
|
| 251 |
+
speech_chunks: List[dict],
|
| 252 |
+
sampling_rate: Optional[int] = None,
|
| 253 |
+
) -> List[dict]:
|
| 254 |
+
if sampling_rate is None:
|
| 255 |
+
sampling_rate = self.sampling_rate
|
| 256 |
+
|
| 257 |
+
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
| 258 |
+
|
| 259 |
+
for segment in segments:
|
| 260 |
+
segment["start"] = ts_map.get_original_time(segment["start"])
|
| 261 |
+
segment["end"] = ts_map.get_original_time(segment["end"])
|
| 262 |
+
|
| 263 |
+
return segments
|
| 264 |
+
|
modules/whisper/__init__.py
ADDED
|
File without changes
|
modules/whisper/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (175 Bytes). View file
|
|
|
modules/whisper/__pycache__/faster_whisper_inference.cpython-310.pyc
ADDED
|
Binary file (6.51 kB). View file
|
|
|
modules/whisper/__pycache__/whisper_base.cpython-310.pyc
ADDED
|
Binary file (12.9 kB). View file
|
|
|
modules/whisper/__pycache__/whisper_factory.cpython-310.pyc
ADDED
|
Binary file (2.87 kB). View file
|
|
|
modules/whisper/__pycache__/whisper_parameter.cpython-310.pyc
ADDED
|
Binary file (3.68 kB). View file
|
|
|
modules/whisper/faster_whisper_inference.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from typing import BinaryIO, Union, Tuple, List
|
| 6 |
+
import faster_whisper
|
| 7 |
+
from faster_whisper.vad import VadOptions
|
| 8 |
+
import ast
|
| 9 |
+
import ctranslate2
|
| 10 |
+
import whisper
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from argparse import Namespace
|
| 13 |
+
|
| 14 |
+
from modules.whisper.whisper_parameter import *
|
| 15 |
+
from modules.whisper.whisper_base import WhisperBase
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FasterWhisperInference(WhisperBase):
|
| 19 |
+
def __init__(self,
|
| 20 |
+
model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
|
| 21 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 22 |
+
output_dir: str = os.path.join("outputs"),
|
| 23 |
+
):
|
| 24 |
+
super().__init__(
|
| 25 |
+
model_dir=model_dir,
|
| 26 |
+
diarization_model_dir=diarization_model_dir,
|
| 27 |
+
output_dir=output_dir
|
| 28 |
+
)
|
| 29 |
+
self.model_dir = model_dir
|
| 30 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
self.model_paths = self.get_model_paths()
|
| 33 |
+
self.device = self.get_device()
|
| 34 |
+
self.available_models = self.model_paths.keys()
|
| 35 |
+
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
| 36 |
+
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
| 37 |
+
|
| 38 |
+
def transcribe(self,
|
| 39 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
| 40 |
+
progress: gr.Progress,
|
| 41 |
+
*whisper_params,
|
| 42 |
+
) -> Tuple[List[dict], float]:
|
| 43 |
+
"""
|
| 44 |
+
transcribe method for faster-whisper.
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 49 |
+
Audio path or file binary or Audio numpy array
|
| 50 |
+
progress: gr.Progress
|
| 51 |
+
Indicator to show progress directly in gradio.
|
| 52 |
+
*whisper_params: tuple
|
| 53 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 54 |
+
|
| 55 |
+
Returns
|
| 56 |
+
----------
|
| 57 |
+
segments_result: List[dict]
|
| 58 |
+
list of dicts that includes start, end timestamps and transcribed text
|
| 59 |
+
elapsed_time: float
|
| 60 |
+
elapsed time for transcription
|
| 61 |
+
"""
|
| 62 |
+
start_time = time.time()
|
| 63 |
+
|
| 64 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 65 |
+
|
| 66 |
+
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 67 |
+
self.update_model(params.model_size, params.compute_type, progress)
|
| 68 |
+
|
| 69 |
+
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
|
| 70 |
+
if not params.initial_prompt:
|
| 71 |
+
params.initial_prompt = None
|
| 72 |
+
if not params.prefix:
|
| 73 |
+
params.prefix = None
|
| 74 |
+
if not params.hotwords:
|
| 75 |
+
params.hotwords = None
|
| 76 |
+
|
| 77 |
+
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
|
| 78 |
+
|
| 79 |
+
segments, info = self.model.transcribe(
|
| 80 |
+
audio=audio,
|
| 81 |
+
language=params.lang,
|
| 82 |
+
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
| 83 |
+
beam_size=params.beam_size,
|
| 84 |
+
log_prob_threshold=params.log_prob_threshold,
|
| 85 |
+
no_speech_threshold=params.no_speech_threshold,
|
| 86 |
+
best_of=params.best_of,
|
| 87 |
+
patience=params.patience,
|
| 88 |
+
temperature=params.temperature,
|
| 89 |
+
initial_prompt=params.initial_prompt,
|
| 90 |
+
compression_ratio_threshold=params.compression_ratio_threshold,
|
| 91 |
+
length_penalty=params.length_penalty,
|
| 92 |
+
repetition_penalty=params.repetition_penalty,
|
| 93 |
+
no_repeat_ngram_size=params.no_repeat_ngram_size,
|
| 94 |
+
prefix=params.prefix,
|
| 95 |
+
suppress_blank=params.suppress_blank,
|
| 96 |
+
suppress_tokens=params.suppress_tokens,
|
| 97 |
+
max_initial_timestamp=params.max_initial_timestamp,
|
| 98 |
+
word_timestamps=params.word_timestamps,
|
| 99 |
+
prepend_punctuations=params.prepend_punctuations,
|
| 100 |
+
append_punctuations=params.append_punctuations,
|
| 101 |
+
max_new_tokens=params.max_new_tokens,
|
| 102 |
+
chunk_length=params.chunk_length,
|
| 103 |
+
hallucination_silence_threshold=params.hallucination_silence_threshold,
|
| 104 |
+
hotwords=params.hotwords,
|
| 105 |
+
language_detection_threshold=params.language_detection_threshold,
|
| 106 |
+
language_detection_segments=params.language_detection_segments,
|
| 107 |
+
prompt_reset_on_temperature=params.prompt_reset_on_temperature,
|
| 108 |
+
)
|
| 109 |
+
progress(0, desc="Loading audio..")
|
| 110 |
+
|
| 111 |
+
segments_result = []
|
| 112 |
+
for segment in segments:
|
| 113 |
+
progress(segment.start / info.duration, desc="Transcribing..")
|
| 114 |
+
segments_result.append({
|
| 115 |
+
"start": segment.start,
|
| 116 |
+
"end": segment.end,
|
| 117 |
+
"text": segment.text
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
elapsed_time = time.time() - start_time
|
| 121 |
+
return segments_result, elapsed_time
|
| 122 |
+
|
| 123 |
+
def update_model(self,
|
| 124 |
+
model_size: str,
|
| 125 |
+
compute_type: str,
|
| 126 |
+
progress: gr.Progress
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Update current model setting
|
| 130 |
+
|
| 131 |
+
Parameters
|
| 132 |
+
----------
|
| 133 |
+
model_size: str
|
| 134 |
+
Size of whisper model
|
| 135 |
+
compute_type: str
|
| 136 |
+
Compute type for transcription.
|
| 137 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 138 |
+
progress: gr.Progress
|
| 139 |
+
Indicator to show progress directly in gradio.
|
| 140 |
+
"""
|
| 141 |
+
progress(0, desc="Initializing Model..")
|
| 142 |
+
self.current_model_size = self.model_paths[model_size]
|
| 143 |
+
self.current_compute_type = compute_type
|
| 144 |
+
self.model = faster_whisper.WhisperModel(
|
| 145 |
+
device=self.device,
|
| 146 |
+
model_size_or_path=self.current_model_size,
|
| 147 |
+
download_root=self.model_dir,
|
| 148 |
+
compute_type=self.current_compute_type
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def get_model_paths(self):
|
| 152 |
+
"""
|
| 153 |
+
Get available models from models path including fine-tuned model.
|
| 154 |
+
|
| 155 |
+
Returns
|
| 156 |
+
----------
|
| 157 |
+
Name list of models
|
| 158 |
+
"""
|
| 159 |
+
model_paths = {model:model for model in whisper.available_models()}
|
| 160 |
+
faster_whisper_prefix = "models--Systran--faster-whisper-"
|
| 161 |
+
|
| 162 |
+
existing_models = os.listdir(self.model_dir)
|
| 163 |
+
wrong_dirs = [".locks"]
|
| 164 |
+
existing_models = list(set(existing_models) - set(wrong_dirs))
|
| 165 |
+
|
| 166 |
+
webui_dir = os.getcwd()
|
| 167 |
+
|
| 168 |
+
for model_name in existing_models:
|
| 169 |
+
if faster_whisper_prefix in model_name:
|
| 170 |
+
model_name = model_name[len(faster_whisper_prefix):]
|
| 171 |
+
|
| 172 |
+
if model_name not in whisper.available_models():
|
| 173 |
+
model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
|
| 174 |
+
return model_paths
|
| 175 |
+
|
| 176 |
+
@staticmethod
|
| 177 |
+
def get_device():
|
| 178 |
+
if torch.cuda.is_available():
|
| 179 |
+
return "cuda"
|
| 180 |
+
else:
|
| 181 |
+
return "auto"
|
| 182 |
+
|
| 183 |
+
@staticmethod
|
| 184 |
+
def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
|
| 185 |
+
try:
|
| 186 |
+
suppress_tokens = ast.literal_eval(suppress_tokens_str)
|
| 187 |
+
if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
|
| 188 |
+
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
| 189 |
+
return suppress_tokens
|
| 190 |
+
except Exception as e:
|
| 191 |
+
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
modules/whisper/insanely_fast_whisper_inference.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import numpy as np
|
| 4 |
+
from typing import BinaryIO, Union, Tuple, List
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import pipeline
|
| 7 |
+
from transformers.utils import is_flash_attn_2_available
|
| 8 |
+
import gradio as gr
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
import whisper
|
| 11 |
+
from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
| 12 |
+
from argparse import Namespace
|
| 13 |
+
|
| 14 |
+
from modules.whisper.whisper_parameter import *
|
| 15 |
+
from modules.whisper.whisper_base import WhisperBase
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class InsanelyFastWhisperInference(WhisperBase):
|
| 19 |
+
def __init__(self,
|
| 20 |
+
model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
|
| 21 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 22 |
+
output_dir: str = os.path.join("outputs"),
|
| 23 |
+
):
|
| 24 |
+
super().__init__(
|
| 25 |
+
model_dir=model_dir,
|
| 26 |
+
output_dir=output_dir,
|
| 27 |
+
diarization_model_dir=diarization_model_dir
|
| 28 |
+
)
|
| 29 |
+
self.model_dir = model_dir
|
| 30 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 31 |
+
|
| 32 |
+
openai_models = whisper.available_models()
|
| 33 |
+
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
| 34 |
+
self.available_models = openai_models + distil_models
|
| 35 |
+
self.available_compute_types = ["float16"]
|
| 36 |
+
|
| 37 |
+
def transcribe(self,
|
| 38 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 39 |
+
progress: gr.Progress,
|
| 40 |
+
*whisper_params,
|
| 41 |
+
) -> Tuple[List[dict], float]:
|
| 42 |
+
"""
|
| 43 |
+
transcribe method for faster-whisper.
|
| 44 |
+
|
| 45 |
+
Parameters
|
| 46 |
+
----------
|
| 47 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 48 |
+
Audio path or file binary or Audio numpy array
|
| 49 |
+
progress: gr.Progress
|
| 50 |
+
Indicator to show progress directly in gradio.
|
| 51 |
+
*whisper_params: tuple
|
| 52 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 53 |
+
|
| 54 |
+
Returns
|
| 55 |
+
----------
|
| 56 |
+
segments_result: List[dict]
|
| 57 |
+
list of dicts that includes start, end timestamps and transcribed text
|
| 58 |
+
elapsed_time: float
|
| 59 |
+
elapsed time for transcription
|
| 60 |
+
"""
|
| 61 |
+
start_time = time.time()
|
| 62 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 63 |
+
|
| 64 |
+
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 65 |
+
self.update_model(params.model_size, params.compute_type, progress)
|
| 66 |
+
|
| 67 |
+
progress(0, desc="Transcribing...Progress is not shown in insanely-fast-whisper.")
|
| 68 |
+
with Progress(
|
| 69 |
+
TextColumn("[progress.description]{task.description}"),
|
| 70 |
+
BarColumn(style="yellow1", pulse_style="white"),
|
| 71 |
+
TimeElapsedColumn(),
|
| 72 |
+
) as progress:
|
| 73 |
+
progress.add_task("[yellow]Transcribing...", total=None)
|
| 74 |
+
|
| 75 |
+
segments = self.model(
|
| 76 |
+
inputs=audio,
|
| 77 |
+
return_timestamps=True,
|
| 78 |
+
chunk_length_s=params.chunk_length_s,
|
| 79 |
+
batch_size=params.batch_size,
|
| 80 |
+
generate_kwargs={
|
| 81 |
+
"language": params.lang,
|
| 82 |
+
"task": "translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
| 83 |
+
"no_speech_threshold": params.no_speech_threshold,
|
| 84 |
+
"temperature": params.temperature,
|
| 85 |
+
"compression_ratio_threshold": params.compression_ratio_threshold
|
| 86 |
+
}
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
segments_result = self.format_result(
|
| 90 |
+
transcribed_result=segments,
|
| 91 |
+
)
|
| 92 |
+
elapsed_time = time.time() - start_time
|
| 93 |
+
return segments_result, elapsed_time
|
| 94 |
+
|
| 95 |
+
def update_model(self,
|
| 96 |
+
model_size: str,
|
| 97 |
+
compute_type: str,
|
| 98 |
+
progress: gr.Progress,
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Update current model setting
|
| 102 |
+
|
| 103 |
+
Parameters
|
| 104 |
+
----------
|
| 105 |
+
model_size: str
|
| 106 |
+
Size of whisper model
|
| 107 |
+
compute_type: str
|
| 108 |
+
Compute type for transcription.
|
| 109 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 110 |
+
progress: gr.Progress
|
| 111 |
+
Indicator to show progress directly in gradio.
|
| 112 |
+
"""
|
| 113 |
+
progress(0, desc="Initializing Model..")
|
| 114 |
+
model_path = os.path.join(self.model_dir, model_size)
|
| 115 |
+
if not os.path.isdir(model_path) or not os.listdir(model_path):
|
| 116 |
+
self.download_model(
|
| 117 |
+
model_size=model_size,
|
| 118 |
+
download_root=model_path,
|
| 119 |
+
progress=progress
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
self.current_compute_type = compute_type
|
| 123 |
+
self.current_model_size = model_size
|
| 124 |
+
self.model = pipeline(
|
| 125 |
+
"automatic-speech-recognition",
|
| 126 |
+
model=os.path.join(self.model_dir, model_size),
|
| 127 |
+
torch_dtype=self.current_compute_type,
|
| 128 |
+
device=self.device,
|
| 129 |
+
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
@staticmethod
|
| 133 |
+
def format_result(
|
| 134 |
+
transcribed_result: dict
|
| 135 |
+
) -> List[dict]:
|
| 136 |
+
"""
|
| 137 |
+
Format the transcription result of insanely_fast_whisper as the same with other implementation.
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
transcribed_result: dict
|
| 142 |
+
Transcription result of the insanely_fast_whisper
|
| 143 |
+
|
| 144 |
+
Returns
|
| 145 |
+
----------
|
| 146 |
+
result: List[dict]
|
| 147 |
+
Formatted result as the same with other implementation
|
| 148 |
+
"""
|
| 149 |
+
result = transcribed_result["chunks"]
|
| 150 |
+
for item in result:
|
| 151 |
+
start, end = item["timestamp"][0], item["timestamp"][1]
|
| 152 |
+
if end is None:
|
| 153 |
+
end = start
|
| 154 |
+
item["start"] = start
|
| 155 |
+
item["end"] = end
|
| 156 |
+
return result
|
| 157 |
+
|
| 158 |
+
@staticmethod
|
| 159 |
+
def download_model(
|
| 160 |
+
model_size: str,
|
| 161 |
+
download_root: str,
|
| 162 |
+
progress: gr.Progress
|
| 163 |
+
):
|
| 164 |
+
progress(0, 'Initializing model..')
|
| 165 |
+
print(f'Downloading {model_size} to "{download_root}"....')
|
| 166 |
+
|
| 167 |
+
os.makedirs(download_root, exist_ok=True)
|
| 168 |
+
download_list = [
|
| 169 |
+
"model.safetensors",
|
| 170 |
+
"config.json",
|
| 171 |
+
"generation_config.json",
|
| 172 |
+
"preprocessor_config.json",
|
| 173 |
+
"tokenizer.json",
|
| 174 |
+
"tokenizer_config.json",
|
| 175 |
+
"added_tokens.json",
|
| 176 |
+
"special_tokens_map.json",
|
| 177 |
+
"vocab.json",
|
| 178 |
+
]
|
| 179 |
+
|
| 180 |
+
if model_size.startswith("distil"):
|
| 181 |
+
repo_id = f"distil-whisper/{model_size}"
|
| 182 |
+
else:
|
| 183 |
+
repo_id = f"openai/whisper-{model_size}"
|
| 184 |
+
for item in download_list:
|
| 185 |
+
hf_hub_download(repo_id=repo_id, filename=item, local_dir=download_root)
|
modules/whisper/whisper_Inference.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import whisper
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import time
|
| 4 |
+
from typing import BinaryIO, Union, Tuple, List
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
from argparse import Namespace
|
| 9 |
+
|
| 10 |
+
from modules.whisper.whisper_base import WhisperBase
|
| 11 |
+
from modules.whisper.whisper_parameter import *
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class WhisperInference(WhisperBase):
|
| 15 |
+
def __init__(self,
|
| 16 |
+
model_dir: str = os.path.join("models", "Whisper"),
|
| 17 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 18 |
+
output_dir: str = os.path.join("outputs"),
|
| 19 |
+
):
|
| 20 |
+
super().__init__(
|
| 21 |
+
model_dir=model_dir,
|
| 22 |
+
output_dir=output_dir,
|
| 23 |
+
diarization_model_dir=diarization_model_dir
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
def transcribe(self,
|
| 27 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
| 28 |
+
progress: gr.Progress,
|
| 29 |
+
*whisper_params,
|
| 30 |
+
) -> Tuple[List[dict], float]:
|
| 31 |
+
"""
|
| 32 |
+
transcribe method for faster-whisper.
|
| 33 |
+
|
| 34 |
+
Parameters
|
| 35 |
+
----------
|
| 36 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 37 |
+
Audio path or file binary or Audio numpy array
|
| 38 |
+
progress: gr.Progress
|
| 39 |
+
Indicator to show progress directly in gradio.
|
| 40 |
+
*whisper_params: tuple
|
| 41 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 42 |
+
|
| 43 |
+
Returns
|
| 44 |
+
----------
|
| 45 |
+
segments_result: List[dict]
|
| 46 |
+
list of dicts that includes start, end timestamps and transcribed text
|
| 47 |
+
elapsed_time: float
|
| 48 |
+
elapsed time for transcription
|
| 49 |
+
"""
|
| 50 |
+
start_time = time.time()
|
| 51 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 52 |
+
|
| 53 |
+
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
| 54 |
+
self.update_model(params.model_size, params.compute_type, progress)
|
| 55 |
+
|
| 56 |
+
def progress_callback(progress_value):
|
| 57 |
+
progress(progress_value, desc="Transcribing..")
|
| 58 |
+
|
| 59 |
+
segments_result = self.model.transcribe(audio=audio,
|
| 60 |
+
language=params.lang,
|
| 61 |
+
verbose=False,
|
| 62 |
+
beam_size=params.beam_size,
|
| 63 |
+
logprob_threshold=params.log_prob_threshold,
|
| 64 |
+
no_speech_threshold=params.no_speech_threshold,
|
| 65 |
+
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
| 66 |
+
fp16=True if params.compute_type == "float16" else False,
|
| 67 |
+
best_of=params.best_of,
|
| 68 |
+
patience=params.patience,
|
| 69 |
+
temperature=params.temperature,
|
| 70 |
+
compression_ratio_threshold=params.compression_ratio_threshold,
|
| 71 |
+
progress_callback=progress_callback,)["segments"]
|
| 72 |
+
elapsed_time = time.time() - start_time
|
| 73 |
+
|
| 74 |
+
return segments_result, elapsed_time
|
| 75 |
+
|
| 76 |
+
def update_model(self,
|
| 77 |
+
model_size: str,
|
| 78 |
+
compute_type: str,
|
| 79 |
+
progress: gr.Progress,
|
| 80 |
+
):
|
| 81 |
+
"""
|
| 82 |
+
Update current model setting
|
| 83 |
+
|
| 84 |
+
Parameters
|
| 85 |
+
----------
|
| 86 |
+
model_size: str
|
| 87 |
+
Size of whisper model
|
| 88 |
+
compute_type: str
|
| 89 |
+
Compute type for transcription.
|
| 90 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 91 |
+
progress: gr.Progress
|
| 92 |
+
Indicator to show progress directly in gradio.
|
| 93 |
+
"""
|
| 94 |
+
progress(0, desc="Initializing Model..")
|
| 95 |
+
self.current_compute_type = compute_type
|
| 96 |
+
self.current_model_size = model_size
|
| 97 |
+
self.model = whisper.load_model(
|
| 98 |
+
name=model_size,
|
| 99 |
+
device=self.device,
|
| 100 |
+
download_root=self.model_dir
|
| 101 |
+
)
|
modules/whisper/whisper_base.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import whisper
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
from typing import BinaryIO, Union, Tuple, List
|
| 7 |
+
import numpy as np
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from faster_whisper.vad import VadOptions
|
| 10 |
+
from dataclasses import astuple
|
| 11 |
+
|
| 12 |
+
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
| 13 |
+
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
| 14 |
+
from modules.utils.files_manager import get_media_files, format_gradio_files
|
| 15 |
+
from modules.whisper.whisper_parameter import *
|
| 16 |
+
from modules.diarize.diarizer import Diarizer
|
| 17 |
+
from modules.vad.silero_vad import SileroVAD
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class WhisperBase(ABC):
|
| 21 |
+
def __init__(self,
|
| 22 |
+
model_dir: str = os.path.join("models", "Whisper"),
|
| 23 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 24 |
+
output_dir: str = os.path.join("outputs"),
|
| 25 |
+
):
|
| 26 |
+
self.model_dir = model_dir
|
| 27 |
+
self.output_dir = output_dir
|
| 28 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 29 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
| 30 |
+
self.diarizer = Diarizer(
|
| 31 |
+
model_dir=diarization_model_dir
|
| 32 |
+
)
|
| 33 |
+
self.vad = SileroVAD()
|
| 34 |
+
|
| 35 |
+
self.model = None
|
| 36 |
+
self.current_model_size = None
|
| 37 |
+
self.available_models = whisper.available_models()
|
| 38 |
+
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
| 39 |
+
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
| 40 |
+
self.device = self.get_device()
|
| 41 |
+
self.available_compute_types = ["float16", "float32"]
|
| 42 |
+
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def transcribe(self,
|
| 46 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
| 47 |
+
progress: gr.Progress,
|
| 48 |
+
*whisper_params,
|
| 49 |
+
):
|
| 50 |
+
"""Inference whisper model to transcribe"""
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def update_model(self,
|
| 55 |
+
model_size: str,
|
| 56 |
+
compute_type: str,
|
| 57 |
+
progress: gr.Progress
|
| 58 |
+
):
|
| 59 |
+
"""Initialize whisper model"""
|
| 60 |
+
pass
|
| 61 |
+
|
| 62 |
+
def run(self,
|
| 63 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
| 64 |
+
progress: gr.Progress,
|
| 65 |
+
*whisper_params,
|
| 66 |
+
) -> Tuple[List[dict], float]:
|
| 67 |
+
"""
|
| 68 |
+
Run transcription with conditional pre-processing and post-processing.
|
| 69 |
+
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
|
| 70 |
+
The diarization will be performed in post-processing, if enabled.
|
| 71 |
+
|
| 72 |
+
Parameters
|
| 73 |
+
----------
|
| 74 |
+
audio: Union[str, BinaryIO, np.ndarray]
|
| 75 |
+
Audio input. This can be file path or binary type.
|
| 76 |
+
progress: gr.Progress
|
| 77 |
+
Indicator to show progress directly in gradio.
|
| 78 |
+
*whisper_params: tuple
|
| 79 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 80 |
+
|
| 81 |
+
Returns
|
| 82 |
+
----------
|
| 83 |
+
segments_result: List[dict]
|
| 84 |
+
list of dicts that includes start, end timestamps and transcribed text
|
| 85 |
+
elapsed_time: float
|
| 86 |
+
elapsed time for running
|
| 87 |
+
"""
|
| 88 |
+
params = WhisperParameters.as_value(*whisper_params)
|
| 89 |
+
|
| 90 |
+
if params.lang == "Automatic Detection":
|
| 91 |
+
params.lang = None
|
| 92 |
+
else:
|
| 93 |
+
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
| 94 |
+
params.lang = language_code_dict[params.lang]
|
| 95 |
+
|
| 96 |
+
speech_chunks = None
|
| 97 |
+
if params.vad_filter:
|
| 98 |
+
# Explicit value set for float('inf') from gr.Number()
|
| 99 |
+
if params.max_speech_duration_s >= 9999:
|
| 100 |
+
params.max_speech_duration_s = float('inf')
|
| 101 |
+
|
| 102 |
+
vad_options = VadOptions(
|
| 103 |
+
threshold=params.threshold,
|
| 104 |
+
min_speech_duration_ms=params.min_speech_duration_ms,
|
| 105 |
+
max_speech_duration_s=params.max_speech_duration_s,
|
| 106 |
+
min_silence_duration_ms=params.min_silence_duration_ms,
|
| 107 |
+
speech_pad_ms=params.speech_pad_ms
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
audio, speech_chunks = self.vad.run(
|
| 111 |
+
audio=audio,
|
| 112 |
+
vad_parameters=vad_options,
|
| 113 |
+
progress=progress
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
result, elapsed_time = self.transcribe(
|
| 117 |
+
audio,
|
| 118 |
+
progress,
|
| 119 |
+
*astuple(params)
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if params.vad_filter:
|
| 123 |
+
result = self.vad.restore_speech_timestamps(
|
| 124 |
+
segments=result,
|
| 125 |
+
speech_chunks=speech_chunks,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if params.is_diarize:
|
| 129 |
+
result, elapsed_time_diarization = self.diarizer.run(
|
| 130 |
+
audio=audio,
|
| 131 |
+
use_auth_token=params.hf_token,
|
| 132 |
+
transcribed_result=result,
|
| 133 |
+
)
|
| 134 |
+
elapsed_time += elapsed_time_diarization
|
| 135 |
+
return result, elapsed_time
|
| 136 |
+
|
| 137 |
+
def transcribe_file(self,
|
| 138 |
+
files: list,
|
| 139 |
+
input_folder_path: str,
|
| 140 |
+
file_format: str,
|
| 141 |
+
add_timestamp: bool,
|
| 142 |
+
progress=gr.Progress(),
|
| 143 |
+
*whisper_params,
|
| 144 |
+
) -> list:
|
| 145 |
+
"""
|
| 146 |
+
Write subtitle file from Files
|
| 147 |
+
|
| 148 |
+
Parameters
|
| 149 |
+
----------
|
| 150 |
+
files: list
|
| 151 |
+
List of files to transcribe from gr.Files()
|
| 152 |
+
input_folder_path: str
|
| 153 |
+
Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
|
| 154 |
+
this will be used instead.
|
| 155 |
+
file_format: str
|
| 156 |
+
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
| 157 |
+
add_timestamp: bool
|
| 158 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
| 159 |
+
progress: gr.Progress
|
| 160 |
+
Indicator to show progress directly in gradio.
|
| 161 |
+
*whisper_params: tuple
|
| 162 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 163 |
+
|
| 164 |
+
Returns
|
| 165 |
+
----------
|
| 166 |
+
result_str:
|
| 167 |
+
Result of transcription to return to gr.Textbox()
|
| 168 |
+
result_file_path:
|
| 169 |
+
Output file path to return to gr.Files()
|
| 170 |
+
"""
|
| 171 |
+
try:
|
| 172 |
+
if input_folder_path:
|
| 173 |
+
files = get_media_files(input_folder_path)
|
| 174 |
+
files = format_gradio_files(files)
|
| 175 |
+
|
| 176 |
+
files_info = {}
|
| 177 |
+
for file in files:
|
| 178 |
+
transcribed_segments, time_for_task = self.run(
|
| 179 |
+
file.name,
|
| 180 |
+
progress,
|
| 181 |
+
*whisper_params,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
file_name, file_ext = os.path.splitext(os.path.basename(file.name))
|
| 185 |
+
subtitle, file_path = self.generate_and_write_file(
|
| 186 |
+
file_name=file_name,
|
| 187 |
+
transcribed_segments=transcribed_segments,
|
| 188 |
+
add_timestamp=add_timestamp,
|
| 189 |
+
file_format=file_format,
|
| 190 |
+
output_dir=self.output_dir
|
| 191 |
+
)
|
| 192 |
+
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
|
| 193 |
+
|
| 194 |
+
total_result = ''
|
| 195 |
+
total_time = 0
|
| 196 |
+
for file_name, info in files_info.items():
|
| 197 |
+
total_result += '------------------------------------\n'
|
| 198 |
+
total_result += f'{file_name}\n\n'
|
| 199 |
+
total_result += f'{info["subtitle"]}'
|
| 200 |
+
total_time += info["time_for_task"]
|
| 201 |
+
|
| 202 |
+
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
|
| 203 |
+
result_file_path = [info['path'] for info in files_info.values()]
|
| 204 |
+
|
| 205 |
+
return [result_str, result_file_path]
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"Error transcribing file: {e}")
|
| 209 |
+
finally:
|
| 210 |
+
self.release_cuda_memory()
|
| 211 |
+
if not files:
|
| 212 |
+
self.remove_input_files([file.name for file in files])
|
| 213 |
+
|
| 214 |
+
def transcribe_mic(self,
|
| 215 |
+
mic_audio: str,
|
| 216 |
+
file_format: str,
|
| 217 |
+
progress=gr.Progress(),
|
| 218 |
+
*whisper_params,
|
| 219 |
+
) -> list:
|
| 220 |
+
"""
|
| 221 |
+
Write subtitle file from microphone
|
| 222 |
+
|
| 223 |
+
Parameters
|
| 224 |
+
----------
|
| 225 |
+
mic_audio: str
|
| 226 |
+
Audio file path from gr.Microphone()
|
| 227 |
+
file_format: str
|
| 228 |
+
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
| 229 |
+
progress: gr.Progress
|
| 230 |
+
Indicator to show progress directly in gradio.
|
| 231 |
+
*whisper_params: tuple
|
| 232 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 233 |
+
|
| 234 |
+
Returns
|
| 235 |
+
----------
|
| 236 |
+
result_str:
|
| 237 |
+
Result of transcription to return to gr.Textbox()
|
| 238 |
+
result_file_path:
|
| 239 |
+
Output file path to return to gr.Files()
|
| 240 |
+
"""
|
| 241 |
+
try:
|
| 242 |
+
progress(0, desc="Loading Audio..")
|
| 243 |
+
transcribed_segments, time_for_task = self.run(
|
| 244 |
+
mic_audio,
|
| 245 |
+
progress,
|
| 246 |
+
*whisper_params,
|
| 247 |
+
)
|
| 248 |
+
progress(1, desc="Completed!")
|
| 249 |
+
|
| 250 |
+
subtitle, result_file_path = self.generate_and_write_file(
|
| 251 |
+
file_name="Mic",
|
| 252 |
+
transcribed_segments=transcribed_segments,
|
| 253 |
+
add_timestamp=True,
|
| 254 |
+
file_format=file_format,
|
| 255 |
+
output_dir=self.output_dir
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 259 |
+
return [result_str, result_file_path]
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"Error transcribing file: {e}")
|
| 262 |
+
finally:
|
| 263 |
+
self.release_cuda_memory()
|
| 264 |
+
self.remove_input_files([mic_audio])
|
| 265 |
+
|
| 266 |
+
def transcribe_youtube(self,
|
| 267 |
+
youtube_link: str,
|
| 268 |
+
file_format: str,
|
| 269 |
+
add_timestamp: bool,
|
| 270 |
+
progress=gr.Progress(),
|
| 271 |
+
*whisper_params,
|
| 272 |
+
) -> list:
|
| 273 |
+
"""
|
| 274 |
+
Write subtitle file from Youtube
|
| 275 |
+
|
| 276 |
+
Parameters
|
| 277 |
+
----------
|
| 278 |
+
youtube_link: str
|
| 279 |
+
URL of the Youtube video to transcribe from gr.Textbox()
|
| 280 |
+
file_format: str
|
| 281 |
+
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
| 282 |
+
add_timestamp: bool
|
| 283 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
| 284 |
+
progress: gr.Progress
|
| 285 |
+
Indicator to show progress directly in gradio.
|
| 286 |
+
*whisper_params: tuple
|
| 287 |
+
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
| 288 |
+
|
| 289 |
+
Returns
|
| 290 |
+
----------
|
| 291 |
+
result_str:
|
| 292 |
+
Result of transcription to return to gr.Textbox()
|
| 293 |
+
result_file_path:
|
| 294 |
+
Output file path to return to gr.Files()
|
| 295 |
+
"""
|
| 296 |
+
try:
|
| 297 |
+
progress(0, desc="Loading Audio from Youtube..")
|
| 298 |
+
yt = get_ytdata(youtube_link)
|
| 299 |
+
audio = get_ytaudio(yt)
|
| 300 |
+
|
| 301 |
+
transcribed_segments, time_for_task = self.run(
|
| 302 |
+
audio,
|
| 303 |
+
progress,
|
| 304 |
+
*whisper_params,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
progress(1, desc="Completed!")
|
| 308 |
+
|
| 309 |
+
file_name = safe_filename(yt.title)
|
| 310 |
+
subtitle, result_file_path = self.generate_and_write_file(
|
| 311 |
+
file_name=file_name,
|
| 312 |
+
transcribed_segments=transcribed_segments,
|
| 313 |
+
add_timestamp=add_timestamp,
|
| 314 |
+
file_format=file_format,
|
| 315 |
+
output_dir=self.output_dir
|
| 316 |
+
)
|
| 317 |
+
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
| 318 |
+
|
| 319 |
+
return [result_str, result_file_path]
|
| 320 |
+
|
| 321 |
+
except Exception as e:
|
| 322 |
+
print(f"Error transcribing file: {e}")
|
| 323 |
+
finally:
|
| 324 |
+
try:
|
| 325 |
+
if 'yt' not in locals():
|
| 326 |
+
yt = get_ytdata(youtube_link)
|
| 327 |
+
file_path = get_ytaudio(yt)
|
| 328 |
+
else:
|
| 329 |
+
file_path = get_ytaudio(yt)
|
| 330 |
+
|
| 331 |
+
self.release_cuda_memory()
|
| 332 |
+
self.remove_input_files([file_path])
|
| 333 |
+
except Exception as cleanup_error:
|
| 334 |
+
pass
|
| 335 |
+
|
| 336 |
+
@staticmethod
|
| 337 |
+
def generate_and_write_file(file_name: str,
|
| 338 |
+
transcribed_segments: list,
|
| 339 |
+
add_timestamp: bool,
|
| 340 |
+
file_format: str,
|
| 341 |
+
output_dir: str
|
| 342 |
+
) -> str:
|
| 343 |
+
"""
|
| 344 |
+
Writes subtitle file
|
| 345 |
+
|
| 346 |
+
Parameters
|
| 347 |
+
----------
|
| 348 |
+
file_name: str
|
| 349 |
+
Output file name
|
| 350 |
+
transcribed_segments: list
|
| 351 |
+
Text segments transcribed from audio
|
| 352 |
+
add_timestamp: bool
|
| 353 |
+
Determines whether to add a timestamp to the end of the filename.
|
| 354 |
+
file_format: str
|
| 355 |
+
File format to write. Supported formats: [SRT, WebVTT, txt]
|
| 356 |
+
output_dir: str
|
| 357 |
+
Directory path of the output
|
| 358 |
+
|
| 359 |
+
Returns
|
| 360 |
+
----------
|
| 361 |
+
content: str
|
| 362 |
+
Result of the transcription
|
| 363 |
+
output_path: str
|
| 364 |
+
output file path
|
| 365 |
+
"""
|
| 366 |
+
if add_timestamp:
|
| 367 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
| 368 |
+
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
|
| 369 |
+
else:
|
| 370 |
+
output_path = os.path.join(output_dir, f"{file_name}")
|
| 371 |
+
|
| 372 |
+
if file_format == "SRT":
|
| 373 |
+
content = get_srt(transcribed_segments)
|
| 374 |
+
output_path += '.srt'
|
| 375 |
+
|
| 376 |
+
elif file_format == "WebVTT":
|
| 377 |
+
content = get_vtt(transcribed_segments)
|
| 378 |
+
output_path += '.vtt'
|
| 379 |
+
|
| 380 |
+
elif file_format == "txt":
|
| 381 |
+
content = get_txt(transcribed_segments)
|
| 382 |
+
output_path += '.txt'
|
| 383 |
+
|
| 384 |
+
write_file(content, output_path)
|
| 385 |
+
return content, output_path
|
| 386 |
+
|
| 387 |
+
@staticmethod
|
| 388 |
+
def format_time(elapsed_time: float) -> str:
|
| 389 |
+
"""
|
| 390 |
+
Get {hours} {minutes} {seconds} time format string
|
| 391 |
+
|
| 392 |
+
Parameters
|
| 393 |
+
----------
|
| 394 |
+
elapsed_time: str
|
| 395 |
+
Elapsed time for transcription
|
| 396 |
+
|
| 397 |
+
Returns
|
| 398 |
+
----------
|
| 399 |
+
Time format string
|
| 400 |
+
"""
|
| 401 |
+
hours, rem = divmod(elapsed_time, 3600)
|
| 402 |
+
minutes, seconds = divmod(rem, 60)
|
| 403 |
+
|
| 404 |
+
time_str = ""
|
| 405 |
+
if hours:
|
| 406 |
+
time_str += f"{hours} hours "
|
| 407 |
+
if minutes:
|
| 408 |
+
time_str += f"{minutes} minutes "
|
| 409 |
+
seconds = round(seconds)
|
| 410 |
+
time_str += f"{seconds} seconds"
|
| 411 |
+
|
| 412 |
+
return time_str.strip()
|
| 413 |
+
|
| 414 |
+
@staticmethod
|
| 415 |
+
def get_device():
|
| 416 |
+
if torch.cuda.is_available():
|
| 417 |
+
return "cuda"
|
| 418 |
+
elif torch.backends.mps.is_available():
|
| 419 |
+
return "mps"
|
| 420 |
+
else:
|
| 421 |
+
return "cpu"
|
| 422 |
+
|
| 423 |
+
@staticmethod
|
| 424 |
+
def release_cuda_memory():
|
| 425 |
+
if torch.cuda.is_available():
|
| 426 |
+
torch.cuda.empty_cache()
|
| 427 |
+
torch.cuda.reset_max_memory_allocated()
|
| 428 |
+
|
| 429 |
+
@staticmethod
|
| 430 |
+
def remove_input_files(file_paths: List[str]):
|
| 431 |
+
if not file_paths:
|
| 432 |
+
return
|
| 433 |
+
|
| 434 |
+
for file_path in file_paths:
|
| 435 |
+
if file_path and os.path.exists(file_path):
|
| 436 |
+
os.remove(file_path)
|
modules/whisper/whisper_factory.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
| 5 |
+
from modules.whisper.whisper_Inference import WhisperInference
|
| 6 |
+
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
| 7 |
+
from modules.whisper.whisper_base import WhisperBase
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class WhisperFactory:
|
| 11 |
+
@staticmethod
|
| 12 |
+
def create_whisper_inference(
|
| 13 |
+
whisper_type: str,
|
| 14 |
+
whisper_model_dir: str = os.path.join("models", "Whisper"),
|
| 15 |
+
faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
|
| 16 |
+
insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
|
| 17 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
| 18 |
+
output_dir: str = os.path.join("outputs"),
|
| 19 |
+
) -> "WhisperBase":
|
| 20 |
+
"""
|
| 21 |
+
Create a whisper inference class based on the provided whisper_type.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
whisper_type : str
|
| 26 |
+
The type of Whisper implementation to use. Supported values (case-insensitive):
|
| 27 |
+
- "faster-whisper": https://github.com/openai/whisper
|
| 28 |
+
- "whisper": https://github.com/openai/whisper
|
| 29 |
+
- "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper
|
| 30 |
+
whisper_model_dir : str
|
| 31 |
+
Directory path for the Whisper model.
|
| 32 |
+
faster_whisper_model_dir : str
|
| 33 |
+
Directory path for the Faster Whisper model.
|
| 34 |
+
insanely_fast_whisper_model_dir : str
|
| 35 |
+
Directory path for the Insanely Fast Whisper model.
|
| 36 |
+
diarization_model_dir : str
|
| 37 |
+
Directory path for the diarization model.
|
| 38 |
+
output_dir : str
|
| 39 |
+
Directory path where output files will be saved.
|
| 40 |
+
|
| 41 |
+
Returns
|
| 42 |
+
-------
|
| 43 |
+
WhisperBase
|
| 44 |
+
An instance of the appropriate whisper inference class based on the whisper_type.
|
| 45 |
+
"""
|
| 46 |
+
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
| 47 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
| 48 |
+
|
| 49 |
+
whisper_type = whisper_type.lower().strip()
|
| 50 |
+
|
| 51 |
+
faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
|
| 52 |
+
whisper_typos = ["whisper"]
|
| 53 |
+
insanely_fast_whisper_typos = [
|
| 54 |
+
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
| 55 |
+
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
if whisper_type in faster_whisper_typos:
|
| 59 |
+
return FasterWhisperInference(
|
| 60 |
+
model_dir=faster_whisper_model_dir,
|
| 61 |
+
output_dir=output_dir,
|
| 62 |
+
diarization_model_dir=diarization_model_dir
|
| 63 |
+
)
|
| 64 |
+
elif whisper_type in whisper_typos:
|
| 65 |
+
return WhisperInference(
|
| 66 |
+
model_dir=whisper_model_dir,
|
| 67 |
+
output_dir=output_dir,
|
| 68 |
+
diarization_model_dir=diarization_model_dir
|
| 69 |
+
)
|
| 70 |
+
elif whisper_type in insanely_fast_whisper_typos:
|
| 71 |
+
return InsanelyFastWhisperInference(
|
| 72 |
+
model_dir=insanely_fast_whisper_model_dir,
|
| 73 |
+
output_dir=output_dir,
|
| 74 |
+
diarization_model_dir=diarization_model_dir
|
| 75 |
+
)
|
| 76 |
+
else:
|
| 77 |
+
return FasterWhisperInference(
|
| 78 |
+
model_dir=faster_whisper_model_dir,
|
| 79 |
+
output_dir=output_dir,
|
| 80 |
+
diarization_model_dir=diarization_model_dir
|
| 81 |
+
)
|
modules/whisper/whisper_parameter.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, fields
|
| 2 |
+
import gradio as gr
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class WhisperParameters:
|
| 8 |
+
model_size: gr.Dropdown
|
| 9 |
+
lang: gr.Dropdown
|
| 10 |
+
is_translate: gr.Checkbox
|
| 11 |
+
beam_size: gr.Number
|
| 12 |
+
log_prob_threshold: gr.Number
|
| 13 |
+
no_speech_threshold: gr.Number
|
| 14 |
+
compute_type: gr.Dropdown
|
| 15 |
+
best_of: gr.Number
|
| 16 |
+
patience: gr.Number
|
| 17 |
+
condition_on_previous_text: gr.Checkbox
|
| 18 |
+
prompt_reset_on_temperature: gr.Slider
|
| 19 |
+
initial_prompt: gr.Textbox
|
| 20 |
+
temperature: gr.Slider
|
| 21 |
+
compression_ratio_threshold: gr.Number
|
| 22 |
+
vad_filter: gr.Checkbox
|
| 23 |
+
threshold: gr.Slider
|
| 24 |
+
min_speech_duration_ms: gr.Number
|
| 25 |
+
max_speech_duration_s: gr.Number
|
| 26 |
+
min_silence_duration_ms: gr.Number
|
| 27 |
+
speech_pad_ms: gr.Number
|
| 28 |
+
chunk_length_s: gr.Number
|
| 29 |
+
batch_size: gr.Number
|
| 30 |
+
is_diarize: gr.Checkbox
|
| 31 |
+
hf_token: gr.Textbox
|
| 32 |
+
diarization_device: gr.Dropdown
|
| 33 |
+
length_penalty: gr.Number
|
| 34 |
+
repetition_penalty: gr.Number
|
| 35 |
+
no_repeat_ngram_size: gr.Number
|
| 36 |
+
prefix: gr.Textbox
|
| 37 |
+
suppress_blank: gr.Checkbox
|
| 38 |
+
suppress_tokens: gr.Textbox
|
| 39 |
+
max_initial_timestamp: gr.Number
|
| 40 |
+
word_timestamps: gr.Checkbox
|
| 41 |
+
prepend_punctuations: gr.Textbox
|
| 42 |
+
append_punctuations: gr.Textbox
|
| 43 |
+
max_new_tokens: gr.Number
|
| 44 |
+
chunk_length: gr.Number
|
| 45 |
+
hallucination_silence_threshold: gr.Number
|
| 46 |
+
hotwords: gr.Textbox
|
| 47 |
+
language_detection_threshold: gr.Number
|
| 48 |
+
language_detection_segments: gr.Number
|
| 49 |
+
"""
|
| 50 |
+
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
| 51 |
+
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
| 52 |
+
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
| 53 |
+
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
| 54 |
+
|
| 55 |
+
Attributes
|
| 56 |
+
----------
|
| 57 |
+
model_size: gr.Dropdown
|
| 58 |
+
Whisper model size.
|
| 59 |
+
|
| 60 |
+
lang: gr.Dropdown
|
| 61 |
+
Source language of the file to transcribe.
|
| 62 |
+
|
| 63 |
+
is_translate: gr.Checkbox
|
| 64 |
+
Boolean value that determines whether to translate to English.
|
| 65 |
+
It's Whisper's feature to translate speech from another language directly into English end-to-end.
|
| 66 |
+
|
| 67 |
+
beam_size: gr.Number
|
| 68 |
+
Int value that is used for decoding option.
|
| 69 |
+
|
| 70 |
+
log_prob_threshold: gr.Number
|
| 71 |
+
If the average log probability over sampled tokens is below this value, treat as failed.
|
| 72 |
+
|
| 73 |
+
no_speech_threshold: gr.Number
|
| 74 |
+
If the no_speech probability is higher than this value AND
|
| 75 |
+
the average log probability over sampled tokens is below `log_prob_threshold`,
|
| 76 |
+
consider the segment as silent.
|
| 77 |
+
|
| 78 |
+
compute_type: gr.Dropdown
|
| 79 |
+
compute type for transcription.
|
| 80 |
+
see more info : https://opennmt.net/CTranslate2/quantization.html
|
| 81 |
+
|
| 82 |
+
best_of: gr.Number
|
| 83 |
+
Number of candidates when sampling with non-zero temperature.
|
| 84 |
+
|
| 85 |
+
patience: gr.Number
|
| 86 |
+
Beam search patience factor.
|
| 87 |
+
|
| 88 |
+
condition_on_previous_text: gr.Checkbox
|
| 89 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
| 90 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
| 91 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
| 92 |
+
|
| 93 |
+
initial_prompt: gr.Textbox
|
| 94 |
+
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
| 95 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
| 96 |
+
to make it more likely to predict those word correctly.
|
| 97 |
+
|
| 98 |
+
temperature: gr.Slider
|
| 99 |
+
Temperature for sampling. It can be a tuple of temperatures,
|
| 100 |
+
which will be successively used upon failures according to either
|
| 101 |
+
`compression_ratio_threshold` or `log_prob_threshold`.
|
| 102 |
+
|
| 103 |
+
compression_ratio_threshold: gr.Number
|
| 104 |
+
If the gzip compression ratio is above this value, treat as failed
|
| 105 |
+
|
| 106 |
+
vad_filter: gr.Checkbox
|
| 107 |
+
Enable the voice activity detection (VAD) to filter out parts of the audio
|
| 108 |
+
without speech. This step is using the Silero VAD model
|
| 109 |
+
https://github.com/snakers4/silero-vad.
|
| 110 |
+
|
| 111 |
+
threshold: gr.Slider
|
| 112 |
+
This parameter is related with Silero VAD. Speech threshold.
|
| 113 |
+
Silero VAD outputs speech probabilities for each audio chunk,
|
| 114 |
+
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
| 115 |
+
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
| 116 |
+
|
| 117 |
+
min_speech_duration_ms: gr.Number
|
| 118 |
+
This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
|
| 119 |
+
|
| 120 |
+
max_speech_duration_s: gr.Number
|
| 121 |
+
This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
|
| 122 |
+
than max_speech_duration_s will be split at the timestamp of the last silence that
|
| 123 |
+
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
| 124 |
+
split aggressively just before max_speech_duration_s.
|
| 125 |
+
|
| 126 |
+
min_silence_duration_ms: gr.Number
|
| 127 |
+
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
| 128 |
+
before separating it
|
| 129 |
+
|
| 130 |
+
speech_pad_ms: gr.Number
|
| 131 |
+
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
| 132 |
+
|
| 133 |
+
chunk_length_s: gr.Number
|
| 134 |
+
This parameter is related with insanely-fast-whisper pipe.
|
| 135 |
+
Maximum length of each chunk
|
| 136 |
+
|
| 137 |
+
batch_size: gr.Number
|
| 138 |
+
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
| 139 |
+
|
| 140 |
+
is_diarize: gr.Checkbox
|
| 141 |
+
This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
|
| 142 |
+
|
| 143 |
+
hf_token: gr.Textbox
|
| 144 |
+
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
| 145 |
+
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
| 146 |
+
|
| 147 |
+
diarization_device: gr.Dropdown
|
| 148 |
+
This parameter is related with whisperx. Device to run diarization model
|
| 149 |
+
|
| 150 |
+
length_penalty:
|
| 151 |
+
This parameter is related to faster-whisper. Exponential length penalty constant.
|
| 152 |
+
|
| 153 |
+
repetition_penalty:
|
| 154 |
+
This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
|
| 155 |
+
(set > 1 to penalize).
|
| 156 |
+
|
| 157 |
+
no_repeat_ngram_size:
|
| 158 |
+
This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
|
| 159 |
+
|
| 160 |
+
prefix:
|
| 161 |
+
This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
|
| 162 |
+
|
| 163 |
+
suppress_blank:
|
| 164 |
+
This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
|
| 165 |
+
|
| 166 |
+
suppress_tokens:
|
| 167 |
+
This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
|
| 168 |
+
of symbols as defined in the model config.json file.
|
| 169 |
+
|
| 170 |
+
max_initial_timestamp:
|
| 171 |
+
This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
|
| 172 |
+
|
| 173 |
+
word_timestamps:
|
| 174 |
+
This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
|
| 175 |
+
and dynamic time warping, and include the timestamps for each word in each segment.
|
| 176 |
+
|
| 177 |
+
prepend_punctuations:
|
| 178 |
+
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
| 179 |
+
with the next word.
|
| 180 |
+
|
| 181 |
+
append_punctuations:
|
| 182 |
+
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
| 183 |
+
with the previous word.
|
| 184 |
+
|
| 185 |
+
max_new_tokens:
|
| 186 |
+
This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
|
| 187 |
+
the maximum will be set by the default max_length.
|
| 188 |
+
|
| 189 |
+
chunk_length:
|
| 190 |
+
This parameter is related to faster-whisper. The length of audio segments. If it is not None, it will overwrite the
|
| 191 |
+
default chunk_length of the FeatureExtractor.
|
| 192 |
+
|
| 193 |
+
hallucination_silence_threshold:
|
| 194 |
+
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
| 195 |
+
(in seconds) when a possible hallucination is detected.
|
| 196 |
+
|
| 197 |
+
hotwords:
|
| 198 |
+
This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
| 199 |
+
|
| 200 |
+
language_detection_threshold:
|
| 201 |
+
This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
|
| 202 |
+
|
| 203 |
+
language_detection_segments:
|
| 204 |
+
This parameter is related to faster-whisper. Number of segments to consider for the language detection.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def as_list(self) -> list:
|
| 208 |
+
"""
|
| 209 |
+
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
|
| 210 |
+
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
| 211 |
+
|
| 212 |
+
Returns
|
| 213 |
+
----------
|
| 214 |
+
A list of Gradio components
|
| 215 |
+
"""
|
| 216 |
+
return [getattr(self, f.name) for f in fields(self)]
|
| 217 |
+
|
| 218 |
+
@staticmethod
|
| 219 |
+
def as_value(*args) -> 'WhisperValues':
|
| 220 |
+
"""
|
| 221 |
+
To use Whisper parameters in function after Gradio post-processing.
|
| 222 |
+
See more about Gradio post-processing: : https://www.gradio.app/docs/components
|
| 223 |
+
|
| 224 |
+
Returns
|
| 225 |
+
----------
|
| 226 |
+
WhisperValues
|
| 227 |
+
Data class that has values of parameters
|
| 228 |
+
"""
|
| 229 |
+
return WhisperValues(*args)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@dataclass
|
| 233 |
+
class WhisperValues:
|
| 234 |
+
model_size: str
|
| 235 |
+
lang: str
|
| 236 |
+
is_translate: bool
|
| 237 |
+
beam_size: int
|
| 238 |
+
log_prob_threshold: float
|
| 239 |
+
no_speech_threshold: float
|
| 240 |
+
compute_type: str
|
| 241 |
+
best_of: int
|
| 242 |
+
patience: float
|
| 243 |
+
condition_on_previous_text: bool
|
| 244 |
+
prompt_reset_on_temperature: float
|
| 245 |
+
initial_prompt: Optional[str]
|
| 246 |
+
temperature: float
|
| 247 |
+
compression_ratio_threshold: float
|
| 248 |
+
vad_filter: bool
|
| 249 |
+
threshold: float
|
| 250 |
+
min_speech_duration_ms: int
|
| 251 |
+
max_speech_duration_s: float
|
| 252 |
+
min_silence_duration_ms: int
|
| 253 |
+
speech_pad_ms: int
|
| 254 |
+
chunk_length_s: int
|
| 255 |
+
batch_size: int
|
| 256 |
+
is_diarize: bool
|
| 257 |
+
hf_token: str
|
| 258 |
+
diarization_device: str
|
| 259 |
+
length_penalty: float
|
| 260 |
+
repetition_penalty: float
|
| 261 |
+
no_repeat_ngram_size: int
|
| 262 |
+
prefix: Optional[str]
|
| 263 |
+
suppress_blank: bool
|
| 264 |
+
suppress_tokens: Optional[str]
|
| 265 |
+
max_initial_timestamp: float
|
| 266 |
+
word_timestamps: bool
|
| 267 |
+
prepend_punctuations: Optional[str]
|
| 268 |
+
append_punctuations: Optional[str]
|
| 269 |
+
max_new_tokens: Optional[int]
|
| 270 |
+
chunk_length: Optional[int]
|
| 271 |
+
hallucination_silence_threshold: Optional[float]
|
| 272 |
+
hotwords: Optional[str]
|
| 273 |
+
language_detection_threshold: Optional[float]
|
| 274 |
+
language_detection_segments: int
|
| 275 |
+
"""
|
| 276 |
+
A data class to use Whisper parameters.
|
| 277 |
+
"""
|
outputs/outputs are saved here.txt
ADDED
|
File without changes
|
outputs/translations/outputs for translation are saved here.txt
ADDED
|
File without changes
|