pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Object detection reward from "Tuning computer vision models with task rewards" (https://arxiv.org/abs/2302.08242).
The `reward_fn` computes the reward for a batch of predictions and ground truth
annotations. When using it to optimize a model that outputs a prediction as a
sequence of tokens like [y0, x0, Y0, X0, class0, confidence0, y1, x1, Y1, ...]
the training loop may look like:
```
# Settings used in the paper.
config.max_level = 1000 # Coordinates are discretized into 1000 buckets.
config.max_conf = 2 # Two tokens are reserved to represent confidence.
config.num_cls = 80 # Number of classes in COCO.
config.nms_w = 0.3 # Weight for duplicate instances.
config.cls_smooth = 0.05 # Adjust the classes weights based on their frequency.
config.reward_thr = (0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95)
config.correct_thr = 0.5 # Learn the IoU when matching with threshold=0.5.
config.conf_w = 0.3 # Weight for the confidence loss.
# 1) Sample N outputs for each input and compute rewards, use one sample to
# optimize and others to compute a reward baseline.
sample_seqs = sample_fn(params, images, num_samples)
sample_rewards, aux = reward_fn(sample_seqs, labels, config)
labels = sample_seqs[:, 0, ...]
rewards = sample_rewards[:, 0]
match_iou = aux["match_iou"][:, 0]
baselines = (jnp.sum(sample_rewards, axis=-1) - rewards) / (num_samples - 1)
# 2) Optimizize the model. By using REINFORCE to adjust the likelihood of the
# sequence based on the reward and with supervision to teach the model to
# predict the expected IoU of each box in its own samples.
def loss_fn(params):
logits = model.apply(params, images, labels, train=True, rngs=rngs)
logits_softmax = jax.nn.log_softmax(logits)
# Use reinforce to optimize the expected reward for the whole sequence.
seq_rewards = (rewards - baselines)
# Note: consider improve this code to skip this loss for confidence tokens.
# The paper did not do it due to a bug (and also does not seem to matter).
target = jax.nn.one_hot(labels, logits.shape[-1]) * seq_rewards[:, None, None]
loss_reward = -jnp.sum(target * logits_softmax, axis=-1)
# Use supervision loss to tune the confidence tokens to predict IoU:
# - (1.0, 0.0, 0.0, ...) -> for padded boxes.
# - (0.0, 1-iou, iou, ...) -> for sampled boxes.
conf0 = (labels[:, 5::6] == 0)
conf1 = (labels[:, 5::6] > 0) * (1.0 - match_iou)
conf2 = (labels[:, 5::6] > 0) * match_iou
target_conf = jnp.stack([conf0, conf1, conf2], axis=-1)
logits_conf = logits_softmax[:, 5::6, :3]
loss_conf = -jnp.sum(target_conf * logits_conf, axis=-1)
loss = jnp.mean(loss_reward) + config.conf_w * jnp.mean(loss_conf)
return loss
```
"""
import functools
import einops
import jax
import jax.numpy as jnp
# Frequency of COCO object detection classes as observed in the training set.
# pylint: disable=bad-whitespace,bad-continuation
CLS_COUNTS = [
262465, 7113, 43867, 8725, 5135, 6069, 4571, 9973, 10759,
12884, 1865, 1983, 1285, 9838, 10806, 4768, 5508, 6587,
9509, 8147, 5513, 1294, 5303, 5131, 8720, 11431, 12354,
6496, 6192, 2682, 6646, 2685, 6347, 9076, 3276, 3747,
5543, 6126, 4812, 24342, 7913, 20650, 5479, 7770, 6165,
14358, 9458, 5851, 4373, 6399, 7308, 7852, 2918, 5821,
7179, 6353, 38491, 5779, 8652, 4192, 15714, 4157, 5805,
4970, 2262, 5703, 2855, 6434, 1673, 3334, 225, 5610,
2637, 24715, 6334, 6613, 1481, 4793, 198, 1954
]
# pylint: enable=bad-whitespace,bad-continuation
def seq2box(seq, max_level, max_conf, num_cls):
"""Extract boxes encoded as sequences."""
# Reshape to instances of boxes
dim_per_box = 6
seq_len = seq.shape[-1]
seq = seq[..., :(seq_len - seq_len % dim_per_box)]
seq = einops.rearrange(seq, "... (n d) -> ... n d", d=dim_per_box)
# Unpack box fields
boxes, labels, confs = seq[..., 0:4], seq[..., 4], seq[..., 5]
boxes = boxes - max_conf - 1
labels = labels - max_conf - 1 - max_level - 1
boxes = jnp.clip(boxes, 0, max_level) / max_level
labels = jnp.clip(labels, 0, num_cls - 1)
confs = jnp.clip(confs, 0, max_conf)
return boxes, labels, confs
def iou_fn(box1, box2):
"""Compute IoU of two boxes."""
ymin1, xmin1, ymax1, xmax1 = box1
ymin2, xmin2, ymax2, xmax2 = box2
a1 = jnp.abs((ymax1 - ymin1) * (xmax1 - xmin1))
a2 = jnp.abs((ymax2 - ymin2) * (xmax2 - xmin2))
yl = jnp.maximum(ymin1, ymin2)
yr = jnp.minimum(ymax1, ymax2)
yi = jnp.maximum(0, yr - yl)
xl = jnp.maximum(xmin1, xmin2)
xr = jnp.minimum(xmax1, xmax2)
xi = jnp.maximum(0, xr - xl)
inter = xi * yi
return inter / (a1 + a2 - inter + 1e-9)
iou_fn_batched = jax.vmap(
jax.vmap(iou_fn, in_axes=(None, 0)), in_axes=(0, None)
)
def _reward_fn_thr(seq_pred, seq_gt,
thr, nms_w, max_level, max_conf, num_cls, cls_smooth):
"""Compute detection reward function for a given IoU threshold."""
# Weight matches of each label inversely proportional to the percentage of
# GT instances with such label in the whole train dataset. Additionally
# smooth out the observed distribution.
cls_counts = jnp.array(CLS_COUNTS)
weights = 1.0 / (cls_counts + cls_smooth*jnp.sum(cls_counts))
weights = num_cls * weights / jnp.sum(weights)
boxes_pred, labels_pred, confs_pred = seq2box(
seq_pred, max_level, max_conf, num_cls)
boxes_gt, labels_gt, confs_gt = seq2box(
seq_gt, max_level, max_conf, num_cls)
# Compute IoU matrix: Predictions X GT
iou = iou_fn_batched(boxes_pred, boxes_gt)
# IoU thr
iou = jnp.where(iou > thr, iou, 0.0)
# EOS mask
confs_mask = (confs_pred[:, None] > 0) * (confs_gt[None, :] > 0)
iou = confs_mask * iou
# Label mask
label_mask = labels_pred[:, None] == labels_gt[None, :]
iou = label_mask * iou
# Each prediction is matched to a single box
single_match_mask = jax.nn.one_hot(jnp.argmax(iou, axis=1), iou.shape[1])
iou = iou * single_match_mask
# Pred. boxes indicators
correct = jnp.any(iou > 0.0, axis=1).astype("int32") + 1
correct = jnp.where(confs_pred > 0, correct, 0)
# For each GT box find best match
matches_idx = jnp.argmax(iou, axis=0)
matches_iou = jnp.take_along_axis(iou, matches_idx[None], axis=0)[0]
matches_idx = jnp.where(matches_iou > 0.0, matches_idx, -1)
match_reward = jnp.sum((matches_idx >= 0) * weights[labels_gt][None, :])
# Compute duplicate penalty (aka NMS).
matches_mask = jax.nn.one_hot(matches_idx, iou.shape[0], axis=0)
nms_penalty = jnp.sum(
(iou > 0.0) * (1 - matches_mask) * weights[labels_pred][:, None])
match_iou = jnp.sum(iou, axis=1)
return {
"reward": (match_reward - nms_w * nms_penalty),
"num_matches": jnp.sum(matches_idx >= 0),
"nms_penalty": nms_penalty,
"correct": correct,
"match_iou": match_iou,
}
def reward_fn(seqs_pred, seqs_gt, config):
"""Total reward."""
result = {}
thrs = config.reward_thr
correct_thr = config.correct_thr
r_keys = ["reward", "num_matches", "nms_penalty"]
for thr in thrs:
fn = functools.partial(
_reward_fn_thr,
thr=thr,
nms_w=config.nms_w,
max_level=config.max_level,
max_conf=config.max_conf,
num_cls=config.num_cls,
cls_smooth=config.cls_smooth,
)
rewards = jax.vmap(jax.vmap(fn, in_axes=(0, None)))(seqs_pred, seqs_gt)
result = {**result, **{f"{k}-{thr:0.1f}": rewards[k]
for k in r_keys}}
if thr == correct_thr:
correct = rewards["correct"]
match_iou = rewards["match_iou"]
result = {
**result,
**{k: jnp.mean(
jnp.array([result[f"{k}-{thr:0.1f}"] for thr in thrs]), axis=0)
for k in r_keys}
}
return result["reward"], {
"result": result,
"correct": correct,
"match_iou": match_iou,
}