| """
|
| Evaluation module β greedy / beam-search decoding + chrF scoring.
|
| =================================================================
|
| Provides:
|
| β’ ``greedy_decode`` β auto-regressive greedy decoding.
|
| β’ ``beam_search_decode`` β beam search with length normalisation.
|
| β’ ``translate`` β end-to-end: raw English string β Malay string.
|
| β’ ``compute_chrf`` β corpus-level chrF score via *sacrebleu*.
|
| β’ ``evaluate`` β decode the full validation set, compute chrF,
|
| and print sample translations.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import re
|
| from typing import List, Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
| from tokenizers import Tokenizer
|
|
|
| import sacrebleu
|
|
|
|
|
|
|
|
|
|
|
| def postprocess_translation(text: str) -> str:
|
| """
|
| Clean up raw tokenizer decode output:
|
| 1. Remove spaces before punctuation ( ", tuan ." β ", tuan.")
|
| 2. Remove spaces after opening brackets/quotes
|
| 3. Remove spaces before closing brackets/quotes
|
| 4. Capitalise the first letter
|
| 5. Collapse multiple spaces
|
| """
|
|
|
| text = re.sub(r'\s+([.,?!;:)\]}"\'β¦])', r'\1', text)
|
|
|
| text = re.sub(r'([(\[{"\'])\s+', r'\1', text)
|
|
|
| text = re.sub(r'\s*-\s*', '-', text)
|
|
|
| text = re.sub(r'\s{2,}', ' ', text)
|
|
|
| text = text.strip()
|
| if text:
|
| text = text[0].upper() + text[1:]
|
| return text
|
|
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def greedy_decode(
|
| model: nn.Module,
|
| src: torch.Tensor,
|
| bos_id: int,
|
| eos_id: int,
|
| pad_id: int = 0,
|
| max_len: int = 128,
|
| ) -> torch.Tensor:
|
| """
|
| Auto-regressive greedy decoding for a single source sequence.
|
|
|
| Parameters
|
| ----------
|
| model : TransformerTranslator
|
| src : (1, src_len) source token IDs.
|
| bos_id : beginning-of-sentence token ID.
|
| eos_id : end-of-sentence token ID.
|
| pad_id : padding token ID.
|
| max_len : maximum decoding steps.
|
|
|
| Returns
|
| -------
|
| (1, out_len) generated token IDs (including [BOS], up to [EOS]).
|
| """
|
| device = src.device
|
| model.eval()
|
|
|
|
|
| src_pad_mask = (src == pad_id)
|
| memory = model.encode(src, src_key_padding_mask=src_pad_mask)
|
|
|
|
|
| ys = torch.tensor([[bos_id]], dtype=torch.long, device=device)
|
|
|
| for _ in range(max_len - 1):
|
| logits = model.decode(
|
| ys, memory,
|
| memory_key_padding_mask=src_pad_mask,
|
| )
|
| next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
|
| ys = torch.cat([ys, next_token], dim=1)
|
|
|
| if next_token.item() == eos_id:
|
| break
|
|
|
| return ys
|
|
|
|
|
|
|
|
|
|
|
| @torch.no_grad()
|
| def beam_search_decode(
|
| model: nn.Module,
|
| src: torch.Tensor,
|
| bos_id: int,
|
| eos_id: int,
|
| pad_id: int = 0,
|
| max_len: int = 128,
|
| beam_width: int = 5,
|
| length_penalty: float = 0.6,
|
| ) -> torch.Tensor:
|
| """
|
| Beam-search decoding for a single source sequence.
|
|
|
| Parameters
|
| ----------
|
| model : TransformerTranslator
|
| src : (1, src_len) source token IDs.
|
| bos_id, eos_id, pad_id : special token IDs.
|
| max_len : maximum decoding steps.
|
| beam_width : number of beams to keep at each step.
|
| length_penalty : Ξ± for length normalisation: score / len^Ξ±.
|
|
|
| Returns
|
| -------
|
| (1, out_len) best hypothesis token IDs (including [BOS], up to [EOS]).
|
| """
|
| device = src.device
|
| model.eval()
|
|
|
|
|
| src_pad_mask = (src == pad_id)
|
| memory = model.encode(src, src_key_padding_mask=src_pad_mask)
|
|
|
|
|
| beams = [(0.0, [bos_id])]
|
| completed = []
|
|
|
| for _ in range(max_len - 1):
|
| candidates = []
|
| for score, tokens in beams:
|
| if tokens[-1] == eos_id:
|
| completed.append((score, tokens))
|
| continue
|
|
|
| ys = torch.tensor([tokens], dtype=torch.long, device=device)
|
| logits = model.decode(
|
| ys, memory,
|
| memory_key_padding_mask=src_pad_mask,
|
| )
|
| log_probs = torch.log_softmax(logits[:, -1, :], dim=-1).squeeze(0)
|
|
|
| topk_log_probs, topk_ids = log_probs.topk(beam_width)
|
| for k in range(beam_width):
|
| new_score = score + topk_log_probs[k].item()
|
| new_tokens = tokens + [topk_ids[k].item()]
|
| candidates.append((new_score, new_tokens))
|
|
|
| if not candidates:
|
| break
|
|
|
|
|
| candidates.sort(
|
| key=lambda x: x[0] / (len(x[1]) ** length_penalty),
|
| reverse=True,
|
| )
|
| beams = candidates[:beam_width]
|
|
|
|
|
| if all(b[1][-1] == eos_id for b in beams):
|
| completed.extend(beams)
|
| break
|
|
|
|
|
| completed.extend(beams)
|
|
|
|
|
| best = max(
|
| completed,
|
| key=lambda x: x[0] / (len(x[1]) ** length_penalty),
|
| )
|
| return torch.tensor([best[1]], dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
| def translate(
|
| model: nn.Module,
|
| sentence: str,
|
| src_tokenizer: Tokenizer,
|
| tgt_tokenizer: Tokenizer,
|
| bos_id: int,
|
| eos_id: int,
|
| pad_id: int = 0,
|
| max_len: int = 128,
|
| device: Optional[torch.device] = None,
|
| beam_width: int = 1,
|
| length_penalty: float = 0.6,
|
| ) -> str:
|
| """Translate a single English sentence to Malay.
|
| Set beam_width=1 for greedy, >1 for beam search.
|
| """
|
| if device is None:
|
| device = next(model.parameters()).device
|
|
|
|
|
| src_ids = src_tokenizer.encode(sentence).ids
|
| src = torch.tensor([src_ids], dtype=torch.long, device=device)
|
|
|
|
|
| if beam_width > 1:
|
| out_ids = beam_search_decode(
|
| model, src, bos_id, eos_id, pad_id, max_len,
|
| beam_width=beam_width, length_penalty=length_penalty,
|
| )
|
| else:
|
| out_ids = greedy_decode(model, src, bos_id, eos_id, pad_id, max_len)
|
|
|
|
|
| raw = tgt_tokenizer.decode(out_ids.squeeze(0).tolist(), skip_special_tokens=True)
|
| return postprocess_translation(raw)
|
|
|
|
|
|
|
|
|
|
|
| def compute_chrf(hypotheses: List[str], references: List[str]) -> sacrebleu.CHRFScore:
|
| """
|
| Compute corpus-level chrF score.
|
|
|
| Parameters
|
| ----------
|
| hypotheses : list[str]
|
| System outputs (decoded translations).
|
| references : list[str]
|
| Gold reference translations.
|
|
|
| Returns
|
| -------
|
| sacrebleu.CHRFScore β has ``.score`` attribute (0β100 scale).
|
| """
|
| return sacrebleu.corpus_chrf(hypotheses, [references])
|
|
|
|
|
|
|
|
|
|
|
| def evaluate(
|
| model: nn.Module,
|
| hf_dataset,
|
| src_tokenizer: Tokenizer,
|
| tgt_tokenizer: Tokenizer,
|
| src_lang: str = "en",
|
| tgt_lang: str = "ms",
|
| bos_id: int = 5,
|
| eos_id: int = 6,
|
| pad_id: int = 0,
|
| max_len: int = 128,
|
| device: Optional[torch.device] = None,
|
| num_samples: int = 5,
|
| beam_width: int = 1,
|
| length_penalty: float = 0.6,
|
| ) -> float:
|
| """
|
| Decode every example in *hf_dataset*, compute corpus chrF, and
|
| print ``num_samples`` side-by-side translations.
|
|
|
| Set beam_width=1 for greedy, >1 for beam search.
|
|
|
| Returns
|
| -------
|
| chrf_score : float (0β100)
|
| """
|
| if device is None:
|
| device = next(model.parameters()).device
|
|
|
| model.eval()
|
| hypotheses: List[str] = []
|
| references: List[str] = []
|
|
|
| for i, example in enumerate(hf_dataset):
|
| src_text = example["translation"][src_lang]
|
| ref_text = example["translation"][tgt_lang]
|
|
|
| hyp_text = translate(
|
| model, src_text,
|
| src_tokenizer, tgt_tokenizer,
|
| bos_id, eos_id, pad_id, max_len, device,
|
| beam_width=beam_width,
|
| length_penalty=length_penalty,
|
| )
|
|
|
| hypotheses.append(hyp_text)
|
| references.append(ref_text)
|
|
|
| chrf = compute_chrf(hypotheses, references)
|
|
|
|
|
| print(f"\n{'='*60}")
|
| print(f"chrF Score: {chrf.score:.2f}")
|
| print(f"{'='*60}")
|
| for i in range(min(num_samples, len(hypotheses))):
|
| src_text = hf_dataset[i]["translation"][src_lang]
|
| print(f"\n[{i}] SRC: {src_text[:120]}")
|
| print(f" REF: {references[i][:120]}")
|
| print(f" HYP: {hypotheses[i][:120]}")
|
|
|
| return chrf.score |