Mgolo commited on
Commit
eafa517
·
verified ·
1 Parent(s): bbd3488

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +158 -117
app.py CHANGED
@@ -10,7 +10,7 @@ import os
10
  import re
11
  import logging
12
  import tempfile
13
- from typing import Optional, Dict, Tuple, Any, Union
14
  from pathlib import Path
15
  from dataclasses import dataclass
16
  from enum import Enum
@@ -131,26 +131,33 @@ class ModelManager:
131
 
132
  # Authenticate with Hugging Face if token provided
133
  if hf_token := os.getenv("hffff"):
134
- login(token=hf_token)
135
-
136
- model = AutoModelForSeq2SeqLM.from_pretrained(
137
- config.model_name,
138
- token=hf_token
139
- ).to(self._get_device())
140
-
141
- tokenizer = MarianTokenizer.from_pretrained(
142
- config.model_name,
143
- token=hf_token
144
- )
145
-
146
- self._translation_pipeline = pipeline(
147
- "translation",
148
- model=model,
149
- tokenizer=tokenizer,
150
- device=0 if self._get_device().type == "cuda" else -1
151
- )
152
 
153
- self._current_model_name = config.model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  return self._translation_pipeline, config.language_tag
156
 
@@ -163,7 +170,11 @@ class ModelManager:
163
  """
164
  if self._whisper_model is None:
165
  logger.info("Loading Whisper base model...")
166
- self._whisper_model = whisper.load_model("base")
 
 
 
 
167
  return self._whisper_model
168
 
169
  def _get_device(self) -> torch.device:
@@ -196,20 +207,18 @@ class ContentProcessor:
196
  extension = file_path.suffix.lower()
197
 
198
  try:
199
- content = file_path.read_bytes()
200
-
201
  if extension == ".pdf":
202
- return ContentProcessor._extract_pdf_text(content)
203
  elif extension == ".docx":
204
  return ContentProcessor._extract_docx_text(file_path)
205
  elif extension in (".html", ".htm"):
206
- return ContentProcessor._extract_html_text(content)
207
  elif extension == ".md":
208
- return ContentProcessor._extract_markdown_text(content)
209
  elif extension == ".srt":
210
- return ContentProcessor._extract_srt_text(content)
211
  elif extension in (".txt", ".text"):
212
- return ContentProcessor._extract_plain_text(content)
213
  else:
214
  raise ValueError(f"Unsupported file type: {extension}")
215
 
@@ -218,28 +227,30 @@ class ContentProcessor:
218
  raise
219
 
220
  @staticmethod
221
- def _extract_pdf_text(content: bytes) -> str:
222
  """Extract text from PDF file."""
223
- with fitz.open(stream=content, filetype="pdf") as doc:
224
  return "\n".join(page.get_text() for page in doc)
225
 
226
  @staticmethod
227
  def _extract_docx_text(file_path: Path) -> str:
228
  """Extract text from DOCX file."""
229
- doc = docx.Document(str(file_path))
230
  return "\n".join(paragraph.text for paragraph in doc.paragraphs)
231
 
232
  @staticmethod
233
- def _extract_html_text(content: bytes) -> str:
234
  """Extract text from HTML file."""
 
235
  encoding = chardet.detect(content)["encoding"] or "utf-8"
236
  text = content.decode(encoding, errors="ignore")
237
  soup = BeautifulSoup(text, "html.parser")
238
  return soup.get_text()
239
 
240
  @staticmethod
241
- def _extract_markdown_text(content: bytes) -> str:
242
  """Extract text from Markdown file."""
 
243
  encoding = chardet.detect(content)["encoding"] or "utf-8"
244
  text = content.decode(encoding, errors="ignore")
245
  html = markdown(text)
@@ -247,16 +258,18 @@ class ContentProcessor:
247
  return soup.get_text()
248
 
249
  @staticmethod
250
- def _extract_srt_text(content: bytes) -> str:
251
  """Extract text from SRT subtitle file."""
 
252
  encoding = chardet.detect(content)["encoding"] or "utf-8"
253
  text = content.decode(encoding, errors="ignore")
254
  # Remove timestamp lines
255
  return re.sub(r"\d+\n\d{2}:\d{2}:\d{2},\d{3} --> .*?\n", "", text)
256
 
257
  @staticmethod
258
- def _extract_plain_text(content: bytes) -> str:
259
  """Extract text from plain text file."""
 
260
  encoding = chardet.detect(content)["encoding"] or "utf-8"
261
  return content.decode(encoding, errors="ignore")
262
 
@@ -304,11 +317,15 @@ class TranslationService:
304
  target_lang: Language
305
  ) -> str:
306
  """Perform direct translation using available model."""
307
- pipeline_obj, lang_tag = self.model_manager.get_translation_pipeline(
308
- source_lang, target_lang
309
- )
310
-
311
- return self._process_text_with_pipeline(text, pipeline_obj, lang_tag)
 
 
 
 
312
 
313
  def _chained_translate(
314
  self,
@@ -327,17 +344,21 @@ class TranslationService:
327
  Returns:
328
  Translated text through chaining
329
  """
