YUNTA88 commited on
Commit
36ec0aa
·
verified ·
1 Parent(s): 15fa009

Upload scripts/train_sft.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_sft.py +301 -0
scripts/train_sft.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SFT Training Script for Qwen2.5-VL-3B-Instruct on Physics CoT Data.
3
+ Aligned with RL-with-Cold-Start 7B reference configuration.
4
+
5
+ Key changes from previous version:
6
+ - Full fine-tuning (no LoRA) for stronger cold-start
7
+ - Vision encoder NOT frozen (freeze_aligner=false in reference)
8
+ - 3 epochs (not 16) to avoid overfitting
9
+ - Higher image resolution (max_pixels=1204224) matching reference
10
+ - Larger effective batch size (grad_accum=16)
11
+ - DeepSpeed ZeRO-2 for memory efficiency
12
+ - Lower learning rate (1e-5) appropriate for full FT
13
+ """
14
+ import os
15
+ import json
16
+ import torch
17
+ from PIL import Image
18
+ from torch.utils.data import Dataset
19
+ from transformers import (
20
+ Qwen2_5_VLForConditionalGeneration,
21
+ AutoProcessor,
22
+ TrainingArguments,
23
+ Trainer,
24
+ )
25
+
26
+
27
+ # ===== Configuration =====
28
+ MODEL_NAME = "/workspace/rl4phyx/models/Qwen2.5-VL-3B-Instruct"
29
+ DATA_PATH = "/workspace/rl4phyx/RL4Phyx/SFT/sft_train/coldstart_formatted.jsonl"
30
+ OUTPUT_DIR = "/workspace/rl4phyx/RL4Phyx/SFT/checkpoints/sft_qwen25vl_3b_fullft"
31
+
32
+ # Training hyperparameters (aligned with 7B reference)
33
+ NUM_EPOCHS = 3 # Reference uses 3 epochs
34
+ LEARNING_RATE = 1e-5 # Full FT uses lower LR than LoRA
35
+ PER_DEVICE_BATCH_SIZE = 1 # Small batch for VLM
36
+ GRAD_ACCUM_STEPS = 8 # Effective batch = 1 * 8 GPUs * 8 = 64
37
+ MAX_LENGTH = 4096 # Max total sequence length
38
+ FREEZE_VISION = False # Reference: freeze_aligner=false
39
+
40
+
41
+ class PhysicsCoTDataset(Dataset):
42
+ """Dataset for Qwen2.5-VL SFT with physics CoT."""
43
+
44
+ def __init__(self, data_path, processor, max_length=4096):
45
+ self.processor = processor
46
+ self.max_length = max_length
47
+
48
+ with open(data_path, 'r', encoding='utf-8') as f:
49
+ self.records = [json.loads(line) for line in f]
50
+
51
+ print(f"Loaded {len(self.records)} records from {data_path}")
52
+
53
+ def __len__(self):
54
+ return len(self.records)
55
+
56
+ def __getitem__(self, idx):
57
+ record = self.records[idx]
58
+ messages = record['messages']
59
+
60
+ # Extract image path from user message
61
+ user_msg = messages[0]
62
+ image_path = None
63
+ text_content = ""
64
+
65
+ for content in user_msg['content']:
66
+ if content['type'] == 'image':
67
+ image_path = content['image'].replace('file://', '')
68
+ elif content['type'] == 'text':
69
+ text_content = content['text']
70
+
71
+ # Extract assistant response
72
+ assistant_msg = messages[1]
73
+ assistant_text = assistant_msg['content'][0]['text']
74
+
75
+ # Load image
76
+ image = Image.open(image_path).convert('RGB')
77
+ # Ensure minimum image size for Qwen2.5-VL vision encoder (factor=28)
78
+ # Strategy: scale up proportionally (preserve aspect ratio), then pad with white
79
+ MIN_DIM = 56 # Must be >= 28, use 56 for safety (2*factor)
80
+ w, h = image.size
81
+ if w < MIN_DIM or h < MIN_DIM:
82
+ # Scale proportionally so the smaller dimension reaches MIN_DIM
83
+ scale = max(MIN_DIM / w, MIN_DIM / h)
84
+ new_w = int(w * scale)
85
+ new_h = int(h * scale)
86
+ image = image.resize((new_w, new_h), Image.LANCZOS)
87
+ # Pad with white if any dimension still < MIN_DIM (shouldn't happen, but safety)
88
+ if new_w < MIN_DIM or new_h < MIN_DIM:
89
+ from PIL import ImageOps
90
+ padded = Image.new('RGB', (max(new_w, MIN_DIM), max(new_h, MIN_DIM)), (255, 255, 255))
91
+ padded.paste(image, (0, 0))
92
+ image = padded
93
+
94
+ # Build conversation for apply_chat_template
95
+ conversation = [
96
+ {
97
+ "role": "user",
98
+ "content": [
99
+ {"type": "image", "image": image},
100
+ {"type": "text", "text": text_content},
101
+ ],
102
+ },
103
+ {
104
+ "role": "assistant",
105
+ "content": [
106
+ {"type": "text", "text": assistant_text},
107
+ ],
108
+ },
109
+ ]
110
+
111
+ # Use processor to create inputs
112
+ text = self.processor.apply_chat_template(
113
+ conversation,
114
+ tokenize=False,
115
+ add_generation_prompt=False,
116
+ )
117
+
118
+ inputs = self.processor(
119
+ text=[text],
120
+ images=[image],
121
+ padding=False,
122
+ truncation=True,
123
+ max_length=self.max_length,
124
+ return_tensors="pt",
125
+ )
126
+
127
+ # Squeeze batch dimension
128
+ input_ids = inputs['input_ids'].squeeze(0)
129
+ attention_mask = inputs['attention_mask'].squeeze(0)
130
+
131
+ # Create labels: mask user tokens (only train on assistant response)
132
+ labels = input_ids.clone()
133
+
134
+ # Find the assistant turn start token and mask everything before it
135
+ assistant_token_str = "<|im_start|>assistant\n"
136
+ assistant_token_ids = self.processor.tokenizer.encode(
137
+ assistant_token_str, add_special_tokens=False
138
+ )
139
+ input_ids_list = input_ids.tolist()
140
+ assistant_start = -1
141
+ for i in range(len(input_ids_list) - len(assistant_token_ids) + 1):
142
+ if input_ids_list[i:i + len(assistant_token_ids)] == assistant_token_ids:
143
+ assistant_start = i + len(assistant_token_ids)
144
+ break
145
+
146
+ if assistant_start > 0:
147
+ labels[:assistant_start] = -100 # Mask user prompt
148
+ else:
149
+ raise ValueError(f"FATAL: assistant start token not found in sample {idx}.")
150
+
151
+ # Also mask padding
152
+ labels[attention_mask == 0] = -100
153
+
154
+ return {
155
+ 'input_ids': input_ids,
156
+ 'attention_mask': attention_mask,
157
+ 'labels': labels,
158
+ 'pixel_values': inputs.get('pixel_values', torch.tensor([])).squeeze(0) if 'pixel_values' in inputs else None,
159
+ 'image_grid_thw': inputs.get('image_grid_thw', torch.tensor([])).squeeze(0) if 'image_grid_thw' in inputs else None,
160
+ }
161
+
162
+
163
+ class VLMDataCollator:
164
+ """Custom data collator for variable-length VLM inputs."""
165
+
166
+ def __init__(self, processor):
167
+ self.processor = processor
168
+ self.pad_token_id = processor.tokenizer.pad_token_id or processor.tokenizer.eos_token_id
169
+
170
+ def __call__(self, features):
171
+ max_len = max(f['input_ids'].size(0) for f in features)
172
+
173
+ input_ids = []
174
+ attention_mask = []
175
+ labels = []
176
+ pixel_values = []
177
+ image_grid_thw = []
178
+
179
+ for f in features:
180
+ seq_len = f['input_ids'].size(0)
181
+ pad_len = max_len - seq_len
182
+
183
+ input_ids.append(torch.cat([
184
+ f['input_ids'],
185
+ torch.full((pad_len,), self.pad_token_id, dtype=f['input_ids'].dtype)
186
+ ]))
187
+ attention_mask.append(torch.cat([
188
+ f['attention_mask'],
189
+ torch.zeros(pad_len, dtype=f['attention_mask'].dtype)
190
+ ]))
191
+ labels.append(torch.cat([
192
+ f['labels'],
193
+ torch.full((pad_len,), -100, dtype=f['labels'].dtype)
194
+ ]))
195
+
196
+ if f.get('pixel_values') is not None:
197
+ pixel_values.append(f['pixel_values'])
198
+ if f.get('image_grid_thw') is not None:
199
+ image_grid_thw.append(f['image_grid_thw'])
200
+
201
+ batch = {
202
+ 'input_ids': torch.stack(input_ids),
203
+ 'attention_mask': torch.stack(attention_mask),
204
+ 'labels': torch.stack(labels),
205
+ }
206
+
207
+ if pixel_values:
208
+ batch['pixel_values'] = torch.cat(pixel_values, dim=0)
209
+ if image_grid_thw:
210
+ batch['image_grid_thw'] = torch.stack(image_grid_thw)
211
+
212
+ return batch
213
+
214
+
215
+ def main():
216
+ print(f"Loading model: {MODEL_NAME}")
217
+ print(f"Data: {DATA_PATH}")
218
+ print(f"Output: {OUTPUT_DIR}")
219
+ print(f"Full FT (no LoRA), Freeze Vision: {FREEZE_VISION}")
220
+ print(f"Epochs: {NUM_EPOCHS}, LR: {LEARNING_RATE}, Batch: {PER_DEVICE_BATCH_SIZE} x {GRAD_ACCUM_STEPS}")
221
+
222
+ # Load processor (higher resolution matching 7B reference)
223
+ processor = AutoProcessor.from_pretrained(
224
+ MODEL_NAME,
225
+ min_pixels=3136, # 56x56
226
+ max_pixels=1204224, # ~1100x1100, matching reference MAX_PIXELS
227
+ )
228
+
229
+ # Load model
230
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
231
+ MODEL_NAME,
232
+ torch_dtype=torch.bfloat16,
233
+ attn_implementation="sdpa",
234
+ )
235
+
236
+ # Vision encoder: NOT frozen (matching reference freeze_aligner=false)
237
+ if FREEZE_VISION:
238
+ for name, param in model.named_parameters():
239
+ if 'visual' in name:
240
+ param.requires_grad = False
241
+ print("Froze vision encoder parameters")
242
+ else:
243
+ print("Vision encoder is trainable (matching 7B reference)")
244
+
245
+ # Full fine-tuning: enable input grads for gradient checkpointing
246
+ model.enable_input_require_grads()
247
+
248
+ # Create dataset
249
+ dataset = PhysicsCoTDataset(data_path=DATA_PATH, processor=processor, max_length=MAX_LENGTH)
250
+
251
+ # Training arguments
252
+ training_args = TrainingArguments(
253
+ output_dir=OUTPUT_DIR,
254
+ num_train_epochs=NUM_EPOCHS,
255
+ per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
256
+ gradient_accumulation_steps=GRAD_ACCUM_STEPS,
257
+ learning_rate=LEARNING_RATE,
258
+ lr_scheduler_type="cosine",
259
+ warmup_ratio=0.03, # Matching reference
260
+ weight_decay=0.01,
261
+ bf16=True,
262
+ logging_steps=10,
263
+ save_strategy="steps",
264
+ save_steps=20, # Matching reference
265
+ save_total_limit=2, # Matching reference
266
+ eval_steps=20, # Matching reference
267
+ dataloader_num_workers=4,
268
+ gradient_checkpointing=True,
269
+ gradient_checkpointing_kwargs={'use_reentrant': False},
270
+ remove_unused_columns=False,
271
+ report_to="none",
272
+ deepspeed="ds_zero2.json", # DeepSpeed ZeRO-2 for full FT
273
+ save_only_model=True, # Matching reference
274
+ )
275
+
276
+ # Collator
277
+ collator = VLMDataCollator(processor)
278
+
279
+ # Trainer
280
+ trainer = Trainer(
281
+ model=model,
282
+ args=training_args,
283
+ train_dataset=dataset,
284
+ data_collator=collator,
285
+ )
286
+
287
+ # Train
288
+ print("\n===== Starting SFT Training (Full FT, aligned with 7B reference) =====")
289
+ trainer.train()
290
+
291
+ # Save final model
292
+ print("\n===== Saving final model =====")
293
+ trainer.save_model(os.path.join(OUTPUT_DIR, "final"))
294
+ processor.save_pretrained(os.path.join(OUTPUT_DIR, "final"))
295
+ print(f"Final model saved to: {os.path.join(OUTPUT_DIR, 'final')}")
296
+
297
+ print("\n===== SFT Training Complete =====")
298
+
299
+
300
+ if __name__ == "__main__":
301
+ main()