| """Training module for AAM Diffusion LLM.""" | |
| from diffusion_llm.training.trainer import AamTrainer | |
| from diffusion_llm.training.dataset import GraphNarrativeDataset | |
| from diffusion_llm.training.losses import DiffusionLoss, compute_loss | |
| from diffusion_llm.training.llm_jepa import JEPAPredictor, JEPAConfig, JEPATrainer | |
| __all__ = ["AamTrainer", "GraphNarrativeDataset", "DiffusionLoss", "compute_loss", "JEPAPredictor", "JEPAConfig", "JEPATrainer"] | |