training-scripts / scripts /train_qwen3_sft_multitask.py
stmasson's picture
Upload scripts/train_qwen3_sft_multitask.py with huggingface_hub
188cdd5 verified
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "transformers>=4.45.0",
# "trl>=0.12.0",
# "peft>=0.13.0",
# "datasets>=3.0.0",
# "accelerate>=1.0.0",
# "huggingface_hub>=0.26.0",
# "torch>=2.4.0",
# ]
# [tool.uv]
# index-strategy = "unsafe-best-match"
# extra-index-url = ["https://download.pytorch.org/whl/cu124"]
# ///
"""
SFT Multi-Task Training Script for n8n Agent
This script fine-tunes the DPO-trained Qwen3-0.6B model on multi-task n8n workflows.
It builds on the reasoning capabilities from DPO training and adds task-specific skills.
Tasks covered:
- generate: Create workflows from descriptions
- edit: Modify existing workflows
- fix: Correct errors in workflows
- explain: Explain what workflows do
- debug: Diagnose execution issues
- improve: Optimize and enhance workflows
Usage:
hf jobs uv run \
--script train_qwen3_sft_multitask.py \
--flavor l4x1 \
--timeout 24h
"""
import os
import json
import torch
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, PeftModel, get_peft_model
from trl import SFTTrainer, SFTConfig
from huggingface_hub import login, hf_hub_download
# ============================================================================
# CONFIGURATION
# ============================================================================
# Base model (the DPO-trained model with reasoning capabilities)
BASE_MODEL = os.environ.get("BASE_MODEL", "Qwen/Qwen3-0.6B")
DPO_ADAPTER = os.environ.get("DPO_ADAPTER", "stmasson/qwen3-0.6b-n8n-reasoning")
# Dataset
DATASET_REPO = "stmasson/n8n-agentic-multitask"
TRAIN_FILE = "data/multitask_large/train.jsonl"
VAL_FILE = "data/multitask_large/val.jsonl"
# Output
OUTPUT_DIR = "./qwen3-sft-multitask"
HF_REPO = os.environ.get("HF_REPO", "stmasson/qwen3-0.6b-n8n-agent")
# Hyperparameters
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", "1"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "1"))
GRAD_ACCUM = int(os.environ.get("GRAD_ACCUM", "8"))
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", "1e-5"))
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", "4096"))
# LoRA (continuing from DPO adapter)
LORA_R = int(os.environ.get("LORA_R", "32"))
LORA_ALPHA = int(os.environ.get("LORA_ALPHA", "64"))
LORA_DROPOUT = float(os.environ.get("LORA_DROPOUT", "0.05"))
# ============================================================================
# AUTHENTICATION
# ============================================================================
print("=" * 60)
print("SFT MULTI-TASK TRAINING - N8N AGENT")
print("=" * 60)
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("Authenticated with HuggingFace")
else:
print("Warning: HF_TOKEN not set, push disabled")
# ============================================================================
# LOAD MODEL WITH DPO ADAPTER
# ============================================================================
print(f"\nLoading base model: {BASE_MODEL}")
print(f"Loading DPO adapter: {DPO_ADAPTER}")
# Load base model
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
device_map="auto",
trust_remote_code=True,
)
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
# Load DPO adapter and merge it into the base model
print("Loading and merging DPO adapter...")
model = PeftModel.from_pretrained(model, DPO_ADAPTER)
model = model.merge_and_unload()
print("DPO adapter merged successfully!")
print(f"Model loaded: {model.config.num_hidden_layers} layers, {model.config.hidden_size} hidden size")
# ============================================================================
# NEW LORA CONFIG FOR SFT
# ============================================================================
print(f"\nNew LoRA config for SFT: r={LORA_R}, alpha={LORA_ALPHA}")
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_dropout=LORA_DROPOUT,
bias="none",
task_type="CAUSAL_LM"
)
# ============================================================================
# LOAD DATASET
# ============================================================================
print(f"\nLoading dataset: {DATASET_REPO}")
def load_jsonl_dataset(repo_id: str, filename: str) -> Dataset:
"""Load JSONL dataset and extract only messages column."""
local_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="dataset"
)
messages_list = []
with open(local_path, 'r', encoding='utf-8') as f:
for line in f:
data = json.loads(line)
messages_list.append({"messages": data["messages"]})
return Dataset.from_list(messages_list)
# Load train and validation
train_dataset = load_jsonl_dataset(DATASET_REPO, TRAIN_FILE)
val_dataset = load_jsonl_dataset(DATASET_REPO, VAL_FILE)
print(f"Train: {len(train_dataset)} examples")
print(f"Validation: {len(val_dataset)} examples")
# Filter out very long examples to avoid OOM
def filter_by_length(example):
"""Filter examples that would be too long."""
total_len = sum(len(m.get('content', '')) for m in example['messages'])
return total_len < 30000 # ~7500 tokens max
print("Filtering long examples...")
train_dataset = train_dataset.filter(filter_by_length)
val_dataset = val_dataset.filter(filter_by_length)
print(f"After filtering - Train: {len(train_dataset)}, Val: {len(val_dataset)}")
# Format examples
def format_example(example):
"""Format messages to text for training."""
messages = example["messages"]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
return {"text": text}
print("Formatting data...")
train_dataset = train_dataset.map(format_example, remove_columns=train_dataset.column_names)
val_dataset = val_dataset.map(format_example, remove_columns=val_dataset.column_names)
# Show example
print("\nExample formatted data:")
print(train_dataset[0]["text"][:500] + "...")
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================
print(f"\nTraining configuration:")
print(f" - Epochs: {NUM_EPOCHS}")
print(f" - Batch size: {BATCH_SIZE}")
print(f" - Gradient accumulation: {GRAD_ACCUM}")
print(f" - Effective batch size: {BATCH_SIZE * GRAD_ACCUM}")
print(f" - Learning rate: {LEARNING_RATE}")
print(f" - Max sequence length: {MAX_SEQ_LENGTH}")
training_args = SFTConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_EPOCHS,
per_device_train_batch_size=BATCH_SIZE,
per_device_eval_batch_size=BATCH_SIZE,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=LEARNING_RATE,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
weight_decay=0.01,
bf16=True,
tf32=True,
logging_steps=50,
save_strategy="steps",
save_steps=1000,
save_total_limit=3,
eval_strategy="steps",
eval_steps=1000,
max_seq_length=MAX_SEQ_LENGTH,
packing=False,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
dataset_text_field="text",
report_to="none",
run_name="qwen3-sft-multitask",
hub_model_id=HF_REPO if hf_token else None,
push_to_hub=bool(hf_token),
hub_strategy="checkpoint",
)
# ============================================================================
# TRAINING
# ============================================================================
print("\nInitializing SFT trainer...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
peft_config=lora_config,
processing_class=tokenizer,
)
# Show trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"\nTrainable parameters: {trainable_params:,} / {total_params:,} ({100 * trainable_params / total_params:.2f}%)")
print("\n" + "=" * 60)
print("STARTING SFT MULTI-TASK TRAINING")
print("=" * 60)
trainer.train()
# ============================================================================
# SAVE MODEL
# ============================================================================
print("\nSaving model...")
trainer.save_model(f"{OUTPUT_DIR}/final")
if hf_token:
print(f"Pushing to {HF_REPO}...")
trainer.push_to_hub()
print(f"Model available at: https://huggingface.co/{HF_REPO}")
print("\n" + "=" * 60)
print("SFT MULTI-TASK TRAINING COMPLETE")
print("=" * 60)