mindi-backup / src /inference_engine /inference_engine.py
Mindigenous
Initial full project backup with Git LFS
53f0cc2
"""
Component 7: Inference engine for local code generation.
Features:
- Deterministic low-temperature greedy mode.
- Stop rules for clean function completion.
- Syntax-aware retry with up to 3 attempts.
"""
from __future__ import annotations
import ast
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from src.evaluation_system.code_eval import restore_code_from_structured
from src.model_architecture.code_transformer import CodeTransformerLM
from src.tokenizer.code_tokenizer import CodeTokenizer
@dataclass
class DecodingConfig:
max_new_tokens: int = 300
# Mode 1: deterministic output
greedy_temperature: float = 0.0
# Retry mode 2
retry2_temperature: float = 0.25
retry2_top_p: float = 0.85
# Retry mode 3
retry3_temperature: float = 0.35
retry3_top_p: float = 0.90
max_retries: int = 3
min_tokens_before_stop_check: int = 64
# Stop only when function body is non-trivial.
min_function_body_statements: int = 2
class InferenceEngine:
def __init__(self, model: CodeTransformerLM, tokenizer: CodeTokenizer, device: torch.device) -> None:
self.model = model
self.tokenizer = tokenizer
self.device = device
self.model.eval()
@staticmethod
def _syntax_ok_python(code: str) -> bool:
try:
ast.parse(code)
return True
except Exception:
return False
@staticmethod
def _function_completion_score(code: str) -> int:
# Higher score = more complete usable function.
try:
tree = ast.parse(code)
except Exception:
return 0
funcs = [n for n in tree.body if isinstance(n, ast.FunctionDef)]
if not funcs:
return 0
fn = funcs[-1]
body_len = len(fn.body)
has_return = any(isinstance(n, ast.Return) for n in ast.walk(fn))
return body_len + (2 if has_return else 0)
def _looks_complete_function(self, code: str, min_body_statements: int) -> bool:
if "def " not in code:
return False
try:
tree = ast.parse(code)
except Exception:
return False
funcs = [n for n in tree.body if isinstance(n, ast.FunctionDef)]
if not funcs:
return False
fn = funcs[-1]
if len(fn.body) < min_body_statements:
return False
return True
def _sample_next(
self,
logits: torch.Tensor,
temperature: float,
top_p: float,
) -> torch.Tensor:
if temperature <= 0:
return torch.argmax(logits, dim=-1, keepdim=True)
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumulative = torch.cumsum(sorted_probs, dim=-1)
cutoff = cumulative > top_p
cutoff[..., 1:] = cutoff[..., :-1].clone()
cutoff[..., 0] = False
sorted_probs[cutoff] = 0.0
denom = sorted_probs.sum(dim=-1, keepdim=True).clamp_min(1e-12)
sorted_probs = sorted_probs / denom
sampled = torch.multinomial(sorted_probs, num_samples=1)
return sorted_idx.gather(-1, sampled)
@torch.no_grad()
def _generate_once(
self,
prompt: str,
language: str,
max_new_tokens: int,
temperature: float,
top_p: float,
min_tokens_before_stop_check: int,
min_function_body_statements: int,
) -> Dict[str, object]:
prompt_text = self.tokenizer.format_training_sample(prompt=prompt, code="", language=language)
prompt_text = prompt_text.replace(" <NL>", "").strip()
ids = self.tokenizer.encode(prompt_text)
eos_id = self.tokenizer.special_token_ids.get("<EOS>")
# Remove trailing EOS so generation can continue.
if eos_id is not None and len(ids) > 1 and ids[-1] == int(eos_id):
ids = ids[:-1]
input_ids = torch.tensor([ids], dtype=torch.long, device=self.device)
generated_steps = 0
for _ in range(max_new_tokens):
out = self.model(input_ids=input_ids)
logits = out["logits"][:, -1, :]
next_id = self._sample_next(logits, temperature=temperature, top_p=top_p)
input_ids = torch.cat([input_ids, next_id], dim=1)
generated_steps += 1
# Primary stop: EOS token.
if eos_id is not None and int(next_id.item()) == int(eos_id):
break
# Secondary stop: complete parseable function with non-trivial body.
if generated_steps >= min_tokens_before_stop_check and (generated_steps % 12 == 0):
decoded = self.tokenizer.decode(input_ids[0].tolist())
code = restore_code_from_structured(decoded)
if self._looks_complete_function(code, min_body_statements=min_function_body_statements):
break
decoded = self.tokenizer.decode(input_ids[0].tolist())
code = restore_code_from_structured(decoded)
syntax_ok = self._syntax_ok_python(code) if language == "python" else True
completion_score = self._function_completion_score(code) if language == "python" else 0
return {
"code": code,
"syntax_ok": syntax_ok,
"generated_tokens": generated_steps,
"temperature": temperature,
"top_p": top_p,
"completion_score": completion_score,
}
@torch.no_grad()
def generate_with_retry(
self,
prompt: str,
language: str = "python",
cfg: Optional[DecodingConfig] = None,
) -> Dict[str, object]:
cfg = cfg or DecodingConfig()
attempts: List[Tuple[float, float]] = [
(cfg.greedy_temperature, 1.0),
(cfg.retry2_temperature, cfg.retry2_top_p),
(cfg.retry3_temperature, cfg.retry3_top_p),
]
results = []
for i in range(min(cfg.max_retries, len(attempts))):
temp, top_p = attempts[i]
res = self._generate_once(
prompt=prompt,
language=language,
max_new_tokens=cfg.max_new_tokens,
temperature=temp,
top_p=top_p,
min_tokens_before_stop_check=cfg.min_tokens_before_stop_check,
min_function_body_statements=cfg.min_function_body_statements,
)
res["attempt"] = i + 1
results.append(res)
# Syntax-aware retry: stop retries as soon as syntax is valid.
if bool(res["syntax_ok"]):
return {
"final": res,
"attempts": results,
"used_retry": i > 0,
}
# If all retries fail, choose best completion score then longest generation.
best = sorted(
results,
key=lambda x: (int(x.get("completion_score", 0)), int(x.get("generated_tokens", 0))),
reverse=True,
)[0]
return {
"final": best,
"attempts": results,
"used_retry": True,
}