MuratcanKoylan's picture
Upload folder using huggingface_hub
685d968 verified
"""
Supervised Fine-Tuning (SFT) for Memory Routing
This implements Stage 1 of the PRD: Prompt Distillation using Tinker's
cross_entropy loss function with LoRA fine-tuning.
Per Tinker docs (supervised-learning.mdx):
- SFT means maximizing log-probability of target tokens
- Use cross_entropy loss: -(weights * logp(target_tokens)).sum()
Per Tinker docs (lora-primer.mdx):
- LoRA requires larger LR than full fine-tuning (20-100x)
- Use get_lr() utility to get recommended LR
- Default rank 32 is suitable for classification tasks
Per Tinker docs (async.mdx):
- Use async methods for performance
- Double await pattern: await future, then await result_async()
Per PRD Section 7:
- 300-500 steps minimum
- Batch size 128
- Early stopping if test loss plateaus
- Checkpoint every 20 steps
"""
import asyncio
import json
import time
from typing import List, Dict, Any, Optional
from dataclasses import dataclass, field
# Configuration
@dataclass
class SFTConfig:
# Model
base_model: str = "meta-llama/Llama-3.1-8B"
lora_rank: int = 32
renderer_name: str = "llama3"
# Training
num_steps: int = 300
batch_size: int = 128
learning_rate: Optional[float] = None # Will use get_lr() if None
# Adam optimizer (per Tinker defaults)
beta1: float = 0.9
beta2: float = 0.95
eps: float = 1e-8
# Checkpointing
checkpoint_every: int = 20
eval_every: int = 20
# Early stopping
early_stopping_patience: int = 5 # Stop if no improvement for this many evals
# Paths
train_data_path: str = "training/processed_data/train_data.json"
test_data_path: str = "training/processed_data/test_data.json"
log_path: str = "training/logs/sft"
@dataclass
class TrainingMetrics:
step: int
train_loss: float
test_loss: Optional[float] = None
learning_rate: float = 0.0
batch_time: float = 0.0
checkpoint_path: Optional[str] = None
def load_processed_data(path: str) -> List[Dict[str, Any]]:
"""Load preprocessed data from JSON."""
with open(path, "r") as f:
return json.load(f)
def create_batch(data: List[Any], batch_size: int, step: int) -> List[Any]:
"""
Create a batch of data for training.
Cycles through data if step * batch_size exceeds data length.
"""
start_idx = (step * batch_size) % len(data)
end_idx = start_idx + batch_size
if end_idx <= len(data):
return data[start_idx:end_idx]
else:
# Wrap around
batch = data[start_idx:]
batch.extend(data[:end_idx - len(data)])
return batch
async def run_sft_training(config: SFTConfig):
"""
Main SFT training loop.
Per Tinker docs (training-sampling.mdx):
1. Create ServiceClient
2. Create TrainingClient with base_model and LoRA config
3. Loop: forward_backward -> optim_step
4. Periodically save checkpoints and evaluate
"""
import tinker
from tinker import types
from tinker_cookbook.hyperparam_utils import get_lr
from tinker_cookbook import renderers, tokenizer_utils
import numpy as np
import os
from dotenv import load_dotenv
# Load API key from .env
load_dotenv()
os.makedirs(config.log_path, exist_ok=True)
# Get learning rate if not specified
if config.learning_rate is None:
config.learning_rate = get_lr(config.base_model)
print(f"Using recommended LR for {config.base_model}: {config.learning_rate:.2e}")
# Load data
print(f"Loading training data from {config.train_data_path}...")
train_data_raw = load_processed_data(config.train_data_path)
print(f"Loading test data from {config.test_data_path}...")
test_data_raw = load_processed_data(config.test_data_path)
print(f"Train examples: {len(train_data_raw)}")
print(f"Test examples: {len(test_data_raw)}")
# Initialize Tinker clients
print(f"\nInitializing Tinker ServiceClient...")
service_client = tinker.ServiceClient()
print(f"Creating LoRA training client...")
print(f" Base model: {config.base_model}")
print(f" LoRA rank: {config.lora_rank}")
training_client = await service_client.create_lora_training_client_async(
base_model=config.base_model,
rank=config.lora_rank,
)
# Get tokenizer from training client (avoids HF auth issues)
tokenizer = training_client.get_tokenizer()
renderer = renderers.get_renderer(name=config.renderer_name, tokenizer=tokenizer)
# Convert raw data to Datum objects
print("Converting data to Datum objects...")
def convert_to_datum(item: Dict) -> types.Datum:
"""Convert preprocessed item back to Datum."""
if "model_input" in item:
# Already in Datum format
return types.Datum(
model_input=types.ModelInput.from_ints(item["model_input"]["chunks"][0]["tokens"]),
loss_fn_inputs=item["loss_fn_inputs"]
)
else:
# Mock format - need to re-tokenize
messages = item["messages"]
tokens, weights = renderer.build_supervised_example(messages)
# Convert tensors to lists if needed
if hasattr(tokens, 'tolist'):
tokens = tokens.tolist()
if hasattr(weights, 'tolist'):
weights = weights.tolist()
input_tokens = tokens[:-1]
target_tokens = tokens[1:]
loss_weights = weights[1:]
return types.Datum(
model_input=types.ModelInput.from_ints(input_tokens),
loss_fn_inputs=dict(
target_tokens=target_tokens,
weights=loss_weights
)
)
train_data = [convert_to_datum(item) for item in train_data_raw]
test_data = [convert_to_datum(item) for item in test_data_raw]
print(f"Converted {len(train_data)} train, {len(test_data)} test examples")
# Training loop
print(f"\n{'='*60}")
print(f"Starting SFT Training")
print(f"{'='*60}")
print(f"Steps: {config.num_steps}")
print(f"Batch size: {config.batch_size}")
print(f"Learning rate: {config.learning_rate:.2e}")
print(f"Checkpoint every: {config.checkpoint_every} steps")
print(f"Eval every: {config.eval_every} steps")
print(f"{'='*60}\n")
metrics_log = []
best_test_loss = float('inf')
no_improvement_count = 0
final_checkpoint_path = None
for step in range(config.num_steps):
step_start = time.time()
# Create batch
batch = create_batch(train_data, config.batch_size, step)
# Forward-backward pass
# Per Tinker docs: submit forward_backward, then optim_step
# Can overlap by submitting both before waiting
fwd_bwd_future = await training_client.forward_backward_async(
batch,
loss_fn="cross_entropy"
)
# Optimizer step
adam_params = types.AdamParams(
learning_rate=config.learning_rate,
beta1=config.beta1,
beta2=config.beta2,
eps=config.eps,
)
optim_future = await training_client.optim_step_async(adam_params)
# Wait for results
# Per Tinker async.mdx: must await result_async() to get actual values
fwd_bwd_result = await fwd_bwd_future.result_async()
optim_result = await optim_future.result_async()
# Compute train loss
# Per Tinker losses.mdx: cross_entropy outputs logprobs
logprobs = np.concatenate([
output['logprobs'].tolist()
for output in fwd_bwd_result.loss_fn_outputs
])
weights = np.concatenate([
datum.loss_fn_inputs['weights'].tolist()
for datum in batch
])
train_loss = -np.dot(logprobs, weights) / max(weights.sum(), 1)
step_time = time.time() - step_start
# Create metrics
metrics = TrainingMetrics(
step=step,
train_loss=train_loss,
learning_rate=config.learning_rate,
batch_time=step_time
)
# Periodic evaluation
if step % config.eval_every == 0 or step == config.num_steps - 1:
# Evaluate on test set (sample a batch)
test_batch = create_batch(test_data, min(config.batch_size, len(test_data)), 0)
# Forward only (no backward) for evaluation
eval_future = await training_client.forward_backward_async(
test_batch,
loss_fn="cross_entropy"
)
eval_result = await eval_future.result_async()
test_logprobs = np.concatenate([
output['logprobs'].tolist()
for output in eval_result.loss_fn_outputs
])
test_weights = np.concatenate([
datum.loss_fn_inputs['weights'].tolist()
for datum in test_batch
])
test_loss = -np.dot(test_logprobs, test_weights) / max(test_weights.sum(), 1)
metrics.test_loss = test_loss
# Early stopping check
if test_loss < best_test_loss:
best_test_loss = test_loss
no_improvement_count = 0
else:
no_improvement_count += 1
if no_improvement_count >= config.early_stopping_patience:
print(f"\nEarly stopping at step {step} (no improvement for {config.early_stopping_patience} evals)")
break
# Periodic checkpointing
if step % config.checkpoint_every == 0 or step == config.num_steps - 1:
# Save both sampler weights (for inference) and full state (for RL continuation)
# Per Tinker save-load.mdx: save_state for resuming training
# Sampler weights for inference
sampler_future = await training_client.save_weights_for_sampler_async(
name=f"sft_step_{step:04d}"
)
sampler_result = await sampler_future.result_async()
metrics.checkpoint_path = sampler_result.path
# Full state for RL continuation (only at final step to save storage)
if step == config.num_steps - 1:
state_future = await training_client.save_state_async(
name=f"sft_final_state"
)
state_result = await state_future.result_async()
final_checkpoint_path = state_result.path
print(f" Full state checkpoint: {final_checkpoint_path}")
else:
final_checkpoint_path = sampler_result.path
metrics_log.append(metrics)
# Print progress
test_str = f", test_loss={metrics.test_loss:.4f}" if metrics.test_loss else ""
ckpt_str = f", checkpoint={metrics.checkpoint_path}" if metrics.checkpoint_path else ""
print(f"Step {step:4d}/{config.num_steps}: train_loss={train_loss:.4f}{test_str}, time={step_time:.1f}s{ckpt_str}")
# Save metrics log
metrics_path = os.path.join(config.log_path, "metrics.jsonl")
with open(metrics_path, "w") as f:
for m in metrics_log:
f.write(json.dumps({
"step": m.step,
"train_loss": m.train_loss,
"test_loss": m.test_loss,
"learning_rate": m.learning_rate,
"batch_time": m.batch_time,
"checkpoint_path": m.checkpoint_path
}) + "\n")
print(f"\n{'='*60}")
print(f"SFT Training Complete")
print(f"{'='*60}")
print(f"Final train loss: {metrics_log[-1].train_loss:.4f}")
print(f"Best test loss: {best_test_loss:.4f}")
print(f"Final checkpoint: {final_checkpoint_path}")
print(f"Metrics saved to: {metrics_path}")
print(f"{'='*60}")
return final_checkpoint_path, metrics_log
async def main():
"""Entry point for SFT training."""
import sys
config = SFTConfig()
# Parse command line args
for arg in sys.argv[1:]:
if "=" in arg:
key, value = arg.split("=", 1)
if hasattr(config, key):
# Type conversion
current_value = getattr(config, key)
if isinstance(current_value, int):
setattr(config, key, int(value))
elif isinstance(current_value, float):
setattr(config, key, float(value))
else:
setattr(config, key, value)
await run_sft_training(config)
if __name__ == "__main__":
asyncio.run(main())