| | import argparse |
| | import os |
| | import torch |
| | from torch.utils.data import DataLoader, Dataset |
| | from transformers import AutoTokenizer |
| |
|
| | from scripts.core.training.model import CodeEmbedder |
| | from scripts.core.training.trainer import CodeTrainer |
| |
|
| | import json |
| |
|
| | |
| | class RealCodeDataset(Dataset): |
| | def __init__(self, jsonl_path, tokenizer, max_length=512): |
| | self.tokenizer = tokenizer |
| | self.max_length = max_length |
| | self.data = [] |
| | |
| | print(f"Loading data from {jsonl_path}...") |
| | with open(jsonl_path, 'r', encoding='utf-8') as f: |
| | for line in f: |
| | if line.strip(): |
| | self.data.append(json.loads(line)) |
| | print(f"Loaded {len(self.data)} triplets.") |
| |
|
| | def __len__(self): |
| | return len(self.data) |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| | |
| | |
| | def tokenize_text(text): |
| | return self.tokenizer( |
| | text, |
| | return_tensors='pt', |
| | padding='max_length', |
| | truncation=True, |
| | max_length=self.max_length |
| | ) |
| | |
| | |
| | anchor = tokenize_text(item['anchor']) |
| | positive = tokenize_text(item['positive']) |
| | negative = tokenize_text(item['negative']) |
| | |
| | |
| | return { |
| | 'anchor_input_ids': anchor['input_ids'].squeeze(0), |
| | 'anchor_attention_mask': anchor['attention_mask'].squeeze(0), |
| | 'positive_input_ids': positive['input_ids'].squeeze(0), |
| | 'positive_attention_mask': positive['attention_mask'].squeeze(0), |
| | 'negative_input_ids': negative['input_ids'].squeeze(0), |
| | 'negative_attention_mask': negative['attention_mask'].squeeze(0) |
| | } |
| |
|
| | |
| | class DummyCodeDataset(Dataset): |
| | def __init__(self, tokenizer, size=100): |
| | self.tokenizer = tokenizer |
| | self.size = size |
| | |
| | self.data = [{"anchor": "def hello(): return 'world'", "positive": "def hi(): return 'earth'", "negative": "class Foo: pass"}] * size |
| |
|
| | def __len__(self): |
| | return self.size |
| |
|
| | def __getitem__(self, idx): |
| | item = self.data[idx] |
| | |
| | |
| | def tokenize_text(text): |
| | return self.tokenizer( |
| | text, |
| | return_tensors='pt', |
| | padding='max_length', |
| | truncation=True, |
| | max_length=128 |
| | ) |
| | |
| | anchor = tokenize_text(item['anchor']) |
| | positive = tokenize_text(item['positive']) |
| | negative = tokenize_text(item['negative']) |
| |
|
| | return { |
| | 'anchor_input_ids': anchor['input_ids'].squeeze(0), |
| | 'anchor_attention_mask': anchor['attention_mask'].squeeze(0), |
| | 'positive_input_ids': positive['input_ids'].squeeze(0), |
| | 'positive_attention_mask': positive['attention_mask'].squeeze(0), |
| | 'negative_input_ids': negative['input_ids'].squeeze(0), |
| | 'negative_attention_mask': negative['attention_mask'].squeeze(0) |
| | } |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Train CodeMode Embeddings") |
| | |
| | parser.add_argument("--model_name", type=str, default="microsoft/codebert-base", help="Hub model name") |
| | parser.add_argument("--data_path", type=str, required=False, help="Path to parsed chunks.jsonl") |
| | parser.add_argument("--output_dir", type=str, default="./output", help="Where to save checkpoints") |
| | parser.add_argument("--epochs", type=int, default=3) |
| | parser.add_argument("--batch_size", type=int, default=8) |
| | parser.add_argument("--accumulation_steps", type=int, default=4, help="Gradient Accumulation Steps") |
| | parser.add_argument("--lr", type=float, default=2e-5) |
| | parser.add_argument("--dry_run", action="store_true", help="Run with dummy data for 1 epoch") |
| |
|
| | args = parser.parse_args() |
| | |
| | print(f"Initializing Training Pipeline...") |
| | print(f" Model: {args.model_name}") |
| | print(f" Output: {args.output_dir}") |
| | print(f" Device: {'cuda' if torch.cuda.is_available() else 'cpu'}") |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| |
|
| | |
| | if args.data_path and os.path.exists(args.data_path): |
| | train_dataset = RealCodeDataset(args.data_path, tokenizer) |
| | else: |
| | print("No data path provided or file missing. Using DUMMY data for verification.") |
| | train_dataset = DummyCodeDataset(tokenizer, size=100) |
| |
|
| | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
| |
|
| | |
| | model = CodeEmbedder(model_name_or_path=args.model_name) |
| |
|
| | |
| | trainer = CodeTrainer( |
| | model=model, |
| | train_loader=train_loader, |
| | epochs=args.epochs, |
| | learning_rate=args.lr, |
| | accumulation_steps=args.accumulation_steps, |
| | mixed_precision=True, |
| | output_dir=args.output_dir |
| | ) |
| |
|
| | |
| | trainer.train() |
| | |
| | print("Training Complete.") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|