Yilin0601's picture
Update app.py
c7f56a8 verified
raw
history blame
7.08 kB
import gradio as gr
import torch
import numpy as np
import librosa
from transformers import pipeline
# --------------------------------------------------
# ASR Pipeline (for English transcription)
# --------------------------------------------------
asr = pipeline(
"automatic-speech-recognition",
model="facebook/wav2vec2-large-960h-lv60-self"
)
# --------------------------------------------------
# Mapping for Target Languages and Models
# --------------------------------------------------
translation_models = {
"Spanish": "Helsinki-NLP/opus-mt-en-es",
"French": "Helsinki-NLP/opus-mt-en-fr",
"German": "Helsinki-NLP/opus-mt-en-de",
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
"Russian": "Helsinki-NLP/opus-mt-en-ru",
"Arabic": "Helsinki-NLP/opus-mt-en-ar",
"Portuguese": "Helsinki-NLP/opus-mt-en-pt",
"Japanese": "Helsinki-NLP/opus-mt-en-ja",
"Italian": "Helsinki-NLP/opus-mt-en-it",
"Korean": "Helsinki-NLP/opus-mt-en-ko"
}
# Each language often requires a specific pipeline task name
# (e.g., "translation_en_to_zh" rather than "translation_en_to_chinese")
translation_tasks = {
"Spanish": "translation_en_to_es",
"French": "translation_en_to_fr",
"German": "translation_en_to_de",
"Chinese": "translation_en_to_zh",
"Russian": "translation_en_to_ru",
"Arabic": "translation_en_to_ar",
"Portuguese": "translation_en_to_pt",
"Japanese": "translation_en_to_ja",
"Italian": "translation_en_to_it",
"Korean": "translation_en_to_ko"
}
# TTS models (some may not exist or may be unofficial)
tts_models = {
"Spanish": "tts_models/es/tacotron2-DDC",
"French": "tts_models/fr/tacotron2",
"German": "tts_models/de/tacotron2",
"Chinese": "tts_models/zh/tacotron2", # Verify if this actually exists on Hugging Face
"Russian": "tts_models/ru/tacotron2", # Same note
"Arabic": "tts_models/ar/tacotron2", # Same note
"Portuguese": "tts_models/pt/tacotron2", # Same note
"Japanese": "tts_models/ja/tacotron2", # Same note
"Italian": "tts_models/it/tacotron2", # Same note
"Korean": "tts_models/ko/tacotron2" # Same note
}
# --------------------------------------------------
# Caches for translator and TTS pipelines
# --------------------------------------------------
translator_cache = {}
tts_cache = {}
def get_translator(target_language):
"""
Retrieve or create a translation pipeline for the specified language.
"""
if target_language in translator_cache:
return translator_cache[target_language]
model_name = translation_models[target_language]
task_name = translation_tasks[target_language]
translator = pipeline(task_name, model=model_name)
translator_cache[target_language] = translator
return translator
def get_tts(target_language):
"""
Retrieve or create a TTS pipeline for the specified language, if available.
"""
if target_language in tts_cache:
return tts_cache[target_language]
model_name = tts_models.get(target_language)
if model_name is None:
# If no TTS model is mapped, raise an error or handle gracefully
raise ValueError(f"No TTS model available for {target_language}.")
try:
tts_pipeline = pipeline("text-to-speech", model=model_name)
except Exception as e:
raise ValueError(
f"Failed to load TTS model for {target_language}. "
f"Make sure '{model_name}' exists on Hugging Face.\nError: {e}"
)
tts_cache[target_language] = tts_pipeline
return tts_pipeline
# --------------------------------------------------
# Prediction Function
# --------------------------------------------------
def predict(audio, text, target_language):
"""
1. Obtain English text (from text input or ASR).
2. Translate English -> target_language.
3. Synthesize speech in target_language.
"""
# 1. English text from text input (if provided), else from audio via ASR
if text.strip():
english_text = text.strip()
elif audio is not None:
sample_rate, audio_data = audio
# Ensure the audio is float32 for librosa
if audio_data.dtype not in [np.float32, np.float64]:
audio_data = audio_data.astype(np.float32)
# Convert stereo to mono if needed
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
audio_data = np.mean(audio_data, axis=1)
# Resample to 16 kHz if necessary
if sample_rate != 16000:
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
input_audio = {"array": audio_data, "sampling_rate": 16000}
asr_result = asr(input_audio)
english_text = asr_result["text"]
else:
return "No input provided.", "", None
# 2. Translation step
translator = get_translator(target_language)
try:
translation_result = translator(english_text)
translated_text = translation_result[0]["translation_text"]
except Exception as e:
# If there's an error in translation, return partial results
return english_text, f"Translation error: {e}", None
# 3. TTS step: synthesize speech from the translated text
try:
tts_pipeline = get_tts(target_language)
tts_result = tts_pipeline(translated_text)
# The TTS pipeline returns a dict with "wav" and "sample_rate"
synthesized_audio = (tts_result["sample_rate"], tts_result["wav"])
except Exception as e:
# If TTS fails, return partial results
return english_text, translated_text, f"TTS error: {e}"
return english_text, translated_text, synthesized_audio
# --------------------------------------------------
# Gradio Interface Setup
# --------------------------------------------------
iface = gr.Interface(
fn=predict,
inputs=[
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
gr.Dropdown(choices=list(translation_models.keys()), value="Spanish", label="Target Language")
],
outputs=[
gr.Textbox(label="English Transcription"),
gr.Textbox(label="Translation (Target Language)"),
gr.Audio(label="Synthesized Speech in Target Language")
],
title="Multimodal Language Learning Aid",
description=(
"This app helps language learners by providing three outputs:\n"
"1. English transcription (from ASR or text input),\n"
"2. Translation to a target language (using Helsinki-NLP models), and\n"
"3. Synthetic speech in the target language.\n\n"
"Select one of the top 10 commonly used languages from the dropdown.\n"
"Either record/upload an English audio sample or enter English text directly.\n\n"
"Note: Some TTS models may not exist or be unstable for certain languages."
),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()