Demo / modules /whisper /whisper_base.py
LAP-DEV's picture
Update modules/whisper/whisper_base.py
6e270b6 verified
import os
import torch
import whisper
import gradio as gr
import torchaudio
from abc import ABC, abstractmethod
from typing import BinaryIO, Union, Tuple, List
import numpy as np
from datetime import datetime
from faster_whisper.vad import VadOptions
from dataclasses import astuple
import gc
from copy import deepcopy
from modules.uvr.music_separator import MusicSeparator
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
UVR_MODELS_DIR)
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_csv, write_file, safe_filename
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
from modules.whisper.whisper_parameter import *
from modules.diarize.diarizer import Diarizer
from modules.vad.silero_vad import SileroVAD
from modules.translation.nllb_inference import NLLBInference
from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
class WhisperBase(ABC):
def __init__(self,
model_dir: str = WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
uvr_model_dir: str = UVR_MODELS_DIR,
output_dir: str = OUTPUT_DIR,
):
self.model_dir = model_dir
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.model_dir, exist_ok=True)
self.diarizer = Diarizer(
model_dir=diarization_model_dir
)
self.vad = SileroVAD()
self.music_separator = MusicSeparator(
model_dir=uvr_model_dir,
output_dir=os.path.join(output_dir, "UVR")
)
self.model = None
self.current_model_size = None
self.available_models = whisper.available_models()
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
#self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
self.translatable_models = whisper.available_models()
self.device = self.get_device()
self.available_compute_types = ["float16", "float32"]
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
@abstractmethod
def transcribe(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
*whisper_params,
):
"""Inference whisper model to transcribe"""
pass
@abstractmethod
def update_model(self,
model_size: str,
compute_type: str,
progress: gr.Progress = gr.Progress()
):
"""Initialize whisper model"""
pass
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
add_timestamp: bool = True,
*whisper_params,
) -> Tuple[List[dict], float]:
"""
Run transcription with conditional pre-processing and post-processing.
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
The diarization will be performed in post-processing, if enabled.
Parameters
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio input. This can be file path or binary type.
progress: gr.Progress
Indicator to show progress directly in gradio.
add_timestamp: bool
Whether to add a timestamp at the end of the filename.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
start_time = datetime.now()
params = WhisperParameters.as_value(*whisper_params)
# Get the offload params
default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
whisper_params = default_params["whisper"]
diarization_params = default_params["diarization"]
bool_whisper_enable_offload = whisper_params["enable_offload"]
bool_diarization_enable_offload = diarization_params["enable_offload"]
if params.lang is None:
pass
elif params.lang == "Automatic Detection":
params.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.lang = language_code_dict[params.lang]
if params.is_bgm_separate:
music, audio, _ = self.music_separator.separate(
audio=audio,
model_name=params.uvr_model_size,
device=params.uvr_device,
segment_size=params.uvr_segment_size,
save_file=params.uvr_save_file,
progress=progress
)
if audio.ndim >= 2:
audio = audio.mean(axis=1)
if self.music_separator.audio_info is None:
origin_sample_rate = 16000
else:
origin_sample_rate = self.music_separator.audio_info.sample_rate
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
if params.uvr_enable_offload:
self.music_separator.offload()
elapsed_time_bgm_sep = datetime.now() - start_time
origin_audio = deepcopy(audio)
if params.vad_filter:
# Explicit value set for float('inf') from gr.Number()
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
params.max_speech_duration_s = float('inf')
progress(0, desc="Filtering silent parts from audio...")
vad_options = VadOptions(
threshold=params.threshold,
min_speech_duration_ms=params.min_speech_duration_ms,
max_speech_duration_s=params.max_speech_duration_s,
min_silence_duration_ms=params.min_silence_duration_ms,
speech_pad_ms=params.speech_pad_ms
)
vad_processed, speech_chunks = self.vad.run(
audio=audio,
vad_parameters=vad_options,
progress=progress
)
if vad_processed.size > 0:
audio = vad_processed
else:
params.vad_filter = False
result, elapsed_time = self.transcribe(
audio,
progress,
*astuple(params)
)
if bool_whisper_enable_offload:
self.offload()
if params.vad_filter:
restored_result = self.vad.restore_speech_timestamps(
segments=result,
speech_chunks=speech_chunks,
)
if restored_result:
result = restored_result
else:
print("VAD detected no speech segments in the audio.")
if params.is_diarize:
progress(0.99, desc="Diarizing speakers...")
result, elapsed_time_diarization = self.diarizer.run(
audio=origin_audio,
use_auth_token=params.hf_token,
transcribed_result=result,
device=params.diarization_device
)
if bool_diarization_enable_offload:
self.diarizer.offload()
if not result:
print(f"Whisper did not detected any speech segments in the audio.")
result = list()
progress(1.0, desc="Processing done!")
total_elapsed_time = datetime.now() - start_time
return result, elapsed_time
def transcribe_file(self,
files: Optional[List] = None,
input_folder_path: Optional[str] = None,
file_format: str = "SRT",
add_timestamp: bool = True,
translate_output: bool = False,
translate_model: str = "",
target_lang: str = "",
progress=gr.Progress(),
*whisper_params,
) -> list:
"""
Write subtitle file from Files
Parameters
----------
files: list
List of files to transcribe from gr.Files()
input_folder_path: str
Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
this will be used instead.
file_format: str
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
add_timestamp: bool
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
translate_output: bool
Translate output
translate_model: str
Translation model to use
target_lang: str
Target language to use
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
----------
result_str:
Result of transcription to return to gr.Textbox()
result_file_path:
Output file path to return to gr.Files()
"""
try:
if input_folder_path:
files = get_media_files(input_folder_path)
if isinstance(files, str):
files = [files]
if files and isinstance(files[0], gr.utils.NamedString):
files = [file.name for file in files]
## Initialization variables & start time
files_info = {}
files_to_download = {}
time_start = datetime.now()
## Load parameters related with whisper
params = WhisperParameters.as_value(*whisper_params)
## Load model to detect language
model = whisper.load_model("base")
for file in files:
## Detect language
mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
_, probs = model.detect_language(mel)
file_language = ""
file_lang_probs = ""
for key,value in whisper.tokenizer.LANGUAGES.items():
if key == str(max(probs, key=probs.get)):
file_language = value.capitalize()
for key_prob,value_prob in probs.items():
if key == key_prob:
file_lang_probs = str((round(value_prob*100)))
break
break
transcribed_segments, time_for_task = self.run(
file,
progress,
add_timestamp,
*whisper_params,
)
# Define source language
source_lang = file_language
# Translate to English using Whisper built-in functionality
transcription_note = ""
if params.is_translate:
if source_lang != "English":
transcription_note = "To English"
source_lang = "English"
else:
transcription_note = "Already in English"
# Translate the transcribed segments
translation_note = ""
if translate_output:
if source_lang != target_lang:
self.nllb_inf = NLLBInference()
if source_lang in NLLB_AVAILABLE_LANGS.keys():
transcribed_segments = self.nllb_inf.translate_text(
input_list_dict=transcribed_segments,
model_size=translate_model,
src_lang=source_lang,
tgt_lang=target_lang,
speaker_diarization=params.is_diarize
)
translation_note = "To " + target_lang
else:
translation_note = source_lang + " not supported"
else:
translation_note = "Already in " + target_lang
## Get preview as txt
file_name, file_ext = os.path.splitext(os.path.basename(file))
subtitle = get_txt(transcribed_segments)
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
## Add output file as txt
file_name, file_ext = os.path.splitext(os.path.basename(file))
subtitle, file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format="txt",
output_dir=self.output_dir
)
files_to_download[file_name+"_txt"] = {"path": file_path}
## Add output file as srt
file_name, file_ext = os.path.splitext(os.path.basename(file))
subtitle, file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format="srt",
output_dir=self.output_dir
)
files_to_download[file_name+"_srt"] = {"path": file_path}
## Add output file as csv
file_name, file_ext = os.path.splitext(os.path.basename(file))
subtitle, file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format="csv",
output_dir=self.output_dir
)
files_to_download[file_name+"_csv"] = {"path": file_path}
total_result = ''
total_info = ''
total_time = 0
for file_name, info in files_info.items():
total_result += f'{info["subtitle"]}'
total_time += info["time_for_task"]
total_info += f'Input file:\t\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
if params.is_translate:
total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
if translate_output:
total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
time_end = datetime.now()
total_info += f"\nTotal processing time: {self.format_time((time_end-time_start).total_seconds())}"
result_str = total_result.rstrip("\n")
result_file_path = [info['path'] for info in files_to_download.values()]
return [result_str,result_file_path,total_info]
except Exception as e:
print(f"Error transcribing file: {e}")
finally:
self.release_cuda_memory()
def transcribe_mic(self,
mic_audio: str,
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
"""
Write subtitle file from microphone
Parameters
----------
mic_audio: str
Audio file path from gr.Microphone()
file_format: str
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
add_timestamp: bool
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
----------
result_str:
Result of transcription to return to gr.Textbox()
result_file_path:
Output file path to return to gr.Files()
"""
try:
progress(0, desc="Loading Audio...")
transcribed_segments, time_for_task = self.run(
mic_audio,
progress,
add_timestamp,
*whisper_params,
)
progress(1, desc="Completed!")
subtitle, result_file_path = self.generate_and_write_file(
file_name="Mic",
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format,
output_dir=self.output_dir
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
return [result_str, result_file_path]
except Exception as e:
print(f"Error transcribing file: {e}")
finally:
self.release_cuda_memory()
def transcribe_youtube(self,
youtube_link: str,
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
"""
Write subtitle file from Youtube
Parameters
----------
youtube_link: str
URL of the Youtube video to transcribe from gr.Textbox()
file_format: str
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
add_timestamp: bool
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
----------
result_str:
Result of transcription to return to gr.Textbox()
result_file_path:
Output file path to return to gr.Files()
"""
try:
progress(0, desc="Loading Audio from Youtube...")
yt = get_ytdata(youtube_link)
audio = get_ytaudio(yt)
transcribed_segments, time_for_task = self.run(
audio,
progress,
add_timestamp,
*whisper_params,
)
progress(1, desc="Completed!")
file_name = safe_filename(yt.title)
subtitle, result_file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format,
output_dir=self.output_dir
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
if os.path.exists(audio):
os.remove(audio)
return [result_str, result_file_path]
except Exception as e:
print(f"Error transcribing file: {e}")
finally:
self.release_cuda_memory()
@staticmethod
def generate_and_write_file(file_name: str,
transcribed_segments: list,
add_timestamp: bool,
file_format: str,
output_dir: str
) -> str:
"""
Writes subtitle file
Parameters
----------
file_name: str
Output file name
transcribed_segments: list
Text segments transcribed from audio
add_timestamp: bool
Determines whether to add a timestamp to the end of the filename.
file_format: str
File format to write. Supported formats: [SRT, WebVTT, txt, csv]
output_dir: str
Directory path of the output
Returns
----------
content: str
Result of the transcription
output_path: str
output file path
"""
if add_timestamp:
#timestamp = datetime.now().strftime("%m%d%H%M%S")
timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
else:
output_path = os.path.join(output_dir, f"{file_name}")
file_format = file_format.strip().lower()
if file_format == "srt":
content = get_srt(transcribed_segments)
output_path += '.srt'
elif file_format == "webvtt":
content = get_vtt(transcribed_segments)
output_path += '.vtt'
elif file_format == "txt":
content = get_txt(transcribed_segments)
output_path += '.txt'
elif file_format == "csv":
content = get_csv(transcribed_segments)
output_path += '.csv'
write_file(content, output_path)
return content, output_path
def offload(self):
"""Offload the model and free up the memory"""
if self.model is not None:
del self.model
self.model = None
if self.device == "cuda":
self.release_cuda_memory()
gc.collect()
@staticmethod
def format_time(elapsed_time: float) -> str:
"""
Get {hours} {minutes} {seconds} time format string
Parameters
----------
elapsed_time: str
Elapsed time for transcription
Returns
----------
Time format string
"""
hours, rem = divmod(elapsed_time, 3600)
minutes, seconds = divmod(rem, 60)
time_str = ""
hours = round(hours)
if hours:
if hours == 1:
time_str += f"{hours} hour "
else:
time_str += f"{hours} hours "
minutes = round(minutes)
if minutes:
if minutes == 1:
time_str += f"{minutes} minute "
else:
time_str += f"{minutes} minutes "
seconds = round(seconds)
if seconds == 1:
time_str += f"{seconds} second"
else:
time_str += f"{seconds} seconds"
return time_str.strip()
@staticmethod
def get_device():
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
if not WhisperBase.is_sparse_api_supported():
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
return "cpu"
return "mps"
else:
return "cpu"
@staticmethod
def is_sparse_api_supported():
if not torch.backends.mps.is_available():
return False
try:
device = torch.device("mps")
sparse_tensor = torch.sparse_coo_tensor(
indices=torch.tensor([[0, 1], [2, 3]]),
values=torch.tensor([1, 2]),
size=(4, 4),
device=device
)
return True
except RuntimeError:
return False
@staticmethod
def release_cuda_memory():
"""Release memory"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
@staticmethod
def remove_input_files(file_paths: List[str]):
"""Remove gradio cached files"""
if not file_paths:
return
for file_path in file_paths:
if file_path and os.path.exists(file_path):
os.remove(file_path)
@staticmethod
def cache_parameters(
whisper_params: WhisperValues,
add_timestamp: bool
):
"""cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
cached_whisper_param = whisper_params.to_yaml()
cached_yaml = {**cached_params, **cached_whisper_param}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
@staticmethod
def resample_audio(audio: Union[str, np.ndarray],
new_sample_rate: int = 16000,
original_sample_rate: Optional[int] = None,) -> np.ndarray:
"""Resamples audio to 16k sample rate, standard on Whisper model"""
if isinstance(audio, str):
audio, original_sample_rate = torchaudio.load(audio)
else:
if original_sample_rate is None:
raise ValueError("original_sample_rate must be provided when audio is numpy array.")
audio = torch.from_numpy(audio)
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
resampled_audio = resampler(audio).numpy()
return resampled_audio