|
import dataclasses |
|
import hashlib |
|
import sys |
|
import typing |
|
import warnings |
|
import socket |
|
from typing import Optional, Any, Dict |
|
import os |
|
import logging |
|
import absl.flags |
|
from flax.traverse_util import flatten_dict |
|
|
|
from ml_collections import ConfigDict, config_flags |
|
from ml_collections.config_dict import placeholder |
|
from mlxu import function_args_to_config |
|
|
|
_log_extra_fields: Dict[str, Any] = {} |
|
|
|
|
|
def is_float_printable(x): |
|
try: |
|
f"{x:0.2f}" |
|
return True |
|
except (ValueError, TypeError): |
|
return False |
|
|
|
|
|
def compute_hash(string: str) -> str: |
|
"""Computes the hash of a string.""" |
|
return hashlib.sha256(string.encode("utf-8")).hexdigest() |
|
|
|
|
|
def pop_metadata(data): |
|
meta = {k: data.pop(k) for k in list(data) if k.startswith("metadata")} |
|
return data, meta |
|
|
|
|
|
def setup_logging(): |
|
handler: logging.Handler |
|
handler = logging.StreamHandler(sys.stdout) |
|
formatter = logging.Formatter( |
|
"[%(levelname)-.1s %(asctime)s %(filename)s:%(lineno)s] %(message)s", |
|
datefmt="%H:%M:%S" |
|
) |
|
handler.setFormatter(formatter) |
|
logging.basicConfig(handlers=[handler], level=logging.INFO) |
|
|
|
logging.captureWarnings(True) |
|
logging.getLogger("urllib3").setLevel(logging.ERROR) |
|
|
|
|
|
def get_maybe_optional_type(field_type): |
|
if type(None) in typing.get_args(field_type): |
|
|
|
args = [x for x in typing.get_args(field_type) if x != type(None)] |
|
assert len(args) == 1 |
|
field_type = args[0] |
|
return field_type |
|
|
|
|
|
def config_from_dataclass(dataclass, defaults_to_none=False) -> ConfigDict: |
|
"""Build a `ConfigDict` matching the possibly nested dataclass |
|
|
|
dataclass: A dataclass instance or a dataclass type, if an instance defaults |
|
will be set to the values in the class, if a class defaults will be |
|
set to the field defaults, or None if the field is required |
|
defaults_to_none: Make all defaults None |
|
""" |
|
out = {} |
|
fields = dataclasses.fields(dataclass) |
|
for field in fields: |
|
if not field.init: |
|
continue |
|
|
|
if defaults_to_none: |
|
default = None |
|
elif hasattr(dataclass, field.name): |
|
default = getattr(dataclass, field.name) |
|
elif field.default is dataclasses.MISSING: |
|
default = None |
|
else: |
|
default = field.default |
|
|
|
field_type = get_maybe_optional_type(field.type) |
|
|
|
if hasattr(field_type, "__dataclass_fields__"): |
|
if not defaults_to_none and default is None: |
|
pass |
|
else: |
|
out[field.name] = config_from_dataclass( |
|
default or field.type, defaults_to_none=defaults_to_none) |
|
else: |
|
if default is None: |
|
assert not field_type == typing.Any |
|
origin = getattr(field_type, "__origin__", None) |
|
if origin is not None: |
|
field_type = origin |
|
out[field.name] = placeholder(field_type) |
|
else: |
|
out[field.name] = default |
|
return ConfigDict(out) |
|
|
|
|
|
def dataclass_with_none(cls): |
|
"""Build an instance of possibly nested dataclass `cls` with all attributes None""" |
|
fields = dataclasses.fields(cls) |
|
args = {} |
|
for field in fields: |
|
if not field.init: |
|
pass |
|
elif dataclasses.is_dataclass(field.type): |
|
args[field.name] = dataclass_with_none(field.type) |
|
else: |
|
args[field.name] = None |
|
return cls(**args) |
|
|
|
|
|
def dataclass_from_config(cls, config: Dict): |
|
"""Build an instance of `cls` with attributes from `config``""" |
|
fields = dataclasses.fields(cls) |
|
args = set(x.name for x in fields) |
|
for k in config.keys(): |
|
if k not in args: |
|
raise ValueError(f"Config has unknown arg {k} fr {cls}") |
|
args = {} |
|
for field in fields: |
|
if not field.init: |
|
continue |
|
|
|
field_type = get_maybe_optional_type(field.type) |
|
if hasattr(field_type, "__dataclass_fields__"): |
|
if config.get(field.name) is None: |
|
args[field.name] = None |
|
elif hasattr(field_type, "from_dict"): |
|
src = config[field.name] |
|
if isinstance(src, ConfigDict): |
|
src = src.to_dict() |
|
args[field.name] = field_type.from_dict(src) |
|
else: |
|
args[field.name] = dataclass_from_config(field_type, config[field.name]) |
|
elif field.name in config: |
|
if isinstance(config[field.name], ConfigDict): |
|
args[field.name] = config[field.name].to_dict() |
|
else: |
|
args[field.name] = config[field.name] |
|
return cls(**args) |
|
|
|
|
|
def update_dataclass(obj, updates): |
|
"""Sets attributes in `obj` to match non-None fields in `updates`""" |
|
fields = dataclasses.fields(obj) |
|
for field in fields: |
|
if not field.init: |
|
continue |
|
update = updates.get(field.name) |
|
if update is None: |
|
continue |
|
current_value = getattr(obj, field.name) |
|
if dataclasses.is_dataclass(current_value): |
|
update_dataclass(current_value, update) |
|
else: |
|
if isinstance(update, (ConfigDict, dict)): |
|
assert all(x is None for x in flatten_dict(update).values()) |
|
else: |
|
setattr(obj, field.name, update) |
|
|
|
|
|
def log_metrics_to_console(prefix: str, metrics: Dict[str, float]): |
|
|
|
def format_value(value: float) -> str: |
|
if isinstance(value, str): |
|
return value |
|
if value < 0.0001: |
|
return str(value) |
|
elif value > 1000: |
|
return f"{int(value):,d}" |
|
elif value > 100: |
|
return f"{value:.1f}" |
|
elif value > 10: |
|
return f"{value:.2f}" |
|
elif value > 1: |
|
return f"{value:.3f}" |
|
else: |
|
return f"{value:.4f}" |
|
|
|
logging.info( |
|
f"{prefix}\n" |
|
+ "\n".join( |
|
[ |
|
f" {name}={format_value(value)}" |
|
for name, value in metrics.items() |
|
if not name.startswith("optim/") |
|
] |
|
) |
|
) |