|
|
""" |
|
|
Vibe Coding Module for MiniMind Max2 |
|
|
Fill-in-the-Middle (FIM) and intelligent code completion. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import json |
|
|
import re |
|
|
import random |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class CodeCompletionConfig: |
|
|
"""Configuration for code completion and FIM.""" |
|
|
|
|
|
fim_prefix_token: str = "<fim_prefix>" |
|
|
fim_middle_token: str = "<fim_middle>" |
|
|
fim_suffix_token: str = "<fim_suffix>" |
|
|
fim_pad_token: str = "<fim_pad>" |
|
|
|
|
|
|
|
|
code_start_token: str = "<code>" |
|
|
code_end_token: str = "</code>" |
|
|
|
|
|
|
|
|
fim_rate: float = 0.5 |
|
|
fim_spm_rate: float = 0.5 |
|
|
|
|
|
|
|
|
max_prefix_tokens: int = 4096 |
|
|
max_suffix_tokens: int = 2048 |
|
|
max_middle_tokens: int = 1024 |
|
|
|
|
|
|
|
|
supported_languages: List[str] = field(default_factory=lambda: [ |
|
|
"python", "javascript", "typescript", "rust", "go", "java", "cpp", "c" |
|
|
]) |
|
|
|
|
|
|
|
|
enforce_syntax: bool = True |
|
|
use_tree_sitter: bool = False |
|
|
|
|
|
|
|
|
class FIMTokenizer: |
|
|
"""Handle Fill-in-the-Middle tokenization.""" |
|
|
|
|
|
def __init__(self, config: CodeCompletionConfig): |
|
|
self.config = config |
|
|
|
|
|
def create_fim_example( |
|
|
self, |
|
|
code: str, |
|
|
split_point: Optional[int] = None, |
|
|
mode: str = "PSM", |
|
|
) -> Tuple[str, str]: |
|
|
""" |
|
|
Create a FIM training example from code. |
|
|
|
|
|
Args: |
|
|
code: Full code string |
|
|
split_point: Where to split (random if None) |
|
|
mode: PSM (Prefix-Suffix-Middle) or SPM (Suffix-Prefix-Middle) |
|
|
|
|
|
Returns: |
|
|
Tuple of (fim_input, target_middle) |
|
|
""" |
|
|
if split_point is None: |
|
|
|
|
|
split_point = random.randint( |
|
|
len(code) // 4, |
|
|
3 * len(code) // 4, |
|
|
) |
|
|
|
|
|
|
|
|
while split_point < len(code) and code[split_point] != '\n': |
|
|
split_point += 1 |
|
|
|
|
|
|
|
|
middle_start = split_point |
|
|
middle_end = min( |
|
|
middle_start + random.randint(50, 500), |
|
|
len(code), |
|
|
) |
|
|
|
|
|
|
|
|
while middle_end < len(code) and code[middle_end] != '\n': |
|
|
middle_end += 1 |
|
|
|
|
|
prefix = code[:middle_start] |
|
|
middle = code[middle_start:middle_end] |
|
|
suffix = code[middle_end:] |
|
|
|
|
|
cfg = self.config |
|
|
|
|
|
if mode == "PSM": |
|
|
|
|
|
fim_input = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}" |
|
|
else: |
|
|
|
|
|
fim_input = f"{cfg.fim_suffix_token}{suffix}{cfg.fim_prefix_token}{prefix}{cfg.fim_middle_token}" |
|
|
|
|
|
return fim_input, middle |
|
|
|
|
|
def format_completion_prompt( |
|
|
self, |
|
|
prefix: str, |
|
|
suffix: str = "", |
|
|
language: str = "python", |
|
|
) -> str: |
|
|
"""Format a completion prompt.""" |
|
|
cfg = self.config |
|
|
|
|
|
if suffix: |
|
|
|
|
|
prompt = f"{cfg.fim_prefix_token}{prefix}{cfg.fim_suffix_token}{suffix}{cfg.fim_middle_token}" |
|
|
else: |
|
|
|
|
|
prompt = prefix |
|
|
|
|
|
return prompt |
|
|
|
|
|
|
|
|
class CodeProcessor: |
|
|
"""Process code for training and inference.""" |
|
|
|
|
|
|
|
|
LANGUAGE_PATTERNS = { |
|
|
"python": { |
|
|
"comment": r"#.*$", |
|
|
"docstring": r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'', |
|
|
"function": r"def\s+(\w+)\s*\(", |
|
|
"class": r"class\s+(\w+)\s*[:\(]", |
|
|
}, |
|
|
"javascript": { |
|
|
"comment": r"//.*$|/\*[\s\S]*?\*/", |
|
|
"function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>", |
|
|
"class": r"class\s+(\w+)", |
|
|
}, |
|
|
"typescript": { |
|
|
"comment": r"//.*$|/\*[\s\S]*?\*/", |
|
|
"function": r"function\s+(\w+)|(\w+)\s*=\s*(?:async\s+)?(?:\([^)]*\)|[^=])\s*=>", |
|
|
"class": r"class\s+(\w+)", |
|
|
"interface": r"interface\s+(\w+)", |
|
|
}, |
|
|
"rust": { |
|
|
"comment": r"//.*$|/\*[\s\S]*?\*/", |
|
|
"function": r"fn\s+(\w+)", |
|
|
"struct": r"struct\s+(\w+)", |
|
|
"impl": r"impl\s+(\w+)", |
|
|
}, |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def detect_language(cls, code: str, filename: Optional[str] = None) -> str: |
|
|
"""Detect programming language from code or filename.""" |
|
|
if filename: |
|
|
ext_map = { |
|
|
".py": "python", |
|
|
".js": "javascript", |
|
|
".ts": "typescript", |
|
|
".tsx": "typescript", |
|
|
".rs": "rust", |
|
|
".go": "go", |
|
|
".java": "java", |
|
|
".cpp": "cpp", |
|
|
".c": "c", |
|
|
} |
|
|
for ext, lang in ext_map.items(): |
|
|
if filename.endswith(ext): |
|
|
return lang |
|
|
|
|
|
|
|
|
if "def " in code and "import " in code: |
|
|
return "python" |
|
|
if "function " in code or "const " in code: |
|
|
return "javascript" |
|
|
if "fn " in code and "let " in code: |
|
|
return "rust" |
|
|
|
|
|
return "python" |
|
|
|
|
|
@classmethod |
|
|
def extract_context( |
|
|
cls, |
|
|
code: str, |
|
|
cursor_position: int, |
|
|
context_lines: int = 50, |
|
|
) -> Tuple[str, str]: |
|
|
"""Extract prefix and suffix around cursor position.""" |
|
|
lines = code.split('\n') |
|
|
|
|
|
|
|
|
current_pos = 0 |
|
|
cursor_line = 0 |
|
|
for i, line in enumerate(lines): |
|
|
if current_pos + len(line) + 1 > cursor_position: |
|
|
cursor_line = i |
|
|
break |
|
|
current_pos += len(line) + 1 |
|
|
|
|
|
|
|
|
start_line = max(0, cursor_line - context_lines) |
|
|
end_line = min(len(lines), cursor_line + context_lines) |
|
|
|
|
|
prefix_lines = lines[start_line:cursor_line] |
|
|
suffix_lines = lines[cursor_line + 1:end_line] |
|
|
|
|
|
prefix = '\n'.join(prefix_lines) |
|
|
suffix = '\n'.join(suffix_lines) |
|
|
|
|
|
return prefix, suffix |
|
|
|
|
|
|
|
|
class FIMModule(nn.Module): |
|
|
""" |
|
|
Fill-in-the-Middle module for code completion. |
|
|
Enables intelligent middle-of-file completion. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: CodeCompletionConfig, hidden_size: int): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
|
self.fim_position_embed = nn.Embedding(3, hidden_size) |
|
|
|
|
|
|
|
|
self.context_combiner = nn.Sequential( |
|
|
nn.Linear(hidden_size * 2, hidden_size), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_size, hidden_size), |
|
|
) |
|
|
|
|
|
|
|
|
self.quality_predictor = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size // 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_size // 4, 1), |
|
|
nn.Sigmoid(), |
|
|
) |
|
|
|
|
|
|
|
|
self.tokenizer = FIMTokenizer(config) |
|
|
self.processor = CodeProcessor() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
fim_positions: Optional[torch.Tensor] = None, |
|
|
prefix_mask: Optional[torch.Tensor] = None, |
|
|
suffix_mask: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Process hidden states with FIM awareness. |
|
|
|
|
|
Args: |
|
|
hidden_states: [batch, seq_len, hidden_size] |
|
|
fim_positions: Position type for each token (0=prefix, 1=middle, 2=suffix) |
|
|
prefix_mask: Mask for prefix tokens |
|
|
suffix_mask: Mask for suffix tokens |
|
|
|
|
|
Returns: |
|
|
Enhanced hidden states and metrics |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
|
|
|
if fim_positions is not None: |
|
|
pos_embed = self.fim_position_embed(fim_positions) |
|
|
hidden_states = hidden_states + pos_embed |
|
|
|
|
|
|
|
|
if prefix_mask is not None and suffix_mask is not None: |
|
|
|
|
|
prefix_repr = (hidden_states * prefix_mask.unsqueeze(-1)).sum(1) / prefix_mask.sum(1, keepdim=True).clamp(min=1) |
|
|
suffix_repr = (hidden_states * suffix_mask.unsqueeze(-1)).sum(1) / suffix_mask.sum(1, keepdim=True).clamp(min=1) |
|
|
|
|
|
|
|
|
context = self.context_combiner(torch.cat([prefix_repr, suffix_repr], dim=-1)) |
|
|
|
|
|
|
|
|
middle_mask = ~(prefix_mask | suffix_mask) |
|
|
if middle_mask.any(): |
|
|
context_expanded = context.unsqueeze(1).expand(-1, seq_len, -1) |
|
|
hidden_states = hidden_states + context_expanded * middle_mask.unsqueeze(-1) |
|
|
|
|
|
|
|
|
quality = self.quality_predictor(hidden_states.mean(1)) |
|
|
|
|
|
metrics = { |
|
|
"completion_quality": quality, |
|
|
} |
|
|
|
|
|
return hidden_states, metrics |
|
|
|
|
|
|
|
|
class VibeCoder: |
|
|
""" |
|
|
High-level interface for "vibe coding" - intuitive code assistance. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
tokenizer, |
|
|
config: Optional[CodeCompletionConfig] = None, |
|
|
device: str = "cuda", |
|
|
): |
|
|
self.model = model |
|
|
self.tokenizer = tokenizer |
|
|
self.config = config or CodeCompletionConfig() |
|
|
self.device = device |
|
|
|
|
|
|
|
|
if hasattr(model, 'config'): |
|
|
hidden_size = model.config.hidden_size |
|
|
else: |
|
|
hidden_size = 1024 |
|
|
|
|
|
self.fim_module = FIMModule(self.config, hidden_size).to(device) |
|
|
self.fim_tokenizer = FIMTokenizer(self.config) |
|
|
|
|
|
def complete( |
|
|
self, |
|
|
prefix: str, |
|
|
suffix: str = "", |
|
|
max_tokens: int = 100, |
|
|
temperature: float = 0.2, |
|
|
stop_tokens: Optional[List[str]] = None, |
|
|
) -> str: |
|
|
""" |
|
|
Complete code given prefix and optional suffix. |
|
|
|
|
|
Args: |
|
|
prefix: Code before cursor |
|
|
suffix: Code after cursor (for FIM) |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
stop_tokens: Tokens to stop generation |
|
|
|
|
|
Returns: |
|
|
Generated code completion |
|
|
""" |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
prompt = self.fim_tokenizer.format_completion_prompt(prefix, suffix) |
|
|
|
|
|
|
|
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
generated = self.model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=temperature > 0, |
|
|
top_p=0.95, |
|
|
) |
|
|
|
|
|
|
|
|
completion = self.tokenizer.decode( |
|
|
generated[0][input_ids.shape[1]:], |
|
|
skip_special_tokens=True, |
|
|
) |
|
|
|
|
|
|
|
|
if stop_tokens: |
|
|
for stop in stop_tokens: |
|
|
if stop in completion: |
|
|
completion = completion[:completion.index(stop)] |
|
|
|
|
|
return completion |
|
|
|
|
|
def complete_function( |
|
|
self, |
|
|
signature: str, |
|
|
context: str = "", |
|
|
language: str = "python", |
|
|
) -> str: |
|
|
"""Complete a function given its signature.""" |
|
|
if language == "python": |
|
|
prompt = f"{context}\n\n{signature}\n " |
|
|
elif language in ["javascript", "typescript"]: |
|
|
prompt = f"{context}\n\n{signature} {{\n " |
|
|
else: |
|
|
prompt = f"{context}\n\n{signature} {{\n " |
|
|
|
|
|
return self.complete(prompt, max_tokens=500) |
|
|
|
|
|
def explain_code(self, code: str, language: str = "python") -> str: |
|
|
"""Generate explanation for code.""" |
|
|
prompt = f"# Explain the following {language} code:\n```{language}\n{code}\n```\n\n# Explanation:\n" |
|
|
return self.complete(prompt, max_tokens=300, temperature=0.3) |
|
|
|
|
|
def refactor( |
|
|
self, |
|
|
code: str, |
|
|
instruction: str = "Refactor this code to be cleaner and more efficient", |
|
|
language: str = "python", |
|
|
) -> str: |
|
|
"""Refactor code based on instruction.""" |
|
|
prompt = f"""# Original code: |
|
|
```{language} |
|
|
{code} |
|
|
``` |
|
|
|
|
|
# Task: {instruction} |
|
|
|
|
|
# Refactored code: |
|
|
```{language} |
|
|
""" |
|
|
completion = self.complete(prompt, max_tokens=1000, temperature=0.2) |
|
|
|
|
|
|
|
|
if "```" in completion: |
|
|
completion = completion[:completion.index("```")] |
|
|
|
|
|
return completion |
|
|
|
|
|
def fix_bug(self, code: str, error: str = "", language: str = "python") -> str: |
|
|
"""Fix a bug in code.""" |
|
|
prompt = f"""# Buggy code: |
|
|
```{language} |
|
|
{code} |
|
|
``` |
|
|
|
|
|
# Error: {error if error else "Unknown bug"} |
|
|
|
|
|
# Fixed code: |
|
|
```{language} |
|
|
""" |
|
|
completion = self.complete(prompt, max_tokens=1000, temperature=0.1) |
|
|
|
|
|
if "```" in completion: |
|
|
completion = completion[:completion.index("```")] |
|
|
|
|
|
return completion |
|
|
|
|
|
|
|
|
class CodeDataset(Dataset): |
|
|
"""Dataset for code training with FIM.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer, |
|
|
config: CodeCompletionConfig, |
|
|
max_length: int = 2048, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.config = config |
|
|
self.max_length = max_length |
|
|
self.fim_tokenizer = FIMTokenizer(config) |
|
|
|
|
|
self.examples = [] |
|
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
self.examples.append(json.loads(line)) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.examples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
example = self.examples[idx] |
|
|
code = example.get("code", example.get("content", "")) |
|
|
language = example.get("language", "python") |
|
|
|
|
|
|
|
|
use_fim = random.random() < self.config.fim_rate |
|
|
|
|
|
if use_fim and len(code) > 100: |
|
|
|
|
|
mode = "SPM" if random.random() < self.config.fim_spm_rate else "PSM" |
|
|
fim_input, target = self.fim_tokenizer.create_fim_example(code, mode=mode) |
|
|
text = fim_input + target |
|
|
else: |
|
|
|
|
|
text = code |
|
|
|
|
|
|
|
|
encodings = self.tokenizer( |
|
|
text, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": encodings["input_ids"].squeeze(0), |
|
|
"attention_mask": encodings["attention_mask"].squeeze(0), |
|
|
"labels": encodings["input_ids"].squeeze(0), |
|
|
} |
|
|
|
|
|
|
|
|
def prepare_code_dataset( |
|
|
raw_data_path: str, |
|
|
output_path: str, |
|
|
languages: Optional[List[str]] = None, |
|
|
) -> int: |
|
|
"""Prepare code dataset for training.""" |
|
|
languages = languages or ["python", "javascript", "typescript", "rust"] |
|
|
processed = 0 |
|
|
|
|
|
with open(raw_data_path, 'r', encoding='utf-8') as fin, \ |
|
|
open(output_path, 'w', encoding='utf-8') as fout: |
|
|
|
|
|
for line in fin: |
|
|
if not line.strip(): |
|
|
continue |
|
|
|
|
|
data = json.loads(line) |
|
|
|
|
|
|
|
|
code = data.get("code", data.get("content", "")) |
|
|
language = data.get("language", "") |
|
|
|
|
|
|
|
|
if languages and language not in languages: |
|
|
continue |
|
|
|
|
|
|
|
|
if len(code) < 50 or len(code) > 100000: |
|
|
continue |
|
|
|
|
|
processed_example = { |
|
|
"code": code, |
|
|
"language": language, |
|
|
} |
|
|
|
|
|
fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n") |
|
|
processed += 1 |
|
|
|
|
|
return processed |
|
|
|