| """ |
| Beam search and greedy decoders for Singlish β Sinhala transliteration. |
| """ |
|
|
| import math |
| import re |
| import torch |
| import pickle |
| import logging |
| from typing import List, Tuple, Dict, Optional, Set |
|
|
| from transformers import AutoTokenizer, AutoModelForMaskedLM |
|
|
| from core.constants import ( |
| DEFAULT_MODEL_NAME, DEFAULT_DICTIONARY_PATH, |
| DEFAULT_BEAM_WIDTH, MAX_CANDIDATES, MIN_ENGLISH_LEN, |
| PUNCT_PATTERN, |
| ) |
| from core.mappings import COMMON_WORDS, CONTEXT_WORDS_STANDALONE |
| from core.english import ENGLISH_VOCAB |
| from core.scorer import CandidateScorer, ScoredCandidate, WordDiagnostic |
| from core.dictionary import DictionaryAdapter |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| _SINHALA_RE = re.compile(r"[\u0D80-\u0DFF]") |
|
|
|
|
| def _is_sinhala(text: str) -> bool: |
| """Return True if the text already contains Sinhala script characters.""" |
| return bool(_SINHALA_RE.search(text)) |
|
|
|
|
| class BeamSearchDecoder: |
| """ |
| Contextual beam-search decoder for Singlish β Sinhala transliteration. |
| |
| For each word position the decoder: |
| 1. Generates candidates (dictionary + rule engine) |
| 2. Scores them with XLM-R MLM in sentence context |
| 3. Combines MLM score with fidelity & rank via CandidateScorer |
| 4. Prunes to the top-k (beam width) hypotheses |
| """ |
|
|
| def __init__( |
| self, |
| model_name: str = DEFAULT_MODEL_NAME, |
| dictionary_path: str = DEFAULT_DICTIONARY_PATH, |
| device: Optional[str] = None, |
| ): |
| self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| logger.info("Loading tokenizer & model: %s", model_name) |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| self.model = AutoModelForMaskedLM.from_pretrained(model_name) |
| self.model.to(self.device) |
| self.model.eval() |
|
|
| logger.info("Loading dictionary: %s", dictionary_path) |
| with open(dictionary_path, "rb") as f: |
| d_data = pickle.load(f) |
| self.adapter = DictionaryAdapter(d_data) |
| self.scorer = CandidateScorer() |
|
|
| |
|
|
| @staticmethod |
| def _softmax_normalize(raw_scores: List[float]) -> List[float]: |
| """ |
| Normalize raw log-probability scores to [0, 1] via softmax. |
| |
| Unlike min-max (which maps bestβ1.0, worstβ0.0 regardless of |
| the actual difference), softmax preserves the model's relative |
| confidence. When all candidates have similar log-probs the |
| output values cluster together; when the model is very |
| confident they spread apart. |
| |
| The raw scores are already log-probs (negative), so we use |
| them directly as logits for softmax. |
| """ |
| if not raw_scores: |
| return [] |
| if len(raw_scores) == 1: |
| return [1.0] |
|
|
| |
| max_s = max(raw_scores) |
| exps = [math.exp(s - max_s) for s in raw_scores] |
| total = sum(exps) |
| return [e / total for e in exps] |
|
|
| |
|
|
| def _batch_mlm_score( |
| self, |
| left_contexts: List[str], |
| right_contexts: List[str], |
| candidates: List[str], |
| ) -> List[float]: |
| """ |
| Score each candidate using masked LM log-probability with proper |
| multi-mask scoring for multi-subword candidates. |
| |
| Instead of placing a single <mask> and summing subword log-probs |
| at that one position, this method creates one <mask> per subword |
| token and scores each subword at its own position: |
| |
| score = (1/N) * Ξ£_i log P(t_i | mask_position_i, context) |
| """ |
| if not candidates: |
| return [] |
|
|
| mask = self.tokenizer.mask_token |
| mask_token_id = self.tokenizer.mask_token_id |
|
|
| |
| cand_token_ids: List[List[int]] = [] |
| for c in candidates: |
| ids = self.tokenizer.encode(c, add_special_tokens=False) |
| cand_token_ids.append(ids if ids else [self.tokenizer.unk_token_id]) |
|
|
| |
| batch_texts: List[str] = [] |
| for i in range(len(candidates)): |
| n_masks = len(cand_token_ids[i]) |
| mask_str = " ".join([mask] * n_masks) |
| parts = [p for p in [left_contexts[i], mask_str, right_contexts[i]] if p] |
| batch_texts.append(" ".join(parts)) |
|
|
| inputs = self.tokenizer( |
| batch_texts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| ).to(self.device) |
|
|
| with torch.no_grad(): |
| logits = self.model(**inputs).logits |
|
|
| scores: List[float] = [] |
| for i, target_ids in enumerate(cand_token_ids): |
| token_ids = inputs.input_ids[i] |
| mask_positions = (token_ids == mask_token_id).nonzero(as_tuple=True)[0] |
|
|
| if mask_positions.numel() == 0 or not target_ids: |
| scores.append(-100.0) |
| continue |
|
|
| |
| n = min(len(target_ids), mask_positions.numel()) |
| total = 0.0 |
| for j in range(n): |
| pos = mask_positions[j].item() |
| log_probs = torch.log_softmax(logits[i, pos, :], dim=0) |
| total += log_probs[target_ids[j]].item() |
|
|
| scores.append(total / n) |
|
|
| return scores |
|
|
| |
|
|
| def decode( |
| self, |
| sentence: str, |
| beam_width: int = DEFAULT_BEAM_WIDTH, |
| mode: str = "greedy", |
| ) -> Tuple[str, List[str]]: |
| """ |
| Transliterate a full Singlish sentence into Sinhala script. |
| |
| Args: |
| mode: "greedy" (accurate, uses dynamic context) or |
| "beam" (uses fixed rule-based context) |
| |
| Returns: |
| result β the best transliteration string |
| trace_logs β per-step markdown logs for the debug UI |
| """ |
| if mode == "greedy": |
| result, trace_logs, _ = self.greedy_decode_with_diagnostics(sentence) |
| else: |
| result, trace_logs, _ = self.decode_with_diagnostics( |
| sentence=sentence, |
| beam_width=beam_width, |
| ) |
| return result, trace_logs |
|
|
| |
|
|
| def greedy_decode_with_diagnostics( |
| self, |
| sentence: str, |
| ) -> Tuple[str, List[str], List[WordDiagnostic]]: |
| """ |
| Greedy word-by-word decode using actual selected outputs as |
| left context for subsequent MLM scoring. |
| |
| More accurate than beam search with fixed context because XLM-R |
| sees the real transliteration built so far, not rule-based guesses. |
| """ |
| words = sentence.split() |
| if not words: |
| return "", [], [] |
|
|
| |
| word_infos: List[dict] = [] |
|
|
| for raw in words: |
| match = PUNCT_PATTERN.match(raw) |
| prefix, core, suffix = match.groups() if match else ("", raw, "") |
|
|
| if not core: |
| word_infos.append({ |
| "candidates": [raw], |
| "rule_output": raw, |
| "english_flags": [False], |
| "dict_flags": [False], |
| "prefix": prefix, |
| "suffix": suffix, |
| "sinhala_passthrough": False, |
| }) |
| continue |
|
|
| |
| if _is_sinhala(core): |
| word_infos.append({ |
| "candidates": [raw], |
| "rule_output": raw, |
| "english_flags": [False], |
| "dict_flags": [False], |
| "prefix": prefix, |
| "suffix": suffix, |
| "sinhala_passthrough": True, |
| }) |
| continue |
|
|
| rule_output = self.adapter.get_rule_output(core) |
| cands = self.adapter.get_candidates(core, rule_output) |
|
|
| dict_entries: Set[str] = set() |
| if core in self.adapter.dictionary: |
| dict_entries.update(self.adapter.dictionary[core]) |
| elif core.lower() in self.adapter.dictionary: |
| dict_entries.update(self.adapter.dictionary[core.lower()]) |
|
|
| if rule_output and rule_output not in cands: |
| cands.append(rule_output) |
| if not cands: |
| cands = [rule_output] |
|
|
| english_flags = [c.lower() in ENGLISH_VOCAB for c in cands] |
| dict_flags = [c in dict_entries for c in cands] |
|
|
| full_cands = [prefix + c + suffix for c in cands] |
|
|
| word_infos.append({ |
| "candidates": full_cands[:MAX_CANDIDATES], |
| "rule_output": prefix + rule_output + suffix, |
| "core_rule_output": rule_output, |
| "n_dict_entries": len(dict_entries), |
| "dict_entries": dict_entries, |
| "english_flags": english_flags[:MAX_CANDIDATES], |
| "dict_flags": dict_flags[:MAX_CANDIDATES], |
| "prefix": prefix, |
| "suffix": suffix, |
| "sinhala_passthrough": False, |
| }) |
|
|
| |
| stable_right: List[str] = [] |
| for info in word_infos: |
| eng_cands = [ |
| c for c, e in zip(info["candidates"], info["english_flags"]) if e |
| ] |
| stable_right.append( |
| eng_cands[0] if eng_cands else info["rule_output"] |
| ) |
|
|
| |
| selected_words: List[str] = [] |
| trace_logs: List[str] = [] |
| diagnostics: List[WordDiagnostic] = [] |
|
|
| for t, info in enumerate(word_infos): |
| candidates = info["candidates"] |
| eng_flags = info["english_flags"] |
| d_flags = info.get("dict_flags", [False] * len(candidates)) |
| rule_out = info["rule_output"] |
| prefix = info.get("prefix", "") |
| suffix = info.get("suffix", "") |
| total_cands = len(candidates) |
|
|
| |
| if info.get("sinhala_passthrough"): |
| selected_words.append(words[t]) |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{words[t]}` (Sinhala passthrough)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=words[t], |
| beam_score=0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| core_lower = words[t].lower().strip() |
| if core_lower in COMMON_WORDS: |
| override = prefix + COMMON_WORDS[core_lower] + suffix |
| selected_words.append(override) |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{override}` (common-word override)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=override, |
| beam_score=0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| if core_lower in CONTEXT_WORDS_STANDALONE: |
| prev_word_lower = words[t - 1].lower() if t > 0 else "" |
| prev_common_val = COMMON_WORDS.get(prev_word_lower, "") |
| prev_is_english = ( |
| t > 0 |
| and ( |
| prev_word_lower in ENGLISH_VOCAB |
| or prev_common_val.isascii() and prev_common_val != "" |
| ) |
| ) |
| if not prev_is_english: |
| override = prefix + CONTEXT_WORDS_STANDALONE[core_lower] + suffix |
| selected_words.append(override) |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{override}` (standalone override)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=override, |
| beam_score=0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| |
| |
| |
| |
| |
| |
| core_rule = info.get("core_rule_output", "") |
| core_dict = info.get("dict_entries", set()) |
| is_semantically_ambiguous = ( |
| core_rule in core_dict and len(core_dict) >= 3 |
| ) |
| if ( |
| len(core_lower) >= MIN_ENGLISH_LEN |
| and core_lower in ENGLISH_VOCAB |
| and not is_semantically_ambiguous |
| ): |
| selected_words.append(words[t]) |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{words[t]}` (English preserved)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=words[t], |
| beam_score=0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| left_ctx = " ".join(selected_words) if selected_words else "" |
| |
| right_ctx = " ".join(stable_right[t + 1:]) if t + 1 < len(words) else "" |
|
|
| |
| batch_left = [left_ctx] * total_cands |
| batch_right = [right_ctx] * total_cands |
|
|
| mlm_scores = self._batch_mlm_score(batch_left, batch_right, candidates) |
|
|
| |
| |
| |
| |
| mlm_scores = self._softmax_normalize(mlm_scores) |
|
|
| |
| |
| |
| |
| best_nonenglish_mlm = -1e9 |
| if not is_semantically_ambiguous: |
| for i, mlm in enumerate(mlm_scores): |
| is_eng = eng_flags[i] if i < len(eng_flags) else False |
| if not is_eng and mlm > best_nonenglish_mlm: |
| best_nonenglish_mlm = mlm |
|
|
| |
| step_log = f"**Step {t + 1}: `{words[t]}`** (rule β `{rule_out}`)\n\n" |
| best_scored: Optional[ScoredCandidate] = None |
| candidate_breakdown: List[ScoredCandidate] = [] |
|
|
| for i, mlm in enumerate(mlm_scores): |
| cand = candidates[i] |
| is_eng = eng_flags[i] if i < len(eng_flags) else False |
| is_dict = d_flags[i] if i < len(d_flags) else False |
|
|
| effective_mlm = mlm |
| if is_eng and cand.lower() == words[t].lower() and not is_semantically_ambiguous: |
| effective_mlm = max(mlm, best_nonenglish_mlm) |
|
|
| scored = self.scorer.score( |
| mlm_score=effective_mlm, |
| candidate=cand, |
| rule_output=rule_out, |
| rank=i, |
| total_candidates=total_cands, |
| is_english=is_eng, |
| original_input=words[t], |
| is_from_dict=is_dict, |
| is_ambiguous=is_semantically_ambiguous, |
| ) |
| candidate_breakdown.append(scored) |
|
|
| if best_scored is None or scored.combined_score > best_scored.combined_score: |
| best_scored = scored |
|
|
| if mlm > -25.0: |
| eng_tag = " π€" if is_eng else "" |
| step_log += ( |
| f"- `{cand}`{eng_tag} " |
| f"MLM={scored.mlm_score:.2f} " |
| f"Fid={scored.fidelity_score:.2f} " |
| f"Rank={scored.rank_score:.2f} β " |
| f"**{scored.combined_score:.2f}**\n" |
| ) |
|
|
| trace_logs.append(step_log) |
|
|
| selected = best_scored.text if best_scored else rule_out |
| selected_words.append(selected) |
|
|
| candidate_breakdown.sort(key=lambda s: s.combined_score, reverse=True) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=selected, |
| beam_score=best_scored.combined_score if best_scored else 0.0, |
| candidate_breakdown=candidate_breakdown, |
| )) |
|
|
| result = " ".join(selected_words) |
| return result, trace_logs, diagnostics |
|
|
| |
|
|
| def decode_with_diagnostics( |
| self, |
| sentence: str, |
| beam_width: int = DEFAULT_BEAM_WIDTH, |
| ) -> Tuple[str, List[str], List[WordDiagnostic]]: |
| """ |
| Decode sentence using beam search and return detailed diagnostics. |
| |
| Uses fixed rule-based context for all beam paths. Kept for |
| comparison with greedy decode in evaluation. |
| """ |
| words = sentence.split() |
| if not words: |
| return "", [], [] |
|
|
| |
| word_infos: List[dict] = [] |
|
|
| for raw in words: |
| match = PUNCT_PATTERN.match(raw) |
| prefix, core, suffix = match.groups() if match else ("", raw, "") |
|
|
| if not core: |
| word_infos.append({ |
| "candidates": [raw], |
| "rule_output": raw, |
| "english_flags": [False], |
| "prefix": prefix, |
| "suffix": suffix, |
| "sinhala_passthrough": False, |
| }) |
| continue |
|
|
| |
| if _is_sinhala(core): |
| word_infos.append({ |
| "candidates": [raw], |
| "rule_output": raw, |
| "english_flags": [False], |
| "prefix": prefix, |
| "suffix": suffix, |
| "sinhala_passthrough": True, |
| }) |
| continue |
|
|
| rule_output = self.adapter.get_rule_output(core) |
| cands = self.adapter.get_candidates(core, rule_output) |
|
|
| dict_entries: Set[str] = set() |
| if core in self.adapter.dictionary: |
| dict_entries.update(self.adapter.dictionary[core]) |
| elif core.lower() in self.adapter.dictionary: |
| dict_entries.update(self.adapter.dictionary[core.lower()]) |
|
|
| if rule_output and rule_output not in cands: |
| cands.append(rule_output) |
| if not cands: |
| cands = [rule_output] |
|
|
| english_flags = [c.lower() in ENGLISH_VOCAB for c in cands] |
| dict_flags = [c in dict_entries for c in cands] |
| full_cands = [prefix + c + suffix for c in cands] |
|
|
| word_infos.append({ |
| "candidates": full_cands[:MAX_CANDIDATES], |
| "rule_output": prefix + rule_output + suffix, |
| "core_rule_output": rule_output, |
| "n_dict_entries": len(dict_entries), |
| "dict_entries": dict_entries, |
| "english_flags": english_flags[:MAX_CANDIDATES], |
| "dict_flags": dict_flags[:MAX_CANDIDATES], |
| "prefix": prefix, |
| "suffix": suffix, |
| "sinhala_passthrough": False, |
| }) |
|
|
| |
| stable_context: List[str] = [] |
| for info in word_infos: |
| eng_cands = [ |
| c for c, e in zip(info["candidates"], info["english_flags"]) if e |
| ] |
| stable_context.append( |
| eng_cands[0] if eng_cands else info["rule_output"] |
| ) |
|
|
| |
| beam: List[Tuple[List[str], float]] = [([], 0.0)] |
| trace_logs: List[str] = [] |
| diagnostics: List[WordDiagnostic] = [] |
|
|
| for t, info in enumerate(word_infos): |
| candidates = info["candidates"] |
| eng_flags = info["english_flags"] |
| d_flags = info.get("dict_flags", [False] * len(candidates)) |
| rule_out = info["rule_output"] |
| prefix = info.get("prefix", "") |
| suffix = info.get("suffix", "") |
| total_cands = len(candidates) |
|
|
| |
| if info.get("sinhala_passthrough"): |
| next_beam_si = [(path + [words[t]], sc) for path, sc in beam] |
| beam = next_beam_si[:beam_width] |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{words[t]}` (Sinhala passthrough)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=words[t], |
| beam_score=beam[0][1] if beam else 0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| core_lower = words[t].lower().strip() |
| if core_lower in COMMON_WORDS: |
| override = prefix + COMMON_WORDS[core_lower] + suffix |
| next_beam_cw = [(path + [override], sc) for path, sc in beam] |
| beam = next_beam_cw[:beam_width] |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{override}` (common-word override)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=override, |
| beam_score=beam[0][1] if beam else 0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| if core_lower in CONTEXT_WORDS_STANDALONE: |
| prev_word_lower = words[t - 1].lower() if t > 0 else "" |
| prev_common_val = COMMON_WORDS.get(prev_word_lower, "") |
| prev_is_english = ( |
| t > 0 |
| and ( |
| prev_word_lower in ENGLISH_VOCAB |
| or prev_common_val.isascii() and prev_common_val != "" |
| ) |
| ) |
| if not prev_is_english: |
| override = prefix + CONTEXT_WORDS_STANDALONE[core_lower] + suffix |
| next_beam_ctx = [(path + [override], sc) for path, sc in beam] |
| beam = next_beam_ctx[:beam_width] |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{override}` (standalone override)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=override, |
| beam_score=beam[0][1] if beam else 0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| |
| core_rule = info.get("core_rule_output", "") |
| core_dict = info.get("dict_entries", set()) |
| is_semantically_ambiguous = ( |
| core_rule in core_dict and len(core_dict) >= 3 |
| ) |
| if ( |
| len(core_lower) >= MIN_ENGLISH_LEN |
| and core_lower in ENGLISH_VOCAB |
| and not is_semantically_ambiguous |
| ): |
| eng_word = words[t] |
| next_beam_eng = [(path + [eng_word], sc) for path, sc in beam] |
| beam = next_beam_eng[:beam_width] |
| trace_logs.append( |
| f"**Step {t + 1}: `{words[t]}`** β " |
| f"`{eng_word}` (English preserved)\n" |
| ) |
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=eng_word, |
| beam_score=beam[0][1] if beam else 0.0, |
| candidate_breakdown=[], |
| )) |
| continue |
|
|
| |
| batch_left: List[str] = [] |
| batch_right: List[str] = [] |
| batch_tgt: List[str] = [] |
| batch_meta: List[Tuple[int, int]] = [] |
|
|
| for p_idx, (path, _) in enumerate(beam): |
| for c_idx, cand in enumerate(candidates): |
| future = stable_context[t + 1:] if t + 1 < len(words) else [] |
| batch_left.append(" ".join(stable_context[:t])) |
| batch_right.append(" ".join(future)) |
| batch_tgt.append(cand) |
| batch_meta.append((p_idx, c_idx)) |
|
|
| if not batch_tgt: |
| continue |
|
|
| mlm_scores = self._batch_mlm_score(batch_left, batch_right, batch_tgt) |
|
|
| |
| mlm_scores = self._softmax_normalize(mlm_scores) |
|
|
| |
| |
| best_nonenglish_mlm: Dict[int, float] = {} |
| if not is_semantically_ambiguous: |
| for i, mlm in enumerate(mlm_scores): |
| p_idx, c_idx = batch_meta[i] |
| is_eng = eng_flags[c_idx] if c_idx < len(eng_flags) else False |
| if not is_eng: |
| prev = best_nonenglish_mlm.get(p_idx, -1e9) |
| if mlm > prev: |
| best_nonenglish_mlm[p_idx] = mlm |
|
|
| |
| next_beam: List[Tuple[List[str], float]] = [] |
| all_step_scores: List[Tuple[int, ScoredCandidate, float]] = [] |
| step_log = f"**Step {t + 1}: `{words[t]}`** (rule β `{rule_out}`)\n\n" |
|
|
| for i, mlm in enumerate(mlm_scores): |
| p_idx, c_idx = batch_meta[i] |
| orig_path, orig_score = beam[p_idx] |
| cand = batch_tgt[i] |
| is_eng = eng_flags[c_idx] if c_idx < len(eng_flags) else False |
| is_dict = d_flags[c_idx] if c_idx < len(d_flags) else False |
|
|
| effective_mlm = mlm |
| if is_eng and cand.lower() == words[t].lower() and not is_semantically_ambiguous: |
| floor = best_nonenglish_mlm.get(p_idx, mlm) |
| effective_mlm = max(mlm, floor) |
|
|
| scored = self.scorer.score( |
| mlm_score=effective_mlm, |
| candidate=cand, |
| rule_output=rule_out, |
| rank=c_idx, |
| total_candidates=total_cands, |
| is_english=is_eng, |
| original_input=words[t], |
| is_from_dict=is_dict, |
| is_ambiguous=is_semantically_ambiguous, |
| ) |
|
|
| new_total = orig_score + scored.combined_score |
| next_beam.append((orig_path + [cand], new_total)) |
| all_step_scores.append((p_idx, scored, new_total)) |
|
|
| if mlm > -25.0: |
| eng_tag = " π€" if is_eng else "" |
| step_log += ( |
| f"- `{cand}`{eng_tag} " |
| f"MLM={scored.mlm_score:.2f} " |
| f"Fid={scored.fidelity_score:.2f} " |
| f"Rank={scored.rank_score:.2f} β " |
| f"**{scored.combined_score:.2f}**\n" |
| ) |
|
|
| trace_logs.append(step_log) |
|
|
| beam = sorted(next_beam, key=lambda x: x[1], reverse=True)[:beam_width] |
|
|
| root_scores = [item for item in all_step_scores if item[0] == 0] |
| root_scores_sorted = sorted(root_scores, key=lambda x: x[2], reverse=True) |
|
|
| selected = beam[0][0][t] if beam and beam[0][0] else "" |
| selected_total = beam[0][1] if beam else float("-inf") |
| candidate_breakdown = [item[1] for item in root_scores_sorted] |
|
|
| diagnostics.append(WordDiagnostic( |
| step_index=t, |
| input_word=words[t], |
| rule_output=rule_out, |
| selected_candidate=selected, |
| beam_score=selected_total, |
| candidate_breakdown=candidate_breakdown, |
| )) |
|
|
| result = " ".join(beam[0][0]) if beam else "" |
| return result, trace_logs, diagnostics |
|
|