Spaces:
Sleeping
Sleeping
| # ============================================================================= | |
| # train.py | |
| # Dataset Preparation + Finetuning Entry Point | |
| # SmolLM2 Service Space | |
| # Copyright 2026 - Volkan KΓΌcΓΌkbudak | |
| # Apache License V2 + ESOL 1.1 | |
| # ============================================================================= | |
| # Usage: | |
| # python train.py --mode export β export HF dataset to training format | |
| # python train.py --mode validate β validate ADI weights against dataset | |
| # python train.py --mode finetune β finetune SmolLM2 on exported data | |
| # ============================================================================= | |
| import os | |
| import argparse | |
| import json | |
| import logging | |
| from datetime import datetime | |
| from pathlib import Path | |
| # ββ Path Resolution βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HF Spaces: /tmp/ (read-only filesystem) | |
| # Local dev: current directory | |
| _TMP = Path("/tmp") if os.getenv("SPACE_ID") else Path(".") | |
| TRAIN_DATA = _TMP / "train_data.jsonl" | |
| VALID_RESULT = _TMP / "validation_results.json" | |
| MODEL_OUTPUT = _TMP / "finetuned_model" | |
| import model as model_module | |
| from adi import DumpindexAnalyzer | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s") | |
| logger = logging.getLogger("train") | |
| # ============================================================================= | |
| # Mode 1 β Export dataset to training format | |
| # ============================================================================= | |
| def export_dataset(output_path: str = None): | |
| """ | |
| Export HF dataset logs to JSONL format for training. | |
| Includes HIGH_PRIORITY, MEDIUM_PRIORITY and BLOCKED entries. | |
| BLOCKED entries teach the model what to reject. | |
| REJECT entries (ADI noise/quality fail) are skipped β no response logged. | |
| """ | |
| output = Path(output_path) if output_path else TRAIN_DATA | |
| logger.info("Loading dataset from HF...") | |
| entries = model_module.load_logs() | |
| # ββ DEBUG: remove after fix βββββββββββββββββββββββββββββββββββββββββββ | |
| #if entries: | |
| #logger.info(f"Keys: {list(entries[0].keys())}") | |
| #logger.info(f"Sample: {entries[0]}") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if not entries: | |
| logger.warning("Dataset empty β nothing to export") | |
| return | |
| count = 0 | |
| skipped = 0 | |
| with open(output, "w") as f: | |
| for entry in entries: | |
| # Skip ADI-rejected entries β no meaningful response logged | |
| if entry.get("adi_decision") == "REJECT": | |
| skipped += 1 | |
| continue | |
| if not entry.get("response"): | |
| skipped += 1 | |
| continue | |
| # Format as instruction tuning pair | |
| # BLOCKED entries are included β model learns what to refuse | |
| record = { | |
| "instruction": entry.get("system_prompt", "You are a helpful assistant."), | |
| "input": entry.get("prompt", ""), | |
| "output": entry.get("response", ""), | |
| "adi_score": entry.get("adi_score"), | |
| "adi_decision": entry.get("adi_decision"), | |
| "is_safe": entry.get("adi_decision") != "BLOCKED", | |
| } | |
| f.write(json.dumps(record) + "\n") | |
| count += 1 | |
| logger.info(f"Exported {count}/{len(entries)} entries β {output} (skipped: {skipped})") | |
| # ============================================================================= | |
| # Mode 2 β Validate ADI weights against collected data | |
| # ============================================================================= | |
| def validate_adi(): | |
| """ | |
| Run ADI weight validation against dataset. | |
| Uses entries that have human_label field (manually labeled). | |
| """ | |
| logger.info("Loading dataset for ADI validation...") | |
| entries = model_module.load_logs() | |
| labeled = [(e["prompt"], e["human_label"]) for e in entries if e.get("human_label")] | |
| if not labeled: | |
| logger.warning("No labeled entries found β add 'human_label' field to dataset entries") | |
| logger.info("Expected labels: REJECT | MEDIUM_PRIORITY | HIGH_PRIORITY") | |
| return | |
| analyzer = DumpindexAnalyzer() | |
| accuracy = analyzer.validate_weights(labeled) | |
| logger.info(f"ADI Validation accuracy: {accuracy:.1%} on {len(labeled)} samples") | |
| result = { | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "accuracy": accuracy, | |
| "samples": len(labeled), | |
| "weights": analyzer.weights, | |
| } | |
| VALID_RESULT.write_text(json.dumps(result, indent=2)) | |
| logger.info(f"Results saved β {VALID_RESULT}") | |
| # ============================================================================= | |
| # Mode 3 β Finetune SmolLM2 with TRL SFTTrainer | |
| # ============================================================================= | |
| def finetune(): | |
| """ | |
| Finetune SmolLM2 on exported dataset using TRL SFTTrainer. | |
| Requires export first + enough data (500+ samples recommended). | |
| On completion: pushes finetuned weights to private HF model repo. | |
| """ | |
| if not TRAIN_DATA.exists(): | |
| logger.error(f"train_data.jsonl not found at {TRAIN_DATA} β run export first") | |
| return | |
| lines = TRAIN_DATA.read_text().strip().splitlines() | |
| logger.info(f"Training samples available: {len(lines)}") | |
| if len(lines) < 10: | |
| logger.error(f"Too few samples ({len(lines)}) β aborting finetune") | |
| return | |
| if len(lines) < 500: | |
| logger.warning(f"Only {len(lines)} samples β recommend 500+ for meaningful finetuning") | |
| # ββ Imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from trl import SFTTrainer, SFTConfig | |
| from datasets import Dataset | |
| import torch | |
| except ImportError as e: | |
| logger.error(f"Missing dependency: {e} β run: pip install trl transformers datasets torch") | |
| return | |
| # ββ Load dataset ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Loading training data...") | |
| records = [json.loads(l) for l in lines] | |
| def format_record(record): | |
| """Format record into chat template string.""" | |
| instruction = record.get("instruction", "You are a helpful assistant.") | |
| user_input = record.get("input", "") | |
| output = record.get("output", "") | |
| return { | |
| "text": f"<|system|>\n{instruction}\n<|user|>\n{user_input}\n<|assistant|>\n{output}" | |
| } | |
| formatted = [format_record(r) for r in records] | |
| dataset = Dataset.from_list(formatted) | |
| logger.info(f"Dataset ready: {len(dataset)} samples") | |
| # ββ Load model + tokenizer ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| model_id = model_module.get_model_id() | |
| kwargs = model_module.get_model_kwargs() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logger.info(f"Loading base model: {model_id} on {device}...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, **kwargs) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) | |
| # Ensure pad token exists | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # ββ Training config βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Conservative settings for CPU / low RAM (2-8GB) | |
| sft_config = SFTConfig( | |
| output_dir=str(MODEL_OUTPUT), | |
| num_train_epochs=3, | |
| per_device_train_batch_size=1, # CPU friendly | |
| gradient_accumulation_steps=4, # effective batch size = 4 | |
| learning_rate=2e-5, | |
| warmup_steps=10, | |
| logging_steps=10, | |
| save_steps=50, | |
| save_total_limit=2, | |
| fp16=False, # no GPU, no fp16 | |
| bf16=False, | |
| dataloader_num_workers=0, # HF Spaces: no multiprocessing | |
| report_to="none", # no wandb/tensorboard | |
| max_seq_length=512, # SmolLM2 context limit | |
| dataset_text_field="text", | |
| ) | |
| # ββ SFTTrainer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Initializing SFTTrainer...") | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=sft_config, | |
| train_dataset=dataset, | |
| #tokenizer=tokenizer, | |
| ) | |
| # ββ Train βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logger.info("Starting finetuning...") | |
| start = datetime.utcnow() | |
| trainer.train() | |
| duration = (datetime.utcnow() - start).total_seconds() | |
| logger.info(f"Training complete in {duration:.0f}s") | |
| # ββ Save locally ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| trainer.save_model(str(MODEL_OUTPUT)) | |
| tokenizer.save_pretrained(str(MODEL_OUTPUT)) | |
| logger.info(f"Model saved β {MODEL_OUTPUT}") | |
| # ββ Push to HF private repo βββββββββββββββββββββββββββββββββββββββββββββββ | |
| token = model_module.TOKEN | |
| private_repo = model_module.PRIVATE_MODEL | |
| if token and private_repo: | |
| logger.info(f"Pushing to HF: {private_repo}...") | |
| try: | |
| model.push_to_hub(private_repo, token=token, private=True) | |
| tokenizer.push_to_hub(private_repo, token=token, private=True) | |
| model_module.push_model_card({ | |
| "model_id": model_id, | |
| "samples": len(dataset), | |
| "epochs": 3, | |
| "duration_sec": int(duration), | |
| "finetuned_from": model_id, | |
| }) | |
| logger.info(f"Model pushed β {private_repo}") | |
| except Exception as e: | |
| logger.error(f"Push failed: {type(e).__name__}: {e}") | |
| else: | |
| logger.warning("No token or private repo configured β skipping HF push") | |
| # ============================================================================= | |
| # CLI | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="SmolLM2 Training Utilities") | |
| parser.add_argument( | |
| "--mode", | |
| choices=["export", "validate", "finetune"], | |
| required=True, | |
| help="export: dump dataset to JSONL | validate: test ADI weights | finetune: train model" | |
| ) | |
| parser.add_argument("--output", default=None, help="Output file for export mode (default: auto)") | |
| args = parser.parse_args() | |
| if args.mode == "export": | |
| export_dataset(args.output) | |
| elif args.mode == "validate": | |
| validate_adi() | |
| elif args.mode == "finetune": | |
| finetune() | |