Controller / training /losses.py
Gen-HVAC's picture
Upload 4 files
1641a08 verified
"""
losses.py
"""
from dataclasses import dataclass
from typing import Dict, Tuple
import torch
import torch.nn.functional as F
# ============================================================
# 1) CONFIG
# ============================================================
@dataclass
class GeneralistLossConfig:
w_action: float = 1.0
w_physics: float = 20.0
w_value: float = 100.0
label_smoothing: float = 0.0
use_rtg_weighting: bool = True
rtg_weight_mode: str = "exp"
rtg_weight_beta: float = 2.0
min_token_weight: float = 0.05
# ============================================================
# 2) HELPERS
# ============================================================
def _expand_rtg_to_tokens(rtg_bt: torch.Tensor, K: int) -> torch.Tensor:
return rtg_bt.unsqueeze(-1).expand(-1, -1, K)
def _rtg_to_weights(rtg_input: torch.Tensor, mode: str, beta: float) -> torch.Tensor:
if mode == "none":
return torch.ones(rtg_input.shape[:2], device=rtg_input.device)
if rtg_input.dim() == 3:
mu = rtg_input.mean(dim=1, keepdim=True)
sig = rtg_input.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5)
rtg_norm = (rtg_input - mu) / sig
scalar_rtg = rtg_norm.sum(dim=-1)
else:
scalar_rtg = rtg_input
mu_s = scalar_rtg.mean(dim=1, keepdim=True)
sig_s = scalar_rtg.std(dim=1, keepdim=True, unbiased=False).clamp_min(1e-5)
z = (scalar_rtg - mu_s) / sig_s
z = torch.clamp(z, -5.0, 5.0)
if mode == "clamp01":
w = torch.sigmoid(beta * z)
elif mode == "softplus":
w = F.softplus(beta * z)
elif mode == "exp":
w = torch.exp(beta * z)
else:
raise ValueError(f"Unknown rtg_weight_mode={mode}")
w = torch.clamp(w, min=0.01, max=50.0)
return w
# return total, metrics
def compute_generalist_loss(
model_out: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
config: GeneralistLossConfig
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Computes Physics loss and Rescaled Value loss.
"""
action_logits = model_out["action_logits"] # [B, T, K, n_bins]
state_preds = model_out["state_preds"] # [B, T, K]
state_preds_4h = model_out["state_preds_4h"] # [B, T, K]
return_preds = model_out["return_preds"] # [B, T, 2]
target_tokens = batch["target_action_tokens"]
target_mask = batch["target_mask"].float()
attn_mask = batch["attention_mask"].float()
target_rtg = batch["rtg"].float()
time_mask = batch.get("time_mask", torch.ones(target_rtg.shape[:2], device=target_rtg.device)).float()
B, T, K, n_bins = action_logits.shape
is_state = (1.0 - target_mask)
valid_phys = attn_mask * is_state
# 1) Stitching
if config.use_rtg_weighting:
w_bt = _rtg_to_weights(target_rtg, config.rtg_weight_mode, config.rtg_weight_beta)
w_btk = _expand_rtg_to_tokens(w_bt, K)
norm_factor = (target_mask * attn_mask).sum().clamp_min(1e-6) / (w_btk * target_mask * attn_mask).sum().clamp_min(1e-6)
token_importance = w_btk * norm_factor
else:
w_bt = torch.ones((B, T), device=action_logits.device)
token_importance = torch.ones((B, T, K), device=action_logits.device)
# 2) ACTION LOSS (CE)
flat_logits = action_logits.reshape(-1, n_bins)
flat_targets = target_tokens.reshape(-1)
flat_mask = (target_mask * attn_mask).reshape(-1)
flat_importance = token_importance.reshape(-1)
with torch.no_grad():
valid_t = flat_targets[flat_mask > 0.5]
if valid_t.numel() > 0:
counts = torch.bincount(valid_t, minlength=n_bins).float()
class_weights = (1.0 / (counts + 10.0)) / (1.0 / (counts + 10.0)).mean()
else:
class_weights = torch.ones(n_bins, device=flat_logits.device)
ce_per_token = F.cross_entropy(flat_logits, flat_targets, weight=class_weights, reduction="none", ignore_index=-100)
loss_action = (ce_per_token * flat_mask * flat_importance).sum() / flat_mask.sum().clamp_min(1e-6)
# ============================================================
# 3) PHYSICS LOSS (The Delta Fix)
# ============================================================
# Ground Truth from Dataloader
# next_obs is [B, T, 21]
# feature_values is [B, T, 64] (Padded tokens)
true_next = batch["next_obs"].float()
target_delta_4h = batch["target_4h_delta"].float()
K_limit = true_next.shape[2]
true_vals_sliced = batch["feature_values"].float().narrow(2, 0, K_limit)
s_pred_valid = state_preds.narrow(2, 0, K_limit)
s_pred_4h_valid = state_preds_4h.narrow(2, 0, K_limit)
v_phys_mask = valid_phys.narrow(2, 0, K_limit)
target_delta_1s = true_next - true_vals_sliced
mse_1s = (s_pred_valid - target_delta_1s) ** 2
mse_4h = (s_pred_4h_valid - target_delta_4h) ** 2
with torch.no_grad():
act_diff = torch.zeros((B, T), device=true_next.device)
if T > 1:
act_diff[:, 1:] = torch.abs(true_vals_sliced[:, 1:] - true_vals_sliced[:, :-1]).sum(dim=-1)
excitation = (1.0 + 5.0 * act_diff).unsqueeze(-1)
denom = (v_phys_mask * excitation).sum().clamp_min(1e-6)
loss_phys_1s = (mse_1s * v_phys_mask * excitation).sum() / denom
loss_phys_4h = (mse_4h * v_phys_mask * excitation).sum() / denom
loss_physics = loss_phys_1s + 0.5 * loss_phys_4h
val_mse = ((return_preds - target_rtg) ** 2).sum(dim=-1)
loss_value = (val_mse * w_bt * time_mask).sum() / time_mask.sum().clamp_min(1e-6)
loss_value = loss_value * 500.0
total = (config.w_action * loss_action) + \
(config.w_physics * loss_physics) + \
(config.w_value * loss_value)
with torch.no_grad():
acc = ((torch.argmax(flat_logits, -1) == flat_targets).float() * flat_mask).sum() / flat_mask.sum().clamp_min(1e-6)
if torch.rand(1) < 0.001:
print(f"[Loss Debug] Action: {loss_action.item():.3f} | Phys: {loss_physics.item():.3f} | Val: {loss_value.item():.3f}")
metrics = {
"loss_action": loss_action.detach(),
"loss_physics": loss_physics.detach(),
"loss_value": loss_value.detach(),
"accuracy": acc.detach(),
"total_loss": total.detach(),
}
return total, metrics