shaikhsalman commited on
Commit
3aaeeb3
·
verified ·
1 Parent(s): d678e13

Upload ai-ml/hf-finetuning/train_openthoughts.py with huggingface_hub

Browse files
ai-ml/hf-finetuning/train_openthoughts.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Llama-3.1-8B-Instruct on open-thoughts/OpenThoughts-114k (reasoning CoT).
3
+
4
+ This dataset contains DeepSeek-R1 distilled reasoning traces.
5
+ Focuses on: math, code, science with chain-of-thought thinking.
6
+
7
+ Uses LoRA Without Regret config (r=256, all-linear).
8
+ Smaller dataset (114K) so uses higher LR and fewer epochs.
9
+
10
+ Usage:
11
+ python train_openthoughts.py
12
+ python train_openthoughts.py --max_steps 50 # quick test
13
+ """
14
+
15
+ import argparse
16
+ import torch
17
+ from datasets import load_dataset
18
+ from peft import LoraConfig
19
+ from trl import SFTTrainer, SFTConfig
20
+ import trackio
21
+
22
+
23
+ def convert_openthoughts(example):
24
+ """Convert ShareGPT format to messages format."""
25
+ messages = []
26
+ if example.get("system"):
27
+ messages.append({"role": "system", "content": example["system"]})
28
+ for turn in example["conversations"]:
29
+ role = "user" if turn["from"] == "user" else "assistant"
30
+ messages.append({"role": role, "content": turn["value"]})
31
+ return {"messages": messages}
32
+
33
+
34
+ def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-openthoughts-lora"):
35
+
36
+ trackio.init(
37
+ project="devsecops-ml",
38
+ name="sft-llama3.1-8b-openthoughts",
39
+ config={
40
+ "model": "meta-llama/Llama-3.1-8B-Instruct",
41
+ "dataset": "open-thoughts/OpenThoughts-114k",
42
+ "lora_r": 256,
43
+ "lora_alpha": 16,
44
+ "target_modules": "all-linear",
45
+ "learning_rate": 2e-4,
46
+ },
47
+ )
48
+
49
+ # Load and convert
50
+ print("Loading open-thoughts/OpenThoughts-114k...")
51
+ dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
52
+ print(f"Loaded {len(dataset)} examples (raw format)")
53
+
54
+ remove_cols = [c for c in dataset.column_names if c != "messages"]
55
+ dataset = dataset.map(convert_openthoughts, remove_columns=remove_cols)
56
+ print(f"Converted to messages format: {len(dataset)} examples")
57
+
58
+ # LoRA Without Regret
59
+ peft_config = LoraConfig(
60
+ r=256,
61
+ lora_alpha=16,
62
+ lora_dropout=0.05,
63
+ bias="none",
64
+ task_type="CAUSAL_LM",
65
+ target_modules="all-linear",
66
+ )
67
+
68
+ # Smaller dataset = higher LR + more epochs
69
+ training_args = SFTConfig(
70
+ output_dir="./output/llama3.1-8b-openthoughts-lora",
71
+ push_to_hub=push_hub,
72
+ hub_model_id=hub_model_id,
73
+ model_init_kwargs={
74
+ "torch_dtype": torch.bfloat16,
75
+ "attn_implementation": "flash_attention_2",
76
+ },
77
+ learning_rate=2e-4,
78
+ per_device_train_batch_size=2,
79
+ gradient_accumulation_steps=8, # effective batch = 16
80
+ num_train_epochs=2,
81
+ lr_scheduler_type="cosine",
82
+ warmup_ratio=0.1,
83
+ max_seq_length=4096,
84
+ packing=True,
85
+ packing_strategy="bfd_split",
86
+ gradient_checkpointing=True,
87
+ bf16=True,
88
+ assistant_only_loss=True,
89
+ eos_token="<|eot_id|>",
90
+ logging_strategy="steps",
91
+ logging_steps=25,
92
+ logging_first_step=True,
93
+ report_to=["trackio"],
94
+ disable_tqdm=True,
95
+ save_strategy="steps",
96
+ save_steps=500,
97
+ save_total_limit=3,
98
+ optim="adamw_torch",
99
+ )
100
+
101
+ if max_steps:
102
+ training_args.max_steps = max_steps
103
+
104
+ trainer = SFTTrainer(
105
+ model="meta-llama/Llama-3.1-8B-Instruct",
106
+ train_dataset=dataset,
107
+ peft_config=peft_config,
108
+ args=training_args,
109
+ )
110
+
111
+ trainer.train()
112
+
113
+ if push_hub:
114
+ trainer.push_to_hub()
115
+ print(f"Model pushed to: https://huggingface.co/{hub_model_id}")
116
+
117
+ trackio.finish()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = argparse.ArgumentParser()
122
+ parser.add_argument("--max_steps", type=int, default=None)
123
+ parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-openthoughts-lora")
124
+ parser.add_argument("--no_push", action="store_true")
125
+ args = parser.parse_args()
126
+ train(max_steps=args.max_steps, push_hub=not args.no_push, hub_model_id=args.hub_model_id)