SuhasGholkar commited on
Commit
ffad3da
·
verified ·
1 Parent(s): 108fe65

Update src/translate.py

Browse files
Files changed (1) hide show
  1. src/translate.py +136 -89
src/translate.py CHANGED
@@ -2,110 +2,157 @@
2
  from __future__ import annotations
3
  import os
4
  import re
5
- from typing import List, Optional
6
 
7
- import torch
8
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
9
- from IndicTransToolkit.processor import IndicProcessor
10
 
11
- # -------- Model choices (CPU-friendly distilled) ----------
12
- # You can switch to 1B by replacing the *_CKPT names below with:
13
- # ai4bharat/indictrans2-indic-en-1B
14
- # ai4bharat/indictrans2-en-indic-1B
15
- INDIC_EN_CKPT = os.getenv("INDIC_EN_MODEL", "ai4bharat/indictrans2-indic-en-dist-200M")
16
- EN_INDIC_CKPT = os.getenv("EN_INDIC_MODEL", "ai4bharat/indictrans2-en-indic-dist-200M")
 
 
 
 
 
17
 
18
- _DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
19
 
20
- # We keep singletons so the large models load once per process
21
- _ip: Optional[IndicProcessor] = None
22
- _tok_indic_en: Optional[AutoTokenizer] = None
23
- _mod_indic_en: Optional[AutoModelForSeq2SeqLM] = None
24
- _tok_en_indic: Optional[AutoTokenizer] = None
25
- _mod_en_indic: Optional[AutoModelForSeq2SeqLM] = None
26
 
27
- # Hindi (Devanagari) target code for IndicTrans2
28
- HINDI = "hin_Deva"
29
- ENGLISH = "eng_Latn"
30
 
31
- def _iproc() -> IndicProcessor:
32
- global _ip
33
- if _ip is None:
34
- _ip = IndicProcessor(inference=True)
35
- return _ip
 
 
36
 
37
- def _load_indic_en():
38
- global _tok_indic_en, _mod_indic_en
39
- if _tok_indic_en is None or _mod_indic_en is None:
40
- _tok_indic_en = AutoTokenizer.from_pretrained(INDIC_EN_CKPT, trust_remote_code=True)
41
- _mod_indic_en = AutoModelForSeq2SeqLM.from_pretrained(
42
- INDIC_EN_CKPT, trust_remote_code=True
43
- ).to(_DEVICE)
44
- return _tok_indic_en, _mod_indic_en
45
 
46
- def _load_en_indic():
47
- global _tok_en_indic, _mod_en_indic
48
- if _tok_en_indic is None or _mod_en_indic is None:
49
- _tok_en_indic = AutoTokenizer.from_pretrained(EN_INDIC_CKPT, trust_remote_code=True)
50
- _mod_en_indic = AutoModelForSeq2SeqLM.from_pretrained(
51
- EN_INDIC_CKPT, trust_remote_code=True
52
- ).to(_DEVICE)
53
- return _tok_en_indic, _mod_en_indic
 
 
54
 
