SmolFactory / src /data.py
Tonic's picture
adds sft , quantization, better readmes
40fd629 verified
"""
SmolLM3 Dataset Handler
Handles data loading, preprocessing, and tokenization for SmolLM3 fine-tuning
"""
import os
import json
import torch
from typing import Dict, List, Optional, Union
from datasets import Dataset, load_dataset
from transformers import PreTrainedTokenizer
import logging
logger = logging.getLogger(__name__)
class SmolLM3Dataset:
"""Dataset handler for SmolLM3 fine-tuning"""
def __init__(
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_seq_length: int = 4096,
use_chat_template: bool = True,
chat_template_kwargs: Optional[Dict] = None,
filter_bad_entries: bool = False,
bad_entry_field: str = "bad_entry",
sample_size: Optional[int] = None,
sample_seed: int = 42
):
self.data_path = data_path
self.tokenizer = tokenizer
self.max_seq_length = max_seq_length
self.use_chat_template = use_chat_template
self.chat_template_kwargs = chat_template_kwargs or {}
self.filter_bad_entries = filter_bad_entries
self.bad_entry_field = bad_entry_field
self.sample_size = sample_size
self.sample_seed = sample_seed
# Load and process dataset
self.dataset = self._load_dataset()
self.processed_dataset = self._process_dataset()
def _load_dataset(self) -> Dataset:
"""Load dataset from various formats"""
logger.info("Loading dataset from %s", self.data_path)
# Check if it's a Hugging Face dataset
if os.path.isdir(self.data_path):
# Local directory
try:
dataset = load_dataset("json", data_files={
"train": os.path.join(self.data_path, "train.json"),
"validation": os.path.join(self.data_path, "validation.json") if os.path.exists(os.path.join(self.data_path, "validation.json")) else None,
"test": os.path.join(self.data_path, "test.json") if os.path.exists(os.path.join(self.data_path, "test.json")) else None
})
logger.info("Loaded dataset from local JSON files")
return dataset
except Exception as e:
logger.warning("Failed to load as JSON dataset: %s", e)
# Try to load as a single JSON file
if os.path.isfile(self.data_path) and self.data_path.endswith('.json'):
try:
with open(self.data_path, 'r', encoding='utf-8') as f:
data = json.load(f)
# Convert to dataset format
if isinstance(data, list):
dataset = Dataset.from_list(data)
else:
dataset = Dataset.from_dict(data)
logger.info("Loaded dataset from single JSON file")
return dataset
except Exception as e:
logger.error("Failed to load JSON file: %s", e)
raise
# Try to load as a Hugging Face dataset name
try:
dataset = load_dataset(self.data_path)
logger.info("Loaded Hugging Face dataset: %s", self.data_path)
# Filter bad entries if requested
if self.filter_bad_entries and self.bad_entry_field in dataset["train"].column_names:
logger.info("Filtering out bad entries using field: %s", self.bad_entry_field)
for split in dataset:
if self.bad_entry_field in dataset[split].column_names:
original_size = len(dataset[split])
dataset[split] = dataset[split].filter(lambda x: not x[self.bad_entry_field])
filtered_size = len(dataset[split])
logger.info("Filtered %s: %d -> %d samples", split, original_size, filtered_size)
# Apply sampling if requested
if self.sample_size is not None and "train" in dataset:
logger.info(f"Sampling {self.sample_size} random samples from {len(dataset['train'])} total samples")
import random
random.seed(self.sample_seed)
# Sample indices
total_samples = len(dataset["train"])
if self.sample_size > total_samples:
logger.warning(f"Requested sample size ({self.sample_size}) is larger than dataset size ({total_samples}). Using all samples.")
sampled_indices = list(range(total_samples))
else:
sampled_indices = random.sample(range(total_samples), self.sample_size)
# Apply sampling to train split
dataset["train"] = dataset["train"].select(sampled_indices)
logger.info(f"Sampled {len(dataset['train'])} train samples")
# Also sample validation if it exists and is large
if "validation" in dataset and len(dataset["validation"]) > 1000:
val_sample_size = min(1000, len(dataset["validation"]))
logger.info(f"Sampling {val_sample_size} validation samples from {len(dataset['validation'])} total")
val_sampled_indices = random.sample(range(len(dataset["validation"])), val_sample_size)
dataset["validation"] = dataset["validation"].select(val_sampled_indices)
logger.info(f"Sampled {len(dataset['validation'])} validation samples")
# If only 'train' split exists, create validation and test splits
if ("train" in dataset) and ("validation" not in dataset or "test" not in dataset):
logger.info("Automatically splitting train into train/validation/test (98/1/1)")
split_dataset = dataset["train"].train_test_split(test_size=0.02, seed=42)
# Now split test into validation and test (1% each)
val_test_split = split_dataset["test"].train_test_split(test_size=0.5, seed=42)
dataset = {
"train": split_dataset["train"],
"validation": val_test_split["train"],
"test": val_test_split["test"]
}
return dataset
except Exception as e:
logger.error("Failed to load dataset: %s", e)
raise
def _process_dataset(self) -> Dataset:
"""Process the dataset for training"""
logger.info("Processing dataset for training")
def format_chat_template(example):
"""Format example using chat template"""
if self.use_chat_template:
try:
# Handle different input formats
if "messages" in example:
messages = example["messages"]
elif "conversations" in example:
messages = example["conversations"]
elif "user" in example and "assistant" in example:
messages = [
{"role": "user", "content": example["user"]},
{"role": "assistant", "content": example["assistant"]}
]
elif "instruction" in example and "output" in example:
messages = [
{"role": "user", "content": example["instruction"]},
{"role": "assistant", "content": example["output"]}
]
elif "prompt" in example and "completion" in example:
messages = [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["completion"]}
]
elif "prompt" in example and "accepted_completion" in example:
messages = [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["accepted_completion"]}
]
elif "prompt" in example and "completion" in example:
messages = [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["completion"]}
]
else:
# Fallback: treat as plain text
return {"text": str(example)}
# Add system message with /no_think tag if not present
if messages and messages[0]["role"] != "system":
# Check if we should add /no_think tag based on configuration
system_content = "Tu es TonicIA, un assistant francophone rigoureux et bienveillant."
if hasattr(self, 'chat_template_kwargs') and self.chat_template_kwargs:
# If no_think_system_message is True, add /no_think tag
if self.chat_template_kwargs.get("no_think_system_message") == True:
system_content = "Tu es TonicIA , un assistant francophone rigoureux et bienveillant. /no_think"
messages.insert(0, {"role": "system", "content": system_content})
# Apply chat template
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=self.chat_template_kwargs.get("add_generation_prompt", True)
)
return {"text": text}
except Exception as e:
logger.warning("Failed to apply chat template: %s", e)
# Fallback to plain text
return {"text": str(example)}
else:
# Use plain text
if "text" in example:
return {"text": example["text"]}
else:
return {"text": str(example)}
def tokenize_function(examples):
"""Tokenize the examples"""
# Tokenize the texts with fixed length
tokenized = self.tokenizer(
examples["text"],
truncation=True,
padding=True, # Enable padding during tokenization
max_length=self.max_seq_length,
return_overflowing_tokens=False, # Don't return overflowing tokens
return_length=True,
)
# Calculate input length
input_length = [len(x) for x in tokenized["input_ids"]]
# Create labels (same as input_ids for causal LM)
tokenized["labels"] = tokenized["input_ids"].copy()
return {
"input_ids": tokenized["input_ids"],
"attention_mask": tokenized["attention_mask"],
"labels": tokenized["labels"],
"length": input_length,
}
# Process the dataset - handle both single dataset and dictionary of splits
if isinstance(self.dataset, dict):
# Process each split individually
processed_dataset = {}
for split_name, split_dataset in self.dataset.items():
logger.info("Processing %s split...", split_name)
# Format the split
processed_split = split_dataset.map(
format_chat_template,
remove_columns=split_dataset.column_names,
desc="Formatting {} dataset".format(split_name)
)
# Tokenize the split
tokenized_split = processed_split.map(
tokenize_function,
remove_columns=processed_split.column_names,
desc="Tokenizing {} dataset".format(split_name),
batched=True,
)
processed_dataset[split_name] = tokenized_split
else:
# Single dataset
processed_dataset = self.dataset.map(
format_chat_template,
remove_columns=self.dataset.column_names,
desc="Formatting dataset"
)
# Tokenize the dataset
processed_dataset = processed_dataset.map(
tokenize_function,
remove_columns=processed_dataset.column_names,
desc="Tokenizing dataset",
batched=True,
)
# Log processing results
if isinstance(processed_dataset, dict):
logger.info("Dataset processed. Train samples: %d", len(processed_dataset['train']))
if "validation" in processed_dataset:
logger.info("Validation samples: %d", len(processed_dataset['validation']))
if "test" in processed_dataset:
logger.info("Test samples: %d", len(processed_dataset['test']))
else:
logger.info("Dataset processed. Samples: %d", len(processed_dataset))
return processed_dataset
def get_train_dataset(self) -> Dataset:
"""Get training dataset"""
return self.processed_dataset["train"]
def get_eval_dataset(self) -> Optional[Dataset]:
"""Get evaluation dataset if available"""
if "validation" in self.processed_dataset:
return self.processed_dataset["validation"]
elif "test" in self.processed_dataset:
return self.processed_dataset["test"]
else:
return None
def get_data_collator(self):
"""Get data collator for training"""
from transformers import DataCollatorForLanguageModeling
import torch
base_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False,
pad_to_multiple_of=8,
return_tensors="pt",
)
def collator_with_stats(features):
batch = base_collator(features)
# Calculate token stats
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", None)
labels = batch.get("labels", None)
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is None:
pad_token_id = self.tokenizer.eos_token_id
total_tokens = int((input_ids != pad_token_id).sum().item())
padding_tokens = int((input_ids == pad_token_id).sum().item())
batch_size, seq_len = input_ids.shape
# Truncated tokens: count tokens that were cut off due to max_seq_length
# (Assume all input is truncated to max_seq_length, so count tokens at max length)
truncated_tokens = 0
for f in features:
if "length" in f and f["length"] >= self.max_seq_length:
truncated_tokens += f["length"] - self.max_seq_length + 1
batch["total_tokens"] = total_tokens
batch["padding_tokens"] = padding_tokens
batch["truncated_tokens"] = truncated_tokens
batch["batch_size"] = batch_size
batch["seq_len"] = seq_len
return batch
return collator_with_stats
def create_sample_dataset(output_path: str = "my_dataset"):
"""Create a sample dataset for testing"""
os.makedirs(output_path, exist_ok=True)
# Sample conversations
conversations = [
{
"messages": [
{"role": "user", "content": "What is machine learning?"},
{"role": "assistant", "content": "Machine learning is a subset of artificial intelligence that enables computers to learn and improve from experience without being explicitly programmed."}
]
},
{
"messages": [
{"role": "user", "content": "Explain gravity in simple terms."},
{"role": "assistant", "content": "Gravity is the force that pulls objects toward each other, like how the Earth pulls things down to the ground."}
]
},
{
"messages": [
{"role": "user", "content": "How do I make a cup of coffee?"},
{"role": "assistant", "content": "To make a cup of coffee: 1) Boil water, 2) Add coffee grounds to a filter, 3) Pour hot water over the grounds, 4) Let it brew for a few minutes, 5) Enjoy!"}
]
}
]
# Split into train/validation
train_data = conversations[:2]
validation_data = conversations[2:]
# Save to files
with open(os.path.join(output_path, "train.json"), 'w', encoding='utf-8') as f:
json.dump(train_data, f, indent=2, ensure_ascii=False)
with open(os.path.join(output_path, "validation.json"), 'w', encoding='utf-8') as f:
json.dump(validation_data, f, indent=2, ensure_ascii=False)
logger.info("Sample dataset created in %s", output_path)
return output_path