Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, MarianTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
import re | |
import whisper | |
import tempfile | |
import os | |
import nltk | |
nltk.download('punkt') | |
from nltk.tokenize import sent_tokenize | |
import os | |
# Additions for file processing | |
import fitz # PyMuPDF for PDF | |
import docx | |
from bs4 import BeautifulSoup | |
import markdown2 | |
import chardet | |
# --- Device selection --- | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
# --- Load translation models --- | |
def load_models(): | |
en_dar_model_path = "LocaleNLP/english_hausa" | |
en_wol_model_path = "LocaleNLP/eng_wolof" | |
en_hau_model_path = "LocaleNLP/english_darija" | |
en_dar_model = AutoModelForSeq2SeqLM.from_pretrained(en_dar_model_path, token=HF_TOKEN).to(device) | |
en_dar_tokenizer = MarianTokenizer.from_pretrained(en_dar_model_path, token=HF_TOKEN) | |
en_wol_model = AutoModelForSeq2SeqLM.from_pretrained(en_wol_model_path, token=HF_TOKEN).to(device) | |
en_wol_tokenizer = MarianTokenizer.from_pretrained(en_wol_model_path, token=HF_TOKEN) | |
en_hau_model = AutoModelForSeq2SeqLM.from_pretrained(en_hau_model_path, token=HF_TOKEN).to(device) | |
en_hau_tokenizer = MarianTokenizer.from_pretrained(en_hau_model_path, token=HF_TOKEN) | |
en_dar_translator = pipeline("translation", model=en_dar_model, tokenizer=en_dar_tokenizer, device=0 if device.type == 'cuda' else -1) | |
en_wol_translator = pipeline("translation", model=en_wol_model, tokenizer=en_wol_tokenizer, device=0 if device.type == 'cuda' else -1) | |
en_hau_translator = pipeline("translation", model=en_hau_model, tokenizer=en_hau_tokenizer, device=0 if device.type == 'cuda' else -1) | |
return en_dar_translator, en_hau_translator, en_wol_translator | |
def load_whisper_model(): | |
return whisper.load_model("base") | |
def transcribe_audio(audio_file): | |
model = load_whisper_model() | |
if isinstance(audio_file, str): | |
audio_path = audio_file | |
else: | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: | |
tmp.write(audio_file.read()) | |
audio_path = tmp.name | |
result = model.transcribe(audio_path) | |
if not isinstance(audio_file, str): | |
os.remove(audio_path) | |
return result["text"] | |
def translate(text, target_lang): | |
en_dar_translator, en_hau_translator, en_wol_translator = load_models() | |
if target_lang == "Darija (Morocco)": | |
translator = en_dar_translator | |
elif target_lang == "Hausa (Nigeria)": | |
translator = en_hau_translator | |
elif target_lang == "Wolof (Senegal)": | |
translator = en_wol_translator | |
else: | |
raise ValueError("Unsupported target language") | |
lang_tag = { | |
"Darija (Morocco)": ">>dar<<", | |
"Hausa (Nigeria)": ">>hau<<", | |
"Wolof (Senegal)": ">>wol<<" | |
} | |
paragraphs = text.split("\n") | |
translated_output = [] | |
with torch.no_grad(): | |
for para in paragraphs: | |
if not para.strip(): | |
translated_output.append("") | |
continue | |
sentences = [s.strip() for s in para.split('. ') if s.strip()] | |
formatted = [f"{lang_tag} {s}" for s in sentences] | |
results = translator(formatted, | |
max_length=5000, | |
num_beams=5, | |
early_stopping=True, | |
no_repeat_ngram_size=3, | |
repetition_penalty=1.5, | |
length_penalty=1.2) | |
translated_sentences = [r['translation_text'].capitalize() for r in results] | |
translated_output.append('. '.join(translated_sentences)) | |
return "\n".join(translated_output) | |
# --- Extract text from file --- | |
def extract_text_from_file(uploaded_file): | |
# Handle both filepath (str) and file-like object | |
if isinstance(uploaded_file, str): | |
file_path = uploaded_file | |
file_type = file_path.split('.')[-1].lower() | |
with open(file_path, "rb") as f: | |
content = f.read() | |
else: | |
file_type = uploaded_file.name.split('.')[-1].lower() | |
content = uploaded_file.read() | |
if file_type == "pdf": | |
with fitz.open(stream=content, filetype="pdf") as doc: | |
return "\n".join([page.get_text() for page in doc]) | |
elif file_type == "docx": | |
if isinstance(uploaded_file, str): | |
doc = docx.Document(file_path) | |
else: | |
doc = docx.Document(uploaded_file) | |
return "\n".join([para.text for para in doc.paragraphs]) | |
else: | |
encoding = chardet.detect(content)['encoding'] | |
if encoding: | |
content = content.decode(encoding, errors='ignore') | |
if file_type in ("html", "htm"): | |
soup = BeautifulSoup(content, "html.parser") | |
return soup.get_text() | |
elif file_type == "md": | |
html = markdown2.markdown(content) | |
soup = BeautifulSoup(html, "html.parser") | |
return soup.get_text() | |
elif file_type == "srt": | |
return re.sub(r"\d+\n\d{2}:\d{2}:\d{2},\d{3} --> .*?\n", "", content) | |
elif file_type in ("txt", "text"): | |
return content | |
else: | |
raise ValueError("Unsupported file type") | |
# --- Main Function --- | |
def process(target_lang, text_input, audio_input, file_input): | |
input_text = "" | |
if text_input and text_input.strip(): | |
input_text = text_input | |
elif audio_input: | |
input_text = transcribe_audio(audio_input) | |
elif file_input: | |
input_text = extract_text_from_file(file_input) | |
if not input_text.strip(): | |
return "", "No valid input provided." | |
translated_text = translate(input_text, target_lang) | |
return input_text, translated_text | |
# --- Gradio Interface --- | |
with gr.Blocks() as demo: | |
gr.Markdown("## π LocaleNLP Translator β English β Darija / Hausa / Wolof") | |
target_lang = gr.Dropdown( | |
["Darija (Morocco)", "Hausa (Nigeria)", "Wolof (Senegal)"], | |
label="Select target language" | |
) | |
with gr.Row(): | |
text_input = gr.Textbox(label="βοΈ Enter English text", lines=10) | |
audio_input = gr.Audio(type="filepath", label="π Upload Audio") | |
file_input = gr.File(label="π Upload Document") | |
with gr.Row(): | |
extracted_text = gr.Textbox(label="Extracted / Transcribed Text", lines=10) | |
translated_output = gr.Textbox(label="Translated Text", lines=10) | |
run_btn = gr.Button("Translate") | |
run_btn.click(process, inputs=[target_lang, text_input, audio_input, file_input], outputs=[extracted_text, translated_output]) | |
if __name__ == "__main__": | |
demo.launch() | |