|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utils for evaluators in general.""" |
|
|
|
import dataclasses |
|
import functools |
|
import importlib |
|
import json |
|
import os |
|
from typing import Any, Callable |
|
|
|
from absl import flags |
|
from big_vision import input_pipeline |
|
from big_vision.datasets import core as ds_core |
|
from big_vision.pp import builder as pp_builder |
|
import big_vision.utils as u |
|
import flax |
|
import jax |
|
import numpy as np |
|
|
|
from tensorflow.io import gfile |
|
|
|
|
|
def from_config(config, predict_fns, |
|
write_note=lambda s: s, |
|
get_steps=lambda key, cfg: cfg[f"{key}_steps"], |
|
devices=None): |
|
"""Creates a list of evaluators based on `config`.""" |
|
evaluators = [] |
|
specs = config.get("evals", {}) |
|
|
|
for name, cfg in specs.items(): |
|
write_note(name) |
|
|
|
|
|
cfg = cfg.to_dict() |
|
module = cfg.pop("type", name) |
|
pred_key = cfg.pop("pred", "predict") |
|
pred_kw = cfg.pop("pred_kw", None) |
|
prefix = cfg.pop("prefix", f"{name}/") |
|
cfg.pop("skip_first", None) |
|
logsteps = get_steps("log", cfg) |
|
for typ in ("steps", "epochs", "examples", "percent"): |
|
cfg.pop(f"log_{typ}", None) |
|
|
|
|
|
|
|
cfg["batch_size"] = cfg.get("batch_size") or config.get("batch_size_eval") or config.get("input.batch_size") or config.get("batch_size") |
|
|
|
module = importlib.import_module(f"big_vision.evaluators.{module}") |
|
|
|
if devices is not None: |
|
cfg["devices"] = devices |
|
|
|
api_type = getattr(module, "API", "pmap") |
|
if api_type == "pmap" and "devices" in cfg: |
|
raise RuntimeError( |
|
"You are seemingly using the old pmap-based evaluator, but with " |
|
"jit-based train loop, see (internal link) for more details.") |
|
if api_type == "jit" and "devices" not in cfg: |
|
raise RuntimeError( |
|
"You are seemingly using new jit-based evaluator, but with " |
|
"old pmap-based train loop, see (internal link) for more details.") |
|
|
|
try: |
|
predict_fn = predict_fns[pred_key] |
|
except KeyError as e: |
|
raise ValueError( |
|
f"Unknown predict_fn '{pred_key}'. Available predict_fns are:\n" |
|
+ "\n".join(predict_fns)) from e |
|
if pred_kw is not None: |
|
predict_fn = _CacheablePartial(predict_fn, flax.core.freeze(pred_kw)) |
|
evaluator = module.Evaluator(predict_fn, **cfg) |
|
evaluators.append((name, evaluator, logsteps, prefix)) |
|
|
|
return evaluators |
|
|
|
|
|
@dataclasses.dataclass(frozen=True, eq=True) |
|
class _CacheablePartial: |
|
"""partial(fn, **kwargs) that defines hash and eq - to help with jit caches. |
|
|
|
This is particularly common in evaluators when one has many evaluator |
|
instances that run on difference slices of data. |
|
|
|
Example: |
|
|
|
``` |
|
f1 = _CacheablePartial(fn, a=1) |
|
jax.jit(f1)(...) |
|
jax.jit(_CacheablePartial(fn, a=1))(...) # fn won't be retraced. |
|
del f1 |
|
jax.jit(_CacheablePartial(fn, a=1))(...) # fn will be retraced. |
|
``` |
|
""" |
|
fn: Callable[..., Any] |
|
kwargs: flax.core.FrozenDict |
|
|
|
def __call__(self, *args, **kwargs): |
|
return functools.partial(self.fn, **self.kwargs)(*args, **kwargs) |
|
|
|
|
|
def eval_input_pipeline( |
|
data, pp_fn, batch_size, devices, keep_on_cpu=(), |
|
cache="pipeline", prefetch=1, warmup=False, |
|
): |
|
"""Create an input pipeline in the way used by most evaluators. |
|
|
|
Args: |
|
data: The configuration to create the data source (like for training). |
|
pp_fn: A string representing the preprocessing to be performed. |
|
batch_size: The batch size to use. |
|
devices: The devices that the batches are sharded and pre-fetched onto. |
|
keep_on_cpu: See input_pipeline.start_global. Entries in the batch that |
|
should be kept on the CPU, hence could be ragged or of string type. |
|
cache: One of "none", "pipeline", "raw_data", "final_data". Determines what |
|
part of the input stream should be cached across evaluator runs. They use |
|
more and more RAM, but make evals faster, in that order. |
|
- "none": Entirely re-create and destroy the input pipeline each run. |
|
- "pipeline": Keep the (tf.data) pipeline object alive across runs. |
|
- "raw_data": Cache the full raw data before pre-processing. |
|
- "final_data": Cache the full raw data after pre-processing. |
|
prefetch: How many batches to fetch ahead. |
|
warmup: Start fetching the first batch at creation time (right now), |
|
instead of once the iteration starts. |
|
|
|
Returns: |
|
A tuple (get_iter, steps), the first element is a function that returns |
|
the iterator to be used for an evaluation, the second one is how many steps |
|
should be iterated for doing one evaluation. |
|
""" |
|
assert ( |
|
cache is None |
|
or cache.lower() in ("none", "pipeline", "raw_data", "final_data") |
|
), f"Unknown value for cache: {cache}" |
|
data_source = ds_core.get(**data) |
|
tfdata, steps = input_pipeline.make_for_inference( |
|
data_source.get_tfdata(ordered=True, allow_cache=cache.lower() != "none"), |
|
batch_size=batch_size, |
|
num_ex_per_process=data_source.num_examples_per_process(), |
|
preprocess_fn=pp_builder.get_preprocess_fn(pp_fn, str(data)), |
|
cache_final=cache == "raw_data", |
|
cache_raw=cache == "final_data") |
|
get_data_iter = lambda: input_pipeline.start_global( |
|
tfdata, devices, prefetch, keep_on_cpu, warmup) |
|
|
|
|
|
if cache in ("pipeline", "raw_data", "final_data"): |
|
data_iter = get_data_iter() |
|
get_data_iter = lambda: data_iter |
|
|
|
return get_data_iter, steps |
|
|
|
|
|
def process_sum(tree): |
|
"""Sums the pytree across all processes.""" |
|
if jax.process_count() == 1: |
|
return tree |
|
|
|
with jax.transfer_guard_device_to_host("allow"): |
|
gathered = jax.experimental.multihost_utils.process_allgather(tree) |
|
return jax.tree.map(functools.partial(np.sum, axis=0), gathered) |
|
|
|
|
|
def resolve_outfile(outfile, split="", **kw): |
|
if not outfile: |
|
return None |
|
|
|
|
|
|
|
if "{workdir}" in outfile and not flags.FLAGS.workdir: |
|
return None |
|
|
|
return outfile.format( |
|
workdir=flags.FLAGS.workdir, |
|
split="".join(c if c not in "[]%:" else "_" for c in split), |
|
step=getattr(u.chrono, "prev_step", None), |
|
**kw, |
|
) |
|
|
|
|
|
def multiprocess_write_json(outfile, jobj): |
|
"""Write a single json file combining all processes' `jobj`s.""" |
|
if not outfile: |
|
return |
|
|
|
outfile = resolve_outfile(outfile) |
|
gfile.makedirs(os.path.dirname(outfile)) |
|
|
|
if isinstance(jobj, list): |
|
combine_fn = list.extend |
|
elif isinstance(jobj, dict): |
|
combine_fn = dict.update |
|
else: |
|
raise TypeError(f"Can only write list or dict jsons, but got {type(jobj)}") |
|
|
|
|
|
with gfile.GFile(outfile + f".p{jax.process_index()}", "w+") as f: |
|
f.write(json.dumps(jobj)) |
|
|
|
u.sync() |
|
|
|
|
|
all_json = type(jobj)() |
|
if jax.process_index() == 0: |
|
for pid in range(jax.process_count()): |
|
with gfile.GFile(outfile + f".p{pid}", "r") as f: |
|
combine_fn(all_json, json.loads(f.read())) |
|
with gfile.GFile(outfile, "w+") as f: |
|
f.write(json.dumps(all_json)) |
|
|
|
|
|
u.sync() |
|
gfile.remove(outfile + f".p{jax.process_index()}") |
|
|
|
return all_json |
|
|