|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Autorgregressive sampler for GIVT."""
|
|
|
|
import functools
|
|
from typing import Any, Optional
|
|
|
|
from big_vision.models.proj.givt import parallel_decode
|
|
import flax
|
|
from flax import linen as nn
|
|
import jax
|
|
from jax import lax
|
|
from jax import numpy as jnp
|
|
import ml_collections
|
|
|
|
|
|
def _sample_gmm(
|
|
gmm_pdf,
|
|
*,
|
|
rng,
|
|
cfg_inference_weight=None,
|
|
gmm_pdf_uncond=None,
|
|
):
|
|
"""Draw a single sample from a GMM."""
|
|
if cfg_inference_weight is not None:
|
|
assert gmm_pdf_uncond is not None
|
|
gmm_pdf = parallel_decode.CFGDensity(
|
|
gmm_pdf, gmm_pdf_uncond, w=cfg_inference_weight, rng=rng
|
|
)
|
|
samples = gmm_pdf.sample(seed=rng)
|
|
logprobs = gmm_pdf.log_prob(samples)
|
|
if logprobs.ndim == 2:
|
|
logprobs = logprobs[..., None]
|
|
return samples, logprobs
|
|
|
|
|
|
|
|
def _flatten_samples_dim(x):
|
|
"""Flattens samples dimension into batch dimension."""
|
|
if x.ndim == 0:
|
|
return x
|
|
return x.reshape((x.shape[0] * x.shape[1],) + x.shape[2:])
|
|
|
|
|
|
def _unflatten_samples_dim(x, batch_size, num_samples):
|
|
"""Unflattens first dimension into batch and samples dimensions."""
|
|
if x.ndim == 0:
|
|
return x
|
|
assert batch_size * num_samples == x.shape[0]
|
|
return x.reshape((batch_size, num_samples) + x.shape[1:])
|
|
|
|
|
|
def _cache_map(fn, cache, scan=False):
|
|
"""Maps function over cache."""
|
|
if scan:
|
|
|
|
|
|
fn_mod = lambda x: jax.lax.map(fn, x) if x.ndim > 0 else fn(x)
|
|
else:
|
|
fn_mod = fn
|
|
|
|
frozen = isinstance(cache, flax.core.FrozenDict)
|
|
if frozen:
|
|
cache = flax.core.unfreeze(cache)
|
|
flat_cache = flax.traverse_util.flatten_dict(cache)
|
|
|
|
keyvals = {k: v for k, v in flat_cache.items() if k[-1] != "cached_bias"}
|
|
keyvals = jax.tree_map(fn_mod, keyvals)
|
|
flat_cache.update(keyvals)
|
|
new_cache = flax.traverse_util.unflatten_dict(flat_cache)
|
|
if frozen:
|
|
new_cache = flax.core.freeze(new_cache)
|
|
return new_cache
|
|
|
|
|
|
@flax.struct.dataclass
|
|
class LoopState:
|
|
"""Internal state of the sampling loop."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rng: jnp.ndarray
|
|
cache: Any
|
|
sequences: jnp.ndarray
|
|
logprobs: jnp.ndarray
|
|
cache_u: Any
|
|
|
|
|
|
def _create_cache(
|
|
labels,
|
|
model,
|
|
init_sequence,
|
|
params,
|
|
encoded,
|
|
uncond=False,
|
|
):
|
|
"""Creates the cache and returns initial logits."""
|
|
if uncond:
|
|
assert labels is not None
|
|
drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_)
|
|
else:
|
|
drop_labels = None
|
|
|
|
def init_cache(model):
|
|
return model.decode(
|
|
init_sequence, labels, encoded, decode=True, drop_labels=drop_labels
|
|
)
|
|
|
|
cache = nn.apply(init_cache, model, mutable=True)(params)[1]["cache"]
|
|
|
|
def prefill_cache(model):
|
|
return model.prefill(
|
|
labels, init_sequence.shape[0], encoded, drop_labels=drop_labels
|
|
)
|
|
|
|
|
|
prefill_logits, aux = nn.apply(prefill_cache, model, mutable=True)(
|
|
{"params": params["params"], "cache": cache})
|
|
cache = aux["cache"]
|
|
return cache, prefill_logits
|
|
|
|
|
|
def generate(
|
|
params: Any,
|
|
seed: jax.Array,
|
|
*,
|
|
model: nn.Module,
|
|
seq_len: int,
|
|
feature_dim: int,
|
|
labels: Optional[jnp.ndarray] = None,
|
|
cond_image: Optional[jnp.ndarray] = None,
|
|
batch_size: Optional[int] = None,
|
|
config: Optional[ml_collections.ConfigDict] = None,
|
|
) -> tuple[jax.Array, jax.Array]:
|
|
"""Sampling loop for GIVT."""
|
|
if model.style != "ar":
|
|
raise ValueError(f"Invalid style: {model.style}")
|
|
if model.has_encoder != (cond_image is not None):
|
|
raise ValueError("Need cond_image if and only if the model has an encoder!")
|
|
|
|
assert labels is not None or batch_size, (
|
|
"Please provide either labels or batch_size.")
|
|
|
|
config = config or {}
|
|
config = dict(config)
|
|
|
|
|
|
|
|
keep_gt = config.pop("keep_gt", None)
|
|
gt = config.pop("gt", None)
|
|
|
|
if isinstance(seed, int):
|
|
seed = jax.random.PRNGKey(seed)
|
|
|
|
beam_size = config.pop("beam_size", 1)
|
|
fan_size = config.pop("fan_size", 1)
|
|
|
|
if labels is not None:
|
|
batch_size = labels.shape[0]
|
|
|
|
labels = labels.repeat(beam_size, axis=0)
|
|
|
|
|
|
init_sequence = jnp.zeros((batch_size * beam_size, seq_len, feature_dim))
|
|
init_logprobs = jnp.zeros_like(init_sequence)
|
|
|
|
if cond_image is not None:
|
|
|
|
def encode_cond_img(model, cond_img):
|
|
return model.encode(cond_img)
|
|
encoded = nn.apply(encode_cond_img, model)(params, cond_image)
|
|
encoded = jnp.repeat(encoded, beam_size, axis=0)
|
|
else:
|
|
encoded = None
|
|
|
|
cache, prefill_logits = _create_cache(
|
|
labels, model, init_sequence, params, encoded
|
|
)
|
|
|
|
cfg_inference_weight = config.pop("cfg_inference_weight", None)
|
|
if cfg_inference_weight == 0.0:
|
|
cfg_inference_weight = None
|
|
cfg = cfg_inference_weight is not None
|
|
|
|
get_pdf = functools.partial(
|
|
model.get_pdf,
|
|
temperature_scales=config.pop("temp", None),
|
|
temperature_probs=config.pop("temp_probs", None),
|
|
)
|
|
|
|
|
|
sample = functools.partial(
|
|
_sample_gmm, cfg_inference_weight=cfg_inference_weight
|
|
)
|
|
|
|
|
|
pdf_first = get_pdf(prefill_logits)
|
|
rng_first, rng = jax.random.split(seed)
|
|
|
|
if cfg:
|
|
assert beam_size == 1 and fan_size == 1
|
|
cache_u, prefill_logits_u = _create_cache(
|
|
labels, model, init_sequence, params, encoded, uncond=True
|
|
)
|
|
pdf_first_u = get_pdf(prefill_logits_u)
|
|
else:
|
|
cache_u = None
|
|
pdf_first_u = None
|
|
|
|
tokens_first, logprobs_first = sample(
|
|
pdf_first, rng=rng_first, gmm_pdf_uncond=pdf_first_u
|
|
)
|
|
init_sequence = init_sequence.at[:, 0].set(tokens_first.squeeze(axis=1))
|
|
init_logprobs = init_logprobs.at[:, 0].set(logprobs_first.squeeze(axis=1))
|
|
|
|
def tokens_to_logits(tokens, cache, uncond=False):
|
|
if uncond:
|
|
drop_labels = jnp.ones((labels.shape[0],), dtype=jnp.bool_)
|
|
else:
|
|
drop_labels = None
|
|
|
|
def decode_step(model, tokens):
|
|
return model.decode(tokens, labels, encoded,
|
|
decode=True, drop_labels=drop_labels)
|
|
|
|
logits, aux = nn.apply(decode_step, model, mutable=True)(
|
|
{"params": params["params"], "cache": cache}, tokens)
|
|
return logits, aux["cache"]
|
|
|
|
init_state = LoopState(
|
|
cache=cache,
|
|
sequences=init_sequence,
|
|
logprobs=init_logprobs,
|
|
rng=rng,
|
|
cache_u=cache_u,
|
|
)
|
|
|
|
rand_top_k = config.pop("rand_top_k", False)
|
|
rand_top_k_temp = config.pop("rand_top_k_temp", 1.0)
|
|
|
|
assert not config, f"Sampling config is expected to be empty: {config}"
|
|
|
|
def sampling_iteration(i, state):
|
|
rng_sampling, rng_local = jax.random.split(state.rng)
|
|
cur_tokens = state.sequences[:, i][:, None]
|
|
|
|
cur_logits, cache = tokens_to_logits(cur_tokens, state.cache)
|
|
|
|
|
|
cur_logits = _unflatten_samples_dim(
|
|
cur_logits, batch_size, beam_size).squeeze(axis=2)
|
|
|
|
|
|
cur_pdf = get_pdf(cur_logits.repeat(fan_size, axis=1))
|
|
|
|
if cfg:
|
|
cur_logits_u, cache_u = tokens_to_logits(
|
|
cur_tokens, state.cache_u, uncond=True
|
|
)
|
|
cur_logits_u = _unflatten_samples_dim(
|
|
cur_logits_u, batch_size, beam_size).squeeze(axis=2)
|
|
cur_pdf_u = get_pdf(cur_logits_u.repeat(fan_size, axis=1))
|
|
new_tokens, new_logprobs = sample(
|
|
cur_pdf, rng=rng_sampling, gmm_pdf_uncond=cur_pdf_u
|
|
)
|
|
else:
|
|
new_tokens, new_logprobs = sample(cur_pdf, rng=rng_sampling)
|
|
cache_u = None
|
|
|
|
if gt is not None:
|
|
assert keep_gt is not None
|
|
new_tokens = jnp.where(keep_gt[i], gt[:, i, :][:, None], new_tokens)
|
|
|
|
|
|
if beam_size == fan_size == 1:
|
|
sampled_tokens = new_tokens.squeeze(axis=1)
|
|
sequences = state.sequences.at[:, i + 1].set(sampled_tokens)
|
|
return LoopState(
|
|
cache=cache,
|
|
rng=rng_local,
|
|
sequences=sequences,
|
|
logprobs=state.logprobs,
|
|
cache_u=cache_u,
|
|
)
|
|
|
|
|
|
logprobs = _unflatten_samples_dim(state.logprobs, batch_size, beam_size)
|
|
cur_logprobs = logprobs[:, :, i]
|
|
|
|
new_logprobs = new_logprobs + cur_logprobs.repeat(fan_size, axis=1)
|
|
beam_logprobs = new_logprobs.sum(axis=-1)
|
|
|
|
if rand_top_k:
|
|
|
|
def stoc_top_k(r, x, p):
|
|
return jax.random.choice(r, x, shape=(beam_size,), replace=False, p=p)
|
|
|
|
index_grid = jnp.arange(beam_logprobs.shape[1], dtype=jnp.int32)
|
|
|
|
index_grid = index_grid[None].repeat(beam_logprobs.shape[0], axis=0)
|
|
top_k_rng, rng_local = jax.random.split(rng_local)
|
|
top_k_rng = jax.random.split(top_k_rng, beam_logprobs.shape[0])
|
|
|
|
top_beam_fan_indices = jax.vmap(stoc_top_k, in_axes=(0, 0, 0))(
|
|
top_k_rng,
|
|
index_grid,
|
|
nn.softmax(beam_logprobs / rand_top_k_temp, axis=-1))
|
|
else:
|
|
_, top_beam_fan_indices = lax.top_k(beam_logprobs, k=beam_size)
|
|
|
|
top_beam_indices = top_beam_fan_indices // fan_size
|
|
|
|
def _gather_beams(x):
|
|
if x.ndim == 0:
|
|
return x
|
|
|
|
|
|
|
|
|
|
expanded_indices = top_beam_indices.reshape(
|
|
top_beam_indices.shape + (1,) * (x.ndim - 2))
|
|
return jnp.take_along_axis(x, expanded_indices, axis=1)
|
|
|
|
def _gather_tokens(x):
|
|
|
|
|
|
|
|
|
|
return jnp.take_along_axis(x, top_beam_fan_indices[..., None], axis=1)
|
|
|
|
sequences = _unflatten_samples_dim(state.sequences, batch_size, beam_size)
|
|
sequences = _gather_beams(sequences)
|
|
sequences = sequences.at[:, :, i + 1].set(_gather_tokens(new_tokens))
|
|
|
|
sequences = _flatten_samples_dim(sequences)
|
|
|
|
logprobs = _gather_beams(logprobs)
|
|
logprobs = logprobs.at[:, :, i + 1].set(_gather_tokens(new_logprobs))
|
|
logprobs = _flatten_samples_dim(logprobs)
|
|
|
|
scanned_cache = getattr(model, "scan", False)
|
|
cache = _cache_map(
|
|
lambda x: _unflatten_samples_dim(x, batch_size, beam_size),
|
|
cache, scanned_cache)
|
|
cache = _cache_map(_gather_beams, cache, scanned_cache)
|
|
cache = _cache_map(_flatten_samples_dim, cache, scanned_cache)
|
|
|
|
if cfg:
|
|
assert cache_u is not None
|
|
cache_u = _cache_map(
|
|
lambda x: _unflatten_samples_dim(x, batch_size, beam_size),
|
|
cache_u, scanned_cache
|
|
)
|
|
cache_u = _cache_map(_gather_beams, cache_u, scanned_cache)
|
|
cache_u = _cache_map(_flatten_samples_dim, cache_u, scanned_cache)
|
|
else:
|
|
assert cache_u is None
|
|
|
|
return LoopState(
|
|
cache=cache,
|
|
rng=rng_local,
|
|
sequences=sequences,
|
|
logprobs=logprobs,
|
|
cache_u=cache_u,
|
|
)
|
|
|
|
final_state = lax.fori_loop(0, seq_len, sampling_iteration, init_state)
|
|
final_logprobs = final_state.logprobs[::beam_size][:, -1].sum(axis=-1)
|
|
|
|
|
|
return final_state.sequences[::beam_size], final_logprobs
|
|
|