vision / utils.py
Apoorva96nbd's picture
added initial files
01259cd verified
import os
import sys
import base64
import functools
import html
import io
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from PIL import Image
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.lib.xla_bridge.get_backend()
model_path = './Sofa-attributes-paligemma-ckpt.npz'
tokenizer_path = './paligemma_tokenizer.model'
# Define model
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, model_path, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
def setup_and_predict(image_path):
# Preprocess image and tokens
def preprocess_image(image, size=224):
# Model has been trained to handle images of different aspects ratios
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
# options are helpful to improve quality in some tasks.
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]
def preprocess_tokens(prefix, suffix=None, seqlen=None):
# Model has been trained to handle tokenized text composed of a prefix with
# full attention and a suffix with causal attention.
separator = "\n"
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return tokenizer.decode(tokens)
# Make predictions
# Evaluation/inference loop.
SEQLEN = 128
def make_predictions(data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
# Construct a list of examples in the batch.
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
# Not enough examples to complete a batch. Pad by repeating last example.
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Make model predictions
tokens = decode({"params": params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [postprocess_tokens(t) for t in tokens]
# Append to html output.
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
def test_data_iterator(file_name):
image = Image.open(file_name)
image = preprocess_image(image)
prefix = "caption en"
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input)
}
# Call the prediction function and print the result
image, caption = make_predictions(test_data_iterator(file_name=image_path), batch_size=1)[0]
return caption