Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -12,33 +12,41 @@ import os
|
|
12 |
import re
|
13 |
import torch
|
14 |
import numpy as np
|
|
|
15 |
|
16 |
# Load models and tokenizers
|
17 |
@st.cache_resource
|
18 |
def load_models():
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def get_bert_embeddings(text, models):
|
38 |
"""Get contextual embeddings from BERT"""
|
39 |
tokenizer, model = models["context"]
|
40 |
|
41 |
-
# Split text into smaller chunks
|
42 |
max_length = 512
|
43 |
chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)]
|
44 |
contextual_embeddings = []
|
@@ -47,7 +55,7 @@ def get_bert_embeddings(text, models):
|
|
47 |
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
48 |
with torch.no_grad():
|
49 |
outputs = model(**inputs)
|
50 |
-
embeddings = outputs.last_hidden_state.mean(dim=1)
|
51 |
contextual_embeddings.append(embeddings)
|
52 |
|
53 |
# Combine embeddings from all chunks
|
@@ -55,7 +63,7 @@ def get_bert_embeddings(text, models):
|
|
55 |
return combined_embedding
|
56 |
|
57 |
def apply_grammar_correction(text, models):
|
58 |
-
"""
|
59 |
tokenizer, model = models["grammar"]
|
60 |
|
61 |
sentences = re.split('([.!?।]+)', text)
|
@@ -63,22 +71,23 @@ def apply_grammar_correction(text, models):
|
|
63 |
|
64 |
for sentence in sentences:
|
65 |
if sentence.strip():
|
|
|
66 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
67 |
with torch.no_grad():
|
68 |
outputs = model(**inputs)
|
69 |
predictions = torch.argmax(outputs.logits, dim=2)
|
70 |
|
71 |
-
# Convert predictions to corrected text
|
72 |
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
73 |
corrected_tokens = []
|
74 |
|
75 |
for token, pred in zip(tokens, predictions[0]):
|
76 |
-
if pred == 0
|
77 |
-
|
78 |
-
|
79 |
|
80 |
corrected_text = tokenizer.convert_tokens_to_string(corrected_tokens)
|
81 |
-
|
|
|
82 |
|
83 |
return " ".join(corrected_sentences)
|
84 |
|
@@ -109,7 +118,6 @@ def translate_text(text, src_lang, tgt_lang, models):
|
|
109 |
if src_lang == tgt_lang:
|
110 |
return text
|
111 |
|
112 |
-
# Language codes for NLLB
|
113 |
lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
|
114 |
|
115 |
if src_lang not in lang_map or tgt_lang not in lang_map:
|
@@ -118,62 +126,67 @@ def translate_text(text, src_lang, tgt_lang, models):
|
|
118 |
tgt_lang_code = lang_map[tgt_lang]
|
119 |
tokenizer, model = models["nllb"]
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
if
|
131 |
-
current_chunk
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
146 |
-
|
147 |
-
# Add context embedding to attention mask
|
148 |
-
attention_mask = inputs['attention_mask'].float()
|
149 |
-
attention_mask = attention_mask * (1 + 0.1 * context_embedding.norm())
|
150 |
-
|
151 |
-
# Get target language token ID
|
152 |
-
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang_code)
|
153 |
-
|
154 |
-
# Generate translation
|
155 |
-
with torch.no_grad():
|
156 |
-
translated = model.generate(
|
157 |
-
input_ids=inputs['input_ids'],
|
158 |
-
attention_mask=attention_mask,
|
159 |
-
forced_bos_token_id=tgt_lang_id,
|
160 |
-
max_length=512,
|
161 |
-
num_beams=5,
|
162 |
-
length_penalty=1.0,
|
163 |
-
no_repeat_ngram_size=3,
|
164 |
-
do_sample=True,
|
165 |
-
temperature=0.7
|
166 |
-
)
|
167 |
-
translated_chunk = tokenizer.decode(translated[0], skip_special_tokens=True)
|
168 |
-
translated_text += translated_chunk + " "
|
169 |
-
|
170 |
-
# Apply grammar correction
|
171 |
-
corrected_text = apply_grammar_correction(translated_text.strip(), models)
|
172 |
-
|
173 |
-
return corrected_text
|
174 |
|
175 |
def save_text_to_file(text, original_filename, prefix="translated"):
|
176 |
-
|
|
|
177 |
with open(output_filename, "w", encoding="utf-8") as f:
|
178 |
f.write(text)
|
179 |
return output_filename
|
@@ -183,15 +196,14 @@ def process_document(file, source_lang, target_lang, models):
|
|
183 |
# Extract text from uploaded file
|
184 |
text = extract_text(file)
|
185 |
|
186 |
-
#
|
187 |
-
st.sidebar.write("
|
|
|
|
|
188 |
|
189 |
-
# Translate the text
|
190 |
translated_text = translate_text(text, source_lang, target_lang, models)
|
191 |
|
192 |
-
# Log the output text for debugging
|
193 |
-
st.sidebar.write("Output text:", translated_text[:500] + "...")
|
194 |
-
|
195 |
# Save the result
|
196 |
if translated_text.startswith("Error:"):
|
197 |
output_file = save_text_to_file(translated_text, file.name, prefix="error")
|
@@ -199,9 +211,10 @@ def process_document(file, source_lang, target_lang, models):
|
|
199 |
output_file = save_text_to_file(translated_text, file.name)
|
200 |
|
201 |
return output_file, translated_text
|
|
|
202 |
except Exception as e:
|
203 |
error_message = f"Error: {str(e)}"
|
204 |
-
st.error(
|
205 |
output_file = save_text_to_file(error_message, file.name, prefix="error")
|
206 |
return output_file, error_message
|
207 |
|
@@ -209,10 +222,15 @@ def main():
|
|
209 |
st.title("Advanced Document Translator")
|
210 |
st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages.")
|
211 |
|
|
|
|
|
|
|
|
|
212 |
try:
|
213 |
-
# Initialize models
|
214 |
with st.spinner("Loading models..."):
|
215 |
models = load_models()
|
|
|
216 |
|
217 |
# File uploader
|
218 |
uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
|
@@ -224,9 +242,6 @@ def main():
|
|
224 |
with col2:
|
225 |
target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
|
226 |
|
227 |
-
# Add debug mode toggle
|
228 |
-
debug_mode = st.sidebar.checkbox("Enable Debug Mode")
|
229 |
-
|
230 |
if uploaded_file is not None and st.button("Translate"):
|
231 |
with st.spinner("Processing and Translating..."):
|
232 |
output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
|
@@ -242,13 +257,10 @@ def main():
|
|
242 |
file_name=os.path.basename(output_file),
|
243 |
mime="text/plain"
|
244 |
)
|
245 |
-
|
246 |
-
if debug_mode:
|
247 |
-
st.sidebar.write("Translation completed")
|
248 |
-
st.sidebar.write("Output file:", output_file)
|
249 |
|
250 |
except Exception as e:
|
251 |
-
st.error(f"
|
|
|
252 |
|
253 |
if __name__ == "__main__":
|
254 |
main()
|
|
|
12 |
import re
|
13 |
import torch
|
14 |
import numpy as np
|
15 |
+
from datetime import datetime, timezone
|
16 |
|
17 |
# Load models and tokenizers
|
18 |
@st.cache_resource
|
19 |
def load_models():
|
20 |
+
try:
|
21 |
+
# BERT model for context understanding
|
22 |
+
context_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
|
23 |
+
context_model = BertModel.from_pretrained('bert-base-multilingual-cased')
|
24 |
+
|
25 |
+
# NLLB model for translation
|
26 |
+
nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
|
27 |
+
nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
|
28 |
+
|
29 |
+
# Grammar correction model
|
30 |
+
grammar_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
31 |
+
grammar_model = AutoModelForTokenClassification.from_pretrained(
|
32 |
+
'bert-base-cased',
|
33 |
+
num_labels=3 # Assuming 3 labels: keep, delete, replace
|
34 |
+
)
|
35 |
+
|
36 |
+
return {
|
37 |
+
"context": (context_tokenizer, context_model),
|
38 |
+
"nllb": (nllb_tokenizer, nllb_model),
|
39 |
+
"grammar": (grammar_tokenizer, grammar_model)
|
40 |
+
}
|
41 |
+
except Exception as e:
|
42 |
+
st.error(f"Error loading models: {str(e)}")
|
43 |
+
raise e
|
44 |
|
45 |
def get_bert_embeddings(text, models):
|
46 |
"""Get contextual embeddings from BERT"""
|
47 |
tokenizer, model = models["context"]
|
48 |
|
49 |
+
# Split text into smaller chunks
|
50 |
max_length = 512
|
51 |
chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)]
|
52 |
contextual_embeddings = []
|
|
|
55 |
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
56 |
with torch.no_grad():
|
57 |
outputs = model(**inputs)
|
58 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
59 |
contextual_embeddings.append(embeddings)
|
60 |
|
61 |
# Combine embeddings from all chunks
|
|
|
63 |
return combined_embedding
|
64 |
|
65 |
def apply_grammar_correction(text, models):
|
66 |
+
"""Basic grammar correction using BERT"""
|
67 |
tokenizer, model = models["grammar"]
|
68 |
|
69 |
sentences = re.split('([.!?।]+)', text)
|
|
|
71 |
|
72 |
for sentence in sentences:
|
73 |
if sentence.strip():
|
74 |
+
# Basic tokenization and prediction
|
75 |
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
|
76 |
with torch.no_grad():
|
77 |
outputs = model(**inputs)
|
78 |
predictions = torch.argmax(outputs.logits, dim=2)
|
79 |
|
|
|
80 |
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
|
81 |
corrected_tokens = []
|
82 |
|
83 |
for token, pred in zip(tokens, predictions[0]):
|
84 |
+
if pred == 0 or token in ['[CLS]', '[SEP]', '[PAD]']:
|
85 |
+
if token not in ['[CLS]', '[SEP]', '[PAD]']:
|
86 |
+
corrected_tokens.append(token)
|
87 |
|
88 |
corrected_text = tokenizer.convert_tokens_to_string(corrected_tokens)
|
89 |
+
if corrected_text.strip():
|
90 |
+
corrected_sentences.append(corrected_text)
|
91 |
|
92 |
return " ".join(corrected_sentences)
|
93 |
|
|
|
118 |
if src_lang == tgt_lang:
|
119 |
return text
|
120 |
|
|
|
121 |
lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
|
122 |
|
123 |
if src_lang not in lang_map or tgt_lang not in lang_map:
|
|
|
126 |
tgt_lang_code = lang_map[tgt_lang]
|
127 |
tokenizer, model = models["nllb"]
|
128 |
|
129 |
+
try:
|
130 |
+
# Get contextual embeddings
|
131 |
+
context_embedding = get_bert_embeddings(text, models)
|
132 |
+
|
133 |
+
# Split into chunks for translation
|
134 |
+
chunks = []
|
135 |
+
current_chunk = ""
|
136 |
+
|
137 |
+
for sentence in re.split('([.!?।]+)', text):
|
138 |
+
if sentence.strip():
|
139 |
+
if len(current_chunk) + len(sentence) < 450:
|
140 |
+
current_chunk += sentence
|
141 |
+
else:
|
142 |
+
if current_chunk:
|
143 |
+
chunks.append(current_chunk)
|
144 |
+
current_chunk = sentence
|
145 |
+
|
146 |
+
if current_chunk:
|
147 |
+
chunks.append(current_chunk)
|
148 |
+
|
149 |
+
translated_text = ""
|
150 |
+
|
151 |
+
for chunk in chunks:
|
152 |
+
if chunk.strip():
|
153 |
+
inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
|
154 |
+
|
155 |
+
# Use context embedding to modify attention
|
156 |
+
attention_mask = inputs['attention_mask'].float()
|
157 |
+
context_weight = 0.1 * torch.sigmoid(context_embedding.mean())
|
158 |
+
attention_mask = attention_mask * (1 + context_weight)
|
159 |
+
|
160 |
+
# Get target language token ID
|
161 |
+
tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang_code)
|
162 |
+
|
163 |
+
with torch.no_grad():
|
164 |
+
translated = model.generate(
|
165 |
+
input_ids=inputs['input_ids'],
|
166 |
+
attention_mask=attention_mask,
|
167 |
+
forced_bos_token_id=tgt_lang_id,
|
168 |
+
max_length=512,
|
169 |
+
num_beams=5,
|
170 |
+
length_penalty=1.0,
|
171 |
+
no_repeat_ngram_size=3,
|
172 |
+
do_sample=True,
|
173 |
+
temperature=0.7
|
174 |
+
)
|
175 |
+
translated_chunk = tokenizer.decode(translated[0], skip_special_tokens=True)
|
176 |
+
translated_text += translated_chunk + " "
|
177 |
+
|
178 |
+
# Apply basic grammar correction
|
179 |
+
corrected_text = apply_grammar_correction(translated_text.strip(), models)
|
180 |
+
|
181 |
+
return corrected_text
|
182 |
|
183 |
+
except Exception as e:
|
184 |
+
st.error(f"Translation error: {str(e)}")
|
185 |
+
return f"Error during translation: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
def save_text_to_file(text, original_filename, prefix="translated"):
|
188 |
+
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
|
189 |
+
output_filename = f"{prefix}_{timestamp}_{os.path.basename(original_filename)}.txt"
|
190 |
with open(output_filename, "w", encoding="utf-8") as f:
|
191 |
f.write(text)
|
192 |
return output_filename
|
|
|
196 |
# Extract text from uploaded file
|
197 |
text = extract_text(file)
|
198 |
|
199 |
+
# Add debugging information
|
200 |
+
st.sidebar.write("Processing document...")
|
201 |
+
st.sidebar.write(f"Source language: {source_lang}")
|
202 |
+
st.sidebar.write(f"Target language: {target_lang}")
|
203 |
|
204 |
+
# Translate the text
|
205 |
translated_text = translate_text(text, source_lang, target_lang, models)
|
206 |
|
|
|
|
|
|
|
207 |
# Save the result
|
208 |
if translated_text.startswith("Error:"):
|
209 |
output_file = save_text_to_file(translated_text, file.name, prefix="error")
|
|
|
211 |
output_file = save_text_to_file(translated_text, file.name)
|
212 |
|
213 |
return output_file, translated_text
|
214 |
+
|
215 |
except Exception as e:
|
216 |
error_message = f"Error: {str(e)}"
|
217 |
+
st.error(error_message)
|
218 |
output_file = save_text_to_file(error_message, file.name, prefix="error")
|
219 |
return output_file, error_message
|
220 |
|
|
|
222 |
st.title("Advanced Document Translator")
|
223 |
st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages.")
|
224 |
|
225 |
+
# Display current user and timestamp
|
226 |
+
st.sidebar.write(f"Current User: {os.getenv('USER', 'gauravchand')}")
|
227 |
+
st.sidebar.write(f"UTC Time: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
|
228 |
+
|
229 |
try:
|
230 |
+
# Initialize models with error handling
|
231 |
with st.spinner("Loading models..."):
|
232 |
models = load_models()
|
233 |
+
st.success("Models loaded successfully!")
|
234 |
|
235 |
# File uploader
|
236 |
uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
|
|
|
242 |
with col2:
|
243 |
target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
|
244 |
|
|
|
|
|
|
|
245 |
if uploaded_file is not None and st.button("Translate"):
|
246 |
with st.spinner("Processing and Translating..."):
|
247 |
output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
|
|
|
257 |
file_name=os.path.basename(output_file),
|
258 |
mime="text/plain"
|
259 |
)
|
|
|
|
|
|
|
|
|
260 |
|
261 |
except Exception as e:
|
262 |
+
st.error(f"Application error: {str(e)}")
|
263 |
+
st.warning("Please try refreshing the page or contact support.")
|
264 |
|
265 |
if __name__ == "__main__":
|
266 |
main()
|