MiniMind / capabilities /coding.py
fariasultana's picture
feat: Add capabilities/coding.py
ef553ca verified
"""
Vibe Coding Module for MiniMind Max2
Fill-in-the-Middle (FIM) and intelligent code completion.
"""
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import json
import re
import random
@dataclass
class CodeCompletionConfig:
"""Configuration for code completion and FIM."""
# FIM tokens
fim_prefix_token: str = "<fim_prefix>"
fim_middle_token: str = "<fim_middle>"
fim_suffix_token: str = "<fim_suffix>"
fim_pad_token: str = "<fim_pad>"
# Code tokens
code_start_token: str = "<code>"
code_end_token: str = "</code>"
# FIM training settings
fim_rate: float = 0.5 # Probability of using FIM vs standard LM
fim_spm_rate: float = 0.5 # Suffix-Prefix-Middle vs Prefix-Suffix-Middle
# Context settings
max_prefix_tokens: int = 4096
max_suffix_tokens: int = 2048
max_middle_tokens: int = 1024
# Language support
supported_languages: List[str] = field(default_factory=lambda: [
"python", "javascript", "typescript", "rust", "go", "java", "cpp", "c"
])
# Code quality
enforce_syntax: bool = True
use_tree_sitter: bool = False # For syntax-aware completion
class FIMTokenizer:
"""Handle Fill-in-the-Middle tokenization."""
def __init__(self, config: CodeCompletionConfig):
self.config = config
def create_fim_example(
self,
code: str,
split_point: Optional[int] = None,
mode: str = "PSM", # PSM or SPM
) -> Tuple[str, str]:
"""
Create a FIM training example from code.
Args:
code: Full code string
split_point: Where to split (random if None)
mode: PSM (Prefix-Suffix-Middle) or SPM (Suffix-Prefix-Middle)
Returns:
Tuple of (fim_input, target_middle)
"""
if split_point is None:
# Random split point
split_point = random.randint(
len(code) // 4,
3 * len(code) // 4,
)
# Find a good split point (end of line)
while split_point < len(code) and code[split_point] != '\n':
split_point += 1
# Determine middle span
middle_start = split_point
middle_end = min(
middle_start + random.randint(50, 500),
len(code),
)
# Find end of middle span (end of line)
while middle_end < len(code) and code[middle_end] != '\n':
middle_end += 1
prefix = code[:middle_start]
middle = code[middle_start:middle_end]
suffix = code[middle_end:]
cfg = self.config
if mode == "PSM":
# Prefix-Suffix-Middle
fim_input = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}"
else:
# Suffix-Prefix-Middle
fim_input = f"{cfg.fim_suffix_token}{suffix}{cfg.fim_prefix_token}{prefix}{cfg.fim_middle_token}"
return fim_input, middle
def format_completion_prompt(
self,
prefix: str,
suffix: str = "",
language: str = "python",
) -> str:
"""Format a completion prompt."""
cfg = self.config
if suffix:
# FIM mode
prompt = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}"
else:
# Standard completion
prompt = prefix
return prompt
class CodeProcessor:
"""Process code for training and inference."""
# Language-specific patterns
LANGUAGE_PATTERNS = {
"python": {
"comment": r"#.*$",
"docstring": r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'',
"function": r"def\s+(\w+)\s*\(",
"class": r"class\s+(\w+)\s*[:\(]",
},
"javascript": {
"comment": r"//.*$|/\*[\s\S]*?\*/",
"function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>",
"class": r"class\s+(\w+)",
},
"typescript": {
"comment": r"//.*$|/\*[\s\S]*?\*/",
"function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>",
"class": r"class\s+(\w+)",
"interface": r"interface\s+(\w+)",
},
"rust": {
"comment": r"//.*$|/\*[\s\S]*?\*/",
"function": r"fn\s+(\w+)",
"struct": r"struct\s+(\w+)",
"impl": r"impl\s+(\w+)",
},
}
@classmethod
def detect_language(cls, code: str, filename: Optional[str] = None) -> str:
"""Detect programming language from code or filename."""
if filename:
ext_map = {
".py": "python",
".js": "javascript",
".ts": "typescript",
".tsx": "typescript",
".rs": "rust",
".go": "go",
".java": "java",
".cpp": "cpp",
".c": "c",
}
for ext, lang in ext_map.items():
if filename.endswith(ext):
return lang
# Heuristic detection
if "def " in code and "import " in code:
return "python"
if "function " in code or "const " in code:
return "javascript"
if "fn " in code and "let " in code:
return "rust"
return "python" # Default
@classmethod
def extract_context(
cls,
code: str,
cursor_position: int,
context_lines: int = 50,
) -> Tuple[str, str]:
"""Extract prefix and suffix around cursor position."""
lines = code.split('\n')
# Find line number for cursor
current_pos = 0
cursor_line = 0
for i, line in enumerate(lines):
if current_pos + len(line) + 1 > cursor_position:
cursor_line = i
break
current_pos += len(line) + 1
# Get context lines
start_line = max(0, cursor_line - context_lines)
end_line = min(len(lines), cursor_line + context_lines)
prefix_lines = lines[start_line:cursor_line]
suffix_lines = lines[cursor_line + 1:end_line]
prefix = '\n'.join(prefix_lines)
suffix = '\n'.join(suffix_lines)
return prefix, suffix
class FIMModule(nn.Module):
"""
Fill-in-the-Middle module for code completion.
Enables intelligent middle-of-file completion.
"""
def __init__(self, config: CodeCompletionConfig, hidden_size: int):
super().__init__()
self.config = config
self.hidden_size = hidden_size
# FIM position embeddings
self.fim_position_embed = nn.Embedding(3, hidden_size) # prefix, middle, suffix
# Context combiner
self.context_combiner = nn.Sequential(
nn.Linear(hidden_size * 2, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
# Completion quality predictor
self.quality_predictor = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.GELU(),
nn.Linear(hidden_size // 4, 1),
nn.Sigmoid(),
)
# Tokenizer helper
self.tokenizer = FIMTokenizer(config)
self.processor = CodeProcessor()
def forward(
self,
hidden_states: torch.Tensor,
fim_positions: Optional[torch.Tensor] = None,
prefix_mask: Optional[torch.Tensor] = None,
suffix_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Process hidden states with FIM awareness.
Args:
hidden_states: [batch, seq_len, hidden_size]
fim_positions: Position type for each token (0=prefix, 1=middle, 2=suffix)
prefix_mask: Mask for prefix tokens
suffix_mask: Mask for suffix tokens
Returns:
Enhanced hidden states and metrics
"""
batch_size, seq_len, _ = hidden_states.shape
# Add FIM position embeddings
if fim_positions is not None:
pos_embed = self.fim_position_embed(fim_positions)
hidden_states = hidden_states + pos_embed
# Combine context from prefix and suffix
if prefix_mask is not None and suffix_mask is not None:
# Average pool prefix and suffix representations
prefix_repr = (hidden_states * prefix_mask.unsqueeze(-1)).sum(1) / prefix_mask.sum(1, keepdim=True).clamp(min=1)
suffix_repr = (hidden_states * suffix_mask.unsqueeze(-1)).sum(1) / suffix_mask.sum(1, keepdim=True).clamp(min=1)
# Combine
context = self.context_combiner(torch.cat([prefix_repr, suffix_repr], dim=-1))
# Add context to middle tokens
middle_mask = ~(prefix_mask | suffix_mask)
if middle_mask.any():
context_expanded = context.unsqueeze(1).expand(-1, seq_len, -1)
hidden_states = hidden_states + context_expanded * middle_mask.unsqueeze(-1)
# Quality prediction
quality = self.quality_predictor(hidden_states.mean(1))
metrics = {
"completion_quality": quality,
}
return hidden_states, metrics
class VibeCoder:
"""
High-level interface for "vibe coding" - intuitive code assistance.
"""
def __init__(
self,
model: nn.Module,
tokenizer,
config: Optional[CodeCompletionConfig] = None,
device: str = "cuda",
):
self.model = model
self.tokenizer = tokenizer
self.config = config or CodeCompletionConfig()
self.device = device
# Get hidden size
if hasattr(model, 'config'):
hidden_size = model.config.hidden_size
else:
hidden_size = 1024
self.fim_module = FIMModule(self.config, hidden_size).to(device)
self.fim_tokenizer = FIMTokenizer(self.config)
def complete(
self,
prefix: str,
suffix: str = "",
max_tokens: int = 100,
temperature: float = 0.2,
stop_tokens: Optional[List[str]] = None,
) -> str:
"""
Complete code given prefix and optional suffix.
Args:
prefix: Code before cursor
suffix: Code after cursor (for FIM)
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
stop_tokens: Tokens to stop generation
Returns:
Generated code completion
"""
self.model.eval()
# Format prompt
prompt = self.fim_tokenizer.format_completion_prompt(prefix, suffix)
# Tokenize
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
# Generate
with torch.no_grad():
generated = self.model.generate(
input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=temperature > 0,
top_p=0.95,
)
# Decode
completion = self.tokenizer.decode(
generated[0][input_ids.shape[1]:],
skip_special_tokens=True,
)
# Stop at stop tokens
if stop_tokens:
for stop in stop_tokens:
if stop in completion:
completion = completion[:completion.index(stop)]
return completion
def complete_function(
self,
signature: str,
context: str = "",
language: str = "python",
) -> str:
"""Complete a function given its signature."""
if language == "python":
prompt = f"{context}\n\n{signature}\n "
elif language in ["javascript", "typescript"]:
prompt = f"{context}\n\n{signature} {{\n "
else:
prompt = f"{context}\n\n{signature} {{\n "
return self.complete(prompt, max_tokens=500)
def explain_code(self, code: str, language: str = "python") -> str:
"""Generate explanation for code."""
prompt = f"# Explain the following {language} code:\n```{language}\n{code}\n```\n\n# Explanation:\n"
return self.complete(prompt, max_tokens=300, temperature=0.3)
def refactor(
self,
code: str,
instruction: str = "Refactor this code to be cleaner and more efficient",
language: str = "python",
) -> str:
"""Refactor code based on instruction."""
prompt = f"""# Original code:
```{language}
{code}
```
# Task: {instruction}
# Refactored code:
```{language}
"""
completion = self.complete(prompt, max_tokens=1000, temperature=0.2)
# Clean up
if "```" in completion:
completion = completion[:completion.index("```")]
return completion
def fix_bug(self, code: str, error: str = "", language: str = "python") -> str:
"""Fix a bug in code."""
prompt = f"""# Buggy code:
```{language}
{code}
```
# Error: {error if error else "Unknown bug"}
# Fixed code:
```{language}
"""
completion = self.complete(prompt, max_tokens=1000, temperature=0.1)
if "```" in completion:
completion = completion[:completion.index("```")]
return completion
class CodeDataset(Dataset):
"""Dataset for code training with FIM."""
def __init__(
self,
data_path: str,
tokenizer,
config: CodeCompletionConfig,
max_length: int = 2048,
):
self.tokenizer = tokenizer
self.config = config
self.max_length = max_length
self.fim_tokenizer = FIMTokenizer(config)
self.examples = []
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
self.examples.append(json.loads(line))
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
example = self.examples[idx]
code = example.get("code", example.get("content", ""))
language = example.get("language", "python")
# Decide FIM vs standard LM
use_fim = random.random() < self.config.fim_rate
if use_fim and len(code) > 100:
# Create FIM example
mode = "SPM" if random.random() < self.config.fim_spm_rate else "PSM"
fim_input, target = self.fim_tokenizer.create_fim_example(code, mode=mode)
text = fim_input + target
else:
# Standard LM
text = code
# Tokenize
encodings = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
return {
"input_ids": encodings["input_ids"].squeeze(0),
"attention_mask": encodings["attention_mask"].squeeze(0),
"labels": encodings["input_ids"].squeeze(0),
}
def prepare_code_dataset(
raw_data_path: str,
output_path: str,
languages: Optional[List[str]] = None,
) -> int:
"""Prepare code dataset for training."""
languages = languages or ["python", "javascript", "typescript", "rust"]
processed = 0
with open(raw_data_path, 'r', encoding='utf-8') as fin, \
open(output_path, 'w', encoding='utf-8') as fout:
for line in fin:
if not line.strip():
continue
data = json.loads(line)
# Extract code and language
code = data.get("code", data.get("content", ""))
language = data.get("language", "")
# Filter by language
if languages and language not in languages:
continue
# Filter by quality (basic heuristics)
if len(code) < 50 or len(code) > 100000:
continue
processed_example = {
"code": code,
"language": language,
}
fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n")
processed += 1
return processed