Spaces:
Configuration error
Configuration error
import os | |
import sys | |
import base64 | |
import functools | |
import html | |
import io | |
import warnings | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
import ml_collections | |
import tensorflow as tf | |
import sentencepiece | |
from PIL import Image | |
# TPUs with | |
if "COLAB_TPU_ADDR" in os.environ: | |
raise "It seems you are using Colab with remote TPUs which is not supported." | |
# Append big_vision code to python import path | |
if "big_vision_repo" not in sys.path: | |
sys.path.append("big_vision_repo") | |
# Import model definition from big_vision | |
from big_vision.models.proj.paligemma import paligemma | |
from big_vision.trainers.proj.paligemma import predict_fns | |
# Import big vision utilities | |
import big_vision.datasets.jsonl | |
import big_vision.utils | |
import big_vision.sharding | |
# Don't let TF use the GPU or TPUs | |
tf.config.set_visible_devices([], "GPU") | |
tf.config.set_visible_devices([], "TPU") | |
backend = jax.lib.xla_bridge.get_backend() | |
model_path = './Sofa-attributes-paligemma-ckpt.npz' | |
tokenizer_path = './paligemma_tokenizer.model' | |
# Define model | |
model_config = ml_collections.FrozenConfigDict({ | |
"llm": {"vocab_size": 257_152}, | |
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"} | |
}) | |
model = paligemma.Model(**model_config) | |
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path) | |
# Load params - this can take up to 1 minute in T4 colabs. | |
params = paligemma.load(None, model_path, model_config) | |
# Define `decode` function to sample outputs from the model. | |
decode_fn = predict_fns.get_all(model)['decode'] | |
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id()) | |
# Create a pytree mask of the trainable params. | |
def is_trainable_param(name, param): # pylint: disable=unused-argument | |
if name.startswith("llm/layers/attn/"): return True | |
if name.startswith("llm/"): return False | |
if name.startswith("img/"): return False | |
raise ValueError(f"Unexpected param name {name}") | |
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params) | |
# If more than one device is available (e.g. multiple GPUs) the parameters can | |
# be sharded across them to reduce HBM usage per device. | |
mesh = jax.sharding.Mesh(jax.devices(), ("data")) | |
data_sharding = jax.sharding.NamedSharding( | |
mesh, jax.sharding.PartitionSpec("data")) | |
params_sharding = big_vision.sharding.infer_sharding( | |
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh) | |
# Yes: Some donated buffers are not usable. | |
warnings.filterwarnings( | |
"ignore", message="Some donated buffers were not usable") | |
def maybe_cast_to_f32(params, trainable): | |
return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p, | |
params, trainable) | |
# Loading all params in simultaneous - albeit much faster and more succinct - | |
# requires more RAM than the T4 colab runtimes have by default (12GB RAM). | |
# Instead we do it param by param. | |
params, treedef = jax.tree.flatten(params) | |
sharding_leaves = jax.tree.leaves(params_sharding) | |
trainable_leaves = jax.tree.leaves(trainable_mask) | |
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)): | |
params[idx] = big_vision.utils.reshard(params[idx], sharding) | |
params[idx] = maybe_cast_to_f32(params[idx], trainable) | |
params[idx].block_until_ready() | |
params = jax.tree.unflatten(treedef, params) | |
# Print params to show what the model is made of. | |
def parameter_overview(params): | |
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]: | |
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}") | |
print(" == Model params == ") | |
parameter_overview(params) | |
def setup_and_predict(image_path): | |
# Preprocess image and tokens | |
def preprocess_image(image, size=224): | |
# Model has been trained to handle images of different aspects ratios | |
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize | |
# options are helpful to improve quality in some tasks. | |
image = np.asarray(image) | |
if image.ndim == 2: # Convert image without last channel into greyscale. | |
image = np.stack((image,)*3, axis=-1) | |
image = image[..., :3] # Remove alpha layer. | |
assert image.shape[-1] == 3 | |
image = tf.constant(image) | |
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True) | |
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1] | |
def preprocess_tokens(prefix, suffix=None, seqlen=None): | |
# Model has been trained to handle tokenized text composed of a prefix with | |
# full attention and a suffix with causal attention. | |
separator = "\n" | |
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator) | |
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix. | |
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss. | |
if suffix: | |
suffix = tokenizer.encode(suffix, add_eos=True) | |
tokens += suffix | |
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix. | |
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss. | |
mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding. | |
if seqlen: | |
padding = [0] * max(0, seqlen - len(tokens)) | |
tokens = tokens[:seqlen] + padding | |
mask_ar = mask_ar[:seqlen] + padding | |
mask_loss = mask_loss[:seqlen] + padding | |
mask_input = mask_input[:seqlen] + padding | |
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input)) | |
def postprocess_tokens(tokens): | |
tokens = tokens.tolist() # np.array to list[int] | |
try: # Remove tokens at and after EOS if any. | |
eos_pos = tokens.index(tokenizer.eos_id()) | |
tokens = tokens[:eos_pos] | |
except ValueError: | |
pass | |
return tokenizer.decode(tokens) | |
# Make predictions | |
# Evaluation/inference loop. | |
SEQLEN = 128 | |
def make_predictions(data_iterator, *, num_examples=None, | |
batch_size=4, seqlen=SEQLEN, sampler="greedy"): | |
outputs = [] | |
while True: | |
# Construct a list of examples in the batch. | |
examples = [] | |
try: | |
for _ in range(batch_size): | |
examples.append(next(data_iterator)) | |
examples[-1]["_mask"] = np.array(True) # Indicates true example. | |
except StopIteration: | |
if len(examples) == 0: | |
return outputs | |
# Not enough examples to complete a batch. Pad by repeating last example. | |
while len(examples) % batch_size: | |
examples.append(dict(examples[-1])) | |
examples[-1]["_mask"] = np.array(False) # Indicates padding example. | |
# Convert list of examples into a dict of np.arrays and load onto devices. | |
batch = jax.tree.map(lambda *x: np.stack(x), *examples) | |
batch = big_vision.utils.reshard(batch, data_sharding) | |
# Make model predictions | |
tokens = decode({"params": params}, batch=batch, | |
max_decode_len=seqlen, sampler=sampler) | |
# Fetch model predictions to device and detokenize. | |
tokens, mask = jax.device_get((tokens, batch["_mask"])) | |
tokens = tokens[mask] # remove padding examples. | |
responses = [postprocess_tokens(t) for t in tokens] | |
# Append to html output. | |
for example, response in zip(examples, responses): | |
outputs.append((example["image"], response)) | |
if num_examples and len(outputs) >= num_examples: | |
return outputs | |
def test_data_iterator(file_name): | |
image = Image.open(file_name) | |
image = preprocess_image(image) | |
prefix = "caption en" | |
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN) | |
yield { | |
"image": np.asarray(image), | |
"text": np.asarray(tokens), | |
"mask_ar": np.asarray(mask_ar), | |
"mask_input": np.asarray(mask_input) | |
} | |
# Call the prediction function and print the result | |
image, caption = make_predictions(test_data_iterator(file_name=image_path), batch_size=1)[0] | |
return caption | |