aam-diffusion-v1 / diffusion_llm /data /data_pipeline.py
Wolfvin's picture
AAM Diffusion LLM v1.0 — The Body of Aphantasic Abstraction Model
2d7e335 verified
"""
AAM Diffusion LLM — Data Pipeline
Orchestrates data preparation: from raw graph data and narratives
to tokenized, batched training data.
The pipeline handles:
1. Loading raw graph→narrative pairs
2. Generating synthetic data if real data isn't available
3. Tokenizing all data
4. Creating train/val splits
5. Building DataLoaders
Analogi: Seperti proses persiapan sebelum Jin Soun berlatih —
mengumpulkan semua kasus, mengorganisirnya, dan menyiapkan
data latihan yang terstruktur.
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Optional
from torch.utils.data import DataLoader
from diffusion_llm.config.model_config import AamDiffusionConfig
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
from diffusion_llm.training.dataset import GraphNarrativeDataset, collate_fn
from diffusion_llm.data.synthetic_generator import SyntheticDataGenerator
logger = logging.getLogger(__name__)
class DataPipeline:
"""Data preparation pipeline for AAM Diffusion LLM training.
Orchestrates the entire data preparation process:
1. Check for existing data
2. Generate synthetic data if needed
3. Train tokenizer on the data
4. Create datasets and dataloaders
Usage:
pipeline = DataPipeline(config)
tokenizer, train_loader, val_loader = pipeline.prepare()
"""
def __init__(self, config: AamDiffusionConfig):
self.config = config
self.output_dir = Path(config.output_dir) / "data"
self.output_dir.mkdir(parents=True, exist_ok=True)
def prepare(
self,
tokenizer: Optional[AamTokenizer] = None,
force_regenerate: bool = False,
) -> tuple[AamTokenizer, DataLoader, Optional[DataLoader]]:
"""Prepare all data for training.
Args:
tokenizer: Optional pre-trained tokenizer.
force_regenerate: Whether to regenerate synthetic data.
Returns:
Tuple of (tokenizer, train_loader, val_loader).
"""
train_path = Path(self.config.training.train_data_path) if self.config.training.train_data_path else None
val_path = Path(self.config.training.val_data_path) if self.config.training.val_data_path else None
# Step 1: Generate synthetic data if no real data
if not train_path or not train_path.exists() or force_regenerate:
logger.info("Generating synthetic training data...")
train_path, val_path = SyntheticDataGenerator.generate_training_split(
output_dir=self.output_dir,
n_train=10000,
n_val=500,
language=self.config.inference.language,
seed=self.config.seed,
)
# Step 2: Train tokenizer if not provided
if tokenizer is None or not tokenizer.is_trained:
logger.info("Training tokenizer...")
tokenizer = AamTokenizer()
# Read training texts for tokenizer training
texts = self._read_texts(train_path)
tokenizer.train(texts, vocab_size=self.config.tokenizer.bpe_vocab_size)
tokenizer.save(self.output_dir / "tokenizer.json")
logger.info("Tokenizer trained and saved. Vocab size: %d", tokenizer.vocab_size)
# Step 3: Create datasets
logger.info("Creating datasets...")
train_dataset = GraphNarrativeDataset(
data_path=train_path,
tokenizer=tokenizer,
max_seq_len=self.config.model.max_seq_len,
max_evidence=self.config.graph_encoder.max_evidence_nodes,
max_anomalies=self.config.graph_encoder.max_anomalies,
max_reasoning=self.config.graph_encoder.max_reasoning_steps,
)
val_dataset = None
if val_path and val_path.exists():
val_dataset = GraphNarrativeDataset(
data_path=val_path,
tokenizer=tokenizer,
max_seq_len=self.config.model.max_seq_len,
max_evidence=self.config.graph_encoder.max_evidence_nodes,
max_anomalies=self.config.graph_encoder.max_anomalies,
max_reasoning=self.config.graph_encoder.max_reasoning_steps,
augment=False, # No augmentation for validation
)
# Step 4: Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=self.config.training.batch_size,
shuffle=True,
num_workers=self.config.training.num_workers,
collate_fn=collate_fn,
pin_memory=True,
)
val_loader = None
if val_dataset:
val_loader = DataLoader(
val_dataset,
batch_size=self.config.training.batch_size,
shuffle=False,
num_workers=self.config.training.num_workers,
collate_fn=collate_fn,
pin_memory=True,
)
logger.info(
"Data pipeline ready: %d training examples, %s validation examples",
len(train_dataset),
len(val_dataset) if val_dataset else 0,
)
return tokenizer, train_loader, val_loader
def _read_texts(self, path: Path) -> list[str]:
"""Read narrative texts from JSONL file for tokenizer training.
Args:
path: Path to JSONL data file.
Returns:
List of narrative texts.
"""
import json
texts = []
if not path.exists():
return texts
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
# Collect both narratives and evidence for richer tokenizer
if data.get("narrative"):
texts.append(data["narrative"])
if data.get("trigger"):
texts.append(data["trigger"])
for ev in data.get("evidence_nodes", []):
texts.append(ev)
for anom in data.get("anomalies", []):
texts.append(anom)
for step in data.get("reasoning_steps", []):
texts.append(step)
except json.JSONDecodeError:
continue
return texts