ffmpeg-command-generator / train_ffmpeg.py
kingjux's picture
Upload train_ffmpeg.py with huggingface_hub
341f4dc verified
# /// script
# dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "transformers", "datasets", "accelerate", "bitsandbytes"]
# ///
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
# Load the dataset
dataset = load_dataset("kingjux/ffmpeg-commands-cot", split="train")
print(f"Loaded {len(dataset)} training examples")
# LoRA config for efficient fine-tuning
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
# Training config
training_args = SFTConfig(
output_dir="ffmpeg-command-generator",
# Training params
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
learning_rate=2e-4,
warmup_ratio=0.1,
# Logging and saving
logging_steps=5,
save_strategy="epoch",
# Hub settings
push_to_hub=True,
hub_model_id="kingjux/ffmpeg-command-generator",
hub_strategy="every_save",
# Trackio monitoring
report_to="trackio",
run_name="ffmpeg-sft-30examples",
# Memory optimization
gradient_checkpointing=True,
bf16=True,
# Other
seed=42,
max_length=1024,
)
# Create trainer
trainer = SFTTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
)
# Train
print("Starting training...")
trainer.train()
# Save and push
print("Pushing to Hub...")
trainer.save_model()
trainer.push_to_hub()
print("Training complete!")