erastorgueva-nv's picture
update Audio param to sources
bcd5b7d
import gradio as gr
import librosa
import soundfile
import tempfile
import os
import uuid
import json
import jieba
import nemo.collections.asr as nemo_asr
from nemo.collections.asr.models import ASRModel
from nemo.utils import logging
from align import main, AlignmentConfig, ASSFileConfig
SAMPLE_RATE = 16000
# Pre-download and cache the model in disk space
logging.setLevel(logging.ERROR)
for tmp_model_name in [
"stt_en_fastconformer_hybrid_large_pc",
"stt_de_fastconformer_hybrid_large_pc",
"stt_es_fastconformer_hybrid_large_pc",
"stt_fr_conformer_ctc_large",
"stt_zh_citrinet_1024_gamma_0_25",
]:
tmp_model = ASRModel.from_pretrained(tmp_model_name, map_location='cpu')
del tmp_model
logging.setLevel(logging.INFO)
def get_audio_data_and_duration(file):
data, sr = librosa.load(file)
if sr != SAMPLE_RATE:
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
# monochannel
data = librosa.to_mono(data)
duration = librosa.get_duration(y=data, sr=SAMPLE_RATE)
return data, duration
def get_char_tokens(text, model):
tokens = []
for character in text:
if character in model.decoder.vocabulary:
tokens.append(model.decoder.vocabulary.index(character))
else:
tokens.append(len(model.decoder.vocabulary)) # return unk token (same as blank token)
return tokens
def get_S_prime_and_T(text, model_name, model, audio_duration):
# estimate T
if "citrinet" in model_name or "_fastconformer_" in model_name:
output_timestep_duration = 0.08
elif "_conformer_" in model_name:
output_timestep_duration = 0.04
elif "quartznet" in model_name:
output_timestep_duration = 0.02
else:
raise RuntimeError("unexpected model name")
T = int(audio_duration / output_timestep_duration) + 1
# calculate S_prime = num tokens + num repetitions
if hasattr(model, 'tokenizer'):
all_tokens = model.tokenizer.text_to_ids(text)
elif hasattr(model.decoder, "vocabulary"): # i.e. tokenization is simply character-based
all_tokens = get_char_tokens(text, model)
else:
raise RuntimeError("cannot obtain tokens from this model")
n_token_repetitions = 0
for i_tok in range(1, len(all_tokens)):
if all_tokens[i_tok] == all_tokens[i_tok - 1]:
n_token_repetitions += 1
S_prime = len(all_tokens) + n_token_repetitions
return S_prime, T
def hex_to_rgb_list(hex_string):
hex_string = hex_string.lstrip("#")
r = int(hex_string[:2], 16)
g = int(hex_string[2:4], 16)
b = int(hex_string[4:], 16)
return [r, g, b]
def delete_mp4s_except_given_filepath(filepath):
files_in_dir = os.listdir()
mp4_files_in_dir = [x for x in files_in_dir if x.endswith(".mp4")]
for mp4_file in mp4_files_in_dir:
if mp4_file != filepath:
os.remove(mp4_file)
def align(lang, Microphone, File_Upload, text, col1, col2, col3, progress=gr.Progress()):
# Create utt_id, specify output_video_filepath and delete any MP4s
# that are not that filepath. These stray MP4s can be created
# if a user refreshes or exits the page while this 'align' function is executing.
# This deletion will not delete any other users' video as long as this 'align' function
# is run one at a time.
utt_id = uuid.uuid4()
output_video_filepath = f"{utt_id}.mp4"
delete_mp4s_except_given_filepath(output_video_filepath)
output_info = ""
progress(0, desc="Validating input")
# choose model
if lang in ["en", "de", "es"]:
model_name = f"stt_{lang}_fastconformer_hybrid_large_pc"
elif lang in ["fr"]:
model_name = f"stt_{lang}_conformer_ctc_large"
elif lang in ["zh"]:
model_name = f"stt_{lang}_citrinet_1024_gamma_0_25"
# decide which of Mic / File_Upload is used as input & do error handling
if (Microphone is not None) and (File_Upload is not None):
raise gr.Error("Please use either the microphone or file upload input - not both")
elif (Microphone is None) and (File_Upload is None):
raise gr.Error("You have to either use the microphone or upload an audio file")
elif Microphone is not None:
file = Microphone
else:
file = File_Upload
# check audio is not too long
audio_data, duration = get_audio_data_and_duration(file)
if duration > 4 * 60:
raise gr.Error(
f"Detected that uploaded audio has duration {duration/60:.1f} mins - please only upload audio of less than 4 mins duration"
)
# loading model
progress(0.1, desc="Loading speech recognition model")
model = ASRModel.from_pretrained(model_name)
if text: # check input text is not too long compared to audio
S_prime, T = get_S_prime_and_T(text, model_name, model, duration)
if S_prime > T:
raise gr.Error(
f"The number of tokens in the input text is too long compared to the duration of the audio."
f" This model can handle {T} tokens + token repetitions at most. You have provided {S_prime} tokens + token repetitions. "
f" (Adjacent tokens that are not in the model's vocabulary are also counted as a token repetition.)"
)
with tempfile.TemporaryDirectory() as tmpdir:
audio_path = os.path.join(tmpdir, f'{utt_id}.wav')
soundfile.write(audio_path, audio_data, SAMPLE_RATE)
# getting the text if it hasn't been provided
if not text:
progress(0.2, desc="Transcribing audio")
text = model.transcribe([audio_path])[0]
if 'hybrid' in model_name:
text = text[0]
if text == "":
raise gr.Error(
"ERROR: the ASR model did not detect any speech in the input audio. Please upload audio with speech."
)
output_info += (
"You did not enter any input text, so the ASR model's transcription will be used:\n"
"--------------------------\n"
f"{text}\n"
"--------------------------\n"
f"You could try pasting the transcription into the text input box, correcting any"
" transcription errors, and clicking 'Submit' again."
)
if lang == "zh" and " " not in text:
# use jieba to add spaces between zh characters
text = " ".join(jieba.cut(text))
data = {
"audio_filepath": audio_path,
"text": text,
}
manifest_path = os.path.join(tmpdir, f"{utt_id}_manifest.json")
with open(manifest_path, 'w') as fout:
fout.write(f"{json.dumps(data)}\n")
# run alignment
if "|" in text:
resegment_text_to_fill_space = False
else:
resegment_text_to_fill_space = True
alignment_config = AlignmentConfig(
pretrained_name=model_name,
manifest_filepath=manifest_path,
output_dir=f"{tmpdir}/nfa_output/",
audio_filepath_parts_in_utt_id=1,
batch_size=1,
use_local_attention=True,
additional_segment_grouping_separator="|",
# transcribe_device='cpu',
# viterbi_device='cpu',
save_output_file_formats=["ass"],
ass_file_config=ASSFileConfig(
fontsize=45,
resegment_text_to_fill_space=resegment_text_to_fill_space,
max_lines_per_segment=4,
text_already_spoken_rgb=hex_to_rgb_list(col1),
text_being_spoken_rgb=hex_to_rgb_list(col2),
text_not_yet_spoken_rgb=hex_to_rgb_list(col3),
),
)
progress(0.5, desc="Aligning audio")
main(alignment_config)
progress(0.95, desc="Saving generated alignments")
if lang=="zh":
# make video file from the token-level ASS file
ass_file_for_video = f"{tmpdir}/nfa_output/ass/tokens/{utt_id}.ass"
else:
# make video file from the word-level ASS file
ass_file_for_video = f"{tmpdir}/nfa_output/ass/words/{utt_id}.ass"
ffmpeg_command = (
f"ffmpeg -y -i {audio_path} "
"-f lavfi -i color=c=white:s=1280x720:r=50 "
"-crf 1 -shortest -vcodec libx264 -pix_fmt yuv420p "
f"-vf 'ass={ass_file_for_video}' "
f"{output_video_filepath}"
)
os.system(ffmpeg_command)
return output_video_filepath, gr.update(value=output_info, visible=True), output_video_filepath
def delete_non_tmp_video(video_path):
if video_path:
if os.path.exists(video_path):
os.remove(video_path)
return None
with gr.Blocks(title="NeMo Forced Aligner", theme="huggingface") as demo:
non_tmp_output_video_filepath = gr.State([])
with gr.Row():
with gr.Column():
gr.Markdown("# NeMo Forced Aligner")
gr.Markdown(
"Demo for [NeMo Forced Aligner](https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner) (NFA). "
"Upload audio and (optionally) the text spoken in the audio to generate a video where each part of the text will be highlighted as it is spoken. ",
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Input")
lang_drop = gr.Dropdown(choices=["de", "en", "es", "fr", "zh"], value="en", label="Audio language",)
mic_in = gr.Audio(sources=["microphone"], type='filepath', label="Microphone input (max 4 mins)")
audio_file_in = gr.Audio(sources=["upload"], type='filepath', label="File upload (max 4 mins)")
ref_text = gr.Textbox(
label="[Optional] The reference text. Use '|' separators to specify which text will appear together. "
"Leave this field blank to use an ASR model's transcription as the reference text instead."
)
gr.Markdown("[Optional] For fun - adjust the colors of the text in the output video")
with gr.Row():
col1 = gr.ColorPicker(label="text already spoken", value="#fcba03")
col2 = gr.ColorPicker(label="text being spoken", value="#bf45bf")
col3 = gr.ColorPicker(label="text to be spoken", value="#3e1af0")
submit_button = gr.Button("Submit")
with gr.Column(scale=1):
gr.Markdown("## Output")
video_out = gr.Video(label="output video")
text_out = gr.Textbox(label="output info", visible=False)
with gr.Row():
gr.HTML(
"<p style='text-align: center'>"
"Tutorial: <a href='https://colab.research.google.com/github/NVIDIA/NeMo/blob/main/tutorials/tools/NeMo_Forced_Aligner_Tutorial.ipynb' target='_blank'>\"How to use NFA?\"</a> πŸš€ | "
"Blog post: <a href='https://nvidia.github.io/NeMo/blogs/2023/2023-08-forced-alignment/' target='_blank'>\"How does forced alignment work?\"</a> πŸ“š | "
"NFA <a href='https://github.com/NVIDIA/NeMo/tree/main/tools/nemo_forced_aligner/' target='_blank'>Github page</a> πŸ‘©β€πŸ’»"
"</p>"
)
submit_button.click(
fn=align,
inputs=[lang_drop, mic_in, audio_file_in, ref_text, col1, col2, col3,],
outputs=[video_out, text_out, non_tmp_output_video_filepath],
).then(
fn=delete_non_tmp_video, inputs=[non_tmp_output_video_filepath], outputs=None,
)
demo.queue()
demo.launch()