Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast | |
from utils import lang_ids | |
import nltk | |
nltk.download('punkt') | |
MODEL_NAME = "Pranjal12345/pranjal_whisper_medium" | |
BATCH_SIZE = 10 | |
FILE_LIMIT_MB = 1000 | |
pipe = pipeline( | |
task="automatic-speech-recognition", | |
model=MODEL_NAME, | |
chunk_length_s=30, | |
device='cpu', | |
) | |
# Download the mbart model | |
model = MBartForConditionalGeneration.from_pretrained("sanjitaa/mbart-many-to-many") | |
tokenizer = MBart50TokenizerFast.from_pretrained("sanjitaa/mbart-many-to-many") | |
lang_list = list(lang_ids.keys()) | |
def translate_audio(inputs,target_language): | |
if inputs is None: | |
raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.") | |
text = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": "translate"}, return_timestamps=True)["text"] | |
target_lang = lang_ids[target_language] | |
if target_language == 'English': | |
return text | |
else: | |
tokenizer.src_lang = "en_XX" | |
chunks = nltk.tokenize.sent_tokenize(text) | |
translated_text = '' | |
for segment in chunks: | |
encoded_chunk = tokenizer(segment, return_tensors="pt") | |
generated_tokens = model.generate( | |
**encoded_chunk, | |
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang] | |
) | |
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
translated_text = translated_text + translated_chunk[0] | |
return translated_text | |
inputs=[ | |
gr.inputs.Audio(source = "upload", type="filepath", label="Audio file"), | |
gr.Dropdown(lang_list, value="English", label="Target Language"), | |
] | |
description = "Audio translation" | |
translation_interface = gr.Interface( | |
fn=translate_audio, | |
inputs= inputs, | |
outputs="text", | |
title="Speech Translation", | |
description= description | |
) | |
if __name__ == "__main__": | |
translation_interface.launch() | |