try / app.py
gauravchand11's picture
Update app.py
90c759f verified
raw
history blame
12.9 kB
import streamlit as st
import PyPDF2
import docx
import io
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
import torch
from pathlib import Path
import tempfile
from typing import Union, Tuple
import os
import sys
from datetime import datetime, timezone
import warnings
# Filter out specific warnings
warnings.filterwarnings('ignore', category=UserWarning, module='transformers.convert_slow_tokenizer')
warnings.filterwarnings('ignore', category=UserWarning, module='transformers.tokenization_utils_base')
# Custom styling
st.set_page_config(
page_title="Document Translation App",
page_icon="🌐",
layout="wide"
)
# Display current information in sidebar
current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
st.sidebar.markdown("""
### System Information
**Current UTC Time:** {}
**User:** {}
""".format(current_time, os.environ.get('USER', 'gauravchand')))
# Get Hugging Face token from environment variables
HF_TOKEN = os.environ.get('HF_TOKEN')
if not HF_TOKEN:
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
st.stop()
# Define supported languages and their codes
SUPPORTED_LANGUAGES = {
'English': 'eng_Latn',
'Hindi': 'hin_Deva',
'Marathi': 'mar_Deva'
}
# Language codes for MT5
MT5_LANG_CODES = {
'eng_Latn': 'en',
'hin_Deva': 'hi',
'mar_Deva': 'mr'
}
@st.cache_resource
def load_models():
"""Load and cache the translation and context interpretation models."""
try:
# Set device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load Gemma model for context interpretation
gemma_tokenizer = AutoTokenizer.from_pretrained(
"google/gemma-2b",
token=HF_TOKEN,
trust_remote_code=True
)
gemma_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Load NLLB model for translation
nllb_tokenizer = AutoTokenizer.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=HF_TOKEN,
src_lang="eng_Latn",
trust_remote_code=True
)
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
"facebook/nllb-200-distilled-600M",
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Load MT5 model for grammar correction
mt5_tokenizer = AutoTokenizer.from_pretrained(
"google/mt5-base", # Changed to base model for better performance
token=HF_TOKEN,
trust_remote_code=True
)
mt5_model = MT5ForConditionalGeneration.from_pretrained(
"google/mt5-base", # Changed to base model for better performance
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True
)
# Move models to device if not using device_map="auto"
if not torch.cuda.is_available():
gemma_model = gemma_model.to(device)
nllb_model = nllb_model.to(device)
mt5_model = mt5_model.to(device)
return (gemma_tokenizer, gemma_model), (nllb_tokenizer, nllb_model), (mt5_tokenizer, mt5_model)
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error("Detailed error information:")
st.error(f"Python version: {sys.version}")
st.error(f"PyTorch version: {torch.__version__}")
raise e
def extract_text_from_file(uploaded_file) -> str:
"""Extract text content from uploaded file based on its type."""
file_extension = Path(uploaded_file.name).suffix.lower()
if file_extension == '.pdf':
return extract_from_pdf(uploaded_file)
elif file_extension == '.docx':
return extract_from_docx(uploaded_file)
elif file_extension == '.txt':
return uploaded_file.getvalue().decode('utf-8')
else:
raise ValueError(f"Unsupported file format: {file_extension}")
def extract_from_pdf(file) -> str:
"""Extract text from PDF file."""
pdf_reader = PyPDF2.PdfReader(file)
text = ""
for page in pdf_reader.pages:
text += page.extract_text() + "\n"
return text.strip()
def extract_from_docx(file) -> str:
"""Extract text from DOCX file."""
doc = docx.Document(file)
text = ""
for paragraph in doc.paragraphs:
text += paragraph.text + "\n"
return text.strip()
def batch_process_text(text: str, max_length: int = 512) -> list:
"""Split text into batches for processing."""
words = text.split()
batches = []
current_batch = []
current_length = 0
for word in words:
if current_length + len(word) + 1 > max_length:
batches.append(" ".join(current_batch))
current_batch = [word]
current_length = len(word)
else:
current_batch.append(word)
current_length += len(word) + 1
if current_batch:
batches.append(" ".join(current_batch))
return batches
@torch.no_grad()
def interpret_context(text: str, gemma_tuple: Tuple) -> str:
"""Use Gemma model to interpret context and understand regional nuances."""
tokenizer, model = gemma_tuple
batches = batch_process_text(text)
interpreted_batches = []
for batch in batches:
prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=512,
do_sample=True,
temperature=0.3,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from the output
interpreted_text = interpreted_text.replace(prompt, "").strip()
interpreted_batches.append(interpreted_text)
return " ".join(interpreted_batches)
@torch.no_grad()
def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tuple) -> str:
"""Translate text using NLLB model."""
tokenizer, model = nllb_tuple
batches = batch_process_text(text)
translated_batches = []
for batch in batches:
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
max_length=512,
do_sample=True,
temperature=0.7,
num_beams=5,
num_return_sequences=1
)
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
translated_batches.append(translated_text)
return " ".join(translated_batches)
@torch.no_grad()
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
"""Correct grammar using MT5 model for all supported languages."""
tokenizer, model = mt5_tuple
lang_code = MT5_LANG_CODES[target_lang]
# Language-specific prompts for grammar correction
prompts = {
'en': "Fix grammar: ",
'hi': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£: ",
'mr': "ΰ€΅ΰ₯ΰ€―ΰ€Ύΰ€•ΰ€°ΰ€£: "
}
batches = batch_process_text(text)
corrected_batches = []
for batch in batches:
# Prepare input with target language prefix
input_text = f"{prompts[lang_code]}{batch}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
outputs = model.generate(
**inputs,
max_length=512,
num_beams=5,
length_penalty=1.0,
early_stopping=True,
do_sample=False # Disable sampling for more stable output
)
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean up the output
for prefix in prompts.values():
corrected_text = corrected_text.replace(prefix, "")
corrected_text = corrected_text.replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
corrected_batches.append(corrected_text)
return " ".join(corrected_batches)
def save_as_docx(text: str) -> io.BytesIO:
"""Save translated text as a DOCX file."""
doc = docx.Document()
doc.add_paragraph(text)
docx_buffer = io.BytesIO()
doc.save(docx_buffer)
docx_buffer.seek(0)
return docx_buffer
def main():
st.title("🌐 Document Translation App")
# Load models
with st.spinner("Loading models... This may take a few minutes."):
try:
gemma_tuple, nllb_tuple, mt5_tuple = load_models()
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.error("Please check if the HF_TOKEN is valid and has the necessary permissions.")
return
# File upload
uploaded_file = st.file_uploader(
"Upload your document (PDF, DOCX, or TXT)",
type=['pdf', 'docx', 'txt']
)
# Language selection
col1, col2 = st.columns(2)
with col1:
source_language = st.selectbox(
"Source Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=0
)
with col2:
target_language = st.selectbox(
"Target Language",
options=list(SUPPORTED_LANGUAGES.keys()),
index=1
)
if uploaded_file and st.button("Translate", type="primary"):
try:
progress_bar = st.progress(0)
# Extract text
text = extract_text_from_file(uploaded_file)
progress_bar.progress(20)
# Interpret context
with st.spinner("Interpreting context..."):
interpreted_text = interpret_context(text, gemma_tuple)
progress_bar.progress(40)
# Translate
with st.spinner("Translating..."):
translated_text = translate_text(
interpreted_text,
SUPPORTED_LANGUAGES[source_language],
SUPPORTED_LANGUAGES[target_language],
nllb_tuple
)
progress_bar.progress(70)
# Grammar correction
with st.spinner("Correcting grammar..."):
corrected_text = correct_grammar(
translated_text,
SUPPORTED_LANGUAGES[target_language],
mt5_tuple
)
progress_bar.progress(90)
# Display result
st.markdown("### Translation Result")
st.text_area(
label="Translated Text",
value=corrected_text,
height=200,
key="translation_result"
)
# Download options
st.markdown("### Download Options")
col1, col2 = st.columns(2)
with col1:
# Text file download
text_buffer = io.BytesIO()
text_buffer.write(corrected_text.encode())
text_buffer.seek(0)
st.download_button(
label="Download as TXT",
data=text_buffer,
file_name="translated_document.txt",
mime="text/plain"
)
with col2:
# DOCX file download
docx_buffer = save_as_docx(corrected_text)
st.download_button(
label="Download as DOCX",
data=docx_buffer,
file_name="translated_document.docx",
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
)
progress_bar.progress(100)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()