330
- # First: source_lang -> English
331
- intermediate_text = self._direct_translate(
332
- text, source_lang, Language.ENGLISH
333
- )
334
-
335
- # Second: English -> target_lang
336
- final_text = self._direct_translate(
337
- intermediate_text, Language.ENGLISH, target_lang
338
- )
339
-
340
- return final_text
 
 
 
 
341
 
342
  def _process_text_with_pipeline(
343
  self,
@@ -362,30 +383,38 @@ class TranslationService:
362
  if s.strip()
363
  ]
364
 
 
 
 
 
365
  # Add language tag to each sentence
366
  formatted_sentences = [
367
  f"{lang_tag} {sentence}"
368
  for sentence in sentences
369
  ]
370
 
371
- # Perform translation
372
- results = pipeline_obj(
373
- formatted_sentences,
374
- max_length=5000,
375
- num_beams=5,
376
- early_stopping=True,
377
- no_repeat_ngram_size=3,
378
- repetition_penalty=1.5,
379
- length_penalty=1.2
380
- )
381
-
382
- # Process results
383
- translated_sentences = [
384
- result["translation_text"].capitalize()
385
- for result in results
386
- ]
387
-
388
- translated_paragraphs.append(". ".join(translated_sentences))
 
 
 
 
389
 
390
  return "\n".join(translated_paragraphs)
391
 
@@ -409,9 +438,13 @@ class AudioProcessor:
409
  Returns:
410
  Transcribed text
411
  """
412
- model = self.model_manager.get_whisper_model()
413
- result = model.transcribe(audio_file_path)
414
- return result["text"]
 
 
 
 
415
 
416
  # ================================
417
  # Main Application
@@ -432,7 +465,7 @@ class TranslationApp:
432
  source_lang: Language,
433
  text_input: str,
434
  audio_file: Optional[str],
435
- file_obj: Optional[gr.FileData]
436
  ) -> str:
437
  """
438
  Process input based on selected mode.
@@ -447,22 +480,29 @@ class TranslationApp:
447
  Returns:
448
  Processed text content
449
  """
450
- if mode == InputMode.TEXT:
451
- return text_input
452
-
453
- elif mode == InputMode.AUDIO:
454
- if source_lang != Language.ENGLISH:
455
- raise ValueError("Audio input must be in English.")
456
- if not audio_file:
457
- raise ValueError("No audio file provided.")
458
- return self.audio_processor.transcribe(audio_file)
459
-
460
- elif mode == InputMode.FILE:
461
- if not file_obj:
462
- raise ValueError("No file uploaded.")
463
- return self.content_processor.extract_text_from_file(file_obj.name)
464
-
465
- return ""
 
 
 
 
 
 
 
466
 
467
  def create_interface(self) -> gr.Blocks:
468
  """Create and return the Gradio interface."""
@@ -471,6 +511,22 @@ class TranslationApp:
471
  title="LocaleNLP Translation Service",
472
  theme=gr.themes.Monochrome()
473
  ) as interface:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  # Header
475
  gr.Markdown("""
