aam-diffusion-v1 / diffusion_llm /scripts /train_minimal.py
Wolfvin's picture
AAM Diffusion LLM v1.0 — The Body of Aphantasic Abstraction Model
2d7e335 verified
#!/usr/bin/env python3
"""
AAM Diffusion LLM — Minimal Training Script for CPU
Trains a very small AAM Diffusion LLM model on CPU.
"""
import sys
import json
import time
import logging
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
import numpy as np
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger("train")
def main():
from diffusion_llm.config.model_config import (
AamDiffusionConfig, ModelConfig, DiffusionConfig,
GraphEncoderConfig, TokenizerConfig, TrainingConfig, InferenceConfig,
)
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
from torch.utils.data import DataLoader
output_dir = Path("./aam-diffusion-v1")
output_dir.mkdir(parents=True, exist_ok=True)
data_dir = output_dir / "data"
data_dir.mkdir(parents=True, exist_ok=True)
# ===== STEP 1: Generate Data =====
logger.info("STEP 1: Generating synthetic data...")
train_path, val_path = SyntheticDataGenerator.generate_training_split(
output_dir=data_dir, n_train=200, n_val=20, language="id", seed=42,
)
# ===== STEP 2: Train Tokenizer =====
logger.info("STEP 2: Training tokenizer...")
tokenizer = AamTokenizer()
texts = []
with open(train_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
for key in ["narrative", "trigger"]:
if data.get(key):
texts.append(data[key])
for key in ["evidence_nodes", "anomalies", "reasoning_steps"]:
for item in data.get(key, []):
texts.append(item)
except json.JSONDecodeError:
continue
tokenizer.train(texts, vocab_size=2000)
tokenizer.save(data_dir / "tokenizer.json")
actual_vocab = tokenizer.vocab_size
logger.info(f" Tokenizer: vocab_size={actual_vocab}, merges={len(tokenizer.merges)}")
# ===== STEP 3: Config =====
config = AamDiffusionConfig(
model=ModelConfig(
d_model=128,
n_layers=2,
n_heads=4,
d_ff=256,
vocab_size=actual_vocab,
max_seq_len=64,
pos_encoding_type="learned",
use_flash_attention=False,
norm_type="layernorm",
init_std=0.02,
),
diffusion=DiffusionConfig(
n_timesteps=100,
n_inference_steps=10,
schedule_type="cosine",
prediction_type="epsilon",
loss_type="mse",
loss_weighting="none",
),
graph_encoder=GraphEncoderConfig(
d_graph=64,
n_graph_layers=1,
n_graph_heads=2,
max_evidence_nodes=5,
max_compositions=3,
max_anomalies=3,
max_reasoning_steps=3,
conditioning_method="cross_attention",
embed_confidence=False,
embed_temporal=False,
),
tokenizer=TokenizerConfig(bpe_vocab_size=2000),
training=TrainingConfig(
batch_size=4,
learning_rate=1e-3,
max_steps=100,
warmup_steps=10,
use_amp=False,
num_workers=0,
grad_clip_norm=1.0,
),
inference=InferenceConfig(n_steps=10),
model_name="aam-diffusion-v1.0",
output_dir=str(output_dir),
seed=42,
)
# ===== STEP 4: Create Model =====
logger.info("STEP 3: Creating model...")
model = AamDiffusionModel(config)
n_params = model.get_num_params()
logger.info(f" Parameters: {model._format_params(n_params)} ({n_params:,})")
# ===== STEP 5: Create DataLoaders =====
logger.info("STEP 4: Creating dataloaders...")
train_dataset = GraphNarrativeDataset(
data_path=train_path, tokenizer=tokenizer,
max_seq_len=config.model.max_seq_len,
max_evidence=config.graph_encoder.max_evidence_nodes,
max_anomalies=config.graph_encoder.max_anomalies,
max_reasoning=config.graph_encoder.max_reasoning_steps,
augment=True,
)
val_dataset = GraphNarrativeDataset(
data_path=val_path, tokenizer=tokenizer,
max_seq_len=config.model.max_seq_len,
max_evidence=config.graph_encoder.max_evidence_nodes,
max_anomalies=config.graph_encoder.max_anomalies,
max_reasoning=config.graph_encoder.max_reasoning_steps,
augment=False,
)
train_loader = DataLoader(
train_dataset, batch_size=4, shuffle=True,
num_workers=0, collate_fn=collate_fn,
)
val_loader = DataLoader(
val_dataset, batch_size=4, shuffle=False,
num_workers=0, collate_fn=collate_fn,
)
# ===== STEP 6: Train =====
logger.info("STEP 5: Training...")
device = torch.device("cpu")
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.01)
max_steps = 100
start_time = time.time()
global_step = 0
train_losses = []
for epoch in range(50): # Max epochs
model.train()
for batch in train_loader:
if global_step >= max_steps:
break
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
batch_size = batch["token_ids"].shape[0]
t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device)
predicted, target = model(
token_ids=batch["token_ids"],
timestep=t,
evidence_ids=batch.get("evidence_ids"),
evidence_confidence=batch.get("evidence_confidence"),
anomaly_ids=batch.get("anomaly_ids"),
anomaly_confidence=batch.get("anomaly_confidence"),
reasoning_ids=batch.get("reasoning_ids"),
reasoning_confidence=batch.get("reasoning_confidence"),
source_trust=batch.get("source_trust"),
)
loss = model.compute_loss(predicted, target, t)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
train_losses.append(loss.item())
global_step += 1
if global_step % 10 == 0:
avg = sum(train_losses[-10:]) / len(train_losses[-10:])
elapsed = time.time() - start_time
logger.info(f" Step {global_step}/{max_steps} | Loss: {avg:.4f} | Time: {elapsed:.1f}s")
if global_step >= max_steps:
break
# ===== STEP 7: Evaluate =====
logger.info("STEP 6: Evaluating...")
model.eval()
val_losses = []
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
batch_size = batch["token_ids"].shape[0]
t = torch.randint(0, config.diffusion.n_timesteps, (batch_size,), device=device)
predicted, target = model(
token_ids=batch["token_ids"],
timestep=t,
evidence_ids=batch.get("evidence_ids"),
evidence_confidence=batch.get("evidence_confidence"),
anomaly_ids=batch.get("anomaly_ids"),
anomaly_confidence=batch.get("anomaly_confidence"),
reasoning_ids=batch.get("reasoning_ids"),
reasoning_confidence=batch.get("reasoning_confidence"),
source_trust=batch.get("source_trust"),
)
loss = model.compute_loss(predicted, target, t)
val_losses.append(loss.item())
avg_val_loss = sum(val_losses) / len(val_losses) if val_losses else 0
logger.info(f" Val loss: {avg_val_loss:.4f}")
# ===== STEP 8: Save =====
logger.info("STEP 7: Saving model...")
# Save model
model_path = output_dir / "model.pt"
torch.save({
"model_state_dict": model.state_dict(),
"config": config.to_dict(),
}, model_path)
# Save tokenizer (already saved)
# Save config
config.to_json(output_dir / "config.json")
elapsed = time.time() - start_time
logger.info(f"\n DONE! {global_step} steps in {elapsed:.1f}s")
logger.info(f" Final train loss: {train_losses[-1]:.4f}")
logger.info(f" Val loss: {avg_val_loss:.4f}")
logger.info(f" Parameters: {model._format_params(n_params)}")
logger.info(f" Output: {output_dir}")
return model, tokenizer, config, output_dir
if __name__ == "__main__":
main()