luiscosio commited on
Commit
b0e8599
·
verified ·
1 Parent(s): 2311842

Upload test_qwen3_capybara.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_qwen3_capybara.py +76 -0
test_qwen3_capybara.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl>=0.12.0",
4
+ # "peft>=0.7.0",
5
+ # "transformers>=4.36.0",
6
+ # "accelerate>=0.24.0",
7
+ # "datasets",
8
+ # "torch",
9
+ # ]
10
+ # ///
11
+
12
+ from datasets import load_dataset
13
+ from peft import LoraConfig
14
+ from trl import SFTTrainer, SFTConfig
15
+
16
+ # Load known-working TRL dataset
17
+ print("Loading dataset...")
18
+ dataset = load_dataset("trl-lib/Capybara", split="train")
19
+ print(f"Dataset loaded: {len(dataset)} examples")
20
+
21
+ # Small subset for quick test
22
+ dataset = dataset.shuffle(seed=42).select(range(1000))
23
+ print(f"Using {len(dataset)} examples")
24
+
25
+ # Split
26
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
27
+ train_dataset = dataset_split["train"]
28
+ eval_dataset = dataset_split["test"]
29
+
30
+ # Training configuration
31
+ config = SFTConfig(
32
+ output_dir="qwen3-0.6b-test",
33
+ push_to_hub=True,
34
+ hub_model_id="luiscosio/qwen3-0.6b-test",
35
+ num_train_epochs=1,
36
+ per_device_train_batch_size=2,
37
+ gradient_accumulation_steps=4,
38
+ gradient_checkpointing=True,
39
+ learning_rate=2e-4,
40
+ logging_steps=10,
41
+ save_strategy="steps",
42
+ save_steps=50,
43
+ eval_strategy="steps",
44
+ eval_steps=50,
45
+ warmup_ratio=0.1,
46
+ bf16=True,
47
+ max_length=1024,
48
+ report_to="none",
49
+ )
50
+
51
+ # LoRA configuration
52
+ peft_config = LoraConfig(
53
+ r=16,
54
+ lora_alpha=32,
55
+ lora_dropout=0.05,
56
+ bias="none",
57
+ task_type="CAUSAL_LM",
58
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
59
+ )
60
+
61
+ # Initialize and train
62
+ print("Initializing trainer...")
63
+ trainer = SFTTrainer(
64
+ model="Qwen/Qwen3-0.6B",
65
+ train_dataset=train_dataset,
66
+ eval_dataset=eval_dataset,
67
+ args=config,
68
+ peft_config=peft_config,
69
+ )
70
+
71
+ print("Starting training...")
72
+ trainer.train()
73
+
74
+ print("Pushing to Hub...")
75
+ trainer.push_to_hub()
76
+ print("Done!")