Spaces:
Build error
Build error
import streamlit as st | |
import PyPDF2 | |
import docx | |
import io | |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration | |
import torch | |
from pathlib import Path | |
import tempfile | |
from typing import Union, Tuple | |
import os | |
import sys | |
from datetime import datetime, timezone | |
import warnings | |
# Filter out specific warnings | |
warnings.filterwarnings('ignore', category=UserWarning, module='transformers.convert_slow_tokenizer') | |
warnings.filterwarnings('ignore', category=UserWarning, module='transformers.tokenization_utils_base') | |
# Custom styling | |
st.set_page_config( | |
page_title="Document Translation App", | |
page_icon="π", | |
layout="wide" | |
) | |
# Display current information in sidebar | |
current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S') | |
st.sidebar.markdown(""" | |
### System Information | |
**Current UTC Time:** {} | |
**User:** {} | |
""".format(current_time, os.environ.get('USER', 'gauravchand'))) | |
# Get Hugging Face token from environment variables | |
HF_TOKEN = os.environ.get('HF_TOKEN') | |
if not HF_TOKEN: | |
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.") | |
st.stop() | |
# Define supported languages and their codes | |
SUPPORTED_LANGUAGES = { | |
'English': 'eng_Latn', | |
'Hindi': 'hin_Deva', | |
'Marathi': 'mar_Deva' | |
} | |
# Language codes for MT5 | |
MT5_LANG_CODES = { | |
'eng_Latn': 'en', | |
'hin_Deva': 'hi', | |
'mar_Deva': 'mr' | |
} | |
def load_models(): | |
"""Load and cache the translation and context interpretation models.""" | |
try: | |
# Set device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Load Gemma model for context interpretation | |
gemma_tokenizer = AutoTokenizer.from_pretrained( | |
"google/gemma-2b", | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
gemma_model = AutoModelForCausalLM.from_pretrained( | |
"google/gemma-2b", | |
token=HF_TOKEN, | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
# Load NLLB model for translation | |
nllb_tokenizer = AutoTokenizer.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
token=HF_TOKEN, | |
src_lang="eng_Latn", | |
trust_remote_code=True | |
) | |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained( | |
"facebook/nllb-200-distilled-600M", | |
token=HF_TOKEN, | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
# Load MT5 model for grammar correction | |
mt5_tokenizer = AutoTokenizer.from_pretrained( | |
"google/mt5-base", # Changed to base model for better performance | |
token=HF_TOKEN, | |
trust_remote_code=True | |
) | |
mt5_model = MT5ForConditionalGeneration.from_pretrained( | |
"google/mt5-base", # Changed to base model for better performance | |
token=HF_TOKEN, | |
torch_dtype=torch.float16, | |
device_map="auto" if torch.cuda.is_available() else None, | |
trust_remote_code=True | |
) | |
# Move models to device if not using device_map="auto" | |
if not torch.cuda.is_available(): | |
gemma_model = gemma_model.to(device) | |
nllb_model = nllb_model.to(device) | |
mt5_model = mt5_model.to(device) | |
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model) | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.error("Detailed error information:") | |
st.error(f"Python version: {sys.version}") | |
st.error(f"PyTorch version: {torch.__version__}") | |
raise e | |
def extract_text_from_file(uploaded_file) -> str: | |
"""Extract text content from uploaded file based on its type.""" | |
file_extension = Path(uploaded_file.name).suffix.lower() | |
if file_extension == '.pdf': | |
return extract_from_pdf(uploaded_file) | |
elif file_extension == '.docx': | |
return extract_from_docx(uploaded_file) | |
elif file_extension == '.txt': | |
return uploaded_file.getvalue().decode('utf-8') | |
else: | |
raise ValueError(f"Unsupported file format: {file_extension}") | |
def extract_from_pdf(file) -> str: | |
"""Extract text from PDF file.""" | |
pdf_reader = PyPDF2.PdfReader(file) | |
text = "" | |
for page in pdf_reader.pages: | |
text += page.extract_text() + "\n" | |
return text.strip() | |
def extract_from_docx(file) -> str: | |
"""Extract text from DOCX file.""" | |
doc = docx.Document(file) | |
text = "" | |
for paragraph in doc.paragraphs: | |
text += paragraph.text + "\n" | |
return text.strip() | |
def batch_process_text(text: str, max_length: int = 512) -> list: | |
"""Split text into batches for processing.""" | |
words = text.split() | |
batches = [] | |
current_batch = [] | |
current_length = 0 | |
for word in words: | |
if current_length + len(word) + 1 > max_length: | |
batches.append(" ".join(current_batch)) | |
current_batch = [word] | |
current_length = len(word) | |
else: | |
current_batch.append(word) | |
current_length += len(word) + 1 | |
if current_batch: | |
batches.append(" ".join(current_batch)) | |
return batches | |
def interpret_context(text: str, gemma_tuple: Tuple) -> str: | |
"""Use Gemma model to interpret context and understand regional nuances.""" | |
tokenizer, model = gemma_tuple | |
batches = batch_process_text(text) | |
interpreted_batches = [] | |
for batch in batches: | |
prompt = f"""Analyze and maintain the core meaning of this text: {batch}""" | |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
do_sample=True, | |
temperature=0.3, | |
pad_token_id=tokenizer.eos_token_id, | |
num_return_sequences=1 | |
) | |
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Remove the prompt from the output | |
interpreted_text = interpreted_text.replace(prompt, "").strip() | |
interpreted_batches.append(interpreted_text) | |
return " ".join(interpreted_batches) | |
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str: | |
"""Translate text using NLLB model.""" | |
tokenizer, model = nllb_tuple | |
batches = batch_process_text(text) | |
translated_batches = [] | |
for batch in batches: | |
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang], | |
max_length=512, | |
do_sample=True, | |
temperature=0.7, | |
num_beams=5, | |
num_return_sequences=1 | |
) | |
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
translated_batches.append(translated_text) | |
return " ".join(translated_batches) | |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str: | |
"""Correct grammar using MT5 model for all supported languages.""" | |
tokenizer, model = mt5_tuple | |
lang_code = MT5_LANG_CODES[target_lang] | |
# Language-specific prompts for grammar correction | |
prompts = { | |
'en': "Fix grammar: ", | |
'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£: ", | |
'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€ΰ€°ΰ€£: " | |
} | |
batches = batch_process_text(text) | |
corrected_batches = [] | |
for batch in batches: | |
# Prepare input with target language prefix | |
input_text = f"{prompts[lang_code]}{batch}" | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
outputs = model.generate( | |
**inputs, | |
max_length=512, | |
num_beams=5, | |
length_penalty=1.0, | |
early_stopping=True, | |
do_sample=False # Disable sampling for more stable output | |
) | |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up the output | |
for prefix in prompts.values(): | |
corrected_text = corrected_text.replace(prefix, "") | |
corrected_text = corrected_text.replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip() | |
corrected_batches.append(corrected_text) | |
return " ".join(corrected_batches) | |
def save_as_docx(text: str) -> io.BytesIO: | |
"""Save translated text as a DOCX file.""" | |
doc = docx.Document() | |
doc.add_paragraph(text) | |
docx_buffer = io.BytesIO() | |
doc.save(docx_buffer) | |
docx_buffer.seek(0) | |
return docx_buffer | |
def main(): | |
st.title("π Document Translation App") | |
# Load models | |
with st.spinner("Loading models... This may take a few minutes."): | |
try: | |
gemma_tuple, nllb_tuple, mt5_tuple = load_models() | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.") | |
return | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload your document (PDF, DOCX, or TXT)", | |
type=['pdf', 'docx', 'txt'] | |
) | |
# Language selection | |
col1, col2 = st.columns(2) | |
with col1: | |
source_language = st.selectbox( | |
"Source Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=0 | |
) | |
with col2: | |
target_language = st.selectbox( | |
"Target Language", | |
options=list(SUPPORTED_LANGUAGES.keys()), | |
index=1 | |
) | |
if uploaded_file and st.button("Translate", type="primary"): | |
try: | |
progress_bar = st.progress(0) | |
# Extract text | |
text = extract_text_from_file(uploaded_file) | |
progress_bar.progress(20) | |
# Interpret context | |
with st.spinner("Interpreting context..."): | |
interpreted_text = interpret_context(text, gemma_tuple) | |
progress_bar.progress(40) | |
# Translate | |
with st.spinner("Translating..."): | |
translated_text = translate_text( | |
interpreted_text, | |
SUPPORTED_LANGUAGES[source_language], | |
SUPPORTED_LANGUAGES[target_language], | |
nllb_tuple | |
) | |
progress_bar.progress(70) | |
# Grammar correction | |
with st.spinner("Correcting grammar..."): | |
corrected_text = correct_grammar( | |
translated_text, | |
SUPPORTED_LANGUAGES[target_language], | |
mt5_tuple | |
) | |
progress_bar.progress(90) | |
# Display result | |
st.markdown("### Translation Result") | |
st.text_area( | |
label="Translated Text", | |
value=corrected_text, | |
height=200, | |
key="translation_result" | |
) | |
# Download options | |
st.markdown("### Download Options") | |
col1, col2 = st.columns(2) | |
with col1: | |
# Text file download | |
text_buffer = io.BytesIO() | |
text_buffer.write(corrected_text.encode()) | |
text_buffer.seek(0) | |
st.download_button( | |
label="Download as TXT", | |
data=text_buffer, | |
file_name="translated_document.txt", | |
mime="text/plain" | |
) | |
with col2: | |
# DOCX file download | |
docx_buffer = save_as_docx(corrected_text) | |
st.download_button( | |
label="Download as DOCX", | |
data=docx_buffer, | |
file_name="translated_document.docx", | |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document" | |
) | |
progress_bar.progress(100) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() |