Spaces:
Configuration error
Configuration error
File size: 8,160 Bytes
01259cd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
import os
import sys
import base64
import functools
import html
import io
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from PIL import Image
# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
raise "It seems you are using Colab with remote TPUs which is not supported."
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
sys.path.append("big_vision_repo")
# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.lib.xla_bridge.get_backend()
model_path = './Sofa-attributes-paligemma-ckpt.npz'
tokenizer_path = './paligemma_tokenizer.model'
# Define model
model_config = ml_collections.FrozenConfigDict({
"llm": {"vocab_size": 257_152},
"img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, model_path, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())
# Create a pytree mask of the trainable params.
def is_trainable_param(name, param): # pylint: disable=unused-argument
if name.startswith("llm/layers/attn/"): return True
if name.startswith("llm/"): return False
if name.startswith("img/"): return False
raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))
data_sharding = jax.sharding.NamedSharding(
mesh, jax.sharding.PartitionSpec("data"))
params_sharding = big_vision.sharding.infer_sharding(
params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)
# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
"ignore", message="Some donated buffers were not usable")
@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,
params, trainable)
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
params[idx] = big_vision.utils.reshard(params[idx], sharding)
params[idx] = maybe_cast_to_f32(params[idx], trainable)
params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)
# Print params to show what the model is made of.
def parameter_overview(params):
for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")
print(" == Model params == ")
parameter_overview(params)
def setup_and_predict(image_path):
# Preprocess image and tokens
def preprocess_image(image, size=224):
# Model has been trained to handle images of different aspects ratios
# resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
# options are helpful to improve quality in some tasks.
image = np.asarray(image)
if image.ndim == 2: # Convert image without last channel into greyscale.
image = np.stack((image,)*3, axis=-1)
image = image[..., :3] # Remove alpha layer.
assert image.shape[-1] == 3
image = tf.constant(image)
image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]
def preprocess_tokens(prefix, suffix=None, seqlen=None):
# Model has been trained to handle tokenized text composed of a prefix with
# full attention and a suffix with causal attention.
separator = "\n"
tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
mask_ar = [0] * len(tokens) # 0 to use full attention for prefix.
mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss.
if suffix:
suffix = tokenizer.encode(suffix, add_eos=True)
tokens += suffix
mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix.
mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss.
mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding.
if seqlen:
padding = [0] * max(0, seqlen - len(tokens))
tokens = tokens[:seqlen] + padding
mask_ar = mask_ar[:seqlen] + padding
mask_loss = mask_loss[:seqlen] + padding
mask_input = mask_input[:seqlen] + padding
return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))
def postprocess_tokens(tokens):
tokens = tokens.tolist() # np.array to list[int]
try: # Remove tokens at and after EOS if any.
eos_pos = tokens.index(tokenizer.eos_id())
tokens = tokens[:eos_pos]
except ValueError:
pass
return tokenizer.decode(tokens)
# Make predictions
# Evaluation/inference loop.
SEQLEN = 128
def make_predictions(data_iterator, *, num_examples=None,
batch_size=4, seqlen=SEQLEN, sampler="greedy"):
outputs = []
while True:
# Construct a list of examples in the batch.
examples = []
try:
for _ in range(batch_size):
examples.append(next(data_iterator))
examples[-1]["_mask"] = np.array(True) # Indicates true example.
except StopIteration:
if len(examples) == 0:
return outputs
# Not enough examples to complete a batch. Pad by repeating last example.
while len(examples) % batch_size:
examples.append(dict(examples[-1]))
examples[-1]["_mask"] = np.array(False) # Indicates padding example.
# Convert list of examples into a dict of np.arrays and load onto devices.
batch = jax.tree.map(lambda *x: np.stack(x), *examples)
batch = big_vision.utils.reshard(batch, data_sharding)
# Make model predictions
tokens = decode({"params": params}, batch=batch,
max_decode_len=seqlen, sampler=sampler)
# Fetch model predictions to device and detokenize.
tokens, mask = jax.device_get((tokens, batch["_mask"]))
tokens = tokens[mask] # remove padding examples.
responses = [postprocess_tokens(t) for t in tokens]
# Append to html output.
for example, response in zip(examples, responses):
outputs.append((example["image"], response))
if num_examples and len(outputs) >= num_examples:
return outputs
def test_data_iterator(file_name):
image = Image.open(file_name)
image = preprocess_image(image)
prefix = "caption en"
tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)
yield {
"image": np.asarray(image),
"text": np.asarray(tokens),
"mask_ar": np.asarray(mask_ar),
"mask_input": np.asarray(mask_input)
}
# Call the prediction function and print the result
image, caption = make_predictions(test_data_iterator(file_name=image_path), batch_size=1)[0]
return caption
|