|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Object detection reward from "Tuning computer vision models with task rewards" (https://arxiv.org/abs/2302.08242). |
|
|
|
The `reward_fn` computes the reward for a batch of predictions and ground truth |
|
annotations. When using it to optimize a model that outputs a prediction as a |
|
sequence of tokens like [y0, x0, Y0, X0, class0, confidence0, y1, x1, Y1, ...] |
|
the training loop may look like: |
|
|
|
``` |
|
# Settings used in the paper. |
|
config.max_level = 1000 # Coordinates are discretized into 1000 buckets. |
|
config.max_conf = 2 # Two tokens are reserved to represent confidence. |
|
config.num_cls = 80 # Number of classes in COCO. |
|
config.nms_w = 0.3 # Weight for duplicate instances. |
|
config.cls_smooth = 0.05 # Adjust the classes weights based on their frequency. |
|
config.reward_thr = (0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95) |
|
config.correct_thr = 0.5 # Learn the IoU when matching with threshold=0.5. |
|
config.conf_w = 0.3 # Weight for the confidence loss. |
|
|
|
|
|
# 1) Sample N outputs for each input and compute rewards, use one sample to |
|
# optimize and others to compute a reward baseline. |
|
sample_seqs = sample_fn(params, images, num_samples) |
|
sample_rewards, aux = reward_fn(sample_seqs, labels, config) |
|
labels = sample_seqs[:, 0, ...] |
|
rewards = sample_rewards[:, 0] |
|
match_iou = aux["match_iou"][:, 0] |
|
baselines = (jnp.sum(sample_rewards, axis=-1) - rewards) / (num_samples - 1) |
|
|
|
# 2) Optimizize the model. By using REINFORCE to adjust the likelihood of the |
|
# sequence based on the reward and with supervision to teach the model to |
|
# predict the expected IoU of each box in its own samples. |
|
def loss_fn(params): |
|
logits = model.apply(params, images, labels, train=True, rngs=rngs) |
|
logits_softmax = jax.nn.log_softmax(logits) |
|
|
|
# Use reinforce to optimize the expected reward for the whole sequence. |
|
seq_rewards = (rewards - baselines) |
|
# Note: consider improve this code to skip this loss for confidence tokens. |
|
# The paper did not do it due to a bug (and also does not seem to matter). |
|
target = jax.nn.one_hot(labels, logits.shape[-1]) * seq_rewards[:, None, None] |
|
loss_reward = -jnp.sum(target * logits_softmax, axis=-1) |
|
|
|
# Use supervision loss to tune the confidence tokens to predict IoU: |
|
# - (1.0, 0.0, 0.0, ...) -> for padded boxes. |
|
# - (0.0, 1-iou, iou, ...) -> for sampled boxes. |
|
conf0 = (labels[:, 5::6] == 0) |
|
conf1 = (labels[:, 5::6] > 0) * (1.0 - match_iou) |
|
conf2 = (labels[:, 5::6] > 0) * match_iou |
|
target_conf = jnp.stack([conf0, conf1, conf2], axis=-1) |
|
logits_conf = logits_softmax[:, 5::6, :3] |
|
loss_conf = -jnp.sum(target_conf * logits_conf, axis=-1) |
|
|
|
loss = jnp.mean(loss_reward) + config.conf_w * jnp.mean(loss_conf) |
|
return loss |
|
``` |
|
""" |
|
import functools |
|
|
|
import einops |
|
import jax |
|
import jax.numpy as jnp |
|
|
|
|
|
|
|
|
|
CLS_COUNTS = [ |
|
262465, 7113, 43867, 8725, 5135, 6069, 4571, 9973, 10759, |
|
12884, 1865, 1983, 1285, 9838, 10806, 4768, 5508, 6587, |
|
9509, 8147, 5513, 1294, 5303, 5131, 8720, 11431, 12354, |
|
6496, 6192, 2682, 6646, 2685, 6347, 9076, 3276, 3747, |
|
5543, 6126, 4812, 24342, 7913, 20650, 5479, 7770, 6165, |
|
14358, 9458, 5851, 4373, 6399, 7308, 7852, 2918, 5821, |
|
7179, 6353, 38491, 5779, 8652, 4192, 15714, 4157, 5805, |
|
4970, 2262, 5703, 2855, 6434, 1673, 3334, 225, 5610, |
|
2637, 24715, 6334, 6613, 1481, 4793, 198, 1954 |
|
] |
|
|
|
|
|
|
|
def seq2box(seq, max_level, max_conf, num_cls): |
|
"""Extract boxes encoded as sequences.""" |
|
|
|
dim_per_box = 6 |
|
seq_len = seq.shape[-1] |
|
seq = seq[..., :(seq_len - seq_len % dim_per_box)] |
|
seq = einops.rearrange(seq, "... (n d) -> ... n d", d=dim_per_box) |
|
|
|
|
|
boxes, labels, confs = seq[..., 0:4], seq[..., 4], seq[..., 5] |
|
boxes = boxes - max_conf - 1 |
|
labels = labels - max_conf - 1 - max_level - 1 |
|
boxes = jnp.clip(boxes, 0, max_level) / max_level |
|
labels = jnp.clip(labels, 0, num_cls - 1) |
|
confs = jnp.clip(confs, 0, max_conf) |
|
|
|
return boxes, labels, confs |
|
|
|
|
|
def iou_fn(box1, box2): |
|
"""Compute IoU of two boxes.""" |
|
ymin1, xmin1, ymax1, xmax1 = box1 |
|
ymin2, xmin2, ymax2, xmax2 = box2 |
|
|
|
a1 = jnp.abs((ymax1 - ymin1) * (xmax1 - xmin1)) |
|
a2 = jnp.abs((ymax2 - ymin2) * (xmax2 - xmin2)) |
|
|
|
yl = jnp.maximum(ymin1, ymin2) |
|
yr = jnp.minimum(ymax1, ymax2) |
|
yi = jnp.maximum(0, yr - yl) |
|
|
|
xl = jnp.maximum(xmin1, xmin2) |
|
xr = jnp.minimum(xmax1, xmax2) |
|
xi = jnp.maximum(0, xr - xl) |
|
|
|
inter = xi * yi |
|
return inter / (a1 + a2 - inter + 1e-9) |
|
|
|
iou_fn_batched = jax.vmap( |
|
jax.vmap(iou_fn, in_axes=(None, 0)), in_axes=(0, None) |
|
) |
|
|
|
|
|
def _reward_fn_thr(seq_pred, seq_gt, |
|
thr, nms_w, max_level, max_conf, num_cls, cls_smooth): |
|
"""Compute detection reward function for a given IoU threshold.""" |
|
|
|
|
|
|
|
cls_counts = jnp.array(CLS_COUNTS) |
|
weights = 1.0 / (cls_counts + cls_smooth*jnp.sum(cls_counts)) |
|
weights = num_cls * weights / jnp.sum(weights) |
|
|
|
boxes_pred, labels_pred, confs_pred = seq2box( |
|
seq_pred, max_level, max_conf, num_cls) |
|
boxes_gt, labels_gt, confs_gt = seq2box( |
|
seq_gt, max_level, max_conf, num_cls) |
|
|
|
|
|
iou = iou_fn_batched(boxes_pred, boxes_gt) |
|
|
|
|
|
iou = jnp.where(iou > thr, iou, 0.0) |
|
|
|
|
|
confs_mask = (confs_pred[:, None] > 0) * (confs_gt[None, :] > 0) |
|
iou = confs_mask * iou |
|
|
|
|
|
label_mask = labels_pred[:, None] == labels_gt[None, :] |
|
iou = label_mask * iou |
|
|
|
|
|
single_match_mask = jax.nn.one_hot(jnp.argmax(iou, axis=1), iou.shape[1]) |
|
iou = iou * single_match_mask |
|
|
|
|
|
correct = jnp.any(iou > 0.0, axis=1).astype("int32") + 1 |
|
correct = jnp.where(confs_pred > 0, correct, 0) |
|
|
|
|
|
matches_idx = jnp.argmax(iou, axis=0) |
|
matches_iou = jnp.take_along_axis(iou, matches_idx[None], axis=0)[0] |
|
matches_idx = jnp.where(matches_iou > 0.0, matches_idx, -1) |
|
|
|
match_reward = jnp.sum((matches_idx >= 0) * weights[labels_gt][None, :]) |
|
|
|
|
|
matches_mask = jax.nn.one_hot(matches_idx, iou.shape[0], axis=0) |
|
nms_penalty = jnp.sum( |
|
(iou > 0.0) * (1 - matches_mask) * weights[labels_pred][:, None]) |
|
|
|
match_iou = jnp.sum(iou, axis=1) |
|
|
|
return { |
|
"reward": (match_reward - nms_w * nms_penalty), |
|
"num_matches": jnp.sum(matches_idx >= 0), |
|
"nms_penalty": nms_penalty, |
|
"correct": correct, |
|
"match_iou": match_iou, |
|
} |
|
|
|
|
|
def reward_fn(seqs_pred, seqs_gt, config): |
|
"""Total reward.""" |
|
result = {} |
|
thrs = config.reward_thr |
|
correct_thr = config.correct_thr |
|
r_keys = ["reward", "num_matches", "nms_penalty"] |
|
for thr in thrs: |
|
fn = functools.partial( |
|
_reward_fn_thr, |
|
thr=thr, |
|
nms_w=config.nms_w, |
|
max_level=config.max_level, |
|
max_conf=config.max_conf, |
|
num_cls=config.num_cls, |
|
cls_smooth=config.cls_smooth, |
|
) |
|
rewards = jax.vmap(jax.vmap(fn, in_axes=(0, None)))(seqs_pred, seqs_gt) |
|
|
|
result = {**result, **{f"{k}-{thr:0.1f}": rewards[k] |
|
for k in r_keys}} |
|
if thr == correct_thr: |
|
correct = rewards["correct"] |
|
match_iou = rewards["match_iou"] |
|
|
|
result = { |
|
**result, |
|
**{k: jnp.mean( |
|
jnp.array([result[f"{k}-{thr:0.1f}"] for thr in thrs]), axis=0) |
|
for k in r_keys} |
|
} |
|
|
|
return result["reward"], { |
|
"result": result, |
|
"correct": correct, |
|
"match_iou": match_iou, |
|
} |
|
|