476
  # 🌍 LocaleNLP Translation Service
@@ -536,22 +592,26 @@ class TranslationApp:
536
  )
537
 
538
  # Event handlers
539
- def update_visibility(mode: str) -> Dict[str, Any]:
540
  """Update component visibility based on input mode."""
541
- return {
542
- input_text: gr.update(visible=(mode == InputMode.TEXT.value)),
543
- audio_input: gr.update(visible=(mode == InputMode.AUDIO.value)),
544
- file_input: gr.update(visible=(mode == InputMode.FILE.value)),
545
- extracted_text: gr.update(value="", visible=True),
546
- output_text: gr.update(value="")
547
- }
 
 
 
 
548
 
549
  def handle_process(
550
  mode: str,
551
  source_lang: str,
552
  text_input: str,
553
  audio_file: Optional[str],
554
- file_obj: Optional[gr.FileData]
555
  ) -> Tuple[str, str]:
556
  """Handle initial input processing."""
557
  try:
@@ -601,25 +661,6 @@ class TranslationApp:
601
  inputs=[extracted_text, input_lang, output_lang],
602
  outputs=output_text
603
  )
604
-
605
- # Custom CSS for black button (applied after interface creation)
606
- interface.load(lambda: None, None, None, _js="""
607
- () => {
608
- const style = document.createElement('style');
609
- style.textContent = `
610
- .gr-button-secondary {
611
- background-color: #000000 !important;
612
- border-color: #000000 !important;
613
- color: white !important;
614
- }
615
- .gr-button-secondary:hover {
616
- background-color: #333333 !important;
617
- border-color: #333333 !important;
618
- }
619
- `;
620
- document.head.appendChild(style);
621
- }
622
- """)
623
 
624
  return interface
625
 
 
10
  import re
11
  import logging
12
  import tempfile
13
+ from typing import Optional, Dict, Tuple, Any, Union, List
14
  from pathlib import Path
15
  from dataclasses import dataclass
16
  from enum import Enum
 
131
 
132
  # Authenticate with Hugging Face if token provided
133
  if hf_token := os.getenv("hffff"):
