|
|
""" |
|
|
Supervised Fine-Tuning (SFT) for Memory Routing |
|
|
|
|
|
This implements Stage 1 of the PRD: Prompt Distillation using Tinker's |
|
|
cross_entropy loss function with LoRA fine-tuning. |
|
|
|
|
|
Per Tinker docs (supervised-learning.mdx): |
|
|
- SFT means maximizing log-probability of target tokens |
|
|
- Use cross_entropy loss: -(weights * logp(target_tokens)).sum() |
|
|
|
|
|
Per Tinker docs (lora-primer.mdx): |
|
|
- LoRA requires larger LR than full fine-tuning (20-100x) |
|
|
- Use get_lr() utility to get recommended LR |
|
|
- Default rank 32 is suitable for classification tasks |
|
|
|
|
|
Per Tinker docs (async.mdx): |
|
|
- Use async methods for performance |
|
|
- Double await pattern: await future, then await result_async() |
|
|
|
|
|
Per PRD Section 7: |
|
|
- 300-500 steps minimum |
|
|
- Batch size 128 |
|
|
- Early stopping if test loss plateaus |
|
|
- Checkpoint every 20 steps |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import time |
|
|
from typing import List, Dict, Any, Optional |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SFTConfig: |
|
|
|
|
|
base_model: str = "meta-llama/Llama-3.1-8B" |
|
|
lora_rank: int = 32 |
|
|
renderer_name: str = "llama3" |
|
|
|
|
|
|
|
|
num_steps: int = 300 |
|
|
batch_size: int = 128 |
|
|
learning_rate: Optional[float] = None |
|
|
|
|
|
|
|
|
beta1: float = 0.9 |
|
|
beta2: float = 0.95 |
|
|
eps: float = 1e-8 |
|
|
|
|
|
|
|
|
checkpoint_every: int = 20 |
|
|
eval_every: int = 20 |
|
|
|
|
|
|
|
|
early_stopping_patience: int = 5 |
|
|
|
|
|
|
|
|
train_data_path: str = "training/processed_data/train_data.json" |
|
|
test_data_path: str = "training/processed_data/test_data.json" |
|
|
log_path: str = "training/logs/sft" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingMetrics: |
|
|
step: int |
|
|
train_loss: float |
|
|
test_loss: Optional[float] = None |
|
|
learning_rate: float = 0.0 |
|
|
batch_time: float = 0.0 |
|
|
checkpoint_path: Optional[str] = None |
|
|
|
|
|
|
|
|
def load_processed_data(path: str) -> List[Dict[str, Any]]: |
|
|
"""Load preprocessed data from JSON.""" |
|
|
with open(path, "r") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
def create_batch(data: List[Any], batch_size: int, step: int) -> List[Any]: |
|
|
""" |
|
|
Create a batch of data for training. |
|
|
Cycles through data if step * batch_size exceeds data length. |
|
|
""" |
|
|
start_idx = (step * batch_size) % len(data) |
|
|
end_idx = start_idx + batch_size |
|
|
|
|
|
if end_idx <= len(data): |
|
|
return data[start_idx:end_idx] |
|
|
else: |
|
|
|
|
|
batch = data[start_idx:] |
|
|
batch.extend(data[:end_idx - len(data)]) |
|
|
return batch |
|
|
|
|
|
|
|
|
async def run_sft_training(config: SFTConfig): |
|
|
""" |
|
|
Main SFT training loop. |
|
|
|
|
|
Per Tinker docs (training-sampling.mdx): |
|
|
1. Create ServiceClient |
|
|
2. Create TrainingClient with base_model and LoRA config |
|
|
3. Loop: forward_backward -> optim_step |
|
|
4. Periodically save checkpoints and evaluate |
|
|
""" |
|
|
import tinker |
|
|
from tinker import types |
|
|
from tinker_cookbook.hyperparam_utils import get_lr |
|
|
from tinker_cookbook import renderers, tokenizer_utils |
|
|
import numpy as np |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
os.makedirs(config.log_path, exist_ok=True) |
|
|
|
|
|
|
|
|
if config.learning_rate is None: |
|
|
config.learning_rate = get_lr(config.base_model) |
|
|
print(f"Using recommended LR for {config.base_model}: {config.learning_rate:.2e}") |
|
|
|
|
|
|
|
|
print(f"Loading training data from {config.train_data_path}...") |
|
|
train_data_raw = load_processed_data(config.train_data_path) |
|
|
print(f"Loading test data from {config.test_data_path}...") |
|
|
test_data_raw = load_processed_data(config.test_data_path) |
|
|
|
|
|
print(f"Train examples: {len(train_data_raw)}") |
|
|
print(f"Test examples: {len(test_data_raw)}") |
|
|
|
|
|
|
|
|
print(f"\nInitializing Tinker ServiceClient...") |
|
|
service_client = tinker.ServiceClient() |
|
|
|
|
|
print(f"Creating LoRA training client...") |
|
|
print(f" Base model: {config.base_model}") |
|
|
print(f" LoRA rank: {config.lora_rank}") |
|
|
|
|
|
training_client = await service_client.create_lora_training_client_async( |
|
|
base_model=config.base_model, |
|
|
rank=config.lora_rank, |
|
|
) |
|
|
|
|
|
|
|
|
tokenizer = training_client.get_tokenizer() |
|
|
renderer = renderers.get_renderer(name=config.renderer_name, tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
print("Converting data to Datum objects...") |
|
|
|
|
|
def convert_to_datum(item: Dict) -> types.Datum: |
|
|
"""Convert preprocessed item back to Datum.""" |
|
|
if "model_input" in item: |
|
|
|
|
|
return types.Datum( |
|
|
model_input=types.ModelInput.from_ints(item["model_input"]["chunks"][0]["tokens"]), |
|
|
loss_fn_inputs=item["loss_fn_inputs"] |
|
|
) |
|
|
else: |
|
|
|
|
|
messages = item["messages"] |
|
|
tokens, weights = renderer.build_supervised_example(messages) |
|
|
|
|
|
|
|
|
if hasattr(tokens, 'tolist'): |
|
|
tokens = tokens.tolist() |
|
|
if hasattr(weights, 'tolist'): |
|
|
weights = weights.tolist() |
|
|
|
|
|
input_tokens = tokens[:-1] |
|
|
target_tokens = tokens[1:] |
|
|
loss_weights = weights[1:] |
|
|
|
|
|
return types.Datum( |
|
|
model_input=types.ModelInput.from_ints(input_tokens), |
|
|
loss_fn_inputs=dict( |
|
|
target_tokens=target_tokens, |
|
|
weights=loss_weights |
|
|
) |
|
|
) |
|
|
|
|
|
train_data = [convert_to_datum(item) for item in train_data_raw] |
|
|
test_data = [convert_to_datum(item) for item in test_data_raw] |
|
|
|
|
|
print(f"Converted {len(train_data)} train, {len(test_data)} test examples") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Starting SFT Training") |
|
|
print(f"{'='*60}") |
|
|
print(f"Steps: {config.num_steps}") |
|
|
print(f"Batch size: {config.batch_size}") |
|
|
print(f"Learning rate: {config.learning_rate:.2e}") |
|
|
print(f"Checkpoint every: {config.checkpoint_every} steps") |
|
|
print(f"Eval every: {config.eval_every} steps") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
metrics_log = [] |
|
|
best_test_loss = float('inf') |
|
|
no_improvement_count = 0 |
|
|
final_checkpoint_path = None |
|
|
|
|
|
for step in range(config.num_steps): |
|
|
step_start = time.time() |
|
|
|
|
|
|
|
|
batch = create_batch(train_data, config.batch_size, step) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fwd_bwd_future = await training_client.forward_backward_async( |
|
|
batch, |
|
|
loss_fn="cross_entropy" |
|
|
) |
|
|
|
|
|
|
|
|
adam_params = types.AdamParams( |
|
|
learning_rate=config.learning_rate, |
|
|
beta1=config.beta1, |
|
|
beta2=config.beta2, |
|
|
eps=config.eps, |
|
|
) |
|
|
optim_future = await training_client.optim_step_async(adam_params) |
|
|
|
|
|
|
|
|
|
|
|
fwd_bwd_result = await fwd_bwd_future.result_async() |
|
|
optim_result = await optim_future.result_async() |
|
|
|
|
|
|
|
|
|
|
|
logprobs = np.concatenate([ |
|
|
output['logprobs'].tolist() |
|
|
for output in fwd_bwd_result.loss_fn_outputs |
|
|
]) |
|
|
weights = np.concatenate([ |
|
|
datum.loss_fn_inputs['weights'].tolist() |
|
|
for datum in batch |
|
|
]) |
|
|
train_loss = -np.dot(logprobs, weights) / max(weights.sum(), 1) |
|
|
|
|
|
step_time = time.time() - step_start |
|
|
|
|
|
|
|
|
metrics = TrainingMetrics( |
|
|
step=step, |
|
|
train_loss=train_loss, |
|
|
learning_rate=config.learning_rate, |
|
|
batch_time=step_time |
|
|
) |
|
|
|
|
|
|
|
|
if step % config.eval_every == 0 or step == config.num_steps - 1: |
|
|
|
|
|
test_batch = create_batch(test_data, min(config.batch_size, len(test_data)), 0) |
|
|
|
|
|
|
|
|
eval_future = await training_client.forward_backward_async( |
|
|
test_batch, |
|
|
loss_fn="cross_entropy" |
|
|
) |
|
|
eval_result = await eval_future.result_async() |
|
|
|
|
|
test_logprobs = np.concatenate([ |
|
|
output['logprobs'].tolist() |
|
|
for output in eval_result.loss_fn_outputs |
|
|
]) |
|
|
test_weights = np.concatenate([ |
|
|
datum.loss_fn_inputs['weights'].tolist() |
|
|
for datum in test_batch |
|
|
]) |
|
|
test_loss = -np.dot(test_logprobs, test_weights) / max(test_weights.sum(), 1) |
|
|
metrics.test_loss = test_loss |
|
|
|
|
|
|
|
|
if test_loss < best_test_loss: |
|
|
best_test_loss = test_loss |
|
|
no_improvement_count = 0 |
|
|
else: |
|
|
no_improvement_count += 1 |
|
|
|
|
|
if no_improvement_count >= config.early_stopping_patience: |
|
|
print(f"\nEarly stopping at step {step} (no improvement for {config.early_stopping_patience} evals)") |
|
|
break |
|
|
|
|
|
|
|
|
if step % config.checkpoint_every == 0 or step == config.num_steps - 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sampler_future = await training_client.save_weights_for_sampler_async( |
|
|
name=f"sft_step_{step:04d}" |
|
|
) |
|
|
sampler_result = await sampler_future.result_async() |
|
|
metrics.checkpoint_path = sampler_result.path |
|
|
|
|
|
|
|
|
if step == config.num_steps - 1: |
|
|
state_future = await training_client.save_state_async( |
|
|
name=f"sft_final_state" |
|
|
) |
|
|
state_result = await state_future.result_async() |
|
|
final_checkpoint_path = state_result.path |
|
|
print(f" Full state checkpoint: {final_checkpoint_path}") |
|
|
else: |
|
|
final_checkpoint_path = sampler_result.path |
|
|
|
|
|
metrics_log.append(metrics) |
|
|
|
|
|
|
|
|
test_str = f", test_loss={metrics.test_loss:.4f}" if metrics.test_loss else "" |
|
|
ckpt_str = f", checkpoint={metrics.checkpoint_path}" if metrics.checkpoint_path else "" |
|
|
print(f"Step {step:4d}/{config.num_steps}: train_loss={train_loss:.4f}{test_str}, time={step_time:.1f}s{ckpt_str}") |
|
|
|
|
|
|
|
|
metrics_path = os.path.join(config.log_path, "metrics.jsonl") |
|
|
with open(metrics_path, "w") as f: |
|
|
for m in metrics_log: |
|
|
f.write(json.dumps({ |
|
|
"step": m.step, |
|
|
"train_loss": m.train_loss, |
|
|
"test_loss": m.test_loss, |
|
|
"learning_rate": m.learning_rate, |
|
|
"batch_time": m.batch_time, |
|
|
"checkpoint_path": m.checkpoint_path |
|
|
}) + "\n") |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"SFT Training Complete") |
|
|
print(f"{'='*60}") |
|
|
print(f"Final train loss: {metrics_log[-1].train_loss:.4f}") |
|
|
print(f"Best test loss: {best_test_loss:.4f}") |
|
|
print(f"Final checkpoint: {final_checkpoint_path}") |
|
|
print(f"Metrics saved to: {metrics_path}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
return final_checkpoint_path, metrics_log |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Entry point for SFT training.""" |
|
|
import sys |
|
|
|
|
|
config = SFTConfig() |
|
|
|
|
|
|
|
|
for arg in sys.argv[1:]: |
|
|
if "=" in arg: |
|
|
key, value = arg.split("=", 1) |
|
|
if hasattr(config, key): |
|
|
|
|
|
current_value = getattr(config, key) |
|
|
if isinstance(current_value, int): |
|
|
setattr(config, key, int(value)) |
|
|
elif isinstance(current_value, float): |
|
|
setattr(config, key, float(value)) |
|
|
else: |
|
|
setattr(config, key, value) |
|
|
|
|
|
await run_sft_training(config) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|