Spaces:
Sleeping
Sleeping
| """ | |
| inference.py | |
| Inference (translation) for English→Bengali with full calculation logging. | |
| Supports greedy decoding and beam search, showing every step. | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import math | |
| from typing import Dict, List, Tuple, Optional | |
| from transformer import Transformer, CalcLog | |
| from vocab import get_vocabs, PAD_IDX, BOS_IDX, EOS_IDX | |
| # ───────────────────────────────────────────── | |
| # Greedy decoding with full logging | |
| # ───────────────────────────────────────────── | |
| def greedy_decode( | |
| model: Transformer, | |
| src: torch.Tensor, | |
| max_len: int = 20, | |
| device: str = "cpu", | |
| log: Optional[CalcLog] = None, | |
| ) -> Tuple[List[int], List[Dict]]: | |
| model.eval() | |
| src_v, tgt_v = get_vocabs() | |
| with torch.no_grad(): | |
| src_mask = model.make_src_mask(src) | |
| # ── Encode once ────────────────────── | |
| src_emb = model.src_embed(src) * math.sqrt(model.d_model) | |
| enc_x = model.src_pe(src_emb, log=log) | |
| enc_attn_weights = [] | |
| for i, layer in enumerate(model.encoder_layers): | |
| enc_x, ew = layer(enc_x, src_mask=src_mask, | |
| log=log if i == 0 else None, layer_idx=i) | |
| enc_attn_weights.append(ew.cpu().numpy()) | |
| if log: | |
| log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8], | |
| note="Encoder finished. Output K,V will be reused for every decoder step.") | |
| # ── Auto-regressive decode ──────────── | |
| generated = [BOS_IDX] | |
| step_logs = [] | |
| for step in range(max_len): | |
| tgt_so_far = torch.tensor([generated], dtype=torch.long, device=device) | |
| tgt_mask = model.make_tgt_mask(tgt_so_far) | |
| tgt_emb = model.tgt_embed(tgt_so_far) * math.sqrt(model.d_model) | |
| dec_x = model.tgt_pe(tgt_emb) | |
| step_dec_cross = [] | |
| for i, layer in enumerate(model.decoder_layers): | |
| do_log = (log is not None) and (step < 3) and (i == 0) | |
| if do_log: | |
| log.log(f"INFERENCE_step{step}_dec_input", dec_x[0, :, :8], | |
| note=f"Decoder input at step {step}: tokens so far = " | |
| f"{tgt_v.tokens(generated)}") | |
| dec_x, mw, cw = layer( | |
| dec_x, enc_x, | |
| tgt_mask=tgt_mask, src_mask=src_mask, | |
| log=log if do_log else None, | |
| layer_idx=i, | |
| ) | |
| step_dec_cross.append(cw.cpu().numpy()) | |
| # Only look at last position | |
| last_logits = model.output_linear(dec_x[:, -1, :]) # (1, V) | |
| probs = F.softmax(last_logits, dim=-1)[0] | |
| # Top-5 predictions | |
| top5_probs, top5_ids = probs.topk(5) | |
| top5 = [ | |
| {"token": tgt_v.idx2token.get(idx.item(), "?"), | |
| "id": idx.item(), | |
| "prob": round(prob.item(), 4)} | |
| for prob, idx in zip(top5_probs, top5_ids) | |
| ] | |
| # Greedy: pick highest | |
| next_token = top5_ids[0].item() | |
| step_info = { | |
| "step": step, | |
| "tokens_so_far": tgt_v.tokens(generated), | |
| "top5": top5, | |
| "chosen_token": tgt_v.idx2token.get(next_token, "?"), | |
| "chosen_id": next_token, | |
| "chosen_prob": round(top5_probs[0].item(), 4), | |
| "cross_attn": step_dec_cross[0][0].tolist() | |
| if step_dec_cross else None, | |
| } | |
| step_logs.append(step_info) | |
| if log and step < 3: | |
| log.log(f"INFERENCE_step{step}_top5", top5, | |
| formula="P(next_token) = softmax(W_out · dec_out[-1])", | |
| note=f"Step {step}: top-5 candidates. Chosen: {step_info['chosen_token']} ({step_info['chosen_prob']:.4f})") | |
| generated.append(next_token) | |
| if next_token == EOS_IDX: | |
| break | |
| return generated, step_logs | |
| # ───────────────────────────────────────────── | |
| # Beam search | |
| # ───────────────────────────────────────────── | |
| def beam_search( | |
| model: Transformer, | |
| src: torch.Tensor, | |
| beam_size: int = 3, | |
| max_len: int = 20, | |
| device: str = "cpu", | |
| log: Optional[CalcLog] = None, | |
| ) -> Tuple[List[int], List[Dict]]: | |
| model.eval() | |
| src_v, tgt_v = get_vocabs() | |
| with torch.no_grad(): | |
| src_mask = model.make_src_mask(src) | |
| # Encode (with logging, same as greedy) | |
| src_emb = model.src_embed(src) * math.sqrt(model.d_model) | |
| enc_x = model.src_pe(src_emb, log=log) | |
| for i, layer in enumerate(model.encoder_layers): | |
| enc_x, _ = layer(enc_x, src_mask=src_mask, | |
| log=log if i == 0 else None, layer_idx=i) | |
| if log: | |
| log.log("INFERENCE_ENCODER_done", enc_x[0, :, :8], | |
| note="Encoder done. K,V reused for every beam decode step.") | |
| # Beams: list of (score, token_ids) | |
| beams = [(0.0, [BOS_IDX])] | |
| completed = [] | |
| step_logs = [] # greedy-compatible format for decode_steps_html | |
| for step in range(max_len): | |
| if not beams: | |
| break | |
| candidates = [] | |
| best_cross_attn = None # capture from top beam only | |
| for beam_idx, (score, tokens) in enumerate(beams): | |
| tgt_t = torch.tensor([tokens], dtype=torch.long, device=device) | |
| tgt_mask = model.make_tgt_mask(tgt_t) | |
| tgt_emb = model.tgt_embed(tgt_t) * math.sqrt(model.d_model) | |
| dec_x = model.tgt_pe(tgt_emb) | |
| step_dec_cross = [] | |
| for i, layer in enumerate(model.decoder_layers): | |
| do_log = (log is not None) and (step < 3) and (i == 0) and (beam_idx == 0) | |
| dec_x, _, cw = layer(dec_x, enc_x, | |
| tgt_mask=tgt_mask, src_mask=src_mask, | |
| log=log if do_log else None, layer_idx=i) | |
| step_dec_cross.append(cw.cpu().numpy()) | |
| if beam_idx == 0: | |
| best_cross_attn = step_dec_cross | |
| last_logits = model.output_linear(dec_x[:, -1, :]) | |
| log_probs = F.log_softmax(last_logits, dim=-1)[0] | |
| top_lp, top_id = log_probs.topk(beam_size) | |
| for lp, tid in zip(top_lp, top_id): | |
| candidates.append((score + lp.item(), tokens + [tid.item()])) | |
| # Sort all candidates | |
| candidates.sort(key=lambda x: x[0], reverse=True) | |
| # Build greedy-compatible step_info from top candidates | |
| tokens_so_far = tgt_v.tokens(beams[0][1]) | |
| top5 = [ | |
| { | |
| "token": tgt_v.idx2token.get(toks[-1], "?"), | |
| "id": toks[-1], | |
| "prob": round(math.exp(max(sc / max(len(toks) - 1, 1), -20)), 4), | |
| } | |
| for sc, toks in candidates[:5] | |
| ] | |
| best_sc, best_toks = candidates[0] if candidates else (0.0, [BOS_IDX, EOS_IDX]) | |
| chosen_id = best_toks[-1] | |
| # cross-attn: head 0, last position → [T_src] | |
| cross_attn = None | |
| if best_cross_attn: | |
| attn = best_cross_attn[0][0] # (4, step+1, T_src) after [0]=batch | |
| cross_attn = attn.tolist() | |
| step_logs.append({ | |
| "step": step, | |
| "tokens_so_far": tokens_so_far, | |
| "top5": top5, | |
| "chosen_token": tgt_v.idx2token.get(chosen_id, "?"), | |
| "chosen_id": chosen_id, | |
| "chosen_prob": top5[0]["prob"] if top5 else 0.0, | |
| "cross_attn": cross_attn, | |
| }) | |
| if log and step < 3: | |
| log.log(f"BEAM_step{step}_top_candidates", top5, | |
| formula="score = Σ log P(token_i | prev, src)", | |
| note=f"Step {step}: top beam candidates. Best: '{top5[0]['token'] if top5 else '?'}'") | |
| # Prune into next beams | |
| beams = [] | |
| for sc, toks in candidates[:beam_size * 2]: | |
| if toks[-1] == EOS_IDX: | |
| completed.append((sc / len(toks), toks)) | |
| elif len(beams) < beam_size: | |
| beams.append((sc, toks)) | |
| if len(completed) >= beam_size: | |
| break | |
| if completed: | |
| best = max(completed, key=lambda x: x[0]) | |
| return best[1], step_logs | |
| elif beams: | |
| return beams[0][1] + [EOS_IDX], step_logs | |
| else: | |
| return [BOS_IDX, EOS_IDX], step_logs | |
| # ───────────────────────────────────────────── | |
| # Full inference pipeline with visualization | |
| # ───────────────────────────────────────────── | |
| def visualize_inference( | |
| model: Transformer, | |
| en_sentence: str, | |
| device: str = "cpu", | |
| decode_method: str = "greedy", | |
| ) -> Dict: | |
| src_v, tgt_v = get_vocabs() | |
| log = CalcLog() | |
| src_ids = src_v.encode(en_sentence) | |
| log.log("INFERENCE_TOKENIZATION", { | |
| "sentence": en_sentence, | |
| "tokens": en_sentence.lower().split(), | |
| "ids": src_ids, | |
| }, formula="word → vocab_id lookup", | |
| note="No ground-truth Bengali needed — model generates from scratch") | |
| src = torch.tensor([src_ids], dtype=torch.long, device=device) | |
| if decode_method == "beam": | |
| output_ids, step_logs = beam_search(model, src, beam_size=3, | |
| device=device, log=log) | |
| log.log("BEAM_SEARCH_complete", { | |
| "method": "beam search (beam=3)", | |
| "note": "Explores multiple hypotheses simultaneously — generally better quality" | |
| }) | |
| else: | |
| output_ids, step_logs = greedy_decode(model, src, device=device, log=log) | |
| log.log("GREEDY_complete", { | |
| "method": "greedy decoding", | |
| "note": "Always picks highest probability token — fast but can miss optimal sequences" | |
| }) | |
| translation = tgt_v.decode(output_ids) | |
| output_tokens = tgt_v.tokens(output_ids) | |
| log.log("FINAL_TRANSLATION", { | |
| "input": en_sentence, | |
| "output_ids": output_ids, | |
| "output_tokens": output_tokens, | |
| "translation": translation, | |
| }, note="Complete English→Bengali translation") | |
| return { | |
| "en_sentence": en_sentence, | |
| "translation": translation, | |
| "output_tokens": output_tokens, | |
| "output_ids": output_ids, | |
| "src_tokens": src_v.tokens(src_ids), | |
| "step_logs": step_logs, | |
| "calc_log": log.to_dict(), | |
| "decode_method": decode_method, | |
| } | |