|
|
""" |
|
|
Training Monitor - Check progress and evaluate completed models. |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
from datetime import datetime |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
import tinker |
|
|
from tinker import types |
|
|
from tinker_cookbook import renderers |
|
|
from tinker_cookbook.tokenizer_utils import get_tokenizer |
|
|
import numpy as np |
|
|
|
|
|
BASE_MODEL = "meta-llama/Llama-3.1-8B" |
|
|
|
|
|
VALID_CATEGORIES = { |
|
|
"company.brand_core", "company.strategic_signatures", "company.knowledge_artifacts", |
|
|
"company.business_priorities", "company.tools_config", "company.performance_context", |
|
|
"user.communication_style", "user.strategic_approach", "user.role_context", |
|
|
"user.workflow_patterns", "user.session_history", "user.interaction_preferences", |
|
|
"none" |
|
|
} |
|
|
|
|
|
|
|
|
def list_training_runs(): |
|
|
"""List all training runs and their checkpoints.""" |
|
|
service_client = tinker.ServiceClient() |
|
|
rest_client = service_client.create_rest_client() |
|
|
|
|
|
runs = rest_client.list_training_runs().result() |
|
|
|
|
|
print("=" * 70) |
|
|
print("TRAINING RUNS") |
|
|
print("=" * 70) |
|
|
|
|
|
for run in runs.training_runs[:10]: |
|
|
ckpts = rest_client.list_checkpoints(run.training_run_id).result() |
|
|
|
|
|
|
|
|
sft_ckpts = [c for c in ckpts.checkpoints if 'sft' in c.checkpoint_id] |
|
|
rl_ckpts = [c for c in ckpts.checkpoints if 'rl_' in c.checkpoint_id] |
|
|
|
|
|
print(f"\nRun: {run.training_run_id}") |
|
|
print(f" Last request: {run.last_request_time}") |
|
|
print(f" SFT checkpoints: {len(sft_ckpts)}") |
|
|
print(f" RL checkpoints: {len(rl_ckpts)}") |
|
|
|
|
|
if rl_ckpts: |
|
|
|
|
|
latest = sorted(rl_ckpts, key=lambda x: x.time)[-1] |
|
|
print(f" Latest RL: {latest.checkpoint_id}") |
|
|
|
|
|
|
|
|
if 'final' in latest.checkpoint_id: |
|
|
print(f" STATUS: RL COMPLETE") |
|
|
print(f" Final checkpoint: tinker://{run.training_run_id}/{latest.checkpoint_id}") |
|
|
|
|
|
|
|
|
async def quick_eval(checkpoint_path: str, n_samples: int = 20): |
|
|
"""Quick evaluation of a checkpoint.""" |
|
|
service_client = tinker.ServiceClient() |
|
|
tokenizer = get_tokenizer(BASE_MODEL) |
|
|
renderer = renderers.get_renderer(name="llama3", tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
with open("training/processed_data/test_data.json", "r") as f: |
|
|
test_data = json.load(f) |
|
|
|
|
|
print(f"\nEvaluating: {checkpoint_path}") |
|
|
print(f"Samples: {n_samples}") |
|
|
|
|
|
sampling_client = service_client.create_sampling_client(model_path=checkpoint_path) |
|
|
stop_sequences = renderer.get_stop_sequences() |
|
|
|
|
|
correct = 0 |
|
|
total = 0 |
|
|
|
|
|
for example in test_data[:n_samples]: |
|
|
gold = example.get("categories", []) |
|
|
messages = example.get("messages", []) |
|
|
prompt_messages = [m for m in messages if m.get("role") != "assistant"] |
|
|
|
|
|
if not prompt_messages: |
|
|
continue |
|
|
|
|
|
prompt = renderer.build_generation_prompt(prompt_messages) |
|
|
params = types.SamplingParams(max_tokens=50, temperature=0.1, stop=stop_sequences) |
|
|
|
|
|
result = sampling_client.sample(prompt=prompt, sampling_params=params, num_samples=1).result() |
|
|
response, success = renderer.parse_response(result.sequences[0].tokens) |
|
|
predicted_text = response["content"] if success else "" |
|
|
|
|
|
predicted_set = set([c.strip().lower() for c in predicted_text.split(",") |
|
|
if c.strip().lower() in VALID_CATEGORIES]) |
|
|
gold_set = set([c.lower() for c in gold]) |
|
|
|
|
|
if predicted_set & gold_set: |
|
|
correct += 1 |
|
|
total += 1 |
|
|
|
|
|
accuracy = correct / total if total > 0 else 0 |
|
|
print(f"Any Match Accuracy: {accuracy:.1%} ({correct}/{total})") |
|
|
|
|
|
return accuracy |
|
|
|
|
|
|
|
|
def find_best_checkpoint(): |
|
|
"""Find the best completed RL checkpoint.""" |
|
|
service_client = tinker.ServiceClient() |
|
|
rest_client = service_client.create_rest_client() |
|
|
|
|
|
runs = rest_client.list_training_runs().result() |
|
|
|
|
|
best_rl_checkpoint = None |
|
|
best_sft_checkpoint = None |
|
|
|
|
|
for run in runs.training_runs: |
|
|
ckpts = rest_client.list_checkpoints(run.training_run_id).result() |
|
|
|
|
|
for ckpt in ckpts.checkpoints: |
|
|
if 'rl_final' in ckpt.checkpoint_id: |
|
|
path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}" |
|
|
if best_rl_checkpoint is None or ckpt.time > best_rl_checkpoint[1]: |
|
|
best_rl_checkpoint = (path, ckpt.time) |
|
|
|
|
|
if 'sft_final_sampler' in ckpt.checkpoint_id: |
|
|
path = f"tinker://{run.training_run_id}/{ckpt.checkpoint_id}" |
|
|
if best_sft_checkpoint is None or ckpt.time > best_sft_checkpoint[1]: |
|
|
best_sft_checkpoint = (path, ckpt.time) |
|
|
|
|
|
return best_sft_checkpoint, best_rl_checkpoint |
|
|
|
|
|
|
|
|
async def main(): |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) > 1 and sys.argv[1] == "eval": |
|
|
|
|
|
sft_ckpt, rl_ckpt = find_best_checkpoint() |
|
|
|
|
|
print("=" * 70) |
|
|
print("CHECKPOINT EVALUATION") |
|
|
print("=" * 70) |
|
|
|
|
|
if sft_ckpt: |
|
|
print(f"\nBest SFT: {sft_ckpt[0]}") |
|
|
await quick_eval(sft_ckpt[0], n_samples=50) |
|
|
|
|
|
if rl_ckpt: |
|
|
print(f"\nBest RL: {rl_ckpt[0]}") |
|
|
await quick_eval(rl_ckpt[0], n_samples=50) |
|
|
else: |
|
|
|
|
|
list_training_runs() |
|
|
|
|
|
sft_ckpt, rl_ckpt = find_best_checkpoint() |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("BEST CHECKPOINTS") |
|
|
print("=" * 70) |
|
|
|
|
|
if sft_ckpt: |
|
|
print(f"\nSFT: {sft_ckpt[0]}") |
|
|
print(f" Time: {sft_ckpt[1]}") |
|
|
|
|
|
if rl_ckpt: |
|
|
print(f"\nRL: {rl_ckpt[0]}") |
|
|
print(f" Time: {rl_ckpt[1]}") |
|
|
|
|
|
print("\nTo evaluate, run: python training/monitor.py eval") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
asyncio.run(main()) |
|
|
|
|
|
|