Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -1,426 +1,146 @@
|
|
|
|
1 |
import streamlit as st
|
2 |
-
import
|
3 |
import docx
|
4 |
-
import io
|
5 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
|
6 |
-
import torch
|
7 |
-
from pathlib import Path
|
8 |
-
from typing import Union, Tuple, List, Dict
|
9 |
import os
|
10 |
-
import sys
|
11 |
-
from datetime import datetime, timezone
|
12 |
-
import warnings
|
13 |
import re
|
14 |
|
15 |
-
#
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
#
|
19 |
-
st.
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
)
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
"
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
"MT5_LANG_CODES": {
|
38 |
-
'eng_Latn': 'en',
|
39 |
-
'hin_Deva': 'hi',
|
40 |
-
'mar_Deva': 'mr'
|
41 |
-
},
|
42 |
-
"GRAMMAR_PROMPTS": {
|
43 |
-
'en': "Fix grammar and improve fluency: ",
|
44 |
-
'hi': "व्याकरण और प्रवाह सुधारें: ",
|
45 |
-
'mr': "व्याकरण आणि प्रवाह सुधारा: "
|
46 |
-
}
|
47 |
-
}
|
48 |
-
|
49 |
-
class DocumentProcessor:
|
50 |
-
"""Handles document processing and text extraction"""
|
51 |
-
|
52 |
-
@staticmethod
|
53 |
-
def extract_text_from_file(uploaded_file) -> str:
|
54 |
-
file_extension = Path(uploaded_file.name).suffix.lower()
|
55 |
-
|
56 |
-
extractors = {
|
57 |
-
'.pdf': DocumentProcessor._extract_from_pdf,
|
58 |
-
'.docx': DocumentProcessor._extract_from_docx,
|
59 |
-
'.txt': lambda f: f.getvalue().decode('utf-8')
|
60 |
-
}
|
61 |
-
|
62 |
-
if file_extension not in extractors:
|
63 |
-
raise ValueError(f"Unsupported file format: {file_extension}")
|
64 |
-
|
65 |
-
return extractors[file_extension](uploaded_file)
|
66 |
-
|
67 |
-
@staticmethod
|
68 |
-
def _extract_from_pdf(file) -> str:
|
69 |
-
pdf_reader = PyPDF2.PdfReader(file)
|
70 |
-
return "\n".join(page.extract_text() for page in pdf_reader.pages).strip()
|
71 |
-
|
72 |
-
@staticmethod
|
73 |
-
def _extract_from_docx(file) -> str:
|
74 |
doc = docx.Document(file)
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
|
81 |
-
|
82 |
-
sentences = TextBatcher._split_into_sentences(text)
|
83 |
-
batches = []
|
84 |
-
current_batch = []
|
85 |
-
current_length = 0
|
86 |
-
|
87 |
-
for sentence in sentences:
|
88 |
-
sentence_length = len(sentence)
|
89 |
-
|
90 |
-
if current_length + sentence_length > max_length:
|
91 |
-
if current_batch:
|
92 |
-
batches.append(" ".join(current_batch))
|
93 |
-
current_batch = [sentence]
|
94 |
-
current_length = sentence_length
|
95 |
-
else:
|
96 |
-
current_batch.append(sentence)
|
97 |
-
current_length += sentence_length
|
98 |
-
|
99 |
-
if current_batch:
|
100 |
-
batches.append(" ".join(current_batch))
|
101 |
-
|
102 |
-
return batches
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
"""Split text into sentences with improved boundary detection"""
|
107 |
-
delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
|
108 |
-
sentences = []
|
109 |
-
current = text
|
110 |
-
|
111 |
-
for delimiter in delimiters:
|
112 |
-
parts = current.split(delimiter)
|
113 |
-
current = parts[0]
|
114 |
-
for part in parts[1:]:
|
115 |
-
if len(current.strip()) > 0:
|
116 |
-
sentences.append(current.strip() + delimiter.strip())
|
117 |
-
current = part
|
118 |
-
|
119 |
-
if len(current.strip()) > 0:
|
120 |
-
sentences.append(current.strip())
|
121 |
-
|
122 |
-
return sentences
|
123 |
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
def load_models():
|
129 |
-
try:
|
130 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
131 |
-
|
132 |
-
models = {
|
133 |
-
"gemma": ModelManager._load_gemma_model(),
|
134 |
-
"nllb": ModelManager._load_nllb_model(),
|
135 |
-
"mt5": ModelManager._load_mt5_model()
|
136 |
-
}
|
137 |
-
|
138 |
-
if not torch.cuda.is_available():
|
139 |
-
for model_tuple in models.values():
|
140 |
-
model_tuple[1].to(device)
|
141 |
-
|
142 |
-
return models
|
143 |
-
|
144 |
-
except Exception as e:
|
145 |
-
st.error(f"Error loading models: {str(e)}")
|
146 |
-
st.error(f"Python version: {sys.version}")
|
147 |
-
st.error(f"PyTorch version: {torch.__version__}")
|
148 |
-
raise e
|
149 |
-
|
150 |
-
@staticmethod
|
151 |
-
def _load_gemma_model():
|
152 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
153 |
-
"google/gemma-2b",
|
154 |
-
token=os.environ.get('HF_TOKEN'),
|
155 |
-
trust_remote_code=True
|
156 |
-
)
|
157 |
-
model = AutoModelForCausalLM.from_pretrained(
|
158 |
-
"google/gemma-2b",
|
159 |
-
token=os.environ.get('HF_TOKEN'),
|
160 |
-
torch_dtype=torch.float16,
|
161 |
-
device_map="auto" if torch.cuda.is_available() else None,
|
162 |
-
trust_remote_code=True
|
163 |
-
)
|
164 |
-
return (tokenizer, model)
|
165 |
-
|
166 |
-
@staticmethod
|
167 |
-
def _load_nllb_model():
|
168 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
169 |
-
"facebook/nllb-200-distilled-600M",
|
170 |
-
token=os.environ.get('HF_TOKEN'),
|
171 |
-
use_fast=False,
|
172 |
-
trust_remote_code=True
|
173 |
-
)
|
174 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(
|
175 |
-
"facebook/nllb-200-distilled-600M",
|
176 |
-
token=os.environ.get('HF_TOKEN'),
|
177 |
-
torch_dtype=torch.float16,
|
178 |
-
device_map="auto" if torch.cuda.is_available() else None,
|
179 |
-
trust_remote_code=True
|
180 |
-
)
|
181 |
-
return (tokenizer, model)
|
182 |
-
|
183 |
-
@staticmethod
|
184 |
-
def _load_mt5_model():
|
185 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
186 |
-
"google/mt5-base",
|
187 |
-
token=os.environ.get('HF_TOKEN'),
|
188 |
-
trust_remote_code=True
|
189 |
-
)
|
190 |
-
model = MT5ForConditionalGeneration.from_pretrained(
|
191 |
-
"google/mt5-base",
|
192 |
-
token=os.environ.get('HF_TOKEN'),
|
193 |
-
torch_dtype=torch.float16,
|
194 |
-
device_map="auto" if torch.cuda.is_available() else None,
|
195 |
-
trust_remote_code=True
|
196 |
-
)
|
197 |
-
return (tokenizer, model)
|
198 |
|
199 |
-
|
200 |
-
"""
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
target_lang
|
225 |
)
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
# Clean up the final text
|
230 |
-
final_text = " ".join(final_results)
|
231 |
-
return self._clean_text(final_text)
|
232 |
|
233 |
-
|
234 |
-
tokenizer, model = self.models["gemma"]
|
235 |
-
|
236 |
-
prompt = f"""Analyze and provide context for translation:
|
237 |
-
Text: {text}
|
238 |
-
Key points to consider:
|
239 |
-
- Main topic and subject matter
|
240 |
-
- Cultural context and nuances
|
241 |
-
- Technical terminology if any
|
242 |
-
- Tone and style of writing
|
243 |
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
251 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
252 |
-
|
253 |
-
outputs = model.generate(
|
254 |
-
**inputs,
|
255 |
-
max_length=CONFIG["MAX_BATCH_LENGTH"],
|
256 |
-
do_sample=True,
|
257 |
-
temperature=CONFIG["CONTEXT_TEMPERATURE"],
|
258 |
-
pad_token_id=tokenizer.eos_token_id,
|
259 |
-
num_return_sequences=1
|
260 |
-
)
|
261 |
-
|
262 |
-
context = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
263 |
-
return context.replace(prompt, "").strip()
|
264 |
-
|
265 |
-
def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
|
266 |
-
tokenizer, model = self.models["nllb"]
|
267 |
-
|
268 |
-
target_lang_token = f"___{target_lang}___"
|
269 |
-
inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
270 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
271 |
-
|
272 |
-
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
|
273 |
-
|
274 |
-
outputs = model.generate(
|
275 |
-
**inputs,
|
276 |
-
forced_bos_token_id=target_lang_id,
|
277 |
-
max_length=CONFIG["MAX_BATCH_LENGTH"],
|
278 |
-
do_sample=True,
|
279 |
-
temperature=CONFIG["TRANSLATION_TEMPERATURE"],
|
280 |
-
num_beams=CONFIG["NUM_BEAMS"],
|
281 |
-
num_return_sequences=1,
|
282 |
-
length_penalty=1.0,
|
283 |
-
repetition_penalty=1.2
|
284 |
-
)
|
285 |
-
|
286 |
-
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
287 |
-
|
288 |
-
def _correct_grammar(self, text: str, target_lang: str) -> str:
|
289 |
-
tokenizer, model = self.models["mt5"]
|
290 |
-
lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
|
291 |
-
prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
|
292 |
-
|
293 |
-
input_text = f"{prompt}{text}"
|
294 |
-
inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
295 |
-
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
296 |
-
|
297 |
-
outputs = model.generate(
|
298 |
-
**inputs,
|
299 |
-
max_length=CONFIG["MAX_BATCH_LENGTH"],
|
300 |
-
num_beams=CONFIG["NUM_BEAMS"],
|
301 |
-
length_penalty=1.0,
|
302 |
-
early_stopping=True,
|
303 |
-
no_repeat_ngram_size=2,
|
304 |
-
do_sample=False
|
305 |
-
)
|
306 |
-
|
307 |
-
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
308 |
-
return self._clean_text(corrected.replace(prompt, "").strip())
|
309 |
-
|
310 |
-
def _clean_text(self, text: str) -> str:
|
311 |
-
"""Clean up the text by removing special tokens and fixing formatting"""
|
312 |
-
# Remove MT5 special tokens
|
313 |
-
text = re.sub(r'<extra_id_\d+>', '', text)
|
314 |
-
|
315 |
-
# Fix multiple spaces
|
316 |
-
text = re.sub(r'\s+', ' ', text)
|
317 |
-
|
318 |
-
# Fix punctuation spacing
|
319 |
-
text = re.sub(r'\s+([.,!?।॥])', r'\1', text)
|
320 |
-
|
321 |
-
return text.strip()
|
322 |
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
335 |
|
|
|
336 |
def main():
|
337 |
-
st.title("
|
338 |
-
|
339 |
-
# Display system info
|
340 |
-
st.sidebar.markdown(f"""
|
341 |
-
### System Information
|
342 |
-
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
|
343 |
-
**User:** {os.environ.get('USER', 'gauravchand')}
|
344 |
-
""")
|
345 |
|
346 |
-
#
|
347 |
-
|
348 |
-
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
|
349 |
-
st.stop()
|
350 |
|
351 |
-
#
|
352 |
-
|
353 |
-
try:
|
354 |
-
models = ModelManager.load_models()
|
355 |
-
pipeline = TranslationPipeline(models)
|
356 |
-
except Exception as e:
|
357 |
-
st.error(f"Error initializing translation pipeline: {str(e)}")
|
358 |
-
return
|
359 |
-
|
360 |
-
# File upload
|
361 |
-
uploaded_file = st.file_uploader(
|
362 |
-
"Upload your document (PDF, DOCX, or TXT)",
|
363 |
-
type=['pdf', 'docx', 'txt']
|
364 |
-
)
|
365 |
|
366 |
# Language selection
|
367 |
col1, col2 = st.columns(2)
|
368 |
with col1:
|
369 |
-
|
370 |
-
"Source Language",
|
371 |
-
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
|
372 |
-
index=0
|
373 |
-
)
|
374 |
-
|
375 |
with col2:
|
376 |
-
|
377 |
-
"Target Language",
|
378 |
-
options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
|
379 |
-
index=1
|
380 |
-
)
|
381 |
|
382 |
-
if uploaded_file and st.button("Translate"
|
383 |
-
|
384 |
-
|
385 |
-
status_text = st.empty()
|
386 |
-
|
387 |
-
# Process document
|
388 |
-
status_text.text("Extracting text from document...")
|
389 |
-
text = DocumentProcessor.extract_text_from_file(uploaded_file)
|
390 |
-
progress_bar.progress(20)
|
391 |
-
|
392 |
-
# Perform translation
|
393 |
-
status_text.text("Translating document with context understanding...")
|
394 |
-
final_text = pipeline.process_text(
|
395 |
-
text,
|
396 |
-
CONFIG["SUPPORTED_LANGUAGES"][source_language],
|
397 |
-
CONFIG["SUPPORTED_LANGUAGES"][target_language]
|
398 |
-
)
|
399 |
-
progress_bar.progress(90)
|
400 |
|
401 |
# Display result
|
402 |
-
st.
|
403 |
-
st.text_area(
|
404 |
-
label="Translated Text",
|
405 |
-
value=final_text,
|
406 |
-
height=200,
|
407 |
-
key="translation_result"
|
408 |
-
)
|
409 |
-
|
410 |
-
# Download option
|
411 |
-
st.markdown("### Download Option")
|
412 |
-
st.download_button(
|
413 |
-
label="Download as DOCX",
|
414 |
-
data=DocumentExporter.save_as_docx(final_text),
|
415 |
-
file_name="translated_document.docx",
|
416 |
-
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
417 |
-
)
|
418 |
-
|
419 |
-
status_text.text("Translation completed successfully!")
|
420 |
-
progress_bar.progress(100)
|
421 |
|
422 |
-
|
423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
|
425 |
if __name__ == "__main__":
|
426 |
main()
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import streamlit as st
|
3 |
+
from PyPDF2 import PdfReader
|
4 |
import docx
|
|
|
|
|
|
|
|
|
|
|
5 |
import os
|
|
|
|
|
|
|
6 |
import re
|
7 |
|
8 |
+
# Load NLLB model and tokenizer
|
9 |
+
@st.cache_resource
|
10 |
+
def load_translation_model():
|
11 |
+
model_name = "facebook/nllb-200-distilled-600M"
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
13 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
14 |
+
return tokenizer, model
|
15 |
|
16 |
+
# Initialize model
|
17 |
+
@st.cache_resource
|
18 |
+
def initialize_models():
|
19 |
+
tokenizer, model = load_translation_model()
|
20 |
+
return {"nllb": (tokenizer, model)}
|
|
|
21 |
|
22 |
+
# Function to extract text from different file types
|
23 |
+
def extract_text(file):
|
24 |
+
ext = os.path.splitext(file.name)[1].lower()
|
25 |
+
|
26 |
+
if ext == ".pdf":
|
27 |
+
reader = PdfReader(file)
|
28 |
+
text = ""
|
29 |
+
for page in reader.pages:
|
30 |
+
text += page.extract_text() + "\n"
|
31 |
+
return text
|
32 |
+
|
33 |
+
elif ext == ".docx":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
doc = docx.Document(file)
|
35 |
+
text = ""
|
36 |
+
for para in doc.paragraphs:
|
37 |
+
text += para.text + "\n"
|
38 |
+
return text
|
39 |
|
40 |
+
elif ext == ".txt":
|
41 |
+
return file.read().decode("utf-8")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
else:
|
44 |
+
raise ValueError("Unsupported file format. Please upload PDF, DOCX, or TXT files.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
+
# Translation function
|
47 |
+
def translate_text(text, src_lang, tgt_lang, models):
|
48 |
+
if src_lang == tgt_lang:
|
49 |
+
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
# Language codes for NLLB
|
52 |
+
lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
|
53 |
+
|
54 |
+
if src_lang not in lang_map or tgt_lang not in lang_map:
|
55 |
+
return "Error: Unsupported language combination"
|
56 |
+
|
57 |
+
tgt_lang_code = lang_map[tgt_lang]
|
58 |
+
|
59 |
+
tokenizer, model = models["nllb"]
|
60 |
+
|
61 |
+
# Preprocess for idioms
|
62 |
+
preprocessed_text = preprocess_idioms(text, src_lang, tgt_lang)
|
63 |
+
|
64 |
+
# Split text into manageable chunks
|
65 |
+
sentences = preprocessed_text.split("\n")
|
66 |
+
translated_text = ""
|
67 |
+
|
68 |
+
for sentence in sentences:
|
69 |
+
if sentence.strip():
|
70 |
+
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
71 |
+
# Use lang_code_to_id instead of get_lang_id
|
72 |
+
translated = model.generate(
|
73 |
+
**inputs,
|
74 |
+
forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code],
|
75 |
+
max_length=512
|
|
|
76 |
)
|
77 |
+
translated_sentence = tokenizer.decode(translated[0], skip_special_tokens=True)
|
78 |
+
translated_text += translated_sentence + "\n"
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
return translated_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
# Function to save text as a file
|
83 |
+
def save_text_to_file(text, original_filename, prefix="translated"):
|
84 |
+
output_filename = f"{prefix}_{os.path.basename(original_filename)}.txt"
|
85 |
+
with open(output_filename, "w", encoding="utf-8") as f:
|
86 |
+
f.write(text)
|
87 |
+
return output_filename
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
+
# Main processing function
|
90 |
+
def process_document(file, source_lang, target_lang, models):
|
91 |
+
try:
|
92 |
+
# Extract text from uploaded file
|
93 |
+
text = extract_text(file)
|
94 |
+
|
95 |
+
# Translate the text
|
96 |
+
translated_text = translate_text(text, source_lang, target_lang, models)
|
97 |
+
|
98 |
+
# Save the result (success or error) to a file
|
99 |
+
if translated_text.startswith("Error:"):
|
100 |
+
output_file = save_text_to_file(translated_text, file.name, prefix="error")
|
101 |
+
else:
|
102 |
+
output_file = save_text_to_file(translated_text, file.name)
|
103 |
+
|
104 |
+
return output_file, translated_text
|
105 |
+
except Exception as e:
|
106 |
+
# Save error message to a file
|
107 |
+
error_message = f"Error: {str(e)}"
|
108 |
+
output_file = save_text_to_file(error_message, file.name, prefix="error")
|
109 |
+
return output_file, error_message
|
110 |
|
111 |
+
# Streamlit interface
|
112 |
def main():
|
113 |
+
st.title("Document Translator (NLLB-200)")
|
114 |
+
st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages (English, Hindi, Marathi).")
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
+
# Initialize models
|
117 |
+
models = initialize_models()
|
|
|
|
|
118 |
|
119 |
+
# File uploader
|
120 |
+
uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
# Language selection
|
123 |
col1, col2 = st.columns(2)
|
124 |
with col1:
|
125 |
+
source_lang = st.selectbox("Source Language", ["en", "hi", "mr"], index=0)
|
|
|
|
|
|
|
|
|
|
|
126 |
with col2:
|
127 |
+
target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
|
|
|
|
|
|
|
|
|
128 |
|
129 |
+
if uploaded_file is not None and st.button("Translate"):
|
130 |
+
with st.spinner("Translating..."):
|
131 |
+
output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
|
133 |
# Display result
|
134 |
+
st.text_area("Translated Text", result_text, height=300)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
# Provide download button
|
137 |
+
with open(output_file, "rb") as file:
|
138 |
+
st.download_button(
|
139 |
+
label="Download Translated Document",
|
140 |
+
data=file,
|
141 |
+
file_name=os.path.basename(output_file),
|
142 |
+
mime="text/plain"
|
143 |
+
)
|
144 |
|
145 |
if __name__ == "__main__":
|
146 |
main()
|