|
import torch |
|
import gradio as gr |
|
from faster_whisper import WhisperModel |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
from pydub import AudioSegment |
|
import yt_dlp as youtube_dl |
|
import tempfile |
|
from transformers.pipelines.audio_utils import ffmpeg_read |
|
from gradio.components import Audio, Dropdown, Radio, Textbox |
|
import os |
|
import numpy as np |
|
import soundfile as sf |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
|
|
FILE_LIMIT_MB = 1000 |
|
YT_LENGTH_LIMIT_S = 3600 |
|
|
|
|
|
from flores200_codes import flores_codes |
|
|
|
|
|
def set_device(): |
|
return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
device = set_device() |
|
|
|
|
|
|
|
model_dict = {} |
|
def load_models(): |
|
global model_dict |
|
if not model_dict: |
|
model_name_dict = { |
|
|
|
'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M', |
|
|
|
|
|
|
|
|
|
} |
|
for call_name, real_name in model_name_dict.items(): |
|
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) |
|
tokenizer = AutoTokenizer.from_pretrained(real_name) |
|
model_dict[call_name+'_model'] = model |
|
model_dict[call_name+'_tokenizer'] = tokenizer |
|
|
|
load_models() |
|
|
|
model_size = "large-v2" |
|
model = WhisperModel(model_size) |
|
|
|
|
|
|
|
def transcribe_audio(audio_file): |
|
|
|
|
|
|
|
global model |
|
segments, _ = model.transcribe(audio_file, beam_size=1) |
|
transcriptions = [("[%.2fs -> %.2fs]" % (seg.start, seg.end), seg.text) for seg in segments] |
|
return transcriptions |
|
|
|
|
|
|
|
def traduction(text, source_lang, target_lang): |
|
|
|
if source_lang not in flores_codes or target_lang not in flores_codes: |
|
print(f"Code de langue non trouvé : {source_lang} ou {target_lang}") |
|
return "" |
|
|
|
src_code = flores_codes[source_lang] |
|
tgt_code = flores_codes[target_lang] |
|
|
|
model_name = "nllb-distilled-600M" |
|
model = model_dict[model_name + "_model"] |
|
tokenizer = model_dict[model_name + "_tokenizer"] |
|
translator = pipeline("translation", model=model, tokenizer=tokenizer) |
|
|
|
return translator(text, src_lang=src_code, tgt_lang=tgt_code)[0]["translation_text"] |
|
|
|
|
|
|
|
def full_transcription_and_translation(audio_input, source_lang, target_lang): |
|
|
|
if isinstance(audio_input, str) and audio_input.startswith("http"): |
|
audio_file = download_yt_audio(audio_input) |
|
|
|
elif isinstance(audio_input, dict) and "array" in audio_input and "sampling_rate" in audio_input: |
|
audio_array = audio_input["array"] |
|
sampling_rate = audio_input["sampling_rate"] |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as f: |
|
sf.write(f, audio_array, sampling_rate) |
|
audio_file = f.name |
|
else: |
|
|
|
audio_file = audio_input |
|
|
|
transcriptions = transcribe_audio(audio_file) |
|
translations = [(timestamp, traduction(text, source_lang, target_lang)) for timestamp, text in transcriptions] |
|
|
|
|
|
if isinstance(audio_input, dict): |
|
os.remove(audio_file) |
|
|
|
return transcriptions, translations |
|
|
|
|
|
"""def download_yt_audio(yt_url): |
|
with tempfile.NamedTemporaryFile(suffix='.mp3') as f: |
|
ydl_opts = { |
|
'format': 'bestaudio/best', |
|
'outtmpl': f.name, |
|
'postprocessors': [{ |
|
'key': 'FFmpegExtractAudio', |
|
'preferredcodec': 'mp3', |
|
'preferredquality': '192', |
|
}], |
|
} |
|
with youtube_dl.YoutubeDL(ydl_opts) as ydl: |
|
ydl.download([yt_url]) |
|
return f.name""" |
|
|
|
lang_codes = list(flores_codes.keys()) |
|
|
|
|
|
def gradio_interface(audio_file, source_lang, target_lang): |
|
if audio_file.startswith("http"): |
|
audio_file = download_yt_audio(audio_file) |
|
transcriptions, translations = full_transcription_and_translation(audio_file, source_lang, target_lang) |
|
transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions]) |
|
translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations]) |
|
return transcribed_text, translated_text |
|
|
|
|
|
def _return_yt_html_embed(yt_url): |
|
video_id = yt_url.split("?v=")[-1] |
|
HTML_str = ( |
|
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' |
|
" </center>" |
|
) |
|
return HTML_str |
|
|
|
def download_yt_audio(yt_url, filename): |
|
info_loader = youtube_dl.YoutubeDL() |
|
|
|
try: |
|
info = info_loader.extract_info(yt_url, download=False) |
|
except youtube_dl.utils.DownloadError as err: |
|
raise gr.Error(str(err)) |
|
|
|
file_length = info["duration_string"] |
|
file_h_m_s = file_length.split(":") |
|
file_h_m_s = [int(sub_length) for sub_length in file_h_m_s] |
|
|
|
if len(file_h_m_s) == 1: |
|
file_h_m_s.insert(0, 0) |
|
if len(file_h_m_s) == 2: |
|
file_h_m_s.insert(0, 0) |
|
file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2] |
|
|
|
if file_length_s > YT_LENGTH_LIMIT_S: |
|
yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S)) |
|
file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s)) |
|
raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.") |
|
|
|
ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"} |
|
|
|
with youtube_dl.YoutubeDL(ydl_opts) as ydl: |
|
try: |
|
ydl.download([yt_url]) |
|
except youtube_dl.utils.ExtractorError as err: |
|
raise gr.Error(str(err)) |
|
|
|
|
|
def yt_transcribe(yt_url, task, max_filesize=75.0): |
|
html_embed_str = _return_yt_html_embed(yt_url) |
|
global model |
|
|
|
with tempfile.TemporaryDirectory() as tmpdirname: |
|
filepath = os.path.join(tmpdirname, "video.mp4") |
|
download_yt_audio(yt_url, filepath) |
|
with open(filepath, "rb") as f: |
|
inputs = f.read() |
|
|
|
inputs = ffmpeg_read(inputs, model.feature_extractor.sampling_rate) |
|
inputs = {"array": inputs, "sampling_rate": model.feature_extractor.sampling_rate} |
|
|
|
transcriptions, translations = full_transcription_and_translation(inputs, source_lang, target_lang) |
|
transcribed_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in transcriptions]) |
|
translated_text = '\n'.join([f"{timestamp}: {text}" for timestamp, text in translations]) |
|
return html_embed_str, transcribed_text, translated_text |
|
|
|
|
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
with gr.Tab("Microphone"): |
|
gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Audio(sources=["microphone"], type="filepath"), |
|
gr.Dropdown(lang_codes, value='French', label='Source Language'), |
|
gr.Dropdown(lang_codes, value='English', label='Target Language')], |
|
outputs=[gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")] |
|
) |
|
|
|
with gr.Tab("Audio file"): |
|
gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.Audio(type="filepath", label="Audio file"), |
|
gr.Dropdown(lang_codes, value='French', label='Source Language'), |
|
gr.Dropdown(lang_codes, value='English', label='Target Language')], |
|
outputs=[gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")] |
|
) |
|
|
|
with gr.Tab("YouTube"): |
|
gr.Interface( |
|
fn=yt_transcribe, |
|
inputs=[ |
|
gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), |
|
gr.Dropdown(lang_codes, value='French', label='Source Language'), |
|
gr.Dropdown(lang_codes, value='English', label='Target Language') |
|
], |
|
outputs=["html", gr.Textbox(label="Transcribed Text"), gr.Textbox(label="Translated Text")] |
|
) |
|
|
|
|
|
|
|
|
|
demo.launch() |