134
+ try:
135
+ login(token=hf_token)
136
+ except Exception as e:
137
+ logger.warning(f"HF login failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
+ try:
140
+ model = AutoModelForSeq2SeqLM.from_pretrained(
141
+ config.model_name,
142
+ token=hf_token if hf_token else None
143
+ ).to(self._get_device())
144
+
145
+ tokenizer = MarianTokenizer.from_pretrained(
146
+ config.model_name,
147
+ token=hf_token if hf_token else None
148
+ )
149
+
150
+ self._translation_pipeline = pipeline(
151
+ "translation",
152
+ model=model,
153
+ tokenizer=tokenizer,
154
+ device=0 if self._get_device().type == "cuda" else -1
155
+ )
156
+
157
+ self._current_model_name = config.model_name
158
+ except Exception as e:
159
+ logger.error(f"Failed to load model {config.model_name}: {e}")
160
+ raise
161
 
162
  return self._translation_pipeline, config.language_tag
163
 
 
170
  """
171
  if self._whisper_model is None:
172
  logger.info("Loading Whisper base model...")
173
+ try:
174
+ self._whisper_model = whisper.load_model("base")
175
+ except Exception as e:
176
+ logger.error(f"Failed to load Whisper model: {e}")
177
+ raise
178
  return self._whisper_model
179
 
180
  def _get_device(self) -> torch.device:
 
207
  extension = file_path.suffix.lower()
208
 
209
  try:
 
 
210
  if extension == ".pdf":
211
+ return ContentProcessor._extract_pdf_text(file_path)
212
  elif extension == ".docx":
213
  return ContentProcessor._extract_docx_text(file_path)
214
  elif extension in (".html", ".htm"):
215
+ return ContentProcessor._extract_html_text(file_path)
216
  elif extension == ".md":
217
+ return ContentProcessor._extract_markdown_text(file_path)
218
  elif extension == ".srt":
219
+ return ContentProcessor._extract_srt_text(file_path)
220
  elif extension in (".txt", ".text"):
221
+ return ContentProcessor._extract_plain_text(file_path)
222
  else:
223
  raise ValueError(f"Unsupported file type: {extension}")
224
 
 
227
  raise
228
 
229
  @staticmethod
230
+ def _extract_pdf_text(file_path: Path) -> str:
231
  """Extract text from PDF file."""
232
+ with fitz.open(file_path) as doc:
233
  return "\n".join(page.get_text() for page in doc)
234
 
235
  @staticmethod
236
  def _extract_docx_text(file_path: Path) -> str:
237
  """Extract text from DOCX file."""
238
+ doc = docx.Document(file_path)
239
  return "\n".join(paragraph.text for paragraph in doc.paragraphs)
240
 
241
  @staticmethod
242
+ def _extract_html_text(file_path: Path) -> str:
243
  """Extract text from HTML file."""
244
+ content = file_path.read_bytes()
245
  encoding = chardet.detect(content)["encoding"] or "utf-8"
246
  text = content.decode(encoding, errors="ignore")
247
  soup = BeautifulSoup(text, "html.parser")
248
  return soup.get_text()
249
 
250
  @staticmethod
251
+ def _extract_markdown_text(file_path: Path) -> str:
252
  """Extract text from Markdown file."""
253
+ content = file_path.read_bytes()
254
  encoding = chardet.detect(content)["encoding"] or "utf-8"
255
  text = content.decode(encoding, errors="ignore")
256
  html = markdown(text)
 
258
  return soup.get_text()
259
 
260
  @staticmethod
261
+ def _extract_srt_text(file_path: Path) -> str:
262
  """Extract text from SRT subtitle file."""
263
+ content = file_path.read_bytes()
264
  encoding = chardet.detect(content)["encoding"] or "utf-8"
265
  text = content.decode(encoding, errors="ignore")
266
  # Remove timestamp lines
267
  return re.sub(r"\d+\n\d{2}:\d{2}:\d{2},\d{3} --> .*?\n", "", text)
268
 
269
  @staticmethod
270
+ def _extract_plain_text(file_path: Path) -> str:
271
  """Extract text from plain text file."""
272
+ content = file_path.read_bytes()
273
  encoding = chardet.detect(content)["encoding"] or "utf-8"
274
  return content.decode(encoding, errors="ignore")
275
 
 
317
  target_lang: Language
318
  ) -> str:
319
  """Perform direct translation using available model."""
320
+ try:
321
+ pipeline_obj, lang_tag = self.model_manager.get_translation_pipeline(
322
+ source_lang, target_lang
323
+ )
324
+
325
+ return self._process_text_with_pipeline(text, pipeline_obj, lang_tag)
326
+ except Exception as e:
327
+ logger.error(f"Direct translation error: {e}")
328
+ return f"Translation error: {str(e)}"
329
 
330
  def _chained_translate(
331
  self,
 
344
  Returns:
345
  Translated text through chaining
346
  """
347
+ try:
348
+ # First: source_lang -> English
349
+ intermediate_text = self._direct_translate(
350
+ text, source_lang, Language.ENGLISH
351
+ )
352
+
353
+ # Second: English -> target_lang
354
+ final_text = self._direct_translate(
355
+ intermediate_text, Language.ENGLISH, target_lang
356
+ )
357
+
358
+ return final_text
359
+ except Exception as e:
360
+ logger.error(f"Chained translation error: {e}")
361
+ return f"Chained translation error: {str(e)}"
362
 
363
  def _process_text_with_pipeline(
364
  self,
 
383
  if s.strip()
384
  ]
385
 
386
+ if not sentences:
387
+ translated_paragraphs.append("")
388
+ continue
389
+
390
  # Add language tag to each sentence
391
  formatted_sentences = [
392
  f"{lang_tag} {sentence}"
393
  for sentence in sentences
394
  ]
395
 
396
+ try:
397
+ # Perform translation
398
+ results = pipeline_obj(
399
+ formatted_sentences,
400
+ max_length=5000,
401
+ num_beams=5,
402
+ early_stopping=True,
403
+ no_repeat_ngram_size=3,
404
+ repetition_penalty=1.5,
405
+ length_penalty=1.2
406
+ )
407
+
408
+ # Process results
409
+ translated_sentences = [
410
+ result["translation_text"].capitalize()
411
+ for result in results
412
+ ]
413
+
414
+ translated_paragraphs.append(". ".join(translated_sentences))
415
+ except Exception as e:
416
+ logger.error(f"Pipeline processing error: {e}")
417
+ translated_paragraphs.append(f"[Translation Error: {str(e)}]")
418
 
419
  return "\n".join(translated_paragraphs)
420
 
 
438
  Returns:
439
  Transcribed text
440
  """
441
+ try:
442
+ model = self.model_manager.get_whisper_model()
443
+ result = model.transcribe(audio_file_path)
444
+ return result["text"]
445
+ except Exception as e:
446
+ logger.error(f"Transcription error: {e}")
447
+ return f"Transcription error: {str(e)}"
448
 
449
  # ================================
450
  # Main Application
 
465
  source_lang: Language,
466
  text_input: str,
467
  audio_file: Optional[str],
468
+ file_obj: Optional[Any]
469
  ) -> str:
470
  """
471
  Process input based on selected mode.
 
480
  Returns:
481
  Processed text content
482
  """
483
+ try:
484
+ if mode == InputMode.TEXT:
485
+ return text_input
486
+
487
+ elif mode == InputMode.AUDIO:
488
+ if source_lang != Language.ENGLISH:
489
+ return "Audio input must be in English."
490
+ if not audio_file:
491
+ return "No audio file provided."
492
+ return self.audio_processor.transcribe(audio_file)
493
+
494
+ elif mode == InputMode.FILE:
495
+ if not file_obj:
496
+ return "No file uploaded."
497
+
498
+ # Handle Gradio file object (could be a string path or a file-like object)
499
+ file_path = file_obj.name if hasattr(file_obj, 'name') else file_obj
500
+ return self.content_processor.extract_text_from_file(file_path)
501
+
502
+ return ""
503
+ except Exception as e:
504
+ logger.error(f"Input processing error: {e}")
505
+ return f"Input processing error: {str(e)}"
506
 
507
  def create_interface(self) -> gr.Blocks:
508
  """Create and return the Gradio interface."""
 
511
  title="LocaleNLP Translation Service",
512
  theme=gr.themes.Monochrome()
513
  ) as interface:
514
+ # Custom CSS for black button
515
+ gr.HTML("""
516
+ <style>
517
+ .gr-button-secondary {
518
+ background-color: #000000 !important;
519
+ border-color: #000000 !important;
520
+ color: white !important;
521
+ }
522
+ .gr-button-secondary:hover {
523
+ background-color: #333333 !important;
524
+ border-color: #333333 !important;
525
+ color: white !important;
526
+ }
527
+ </style>
528
+ """)
529
+
530
  # Header
531
  gr.Markdown("""
532
  # 🌍 LocaleNLP Translation Service
 
592
  )
593
 
594
  # Event handlers
595
+ def update_visibility(mode: str) -> List[Dict[str, Any]]:
596
  """Update component visibility based on input mode."""
597
+ visibility_text = mode == InputMode.TEXT.value
598
+ visibility_audio = mode == InputMode.AUDIO.value
599
+ visibility_file = mode == InputMode.FILE.value
600
+
601
+ return [
602
+ gr.update(visible=visibility_text),
603
+ gr.update(visible=visibility_audio),
604
+ gr.update(visible=visibility_file),
605
+ gr.update(value="", visible=True),
606
+ gr.update(value="")
607
+ ]
608
 
609
  def handle_process(
610
  mode: str,
611
  source_lang: str,
612
  text_input: str,
613
  audio_file: Optional[str],
614
+ file_obj: Optional[Any]
615
  ) -> Tuple[str, str]:
616
  """Handle initial input processing."""
617
  try:
 
661
  inputs=[extracted_text, input_lang, output_lang],
662
  outputs=output_text
663
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
664
 
665
  return interface
666