| import argparse |
| import json |
| import os |
| import random |
| from pathlib import Path |
| from typing import Dict, Optional |
|
|
| import numpy as np |
| import torch |
| from torch.optim import AdamW |
| from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LinearLR, SequentialLR |
| from tqdm import tqdm |
|
|
| from loader import SoilFormerDataset, build_train_eval_dataloaders |
| from soilformer import SoilFormer, loss_function |
| from utils import get_dtype, load_json, save_json |
|
|
| try: |
| import wandb |
| except ImportError: |
| wandb = None |
|
|
|
|
| def set_seed(seed: int, deterministic: bool = True) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| if deterministic: |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def resolve_device(device_str: str) -> torch.device: |
| device_str = device_str.lower() |
|
|
| if device_str == "cuda": |
| if not torch.cuda.is_available(): |
| raise RuntimeError("config requests cuda, but CUDA is not available") |
| return torch.device("cuda") |
|
|
| if device_str == "mps": |
| if not torch.backends.mps.is_available(): |
| raise RuntimeError("config requests mps, but MPS is not available") |
| return torch.device("mps") |
|
|
| if device_str == "cpu": |
| return torch.device("cpu") |
|
|
| raise ValueError(f"Unsupported device: {device_str}") |
|
|
|
|
| def move_batch_to_device(batch: Dict, device: torch.device, float_dtype: torch.dtype) -> Dict: |
| out = {} |
| for key, value in batch.items(): |
| if isinstance(value, torch.Tensor): |
| if value.dtype.is_floating_point: |
| out[key] = value.to(device=device, dtype=float_dtype, non_blocking=True) |
| else: |
| out[key] = value.to(device=device, non_blocking=True) |
| elif isinstance(value, dict): |
| sub = {} |
| for sub_key, sub_value in value.items(): |
| if isinstance(sub_value, torch.Tensor): |
| if sub_value.dtype.is_floating_point: |
| sub[sub_key] = sub_value.to(device=device, dtype=float_dtype, non_blocking=True) |
| else: |
| sub[sub_key] = sub_value.to(device=device, non_blocking=True) |
| else: |
| sub[sub_key] = sub_value |
| out[key] = sub |
| else: |
| out[key] = value |
| return out |
|
|
|
|
| def build_scheduler( |
| optimizer: torch.optim.Optimizer, |
| scheduler_cfg: Dict, |
| ): |
| scheduler_type = str(scheduler_cfg.get("type", "none")).lower() |
|
|
| if scheduler_type == "none": |
| return None |
|
|
| warmup_epochs = int(scheduler_cfg.get("warmup_epochs", 0)) |
| warmup_start_factor = float(scheduler_cfg.get("warmup_start_factor", 0.1)) |
|
|
| if scheduler_type == "cosine": |
| total_epochs = int(scheduler_cfg["total_epochs"]) |
| eta_min = float(scheduler_cfg.get("eta_min", 1e-6)) |
|
|
| if warmup_epochs > 0: |
| t_max = int(scheduler_cfg.get("t_max", total_epochs - warmup_epochs)) |
| if t_max <= 0: |
| raise ValueError( |
| f"Invalid cosine scheduler config: total_epochs={total_epochs}, " |
| f"warmup_epochs={warmup_epochs}, resulting T_max={t_max}" |
| ) |
| else: |
| t_max = int(scheduler_cfg.get("t_max", total_epochs)) |
|
|
| main_scheduler = CosineAnnealingLR( |
| optimizer, |
| T_max=t_max, |
| eta_min=eta_min, |
| ) |
|
|
| elif scheduler_type == "step": |
| step_size = int(scheduler_cfg["step_size"]) |
| gamma = float(scheduler_cfg.get("gamma", 0.1)) |
| main_scheduler = StepLR( |
| optimizer, |
| step_size=step_size, |
| gamma=gamma, |
| ) |
|
|
| else: |
| raise ValueError(f"Unsupported scheduler type: {scheduler_type}") |
|
|
| if warmup_epochs <= 0: |
| return main_scheduler |
|
|
| warmup_scheduler = LinearLR( |
| optimizer, |
| start_factor=warmup_start_factor, |
| total_iters=warmup_epochs, |
| ) |
|
|
| scheduler = SequentialLR( |
| optimizer, |
| schedulers=[warmup_scheduler, main_scheduler], |
| milestones=[warmup_epochs], |
| ) |
| return scheduler |
|
|
|
|
| def get_checkpoint_model_state(model: SoilFormer) -> Dict[str, torch.Tensor]: |
| if hasattr(model, "_checkpoint_state_dict"): |
| return model._checkpoint_state_dict() |
| return model.state_dict() |
|
|
|
|
| def load_checkpoint_model_state(model: SoilFormer, state_dict: Dict[str, torch.Tensor]) -> None: |
| if hasattr(model, "load_weights"): |
| payload = {"model_state_dict": state_dict} |
| tmp_path = None |
| try: |
| import tempfile |
| with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f: |
| tmp_path = f.name |
| torch.save(payload, tmp_path) |
| model.load_weights(tmp_path, map_location="cpu", strict=True) |
| finally: |
| if tmp_path is not None and os.path.exists(tmp_path): |
| os.remove(tmp_path) |
| return |
|
|
| model.load_state_dict(state_dict, strict=True) |
|
|
|
|
| def save_checkpoint( |
| checkpoint_path: Path, |
| model: SoilFormer, |
| optimizer: torch.optim.Optimizer, |
| scheduler, |
| epoch: int, |
| global_step: int, |
| config_train: Dict, |
| config_model: Dict, |
| config_data: Dict, |
| ) -> None: |
| checkpoint = { |
| "epoch": epoch, |
| "global_step": global_step, |
| "model_state_dict": get_checkpoint_model_state(model), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "scheduler_state_dict": None if scheduler is None else scheduler.state_dict(), |
| "config_train": config_train, |
| "config_model": config_model, |
| "config_data": config_data, |
| } |
| checkpoint_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(checkpoint, checkpoint_path) |
|
|
|
|
| def rotate_checkpoints(checkpoint_dir: Path, max_saved_checkpoints: int) -> None: |
| checkpoint_paths = sorted(checkpoint_dir.glob("checkpoint_epoch_*.pt")) |
| if max_saved_checkpoints is None or max_saved_checkpoints <= 0: |
| return |
| while len(checkpoint_paths) > max_saved_checkpoints: |
| oldest = checkpoint_paths.pop(0) |
| oldest.unlink(missing_ok=True) |
|
|
|
|
| def compute_loss_from_batch( |
| model: SoilFormer, |
| batch: Dict, |
| device: torch.device, |
| dtype: torch.dtype, |
| cat_s_bound: Optional[float] = None, |
| num_s_bound: Optional[float] = None, |
| ): |
| batch = move_batch_to_device(batch, device=device, float_dtype=dtype) |
|
|
| cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model( |
| cat_local_ids=batch["masked_cat_local_ids"], |
| numeric_values_by_nin=batch["masked_numeric_values_by_nin"], |
| cat_valid_positions=batch["masked_cat_valid_positions"], |
| numeric_valid_positions_by_nin=batch["masked_numeric_valid_positions_by_nin"], |
| pixel_values=batch["pixel_values"], |
| vision_valid_positions=batch["vision_valid_positions"], |
| ) |
|
|
| total_loss, stats = loss_function( |
| x_cat=cat_logits_padded, |
| s_cat=cat_s, |
| y_cat=batch["original_cat_local_ids"], |
| loss_mask_cat=batch["cat_loss_mask"], |
| valid_class_mask=valid_class_mask, |
| x_num=value_by_nin, |
| s_num=s_by_nin, |
| y_num=batch["original_numeric_values_by_nin"], |
| loss_mask_num=batch["numeric_loss_mask_by_nin"], |
| reduction="mean", |
| cat_s_bound=cat_s_bound, |
| num_s_bound=num_s_bound, |
| ) |
|
|
| return total_loss, stats |
|
|
|
|
| @torch.no_grad() |
| def evaluate( |
| model: SoilFormer, |
| dataset: SoilFormerDataset, |
| eval_loader, |
| device: torch.device, |
| dtype: torch.dtype, |
| cat_mask_ratio: float, |
| num_mask_ratio: float, |
| active_mask_seed: int, |
| show_tqdm: bool, |
| epoch: int, |
| cat_s_bound: Optional[float] = None, |
| num_s_bound: Optional[float] = None, |
| ): |
| model.eval() |
|
|
| totals = { |
| "total": 0.0, |
| "cat_loss": 0.0, |
| "num_loss": 0.0, |
| "cat_base": 0.0, |
| "num_base": 0.0, |
| "cat_acc": 0.0, |
| } |
| num_batches = 0 |
|
|
| iterator = eval_loader |
| if show_tqdm: |
| iterator = tqdm(eval_loader, desc=f"Eval {epoch}", leave=False) |
|
|
| for batch_idx, raw_batch in enumerate(iterator): |
| mask_seed = int(active_mask_seed + batch_idx) |
| masked_batch = dataset.perform_active_mask( |
| raw_batch, |
| cat_ratio=cat_mask_ratio, |
| num_ratio=num_mask_ratio, |
| seed=mask_seed, |
| ) |
|
|
| _, stats = compute_loss_from_batch( |
| model=model, |
| batch=masked_batch, |
| device=device, |
| dtype=dtype, |
| cat_s_bound=cat_s_bound, |
| num_s_bound=num_s_bound, |
| ) |
|
|
| num_batches += 1 |
| for key in totals: |
| totals[key] += float(stats[key].item()) |
|
|
| if num_batches == 0: |
| raise RuntimeError("Eval dataloader is empty") |
|
|
| return {f"eval/{k}": v / num_batches for k, v in totals.items()} |
|
|
|
|
| def maybe_init_wandb(config_train: Dict): |
| wandb_cfg = config_train["logging"]["wandb"] |
| if not bool(wandb_cfg.get("enabled", False)): |
| return None |
|
|
| if wandb is None: |
| raise ImportError("wandb is enabled in config but package is not installed") |
|
|
| run = wandb.init( |
| project=wandb_cfg["project"], |
| entity=wandb_cfg.get("entity"), |
| name=wandb_cfg.get("run_name"), |
| dir=wandb_cfg.get("dir"), |
| config=config_train, |
| mode=wandb_cfg.get("mode", "online"), |
| ) |
| return run |
|
|
|
|
| def print_parameter_stats(model): |
| total = 0 |
| trainable = 0 |
|
|
| for p in model.parameters(): |
| num = p.numel() |
| total += num |
| if p.requires_grad: |
| trainable += num |
|
|
| print("\nParameter statistics:") |
| print(f"Total parameters: {total:,}") |
| print(f"Trainable parameters: {trainable:,}") |
| print(f"Frozen parameters: {total - trainable:,}\n") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, default="config/config_train.json") |
| args = parser.parse_args() |
|
|
| config_train = load_json(args.config) |
| config_paths = config_train["paths"] |
| config_data = load_json(config_paths["config_data_path"]) |
| config_model = load_json(config_paths["config_model_path"]) |
|
|
| seed_cfg = config_train["seed"] |
| runtime_cfg = config_train["runtime"] |
| optim_cfg = config_train["optimization"] |
| checkpoint_cfg = config_train["checkpoint"] |
| logging_cfg = config_train["logging"] |
| loss_cfg = config_train["loss"] |
|
|
| set_seed(int(seed_cfg["seed"]), deterministic=bool(seed_cfg.get("deterministic", True))) |
|
|
| device = resolve_device(runtime_cfg["device"]) |
| dtype = get_dtype(config_model.get("dtype", "bfloat16")) |
|
|
| output_dir = Path(config_paths["output_dir"]) |
| checkpoint_dir = output_dir / "checkpoints" |
| output_dir.mkdir(parents=True, exist_ok=True) |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| save_json(config_train, str(output_dir / "config_train.snapshot.json")) |
| save_json(config_data, str(output_dir / "config_data.snapshot.json")) |
| save_json(config_model, str(output_dir / "config_model.snapshot.json")) |
|
|
| dataset = SoilFormerDataset( |
| csv_path=config_data["data_csv_path"], |
| photo_map_path=config_data["photo_map_path"], |
| cat_vocab_path=config_data["cat_vocab_path"], |
| numeric_vocab_path=config_data["numeric_vocab_path"], |
| numeric_stats_path=config_data["numeric_stats_path"], |
| photo_root=config_data["photo_root"], |
| image_size=int(config_data["image_size"]), |
| ) |
|
|
| train_loader, eval_loader, train_generator = build_train_eval_dataloaders( |
| dataset=dataset, |
| train_ratio=float(config_data["train_ratio"]), |
| seed=int(config_data["train_eval_split_seed"]), |
| batch_size=int(config_data["batch_size"]), |
| ) |
| print("\nSample statistics:") |
| print("Train samples:", len(train_loader.dataset)) |
| print("Eval samples:", len(eval_loader.dataset)) |
| train_generator.manual_seed(int(seed_cfg["seed"])) |
|
|
| model = SoilFormer(config=config_model, device=str(device)) |
|
|
| resume_path = checkpoint_cfg.get("resume_checkpoint_path") |
| if resume_path: |
| checkpoint = torch.load(resume_path, map_location="cpu") |
| load_checkpoint_model_state(model, checkpoint["model_state_dict"]) |
| else: |
| model.init_weights(std=float(runtime_cfg.get("init_weight_std", 0.02))) |
| checkpoint = None |
|
|
| print_parameter_stats(model) |
|
|
| optimizer = AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=float(optim_cfg["lr"]), |
| betas=(float(optim_cfg["beta1"]), float(optim_cfg["beta2"])), |
| eps=float(optim_cfg["eps"]), |
| weight_decay=float(optim_cfg["weight_decay"]), |
| ) |
|
|
| scheduler = build_scheduler( |
| optimizer=optimizer, |
| scheduler_cfg=optim_cfg.get("scheduler", {"type": "none"}) |
| ) |
|
|
| start_epoch = 1 |
| global_step = 0 |
|
|
| if checkpoint is not None: |
| optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| if scheduler is not None and checkpoint.get("scheduler_state_dict") is not None: |
| scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) |
| start_epoch = int(checkpoint["epoch"]) + 1 |
| global_step = int(checkpoint.get("global_step", 0)) |
|
|
| wandb_run = maybe_init_wandb(config_train) |
|
|
| num_epochs = int(runtime_cfg["num_epochs"]) |
| show_tqdm = bool(logging_cfg.get("tqdm", True)) |
| cat_mask_ratio = float(config_data["cat_mask_ratio"]) |
| num_mask_ratio = float(config_data["num_mask_ratio"]) |
| active_mask_seed = int(config_data["active_mask_seed"]) |
| max_grad_norm = optim_cfg.get("max_grad_norm") |
| epochs_per_save = int(checkpoint_cfg["epochs_per_save"]) |
| max_saved_checkpoints = int(checkpoint_cfg["max_saved_checkpoints"]) |
|
|
| for epoch in range(start_epoch, num_epochs + 1): |
| model.train() |
|
|
| epoch_totals = { |
| "total": 0.0, |
| "cat_loss": 0.0, |
| "num_loss": 0.0, |
| "cat_base": 0.0, |
| "num_base": 0.0, |
| "cat_acc": 0.0, |
| } |
| num_batches = 0 |
|
|
| iterator = train_loader |
| if show_tqdm: |
| iterator = tqdm(train_loader, desc=f"Train {epoch}", leave=True) |
|
|
| for batch_idx, raw_batch in enumerate(iterator): |
| global_step += 1 |
| mask_seed = int(active_mask_seed + epoch * 1_000_000 + batch_idx) |
| masked_batch = dataset.perform_active_mask( |
| raw_batch, |
| cat_ratio=cat_mask_ratio, |
| num_ratio=num_mask_ratio, |
| seed=mask_seed, |
| ) |
|
|
| optimizer.zero_grad(set_to_none=True) |
|
|
| total_loss, stats = compute_loss_from_batch( |
| model=model, |
| batch=masked_batch, |
| device=device, |
| dtype=dtype, |
| cat_s_bound=loss_cfg.get("cat_s_bound", None), |
| num_s_bound=loss_cfg.get("num_s_bound", None), |
| ) |
|
|
| total_loss.backward() |
| if max_grad_norm is not None: |
| torch.nn.utils.clip_grad_norm_(model.parameters(), float(max_grad_norm)) |
| optimizer.step() |
|
|
| num_batches += 1 |
| for key in epoch_totals: |
| epoch_totals[key] += float(stats[key].item()) |
|
|
| current_lr = float(optimizer.param_groups[0]["lr"]) |
| train_step_log = { |
| "train/step_total": float(stats["total"].item()), |
| "train/step_cat_loss": float(stats["cat_loss"].item()), |
| "train/step_num_loss": float(stats["num_loss"].item()), |
| "train/step_cat_acc": float(stats["cat_acc"].item()), |
| "train/lr": current_lr, |
| "epoch": epoch, |
| "global_step": global_step, |
| } |
|
|
| if wandb_run is not None: |
| wandb.log(train_step_log, step=global_step) |
|
|
| if show_tqdm: |
| iterator.set_postfix( |
| loss=f"{train_step_log['train/step_total']:.4f}", |
| lr=f"{current_lr:.3e}", |
| ) |
|
|
| if num_batches == 0: |
| raise RuntimeError("Train dataloader is empty") |
|
|
| train_epoch_log = {f"train/{k}": v / num_batches for k, v in epoch_totals.items()} |
| train_epoch_log["train/lr_epoch_end"] = float(optimizer.param_groups[0]["lr"]) |
| train_epoch_log["epoch"] = epoch |
| train_epoch_log["global_step"] = global_step |
|
|
| eval_log = evaluate( |
| model=model, |
| dataset=dataset, |
| eval_loader=eval_loader, |
| device=device, |
| dtype=dtype, |
| cat_mask_ratio=cat_mask_ratio, |
| num_mask_ratio=num_mask_ratio, |
| active_mask_seed=active_mask_seed, |
| show_tqdm=show_tqdm, |
| epoch=epoch, |
| cat_s_bound=loss_cfg.get("cat_s_bound", None), |
| num_s_bound=loss_cfg.get("num_s_bound", None), |
| ) |
| eval_log["epoch"] = epoch |
| eval_log["global_step"] = global_step |
|
|
| merged_log = {} |
| merged_log.update(train_epoch_log) |
| merged_log.update(eval_log) |
|
|
| print(json.dumps(merged_log, ensure_ascii=False)) |
|
|
| if wandb_run is not None: |
| wandb.log(merged_log, step=global_step) |
|
|
| if scheduler is not None: |
| scheduler.step() |
|
|
| if epochs_per_save > 0 and epoch % epochs_per_save == 0: |
| checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt" |
| save_checkpoint( |
| checkpoint_path=checkpoint_path, |
| model=model, |
| optimizer=optimizer, |
| scheduler=scheduler, |
| epoch=epoch, |
| global_step=global_step, |
| config_train=config_train, |
| config_model=config_model, |
| config_data=config_data, |
| ) |
| rotate_checkpoints(checkpoint_dir, max_saved_checkpoints) |
|
|
| if wandb_run is not None: |
| wandb.finish() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|