| """ |
| Component 7: Inference engine for local code generation. |
| |
| Features: |
| - Deterministic low-temperature greedy mode. |
| - Stop rules for clean function completion. |
| - Syntax-aware retry with up to 3 attempts. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import ast |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
|
|
| from src.evaluation_system.code_eval import restore_code_from_structured |
| from src.model_architecture.code_transformer import CodeTransformerLM |
| from src.tokenizer.code_tokenizer import CodeTokenizer |
|
|
|
|
| @dataclass |
| class DecodingConfig: |
| max_new_tokens: int = 300 |
| |
| greedy_temperature: float = 0.0 |
| |
| retry2_temperature: float = 0.25 |
| retry2_top_p: float = 0.85 |
| |
| retry3_temperature: float = 0.35 |
| retry3_top_p: float = 0.90 |
| max_retries: int = 3 |
| min_tokens_before_stop_check: int = 64 |
| |
| min_function_body_statements: int = 2 |
|
|
|
|
| class InferenceEngine: |
| def __init__(self, model: CodeTransformerLM, tokenizer: CodeTokenizer, device: torch.device) -> None: |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = device |
| self.model.eval() |
|
|
| @staticmethod |
| def _syntax_ok_python(code: str) -> bool: |
| try: |
| ast.parse(code) |
| return True |
| except Exception: |
| return False |
|
|
| @staticmethod |
| def _function_completion_score(code: str) -> int: |
| |
| try: |
| tree = ast.parse(code) |
| except Exception: |
| return 0 |
| funcs = [n for n in tree.body if isinstance(n, ast.FunctionDef)] |
| if not funcs: |
| return 0 |
| fn = funcs[-1] |
| body_len = len(fn.body) |
| has_return = any(isinstance(n, ast.Return) for n in ast.walk(fn)) |
| return body_len + (2 if has_return else 0) |
|
|
| def _looks_complete_function(self, code: str, min_body_statements: int) -> bool: |
| if "def " not in code: |
| return False |
| try: |
| tree = ast.parse(code) |
| except Exception: |
| return False |
| funcs = [n for n in tree.body if isinstance(n, ast.FunctionDef)] |
| if not funcs: |
| return False |
| fn = funcs[-1] |
| if len(fn.body) < min_body_statements: |
| return False |
| return True |
|
|
| def _sample_next( |
| self, |
| logits: torch.Tensor, |
| temperature: float, |
| top_p: float, |
| ) -> torch.Tensor: |
| if temperature <= 0: |
| return torch.argmax(logits, dim=-1, keepdim=True) |
|
|
| logits = logits / temperature |
| probs = torch.softmax(logits, dim=-1) |
|
|
| sorted_probs, sorted_idx = torch.sort(probs, descending=True) |
| cumulative = torch.cumsum(sorted_probs, dim=-1) |
| cutoff = cumulative > top_p |
| cutoff[..., 1:] = cutoff[..., :-1].clone() |
| cutoff[..., 0] = False |
| sorted_probs[cutoff] = 0.0 |
| denom = sorted_probs.sum(dim=-1, keepdim=True).clamp_min(1e-12) |
| sorted_probs = sorted_probs / denom |
| sampled = torch.multinomial(sorted_probs, num_samples=1) |
| return sorted_idx.gather(-1, sampled) |
|
|
| @torch.no_grad() |
| def _generate_once( |
| self, |
| prompt: str, |
| language: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| min_tokens_before_stop_check: int, |
| min_function_body_statements: int, |
| ) -> Dict[str, object]: |
| prompt_text = self.tokenizer.format_training_sample(prompt=prompt, code="", language=language) |
| prompt_text = prompt_text.replace(" <NL>", "").strip() |
|
|
| ids = self.tokenizer.encode(prompt_text) |
| eos_id = self.tokenizer.special_token_ids.get("<EOS>") |
|
|
| |
| if eos_id is not None and len(ids) > 1 and ids[-1] == int(eos_id): |
| ids = ids[:-1] |
|
|
| input_ids = torch.tensor([ids], dtype=torch.long, device=self.device) |
|
|
| generated_steps = 0 |
| for _ in range(max_new_tokens): |
| out = self.model(input_ids=input_ids) |
| logits = out["logits"][:, -1, :] |
| next_id = self._sample_next(logits, temperature=temperature, top_p=top_p) |
| input_ids = torch.cat([input_ids, next_id], dim=1) |
| generated_steps += 1 |
|
|
| |
| if eos_id is not None and int(next_id.item()) == int(eos_id): |
| break |
|
|
| |
| if generated_steps >= min_tokens_before_stop_check and (generated_steps % 12 == 0): |
| decoded = self.tokenizer.decode(input_ids[0].tolist()) |
| code = restore_code_from_structured(decoded) |
| if self._looks_complete_function(code, min_body_statements=min_function_body_statements): |
| break |
|
|
| decoded = self.tokenizer.decode(input_ids[0].tolist()) |
| code = restore_code_from_structured(decoded) |
| syntax_ok = self._syntax_ok_python(code) if language == "python" else True |
| completion_score = self._function_completion_score(code) if language == "python" else 0 |
| return { |
| "code": code, |
| "syntax_ok": syntax_ok, |
| "generated_tokens": generated_steps, |
| "temperature": temperature, |
| "top_p": top_p, |
| "completion_score": completion_score, |
| } |
|
|
| @torch.no_grad() |
| def generate_with_retry( |
| self, |
| prompt: str, |
| language: str = "python", |
| cfg: Optional[DecodingConfig] = None, |
| ) -> Dict[str, object]: |
| cfg = cfg or DecodingConfig() |
|
|
| attempts: List[Tuple[float, float]] = [ |
| (cfg.greedy_temperature, 1.0), |
| (cfg.retry2_temperature, cfg.retry2_top_p), |
| (cfg.retry3_temperature, cfg.retry3_top_p), |
| ] |
|
|
| results = [] |
| for i in range(min(cfg.max_retries, len(attempts))): |
| temp, top_p = attempts[i] |
| res = self._generate_once( |
| prompt=prompt, |
| language=language, |
| max_new_tokens=cfg.max_new_tokens, |
| temperature=temp, |
| top_p=top_p, |
| min_tokens_before_stop_check=cfg.min_tokens_before_stop_check, |
| min_function_body_statements=cfg.min_function_body_statements, |
| ) |
| res["attempt"] = i + 1 |
| results.append(res) |
|
|
| |
| if bool(res["syntax_ok"]): |
| return { |
| "final": res, |
| "attempts": results, |
| "used_retry": i > 0, |
| } |
|
|
| |
| best = sorted( |
| results, |
| key=lambda x: (int(x.get("completion_score", 0)), int(x.get("generated_tokens", 0))), |
| reverse=True, |
| )[0] |
| return { |
| "final": best, |
| "attempts": results, |
| "used_retry": True, |
| } |
|
|