Spaces:
Paused
Paused
Upload 4 files
Browse files- __init__.py +6 -0
- denoiser.py +159 -0
- transcriber.py +135 -0
- translator.py +151 -0
__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# services/__init__.py
|
| 2 |
+
from .denoiser import Denoiser
|
| 3 |
+
from .transcriber import Transcriber
|
| 4 |
+
from .translator import Translator
|
| 5 |
+
|
| 6 |
+
__all__ = ["Denoiser", "Transcriber", "Translator"]
|
denoiser.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Department 1 — Denoiser
|
| 3 |
+
Uses DeepFilterNet3 for professional noise removal.
|
| 4 |
+
Processing order:
|
| 5 |
+
convert format → read → stereo→mono → resample →
|
| 6 |
+
gentle pre-boost (max 3×) → AI denoise → EBU R128 normalize → save
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import subprocess
|
| 12 |
+
import tempfile
|
| 13 |
+
import numpy as np
|
| 14 |
+
import soundfile as sf
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
TARGET_SR = 48_000 # DeepFilterNet3 native sample rate
|
| 20 |
+
TARGET_LOUDNESS = -23.0 # EBU R128 target LUFS
|
| 21 |
+
PRE_BOOST_MAX = 3.0 # max linear gain before denoise
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Denoiser:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
print("[Denoiser] Initialising DeepFilterNet3…")
|
| 27 |
+
try:
|
| 28 |
+
from df.enhance import enhance, init_df, load_audio, save_audio
|
| 29 |
+
self._enhance = enhance
|
| 30 |
+
self._init_df = init_df
|
| 31 |
+
self._load_audio = load_audio
|
| 32 |
+
self._save_audio = save_audio
|
| 33 |
+
|
| 34 |
+
# Load model once
|
| 35 |
+
self.model, self.df_state, _ = init_df()
|
| 36 |
+
print("[Denoiser] ✅ DeepFilterNet3 loaded")
|
| 37 |
+
except Exception as e:
|
| 38 |
+
logger.warning(f"[Denoiser] DeepFilterNet3 not available: {e}")
|
| 39 |
+
self.model = None
|
| 40 |
+
|
| 41 |
+
# ── Public ──────────────────────────────────────────────────────
|
| 42 |
+
def process(self, audio_path: str, out_dir: str) -> str:
|
| 43 |
+
"""
|
| 44 |
+
Full pipeline: convert → load → mono → resample →
|
| 45 |
+
pre-boost → denoise → normalise → save.
|
| 46 |
+
Returns path to denoised WAV.
|
| 47 |
+
"""
|
| 48 |
+
t0 = time.time()
|
| 49 |
+
|
| 50 |
+
# Step 1: Convert any format → WAV via ffmpeg
|
| 51 |
+
wav_path = os.path.join(out_dir, "input.wav")
|
| 52 |
+
self._convert_to_wav(audio_path, wav_path)
|
| 53 |
+
|
| 54 |
+
# Step 2: Read audio
|
| 55 |
+
audio, sr = sf.read(wav_path, always_2d=True) # shape (samples, channels)
|
| 56 |
+
|
| 57 |
+
# Step 3: Stereo → mono
|
| 58 |
+
if audio.ndim > 1 and audio.shape[1] > 1:
|
| 59 |
+
audio = audio.mean(axis=1)
|
| 60 |
+
else:
|
| 61 |
+
audio = audio.squeeze()
|
| 62 |
+
|
| 63 |
+
# Step 4: Resample to 48 kHz
|
| 64 |
+
if sr != TARGET_SR:
|
| 65 |
+
audio = self._resample(audio, sr, TARGET_SR)
|
| 66 |
+
sr = TARGET_SR
|
| 67 |
+
|
| 68 |
+
# Step 5: Gentle pre-boost (max 3×)
|
| 69 |
+
peak = np.abs(audio).max()
|
| 70 |
+
if peak > 0 and peak < 1.0:
|
| 71 |
+
boost = min(PRE_BOOST_MAX, 1.0 / peak)
|
| 72 |
+
audio = audio * boost
|
| 73 |
+
|
| 74 |
+
# Clip to [-1, 1] after boost
|
| 75 |
+
audio = np.clip(audio, -1.0, 1.0).astype(np.float32)
|
| 76 |
+
|
| 77 |
+
# Step 6: AI Denoise
|
| 78 |
+
if self.model is not None:
|
| 79 |
+
try:
|
| 80 |
+
# DeepFilterNet3 expects (1, samples) tensor
|
| 81 |
+
import torch
|
| 82 |
+
tensor = torch.from_numpy(audio).unsqueeze(0)
|
| 83 |
+
enhanced = self._enhance(self.model, self.df_state, tensor)
|
| 84 |
+
audio = enhanced.squeeze(0).numpy()
|
| 85 |
+
except Exception as e:
|
| 86 |
+
logger.warning(f"[Denoiser] DeepFilterNet3 enhance failed, using raw: {e}")
|
| 87 |
+
|
| 88 |
+
# Step 7: EBU R128 loudness normalisation
|
| 89 |
+
audio = self._normalise_loudness(audio, sr)
|
| 90 |
+
|
| 91 |
+
# Step 8: Save
|
| 92 |
+
out_path = os.path.join(out_dir, "denoised.wav")
|
| 93 |
+
sf.write(out_path, audio, sr, subtype="PCM_16")
|
| 94 |
+
|
| 95 |
+
logger.info(f"[Denoiser] Done in {time.time()-t0:.2f}s → {out_path}")
|
| 96 |
+
return out_path
|
| 97 |
+
|
| 98 |
+
# ── Private helpers ──────────────────────────────────────────────
|
| 99 |
+
def _convert_to_wav(self, src: str, dst: str):
|
| 100 |
+
"""Convert any audio format to 16-bit PCM WAV using ffmpeg."""
|
| 101 |
+
cmd = [
|
| 102 |
+
"ffmpeg", "-y", "-i", src,
|
| 103 |
+
"-acodec", "pcm_s16le",
|
| 104 |
+
"-ar", str(TARGET_SR),
|
| 105 |
+
"-ac", "1",
|
| 106 |
+
dst
|
| 107 |
+
]
|
| 108 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 109 |
+
if result.returncode != 0:
|
| 110 |
+
# ffmpeg failed — try soundfile direct read as fallback
|
| 111 |
+
logger.warning(f"[Denoiser] ffmpeg conversion failed, trying soundfile direct read")
|
| 112 |
+
try:
|
| 113 |
+
data, sr_in = sf.read(src, always_2d=True)
|
| 114 |
+
sf.write(dst, data, sr_in, subtype="PCM_16")
|
| 115 |
+
except Exception as e2:
|
| 116 |
+
raise RuntimeError(
|
| 117 |
+
f"Could not read audio file '{os.path.basename(src)}'. "
|
| 118 |
+
f"ffmpeg error: {result.stderr[:200]}"
|
| 119 |
+
) from e2
|
| 120 |
+
|
| 121 |
+
def _resample(self, audio: np.ndarray, src_sr: int, tgt_sr: int) -> np.ndarray:
|
| 122 |
+
try:
|
| 123 |
+
import resampy
|
| 124 |
+
return resampy.resample(audio, src_sr, tgt_sr)
|
| 125 |
+
except ImportError:
|
| 126 |
+
pass
|
| 127 |
+
try:
|
| 128 |
+
import librosa
|
| 129 |
+
return librosa.resample(audio, orig_sr=src_sr, target_sr=tgt_sr)
|
| 130 |
+
except ImportError:
|
| 131 |
+
pass
|
| 132 |
+
# Simple linear interpolation fallback
|
| 133 |
+
ratio = tgt_sr / src_sr
|
| 134 |
+
n_samples = int(len(audio) * ratio)
|
| 135 |
+
indices = np.linspace(0, len(audio) - 1, n_samples)
|
| 136 |
+
return np.interp(indices, np.arange(len(audio)), audio).astype(np.float32)
|
| 137 |
+
|
| 138 |
+
def _normalise_loudness(self, audio: np.ndarray, sr: int) -> np.ndarray:
|
| 139 |
+
"""
|
| 140 |
+
EBU R128 normalisation.
|
| 141 |
+
Targets TARGET_LOUDNESS LUFS; falls back to RMS normalisation if
|
| 142 |
+
pyloudnorm is unavailable.
|
| 143 |
+
"""
|
| 144 |
+
try:
|
| 145 |
+
import pyloudnorm as pyln
|
| 146 |
+
meter = pyln.Meter(sr)
|
| 147 |
+
loudness = meter.integrated_loudness(audio)
|
| 148 |
+
if np.isfinite(loudness) and loudness < 0:
|
| 149 |
+
audio = pyln.normalize.loudness(audio, loudness, TARGET_LOUDNESS)
|
| 150 |
+
return np.clip(audio, -1.0, 1.0).astype(np.float32)
|
| 151 |
+
except Exception:
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
# RMS fallback
|
| 155 |
+
rms = np.sqrt(np.mean(audio ** 2))
|
| 156 |
+
if rms > 1e-9:
|
| 157 |
+
target_rms = 10 ** (TARGET_LOUDNESS / 20.0)
|
| 158 |
+
audio = audio * (target_rms / rms)
|
| 159 |
+
return np.clip(audio, -1.0, 1.0).astype(np.float32)
|
transcriber.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Department 2 — Transcriber
|
| 3 |
+
Primary : Groq API (Whisper large-v3 on H100) — free tier 14 400 s/day
|
| 4 |
+
Fallback : faster-whisper (local, small model) if Groq fails or limit reached
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import time
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# Whisper language codes that map to our short codes
|
| 14 |
+
LANG_TO_WHISPER = {
|
| 15 |
+
"auto": None,
|
| 16 |
+
"en": "en",
|
| 17 |
+
"te": "te",
|
| 18 |
+
"hi": "hi",
|
| 19 |
+
"ta": "ta",
|
| 20 |
+
"kn": "kn",
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Transcriber:
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.groq_key = os.environ.get("GROQ_API_KEY", "")
|
| 27 |
+
self._groq_client = None
|
| 28 |
+
self._local_model = None
|
| 29 |
+
|
| 30 |
+
if self.groq_key:
|
| 31 |
+
print("[Transcriber] Groq API key found — primary = Groq Whisper large-v3")
|
| 32 |
+
self._init_groq()
|
| 33 |
+
else:
|
| 34 |
+
print("[Transcriber] ⚠️ No GROQ_API_KEY — falling back to local Whisper small")
|
| 35 |
+
self._init_local()
|
| 36 |
+
|
| 37 |
+
# ── Public ──────────────────────────────────────────────────────
|
| 38 |
+
def transcribe(self, audio_path: str, language: str = "auto"):
|
| 39 |
+
"""
|
| 40 |
+
Returns (transcript_text, detected_language_code, method_label)
|
| 41 |
+
"""
|
| 42 |
+
lang_hint = LANG_TO_WHISPER.get(language, None)
|
| 43 |
+
|
| 44 |
+
if self._groq_client is not None:
|
| 45 |
+
try:
|
| 46 |
+
return self._transcribe_groq(audio_path, lang_hint)
|
| 47 |
+
except Exception as e:
|
| 48 |
+
logger.warning(f"[Transcriber] Groq failed ({e}), falling back to local…")
|
| 49 |
+
if self._local_model is None:
|
| 50 |
+
self._init_local()
|
| 51 |
+
|
| 52 |
+
return self._transcribe_local(audio_path, lang_hint)
|
| 53 |
+
|
| 54 |
+
# ── Groq ─────────────────────────────────────────────────────────
|
| 55 |
+
def _init_groq(self):
|
| 56 |
+
try:
|
| 57 |
+
from groq import Groq
|
| 58 |
+
self._groq_client = Groq(api_key=self.groq_key)
|
| 59 |
+
print("[Transcriber] ✅ Groq client initialised")
|
| 60 |
+
except Exception as e:
|
| 61 |
+
logger.warning(f"[Transcriber] Groq init failed: {e}")
|
| 62 |
+
self._groq_client = None
|
| 63 |
+
self._init_local()
|
| 64 |
+
|
| 65 |
+
def _transcribe_groq(self, audio_path: str, language=None):
|
| 66 |
+
t0 = time.time()
|
| 67 |
+
with open(audio_path, "rb") as f:
|
| 68 |
+
kwargs = dict(
|
| 69 |
+
file=f,
|
| 70 |
+
model="whisper-large-v3",
|
| 71 |
+
response_format="verbose_json",
|
| 72 |
+
temperature=0.0,
|
| 73 |
+
)
|
| 74 |
+
if language:
|
| 75 |
+
kwargs["language"] = language
|
| 76 |
+
|
| 77 |
+
resp = self._groq_client.audio.transcriptions.create(**kwargs)
|
| 78 |
+
|
| 79 |
+
transcript = resp.text.strip()
|
| 80 |
+
detected_lang = getattr(resp, "language", language or "en") or "en"
|
| 81 |
+
# Groq returns full names like "english" — normalise
|
| 82 |
+
detected_lang = self._normalise_lang(detected_lang)
|
| 83 |
+
|
| 84 |
+
logger.info(f"[Transcriber] Groq done in {time.time()-t0:.2f}s, lang={detected_lang}")
|
| 85 |
+
return transcript, detected_lang, "Groq Whisper large-v3"
|
| 86 |
+
|
| 87 |
+
# ── Local Whisper ────────────────────────────────────────────────
|
| 88 |
+
def _init_local(self):
|
| 89 |
+
try:
|
| 90 |
+
from faster_whisper import WhisperModel
|
| 91 |
+
print("[Transcriber] Loading faster-whisper small (CPU)…")
|
| 92 |
+
self._local_model = WhisperModel(
|
| 93 |
+
"small",
|
| 94 |
+
device="cpu",
|
| 95 |
+
compute_type="int8",
|
| 96 |
+
)
|
| 97 |
+
print("[Transcriber] ✅ faster-whisper small ready")
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"[Transcriber] Local Whisper init failed: {e}")
|
| 100 |
+
self._local_model = None
|
| 101 |
+
|
| 102 |
+
def _transcribe_local(self, audio_path: str, language=None):
|
| 103 |
+
t0 = time.time()
|
| 104 |
+
if self._local_model is None:
|
| 105 |
+
raise RuntimeError("No transcription engine available.")
|
| 106 |
+
|
| 107 |
+
segments, info = self._local_model.transcribe(
|
| 108 |
+
audio_path,
|
| 109 |
+
language=language,
|
| 110 |
+
beam_size=5,
|
| 111 |
+
vad_filter=True,
|
| 112 |
+
)
|
| 113 |
+
transcript = " ".join(seg.text.strip() for seg in segments).strip()
|
| 114 |
+
detected_lang = info.language or language or "en"
|
| 115 |
+
|
| 116 |
+
logger.info(f"[Transcriber] Local done in {time.time()-t0:.2f}s, lang={detected_lang}")
|
| 117 |
+
return transcript, detected_lang, "local Whisper small (fallback)"
|
| 118 |
+
|
| 119 |
+
# ── Helpers ──────────────────────────────────────────────────────
|
| 120 |
+
@staticmethod
|
| 121 |
+
def _normalise_lang(raw: str) -> str:
|
| 122 |
+
"""Convert Groq full language names to 2-letter codes."""
|
| 123 |
+
mapping = {
|
| 124 |
+
"english": "en",
|
| 125 |
+
"telugu": "te",
|
| 126 |
+
"hindi": "hi",
|
| 127 |
+
"tamil": "ta",
|
| 128 |
+
"kannada": "kn",
|
| 129 |
+
"spanish": "es",
|
| 130 |
+
"french": "fr",
|
| 131 |
+
"german": "de",
|
| 132 |
+
"japanese": "ja",
|
| 133 |
+
"chinese": "zh",
|
| 134 |
+
}
|
| 135 |
+
return mapping.get(raw.lower(), raw[:2].lower() if len(raw) >= 2 else raw)
|
translator.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Department 3 - Translator
|
| 3 |
+
Primary : NLLB-200-distilled-600M (Meta, offline on ZeroGPU)
|
| 4 |
+
Fallback : deep-translator (Google Translate) if NLLB fails
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import time
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
# Map simple 2-letter UI codes to NLLB-200 language codes
|
| 13 |
+
NLLB_CODES = {
|
| 14 |
+
"en": "eng_Latn",
|
| 15 |
+
"te": "tel_Telu",
|
| 16 |
+
"hi": "hin_Deva",
|
| 17 |
+
"ta": "tam_Taml",
|
| 18 |
+
"kn": "kan_Knda",
|
| 19 |
+
"es": "spa_Latn",
|
| 20 |
+
"fr": "fra_Latn",
|
| 21 |
+
"de": "deu_Latn",
|
| 22 |
+
"ja": "jpn_Jpan",
|
| 23 |
+
"zh": "zho_Hans",
|
| 24 |
+
"ar": "arb_Arab",
|
| 25 |
+
"pt": "por_Latn",
|
| 26 |
+
"ru": "rus_Cyrl",
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
MODEL_ID = "facebook/nllb-200-distilled-600M"
|
| 30 |
+
MAX_LENGTH = 512
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Translator:
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self._pipeline = None
|
| 36 |
+
self._tokenizer = None
|
| 37 |
+
self._model = None
|
| 38 |
+
print(f"[Translator] Loading {MODEL_ID}...")
|
| 39 |
+
self._init_nllb()
|
| 40 |
+
|
| 41 |
+
# ----------------------------------------------------------------
|
| 42 |
+
# Public
|
| 43 |
+
# ----------------------------------------------------------------
|
| 44 |
+
def translate(self, text: str, src_lang: str, tgt_lang: str):
|
| 45 |
+
"""
|
| 46 |
+
Returns (translated_text, method_label).
|
| 47 |
+
src_lang / tgt_lang are 2-letter codes (en, te, hi, ...).
|
| 48 |
+
"""
|
| 49 |
+
if not text or not text.strip():
|
| 50 |
+
return "", "skipped (empty)"
|
| 51 |
+
|
| 52 |
+
if self._pipeline is not None or self._model is not None:
|
| 53 |
+
try:
|
| 54 |
+
return self._translate_nllb(text, src_lang, tgt_lang)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
logger.warning(f"[Translator] NLLB failed ({e}), trying Google...")
|
| 57 |
+
|
| 58 |
+
return self._translate_google(text, src_lang, tgt_lang)
|
| 59 |
+
|
| 60 |
+
# ----------------------------------------------------------------
|
| 61 |
+
# NLLB-200
|
| 62 |
+
# ----------------------------------------------------------------
|
| 63 |
+
def _init_nllb(self):
|
| 64 |
+
try:
|
| 65 |
+
from transformers import pipeline as hf_pipeline
|
| 66 |
+
self._pipeline = hf_pipeline(
|
| 67 |
+
"translation",
|
| 68 |
+
model=MODEL_ID,
|
| 69 |
+
device_map="auto",
|
| 70 |
+
max_length=MAX_LENGTH,
|
| 71 |
+
)
|
| 72 |
+
print("[Translator] NLLB-200-distilled-600M loaded via pipeline")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.warning(f"[Translator] pipeline init failed, trying manual load: {e}")
|
| 75 |
+
self._init_nllb_manual()
|
| 76 |
+
|
| 77 |
+
def _init_nllb_manual(self):
|
| 78 |
+
try:
|
| 79 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 80 |
+
import torch
|
| 81 |
+
self._tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
| 82 |
+
self._model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 83 |
+
MODEL_ID,
|
| 84 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 85 |
+
)
|
| 86 |
+
if torch.cuda.is_available():
|
| 87 |
+
self._model = self._model.cuda()
|
| 88 |
+
self._model.eval()
|
| 89 |
+
print("[Translator] NLLB-200 loaded manually")
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"[Translator] NLLB manual load also failed: {e}")
|
| 92 |
+
self._model = None
|
| 93 |
+
|
| 94 |
+
def _translate_nllb(self, text: str, src_lang: str, tgt_lang: str):
|
| 95 |
+
t0 = time.time()
|
| 96 |
+
src_code = NLLB_CODES.get(src_lang, "eng_Latn")
|
| 97 |
+
tgt_code = NLLB_CODES.get(tgt_lang, "tel_Telu")
|
| 98 |
+
|
| 99 |
+
if self._pipeline is not None:
|
| 100 |
+
result = self._pipeline(
|
| 101 |
+
text,
|
| 102 |
+
src_lang=src_code,
|
| 103 |
+
tgt_lang=tgt_code,
|
| 104 |
+
max_length=MAX_LENGTH,
|
| 105 |
+
)
|
| 106 |
+
translated = result[0]["translation_text"]
|
| 107 |
+
else:
|
| 108 |
+
import torch
|
| 109 |
+
inputs = self._tokenizer(
|
| 110 |
+
text,
|
| 111 |
+
return_tensors="pt",
|
| 112 |
+
padding=True,
|
| 113 |
+
truncation=True,
|
| 114 |
+
max_length=MAX_LENGTH,
|
| 115 |
+
)
|
| 116 |
+
if torch.cuda.is_available():
|
| 117 |
+
inputs = {k: v.cuda() for k, v in inputs.items()}
|
| 118 |
+
|
| 119 |
+
tgt_lang_id = self._tokenizer.convert_tokens_to_ids(tgt_code)
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
output_ids = self._model.generate(
|
| 122 |
+
**inputs,
|
| 123 |
+
forced_bos_token_id=tgt_lang_id,
|
| 124 |
+
max_length=MAX_LENGTH,
|
| 125 |
+
num_beams=4,
|
| 126 |
+
early_stopping=True,
|
| 127 |
+
)
|
| 128 |
+
translated = self._tokenizer.batch_decode(
|
| 129 |
+
output_ids, skip_special_tokens=True
|
| 130 |
+
)[0]
|
| 131 |
+
|
| 132 |
+
elapsed = time.time() - t0
|
| 133 |
+
logger.info(f"[Translator] NLLB done in {elapsed:.2f}s: {src_code} -> {tgt_code}")
|
| 134 |
+
return translated, "NLLB-200-distilled-600M"
|
| 135 |
+
|
| 136 |
+
# ----------------------------------------------------------------
|
| 137 |
+
# Google Translate fallback
|
| 138 |
+
# ----------------------------------------------------------------
|
| 139 |
+
def _translate_google(self, text: str, src_lang: str, tgt_lang: str):
|
| 140 |
+
t0 = time.time()
|
| 141 |
+
try:
|
| 142 |
+
from deep_translator import GoogleTranslator
|
| 143 |
+
translated = GoogleTranslator(
|
| 144 |
+
source=src_lang if src_lang != "auto" else "auto",
|
| 145 |
+
target=tgt_lang,
|
| 146 |
+
).translate(text)
|
| 147 |
+
logger.info(f"[Translator] Google done in {time.time()-t0:.2f}s")
|
| 148 |
+
return translated, "Google Translate (fallback)"
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"[Translator] Google fallback also failed: {e}")
|
| 151 |
+
return f"[Translation failed: {str(e)}]", "error"
|