Codyfederer commited on
Commit
5d5708e
·
verified ·
1 Parent(s): 8bca88b

Upload train_h100.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_h100.py +188 -0
train_h100.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "torch>=2.0.0",
5
+ # "transformers>=4.50.0",
6
+ # "datasets>=2.14.0",
7
+ # "peft>=0.7.0",
8
+ # "accelerate>=0.25.0",
9
+ # "trackio",
10
+ # "huggingface_hub",
11
+ # ]
12
+ # ///
13
+ """
14
+ LoRA Fine-tuning: Add Tool Calling to Synthia-S1-27b
15
+ Using pre-tokenized data from Codyfederer/synthia-tool-calling-tokenized
16
+ Optimized for H100 80GB
17
+ """
18
+
19
+ import os
20
+ from datasets import load_dataset
21
+ from transformers import (
22
+ AutoTokenizer,
23
+ AutoModelForCausalLM,
24
+ DataCollatorForLanguageModeling,
25
+ Trainer,
26
+ TrainingArguments,
27
+ )
28
+ from peft import LoraConfig, get_peft_model
29
+ import torch
30
+ import trackio
31
+ from huggingface_hub import whoami
32
+
33
+ # Configuration
34
+ BASE_MODEL = "Tesslate/Synthia-S1-27b"
35
+ OUTPUT_MODEL = "Synthia-S1-27b-tool-calling"
36
+ TOKENIZED_DATASET = "Codyfederer/synthia-tool-calling-tokenized"
37
+ MAX_SEQ_LENGTH = 4096
38
+
39
+ # H100 optimized parameters
40
+ BATCH_SIZE = 4 # Higher batch size for H100 80GB
41
+ GRADIENT_ACCUMULATION = 8 # Effective batch = 32
42
+ LEARNING_RATE = 2e-4
43
+ NUM_EPOCHS = 1
44
+ LORA_R = 64
45
+ LORA_ALPHA = 128
46
+
47
+ print("=" * 60)
48
+ print("Tool Calling Fine-tuning for Synthia-S1-27b (H100)")
49
+ print("=" * 60)
50
+
51
+ # Initialize Trackio
52
+ trackio.init(project="synthia-tool-calling")
53
+
54
+ # Get HF username
55
+ try:
56
+ username = whoami()["name"]
57
+ hub_model_id = f"{username}/{OUTPUT_MODEL}"
58
+ print(f"Will push to: {hub_model_id}")
59
+ except Exception as e:
60
+ print(f"Error getting username: {e}")
61
+ raise
62
+
63
+ # Load tokenizer
64
+ print(f"\nLoading tokenizer from {BASE_MODEL}...")
65
+ tokenizer = AutoTokenizer.from_pretrained(
66
+ BASE_MODEL,
67
+ trust_remote_code=True,
68
+ padding_side="right",
69
+ )
70
+ if tokenizer.pad_token is None:
71
+ tokenizer.pad_token = tokenizer.eos_token
72
+ tokenizer.pad_token_id = tokenizer.eos_token_id
73
+ print(f"Vocab size: {len(tokenizer):,}")
74
+
75
+ # Load pre-tokenized dataset
76
+ print(f"\nLoading pre-tokenized dataset: {TOKENIZED_DATASET}")
77
+ tokenized_ds = load_dataset(TOKENIZED_DATASET)
78
+
79
+ train_dataset = tokenized_ds["train"]
80
+ eval_dataset = tokenized_ds.get("test", tokenized_ds.get("validation"))
81
+
82
+ print(f"Train samples: {len(train_dataset):,}")
83
+ if eval_dataset:
84
+ print(f"Eval samples: {len(eval_dataset):,}")
85
+
86
+ # Truncate to MAX_SEQ_LENGTH
87
+ def truncate_example(example):
88
+ return {
89
+ "input_ids": example["input_ids"][:MAX_SEQ_LENGTH],
90
+ "attention_mask": example["attention_mask"][:MAX_SEQ_LENGTH],
91
+ "labels": example["labels"][:MAX_SEQ_LENGTH] if "labels" in example else example["input_ids"][:MAX_SEQ_LENGTH],
92
+ }
93
+
94
+ print(f"Truncating to max_length={MAX_SEQ_LENGTH}...")
95
+ train_dataset = train_dataset.map(truncate_example, desc="Truncating train")
96
+ if eval_dataset:
97
+ eval_dataset = eval_dataset.map(truncate_example, desc="Truncating eval")
98
+
99
+ # Load model
100
+ print(f"\nLoading model: {BASE_MODEL}...")
101
+ model = AutoModelForCausalLM.from_pretrained(
102
+ BASE_MODEL,
103
+ device_map="auto",
104
+ trust_remote_code=True,
105
+ torch_dtype=torch.bfloat16,
106
+ attn_implementation="sdpa",
107
+ )
108
+ print(f"Model loaded. Parameters: {model.num_parameters():,}")
109
+
110
+ # Configure LoRA
111
+ print(f"\nConfiguring LoRA (r={LORA_R}, alpha={LORA_ALPHA})...")
112
+ lora_config = LoraConfig(
113
+ r=LORA_R,
114
+ lora_alpha=LORA_ALPHA,
115
+ lora_dropout=0.05,
116
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
117
+ bias="none",
118
+ task_type="CAUSAL_LM",
119
+ )
120
+ model = get_peft_model(model, lora_config)
121
+ model.print_trainable_parameters()
122
+
123
+ # Training arguments - H100 optimized
124
+ print("\nConfiguring training...")
125
+ training_args = TrainingArguments(
126
+ output_dir=f"./{OUTPUT_MODEL}",
127
+ num_train_epochs=NUM_EPOCHS,
128
+ per_device_train_batch_size=BATCH_SIZE,
129
+ per_device_eval_batch_size=BATCH_SIZE,
130
+ gradient_accumulation_steps=GRADIENT_ACCUMULATION,
131
+ learning_rate=LEARNING_RATE,
132
+ lr_scheduler_type="cosine",
133
+ warmup_ratio=0.03,
134
+ weight_decay=0.01,
135
+ optim="adamw_torch",
136
+ gradient_checkpointing=True,
137
+ gradient_checkpointing_kwargs={"use_reentrant": False},
138
+ max_grad_norm=1.0,
139
+ eval_strategy="steps",
140
+ eval_steps=500,
141
+ save_strategy="steps",
142
+ save_steps=500,
143
+ save_total_limit=3,
144
+ push_to_hub=True,
145
+ hub_model_id=hub_model_id,
146
+ hub_strategy="checkpoint",
147
+ logging_steps=10,
148
+ report_to="trackio",
149
+ run_name=f"synthia-tool-calling-lora-r{LORA_R}",
150
+ bf16=True,
151
+ dataloader_num_workers=4,
152
+ dataloader_pin_memory=True,
153
+ seed=42,
154
+ remove_unused_columns=False,
155
+ )
156
+
157
+ # Initialize trainer
158
+ print("\nInitializing trainer...")
159
+ data_collator = DataCollatorForLanguageModeling(
160
+ tokenizer=tokenizer,
161
+ mlm=False,
162
+ )
163
+
164
+ trainer = Trainer(
165
+ model=model,
166
+ args=training_args,
167
+ train_dataset=train_dataset,
168
+ eval_dataset=eval_dataset,
169
+ tokenizer=tokenizer,
170
+ data_collator=data_collator,
171
+ )
172
+
173
+ # Train
174
+ print("\n" + "=" * 60)
175
+ print("Starting training...")
176
+ print("=" * 60 + "\n")
177
+ trainer.train()
178
+
179
+ # Save and push
180
+ print("\nSaving final model...")
181
+ trainer.save_model()
182
+ print(f"Pushing to Hub: {hub_model_id}")
183
+ trainer.push_to_hub()
184
+
185
+ print(f"\n" + "=" * 60)
186
+ print(f"Training complete!")
187
+ print(f"Model available at: https://huggingface.co/{hub_model_id}")
188
+ print("=" * 60)