| |
| """ |
| Production Encoder LoRA Training for Stablebridge |
| |
| Trains LoRA adapters on BAAI/bge-m3 for US regulatory domain. |
| Implements tech spec requirements: |
| - LoRA rank 16, alpha 32 |
| - 8192 token context window |
| - MultipleNegativesRankingLoss (in-batch negatives) |
| - WandB logging, checkpointing, evaluation |
| - Model Hub push |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import torch |
| import wandb |
| from pathlib import Path |
| from datetime import datetime |
| from typing import Dict, List, Optional |
| from dataclasses import dataclass, field |
|
|
| from transformers import ( |
| AutoTokenizer, |
| AutoModel, |
| get_cosine_schedule_with_warmup, |
| TrainingArguments, |
| ) |
| from peft import LoraConfig, get_peft_model, TaskType |
| from torch.utils.data import Dataset, DataLoader |
| from torch.nn import functional as F |
| from torch.cuda.amp import autocast, GradScaler |
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| @dataclass |
| class EncoderTrainingConfig: |
| """Complete training configuration matching tech spec.""" |
| |
| |
| base_model: str = "BAAI/bge-m3" |
| max_length: int = 8192 |
| |
| |
| lora_rank: int = 16 |
| lora_alpha: int = 32 |
| lora_dropout: float = 0.1 |
| target_modules: List[str] = field(default_factory=lambda: ["query", "key", "value"]) |
| |
| |
| epochs: int = 3 |
| per_device_batch_size: int = 4 |
| gradient_accumulation_steps: int = 16 |
| learning_rate: float = 5e-5 |
| weight_decay: float = 0.01 |
| warmup_ratio: float = 0.1 |
| max_grad_norm: float = 1.0 |
| |
| |
| mixed_precision: str = "bf16" |
| |
| |
| save_steps: int = 500 |
| eval_steps: int = 500 |
| logging_steps: int = 50 |
| |
| |
| data_path: str = "/workspace/data/labels/encoder_triplets.jsonl" |
| corpus_dir: str = "/workspace/data/raw" |
| output_dir: str = "/workspace/checkpoints/bge-m3-us-regulatory-lora" |
| |
| |
| wandb_project: str = "stablebridge-encoder" |
| wandb_run_name: Optional[str] = None |
| |
| |
| push_to_hub: bool = True |
| hub_model_id: str = "cognilogue/bge-m3-us-regulatory-lora" |
| hub_token: Optional[str] = None |
| |
| |
| eval_split: float = 0.1 |
| eval_metrics: List[str] = field(default_factory=lambda: ["ndcg@10", "mrr@10", "recall@100"]) |
|
|
|
|
| class TripletDataset(Dataset): |
| """Dataset for encoder triplet training with in-batch negatives.""" |
| |
| def __init__( |
| self, |
| triplets: List[Dict], |
| corpus: Dict[str, str], |
| tokenizer, |
| max_length: int = 8192 |
| ): |
| self.triplets = triplets |
| self.corpus = corpus |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| |
| def __len__(self): |
| return len(self.triplets) |
| |
| def __getitem__(self, idx): |
| triplet = self.triplets[idx] |
| query = triplet["query"] |
| pos_id = triplet["positive"] |
| |
| |
| positive_text = self.corpus.get(pos_id, "") |
| |
| |
| query_enc = self.tokenizer( |
| query, |
| max_length=self.max_length, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt" |
| ) |
| |
| pos_enc = self.tokenizer( |
| positive_text, |
| max_length=self.max_length, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt" |
| ) |
| |
| return { |
| "query_input_ids": query_enc["input_ids"].squeeze(0), |
| "query_attention_mask": query_enc["attention_mask"].squeeze(0), |
| "pos_input_ids": pos_enc["input_ids"].squeeze(0), |
| "pos_attention_mask": pos_enc["attention_mask"].squeeze(0), |
| } |
|
|
|
|
| def mean_pooling(model_output, attention_mask): |
| """Mean pooling over token embeddings (ignore padding).""" |
| token_embeddings = model_output[0] |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp( |
| input_mask_expanded.sum(1), min=1e-9 |
| ) |
|
|
|
|
| def compute_loss(query_emb, pos_emb, temperature=0.05): |
| """ |
| Multiple Negatives Ranking Loss (InfoNCE). |
| |
| Uses in-batch negatives: all other positives in the batch serve as negatives. |
| Standard approach in sentence-transformers contrastive learning. |
| |
| Args: |
| query_emb: (batch_size, hidden_dim) - normalized query embeddings |
| pos_emb: (batch_size, hidden_dim) - normalized positive embeddings |
| temperature: Temperature for softmax (default 0.05) |
| |
| Returns: |
| loss: Scalar loss value |
| """ |
| |
| query_emb = F.normalize(query_emb, p=2, dim=1) |
| pos_emb = F.normalize(pos_emb, p=2, dim=1) |
| |
| |
| |
| sim_matrix = torch.matmul(query_emb, pos_emb.T) / temperature |
| |
| |
| labels = torch.arange(sim_matrix.size(0)).to(sim_matrix.device) |
| |
| |
| loss = F.cross_entropy(sim_matrix, labels) |
| |
| return loss |
|
|
|
|
| def evaluate_retrieval(model, tokenizer, eval_data, corpus, device, config): |
| """ |
| Evaluate retrieval quality on validation set. |
| |
| Metrics: |
| - NDCG@10: Ranking quality |
| - MRR@10: Mean Reciprocal Rank |
| - Recall@100: Coverage |
| """ |
| model.eval() |
| |
| |
| print("\nEncoding corpus for evaluation...") |
| doc_ids = list(corpus.keys()) |
| doc_embeddings = [] |
| |
| with torch.no_grad(): |
| for doc_id in tqdm(doc_ids, desc="Encoding docs"): |
| doc_text = corpus[doc_id] |
| doc_enc = tokenizer( |
| doc_text, |
| max_length=config.max_length, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt" |
| ).to(device) |
| |
| doc_output = model(**doc_enc) |
| doc_emb = mean_pooling(doc_output, doc_enc["attention_mask"]) |
| doc_emb = F.normalize(doc_emb, p=2, dim=1) |
| doc_embeddings.append(doc_emb.cpu()) |
| |
| doc_embeddings = torch.cat(doc_embeddings, dim=0) |
| |
| |
| ndcg_scores = [] |
| mrr_scores = [] |
| recall_scores = [] |
| |
| with torch.no_grad(): |
| for triplet in tqdm(eval_data, desc="Evaluating"): |
| query = triplet["query"] |
| pos_id = triplet["positive"] |
| |
| |
| query_enc = tokenizer( |
| query, |
| max_length=config.max_length, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt" |
| ).to(device) |
| |
| query_output = model(**query_enc) |
| query_emb = mean_pooling(query_output, query_enc["attention_mask"]) |
| query_emb = F.normalize(query_emb, p=2, dim=1) |
| |
| |
| similarities = torch.matmul(query_emb.cpu(), doc_embeddings.T).squeeze(0) |
| |
| |
| ranks = torch.argsort(similarities, descending=True) |
| |
| |
| try: |
| pos_idx = doc_ids.index(pos_id) |
| pos_rank = (ranks == pos_idx).nonzero(as_tuple=True)[0].item() + 1 |
| except (ValueError, IndexError): |
| pos_rank = len(doc_ids) + 1 |
| |
| |
| if pos_rank <= 10: |
| ndcg = 1.0 / np.log2(pos_rank + 1) |
| else: |
| ndcg = 0.0 |
| ndcg_scores.append(ndcg) |
| |
| |
| if pos_rank <= 10: |
| mrr = 1.0 / pos_rank |
| else: |
| mrr = 0.0 |
| mrr_scores.append(mrr) |
| |
| |
| recall = 1.0 if pos_rank <= 100 else 0.0 |
| recall_scores.append(recall) |
| |
| metrics = { |
| "eval/ndcg@10": np.mean(ndcg_scores), |
| "eval/mrr@10": np.mean(mrr_scores), |
| "eval/recall@100": np.mean(recall_scores), |
| } |
| |
| return metrics |
|
|
|
|
| def load_data(config: EncoderTrainingConfig): |
| """Load triplets and corpus, split train/eval.""" |
| |
| |
| print(f"Loading triplets from {config.data_path}...") |
| triplets = [] |
| with open(config.data_path) as f: |
| for line in f: |
| if line.strip(): |
| triplets.append(json.loads(line)) |
| print(f"✅ {len(triplets)} triplets") |
| |
| |
| print(f"Loading corpus from {config.corpus_dir}...") |
| corpus = {} |
| corpus_dir = Path(config.corpus_dir) |
| for json_file in corpus_dir.glob("*.json"): |
| with open(json_file) as f: |
| doc = json.load(f) |
| doc_id = doc.get("doc_id") |
| content = doc.get("content", "") |
| if doc_id and content: |
| corpus[doc_id] = content |
| print(f"✅ {len(corpus)} documents") |
| |
| |
| num_eval = int(len(triplets) * config.eval_split) |
| eval_triplets = triplets[:num_eval] |
| train_triplets = triplets[num_eval:] |
| |
| print(f"\nSplit: {len(train_triplets)} train, {len(eval_triplets)} eval") |
| |
| return train_triplets, eval_triplets, corpus |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, help="Path to YAML config file (optional)") |
| parser.add_argument("--data-path", type=str, help="Override triplets path") |
| parser.add_argument("--output-dir", type=str, help="Override output directory") |
| parser.add_argument("--batch-size", type=int, help="Override batch size") |
| parser.add_argument("--epochs", type=int, help="Override number of epochs") |
| parser.add_argument("--no-wandb", action="store_true", help="Disable WandB logging") |
| parser.add_argument("--no-push", action="store_true", help="Disable Hub push") |
| args = parser.parse_args() |
| |
| |
| config = EncoderTrainingConfig() |
| |
| |
| if args.data_path: |
| config.data_path = args.data_path |
| if args.output_dir: |
| config.output_dir = args.output_dir |
| if args.batch_size: |
| config.per_device_batch_size = args.batch_size |
| if args.epochs: |
| config.epochs = args.epochs |
| if args.no_push: |
| config.push_to_hub = False |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print("=" * 80) |
| print("STABLEBRIDGE ENCODER LORA TRAINING") |
| print("=" * 80) |
| print(f"Device: {device}") |
| print(f"Base model: {config.base_model}") |
| print(f"LoRA rank: {config.lora_rank}, alpha: {config.lora_alpha}") |
| print(f"Max length: {config.max_length}") |
| print(f"Batch size: {config.per_device_batch_size} × {config.gradient_accumulation_steps} = {config.per_device_batch_size * config.gradient_accumulation_steps}") |
| print(f"Epochs: {config.epochs}") |
| print(f"Output: {config.output_dir}") |
| |
| |
| use_wandb = not args.no_wandb and os.getenv("WANDB_API_KEY") |
| if use_wandb: |
| wandb.init( |
| project=config.wandb_project, |
| name=config.wandb_run_name or f"encoder-lora-{datetime.now().strftime('%Y%m%d-%H%M%S')}", |
| config=vars(config) |
| ) |
| |
| |
| train_triplets, eval_triplets, corpus = load_data(config) |
| |
| |
| print("\n" + "=" * 80) |
| print("MODEL SETUP") |
| print("=" * 80) |
| |
| print("\nLoading tokenizer and model...") |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True, local_files_only=True) |
| |
| |
| if config.mixed_precision == "bf16" and torch.cuda.is_bf16_supported(): |
| dtype = torch.bfloat16 |
| print("Using bfloat16 precision") |
| else: |
| dtype = torch.float16 |
| print("Using float16 precision") |
| |
| model = AutoModel.from_pretrained( |
| config.base_model, |
| torch_dtype=dtype, |
| trust_remote_code=True, |
| local_files_only=True |
| ).to(device) |
| |
| |
| print(f"\nApplying LoRA (rank={config.lora_rank}, alpha={config.lora_alpha})...") |
| lora_config = LoraConfig( |
| r=config.lora_rank, |
| lora_alpha=config.lora_alpha, |
| target_modules=config.target_modules, |
| lora_dropout=config.lora_dropout, |
| bias="none", |
| task_type=TaskType.FEATURE_EXTRACTION, |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
| |
| |
| train_dataset = TripletDataset(train_triplets, corpus, tokenizer, config.max_length) |
| eval_dataset = eval_triplets |
| |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.per_device_batch_size, |
| shuffle=True, |
| num_workers=4, |
| pin_memory=True |
| ) |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=config.learning_rate, |
| weight_decay=config.weight_decay |
| ) |
| |
| |
| num_training_steps = len(train_loader) * config.epochs // config.gradient_accumulation_steps |
| num_warmup_steps = int(num_training_steps * config.warmup_ratio) |
| |
| scheduler = get_cosine_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=num_warmup_steps, |
| num_training_steps=num_training_steps |
| ) |
| |
| |
| scaler = GradScaler() if dtype == torch.float16 else None |
| |
| |
| print("\n" + "=" * 80) |
| print("TRAINING") |
| print("=" * 80) |
| print(f"Total steps: {num_training_steps}") |
| print(f"Warmup steps: {num_warmup_steps}") |
| |
| global_step = 0 |
| best_ndcg = 0.0 |
| |
| for epoch in range(config.epochs): |
| print(f"\n{'='*80}") |
| print(f"EPOCH {epoch + 1}/{config.epochs}") |
| print(f"{'='*80}") |
| |
| model.train() |
| epoch_loss = 0.0 |
| optimizer.zero_grad() |
| |
| pbar = tqdm(train_loader, desc=f"Epoch {epoch + 1}") |
| for step, batch in enumerate(pbar): |
| |
| query_ids = batch["query_input_ids"].to(device) |
| query_mask = batch["query_attention_mask"].to(device) |
| pos_ids = batch["pos_input_ids"].to(device) |
| pos_mask = batch["pos_attention_mask"].to(device) |
| |
| |
| with autocast(dtype=dtype): |
| query_output = model(input_ids=query_ids, attention_mask=query_mask) |
| query_emb = mean_pooling(query_output, query_mask) |
| |
| pos_output = model(input_ids=pos_ids, attention_mask=pos_mask) |
| pos_emb = mean_pooling(pos_output, pos_mask) |
| |
| |
| loss = compute_loss(query_emb, pos_emb) |
| loss = loss / config.gradient_accumulation_steps |
| |
| |
| if scaler: |
| scaler.scale(loss).backward() |
| else: |
| loss.backward() |
| |
| epoch_loss += loss.item() * config.gradient_accumulation_steps |
| |
| |
| if (step + 1) % config.gradient_accumulation_steps == 0: |
| |
| if scaler: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| |
| |
| if scaler: |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| optimizer.step() |
| |
| scheduler.step() |
| optimizer.zero_grad() |
| global_step += 1 |
| |
| |
| if global_step % config.logging_steps == 0: |
| lr = scheduler.get_last_lr()[0] |
| pbar.set_postfix({ |
| "loss": f"{loss.item() * config.gradient_accumulation_steps:.4f}", |
| "lr": f"{lr:.2e}" |
| }) |
| |
| if use_wandb: |
| wandb.log({ |
| "train/loss": loss.item() * config.gradient_accumulation_steps, |
| "train/learning_rate": lr, |
| "train/epoch": epoch, |
| "train/step": global_step, |
| }) |
| |
| |
| if global_step % config.eval_steps == 0: |
| print("\n" + "-" * 80) |
| print(f"EVALUATION at step {global_step}") |
| print("-" * 80) |
| |
| eval_metrics = evaluate_retrieval( |
| model, tokenizer, eval_dataset, corpus, device, config |
| ) |
| |
| print("\nEvaluation Results:") |
| for metric, value in eval_metrics.items(): |
| print(f" {metric}: {value:.4f}") |
| |
| if use_wandb: |
| wandb.log(eval_metrics) |
| |
| |
| if eval_metrics["eval/ndcg@10"] > best_ndcg: |
| best_ndcg = eval_metrics["eval/ndcg@10"] |
| print(f"\n✅ New best NDCG@10: {best_ndcg:.4f}") |
| best_model_dir = Path(config.output_dir) / "best" |
| best_model_dir.mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(best_model_dir) |
| tokenizer.save_pretrained(best_model_dir) |
| |
| model.train() |
| print("-" * 80) |
| |
| |
| if global_step % config.save_steps == 0: |
| checkpoint_dir = Path(config.output_dir) / f"checkpoint-{global_step}" |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(checkpoint_dir) |
| tokenizer.save_pretrained(checkpoint_dir) |
| print(f"\n💾 Checkpoint saved: {checkpoint_dir}") |
| |
| avg_loss = epoch_loss / len(train_loader) |
| print(f"\nEpoch {epoch + 1} - Average Loss: {avg_loss:.4f}") |
| |
| |
| print("\n" + "=" * 80) |
| print("SAVING FINAL MODEL") |
| print("=" * 80) |
| |
| output_dir = Path(config.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
| model.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
| print(f"✅ Model saved to: {output_dir}") |
| |
| |
| if config.push_to_hub: |
| print("\n" + "=" * 80) |
| print("PUSHING TO HUGGING FACE HUB") |
| print("=" * 80) |
| |
| try: |
| model.push_to_hub( |
| config.hub_model_id, |
| token=config.hub_token or os.getenv("HF_TOKEN") |
| ) |
| tokenizer.push_to_hub( |
| config.hub_model_id, |
| token=config.hub_token or os.getenv("HF_TOKEN") |
| ) |
| print(f"✅ Model pushed to: {config.hub_model_id}") |
| except Exception as e: |
| print(f"❌ Failed to push to Hub: {e}") |
| |
| |
| print("\n" + "=" * 80) |
| print("FINAL EVALUATION") |
| print("=" * 80) |
| |
| final_metrics = evaluate_retrieval( |
| model, tokenizer, eval_dataset, corpus, device, config |
| ) |
| |
| print("\nFinal Results:") |
| for metric, value in final_metrics.items(): |
| print(f" {metric}: {value:.4f}") |
| |
| if use_wandb: |
| wandb.log({"final/" + k.split("/")[1]: v for k, v in final_metrics.items()}) |
| wandb.finish() |
| |
| print("\n" + "=" * 80) |
| print("TRAINING COMPLETE!") |
| print("=" * 80) |
| print(f"Best NDCG@10: {best_ndcg:.4f}") |
| print(f"Model saved to: {output_dir}") |
| if config.push_to_hub: |
| print(f"Hub: https://huggingface.co/{config.hub_model_id}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|