Spaces:
Configuration error
Configuration error
| import os | |
| import sys | |
| import base64 | |
| import functools | |
| import html | |
| import io | |
| import warnings | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import ml_collections | |
| import tensorflow as tf | |
| import sentencepiece | |
| from PIL import Image | |
| # TPUs with | |
| if "COLAB_TPU_ADDR" in os.environ: | |
| raise "It seems you are using Colab with remote TPUs which is not supported." | |
| # Append big_vision code to python import path | |
| if "big_vision_repo" not in sys.path: | |
| sys.path.append("big_vision_repo") | |
| # Import model definition from big_vision | |
| from big_vision.models.proj.paligemma import paligemma | |
| from big_vision.trainers.proj.paligemma import predict_fns | |
| # Import big vision utilities | |
| import big_vision.datasets.jsonl | |
| import big_vision.utils | |
| import big_vision.sharding | |
| # Don't let TF use the GPU or TPUs | |
| tf.config.set_visible_devices([], "GPU") | |
| tf.config.set_visible_devices([], "TPU") | |
| backend = jax.lib.xla_bridge.get_backend() | |
| model_path = './Sofa-attributes-paligemma-ckpt.npz' | |
| tokenizer_path = './paligemma_tokenizer.model' | |
| # Define model | |
| model_config = ml_collections.FrozenConfigDict({ | |
| "llm": {"vocab_size": 257_152}, | |
| "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"} | |
| }) | |
| model = paligemma.Model(**model_config) | |
| tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path) | |
| # Load params - this can take up to 1 minute in T4 colabs. | |
| params = paligemma.load(None, model_path, model_config) | |
| # Define `decode` function to sample outputs from the model. | |
| decode_fn = predict_fns.get_all(model)['decode'] | |
| decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id()) | |
| # Create a pytree mask of the trainable params. | |
| def is_trainable_param(name, param): # pylint: disable=unused-argument | |
| if name.startswith("llm/layers/attn/"): return True | |
| if name.startswith("llm/"): return False | |
| if name.startswith("img/"): return False | |
| raise ValueError(f"Unexpected param name {name}") | |
| trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params) | |
| # If more than one device is available (e.g. multiple GPUs) the parameters can | |
| # be sharded across them to reduce HBM usage per device. | |
| mesh = jax.sharding.Mesh(jax.devices(), ("data")) | |
| data_sharding = jax.sharding.NamedSharding( | |
| mesh, jax.sharding.PartitionSpec("data")) | |
| params_sharding = big_vision.sharding.infer_sharding( | |
| params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh) | |
| # Yes: Some donated buffers are not usable. | |
| warnings.filterwarnings( | |
| "ignore", message="Some donated buffers were not usable") | |
| def maybe_cast_to_f32(params, trainable): | |
| return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p, | |
| params, trainable) | |
| # Loading all params in simultaneous - albeit much faster and more succinct - | |
| # requires more RAM than the T4 colab runtimes have by default (12GB RAM). | |
| # Instead we do it param by param. | |
| params, treedef = jax.tree.flatten(params) | |
| sharding_leaves = jax.tree.leaves(params_sharding) | |
| trainable_leaves = jax.tree.leaves(trainable_mask) | |
| for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)): | |
| params[idx] = big_vision.utils.reshard(params[idx], sharding) | |
| params[idx] = maybe_cast_to_f32(params[idx], trainable) | |
| params[idx].block_until_ready() | |
| params = jax.tree.unflatten(treedef, params) | |
| # Print params to show what the model is made of. | |
| def parameter_overview(params): | |
| for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]: | |
| print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}") | |
| print(" == Model params == ") | |
| parameter_overview(params) | |
| def setup_and_predict(image_path): | |
| # Preprocess image and tokens | |
| def preprocess_image(image, size=224): | |
| # Model has been trained to handle images of different aspects ratios | |
| # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize | |
| # options are helpful to improve quality in some tasks. | |
| image = np.asarray(image) | |
| if image.ndim == 2: # Convert image without last channel into greyscale. | |
| image = np.stack((image,)*3, axis=-1) | |
| image = image[..., :3] # Remove alpha layer. | |
| assert image.shape[-1] == 3 | |
| image = tf.constant(image) | |
| image = tf.image.resize(image, (size, size), method='bilinear', antialias=True) | |
| return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1] | |
| def preprocess_tokens(prefix, suffix=None, seqlen=None): | |
| # Model has been trained to handle tokenized text composed of a prefix with | |
| # full attention and a suffix with causal attention. | |
| separator = "\n" | |
| tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator) | |
| mask_ar = [0] * len(tokens) # 0 to use full attention for prefix. | |
| mask_loss = [0] * len(tokens) # 0 to not use prefix tokens in the loss. | |
| if suffix: | |
| suffix = tokenizer.encode(suffix, add_eos=True) | |
| tokens += suffix | |
| mask_ar += [1] * len(suffix) # 1 to use causal attention for suffix. | |
| mask_loss += [1] * len(suffix) # 1 to use suffix tokens in the loss. | |
| mask_input = [1] * len(tokens) # 1 if its a token, 0 if padding. | |
| if seqlen: | |
| padding = [0] * max(0, seqlen - len(tokens)) | |
| tokens = tokens[:seqlen] + padding | |
| mask_ar = mask_ar[:seqlen] + padding | |
| mask_loss = mask_loss[:seqlen] + padding | |
| mask_input = mask_input[:seqlen] + padding | |
| return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input)) | |
| def postprocess_tokens(tokens): | |
| tokens = tokens.tolist() # np.array to list[int] | |
| try: # Remove tokens at and after EOS if any. | |
| eos_pos = tokens.index(tokenizer.eos_id()) | |
| tokens = tokens[:eos_pos] | |
| except ValueError: | |
| pass | |
| return tokenizer.decode(tokens) | |
| # Make predictions | |
| # Evaluation/inference loop. | |
| SEQLEN = 128 | |
| def make_predictions(data_iterator, *, num_examples=None, | |
| batch_size=4, seqlen=SEQLEN, sampler="greedy"): | |
| outputs = [] | |
| while True: | |
| # Construct a list of examples in the batch. | |
| examples = [] | |
| try: | |
| for _ in range(batch_size): | |
| examples.append(next(data_iterator)) | |
| examples[-1]["_mask"] = np.array(True) # Indicates true example. | |
| except StopIteration: | |
| if len(examples) == 0: | |
| return outputs | |
| # Not enough examples to complete a batch. Pad by repeating last example. | |
| while len(examples) % batch_size: | |
| examples.append(dict(examples[-1])) | |
| examples[-1]["_mask"] = np.array(False) # Indicates padding example. | |
| # Convert list of examples into a dict of np.arrays and load onto devices. | |
| batch = jax.tree.map(lambda *x: np.stack(x), *examples) | |
| batch = big_vision.utils.reshard(batch, data_sharding) | |
| # Make model predictions | |
| tokens = decode({"params": params}, batch=batch, | |
| max_decode_len=seqlen, sampler=sampler) | |
| # Fetch model predictions to device and detokenize. | |
| tokens, mask = jax.device_get((tokens, batch["_mask"])) | |
| tokens = tokens[mask] # remove padding examples. | |
| responses = [postprocess_tokens(t) for t in tokens] | |
| # Append to html output. | |
| for example, response in zip(examples, responses): | |
| outputs.append((example["image"], response)) | |
| if num_examples and len(outputs) >= num_examples: | |
| return outputs | |
| def test_data_iterator(file_name): | |
| image = Image.open(file_name) | |
| image = preprocess_image(image) | |
| prefix = "caption en" | |
| tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN) | |
| yield { | |
| "image": np.asarray(image), | |
| "text": np.asarray(tokens), | |
| "mask_ar": np.asarray(mask_ar), | |
| "mask_input": np.asarray(mask_input) | |
| } | |
| # Call the prediction function and print the result | |
| image, caption = make_predictions(test_data_iterator(file_name=image_path), batch_size=1)[0] | |
| return caption | |