Spaces:
Sleeping
Sleeping
Dynamic UI + model loading.
Browse files- .gitignore +2 -0
- README.md +4 -4
- app.py +279 -0
- big_vision_contrastive_models.py +241 -0
- gradio_helpers.py +165 -0
- requirements.txt +12 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
/env
|
2 |
+
/__pycache__
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
app_file: app.py
|
9 |
-
pinned:
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
-
|
|
|
1 |
---
|
2 |
+
title: LiT Demo (big_vision)
|
3 |
+
emoji: π
|
4 |
colorFrom: purple
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.21.0
|
8 |
app_file: app.py
|
9 |
+
pinned: true
|
10 |
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
+
Gradio clone of the original [LiT Demo](https://google-research.github.io/vision_transformer/lit/)
|
app.py
ADDED
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Gradio clone of https://google-research.github.io/vision_transformer/lit/.
|
2 |
+
|
3 |
+
Features:
|
4 |
+
|
5 |
+
- Models are downloaded dynamically.
|
6 |
+
- Models are cached on local disk, and in RAM.
|
7 |
+
- Progress bars when downloading/reading/computing.
|
8 |
+
- Dynamic update of model controls.
|
9 |
+
- Dynamic generation of output sliders.
|
10 |
+
- Use of `gr.State()` for better use of progress bars.
|
11 |
+
"""
|
12 |
+
import dataclasses
|
13 |
+
import json
|
14 |
+
import logging
|
15 |
+
import os
|
16 |
+
import time
|
17 |
+
import urllib.request
|
18 |
+
|
19 |
+
import gradio as gr
|
20 |
+
import PIL.Image
|
21 |
+
|
22 |
+
import big_vision_contrastive_models as models
|
23 |
+
import gradio_helpers
|
24 |
+
|
25 |
+
|
26 |
+
INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
|
27 |
+
IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'
|
28 |
+
MAX_ANSWERS = 10
|
29 |
+
|
30 |
+
MAX_DISK_CACHE = 20e9
|
31 |
+
MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM
|
32 |
+
|
33 |
+
LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
|
34 |
+
|
35 |
+
|
36 |
+
# family/variant/res -> name
|
37 |
+
MODEL_MAP = {
|
38 |
+
'lit': {
|
39 |
+
'B/16': {
|
40 |
+
224: 'lit_b16b',
|
41 |
+
},
|
42 |
+
'L/16': {
|
43 |
+
224: 'lit_l16l',
|
44 |
+
},
|
45 |
+
},
|
46 |
+
'siglip': {
|
47 |
+
'B/16': {
|
48 |
+
224: 'siglip_b16b_224',
|
49 |
+
256: 'siglip_b16b_256',
|
50 |
+
384: 'siglip_b16b_384',
|
51 |
+
512: 'siglip_b16b_512',
|
52 |
+
},
|
53 |
+
'L/16': {
|
54 |
+
256: 'siglip_l16l_256',
|
55 |
+
384: 'siglip_l16l_384',
|
56 |
+
},
|
57 |
+
'So400m/14': {
|
58 |
+
224: 'siglip_so400m14so440m_224',
|
59 |
+
384: 'siglip_so400m14so440m_384',
|
60 |
+
},
|
61 |
+
},
|
62 |
+
}
|
63 |
+
|
64 |
+
|
65 |
+
def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progress()):
|
66 |
+
"""Loads model and computes answers."""
|
67 |
+
|
68 |
+
if image_path is None:
|
69 |
+
raise gr.Error('Must first select an image!')
|
70 |
+
|
71 |
+
t0 = time.monotonic()
|
72 |
+
|
73 |
+
model_name = MODEL_MAP[family][variant][res]
|
74 |
+
config = models.MODEL_CONFIGS[model_name]
|
75 |
+
local_ckpt = gradio_helpers.get_disk_cache(
|
76 |
+
config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
|
77 |
+
config = dataclasses.replace(config, ckpt=local_ckpt)
|
78 |
+
params, model = gradio_helpers.get_memory_cache(
|
79 |
+
config,
|
80 |
+
lambda: models.load_model(config),
|
81 |
+
max_cache_size_bytes=MAX_RAM_CACHE,
|
82 |
+
progress=progress,
|
83 |
+
estimated_secs={
|
84 |
+
('lit', 'B/16'): 1,
|
85 |
+
('lit', 'L/16'): 2.5,
|
86 |
+
('siglip', 'B/16'): 9,
|
87 |
+
('siglip', 'L/16'): 28,
|
88 |
+
('siglip', 'So400m/14'): 36,
|
89 |
+
}.get((family, variant))
|
90 |
+
)
|
91 |
+
model: models.ContrastiveModel = model
|
92 |
+
|
93 |
+
it = progress.tqdm(list(range(3)), desc='compute')
|
94 |
+
|
95 |
+
logging.info('Opening image "%s"', image_path)
|
96 |
+
with gradio_helpers.timed(f'opening image "{image_path}"'):
|
97 |
+
image = PIL.Image.open(image_path)
|
98 |
+
next(it)
|
99 |
+
with gradio_helpers.timed('image features'):
|
100 |
+
zimg, out = model.embed_images(
|
101 |
+
params, model.preprocess_images([image])
|
102 |
+
)
|
103 |
+
next(it)
|
104 |
+
with gradio_helpers.timed('text features'):
|
105 |
+
prompts = prompts.split('\n')
|
106 |
+
ztxt, out = model.embed_texts(
|
107 |
+
params, model.preprocess_texts(prompts)
|
108 |
+
)
|
109 |
+
next(it)
|
110 |
+
|
111 |
+
t = model.get_temperature(out)
|
112 |
+
if family == 'lit':
|
113 |
+
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
|
114 |
+
elif family == 'siglip':
|
115 |
+
text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0])
|
116 |
+
|
117 |
+
state = list(zip(prompts, [round(p.item(), 3) for p in text_probs]))
|
118 |
+
|
119 |
+
dt = time.monotonic() - t0
|
120 |
+
mem_n, mem_sz = gradio_helpers.get_memory_cache_info()
|
121 |
+
disk_n, disk_sz = gradio_helpers.get_disk_cache_info()
|
122 |
+
status = gr.Markdown(
|
123 |
+
f'Computed inference in {dt:.1f} seconds ('
|
124 |
+
f'memory cache {mem_n} items, {mem_sz/1e6:.1f} M, '
|
125 |
+
f'disk cache {disk_n} items, {disk_sz/1e6:.1f} M)')
|
126 |
+
|
127 |
+
if 'b' in out:
|
128 |
+
logging.info('model_name=%s default bias=%f', model_name, out['b'])
|
129 |
+
|
130 |
+
return status, state
|
131 |
+
|
132 |
+
|
133 |
+
def update_answers(state):
|
134 |
+
"""Generates visible sliders for answers."""
|
135 |
+
answers = []
|
136 |
+
for prompt, prob in state[:MAX_ANSWERS]:
|
137 |
+
answers.append(gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
|
138 |
+
while len(answers) < MAX_ANSWERS:
|
139 |
+
answers.append(gr.Slider(visible=False))
|
140 |
+
return answers
|
141 |
+
|
142 |
+
|
143 |
+
def create_app():
|
144 |
+
"""Creates demo UI."""
|
145 |
+
|
146 |
+
css = '''
|
147 |
+
.slider input[type="number"] { width: 5em; }
|
148 |
+
#examples td.textbox > div {
|
149 |
+
white-space: pre-wrap !important;
|
150 |
+
text-align: left;
|
151 |
+
}
|
152 |
+
'''
|
153 |
+
|
154 |
+
with gr.Blocks(css=css) as demo:
|
155 |
+
|
156 |
+
gr.Markdown('Gradio clone of the original [LiT demo](https://google-research.github.io/vision_transformer/lit/).')
|
157 |
+
|
158 |
+
status = gr.Markdown()
|
159 |
+
|
160 |
+
with gr.Row():
|
161 |
+
image = gr.Image(label='Image', type='filepath')
|
162 |
+
source = gr.Markdown('', visible=False)
|
163 |
+
state = gr.State([])
|
164 |
+
with gr.Column():
|
165 |
+
prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)')
|
166 |
+
with gr.Row():
|
167 |
+
|
168 |
+
values = {}
|
169 |
+
|
170 |
+
family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
|
171 |
+
values['family'] = family.value
|
172 |
+
|
173 |
+
# Unfortunately below reactive UI code is a bit convoluted, because:
|
174 |
+
# 1. When e.g. `family.change()` updates `variant`, then that does not
|
175 |
+
# trigger a `varaint.change()`.
|
176 |
+
# 2. The widget values like `family.value` are *not* updated when the
|
177 |
+
# widget is updated. Therefore, we keep a manual copy in `values`.
|
178 |
+
|
179 |
+
def make_variant(family_value):
|
180 |
+
choices = list(MODEL_MAP[family_value])
|
181 |
+
values['variant'] = choices[0]
|
182 |
+
return gr.Dropdown(value=values['variant'], choices=choices, label='Variant')
|
183 |
+
variant = make_variant(family.value)
|
184 |
+
|
185 |
+
def make_res(family, variant):
|
186 |
+
choices = list(MODEL_MAP[family][variant])
|
187 |
+
values['res'] = choices[0]
|
188 |
+
return gr.Dropdown(value=values['res'], choices=choices, label='Resolution')
|
189 |
+
res = make_res(family.value, variant.value)
|
190 |
+
values['res'] = res.value
|
191 |
+
|
192 |
+
def make_bias(family, variant, res):
|
193 |
+
visible = family == 'siglip'
|
194 |
+
value = {
|
195 |
+
('siglip', 'B/16', 224): -12.9,
|
196 |
+
('siglip', 'L/16', 256): -12.7,
|
197 |
+
('siglip', 'L/16', 256): -16.5,
|
198 |
+
# ...
|
199 |
+
}.get((family, variant, res), -10.0)
|
200 |
+
return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible)
|
201 |
+
bias = make_bias(family.value, variant.value, res.value)
|
202 |
+
values['bias'] = bias.value
|
203 |
+
|
204 |
+
def family_changed(family):
|
205 |
+
variant = list(MODEL_MAP[family])[0]
|
206 |
+
res = list(MODEL_MAP[family][variant])[0]
|
207 |
+
values['family'] = family
|
208 |
+
values['variant'] = variant
|
209 |
+
values['res'] = res
|
210 |
+
return [
|
211 |
+
make_variant(family),
|
212 |
+
make_res(family, variant),
|
213 |
+
make_bias(family, variant, res),
|
214 |
+
]
|
215 |
+
|
216 |
+
def variant_changed(variant):
|
217 |
+
res = list(MODEL_MAP[values['family']][variant])[0]
|
218 |
+
values['variant'] = variant
|
219 |
+
values['res'] = res
|
220 |
+
return [
|
221 |
+
make_res(values['family'], variant),
|
222 |
+
make_bias(values['family'], variant, res),
|
223 |
+
]
|
224 |
+
|
225 |
+
def res_changed(res):
|
226 |
+
return make_bias(values['family'], values['variant'], res)
|
227 |
+
|
228 |
+
family.change(family_changed, family, [variant, res, bias])
|
229 |
+
variant.change(variant_changed, variant, [res, bias])
|
230 |
+
res.change(res_changed, res, bias)
|
231 |
+
|
232 |
+
# (end of code for reactive UI code)
|
233 |
+
|
234 |
+
run = gr.Button('Run')
|
235 |
+
answers = [
|
236 |
+
# Will be set to visible in `update_answers()`.
|
237 |
+
gr.Slider(0, 100, 0, visible=False, elem_classes='slider')
|
238 |
+
for _ in range(MAX_ANSWERS)
|
239 |
+
]
|
240 |
+
|
241 |
+
# We want to avoid showing multiple progress bars, so we only update
|
242 |
+
# a single `status` widget here, and store the computed information in
|
243 |
+
# `state`...
|
244 |
+
run.click(
|
245 |
+
fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state])
|
246 |
+
# ... then we use `state` to update UI components without showing a
|
247 |
+
# progress bar in their place.
|
248 |
+
status.change(fn=update_answers, inputs=state, outputs=answers)
|
249 |
+
|
250 |
+
info = json.load(urllib.request.urlopen(INFO_URL))
|
251 |
+
gr.Markdown('Note: below images have 224 px resolution only:')
|
252 |
+
gr.Examples(
|
253 |
+
examples=[
|
254 |
+
[
|
255 |
+
IMG_URL_FMT.format(ex['id']),
|
256 |
+
ex['prompts'].replace(', ', '\n'),
|
257 |
+
'[source](%s)' % ex['source'],
|
258 |
+
]
|
259 |
+
for ex in info
|
260 |
+
],
|
261 |
+
inputs=[image, prompts, source, license],
|
262 |
+
outputs=answers,
|
263 |
+
elem_id='examples',
|
264 |
+
)
|
265 |
+
|
266 |
+
return demo
|
267 |
+
|
268 |
+
|
269 |
+
if __name__ == "__main__":
|
270 |
+
|
271 |
+
logging.basicConfig(level=logging.INFO,
|
272 |
+
format='%(asctime)s - %(levelname)s - %(message)s')
|
273 |
+
|
274 |
+
for k, v in os.environ.items():
|
275 |
+
logging.info('environ["%s"] = %r', k, v)
|
276 |
+
|
277 |
+
models.setup()
|
278 |
+
|
279 |
+
create_app().queue().launch()
|
big_vision_contrastive_models.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Wrapper for big_vision contrastive models.
|
2 |
+
|
3 |
+
Before using any of the functions, make sure to call `setup()`.
|
4 |
+
|
5 |
+
Choose one of the configs in `MODEL_CONFIGS` and then call `load_model()` to get
|
6 |
+
the params and model wrapper.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import dataclasses
|
10 |
+
import enum
|
11 |
+
import functools
|
12 |
+
import importlib
|
13 |
+
import os
|
14 |
+
import subprocess
|
15 |
+
import sys
|
16 |
+
import tempfile
|
17 |
+
|
18 |
+
import flax.linen as nn
|
19 |
+
import jax
|
20 |
+
import jax.numpy as jnp
|
21 |
+
import ml_collections
|
22 |
+
import numpy as np
|
23 |
+
import PIL.Image
|
24 |
+
import sentencepiece
|
25 |
+
from tensorflow.io import gfile
|
26 |
+
import transformers
|
27 |
+
|
28 |
+
|
29 |
+
def _clone_git(url, destination_folder, commit_hash=None):
|
30 |
+
subprocess.run([
|
31 |
+
'git', 'clone', '--depth=1',
|
32 |
+
url, destination_folder
|
33 |
+
], check=True)
|
34 |
+
if commit_hash:
|
35 |
+
subprocess.run(['git', '-C', destination_folder, 'checkout', commit_hash], check=True)
|
36 |
+
|
37 |
+
|
38 |
+
def setup(commit_hash=None):
|
39 |
+
for url, dst_name in (
|
40 |
+
('https://github.com/google-research/big_vision', 'big_vision_repo'),
|
41 |
+
('https://github.com/google/flaxformer', 'flaxformer_repo'),
|
42 |
+
):
|
43 |
+
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
|
44 |
+
if not os.path.exists(dst_path):
|
45 |
+
_clone_git(url, dst_path, commit_hash)
|
46 |
+
if not dst_path in sys.path:
|
47 |
+
sys.path.insert(0, dst_path)
|
48 |
+
|
49 |
+
|
50 |
+
class ContrastiveModelFamily(enum.Enum):
|
51 |
+
LIT = 'lit'
|
52 |
+
SIGLIP = 'siglip'
|
53 |
+
|
54 |
+
@property
|
55 |
+
def paper(self):
|
56 |
+
return {
|
57 |
+
self.LIT: 'https://arxiv.org/abs/2111.07991',
|
58 |
+
self.SIGLIP: 'https://arxiv.org/abs/2303.15343',
|
59 |
+
}[self]
|
60 |
+
|
61 |
+
def __lt__(self, other):
|
62 |
+
return self.value < other.value
|
63 |
+
|
64 |
+
|
65 |
+
@dataclasses.dataclass(frozen=True, kw_only=True, order=True)
|
66 |
+
class ContrastiveModelConfig:
|
67 |
+
"""Desribes a `big_vision` contrastive model."""
|
68 |
+
family: ContrastiveModelFamily
|
69 |
+
variant: str
|
70 |
+
res: int
|
71 |
+
textvariant: str
|
72 |
+
embdim: int
|
73 |
+
seqlen: int
|
74 |
+
tokenizer: str
|
75 |
+
vocab_size: int
|
76 |
+
ckpt: str
|
77 |
+
|
78 |
+
|
79 |
+
@dataclasses.dataclass(frozen=True, kw_only=True)
|
80 |
+
class ContrastiveModel:
|
81 |
+
"""Wraps a `big_vision` contrastive model."""
|
82 |
+
|
83 |
+
config: ContrastiveModelConfig
|
84 |
+
flax_module: nn.Module
|
85 |
+
tokenizer_sp: sentencepiece.SentencePieceProcessor | None
|
86 |
+
tokenizer_bert: transformers.BertTokenizer | None
|
87 |
+
|
88 |
+
def embed_images(self, params, images):
|
89 |
+
assert getattr(images, 'ndim') == 4, 'Must call `.preprocess_images()`'
|
90 |
+
zimg, _, out = self.flax_module.apply(dict(params=params), images, None)
|
91 |
+
return zimg, out
|
92 |
+
|
93 |
+
def embed_texts(self, params, texts):
|
94 |
+
assert getattr(texts, 'ndim') == 2, 'Must call `.preprocess_texts()`'
|
95 |
+
_, ztxt, out = self.flax_module.apply(dict(params=params), None, texts)
|
96 |
+
return ztxt, out
|
97 |
+
|
98 |
+
def preprocess_texts(self, texts):
|
99 |
+
|
100 |
+
def tokenize_pad(text, seqlen=self.config.seqlen):
|
101 |
+
|
102 |
+
if self.config.family == ContrastiveModelFamily.LIT:
|
103 |
+
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)[:-1] # removes [SEP]
|
104 |
+
tokens = tokens[:seqlen]
|
105 |
+
return tokens + [0] * (seqlen - len(tokens))
|
106 |
+
|
107 |
+
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
108 |
+
tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
|
109 |
+
if len(tokens) >= seqlen:
|
110 |
+
return tokens[:seqlen - 1] + [tok.eos_id()] # "sticky" eos
|
111 |
+
return tokens + [0] * (seqlen - len(tokens))
|
112 |
+
|
113 |
+
return np.array([tokenize_pad(text) for text in texts])
|
114 |
+
|
115 |
+
def preprocess_images(self, images):
|
116 |
+
if not isinstance(images, (list, tuple)):
|
117 |
+
images = [images]
|
118 |
+
def topil(image):
|
119 |
+
if not isinstance(image, PIL.Image.Image):
|
120 |
+
image = PIL.Image.fromarray(image)
|
121 |
+
return image
|
122 |
+
return np.array([
|
123 |
+
topil(image).resize([self.config.res, self.config.res])
|
124 |
+
for image in images
|
125 |
+
]) / 127.5 - 1.0
|
126 |
+
|
127 |
+
def get_bias(self, out):
|
128 |
+
assert self.config.family == ContrastiveModelFamily.SIGLIP, self.config.family
|
129 |
+
return out['b'].item()
|
130 |
+
|
131 |
+
def get_temperature(self, out):
|
132 |
+
return out['t'].item()
|
133 |
+
|
134 |
+
def get_probabilities(self, zimg, ztxt, temperature, *, axis=None, bias=None):
|
135 |
+
# Note: zimg, ztxt are already normalized.
|
136 |
+
|
137 |
+
if self.config.family == ContrastiveModelFamily.LIT:
|
138 |
+
assert bias is None
|
139 |
+
assert axis in (-1, -2), 'Must specify axis: -1/-2=normalize texts/images'
|
140 |
+
return jax.nn.softmax(zimg @ ztxt.T * temperature, axis=axis)
|
141 |
+
|
142 |
+
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
143 |
+
assert axis is None
|
144 |
+
assert bias is not None, 'Must specify bias.'
|
145 |
+
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
|
146 |
+
|
147 |
+
|
148 |
+
def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size):
|
149 |
+
if family == 'lit':
|
150 |
+
tokenizer = ckpt.replace('.npz', '.txt')
|
151 |
+
else:
|
152 |
+
tokenizer = 'c4_en'
|
153 |
+
return ContrastiveModelConfig(
|
154 |
+
family=ContrastiveModelFamily(family), variant=variant, res=res,
|
155 |
+
textvariant=textvariant, embdim=embdim, seqlen=seqlen,
|
156 |
+
tokenizer=tokenizer, vocab_size=32_000,
|
157 |
+
ckpt=ckpt,
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
MODEL_CONFIGS = dict(
|
162 |
+
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
|
163 |
+
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
|
164 |
+
lit_b16s=_make_config('lit', 'L/16', 224, 'S', 'gs://vit_models/lit/LiT-L16S.npz', 1024, 16, 32_000),
|
165 |
+
lit_b16ti=_make_config('lit', 'L/16', 224, 'Ti', 'gs://vit_models/lit/LiT-L16Ti.npz', 1024, 16, 32_000),
|
166 |
+
|
167 |
+
siglip_b16b_224=_make_config('siglip', 'B/16', 224, 'B', 'gs://big_vision/siglip/webli_en_b16_224_63724782.npz', 768, 64, 32_000),
|
168 |
+
siglip_b16b_256=_make_config('siglip', 'B/16', 256, 'B', 'gs://big_vision/siglip/webli_en_b16_256_60500360.npz', 768, 64, 32_000),
|
169 |
+
siglip_b16b_384=_make_config('siglip', 'B/16', 384, 'B', 'gs://big_vision/siglip/webli_en_b16_384_68578854.npz', 768, 64, 32_000),
|
170 |
+
siglip_b16b_512=_make_config('siglip', 'B/16', 512, 'B', 'gs://big_vision/siglip/webli_en_b16_512_68580893.npz', 768, 64, 32_000),
|
171 |
+
siglip_l16l_256=_make_config('siglip', 'L/16', 256, 'L', 'gs://big_vision/siglip/webli_en_l16_256_60552751.npz', 1024, 64, 32_000),
|
172 |
+
siglip_l16l_384=_make_config('siglip', 'L/16', 384, 'L', 'gs://big_vision/siglip/webli_en_l16_384_63634585.npz', 1024, 64, 32_000),
|
173 |
+
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
|
174 |
+
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
|
175 |
+
)
|
176 |
+
|
177 |
+
|
178 |
+
@functools.cache
|
179 |
+
def load_tokenizer_sp(name_or_path):
|
180 |
+
tok = sentencepiece.SentencePieceProcessor()
|
181 |
+
path = {
|
182 |
+
'c4_en': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model',
|
183 |
+
}.get(name_or_path, name_or_path)
|
184 |
+
tok.LoadFromSerializedProto(gfile.GFile(path, 'rb').read())
|
185 |
+
return tok
|
186 |
+
|
187 |
+
|
188 |
+
@functools.cache
|
189 |
+
def load_tokenizer_bert(path):
|
190 |
+
tok = sentencepiece.SentencePieceProcessor()
|
191 |
+
if path.startswith('gs://'):
|
192 |
+
dst = tempfile.mktemp()
|
193 |
+
gfile.copy(path, dst)
|
194 |
+
path = dst
|
195 |
+
return transformers.BertTokenizer(path, do_lower_case=True)
|
196 |
+
|
197 |
+
|
198 |
+
def load_model(config, check_params=False):
|
199 |
+
"""Loads `big_vision` model."""
|
200 |
+
assert isinstance(config, ContrastiveModelConfig), type(config)
|
201 |
+
|
202 |
+
cfg = ml_collections.ConfigDict()
|
203 |
+
cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
|
204 |
+
if config.family == ContrastiveModelFamily.LIT:
|
205 |
+
cfg.text_model = 'proj.flaxformer.bert'
|
206 |
+
cfg.image = dict(variant=config.variant, pool_type='tok', head_zeroinit=False)
|
207 |
+
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
|
208 |
+
cfg.text = dict(config=bert_config, head_zeroinit=False)
|
209 |
+
tokenizer_bert = load_tokenizer_bert(config.tokenizer)
|
210 |
+
tokenizer_sp = None
|
211 |
+
if config.variant == 'L/16':
|
212 |
+
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
213 |
+
else:
|
214 |
+
cfg.out_dim = (config.embdim, config.embdim) # (image_out_dim, text_out_dim)
|
215 |
+
else:
|
216 |
+
cfg.image = dict(variant=config.variant, pool_type='map')
|
217 |
+
cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default
|
218 |
+
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
|
219 |
+
cfg.bias_init = -10.0
|
220 |
+
tokenizer_sp = load_tokenizer_sp(config.tokenizer)
|
221 |
+
tokenizer_bert = None
|
222 |
+
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
223 |
+
cfg.temperature_init = 10.0
|
224 |
+
|
225 |
+
model_mod = importlib.import_module(
|
226 |
+
'big_vision.models.proj.image_text.two_towers')
|
227 |
+
model = model_mod.Model(**cfg)
|
228 |
+
|
229 |
+
init_params = None # Faster but bypasses loading sanity-checks.
|
230 |
+
if check_params:
|
231 |
+
imgs = jnp.zeros([1, config.res, config.res, 3])
|
232 |
+
txts = jnp.zeros([1, config.seqlen], jnp.int32)
|
233 |
+
init_params = model.init(jax.random.PRNGKey(0), imgs, txts)['params']
|
234 |
+
params_cpu = model_mod.load(init_params, config.ckpt, cfg)
|
235 |
+
|
236 |
+
return params_cpu, ContrastiveModel(
|
237 |
+
config=config,
|
238 |
+
flax_module=model,
|
239 |
+
tokenizer_sp=tokenizer_sp,
|
240 |
+
tokenizer_bert=tokenizer_bert,
|
241 |
+
)
|
gradio_helpers.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Gradio utilities.
|
2 |
+
|
3 |
+
Note that the optional `progress` parameter can be both a `tqdm` module or a
|
4 |
+
`gr.Progress` instance.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import concurrent.futures
|
8 |
+
import contextlib
|
9 |
+
import glob
|
10 |
+
import hashlib
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
import tempfile
|
14 |
+
import time
|
15 |
+
import urllib.request
|
16 |
+
|
17 |
+
import jax
|
18 |
+
import numpy as np
|
19 |
+
from tensorflow.io import gfile
|
20 |
+
|
21 |
+
|
22 |
+
@contextlib.contextmanager
|
23 |
+
def timed(name):
|
24 |
+
t0 = time.monotonic()
|
25 |
+
timing = dict(dt=None)
|
26 |
+
try:
|
27 |
+
yield timing
|
28 |
+
finally:
|
29 |
+
timing['secs'] = time.monotonic() - t0
|
30 |
+
logging.info('Timed %s: %.1f secs', name, timing['secs'])
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False):
|
35 |
+
"""Copies a file with progress bar.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
src: Source file (readable by `tf.io.gfile`) or URL.
|
39 |
+
dst: Destination file. Path must be readable by `tf.io.gfile`.
|
40 |
+
progress: An object with a `.tqdm` attribute, or `None`.
|
41 |
+
block_size: Size of individual blocks to be read/written.
|
42 |
+
"""
|
43 |
+
if os.path.dirname(dst):
|
44 |
+
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
45 |
+
if os.path.exists(dst) and not overwrite:
|
46 |
+
return
|
47 |
+
|
48 |
+
if src.startswith('http://') or src.startswith('https://'):
|
49 |
+
opener = urllib.request.urlopen
|
50 |
+
request = urllib.request.Request(src, method='HEAD')
|
51 |
+
response = urllib.request.urlopen(request)
|
52 |
+
content_length = response.headers.get('Content-Length')
|
53 |
+
n = int(np.ceil(int(content_length) / block_size))
|
54 |
+
print('content_length', content_length)
|
55 |
+
else:
|
56 |
+
opener = lambda path: gfile.GFile(path, 'rb')
|
57 |
+
stats = gfile.stat(src)
|
58 |
+
n = int(np.ceil(stats.length / block_size))
|
59 |
+
|
60 |
+
if progress is None:
|
61 |
+
range_or_trange = range
|
62 |
+
else:
|
63 |
+
range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download')
|
64 |
+
|
65 |
+
with opener(src) as fin:
|
66 |
+
with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout:
|
67 |
+
for _ in range_or_trange(n):
|
68 |
+
fout.write(fin.read(block_size))
|
69 |
+
gfile.rename(f'{dst}-PARTIAL', dst)
|
70 |
+
|
71 |
+
|
72 |
+
_estimated_real = [(10, 10)]
|
73 |
+
_memory_cache = {}
|
74 |
+
|
75 |
+
|
76 |
+
def get_with_progress(getter, secs, progress, step=0.1):
|
77 |
+
"""Returns result from `getter` while showing a progress bar."""
|
78 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
79 |
+
future = executor.submit(getter)
|
80 |
+
for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
|
81 |
+
if not future.done():
|
82 |
+
time.sleep(step)
|
83 |
+
return future.result()
|
84 |
+
|
85 |
+
|
86 |
+
def _get_array_sizes(tree):
|
87 |
+
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
|
88 |
+
|
89 |
+
|
90 |
+
def get_memory_cache(key, getter, max_cache_size_bytes, progress=None, estimated_secs=None):
|
91 |
+
"""Keeps cache below specified size by removing elements not last accessed."""
|
92 |
+
if key in _memory_cache:
|
93 |
+
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
|
94 |
+
return _memory_cache[key]
|
95 |
+
|
96 |
+
est, real = zip(*_estimated_real)
|
97 |
+
if estimated_secs is None:
|
98 |
+
estimated_secs = sum(est) / len(est)
|
99 |
+
with timed(f'loading {key}') as timing:
|
100 |
+
estimated_secs *= sum(real) / sum(est)
|
101 |
+
_memory_cache[key] = get_with_progress(getter, estimated_secs, progress)
|
102 |
+
_estimated_real.append((estimated_secs, timing['secs']))
|
103 |
+
|
104 |
+
sz = sum(_get_array_sizes(list(_memory_cache.values())))
|
105 |
+
logging.info('New memory cache size=%.1f MB', sz/1e6)
|
106 |
+
|
107 |
+
while sz > max_cache_size_bytes:
|
108 |
+
k, v = next(iter(_memory_cache.items()))
|
109 |
+
if k == key:
|
110 |
+
break
|
111 |
+
s = sum(_get_array_sizes(v))
|
112 |
+
logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
|
113 |
+
_memory_cache.pop(k)
|
114 |
+
sz -= s
|
115 |
+
|
116 |
+
return _memory_cache[key]
|
117 |
+
|
118 |
+
|
119 |
+
def get_memory_cache_info():
|
120 |
+
"""Returns number of items and total size in bytes."""
|
121 |
+
sizes = _get_array_sizes(_memory_cache)
|
122 |
+
return len(_memory_cache), sum(sizes)
|
123 |
+
|
124 |
+
|
125 |
+
CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache')
|
126 |
+
|
127 |
+
|
128 |
+
def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None):
|
129 |
+
"""Keeps cache below specified size by removing elements not last accessed."""
|
130 |
+
fname = os.path.basename(path_or_url)
|
131 |
+
path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname
|
132 |
+
dst = os.path.join(CACHE_DIR, path_hash, fname)
|
133 |
+
if os.path.exists(dst):
|
134 |
+
return dst
|
135 |
+
|
136 |
+
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
137 |
+
with timed(f'copying {path_or_url}'):
|
138 |
+
copy_file(path_or_url, dst, progress=progress)
|
139 |
+
|
140 |
+
atimes_sizes_paths = sorted([
|
141 |
+
(os.path.getatime(p), os.path.getsize(p), p)
|
142 |
+
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
|
143 |
+
if os.path.isfile(p)
|
144 |
+
])
|
145 |
+
sz = sum(sz for _, sz, _ in atimes_sizes_paths)
|
146 |
+
logging.info('New disk cache size=%.1f MB', sz/1e6)
|
147 |
+
|
148 |
+
while sz > max_cache_size_bytes:
|
149 |
+
_, s, path = atimes_sizes_paths.pop(0)
|
150 |
+
if path == dst:
|
151 |
+
break
|
152 |
+
logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6)
|
153 |
+
os.unlink(fname)
|
154 |
+
sz -= s
|
155 |
+
|
156 |
+
return dst
|
157 |
+
|
158 |
+
|
159 |
+
def get_disk_cache_info():
|
160 |
+
"""Returns number of items and total size in bytes."""
|
161 |
+
sizes = [
|
162 |
+
os.path.getsize(p)
|
163 |
+
for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
|
164 |
+
]
|
165 |
+
return len(sizes), sum(sizes)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aqtp # for flaxformer
|
2 |
+
einops
|
3 |
+
flax
|
4 |
+
gradio
|
5 |
+
jax
|
6 |
+
jaxlib
|
7 |
+
ml_collections
|
8 |
+
numpy
|
9 |
+
Pillow
|
10 |
+
sentencepiece
|
11 |
+
transformers # for transformers.BertTokenizer
|
12 |
+
tensorflow # for tf.io.gfile
|