File size: 8,160 Bytes
01259cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
import os
import sys
import base64
import functools
import html
import io
import warnings
import jax
import jax.numpy as jnp
import numpy as np
import ml_collections
import tensorflow as tf
import sentencepiece
from PIL import Image

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."
# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")
backend = jax.lib.xla_bridge.get_backend()

model_path = './Sofa-attributes-paligemma-ckpt.npz'
tokenizer_path = './paligemma_tokenizer.model'

# Define model
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(tokenizer_path)
# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, model_path, model_config)
# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,
                      params, trainable)

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)
def setup_and_predict(image_path):
    # Preprocess image and tokens
    def preprocess_image(image, size=224):
      # Model has been trained to handle images of different aspects ratios
      # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
      # options are helpful to improve quality in some tasks.
      image = np.asarray(image)
      if image.ndim == 2:  # Convert image without last channel into greyscale.
        image = np.stack((image,)*3, axis=-1)
      image = image[..., :3]  # Remove alpha layer.
      assert image.shape[-1] == 3

      image = tf.constant(image)
      image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
      return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

    def preprocess_tokens(prefix, suffix=None, seqlen=None):
      # Model has been trained to handle tokenized text composed of a prefix with
      # full attention and a suffix with causal attention.
      separator = "\n"
      tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
      mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
      mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

      if suffix:
        suffix = tokenizer.encode(suffix, add_eos=True)
        tokens += suffix
        mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
        mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

      mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.
      if seqlen:
        padding = [0] * max(0, seqlen - len(tokens))
        tokens = tokens[:seqlen] + padding
        mask_ar = mask_ar[:seqlen] + padding
        mask_loss = mask_loss[:seqlen] + padding
        mask_input = mask_input[:seqlen] + padding

      return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

    def postprocess_tokens(tokens):
      tokens = tokens.tolist()  # np.array to list[int]
      try:  # Remove tokens at and after EOS if any.
        eos_pos = tokens.index(tokenizer.eos_id())
        tokens = tokens[:eos_pos]
      except ValueError:
        pass
      return tokenizer.decode(tokens)

    # Make predictions
    # Evaluation/inference loop.
    SEQLEN = 128
    def make_predictions(data_iterator, *, num_examples=None,
                        batch_size=4, seqlen=SEQLEN, sampler="greedy"):
      outputs = []
      while True:
        # Construct a list of examples in the batch.
        examples = []
        try:
          for _ in range(batch_size):
            examples.append(next(data_iterator))
            examples[-1]["_mask"] = np.array(True)  # Indicates true example.
        except StopIteration:
          if len(examples) == 0:
            return outputs

        # Not enough examples to complete a batch. Pad by repeating last example.
        while len(examples) % batch_size:
          examples.append(dict(examples[-1]))
          examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

        # Convert list of examples into a dict of np.arrays and load onto devices.
        batch = jax.tree.map(lambda *x: np.stack(x), *examples)
        batch = big_vision.utils.reshard(batch, data_sharding)

        # Make model predictions
        tokens = decode({"params": params}, batch=batch,
                        max_decode_len=seqlen, sampler=sampler)

        # Fetch model predictions to device and detokenize.
        tokens, mask = jax.device_get((tokens, batch["_mask"]))
        tokens = tokens[mask]  # remove padding examples.
        responses = [postprocess_tokens(t) for t in tokens]

        # Append to html output.
        for example, response in zip(examples, responses):
          outputs.append((example["image"], response))
          if num_examples and len(outputs) >= num_examples:
            return outputs

    def test_data_iterator(file_name):
      image = Image.open(file_name)
      image = preprocess_image(image)

      prefix = "caption en"
      tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

      yield {
          "image": np.asarray(image),
          "text": np.asarray(tokens),
          "mask_ar": np.asarray(mask_ar),
          "mask_input": np.asarray(mask_input)
      }

    # Call the prediction function and print the result
    image, caption = make_predictions(test_data_iterator(file_name=image_path), batch_size=1)[0]
    return caption