"""Wraps `big_vision` PaliGemma model for easy use in demo.""" from collections.abc import Callable import dataclasses from typing import Any import jax import jax.numpy as jnp import ml_collections import numpy as np import PIL.Image from big_vision import sharding from big_vision import utils from big_vision.models.proj.paligemma import paligemma from big_vision.pp import builder as pp_builder from big_vision.pp import ops_general # pylint: disable=unused-import from big_vision.pp import ops_image # pylint: disable=unused-import from big_vision.pp import ops_text # pylint: disable=unused-import from big_vision.pp import tokenizer from big_vision.pp.proj.paligemma import ops as ops_paligemma # pylint: disable=unused-import from big_vision.trainers.proj.paligemma import predict_fns mesh = jax.sharding.Mesh(jax.devices(), 'data') def _recover_bf16(x): if x.dtype == np.dtype('V2'): x = x.view('bfloat16') return x def _load( path, tokenizer_spec='gemma(tokensets=("loc", "seg"))', vocab_size=257_152 ): """Loads model, params, decode functions and tokenizer.""" tok = tokenizer.get_tokenizer(tokenizer_spec) config = ml_collections.FrozenConfigDict(dict( llm_model='proj.paligemma.gemma_bv', llm=dict(vocab_size=vocab_size, variant='gemma_2b'), img=dict(variant='So400m/14', pool_type='none', scan=True), )) model = paligemma.Model(**config) decode = predict_fns.get_all(model)['decode'] beam_decode = predict_fns.get_all(model)['beam_decode'] params_cpu = paligemma.load(None, path, config) # Some numpy versions don't load bfloat16 correctly: params_cpu = jax.tree.map(_recover_bf16, params_cpu) return model, params_cpu, decode, beam_decode, tok def _shard_params(params_cpu): """Shards `params_cpu` with fsdp strategy on all available devices.""" params_sharding = sharding.infer_sharding( params_cpu, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh ) params = jax.tree.map(utils.reshard, params_cpu, params_sharding) return params def _pil2np(img): """Accepts `PIL.Image` or `np.ndarray` and returns `np.ndarray`.""" if isinstance(img, PIL.Image.Image): img = np.array(img) img = img[..., :3] if img.ndim == 2: img = img[..., None] if img.shape[-1] == 1: img = np.repeat(img, 3, axis=-1) return img def _prepare_batch( images, prefixes, *, res=224, tokenizer_spec='gemma(tokensets=("loc", "seg"))', suffixes=None, text_len=64, ): """Returns non-sharded batch.""" pp_fn = pp_builder.get_preprocess_fn('|'.join([ f'resize({res}, antialias=True)|value_range(-1, 1)', f"tok(key='prefix', bos='yes', model='{tokenizer_spec}')", f"tok(key='septok', text='\\n', model='{tokenizer_spec}')", f"tok(key='suffix', model='{tokenizer_spec}')", 'masked_concat(["prefix", "septok", "suffix"], mask_ar=[0, 0, 1], mask_input=[1, 1, 1])', # pylint: disable=line-too-long f'tolen({text_len}, pad_value=0, key="text")', f'tolen({text_len}, pad_value=1, key="mask_ar")', f'tolen({text_len}, pad_value=0, key="mask_input")', 'keep("image", "text", "mask_ar", "mask_input")', ]), log_data=False) assert not isinstance(prefixes, str), f'expected batch: {prefixes}' assert ( isinstance(images, (list, tuple)) or images.ndim == 4 ), f'expected batch: {images.shape}' if suffixes is None: suffixes = [''] * len(prefixes) assert len(prefixes) == len(suffixes) == len(images) examples = [{'_mask': True, **pp_fn({ 'image': np.asarray(_pil2np(image)), 'prefix': np.array(prefix), 'suffix': np.array(suffix), })} for image, prefix, suffix in zip(images, prefixes, suffixes)] batch = jax.tree_map(lambda *xs: np.stack(xs), *examples) return batch def _shard_batch(batch, n=None): """Shards `batch` with fsdp strategy on all available devices.""" if n is None: n = jax.local_device_count() def pad(x): return jnp.pad(x, [(0, -len(x) % n)] + [(0, 0)] * (x.ndim - 1)) batch = {k: pad(v) for k, v in batch.items()} data_sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec('data') ) batch_on_device = utils.reshard(batch, data_sharding) return batch_on_device @dataclasses.dataclass(frozen=True, kw_only=True, order=True) class PaligemmaConfig: """Desribes a `big_vision` PaliGemma model.""" ckpt: str res: int text_len: int tokenizer: str vocab_size: int @dataclasses.dataclass(frozen=True, kw_only=True) class PaliGemmaModel: """Wraps a `big_vision` PaliGemma model.""" config: PaligemmaConfig tokenizer: tokenizer.Tokenizer decode: Callable[..., Any] beam_decode: Callable[..., Any] @classmethod def shard_batch(cls, batch): return _shard_batch(batch) @classmethod def shard_params(cls, params_cpu): return _shard_params(params_cpu) def prepare_batch(self, images, texts, suffixes=None): return _prepare_batch( images=images, prefixes=texts, suffixes=suffixes, res=self.config.res, tokenizer_spec=self.config.tokenizer, text_len=self.config.text_len, ) def predict( self, params, batch, devices=None, max_decode_len=128, sampler='greedy', **kw, ): """Returns tokens.""" if devices is None: devices = jax.devices() if sampler == 'beam': decode = self.beam_decode else: decode = self.decode kw['sampler'] = sampler return decode( {'params': params}, batch=batch, devices=devices, eos_token=self.tokenizer.eos_token, max_decode_len=max_decode_len, **kw, ) ParamsCpu = Any def load_model(config: PaligemmaConfig) -> tuple[PaliGemmaModel, ParamsCpu]: """Loads model from config.""" model, params_cpu, decode, beam_decode, tok = _load( path=config.ckpt, tokenizer_spec=config.tokenizer, vocab_size=config.vocab_size, ) del model return PaliGemmaModel( config=config, tokenizer=tok, decode=decode, beam_decode=beam_decode, ), params_cpu