55
- _DEVANAGARI_RE = re.compile(r"[\u0900-\u097F]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- def looks_devanagari(text: str) -> bool:
58
- """Heuristic: any Devanagari char treat as Hindi."""
59
- return bool(_DEVANAGARI_RE.search(text or ""))
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
- def _batch_decode(
62
- model, tok, inputs: List[str], src_lang: str, tgt_lang: str, max_new_tokens=256
63
- ) -> List[str]:
64
- ip = _iproc()
65
- # Pre-process
66
- sents, srcl = ip.preprocess_batch(inputs, src_lang=src_lang, tgt_lang=tgt_lang)
67
- enc = tok(sents, return_tensors="pt", padding=True, truncation=True).to(_DEVICE)
68
- with torch.no_grad():
69
- gen = model.generate(
70
- **enc,
71
- max_new_tokens=max_new_tokens,
72
- num_beams=4,
73
- length_penalty=1.0,
74
  )
75
- out = tok.batch_decode(gen, skip_special_tokens=True)
76
- # Post-process
77
- return ip.postprocess_batch(out, lang=tgt_lang)
 
 
 
 
 
 
 
78
 
79
- def indic_to_en(text: str, src_lang: str = HINDI) -> str:
80
- """Translate Indic→English; default assumes Hindi (hin_Deva)."""
81
- if not text:
 
 
 
82
  return text
83
- tok, mod = _load_indic_en()
84
- return _batch_decode(mod, tok, [text], src_lang=src_lang, tgt_lang=ENGLISH)[0]
85
-
86
- def en_to_lang(text: str, tgt_lang: str = HINDI) -> str:
87
- """Translate English→Indic (Hindi by default)."""
88
- if not text:
 
 
 
 
 
 
 
 
89
  return text
90
- tok, mod = _load_en_indic()
91
- return _batch_decode(mod, tok, [text], src_lang=ENGLISH, tgt_lang=tgt_lang)[0]
92
 
93
- def ensure_english(text: str, src_hint: Optional[str] = None) -> tuple[str, Optional[str]]:
94
  """
95
- If text seems Hindi (or src_hint given), translate to English.
96
- Returns (english_text, original_lang_code_or_None).
97
  """
98
- orig_lang = None
99
- # If the caller knows it's Hindi, pass src_hint="hin_Deva"
100
- if src_hint:
101
- if src_hint != ENGLISH:
102
- return indic_to_en(text, src_lang=src_hint), src_hint
103
- return text, None
104
-
105
- if looks_devanagari(text):
106
- try:
107
- return indic_to_en(text, src_lang=HINDI), HINDI
108
- except Exception:
109
- # Fallback: return original text if translation fails
110
- return text, None
111
- return text, None
 
 
 
2
  from __future__ import annotations
3
  import os
4
  import re
5
+ from typing import Optional, List
6
 
7
+ # Public constants
8
+ ENGLISH = "en"
9
+ HINDI = "hi"
10
 
11
+ # ENV knobs
12
+ ENABLE_TRANSLATION = os.getenv("ENABLE_TRANSLATION", "1") == "1"
13
+ MODEL_ID_EN2INDIC = os.getenv(
14
+ "INDICTRANS2_EN2INDIC_MODEL",
15
+ "ai4bharat/indictrans2-en-indic-distilled"
16
+ )
17
+ # If later you add Indic→English, you can add the reverse distilled model:
18
+ MODEL_ID_INDIC2EN = os.getenv(
19
+ "INDICTRANS2_INDIC2EN_MODEL",
20
+ "ai4bharat/indictrans2-indic-en-distilled"
21
+ )
22
 
23
+ # Globals (loaded once)
24
+ _MODEL_EN2INDIC = None
25
+ _TOKENIZER_EN2INDIC = None
26
+ _MODEL_INDIC2EN = None
27
+ _TOKENIZER_INDIC2EN = None
28
+ _IPROCESSOR = None # Indic pre/post processor
29
 
30
+ # Light Hindi detection (Devanagari range)
31
+ _RE_DEVANAGARI = re.compile(r"[\u0900-\u097F]")
 
 
 
 
32
 
33
+ def _likely_hindi(text: str) -> bool:
34
+ return bool(_RE_DEVANAGARI.search(text or ""))
 
35
 
36
+ def _try_imports():
37
+ """Import heavy libs lazily."""
38
+ global transformers, torch, IndicProcessor
39
+ import transformers # type: ignore
40
+ import torch # type: ignore
41
+ from IndicTransToolkit.processor import IndicProcessor # type: ignore
42
+ return transformers, torch, IndicProcessor
43
 
44
+ def _device():
45
+ # Force CPU on Spaces (safe default)
46
+ return "cpu"
 
 
 
 
 
47
 
48
+ def _load_iprocessor():
49
+ global _IPROCESSOR
50
+ if _IPROCESSOR is not None:
51
+ return _IPROCESSOR
52
+ try:
53
+ _, _, IndicProcessor = _try_imports()
54
+ _IPROCESSOR = IndicProcessor(inference=True)
55
+ except Exception:
56
+ _IPROCESSOR = None
57
+ return _IPROCESSOR
58
 
59
+ def _load_en2indic():
60
+ """Load the distilled en→indic model once."""
61
+ global _MODEL_EN2INDIC, _TOKENIZER_EN2INDIC
62
+ if _MODEL_EN2INDIC is not None:
63
+ return _MODEL_EN2INDIC, _TOKENIZER_EN2INDIC
64
+ try:
65
+ transformers, torch, _ = _try_imports()
66
+ tok = transformers.AutoTokenizer.from_pretrained(MODEL_ID_EN2INDIC, trust_remote_code=True)
67
+ model = transformers.AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID_EN2INDIC, trust_remote_code=True)
68
+ model.to(_device())
69
+ model.eval()
70
+ _MODEL_EN2INDIC, _TOKENIZER_EN2INDIC = model, tok
71
+ except Exception:
72
+ _MODEL_EN2INDIC, _TOKENIZER_EN2INDIC = None, None
73
+ return _MODEL_EN2INDIC, _TOKENIZER_EN2INDIC
74
 
