MINDI-1.5-Vision-Coder / src /training /mindi_trainer.py
Mindigenous's picture
Upload src
2957021 verified
"""
MINDI 1.5 Vision-Coder β€” Trainer
Production-ready 3-phase training loop optimized for AMD MI300X (192GB VRAM).
Streams training data from disk (4.18GB train.jsonl) to avoid RAM exhaustion.
Phases:
Phase 1 (steps 0–5000): LoRA only, LR 2e-4, batch 16
Phase 2 (steps 5000–7500): Vision bridge only, LR 1e-5, batch 8
Phase 3 (steps 7500–10000): All trainable, LR 5e-5, batch 12
MI300X specifics:
- ROCm presents as CUDA to PyTorch (torch.cuda.* works)
- bf16 (NOT fp16) for AMD stability
- torch.compile() optional (works on ROCm)
- Gradient checkpointing enabled
- DataLoader: num_workers=4, pin_memory=True, prefetch_factor=2
"""
from __future__ import annotations
import json
import math
import sys
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Iterator, Optional
# Force unbuffered stdout for docker exec -d log visibility
if not sys.stdout.line_buffering:
sys.stdout.reconfigure(line_buffering=True)
import torch
import torch.nn as nn
from PIL import Image
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.utils.data import DataLoader, IterableDataset
# ── Configuration ─────────────────────────────────────────────────────
PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent
@dataclass
class PhaseConfig:
"""Configuration for a single training phase."""
name: str
start_step: int
end_step: int
learning_rate: float
batch_size: int
gradient_accumulation_steps: int = 4
# Component toggles
lora: bool = False
vision_projection: bool = False
fusion: bool = False
# Data type: "text" for code-only, "vision" for image+code, "mixed" for both
data_type: str = "text"
@dataclass
class TrainingConfig:
"""Full training configuration."""
# Data paths
train_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "train.jsonl")
val_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "processed" / "val.jsonl")
vision_train_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "websight" / "train.jsonl")
vision_val_file: Path = field(default_factory=lambda: PROJECT_ROOT / "data" / "websight" / "val.jsonl")
# Output
output_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "checkpoints" / "training")
log_dir: Path = field(default_factory=lambda: PROJECT_ROOT / "logs" / "training")
# Model
max_seq_length: int = 8192
use_compile: bool = False
gradient_checkpointing: bool = True
# Hardware (MI300X defaults)
dtype: str = "bf16"
num_workers: int = 4
pin_memory: bool = True
prefetch_factor: int = 2
# Training
weight_decay: float = 0.01
warmup_ratio: float = 0.03
max_grad_norm: float = 1.0
seed: int = 42
# Logging
log_every_n_steps: int = 10
eval_every_n_steps: int = 250
save_every_n_steps: int = 500
# Phases
phases: list[PhaseConfig] = field(default_factory=lambda: [
PhaseConfig(
name="phase1_lora",
start_step=0, end_step=5000,
learning_rate=2e-4, batch_size=16,
lora=True, vision_projection=False, fusion=False,
data_type="text",
),
PhaseConfig(
name="phase2_vision_bridge",
start_step=5000, end_step=7500,
learning_rate=1e-5, batch_size=8,
lora=False, vision_projection=True, fusion=True,
data_type="vision",
),
PhaseConfig(
name="phase3_all",
start_step=7500, end_step=10000,
learning_rate=5e-5, batch_size=12,
lora=True, vision_projection=True, fusion=True,
data_type="mixed",
),
])
@property
def total_steps(self) -> int:
return self.phases[-1].end_step if self.phases else 0
@property
def torch_dtype(self) -> torch.dtype:
return {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[self.dtype]
# ── Streaming Dataset ─────────────────────────────────────────────────
class StreamingJSONLDataset(IterableDataset):
"""
Streams JSONL training data from disk line by line.
Tokenizes on-the-fly to avoid loading 4+ GB into RAM.
Supports optional image loading for vision-code pairs.
Expected JSONL format:
{"id": "...", "type": "...", "source": "...",
"image_path": "data/websight/images/ws_0000001.jpg", (optional)
"messages": [{"role": "system", "content": "..."},
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."}],
"metadata": {...}}
"""
def __init__(
self,
file_path: Path,
tokenizer: Any,
max_length: int = 8192,
shuffle_buffer: int = 10000,
seed: int = 42,
project_root: Optional[Path] = None,
) -> None:
self.file_path = Path(file_path)
self.tokenizer = tokenizer
self.max_length = max_length
self.shuffle_buffer = shuffle_buffer
self.seed = seed
self.project_root = Path(project_root) if project_root else PROJECT_ROOT
if not self.file_path.exists():
raise FileNotFoundError(f"Training data not found: {self.file_path}")
def _format_messages(self, messages: list[dict[str, str]]) -> str:
"""Format chat messages into a single training string."""
# Use the tokenizer's chat template if available
if hasattr(self.tokenizer, "apply_chat_template"):
return self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
# Fallback: simple concatenation
parts = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
parts.append(f"<|{role}|>\n{content}")
return "\n".join(parts)
def _tokenize(self, text: str) -> Optional[dict[str, torch.Tensor]]:
"""Tokenize text and create training labels."""
encoded = self.tokenizer(
text,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
input_ids = encoded["input_ids"].squeeze(0)
attention_mask = encoded["attention_mask"].squeeze(0)
# Labels = input_ids, with padding tokens masked as -100
labels = input_ids.clone()
labels[attention_mask == 0] = -100
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
def _line_iterator(self) -> Iterator[dict]:
"""Iterate over JSONL file line by line."""
with open(self.file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
yield json.loads(line)
def _shuffled_iterator(self) -> Iterator[dict]:
"""Reservoir-style shuffle buffer for streaming data."""
import random
rng = random.Random(self.seed)
buffer: list[dict] = []
for item in self._line_iterator():
buffer.append(item)
if len(buffer) >= self.shuffle_buffer:
rng.shuffle(buffer)
yield from buffer
buffer.clear()
# Flush remaining items
if buffer:
rng.shuffle(buffer)
yield from buffer
def _load_image(self, image_path: str) -> Optional[Image.Image]:
"""Load image from a relative path. Returns None if missing/corrupt."""
try:
full_path = self.project_root / image_path
if full_path.exists():
img = Image.open(str(full_path)).convert("RGB")
return img
except Exception:
pass
return None
def __iter__(self) -> Iterator[dict[str, Any]]:
for example in self._shuffled_iterator():
messages = example.get("messages", [])
if not messages:
continue
text = self._format_messages(messages)
tokenized = self._tokenize(text)
if tokenized is not None:
# Load image if path present
image_path = example.get("image_path")
if image_path:
tokenized["image"] = self._load_image(image_path)
else:
tokenized["image"] = None
yield tokenized
def count_lines(self) -> int:
"""Count total lines (for progress estimation). Reads file once."""
count = 0
with open(self.file_path, "r", encoding="utf-8") as f:
for _ in f:
count += 1
return count
# ── Trainer ───────────────────────────────────────────────────────────
class MINDITrainer:
"""
3-phase trainer for MINDI 1.5 Vision-Coder.
Optimized for AMD MI300X 192GB:
- bf16 mixed precision
- Gradient checkpointing
- Streaming data from disk
- Optional torch.compile()
- Phase-based component freezing/unfreezing
"""
def __init__(
self,
model: nn.Module,
config: TrainingConfig,
) -> None:
"""
Initialize the trainer.
Args:
model: MINDI15 model instance (already initialized).
config: Training configuration.
"""
self.model = model
self.config = config
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.global_step = 0
self.best_val_loss = float("inf")
# Create output directories
self.config.output_dir.mkdir(parents=True, exist_ok=True)
self.config.log_dir.mkdir(parents=True, exist_ok=True)
# Gradient checkpointing
if config.gradient_checkpointing:
base_model = self.model.architecture.get_model()
if hasattr(base_model, "gradient_checkpointing_enable"):
base_model.gradient_checkpointing_enable()
print("[MINDITrainer] Gradient checkpointing enabled")
# Required for PEFT/LoRA + gradient checkpointing without torch.compile
if hasattr(self.model.llm, "enable_input_require_grads"):
self.model.llm.enable_input_require_grads()
else:
def _make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
self.model.llm.get_input_embeddings().register_forward_hook(_make_inputs_require_grad)
# Optional torch.compile (works on ROCm)
if config.use_compile:
print("[MINDITrainer] Compiling model with torch.compile() ...")
self.model.llm = torch.compile(self.model.llm)
self.model.architecture.peft_model = self.model.llm
print("[MINDITrainer] Compilation complete")
# Mixed precision scaler (bf16 doesn't need GradScaler, but keep structure)
self.use_amp = config.dtype in ("bf16", "fp16")
self.amp_dtype = config.torch_dtype
# Training log
self.log_file = config.log_dir / "training_log.jsonl"
self.metrics_history: list[dict] = []
print(f"[MINDITrainer] Device: {self.device}")
print(f"[MINDITrainer] Dtype: {config.dtype}")
print(f"[MINDITrainer] Total steps: {config.total_steps}")
print(f"[MINDITrainer] Phases: {len(config.phases)}")
def _build_optimizer(self, phase: PhaseConfig) -> AdamW:
"""Build optimizer for the current phase (only trainable params)."""
params = [p for p in self.model.parameters() if p.requires_grad]
if not params:
raise RuntimeError(f"No trainable parameters in phase '{phase.name}'")
return AdamW(
params,
lr=phase.learning_rate,
weight_decay=self.config.weight_decay,
betas=(0.9, 0.95),
)
def _build_scheduler(
self, optimizer: AdamW, phase: PhaseConfig
) -> torch.optim.lr_scheduler.LRScheduler:
"""Build LR scheduler: linear warmup + cosine decay."""
phase_steps = phase.end_step - phase.start_step
warmup_steps = max(1, int(phase_steps * self.config.warmup_ratio))
decay_steps = max(1, phase_steps - warmup_steps)
warmup = LinearLR(
optimizer,
start_factor=0.01,
end_factor=1.0,
total_iters=warmup_steps,
)
cosine = CosineAnnealingLR(
optimizer,
T_max=decay_steps,
eta_min=phase.learning_rate * 0.1,
)
return SequentialLR(
optimizer,
schedulers=[warmup, cosine],
milestones=[warmup_steps],
)
def _build_dataloader(
self, file_path: Path, batch_size: int, shuffle_buffer: int = 10000
) -> DataLoader:
"""Build a streaming DataLoader."""
dataset = StreamingJSONLDataset(
file_path=file_path,
tokenizer=self.model.tokenizer,
max_length=self.config.max_seq_length,
shuffle_buffer=shuffle_buffer,
seed=self.config.seed,
)
def _collate_fn(batch):
"""Custom collate: stack tensors, keep images as list."""
collated = {
"input_ids": torch.stack([b["input_ids"] for b in batch]),
"attention_mask": torch.stack([b["attention_mask"] for b in batch]),
"labels": torch.stack([b["labels"] for b in batch]),
"images": [b.get("image") for b in batch],
}
return collated
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=self.config.num_workers,
pin_memory=self.config.pin_memory,
prefetch_factor=self.config.prefetch_factor if self.config.num_workers > 0 else None,
drop_last=True,
collate_fn=_collate_fn,
)
def _log_metrics(self, metrics: dict) -> None:
"""Append metrics to log file and history."""
self.metrics_history.append(metrics)
with open(self.log_file, "a", encoding="utf-8") as f:
f.write(json.dumps(metrics) + "\n")
@torch.no_grad()
def evaluate(self, val_loader: DataLoader, max_batches: int = 50) -> float:
"""
Run validation and return average loss.
Args:
val_loader: Validation DataLoader.
max_batches: Maximum batches to evaluate (for speed).
Returns:
Average validation loss.
"""
self.model.eval()
total_loss = 0.0
count = 0
for batch_idx, batch in enumerate(val_loader):
if batch_idx >= max_batches:
break
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
images = batch.get("images")
image = None
if images:
for img in images:
if img is not None:
image = img
break
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
result = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
image=image,
)
if result["loss"] is not None:
total_loss += result["loss"].item()
count += 1
self.model.train()
return total_loss / max(count, 1)
def _save_checkpoint(self, phase_name: str, step: int, val_loss: float) -> Path:
"""Save a training checkpoint."""
ckpt_dir = self.config.output_dir / f"{phase_name}_step{step}"
ckpt_dir.mkdir(parents=True, exist_ok=True)
# Save model weights
self.model.save(ckpt_dir)
# Save trainer state
state = {
"global_step": self.global_step,
"phase": phase_name,
"step_in_phase": step,
"val_loss": val_loss,
"best_val_loss": self.best_val_loss,
}
torch.save(state, ckpt_dir / "trainer_state.pt")
print(f"[MINDITrainer] Checkpoint saved: {ckpt_dir}")
return ckpt_dir
def train_phase(self, phase: PhaseConfig) -> dict:
"""
Execute a single training phase.
Args:
phase: Phase configuration.
Returns:
Dict with phase training metrics.
"""
print()
print("=" * 60)
print(f" Phase: {phase.name}")
print(f" Steps: {phase.start_step} β†’ {phase.end_step}")
print(f" LR: {phase.learning_rate} | Batch: {phase.batch_size}")
print(f" Components: LoRA={phase.lora}, Vision={phase.vision_projection}, "
f"Fusion={phase.fusion}")
print(f" Data: {phase.data_type}")
print("=" * 60)
# Set trainable components
self.model.set_trainable_components(
lora=phase.lora,
vision_projection=phase.vision_projection,
fusion=phase.fusion,
)
# Build optimizer and scheduler for this phase
optimizer = self._build_optimizer(phase)
scheduler = self._build_scheduler(optimizer, phase)
# Select data files based on phase data_type
if phase.data_type == "vision":
train_file = self.config.vision_train_file
val_file = self.config.vision_val_file
else:
# "text" or "mixed" β€” use main data (mixed has images inline)
train_file = self.config.train_file
val_file = self.config.val_file
# Build data loaders
train_loader = self._build_dataloader(
train_file, phase.batch_size
)
val_loader = self._build_dataloader(
val_file, batch_size=max(phase.batch_size // 2, 1),
shuffle_buffer=1000,
)
self.model.train()
phase_steps = phase.end_step - phase.start_step
step_in_phase = 0
accum_loss = 0.0
accum_count = 0
phase_start_time = time.time()
train_iter = iter(train_loader)
while step_in_phase < phase_steps:
# Get next batch (restart iterator if exhausted = new epoch)
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
batch = next(train_iter)
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
images = batch.get("images") # list of PIL Images or Nones
# Pick first non-None image in batch (model processes one image at a time)
image = None
if images:
for img in images:
if img is not None:
image = img
break
# Forward pass with mixed precision
with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.use_amp):
result = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
image=image,
)
loss = result["loss"]
if loss is None:
continue
# Scale loss for gradient accumulation
loss = loss / phase.gradient_accumulation_steps
# Backward pass
loss.backward()
accum_loss += loss.item() * phase.gradient_accumulation_steps
accum_count += 1
# Optimizer step (every gradient_accumulation_steps)
if accum_count % phase.gradient_accumulation_steps == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(
[p for p in self.model.parameters() if p.requires_grad],
self.config.max_grad_norm,
)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
step_in_phase += 1
self.global_step += 1
avg_loss = accum_loss / phase.gradient_accumulation_steps
accum_loss = 0.0
# Logging
if step_in_phase % self.config.log_every_n_steps == 0:
elapsed = time.time() - phase_start_time
steps_per_sec = step_in_phase / elapsed if elapsed > 0 else 0.0
eta_sec = (phase_steps - step_in_phase) / steps_per_sec if steps_per_sec > 0 else 0.0
metrics = {
"phase": phase.name,
"global_step": self.global_step,
"step_in_phase": step_in_phase,
"loss": round(avg_loss, 4),
"lr": optimizer.param_groups[0]["lr"],
"steps_per_sec": round(steps_per_sec, 3),
"eta_minutes": round(eta_sec / 60, 1),
"elapsed_minutes": round(elapsed / 60, 1),
}
self._log_metrics(metrics)
print(f" [{phase.name}] step {step_in_phase}/{phase_steps} | "
f"loss={avg_loss:.4f} | "
f"lr={optimizer.param_groups[0]['lr']:.2e} | "
f"speed={steps_per_sec:.2f} steps/s | "
f"ETA={eta_sec / 60:.1f}min")
# Evaluation
if step_in_phase % self.config.eval_every_n_steps == 0:
val_loss = self.evaluate(val_loader)
print(f" [{phase.name}] EVAL step {step_in_phase} | val_loss={val_loss:.4f}")
self._log_metrics({
"phase": phase.name,
"global_step": self.global_step,
"val_loss": round(val_loss, 4),
"type": "eval",
})
# Save best model
if val_loss < self.best_val_loss:
self.best_val_loss = val_loss
self._save_checkpoint(phase.name, step_in_phase, val_loss)
print(f" [{phase.name}] New best val_loss: {val_loss:.4f}")
# Periodic save
if step_in_phase % self.config.save_every_n_steps == 0:
self._save_checkpoint(phase.name, step_in_phase, self.best_val_loss)
# End-of-phase save
phase_elapsed = time.time() - phase_start_time
self._save_checkpoint(phase.name, step_in_phase, self.best_val_loss)
phase_summary = {
"phase": phase.name,
"total_steps": step_in_phase,
"elapsed_minutes": round(phase_elapsed / 60, 1),
"best_val_loss": round(self.best_val_loss, 4),
"type": "phase_complete",
}
self._log_metrics(phase_summary)
print(f"\n [{phase.name}] Complete β€” {step_in_phase} steps in "
f"{phase_elapsed / 60:.1f} min")
return phase_summary
def train(self) -> dict:
"""
Run all 3 training phases sequentially.
Returns:
Dict with complete training summary.
"""
print()
print("=" * 60)
print(" MINDI 1.5 β€” Training Start")
print(f" Total phases: {len(self.config.phases)}")
print(f" Total steps: {self.config.total_steps}")
print(f" Device: {self.device}")
print(f" Dtype: {self.config.dtype}")
print(f" Output: {self.config.output_dir}")
print("=" * 60)
torch.manual_seed(self.config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(self.config.seed)
training_start = time.time()
phase_summaries = []
for phase in self.config.phases:
summary = self.train_phase(phase)
phase_summaries.append(summary)
total_elapsed = time.time() - training_start
# Final save
final_dir = self.config.output_dir / "final"
final_dir.mkdir(parents=True, exist_ok=True)
self.model.save(final_dir)
training_summary = {
"total_steps": self.global_step,
"total_minutes": round(total_elapsed / 60, 1),
"best_val_loss": round(self.best_val_loss, 4),
"phases": phase_summaries,
"type": "training_complete",
}
self._log_metrics(training_summary)
print()
print("=" * 60)
print(" MINDI 1.5 β€” Training Complete")
print(f" Total steps: {self.global_step}")
print(f" Total time: {total_elapsed / 60:.1f} minutes")
print(f" Best val loss: {self.best_val_loss:.4f}")
print(f" Final saved to: {final_dir}")
print("=" * 60)
return training_summary
def resume_from_checkpoint(self, checkpoint_dir: Path) -> None:
"""
Resume training from a checkpoint.
Args:
checkpoint_dir: Directory containing saved checkpoint.
"""
checkpoint_dir = Path(checkpoint_dir)
state_file = checkpoint_dir / "trainer_state.pt"
if not state_file.exists():
raise FileNotFoundError(f"Trainer state not found: {state_file}")
# Load model weights
self.model.load(checkpoint_dir)
# Load trainer state
state = torch.load(state_file, map_location=self.device, weights_only=True)
self.global_step = state["global_step"]
self.best_val_loss = state["best_val_loss"]
print(f"[MINDITrainer] Resumed from step {self.global_step} "
f"(val_loss={self.best_val_loss:.4f})")
# ── Test block ────────────────────────────────────────────────────────
if __name__ == "__main__":
print("=" * 60)
print(" MINDI 1.5 β€” Trainer Test")
print("=" * 60)
print()
# ── Test 1: Config defaults ──────────────────────────────────
print(" Test 1: TrainingConfig defaults")
config = TrainingConfig()
assert config.total_steps == 10000
assert config.dtype == "bf16"
assert config.torch_dtype == torch.bfloat16
assert len(config.phases) == 3
assert config.gradient_checkpointing is True
assert config.num_workers == 4
assert config.pin_memory is True
assert config.prefetch_factor == 2
print(f" Total steps: {config.total_steps}")
print(f" Dtype: {config.dtype}")
print(f" Phases: {[p.name for p in config.phases]}")
print(" βœ“ Config defaults correct")
# ── Test 2: Phase configs ────────────────────────────────────
print("\n Test 2: Phase configurations")
p1, p2, p3 = config.phases
assert p1.name == "phase1_lora"
assert p1.batch_size == 16
assert p1.learning_rate == 2e-4
assert p1.lora is True and p1.vision_projection is False and p1.fusion is False
assert p2.name == "phase2_vision_bridge"
assert p2.batch_size == 8
assert p2.learning_rate == 1e-5
assert p2.lora is False and p2.vision_projection is True and p2.fusion is True
assert p3.name == "phase3_all"
assert p3.batch_size == 12
assert p3.learning_rate == 5e-5
assert p3.lora is True and p3.vision_projection is True and p3.fusion is True
print(" Phase 1: LoRA only, batch=16, lr=2e-4 βœ“")
print(" Phase 2: Vision bridge, batch=8, lr=1e-5 βœ“")
print(" Phase 3: All, batch=12, lr=5e-5 βœ“")
# ── Test 3: Streaming dataset (if data exists) ───────────────
print("\n Test 3: StreamingJSONLDataset")
train_path = config.train_file
if train_path.exists():
from transformers import AutoTokenizer
tok = AutoTokenizer.from_pretrained(
str(PROJECT_ROOT / "data" / "tokenizer" / "mindi_tokenizer"),
trust_remote_code=True,
)
dataset = StreamingJSONLDataset(
file_path=train_path,
tokenizer=tok,
max_length=512, # small for test
shuffle_buffer=100,
)
count = 0
for item in dataset:
assert "input_ids" in item
assert "attention_mask" in item
assert "labels" in item
assert item["input_ids"].shape[0] == 512
count += 1
if count >= 5:
break
print(f" Streamed {count} examples, shape={item['input_ids'].shape} βœ“")
else:
print(f" [SKIP] Train file not found: {train_path}")
# ── Test 4: PhaseConfig step ranges ──────────────────────────
print("\n Test 4: Phase step continuity")
for i in range(1, len(config.phases)):
prev = config.phases[i - 1]
curr = config.phases[i]
assert prev.end_step == curr.start_step, \
f"Gap between {prev.name} and {curr.name}"
print(" All phases are contiguous βœ“")
# ── Test 5: Gradient accumulation ────────────────────────────
print("\n Test 5: Gradient accumulation steps")
for phase in config.phases:
assert phase.gradient_accumulation_steps == 4
print(" All phases: grad_accum=4 βœ“")
print("\n βœ“ All trainer tests passed!")
print("=" * 60)