mjschock's picture
Refactor SFTTrainer configuration in train.py to remove data_collator from the SFT config, preventing duplication and enhancing clarity in trainer setup.
b21080c unverified
raw
history blame
8.59 kB
#!/usr/bin/env python3
"""
Fine-tuning script for SmolLM2-135M model using Unsloth.
This script demonstrates how to:
1. Install and configure Unsloth
2. Prepare and format training data
3. Configure and run the training process
4. Save and evaluate the model
To run this script:
1. Install dependencies: pip install -r requirements.txt
2. Run: python train.py
"""
import logging
import os
from datetime import datetime
from pathlib import Path
from typing import Union
import hydra
from omegaconf import DictConfig, OmegaConf
# isort: off
from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402
from unsloth.chat_templates import get_chat_template # noqa: E402
# isort: on
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
IterableDatasetDict,
load_dataset,
)
from transformers import (
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from trl import SFTTrainer
# Setup logging
def setup_logging():
"""Configure logging for the training process."""
# Create logs directory if it doesn't exist
log_dir = Path("logs")
log_dir.mkdir(exist_ok=True)
# Create a unique log file name with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_file = log_dir / f"training_{timestamp}.log"
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
logger.info(f"Logging initialized. Log file: {log_file}")
return logger
logger = setup_logging()
def install_dependencies():
"""Install required dependencies."""
logger.info("Installing dependencies...")
try:
os.system(
'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"'
)
os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes")
logger.info("Dependencies installed successfully")
except Exception as e:
logger.error(f"Error installing dependencies: {e}")
raise
def load_model(cfg: DictConfig) -> tuple[FastLanguageModel, AutoTokenizer]:
"""Load and configure the model."""
logger.info("Loading model and tokenizer...")
try:
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=cfg.model.name,
max_seq_length=cfg.model.max_seq_length,
dtype=cfg.model.dtype,
load_in_4bit=cfg.model.load_in_4bit,
)
logger.info("Base model loaded successfully")
# Configure LoRA
model = FastLanguageModel.get_peft_model(
model,
r=cfg.peft.r,
target_modules=cfg.peft.target_modules,
lora_alpha=cfg.peft.lora_alpha,
lora_dropout=cfg.peft.lora_dropout,
bias=cfg.peft.bias,
use_gradient_checkpointing=cfg.peft.use_gradient_checkpointing,
random_state=cfg.peft.random_state,
use_rslora=cfg.peft.use_rslora,
loftq_config=cfg.peft.loftq_config,
)
logger.info("LoRA configuration applied successfully")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model: {e}")
raise
def load_and_format_dataset(
tokenizer: AutoTokenizer,
cfg: DictConfig,
) -> tuple[
Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer
]:
"""Load and format the training dataset."""
logger.info("Loading and formatting dataset...")
try:
# Load the code-act dataset
dataset = load_dataset("xingyaoww/code-act", split="codeact")
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples")
# Split into train and validation sets
dataset = dataset.train_test_split(test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed)
logger.info(
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets"
)
# Configure chat template
tokenizer = get_chat_template(
tokenizer,
chat_template="chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth
mapping={
"role": "from",
"content": "value",
"user": "human",
"assistant": "gpt",
}, # ShareGPT style
map_eos_token=True, # Maps <|im_end|> to </s> instead
)
logger.info("Chat template configured successfully")
def formatting_prompts_func(examples):
convos = examples["conversations"]
texts = [
tokenizer.apply_chat_template(
convo, tokenize=False, add_generation_prompt=False
)
for convo in convos
]
return {"text": texts}
# Apply formatting to both train and validation sets
dataset = DatasetDict(
{
"train": dataset["train"].map(formatting_prompts_func, batched=True),
"validation": dataset["test"].map(
formatting_prompts_func, batched=True
),
}
)
logger.info("Dataset formatting completed successfully")
return dataset, tokenizer
except Exception as e:
logger.error(f"Error loading/formatting dataset: {e}")
raise
def create_trainer(
model: FastLanguageModel,
tokenizer: AutoTokenizer,
dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset],
cfg: DictConfig,
) -> Trainer:
"""Create and configure the SFTTrainer."""
logger.info("Creating trainer...")
try:
# Create TrainingArguments from config
training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True)
# Add dynamic precision settings
training_args_dict.update({
"fp16": not is_bfloat16_supported(),
"bf16": is_bfloat16_supported(),
})
training_args = TrainingArguments(**training_args_dict)
# Create data collator from config
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
**cfg.training.sft.data_collator,
)
# Create SFT config without data_collator to avoid duplication
sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True)
sft_config.pop('data_collator', None) # Remove data_collator from config
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
args=training_args,
data_collator=data_collator,
**sft_config,
)
logger.info("Trainer created successfully")
return trainer
except Exception as e:
logger.error(f"Error creating trainer: {e}")
raise
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
"""Main training function."""
try:
logger.info("Starting training process...")
logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
# Install dependencies
install_dependencies()
# Load model and tokenizer
model, tokenizer = load_model(cfg)
# Load and prepare dataset
dataset, tokenizer = load_and_format_dataset(tokenizer, cfg)
# Create trainer
trainer: Trainer = create_trainer(model, tokenizer, dataset, cfg)
# Train if requested
if cfg.train:
logger.info("Starting training...")
trainer.train()
# Save model
logger.info(f"Saving final model to {cfg.output.dir}...")
trainer.save_model(cfg.output.dir)
# Print final metrics
final_metrics = trainer.state.log_history[-1]
logger.info("\nTraining completed!")
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}")
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}")
else:
logger.info("Training skipped as train=False")
except Exception as e:
logger.error(f"Error in main training process: {e}")
raise
if __name__ == "__main__":
main()