ScottzillaSystems commited on
Commit
5e55ab0
·
verified ·
1 Parent(s): fa60c5e

Upload examples/demo_dpo_self_healing.py

Browse files
Files changed (1) hide show
  1. examples/demo_dpo_self_healing.py +114 -0
examples/demo_dpo_self_healing.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Demo: Self-Healing DPO Training
4
+ ===============================
5
+ Loads a pretrained model, does DPO with full self-healing.
6
+ DPO-specific: detects plateau at loss≈0.693 (random chance).
7
+
8
+ Usage:
9
+ python demo_dpo_self_healing.py
10
+ """
11
+ import os, sys, json, time
12
+ import torch
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
+ from datasets import load_dataset
15
+
16
+ from self_healing import SelfHealingTrainer, HealingConfig
17
+
18
+ def main():
19
+ print("\n" + "=" * 60)
20
+ print(" SELF-HEALING DPO TRAINING DEMO")
21
+ print("=" * 60 + "\n")
22
+
23
+ model_id = "Qwen/Qwen2.5-0.5B"
24
+ print(f"[1/4] Loading model: {model_id}")
25
+ model = AutoModelForCausalLM.from_pretrained(
26
+ model_id,
27
+ torch_dtype=torch.bfloat16,
28
+ device_map="auto",
29
+ )
30
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
31
+ if tokenizer.pad_token is None:
32
+ tokenizer.pad_token = tokenizer.eos_token
33
+
34
+ # DPO dataset: needs "prompt", "chosen", "rejected"
35
+ print("[2/4] Loading DPO dataset: trl-lib/ultrafeedback_binarized")
36
+ dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:2000]")
37
+
38
+ from trl import DPOConfig, DPOTrainer
39
+
40
+ training_args = DPOConfig(
41
+ output_dir="./dpo-output",
42
+ per_device_train_batch_size=1,
43
+ gradient_accumulation_steps=8,
44
+ learning_rate=5e-6,
45
+ max_steps=200,
46
+ logging_steps=10,
47
+ logging_strategy="steps",
48
+ logging_first_step=True,
49
+ save_steps=500,
50
+ bf16=True,
51
+ beta=0.1, # DPO temperature
52
+ report_to="none",
53
+ run_name="selfheal-dpo-demo",
54
+ disable_tqdm=True,
55
+ )
56
+
57
+ trainer = DPOTrainer(
58
+ model=model,
59
+ args=training_args,
60
+ train_dataset=dataset,
61
+ tokenizer=tokenizer,
62
+ )
63
+
64
+ # Self-healing: more aggressive for DPO (plateau detection is critical)
65
+ print("[3/4] Wrapping with SelfHealingTrainer...")
66
+ healing_config = HealingConfig(
67
+ nan_patience=3,
68
+ loss_spike_factor=5.0,
69
+ divergence_patience=50,
70
+ max_recovery_attempts=5,
71
+ max_lr_reductions=3,
72
+ max_batch_reductions=2,
73
+ zclip_enabled=True,
74
+ zclip_z_threshold=3.0,
75
+ postmortem_path="./dpo-postmortem.json",
76
+ )
77
+
78
+ sh_trainer = SelfHealingTrainer(trainer, healing_config)
79
+
80
+ # Dry-run
81
+ try:
82
+ sh_trainer.dry_run(num_steps=2)
83
+ print(" ✓ Dry-run passed!\n")
84
+ except Exception as e:
85
+ print(f" ✗ Dry-run failed: {e}")
86
+ sys.exit(1)
87
+
88
+ # Train
89
+ print("[4/4] Training DPO with self-healing...\n")
90
+ result = sh_trainer.train()
91
+
92
+ # Report
93
+ print("\n" + "=" * 60)
94
+ print(" DPO DEMO COMPLETE")
95
+ print("=" * 60)
96
+ report = sh_trainer.get_report()
97
+ print(f" Converged: {report['converged']}")
98
+ print(f" Attempts: {report['attempts']}")
99
+ print(f" Recoveries: {report['total_recoveries']}")
100
+
101
+ if report["recovery_history"]:
102
+ print("\n Recovery log:")
103
+ for i, rec in enumerate(report["recovery_history"]):
104
+ print(f" [{i+1}] {rec['failure']}: {rec['actions']}")
105
+
106
+ if os.path.exists(healing_config.postmortem_path):
107
+ with open(healing_config.postmortem_path) as f:
108
+ pm = json.load(f)
109
+ print(f"\n Postmortem: {pm.get('exit_reason', 'unknown')} "
110
+ f"at step {pm.get('last_step', '?')}")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ main()