stmasson commited on
Commit
334164b
·
verified ·
1 Parent(s): e2a4a21

Upload scripts/train_sft_n8n_multitask.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/train_sft_n8n_multitask.py +161 -0
scripts/train_sft_n8n_multitask.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "transformers>=4.46.0",
6
+ # "accelerate>=0.24.0",
7
+ # "peft>=0.7.0",
8
+ # "trackio",
9
+ # "bitsandbytes",
10
+ # "sentencepiece",
11
+ # "protobuf",
12
+ # ]
13
+ # ///
14
+
15
+ """
16
+ SFT training for n8n agentic multi-task workflows.
17
+
18
+ Continues fine-tuning from stmasson/mistral-7b-n8n-thinking-orpo (ORPO-trained model)
19
+ on the n8n-agentic-multitask dataset for complex multi-step tasks:
20
+ - generate: Create n8n workflows from descriptions
21
+ - edit: Modify existing workflows
22
+ - fix: Repair broken workflows
23
+ - improve: Optimize and enhance workflows
24
+ - explain: Describe what workflows do
25
+ - debug: Diagnose workflow issues
26
+
27
+ The model learns to use <thinking> tags for chain-of-thought reasoning
28
+ before producing structured JSON outputs.
29
+ """
30
+
31
+ import trackio
32
+ import torch
33
+ from datasets import load_dataset
34
+ from peft import LoraConfig, PeftModel
35
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
36
+ from trl import SFTTrainer, SFTConfig
37
+
38
+
39
+ # Load multitask dataset
40
+ print("Loading n8n-agentic-multitask dataset...")
41
+ train_dataset = load_dataset(
42
+ "stmasson/n8n-agentic-multitask",
43
+ data_files="data/multitask_large/train.jsonl",
44
+ split="train"
45
+ )
46
+ eval_dataset = load_dataset(
47
+ "stmasson/n8n-agentic-multitask",
48
+ data_files="data/multitask_large/val.jsonl",
49
+ split="train"
50
+ )
51
+
52
+ print(f"Train: {len(train_dataset)} examples")
53
+ print(f"Eval: {len(eval_dataset)} examples")
54
+
55
+ # Load tokenizer from ORPO-trained model
56
+ MODEL_NAME = "stmasson/mistral-7b-n8n-thinking-orpo"
57
+ BASE_MODEL = "stmasson/mistral-7b-n8n-workflows"
58
+
59
+ print(f"Loading tokenizer from {MODEL_NAME}...")
60
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
61
+ if tokenizer.pad_token is None:
62
+ tokenizer.pad_token = tokenizer.eos_token
63
+
64
+ # Step 1: Load base model WITHOUT quantization to merge ORPO adapter
65
+ print(f"Loading base model {BASE_MODEL} (full precision for merge)...")
66
+ base_model = AutoModelForCausalLM.from_pretrained(
67
+ BASE_MODEL,
68
+ torch_dtype=torch.bfloat16,
69
+ device_map="auto",
70
+ attn_implementation="sdpa",
71
+ )
72
+
73
+ print(f"Loading ORPO adapter from {MODEL_NAME}...")
74
+ model = PeftModel.from_pretrained(base_model, MODEL_NAME)
75
+
76
+ print("Merging ORPO adapter into base model...")
77
+ model = model.merge_and_unload()
78
+ print("ORPO adapter merged successfully!")
79
+
80
+ # Step 2: Prepare for LoRA training with gradient checkpointing
81
+ model.gradient_checkpointing_enable()
82
+ model.enable_input_require_grads()
83
+
84
+ # New LoRA configuration for SFT training
85
+ lora_config = LoraConfig(
86
+ r=32,
87
+ lora_alpha=64,
88
+ lora_dropout=0.05,
89
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
90
+ task_type="CAUSAL_LM",
91
+ )
92
+
93
+ # SFT training configuration
94
+ config = SFTConfig(
95
+ # Hub settings
96
+ output_dir="mistral-7b-n8n-agentic-multitask",
97
+ push_to_hub=True,
98
+ hub_model_id="stmasson/mistral-7b-n8n-agentic-multitask",
99
+ hub_strategy="every_save",
100
+ hub_private_repo=False,
101
+
102
+ # Training parameters
103
+ num_train_epochs=1, # Large dataset, 1 epoch is enough
104
+ per_device_train_batch_size=1,
105
+ gradient_accumulation_steps=32, # Effective batch size = 32
106
+ learning_rate=2e-5, # Lower LR for continued fine-tuning
107
+ max_length=4096, # Longer context for complex workflows
108
+
109
+ # Memory optimization
110
+ gradient_checkpointing=True,
111
+ bf16=True,
112
+
113
+ # Logging & checkpointing
114
+ logging_steps=25,
115
+ save_strategy="steps",
116
+ save_steps=500,
117
+ save_total_limit=3,
118
+
119
+ # Evaluation
120
+ eval_strategy="steps",
121
+ eval_steps=500,
122
+
123
+ # Optimization
124
+ warmup_ratio=0.03,
125
+ lr_scheduler_type="cosine",
126
+ optim="adamw_8bit",
127
+
128
+ # Monitoring
129
+ report_to="trackio",
130
+ project="n8n-agentic-training",
131
+ run_name="mistral-7b-multitask-sft",
132
+ )
133
+
134
+ # Initialize trainer
135
+ print("Initializing SFT trainer...")
136
+ trainer = SFTTrainer(
137
+ model=model,
138
+ processing_class=tokenizer,
139
+ train_dataset=train_dataset,
140
+ eval_dataset=eval_dataset,
141
+ peft_config=lora_config,
142
+ args=config,
143
+ )
144
+
145
+ print("Starting SFT training...")
146
+ print(f" Base: stmasson/mistral-7b-n8n-thinking-orpo (merged)")
147
+ print(f" Dataset: stmasson/n8n-agentic-multitask")
148
+ print(f" Output: stmasson/mistral-7b-n8n-agentic-multitask")
149
+ print(f" Tasks: generate, edit, fix, improve, explain, debug")
150
+
151
+ trainer.train()
152
+
153
+ print("Pushing final model to Hub...")
154
+ trainer.push_to_hub()
155
+
156
+ # Finish Trackio
157
+ trackio.finish()
158
+
159
+ print("Training complete!")
160
+ print("Model: https://huggingface.co/stmasson/mistral-7b-n8n-agentic-multitask")
161
+ print("Metrics: https://huggingface.co/spaces/stmasson/trackio")