Spaces:
Build error
Build error
#!/usr/bin/env python3 | |
""" | |
Fine-tuning script for SmolLM2-135M model using Unsloth. | |
This script demonstrates how to: | |
1. Install and configure Unsloth | |
2. Prepare and format training data | |
3. Configure and run the training process | |
4. Save and evaluate the model | |
To run this script: | |
1. Install dependencies: pip install -r requirements.txt | |
2. Run: python train.py | |
""" | |
import logging | |
import os | |
from datetime import datetime | |
from pathlib import Path | |
from typing import Union | |
import hydra | |
from omegaconf import DictConfig, OmegaConf | |
# isort: off | |
from unsloth import FastLanguageModel, is_bfloat16_supported # noqa: E402 | |
from unsloth.chat_templates import get_chat_template # noqa: E402 | |
# isort: on | |
from datasets import ( | |
Dataset, | |
DatasetDict, | |
IterableDataset, | |
IterableDatasetDict, | |
load_dataset, | |
) | |
from transformers import ( | |
AutoTokenizer, | |
DataCollatorForLanguageModeling, | |
Trainer, | |
TrainingArguments, | |
) | |
from trl import SFTTrainer | |
# Setup logging | |
def setup_logging(): | |
"""Configure logging for the training process.""" | |
# Create logs directory if it doesn't exist | |
log_dir = Path("logs") | |
log_dir.mkdir(exist_ok=True) | |
# Create a unique log file name with timestamp | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
log_file = log_dir / f"training_{timestamp}.log" | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
handlers=[logging.FileHandler(log_file), logging.StreamHandler()], | |
) | |
logger = logging.getLogger(__name__) | |
logger.info(f"Logging initialized. Log file: {log_file}") | |
return logger | |
logger = setup_logging() | |
def install_dependencies(): | |
"""Install required dependencies.""" | |
logger.info("Installing dependencies...") | |
try: | |
os.system( | |
'pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"' | |
) | |
os.system("pip install --no-deps xformers trl peft accelerate bitsandbytes") | |
logger.info("Dependencies installed successfully") | |
except Exception as e: | |
logger.error(f"Error installing dependencies: {e}") | |
raise | |
def load_model(cfg: DictConfig) -> tuple[FastLanguageModel, AutoTokenizer]: | |
"""Load and configure the model.""" | |
logger.info("Loading model and tokenizer...") | |
try: | |
model, tokenizer = FastLanguageModel.from_pretrained( | |
model_name=cfg.model.name, | |
max_seq_length=cfg.model.max_seq_length, | |
dtype=cfg.model.dtype, | |
load_in_4bit=cfg.model.load_in_4bit, | |
) | |
logger.info("Base model loaded successfully") | |
# Configure LoRA | |
model = FastLanguageModel.get_peft_model( | |
model, | |
r=cfg.peft.r, | |
target_modules=cfg.peft.target_modules, | |
lora_alpha=cfg.peft.lora_alpha, | |
lora_dropout=cfg.peft.lora_dropout, | |
bias=cfg.peft.bias, | |
use_gradient_checkpointing=cfg.peft.use_gradient_checkpointing, | |
random_state=cfg.peft.random_state, | |
use_rslora=cfg.peft.use_rslora, | |
loftq_config=cfg.peft.loftq_config, | |
) | |
logger.info("LoRA configuration applied successfully") | |
return model, tokenizer | |
except Exception as e: | |
logger.error(f"Error loading model: {e}") | |
raise | |
def load_and_format_dataset( | |
tokenizer: AutoTokenizer, | |
cfg: DictConfig, | |
) -> tuple[ | |
Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], AutoTokenizer | |
]: | |
"""Load and format the training dataset.""" | |
logger.info("Loading and formatting dataset...") | |
try: | |
# Load the code-act dataset | |
dataset = load_dataset("xingyaoww/code-act", split="codeact") | |
logger.info(f"Dataset loaded successfully. Size: {len(dataset)} examples") | |
# Split into train and validation sets | |
dataset = dataset.train_test_split(test_size=cfg.dataset.validation_split, seed=cfg.dataset.seed) | |
logger.info( | |
f"Dataset split into train ({len(dataset['train'])} examples) and validation ({len(dataset['test'])} examples) sets" | |
) | |
# Configure chat template | |
tokenizer = get_chat_template( | |
tokenizer, | |
chat_template="chatml", # Supports zephyr, chatml, mistral, llama, alpaca, vicuna, vicuna_old, unsloth | |
mapping={ | |
"role": "from", | |
"content": "value", | |
"user": "human", | |
"assistant": "gpt", | |
}, # ShareGPT style | |
map_eos_token=True, # Maps <|im_end|> to </s> instead | |
) | |
logger.info("Chat template configured successfully") | |
def formatting_prompts_func(examples): | |
convos = examples["conversations"] | |
texts = [ | |
tokenizer.apply_chat_template( | |
convo, tokenize=False, add_generation_prompt=False | |
) | |
for convo in convos | |
] | |
return {"text": texts} | |
# Apply formatting to both train and validation sets | |
dataset = DatasetDict( | |
{ | |
"train": dataset["train"].map(formatting_prompts_func, batched=True), | |
"validation": dataset["test"].map( | |
formatting_prompts_func, batched=True | |
), | |
} | |
) | |
logger.info("Dataset formatting completed successfully") | |
return dataset, tokenizer | |
except Exception as e: | |
logger.error(f"Error loading/formatting dataset: {e}") | |
raise | |
def create_trainer( | |
model: FastLanguageModel, | |
tokenizer: AutoTokenizer, | |
dataset: Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset], | |
cfg: DictConfig, | |
) -> Trainer: | |
"""Create and configure the SFTTrainer.""" | |
logger.info("Creating trainer...") | |
try: | |
# Create TrainingArguments from config | |
training_args_dict = OmegaConf.to_container(cfg.training.args, resolve=True) | |
# Add dynamic precision settings | |
training_args_dict.update({ | |
"fp16": not is_bfloat16_supported(), | |
"bf16": is_bfloat16_supported(), | |
}) | |
training_args = TrainingArguments(**training_args_dict) | |
# Create data collator from config | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
**cfg.training.sft.data_collator, | |
) | |
# Create SFT config without data_collator to avoid duplication | |
sft_config = OmegaConf.to_container(cfg.training.sft, resolve=True) | |
sft_config.pop('data_collator', None) # Remove data_collator from config | |
trainer = SFTTrainer( | |
model=model, | |
tokenizer=tokenizer, | |
train_dataset=dataset["train"], | |
eval_dataset=dataset["validation"], | |
args=training_args, | |
data_collator=data_collator, | |
**sft_config, | |
) | |
logger.info("Trainer created successfully") | |
return trainer | |
except Exception as e: | |
logger.error(f"Error creating trainer: {e}") | |
raise | |
def main(cfg: DictConfig) -> None: | |
"""Main training function.""" | |
try: | |
logger.info("Starting training process...") | |
logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}") | |
# Install dependencies | |
install_dependencies() | |
# Load model and tokenizer | |
model, tokenizer = load_model(cfg) | |
# Load and prepare dataset | |
dataset, tokenizer = load_and_format_dataset(tokenizer, cfg) | |
# Create trainer | |
trainer: Trainer = create_trainer(model, tokenizer, dataset, cfg) | |
# Train if requested | |
if cfg.train: | |
logger.info("Starting training...") | |
trainer.train() | |
# Save model | |
logger.info(f"Saving final model to {cfg.output.dir}...") | |
trainer.save_model(cfg.output.dir) | |
# Print final metrics | |
final_metrics = trainer.state.log_history[-1] | |
logger.info("\nTraining completed!") | |
logger.info(f"Final training loss: {final_metrics.get('loss', 'N/A')}") | |
logger.info(f"Final validation loss: {final_metrics.get('eval_loss', 'N/A')}") | |
else: | |
logger.info("Training skipped as train=False") | |
except Exception as e: | |
logger.error(f"Error in main training process: {e}") | |
raise | |
if __name__ == "__main__": | |
main() | |