75
+ def _load_indic2en():
76
+ """Load the distilled indic→en model once (only if needed)."""
77
+ global _MODEL_INDIC2EN, _TOKENIZER_INDIC2EN
78
+ if _MODEL_INDIC2EN is not None:
79
+ return _MODEL_INDIC2EN, _TOKENIZER_INDIC2EN
80
+ try:
81
+ transformers, torch, _ = _try_imports()
82
+ tok = transformers.AutoTokenizer.from_pretrained(MODEL_ID_INDIC2EN, trust_remote_code=True)
83
+ model = transformers.AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID_INDIC2EN, trust_remote_code=True)
84
+ model.to(_device())
85
+ model.eval()
86
+ _MODEL_INDIC2EN, _TOKENIZER_INDIC2EN = model, tok
87
+ except Exception:
88
+ _MODEL_INDIC2EN, _TOKENIZER_INDIC2EN = None, None
89
+ return _MODEL_INDIC2EN, _TOKENIZER_INDIC2EN
90
 
91
+ def _generate(model, tokenizer, inputs: List[str], max_new_tokens=256) -> List[str]:
92
+ """Run generation on a small batch of strings."""
93
+ if model is None or tokenizer is None:
94
+ return inputs # graceful fallback
95
+ try:
96
+ import torch # local import
97
+ enc = tokenizer(
98
+ inputs,
99
+ return_tensors="pt",
100
+ padding=True,
101
+ truncation=True,
102
+ max_length=512,
 
103
  )
104
+ enc = {k: v.to(_device()) for k, v in enc.items()}
105
+ with torch.no_grad():
106
+ outs = model.generate(
107
+ **enc,
108
+ max_new_tokens=max_new_tokens,
109
+ do_sample=False,
110
+ )
111
+ return tokenizer.batch_decode(outs, skip_special_tokens=True)
112
+ except Exception:
113
+ return inputs
114
 
115
+ def ensure_english(text: str) -> str:
116
+ """
117
+ If input text looks Hindi, translate to English. Otherwise return as is.
118
+ We keep this very light: only detect Devanagari → hi→en.
119
+ """
120
+ if not ENABLE_TRANSLATION:
121
  return text
122
+ try:
123
+ if _likely_hindi(text):
124
+ model, tok = _load_indic2en()
125
+ ip = _load_iprocessor()
126
+ src = text
127
+ if ip:
128
+ # Normalize/romanize as the toolkit suggests (safe to skip if None)
129
+ src = ip.preprocess_batch([src], src_lang=HINDI, tgt_lang=ENGLISH)[0]
130
+ out = _generate(model, tok, [src])[0]
131
+ if ip:
132
+ out = ip.postprocess_batch([out], lang=ENGLISH)[0]
133
+ return out
134
+ return text
135
+ except Exception:
136
  return text
 
 
137
 
138
+ def en_to_lang(text: str, tgt_lang: str = HINDI) -> str:
139
  """
140
+ Translate English target Indic language (default: Hindi).
141
+ If translation stack is unavailable, returns original text.
142
  """
143
+ if not ENABLE_TRANSLATION:
144
+ return text
145
+ if not text:
146
+ return text
147
+ try:
148
+ model, tok = _load_en2indic()
149
+ ip = _load_iprocessor()
150
+ src = text
151
+ if ip:
152
+ src = ip.preprocess_batch([src], src_lang=ENGLISH, tgt_lang=tgt_lang)[0]
153
+ out = _generate(model, tok, [src])[0]
154
+ if ip:
155
+ out = ip.postprocess_batch([out], lang=tgt_lang)[0]
156
+ return out
157
+ except Exception:
158
+ return text