| import os |
| from typing import Any, Dict, Optional |
|
|
|
|
| _WANDB_AVAILABLE = False |
| _WANDB_RUN = None |
|
|
|
|
| def _try_import_wandb(): |
| global _WANDB_AVAILABLE |
| if _WANDB_AVAILABLE: |
| return True |
| try: |
| import wandb |
|
|
| _WANDB_AVAILABLE = True |
| return True |
| except Exception: |
| _WANDB_AVAILABLE = False |
| return False |
|
|
|
|
| def _safe_get(cfg: Dict[str, Any], path: list[str], default: Any = None) -> Any: |
| cur: Any = cfg |
| for key in path: |
| if not isinstance(cur, dict) or key not in cur: |
| return default |
| cur = cur[key] |
| return cur |
|
|
|
|
| def is_enabled(cfg: Dict[str, Any]) -> bool: |
| return bool(_safe_get(cfg, ["logging", "wandb", "enabled"], False)) |
|
|
|
|
| def init(cfg: Dict[str, Any], run_dir: str, run_name: Optional[str] = None) -> None: |
| """ |
| Initialize Weights & Biases if enabled in config. No-op if disabled or wandb not installed. |
| """ |
| global _WANDB_RUN |
| if not is_enabled(cfg): |
| return |
| if not _try_import_wandb(): |
| return |
|
|
| import wandb |
|
|
| project = _safe_get(cfg, ["logging", "wandb", "project"], "llm-negotiation") |
| entity = _safe_get(cfg, ["logging", "wandb", "entity"], None) |
| mode = _safe_get(cfg, ["logging", "wandb", "mode"], "online") |
| tags = _safe_get(cfg, ["logging", "wandb", "tags"], []) or [] |
| notes = _safe_get(cfg, ["logging", "wandb", "notes"], None) |
| group = _safe_get(cfg, ["logging", "wandb", "group"], None) |
| name = _safe_get(cfg, ["logging", "wandb", "name"], run_name) |
|
|
| |
| os.makedirs(run_dir, exist_ok=True) |
| os.environ.setdefault("WANDB_DIR", run_dir) |
|
|
| |
| try: |
| from omegaconf import OmegaConf |
|
|
| cfg_container = OmegaConf.to_container(cfg, resolve=True) |
| except Exception: |
| cfg_container = cfg |
|
|
| _WANDB_RUN = wandb.init( |
| project=project, |
| entity=entity, |
| mode=mode, |
| name=name, |
| group=group, |
| tags=tags, |
| notes=notes, |
| config=cfg_container, |
| dir=run_dir, |
| reinit=True, |
| ) |
|
|
|
|
| def log(metrics: Dict[str, Any], step: Optional[int] = None) -> None: |
| """Log a flat dictionary of metrics to W&B if active.""" |
| if not _WANDB_AVAILABLE or _WANDB_RUN is None: |
| return |
| try: |
| import wandb |
|
|
| wandb.log(metrics if step is None else dict(metrics, step=step)) |
| except Exception: |
| pass |
|
|
|
|
| def _flatten(prefix: str, data: Dict[str, Any], out: Dict[str, Any]) -> None: |
| for k, v in data.items(): |
| key = f"{prefix}.{k}" if prefix else k |
| if isinstance(v, dict): |
| _flatten(key, v, out) |
| else: |
| out[key] = v |
|
|
|
|
| def _summarize_value(value: Any) -> Dict[str, Any]: |
| import numpy as np |
|
|
| if value is None: |
| return {"none": 1} |
| |
| if isinstance(value, (int, float)): |
| return {"value": float(value)} |
| |
| try: |
| arr = np.asarray(value) |
| if arr.size == 0: |
| return {"size": 0} |
| return { |
| "mean": float(np.nanmean(arr)), |
| "min": float(np.nanmin(arr)), |
| "max": float(np.nanmax(arr)), |
| "last": float(arr.reshape(-1)[-1]), |
| "size": int(arr.size), |
| } |
| except Exception: |
| |
| return {"text": str(value)} |
|
|
|
|
| def log_tally(array_tally: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: |
| """ |
| Flatten and summarize Tally.array_tally and log to WandB. |
| Each leaf list/array is summarized with mean/min/max/last/size. |
| """ |
| if not _WANDB_AVAILABLE or _WANDB_RUN is None: |
| return |
| summarized: Dict[str, Any] = {} |
|
|
| def walk(node: Any, path: list[str]): |
| if isinstance(node, dict): |
| for k, v in node.items(): |
| walk(v, path + [k]) |
| return |
| |
| key = ".".join([p for p in ([prefix] if prefix else []) + path]) |
| try: |
| summary = _summarize_value(node) |
| for sk, sv in summary.items(): |
| summarized[f"{key}.{sk}"] = sv |
| except Exception: |
| summarized[f"{key}.error"] = 1 |
|
|
| walk(array_tally, []) |
| if summarized: |
| log(summarized, step=step) |
|
|
|
|
| def log_flat_stats(stats: Dict[str, Any], prefix: str = "", step: Optional[int] = None) -> None: |
| if not _WANDB_AVAILABLE or _WANDB_RUN is None: |
| return |
| flat: Dict[str, Any] = {} |
| _flatten(prefix, stats, flat) |
| if flat: |
| log(flat, step=step) |
|
|
|
|
|
|