|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Evaluator for the classfication task.""" |
|
|
|
|
|
import functools |
|
|
|
from big_vision.evaluators import common |
|
import big_vision.utils as u |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
|
|
|
|
API = 'jit' |
|
|
|
|
|
|
|
|
|
@functools.cache |
|
def get_eval_fn(predict_fn, loss_name): |
|
"""Produces eval function, also applies pmap.""" |
|
@jax.jit |
|
def _eval_fn(train_state, batch, labels, mask): |
|
logits, *_ = predict_fn(train_state, batch) |
|
|
|
|
|
mask *= labels.max(axis=1) |
|
|
|
loss = getattr(u, loss_name)( |
|
logits=logits, labels=labels, reduction=False) |
|
loss = jnp.sum(loss * mask) |
|
|
|
top1_idx = jnp.argmax(logits, axis=1) |
|
|
|
top1_correct = jnp.take_along_axis( |
|
labels, top1_idx[:, None], axis=1)[:, 0] |
|
ncorrect = jnp.sum(top1_correct * mask) |
|
nseen = jnp.sum(mask) |
|
return ncorrect, loss, nseen |
|
return _eval_fn |
|
|
|
|
|
class Evaluator: |
|
"""Classification evaluator.""" |
|
|
|
def __init__(self, predict_fn, loss_name, label_key='labels', **kw): |
|
self.get_data_iter, self.steps = common.eval_input_pipeline(**kw) |
|
self.eval_fn = get_eval_fn(predict_fn, loss_name) |
|
self.label_key = label_key |
|
|
|
def run(self, train_state): |
|
"""Computes all metrics.""" |
|
ncorrect, loss, nseen = 0, 0, 0 |
|
for _, batch in zip(range(self.steps), self.get_data_iter()): |
|
labels, mask = batch.pop(self.label_key), batch.pop('_mask') |
|
batch_ncorrect, batch_losses, batch_nseen = jax.device_get( |
|
self.eval_fn(train_state, batch, labels, mask)) |
|
ncorrect += batch_ncorrect |
|
loss += batch_losses |
|
nseen += batch_nseen |
|
yield ('prec@1', ncorrect / nseen) |
|
yield ('loss', loss / nseen) |
|
|