|
|
import logging
|
|
|
import math
|
|
|
import time
|
|
|
import torch
|
|
|
from typing import Dict, Tuple, Optional
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
try:
|
|
|
import mlflow
|
|
|
MLFLOW_AVAILABLE = True
|
|
|
except ImportError:
|
|
|
MLFLOW_AVAILABLE = False
|
|
|
|
|
|
|
|
|
def _is_deepspeed_engine(obj) -> bool:
|
|
|
return hasattr(obj, "backward") and hasattr(obj, "step") and hasattr(obj, "module")
|
|
|
|
|
|
|
|
|
def train_one_epoch(
|
|
|
*,
|
|
|
model,
|
|
|
train_loader,
|
|
|
optimizer,
|
|
|
scheduler,
|
|
|
scaler,
|
|
|
device,
|
|
|
args,
|
|
|
global_step: int,
|
|
|
use_wandb: bool,
|
|
|
is_main_process: bool,
|
|
|
) -> Tuple[float, int]:
|
|
|
model.train()
|
|
|
total_loss = 0.0
|
|
|
num_batches = 0
|
|
|
grad_norms = []
|
|
|
|
|
|
|
|
|
current_total_grad_norm = 0.0
|
|
|
current_router_grad_norm = 0.0
|
|
|
|
|
|
start_time = time.time()
|
|
|
last_step_time = time.time()
|
|
|
|
|
|
logger.debug("Starting training loop, train_loader length estimate: %s", len(train_loader) if hasattr(train_loader, '__len__') else 'unknown')
|
|
|
logger.debug("gradient_accumulation_steps: %s", args.gradient_accumulation_steps)
|
|
|
|
|
|
for batch_idx, batch in enumerate(train_loader):
|
|
|
measured_grad_this_step = False
|
|
|
|
|
|
logger.debug("Batch %d: global_step=%d", batch_idx, global_step)
|
|
|
|
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
step_start = time.time()
|
|
|
|
|
|
use_dre_now = True
|
|
|
if getattr(args, "dre_warmup_steps", 0) and global_step < args.dre_warmup_steps:
|
|
|
use_dre_now = False
|
|
|
|
|
|
if _is_deepspeed_engine(model):
|
|
|
outputs = model(
|
|
|
**batch,
|
|
|
use_dre=use_dre_now,
|
|
|
reasoning_path=getattr(args, "dre_force_path", None)
|
|
|
)
|
|
|
loss = outputs["loss"]
|
|
|
model.backward(loss)
|
|
|
model.step()
|
|
|
else:
|
|
|
|
|
|
if scaler is not None:
|
|
|
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
amp_enabled = True
|
|
|
if getattr(args, "amp_warmup_steps", 0) and global_step < args.amp_warmup_steps:
|
|
|
amp_enabled = False
|
|
|
with torch.amp.autocast(device_type=device_type, enabled=amp_enabled):
|
|
|
outputs = model(
|
|
|
**batch,
|
|
|
use_dre=use_dre_now,
|
|
|
reasoning_path=getattr(args, "dre_force_path", None)
|
|
|
)
|
|
|
loss = outputs["loss"]
|
|
|
else:
|
|
|
outputs = model(
|
|
|
**batch,
|
|
|
use_dre=use_dre_now,
|
|
|
reasoning_path=getattr(args, "dre_force_path", None)
|
|
|
)
|
|
|
loss = outputs["loss"]
|
|
|
loss = loss / args.gradient_accumulation_steps
|
|
|
|
|
|
|
|
|
if not torch.isfinite(loss):
|
|
|
if optimizer is not None:
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
continue
|
|
|
|
|
|
|
|
|
if scaler is not None:
|
|
|
scaler.scale(loss).backward()
|
|
|
else:
|
|
|
loss.backward()
|
|
|
|
|
|
|
|
|
current_router_grad_norm = 0.0
|
|
|
|
|
|
|
|
|
if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
|
|
|
if optimizer is not None:
|
|
|
|
|
|
if args.gradient_clipping > 0:
|
|
|
if scaler is not None:
|
|
|
scaler.unscale_(optimizer)
|
|
|
effective_max_norm = float(args.gradient_clipping) * float(args.gradient_accumulation_steps)
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), effective_max_norm)
|
|
|
|
|
|
total_sq = 0.0
|
|
|
router_sq = 0.0
|
|
|
for name, param in model.named_parameters():
|
|
|
if param.grad is not None:
|
|
|
pn = float(param.grad.data.norm(2).item())
|
|
|
total_sq += pn * pn
|
|
|
if 'gate' in name or 'router' in name:
|
|
|
router_sq += pn * pn
|
|
|
current_total_grad_norm = total_sq ** 0.5
|
|
|
current_router_grad_norm = router_sq ** 0.5
|
|
|
if current_total_grad_norm > effective_max_norm * 1.01:
|
|
|
print(f"[ERROR] Clipping failed! norm={current_total_grad_norm:.2f} max={effective_max_norm:.2f}")
|
|
|
measured_grad_this_step = True
|
|
|
|
|
|
|
|
|
if global_step < 10:
|
|
|
try:
|
|
|
comp_sq = {
|
|
|
'embedding': 0.0,
|
|
|
'attention': 0.0,
|
|
|
'experts': 0.0,
|
|
|
'router': 0.0,
|
|
|
'output': 0.0,
|
|
|
}
|
|
|
expert_param_grads = []
|
|
|
router_param_norms = []
|
|
|
for name, param in model.named_parameters():
|
|
|
if param.grad is None:
|
|
|
continue
|
|
|
g = float(param.grad.data.norm(2).item())
|
|
|
lname = name.lower()
|
|
|
if ('embedding' in lname) or ('wte' in lname) or ('wpe' in lname):
|
|
|
comp_sq['embedding'] += g * g
|
|
|
elif ('router' in lname) or ('gate' in lname):
|
|
|
comp_sq['router'] += g * g
|
|
|
elif ('expert' in lname) or ('experts' in lname) or ('moe' in lname):
|
|
|
comp_sq['experts'] += g * g
|
|
|
elif ('attn' in lname) or ('attention' in lname):
|
|
|
comp_sq['attention'] += g * g
|
|
|
elif ('lm_head' in lname) or ('output' in lname):
|
|
|
comp_sq['output'] += g * g
|
|
|
|
|
|
if (('expert' in lname) or ('experts' in lname)) and g > 0:
|
|
|
expert_param_grads.append((name, g))
|
|
|
|
|
|
if (('router' in lname) or ('gate' in lname)):
|
|
|
try:
|
|
|
router_param_norms.append((name, float(param.data.norm(2).item())))
|
|
|
except Exception:
|
|
|
pass
|
|
|
comp_norms = {k: (v ** 0.5) for k, v in comp_sq.items()}
|
|
|
total_comp_norm = (sum(v for v in comp_sq.values())) ** 0.5
|
|
|
print("[diag] grad_norms:",
|
|
|
f"embedding={comp_norms['embedding']:.4f}",
|
|
|
f"attention={comp_norms['attention']:.4f}",
|
|
|
f"experts={comp_norms['experts']:.4f}",
|
|
|
f"router={comp_norms['router']:.4f}",
|
|
|
f"output={comp_norms['output']:.4f}",
|
|
|
f"total={total_comp_norm:.4f}")
|
|
|
if expert_param_grads:
|
|
|
expert_param_grads.sort(key=lambda x: x[1], reverse=True)
|
|
|
topk = expert_param_grads[:10]
|
|
|
print("[diag] top_expert_grads:")
|
|
|
for n, g in topk:
|
|
|
print(f" {n}: {g:.6f}")
|
|
|
if router_param_norms:
|
|
|
router_param_norms.sort(key=lambda x: x[1], reverse=True)
|
|
|
topk_r = router_param_norms[:5]
|
|
|
print("[diag] router_param_norms:")
|
|
|
for n, pn in topk_r:
|
|
|
print(f" {n}: {pn:.6f}")
|
|
|
|
|
|
try:
|
|
|
if hasattr(outputs, 'get') and outputs.get('logits') is not None:
|
|
|
logits = outputs['logits']
|
|
|
m = float(logits.mean().detach().cpu().item())
|
|
|
s = float(logits.std().detach().cpu().item())
|
|
|
mn = float(logits.min().detach().cpu().item())
|
|
|
mx = float(logits.max().detach().cpu().item())
|
|
|
print(f"[diag] logits: mean={m:.4f} std={s:.4f} min={mn:.4f} max={mx:.4f}")
|
|
|
if s < 0.1:
|
|
|
print("[diag] WARN: logits std < 0.1 (low variance)")
|
|
|
if abs(m) > 10.0:
|
|
|
print("[diag] WARN: logits mean magnitude > 10 (extreme)")
|
|
|
except Exception:
|
|
|
pass
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if scaler is not None:
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
else:
|
|
|
optimizer.step()
|
|
|
scheduler.step()
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
global_step += 1
|
|
|
|
|
|
total_loss += float(loss.detach())
|
|
|
num_batches += 1
|
|
|
|
|
|
|
|
|
if getattr(args, "log_cuda_memory", False) and torch.cuda.is_available():
|
|
|
if (batch_idx + 1) % max(1, int(getattr(args, "perf_log_interval", 200))) == 0 and is_main_process:
|
|
|
mem_alloc = torch.cuda.memory_allocated() / (1024**2)
|
|
|
mem_res = torch.cuda.memory_reserved() / (1024**2)
|
|
|
print(f"[mem] step={global_step} alloc_mb={mem_alloc:.1f} reserved_mb={mem_res:.1f}")
|
|
|
|
|
|
|
|
|
should_log = ((batch_idx + 1) % args.gradient_accumulation_steps == 0) or (batch_idx == 0)
|
|
|
logger.debug("After batch %d: (batch_idx+1)=%d, grad_accum=%d, should_log=%s", batch_idx, batch_idx+1, args.gradient_accumulation_steps, should_log)
|
|
|
|
|
|
if should_log:
|
|
|
|
|
|
step_loss = float(loss.detach()) * args.gradient_accumulation_steps
|
|
|
try:
|
|
|
|
|
|
perplexity = math.exp(min(step_loss, 20))
|
|
|
except (OverflowError, ValueError):
|
|
|
perplexity = float('inf')
|
|
|
|
|
|
|
|
|
try:
|
|
|
toks = int(batch['input_ids'].numel())
|
|
|
except Exception:
|
|
|
toks = 0
|
|
|
step_time = time.time() - step_start
|
|
|
toks_per_sec = toks / step_time if step_time > 0 and toks > 0 else 0.0
|
|
|
|
|
|
|
|
|
moe_metrics = {}
|
|
|
dre_metrics = {}
|
|
|
aux_loss_value = 0.0
|
|
|
|
|
|
|
|
|
logger.debug("batch_idx=%d, global_step=%d, should_log=%s", batch_idx, global_step, should_log)
|
|
|
if hasattr(outputs, 'keys'):
|
|
|
logger.debug("outputs keys: %s", list(outputs.keys()))
|
|
|
else:
|
|
|
logger.debug("outputs type: %s", type(outputs))
|
|
|
|
|
|
|
|
|
if hasattr(outputs, 'get') and outputs.get('routing_info'):
|
|
|
routing_info = outputs['routing_info']
|
|
|
|
|
|
if isinstance(routing_info, dict):
|
|
|
if 'path' in routing_info:
|
|
|
dre_metrics['path_now'] = routing_info['path']
|
|
|
if 'complexity_score' in routing_info:
|
|
|
try:
|
|
|
dre_metrics['comp_now'] = float(routing_info['complexity_score'])
|
|
|
except Exception:
|
|
|
pass
|
|
|
if 'confidence' in routing_info:
|
|
|
try:
|
|
|
dre_metrics['conf_now'] = float(routing_info['confidence'])
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
if 'dre_metrics' in routing_info:
|
|
|
dre_info = routing_info['dre_metrics']
|
|
|
|
|
|
|
|
|
if 'avg_complexity' in dre_info:
|
|
|
dre_metrics['complexity'] = dre_info['avg_complexity']
|
|
|
if 'avg_confidence' in dre_info:
|
|
|
dre_metrics['confidence'] = dre_info['avg_confidence']
|
|
|
if 'path_distribution' in dre_info:
|
|
|
|
|
|
path_dist = dre_info['path_distribution']
|
|
|
if path_dist:
|
|
|
most_used_path = max(path_dist.items(), key=lambda x: x[1])
|
|
|
dre_metrics['main_path'] = f"{most_used_path[0]}({most_used_path[1]:.0f}%)"
|
|
|
if 'cache_hit_rate' in dre_info:
|
|
|
dre_metrics['cache_hit'] = dre_info['cache_hit_rate']
|
|
|
|
|
|
if hasattr(outputs, 'get') and outputs.get('routing_info'):
|
|
|
try:
|
|
|
used_flag = bool(outputs['routing_info'].get('used_moe', False))
|
|
|
moe_metrics['used_moe'] = used_flag
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
if hasattr(outputs, 'get') and outputs.get('moe_info'):
|
|
|
moe_info = outputs['moe_info']
|
|
|
|
|
|
|
|
|
if 'expert_utilization' in moe_info:
|
|
|
util = moe_info['expert_utilization']
|
|
|
|
|
|
if 'avg_routing_entropy' in util:
|
|
|
moe_metrics['entropy'] = util['avg_routing_entropy']
|
|
|
|
|
|
max_concentration = 0
|
|
|
max_concentration_multi = 0
|
|
|
max_ratio = 0.0
|
|
|
k_ratio = None
|
|
|
s_ratio = None
|
|
|
num_used = moe_info.get('num_experts_used', {})
|
|
|
for expert_type in ['knowledge', 'skill', 'meta', 'safety']:
|
|
|
key = f"{expert_type}_top_expert_pct"
|
|
|
if key in util:
|
|
|
val = util[key]
|
|
|
if val > max_concentration:
|
|
|
max_concentration = val
|
|
|
n = num_used.get(expert_type, 0)
|
|
|
if n and n > 1 and val > max_concentration_multi:
|
|
|
max_concentration_multi = val
|
|
|
|
|
|
try:
|
|
|
n_int = int(n)
|
|
|
except Exception:
|
|
|
n_int = 0
|
|
|
if n_int > 0:
|
|
|
ideal_pct = 100.0 * float(min(getattr(args, 'moe_top_k', 2), n_int)) / float(n_int)
|
|
|
if ideal_pct > 0:
|
|
|
ratio = float(val) / float(ideal_pct)
|
|
|
if ratio > max_ratio:
|
|
|
max_ratio = ratio
|
|
|
if expert_type == 'knowledge':
|
|
|
k_ratio = ratio
|
|
|
if expert_type == 'skill':
|
|
|
s_ratio = ratio
|
|
|
if max_concentration > 0:
|
|
|
moe_metrics['max_expert_pct'] = max_concentration
|
|
|
if max_concentration_multi > 0:
|
|
|
moe_metrics['max_expert_pct_multi'] = max_concentration_multi
|
|
|
if 'knowledge_top_expert_pct' in util:
|
|
|
moe_metrics['k_max'] = util['knowledge_top_expert_pct']
|
|
|
if 'skill_top_expert_pct' in util:
|
|
|
moe_metrics['s_max'] = util['skill_top_expert_pct']
|
|
|
if max_ratio > 0:
|
|
|
moe_metrics['max_exp_rel'] = max_ratio
|
|
|
if k_ratio is not None:
|
|
|
moe_metrics['k_rel'] = k_ratio
|
|
|
if s_ratio is not None:
|
|
|
moe_metrics['s_rel'] = s_ratio
|
|
|
|
|
|
|
|
|
if 'aux_losses' in moe_info:
|
|
|
aux_losses = moe_info['aux_losses']
|
|
|
total_aux = 0.0
|
|
|
lb_total = 0.0
|
|
|
z_total = 0.0
|
|
|
imp_total = 0.0
|
|
|
ent_reg_total = 0.0
|
|
|
for key, loss_val in aux_losses.items():
|
|
|
|
|
|
if isinstance(loss_val, torch.Tensor):
|
|
|
val = float(loss_val.detach().mean()) if loss_val.numel() > 1 else float(loss_val.detach())
|
|
|
else:
|
|
|
try:
|
|
|
val = float(loss_val)
|
|
|
except Exception:
|
|
|
val = 0.0
|
|
|
|
|
|
if 'load_loss' in key:
|
|
|
lb_total += val
|
|
|
total_aux += val
|
|
|
elif 'z_loss' in key:
|
|
|
z_total += val
|
|
|
total_aux += val
|
|
|
elif 'importance_loss' in key:
|
|
|
imp_total += val
|
|
|
total_aux += val
|
|
|
elif 'entropy_reg_loss' in key:
|
|
|
ent_reg_total += val
|
|
|
total_aux += val
|
|
|
aux_loss_value = total_aux
|
|
|
moe_metrics['aux_loss'] = aux_loss_value
|
|
|
|
|
|
if lb_total > 0:
|
|
|
moe_metrics['lb'] = lb_total
|
|
|
if z_total > 0:
|
|
|
moe_metrics['z'] = z_total
|
|
|
if imp_total > 0:
|
|
|
moe_metrics['imp'] = imp_total
|
|
|
if ent_reg_total > 0:
|
|
|
moe_metrics['ent_reg'] = ent_reg_total
|
|
|
|
|
|
|
|
|
moe_str = ""
|
|
|
if moe_metrics:
|
|
|
moe_parts = []
|
|
|
if 'entropy' in moe_metrics:
|
|
|
moe_parts.append(f"entropy={moe_metrics['entropy']:.2f}")
|
|
|
if 'max_expert_pct' in moe_metrics:
|
|
|
moe_parts.append(f"max_exp={moe_metrics['max_expert_pct']:.1f}%")
|
|
|
if 'max_expert_pct_multi' in moe_metrics:
|
|
|
moe_parts.append(f"max_exp_multi={moe_metrics['max_expert_pct_multi']:.1f}%")
|
|
|
if 'k_max' in moe_metrics:
|
|
|
moe_parts.append(f"k_max={moe_metrics['k_max']:.1f}%")
|
|
|
if 's_max' in moe_metrics:
|
|
|
moe_parts.append(f"s_max={moe_metrics['s_max']:.1f}%")
|
|
|
if 'max_exp_rel' in moe_metrics:
|
|
|
moe_parts.append(f"max_rel={moe_metrics['max_exp_rel']:.2f}x")
|
|
|
if 'k_rel' in moe_metrics:
|
|
|
moe_parts.append(f"k_rel={moe_metrics['k_rel']:.2f}x")
|
|
|
if 's_rel' in moe_metrics:
|
|
|
moe_parts.append(f"s_rel={moe_metrics['s_rel']:.2f}x")
|
|
|
if 'aux_loss' in moe_metrics and moe_metrics['aux_loss'] > 0:
|
|
|
moe_parts.append(f"aux={moe_metrics['aux_loss']:.4f}")
|
|
|
|
|
|
if 'lb' in moe_metrics:
|
|
|
moe_parts.append(f"lb={moe_metrics['lb']:.4f}")
|
|
|
if 'z' in moe_metrics:
|
|
|
moe_parts.append(f"z={moe_metrics['z']:.4f}")
|
|
|
if 'imp' in moe_metrics:
|
|
|
moe_parts.append(f"imp={moe_metrics['imp']:.4f}")
|
|
|
if 'ent_reg' in moe_metrics:
|
|
|
moe_parts.append(f"ent_reg={moe_metrics['ent_reg']:.4f}")
|
|
|
if 'used_moe' in moe_metrics:
|
|
|
moe_parts.append(f"used_moe={moe_metrics['used_moe']}")
|
|
|
if moe_parts:
|
|
|
moe_str = f" moe=[{','.join(moe_parts)}]"
|
|
|
|
|
|
dre_str = ""
|
|
|
if dre_metrics:
|
|
|
dre_parts = []
|
|
|
|
|
|
if 'comp_now' in dre_metrics:
|
|
|
dre_parts.append(f"comp={dre_metrics['comp_now']:.2f}")
|
|
|
elif 'complexity' in dre_metrics:
|
|
|
dre_parts.append(f"comp_avg={dre_metrics['complexity']:.2f}")
|
|
|
if 'conf_now' in dre_metrics:
|
|
|
dre_parts.append(f"conf={dre_metrics['conf_now']:.2f}")
|
|
|
elif 'confidence' in dre_metrics:
|
|
|
dre_parts.append(f"conf_avg={dre_metrics['confidence']:.2f}")
|
|
|
if 'path_now' in dre_metrics:
|
|
|
dre_parts.append(f"path={dre_metrics['path_now']}")
|
|
|
elif 'main_path' in dre_metrics:
|
|
|
dre_parts.append(f"path_avg={dre_metrics['main_path']}")
|
|
|
if dre_parts:
|
|
|
dre_str = f" dre=[{','.join(dre_parts)}]"
|
|
|
|
|
|
|
|
|
grad_str = ""
|
|
|
if current_total_grad_norm > 0:
|
|
|
grad_str = f" grad=[total={current_total_grad_norm:.3f}"
|
|
|
if current_router_grad_norm > 0:
|
|
|
grad_str += f",router={current_router_grad_norm:.3f}"
|
|
|
grad_str += "]"
|
|
|
|
|
|
|
|
|
try:
|
|
|
curr_lr = float(scheduler.get_last_lr()[0]) if scheduler is not None else 0.0
|
|
|
except Exception:
|
|
|
curr_lr = 0.0
|
|
|
print(f"[step] step={global_step} loss={step_loss:.4f} ppl={perplexity:.2f} toks/s={toks_per_sec:.1f} lr={curr_lr:.6g}{moe_str}{dre_str}{grad_str}")
|
|
|
|
|
|
|
|
|
if MLFLOW_AVAILABLE and getattr(args, "use_mlflow", False):
|
|
|
try:
|
|
|
metrics = {
|
|
|
'train/step_loss': step_loss,
|
|
|
'train/step_perplexity': perplexity if perplexity != float('inf') else 1e10,
|
|
|
'train/tokens_per_sec': toks_per_sec,
|
|
|
'train/learning_rate': float(scheduler.get_last_lr()[0]) if scheduler is not None else 0.0,
|
|
|
}
|
|
|
|
|
|
|
|
|
if hasattr(outputs, 'get') and outputs.get('moe_info'):
|
|
|
moe_info = outputs['moe_info']
|
|
|
if 'expert_utilization' in moe_info:
|
|
|
util = moe_info['expert_utilization']
|
|
|
|
|
|
|
|
|
for key, value in util.items():
|
|
|
if isinstance(value, (int, float)):
|
|
|
metrics[f'moe/{key}'] = value
|
|
|
elif isinstance(value, list) and len(value) <= 10:
|
|
|
for i, v in enumerate(value):
|
|
|
metrics[f'moe/{key}_expert_{i}'] = v
|
|
|
|
|
|
|
|
|
if hasattr(outputs, 'get') and outputs.get('routing_info'):
|
|
|
routing_info = outputs['routing_info']
|
|
|
if 'dre_metrics' in routing_info:
|
|
|
dre_info = routing_info['dre_metrics']
|
|
|
|
|
|
|
|
|
for key, value in dre_info.items():
|
|
|
if isinstance(value, (int, float)):
|
|
|
metrics[f'dre/{key}'] = value
|
|
|
elif isinstance(value, dict):
|
|
|
|
|
|
for path_name, path_pct in value.items():
|
|
|
metrics[f'dre/path_{path_name}'] = path_pct
|
|
|
|
|
|
mlflow.log_metrics(metrics, step=global_step)
|
|
|
except Exception as e:
|
|
|
pass
|
|
|
last_step_time = time.time()
|
|
|
|
|
|
|
|
|
profiler = getattr(args, "_profiler", None)
|
|
|
if profiler is not None:
|
|
|
try:
|
|
|
profiler.step()
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
if (num_batches % max(1, int(getattr(args, "gradient_accumulation_steps", 1)))) != 0 and optimizer is not None and not _is_deepspeed_engine(model):
|
|
|
|
|
|
if getattr(args, "gradient_clipping", 0) > 0:
|
|
|
if scaler is not None:
|
|
|
scaler.unscale_(optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clipping)
|
|
|
|
|
|
if scaler is not None:
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
else:
|
|
|
optimizer.step()
|
|
|
if scheduler is not None:
|
|
|
scheduler.step()
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
global_step += 1
|
|
|
|
|
|
|
|
|
try:
|
|
|
step_loss = float(loss.detach()) * args.gradient_accumulation_steps
|
|
|
try:
|
|
|
perplexity = math.exp(min(step_loss, 20))
|
|
|
except (OverflowError, ValueError):
|
|
|
perplexity = float('inf')
|
|
|
|
|
|
if 'moe_str' not in locals():
|
|
|
moe_str = ""
|
|
|
if 'dre_str' not in locals():
|
|
|
dre_str = ""
|
|
|
if 'grad_str' not in locals():
|
|
|
grad_str = ""
|
|
|
print(f"[step] step={global_step} loss={step_loss:.4f} ppl={perplexity:.2f} toks/s=0.0{moe_str}{dre_str}{grad_str}")
|
|
|
except Exception:
|
|
|
pass
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
epoch_time = time.time() - start_time
|
|
|
avg_loss = total_loss / max(1, num_batches)
|
|
|
if is_main_process:
|
|
|
|
|
|
try:
|
|
|
avg_perplexity = math.exp(min(avg_loss, 20))
|
|
|
except (OverflowError, ValueError):
|
|
|
avg_perplexity = float('inf')
|
|
|
|
|
|
toks = (len(train_loader.dataset) * getattr(args, "max_seq_length", 1)) if hasattr(train_loader, "dataset") else 0
|
|
|
if toks:
|
|
|
toks_per_sec = toks / max(1e-6, epoch_time)
|
|
|
else:
|
|
|
toks_per_sec = 0.0
|
|
|
print(f"[train] avg_loss={avg_loss:.4f} avg_ppl={avg_perplexity:.2f} epoch_time={epoch_time:.1f}s toks/s={toks_per_sec:.1f}")
|
|
|
|
|
|
|
|
|
if MLFLOW_AVAILABLE and getattr(args, "use_mlflow", False):
|
|
|
try:
|
|
|
mlflow.log_metrics({
|
|
|
'train/epoch_avg_loss': avg_loss,
|
|
|
'train/epoch_avg_perplexity': avg_perplexity if avg_perplexity != float('inf') else 1e10,
|
|
|
'train/epoch_time_sec': epoch_time,
|
|
|
}, step=global_step)
|
|
|
except Exception:
|
|
|
pass
|
|
|
|
|
|
|
|
|
if grad_norms:
|
|
|
import numpy as np
|
|
|
arr = np.array(grad_norms, dtype=float)
|
|
|
q50, q90, q99 = np.percentile(arr, [50, 90, 99]).tolist()
|
|
|
print(f"[grad_norm] p50={q50:.3f} p90={q90:.3f} p99={q99:.3f}")
|
|
|
|
|
|
return avg_loss, global_step
|
|
|
|
|
|
|
|
|
def validate_epoch(*, model, val_loader, device, is_main_process: bool) -> float:
|
|
|
model.eval()
|
|
|
total_loss = 0.0
|
|
|
num = 0
|
|
|
with torch.no_grad():
|
|
|
for batch in val_loader:
|
|
|
batch = {k: v.to(device) for k, v in batch.items()}
|
|
|
outputs = model(**batch)
|
|
|
step_loss = float(outputs["loss"].detach())
|
|
|
total_loss += step_loss
|
|
|
num += 1
|
|
|
|
|
|
|
|
|
if num % 50 == 0 and is_main_process:
|
|
|
try:
|
|
|
step_ppl = math.exp(min(step_loss, 20))
|
|
|
except (OverflowError, ValueError):
|
|
|
step_ppl = float('inf')
|
|
|
print(f"[val_progress] batch={num} loss={step_loss:.4f} ppl={step_ppl:.2f}")
|
|
|
|
|
|
avg = total_loss / max(1, num)
|
|
|
if is_main_process:
|
|
|
try:
|
|
|
avg_ppl = math.exp(min(avg, 20))
|
|
|
except (OverflowError, ValueError):
|
|
|
avg_ppl = float('inf')
|
|
|
print(f"[val] avg_loss={avg:.4f} avg_ppl={avg_ppl:.2f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if MLFLOW_AVAILABLE:
|
|
|
try:
|
|
|
mlflow.log_metrics({
|
|
|
'val/avg_loss': avg,
|
|
|
'val/avg_perplexity': avg_ppl if avg_ppl != float('inf') else 1e10,
|
|
|
})
|
|
|
except Exception:
|
|
|
pass
|
|
|
return avg
|
|
|
|