vgtomahawk commited on
Commit
78d2009
·
verified ·
1 Parent(s): 218bb54

Upload train_sft_qwen.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_sft_qwen.py +48 -39
train_sft_qwen.py CHANGED
@@ -1,10 +1,17 @@
1
  # /// script
2
- # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch", "transformers>=4.40.0"]
 
 
 
 
 
 
 
3
  # ///
4
 
5
  """
6
- SFT Fine-tuning Script for Qwen/Qwen2.5-0.5B
7
- Optimized for Hugging Face Jobs with Trackio monitoring
8
  """
9
 
10
  from datasets import load_dataset
@@ -12,81 +19,83 @@ from peft import LoraConfig
12
  from trl import SFTTrainer, SFTConfig
13
  import trackio
14
 
15
- # Load dataset - using TRL-compatible Capybara dataset
16
- print("Loading dataset...")
17
  dataset = load_dataset("trl-lib/Capybara", split="train")
18
 
19
- # Create train/eval split for monitoring progress
20
  dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
21
- print(f"Train size: {len(dataset_split['train'])}, Eval size: {len(dataset_split['test'])}")
22
 
23
- # LoRA configuration for efficient fine-tuning
24
  peft_config = LoraConfig(
25
- r=16, # LoRA rank
26
- lora_alpha=32, # LoRA alpha (scaling factor)
27
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Attention layers
28
- lora_dropout=0.05, # Dropout for regularization
29
- bias="none", # Don't train bias terms
30
- task_type="CAUSAL_LM" # Causal language modeling
31
  )
32
 
33
- # Training configuration
34
- training_args = SFTConfig(
35
- # Output and Hub settings
36
- output_dir="qwen-0.5b-sft-capybara",
 
 
37
  push_to_hub=True,
38
- hub_model_id="vgtomahawk/qwen-0.5b-sft-capybara",
39
- hub_strategy="every_save",
40
  hub_private_repo=False,
41
 
42
  # Training parameters
43
  num_train_epochs=3,
44
- per_device_train_batch_size=4,
45
- per_device_eval_batch_size=4,
46
- gradient_accumulation_steps=4, # Effective batch size: 4 * 4 = 16
47
- gradient_checkpointing=True,
48
 
49
  # Optimization
50
  learning_rate=2e-4,
51
  lr_scheduler_type="cosine",
52
  warmup_ratio=0.1,
53
- optim="paged_adamw_8bit", # Memory-efficient optimizer
54
 
55
- # Evaluation and logging
56
  eval_strategy="steps",
57
  eval_steps=50,
58
- logging_steps=10,
 
 
59
  save_strategy="steps",
60
  save_steps=100,
61
  save_total_limit=3, # Keep only last 3 checkpoints
62
 
63
- # Trackio monitoring
 
64
  report_to="trackio",
65
- run_name="qwen-0.5b-sft-capybara-test",
 
 
 
 
66
 
67
- # Performance
68
- bf16=True, # Use bfloat16 for better numerical stability
69
  dataloader_num_workers=4,
70
- remove_unused_columns=True,
71
  )
72
 
73
  # Initialize trainer
74
- print("Initializing SFT Trainer...")
75
  trainer = SFTTrainer(
76
  model="Qwen/Qwen2.5-0.5B",
77
  train_dataset=dataset_split["train"],
78
  eval_dataset=dataset_split["test"],
79
  peft_config=peft_config,
80
- args=training_args,
81
  )
82
 
83
  # Train the model
84
  print("Starting training...")
85
  trainer.train()
86
 
87
- # Save final model to Hub
88
- print("Pushing final model to Hub...")
89
- trainer.push_to_hub(commit_message="Training completed")
90
 
91
- print("✅ Training completed successfully!")
92
- print(f"Model saved to: https://huggingface.co/{training_args.hub_model_id}")
 
1
  # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "trackio",
6
+ # "transformers>=4.40.0",
7
+ # "datasets>=2.18.0",
8
+ # "torch>=2.0.0",
9
+ # ]
10
  # ///
11
 
12
  """
13
+ SFT (Supervised Fine-Tuning) training script for Qwen/Qwen2.5-0.5B
14
+ Uses TRL with LoRA, Trackio monitoring, and automatic Hub push
15
  """
16
 
17
  from datasets import load_dataset
 
19
  from trl import SFTTrainer, SFTConfig
20
  import trackio
21
 
22
+ # Load a high-quality instruction dataset
 
23
  dataset = load_dataset("trl-lib/Capybara", split="train")
24
 
25
+ # Create train/eval split for monitoring training progress
26
  dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
 
27
 
28
+ # Configure LoRA for efficient fine-tuning
29
  peft_config = LoraConfig(
30
+ r=16, # LoRA rank
31
+ lora_alpha=32, # LoRA alpha scaling
32
+ lora_dropout=0.05, # Dropout for regularization
33
+ bias="none", # Don't train bias parameters
34
+ task_type="CAUSAL_LM", # Causal language modeling
35
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Qwen attention layers
36
  )
37
 
38
+ # Configure trainer
39
+ training_config = SFTConfig(
40
+ # Model and output
41
+ output_dir="qwen-sft-capybara",
42
+
43
+ # Hub configuration - CRITICAL for saving results
44
  push_to_hub=True,
45
+ hub_model_id="qwen-sft-capybara-demo", # Will use format: username/qwen-sft-capybara-demo
46
+ hub_strategy="every_save", # Push checkpoints during training
47
  hub_private_repo=False,
48
 
49
  # Training parameters
50
  num_train_epochs=3,
51
+ per_device_train_batch_size=2,
52
+ gradient_accumulation_steps=4, # Effective batch size: 2 * 4 = 8
 
 
53
 
54
  # Optimization
55
  learning_rate=2e-4,
56
  lr_scheduler_type="cosine",
57
  warmup_ratio=0.1,
 
58
 
59
+ # Evaluation
60
  eval_strategy="steps",
61
  eval_steps=50,
62
+ per_device_eval_batch_size=2,
63
+
64
+ # Checkpointing
65
  save_strategy="steps",
66
  save_steps=100,
67
  save_total_limit=3, # Keep only last 3 checkpoints
68
 
69
+ # Logging - Trackio integration
70
+ logging_steps=10,
71
  report_to="trackio",
72
+ run_name="qwen-0.5b-sft-demo",
73
+
74
+ # Performance optimization
75
+ bf16=True, # Use bfloat16 for better performance on modern GPUs
76
+ gradient_checkpointing=True, # Reduce memory usage
77
 
78
+ # Misc
79
+ seed=42,
80
  dataloader_num_workers=4,
 
81
  )
82
 
83
  # Initialize trainer
 
84
  trainer = SFTTrainer(
85
  model="Qwen/Qwen2.5-0.5B",
86
  train_dataset=dataset_split["train"],
87
  eval_dataset=dataset_split["test"],
88
  peft_config=peft_config,
89
+ args=training_config,
90
  )
91
 
92
  # Train the model
93
  print("Starting training...")
94
  trainer.train()
95
 
96
+ # Final push to Hub
97
+ print("Training complete! Pushing final model to Hub...")
98
+ trainer.push_to_hub()
99
 
100
+ print("✅ Training complete and model saved to Hub!")
101
+ print(f"Model available at: https://huggingface.co/{trainer.hub_model_id}")