andsteing commited on
Commit
e380bd8
β€’
1 Parent(s): 296fec5

Dynamic UI + model loading.

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. README.md +4 -4
  3. app.py +279 -0
  4. big_vision_contrastive_models.py +241 -0
  5. gradio_helpers.py +165 -0
  6. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /env
2
+ /__pycache__
README.md CHANGED
@@ -1,13 +1,13 @@
1
  ---
2
- title: Lit Demo Bv
3
- emoji: πŸ“‰
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.21.0
8
  app_file: app.py
9
- pinned: false
10
  license: apache-2.0
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: LiT Demo (big_vision)
3
+ emoji: πŸ”’
4
  colorFrom: purple
5
  colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 4.21.0
8
  app_file: app.py
9
+ pinned: true
10
  license: apache-2.0
11
  ---
12
 
13
+ Gradio clone of the original [LiT Demo](https://google-research.github.io/vision_transformer/lit/)
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio clone of https://google-research.github.io/vision_transformer/lit/.
2
+
3
+ Features:
4
+
5
+ - Models are downloaded dynamically.
6
+ - Models are cached on local disk, and in RAM.
7
+ - Progress bars when downloading/reading/computing.
8
+ - Dynamic update of model controls.
9
+ - Dynamic generation of output sliders.
10
+ - Use of `gr.State()` for better use of progress bars.
11
+ """
12
+ import dataclasses
13
+ import json
14
+ import logging
15
+ import os
16
+ import time
17
+ import urllib.request
18
+
19
+ import gradio as gr
20
+ import PIL.Image
21
+
22
+ import big_vision_contrastive_models as models
23
+ import gradio_helpers
24
+
25
+
26
+ INFO_URL = 'https://google-research.github.io/vision_transformer/lit/data/images/info.json'
27
+ IMG_URL_FMT = 'https://google-research.github.io/vision_transformer/lit/data/images/{}.jpg'
28
+ MAX_ANSWERS = 10
29
+
30
+ MAX_DISK_CACHE = 20e9
31
+ MAX_RAM_CACHE = 10e9 # CPU basic has 16G RAM
32
+
33
+ LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
34
+
35
+
36
+ # family/variant/res -> name
37
+ MODEL_MAP = {
38
+ 'lit': {
39
+ 'B/16': {
40
+ 224: 'lit_b16b',
41
+ },
42
+ 'L/16': {
43
+ 224: 'lit_l16l',
44
+ },
45
+ },
46
+ 'siglip': {
47
+ 'B/16': {
48
+ 224: 'siglip_b16b_224',
49
+ 256: 'siglip_b16b_256',
50
+ 384: 'siglip_b16b_384',
51
+ 512: 'siglip_b16b_512',
52
+ },
53
+ 'L/16': {
54
+ 256: 'siglip_l16l_256',
55
+ 384: 'siglip_l16l_384',
56
+ },
57
+ 'So400m/14': {
58
+ 224: 'siglip_so400m14so440m_224',
59
+ 384: 'siglip_so400m14so440m_384',
60
+ },
61
+ },
62
+ }
63
+
64
+
65
+ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progress()):
66
+ """Loads model and computes answers."""
67
+
68
+ if image_path is None:
69
+ raise gr.Error('Must first select an image!')
70
+
71
+ t0 = time.monotonic()
72
+
73
+ model_name = MODEL_MAP[family][variant][res]
74
+ config = models.MODEL_CONFIGS[model_name]
75
+ local_ckpt = gradio_helpers.get_disk_cache(
76
+ config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
77
+ config = dataclasses.replace(config, ckpt=local_ckpt)
78
+ params, model = gradio_helpers.get_memory_cache(
79
+ config,
80
+ lambda: models.load_model(config),
81
+ max_cache_size_bytes=MAX_RAM_CACHE,
82
+ progress=progress,
83
+ estimated_secs={
84
+ ('lit', 'B/16'): 1,
85
+ ('lit', 'L/16'): 2.5,
86
+ ('siglip', 'B/16'): 9,
87
+ ('siglip', 'L/16'): 28,
88
+ ('siglip', 'So400m/14'): 36,
89
+ }.get((family, variant))
90
+ )
91
+ model: models.ContrastiveModel = model
92
+
93
+ it = progress.tqdm(list(range(3)), desc='compute')
94
+
95
+ logging.info('Opening image "%s"', image_path)
96
+ with gradio_helpers.timed(f'opening image "{image_path}"'):
97
+ image = PIL.Image.open(image_path)
98
+ next(it)
99
+ with gradio_helpers.timed('image features'):
100
+ zimg, out = model.embed_images(
101
+ params, model.preprocess_images([image])
102
+ )
103
+ next(it)
104
+ with gradio_helpers.timed('text features'):
105
+ prompts = prompts.split('\n')
106
+ ztxt, out = model.embed_texts(
107
+ params, model.preprocess_texts(prompts)
108
+ )
109
+ next(it)
110
+
111
+ t = model.get_temperature(out)
112
+ if family == 'lit':
113
+ text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
114
+ elif family == 'siglip':
115
+ text_probs = list(model.get_probabilities(zimg, ztxt, t, bias=bias)[0])
116
+
117
+ state = list(zip(prompts, [round(p.item(), 3) for p in text_probs]))
118
+
119
+ dt = time.monotonic() - t0
120
+ mem_n, mem_sz = gradio_helpers.get_memory_cache_info()
121
+ disk_n, disk_sz = gradio_helpers.get_disk_cache_info()
122
+ status = gr.Markdown(
123
+ f'Computed inference in {dt:.1f} seconds ('
124
+ f'memory cache {mem_n} items, {mem_sz/1e6:.1f} M, '
125
+ f'disk cache {disk_n} items, {disk_sz/1e6:.1f} M)')
126
+
127
+ if 'b' in out:
128
+ logging.info('model_name=%s default bias=%f', model_name, out['b'])
129
+
130
+ return status, state
131
+
132
+
133
+ def update_answers(state):
134
+ """Generates visible sliders for answers."""
135
+ answers = []
136
+ for prompt, prob in state[:MAX_ANSWERS]:
137
+ answers.append(gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
138
+ while len(answers) < MAX_ANSWERS:
139
+ answers.append(gr.Slider(visible=False))
140
+ return answers
141
+
142
+
143
+ def create_app():
144
+ """Creates demo UI."""
145
+
146
+ css = '''
147
+ .slider input[type="number"] { width: 5em; }
148
+ #examples td.textbox > div {
149
+ white-space: pre-wrap !important;
150
+ text-align: left;
151
+ }
152
+ '''
153
+
154
+ with gr.Blocks(css=css) as demo:
155
+
156
+ gr.Markdown('Gradio clone of the original [LiT demo](https://google-research.github.io/vision_transformer/lit/).')
157
+
158
+ status = gr.Markdown()
159
+
160
+ with gr.Row():
161
+ image = gr.Image(label='Image', type='filepath')
162
+ source = gr.Markdown('', visible=False)
163
+ state = gr.State([])
164
+ with gr.Column():
165
+ prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)')
166
+ with gr.Row():
167
+
168
+ values = {}
169
+
170
+ family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
171
+ values['family'] = family.value
172
+
173
+ # Unfortunately below reactive UI code is a bit convoluted, because:
174
+ # 1. When e.g. `family.change()` updates `variant`, then that does not
175
+ # trigger a `varaint.change()`.
176
+ # 2. The widget values like `family.value` are *not* updated when the
177
+ # widget is updated. Therefore, we keep a manual copy in `values`.
178
+
179
+ def make_variant(family_value):
180
+ choices = list(MODEL_MAP[family_value])
181
+ values['variant'] = choices[0]
182
+ return gr.Dropdown(value=values['variant'], choices=choices, label='Variant')
183
+ variant = make_variant(family.value)
184
+
185
+ def make_res(family, variant):
186
+ choices = list(MODEL_MAP[family][variant])
187
+ values['res'] = choices[0]
188
+ return gr.Dropdown(value=values['res'], choices=choices, label='Resolution')
189
+ res = make_res(family.value, variant.value)
190
+ values['res'] = res.value
191
+
192
+ def make_bias(family, variant, res):
193
+ visible = family == 'siglip'
194
+ value = {
195
+ ('siglip', 'B/16', 224): -12.9,
196
+ ('siglip', 'L/16', 256): -12.7,
197
+ ('siglip', 'L/16', 256): -16.5,
198
+ # ...
199
+ }.get((family, variant, res), -10.0)
200
+ return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible)
201
+ bias = make_bias(family.value, variant.value, res.value)
202
+ values['bias'] = bias.value
203
+
204
+ def family_changed(family):
205
+ variant = list(MODEL_MAP[family])[0]
206
+ res = list(MODEL_MAP[family][variant])[0]
207
+ values['family'] = family
208
+ values['variant'] = variant
209
+ values['res'] = res
210
+ return [
211
+ make_variant(family),
212
+ make_res(family, variant),
213
+ make_bias(family, variant, res),
214
+ ]
215
+
216
+ def variant_changed(variant):
217
+ res = list(MODEL_MAP[values['family']][variant])[0]
218
+ values['variant'] = variant
219
+ values['res'] = res
220
+ return [
221
+ make_res(values['family'], variant),
222
+ make_bias(values['family'], variant, res),
223
+ ]
224
+
225
+ def res_changed(res):
226
+ return make_bias(values['family'], values['variant'], res)
227
+
228
+ family.change(family_changed, family, [variant, res, bias])
229
+ variant.change(variant_changed, variant, [res, bias])
230
+ res.change(res_changed, res, bias)
231
+
232
+ # (end of code for reactive UI code)
233
+
234
+ run = gr.Button('Run')
235
+ answers = [
236
+ # Will be set to visible in `update_answers()`.
237
+ gr.Slider(0, 100, 0, visible=False, elem_classes='slider')
238
+ for _ in range(MAX_ANSWERS)
239
+ ]
240
+
241
+ # We want to avoid showing multiple progress bars, so we only update
242
+ # a single `status` widget here, and store the computed information in
243
+ # `state`...
244
+ run.click(
245
+ fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state])
246
+ # ... then we use `state` to update UI components without showing a
247
+ # progress bar in their place.
248
+ status.change(fn=update_answers, inputs=state, outputs=answers)
249
+
250
+ info = json.load(urllib.request.urlopen(INFO_URL))
251
+ gr.Markdown('Note: below images have 224 px resolution only:')
252
+ gr.Examples(
253
+ examples=[
254
+ [
255
+ IMG_URL_FMT.format(ex['id']),
256
+ ex['prompts'].replace(', ', '\n'),
257
+ '[source](%s)' % ex['source'],
258
+ ]
259
+ for ex in info
260
+ ],
261
+ inputs=[image, prompts, source, license],
262
+ outputs=answers,
263
+ elem_id='examples',
264
+ )
265
+
266
+ return demo
267
+
268
+
269
+ if __name__ == "__main__":
270
+
271
+ logging.basicConfig(level=logging.INFO,
272
+ format='%(asctime)s - %(levelname)s - %(message)s')
273
+
274
+ for k, v in os.environ.items():
275
+ logging.info('environ["%s"] = %r', k, v)
276
+
277
+ models.setup()
278
+
279
+ create_app().queue().launch()
big_vision_contrastive_models.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Wrapper for big_vision contrastive models.
2
+
3
+ Before using any of the functions, make sure to call `setup()`.
4
+
5
+ Choose one of the configs in `MODEL_CONFIGS` and then call `load_model()` to get
6
+ the params and model wrapper.
7
+ """
8
+
9
+ import dataclasses
10
+ import enum
11
+ import functools
12
+ import importlib
13
+ import os
14
+ import subprocess
15
+ import sys
16
+ import tempfile
17
+
18
+ import flax.linen as nn
19
+ import jax
20
+ import jax.numpy as jnp
21
+ import ml_collections
22
+ import numpy as np
23
+ import PIL.Image
24
+ import sentencepiece
25
+ from tensorflow.io import gfile
26
+ import transformers
27
+
28
+
29
+ def _clone_git(url, destination_folder, commit_hash=None):
30
+ subprocess.run([
31
+ 'git', 'clone', '--depth=1',
32
+ url, destination_folder
33
+ ], check=True)
34
+ if commit_hash:
35
+ subprocess.run(['git', '-C', destination_folder, 'checkout', commit_hash], check=True)
36
+
37
+
38
+ def setup(commit_hash=None):
39
+ for url, dst_name in (
40
+ ('https://github.com/google-research/big_vision', 'big_vision_repo'),
41
+ ('https://github.com/google/flaxformer', 'flaxformer_repo'),
42
+ ):
43
+ dst_path = os.path.join(tempfile.gettempdir(), dst_name)
44
+ if not os.path.exists(dst_path):
45
+ _clone_git(url, dst_path, commit_hash)
46
+ if not dst_path in sys.path:
47
+ sys.path.insert(0, dst_path)
48
+
49
+
50
+ class ContrastiveModelFamily(enum.Enum):
51
+ LIT = 'lit'
52
+ SIGLIP = 'siglip'
53
+
54
+ @property
55
+ def paper(self):
56
+ return {
57
+ self.LIT: 'https://arxiv.org/abs/2111.07991',
58
+ self.SIGLIP: 'https://arxiv.org/abs/2303.15343',
59
+ }[self]
60
+
61
+ def __lt__(self, other):
62
+ return self.value < other.value
63
+
64
+
65
+ @dataclasses.dataclass(frozen=True, kw_only=True, order=True)
66
+ class ContrastiveModelConfig:
67
+ """Desribes a `big_vision` contrastive model."""
68
+ family: ContrastiveModelFamily
69
+ variant: str
70
+ res: int
71
+ textvariant: str
72
+ embdim: int
73
+ seqlen: int
74
+ tokenizer: str
75
+ vocab_size: int
76
+ ckpt: str
77
+
78
+
79
+ @dataclasses.dataclass(frozen=True, kw_only=True)
80
+ class ContrastiveModel:
81
+ """Wraps a `big_vision` contrastive model."""
82
+
83
+ config: ContrastiveModelConfig
84
+ flax_module: nn.Module
85
+ tokenizer_sp: sentencepiece.SentencePieceProcessor | None
86
+ tokenizer_bert: transformers.BertTokenizer | None
87
+
88
+ def embed_images(self, params, images):
89
+ assert getattr(images, 'ndim') == 4, 'Must call `.preprocess_images()`'
90
+ zimg, _, out = self.flax_module.apply(dict(params=params), images, None)
91
+ return zimg, out
92
+
93
+ def embed_texts(self, params, texts):
94
+ assert getattr(texts, 'ndim') == 2, 'Must call `.preprocess_texts()`'
95
+ _, ztxt, out = self.flax_module.apply(dict(params=params), None, texts)
96
+ return ztxt, out
97
+
98
+ def preprocess_texts(self, texts):
99
+
100
+ def tokenize_pad(text, seqlen=self.config.seqlen):
101
+
102
+ if self.config.family == ContrastiveModelFamily.LIT:
103
+ tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)[:-1] # removes [SEP]
104
+ tokens = tokens[:seqlen]
105
+ return tokens + [0] * (seqlen - len(tokens))
106
+
107
+ if self.config.family == ContrastiveModelFamily.SIGLIP:
108
+ tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
109
+ if len(tokens) >= seqlen:
110
+ return tokens[:seqlen - 1] + [tok.eos_id()] # "sticky" eos
111
+ return tokens + [0] * (seqlen - len(tokens))
112
+
113
+ return np.array([tokenize_pad(text) for text in texts])
114
+
115
+ def preprocess_images(self, images):
116
+ if not isinstance(images, (list, tuple)):
117
+ images = [images]
118
+ def topil(image):
119
+ if not isinstance(image, PIL.Image.Image):
120
+ image = PIL.Image.fromarray(image)
121
+ return image
122
+ return np.array([
123
+ topil(image).resize([self.config.res, self.config.res])
124
+ for image in images
125
+ ]) / 127.5 - 1.0
126
+
127
+ def get_bias(self, out):
128
+ assert self.config.family == ContrastiveModelFamily.SIGLIP, self.config.family
129
+ return out['b'].item()
130
+
131
+ def get_temperature(self, out):
132
+ return out['t'].item()
133
+
134
+ def get_probabilities(self, zimg, ztxt, temperature, *, axis=None, bias=None):
135
+ # Note: zimg, ztxt are already normalized.
136
+
137
+ if self.config.family == ContrastiveModelFamily.LIT:
138
+ assert bias is None
139
+ assert axis in (-1, -2), 'Must specify axis: -1/-2=normalize texts/images'
140
+ return jax.nn.softmax(zimg @ ztxt.T * temperature, axis=axis)
141
+
142
+ if self.config.family == ContrastiveModelFamily.SIGLIP:
143
+ assert axis is None
144
+ assert bias is not None, 'Must specify bias.'
145
+ return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
146
+
147
+
148
+ def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size):
149
+ if family == 'lit':
150
+ tokenizer = ckpt.replace('.npz', '.txt')
151
+ else:
152
+ tokenizer = 'c4_en'
153
+ return ContrastiveModelConfig(
154
+ family=ContrastiveModelFamily(family), variant=variant, res=res,
155
+ textvariant=textvariant, embdim=embdim, seqlen=seqlen,
156
+ tokenizer=tokenizer, vocab_size=32_000,
157
+ ckpt=ckpt,
158
+ )
159
+
160
+
161
+ MODEL_CONFIGS = dict(
162
+ lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
163
+ lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
164
+ lit_b16s=_make_config('lit', 'L/16', 224, 'S', 'gs://vit_models/lit/LiT-L16S.npz', 1024, 16, 32_000),
165
+ lit_b16ti=_make_config('lit', 'L/16', 224, 'Ti', 'gs://vit_models/lit/LiT-L16Ti.npz', 1024, 16, 32_000),
166
+
167
+ siglip_b16b_224=_make_config('siglip', 'B/16', 224, 'B', 'gs://big_vision/siglip/webli_en_b16_224_63724782.npz', 768, 64, 32_000),
168
+ siglip_b16b_256=_make_config('siglip', 'B/16', 256, 'B', 'gs://big_vision/siglip/webli_en_b16_256_60500360.npz', 768, 64, 32_000),
169
+ siglip_b16b_384=_make_config('siglip', 'B/16', 384, 'B', 'gs://big_vision/siglip/webli_en_b16_384_68578854.npz', 768, 64, 32_000),
170
+ siglip_b16b_512=_make_config('siglip', 'B/16', 512, 'B', 'gs://big_vision/siglip/webli_en_b16_512_68580893.npz', 768, 64, 32_000),
171
+ siglip_l16l_256=_make_config('siglip', 'L/16', 256, 'L', 'gs://big_vision/siglip/webli_en_l16_256_60552751.npz', 1024, 64, 32_000),
172
+ siglip_l16l_384=_make_config('siglip', 'L/16', 384, 'L', 'gs://big_vision/siglip/webli_en_l16_384_63634585.npz', 1024, 64, 32_000),
173
+ siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
174
+ siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
175
+ )
176
+
177
+
178
+ @functools.cache
179
+ def load_tokenizer_sp(name_or_path):
180
+ tok = sentencepiece.SentencePieceProcessor()
181
+ path = {
182
+ 'c4_en': 'gs://t5-data/vocabs/cc_en.32000/sentencepiece.model',
183
+ }.get(name_or_path, name_or_path)
184
+ tok.LoadFromSerializedProto(gfile.GFile(path, 'rb').read())
185
+ return tok
186
+
187
+
188
+ @functools.cache
189
+ def load_tokenizer_bert(path):
190
+ tok = sentencepiece.SentencePieceProcessor()
191
+ if path.startswith('gs://'):
192
+ dst = tempfile.mktemp()
193
+ gfile.copy(path, dst)
194
+ path = dst
195
+ return transformers.BertTokenizer(path, do_lower_case=True)
196
+
197
+
198
+ def load_model(config, check_params=False):
199
+ """Loads `big_vision` model."""
200
+ assert isinstance(config, ContrastiveModelConfig), type(config)
201
+
202
+ cfg = ml_collections.ConfigDict()
203
+ cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
204
+ if config.family == ContrastiveModelFamily.LIT:
205
+ cfg.text_model = 'proj.flaxformer.bert'
206
+ cfg.image = dict(variant=config.variant, pool_type='tok', head_zeroinit=False)
207
+ bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
208
+ cfg.text = dict(config=bert_config, head_zeroinit=False)
209
+ tokenizer_bert = load_tokenizer_bert(config.tokenizer)
210
+ tokenizer_sp = None
211
+ if config.variant == 'L/16':
212
+ cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
213
+ else:
214
+ cfg.out_dim = (config.embdim, config.embdim) # (image_out_dim, text_out_dim)
215
+ else:
216
+ cfg.image = dict(variant=config.variant, pool_type='map')
217
+ cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default
218
+ cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
219
+ cfg.bias_init = -10.0
220
+ tokenizer_sp = load_tokenizer_sp(config.tokenizer)
221
+ tokenizer_bert = None
222
+ cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
223
+ cfg.temperature_init = 10.0
224
+
225
+ model_mod = importlib.import_module(
226
+ 'big_vision.models.proj.image_text.two_towers')
227
+ model = model_mod.Model(**cfg)
228
+
229
+ init_params = None # Faster but bypasses loading sanity-checks.
230
+ if check_params:
231
+ imgs = jnp.zeros([1, config.res, config.res, 3])
232
+ txts = jnp.zeros([1, config.seqlen], jnp.int32)
233
+ init_params = model.init(jax.random.PRNGKey(0), imgs, txts)['params']
234
+ params_cpu = model_mod.load(init_params, config.ckpt, cfg)
235
+
236
+ return params_cpu, ContrastiveModel(
237
+ config=config,
238
+ flax_module=model,
239
+ tokenizer_sp=tokenizer_sp,
240
+ tokenizer_bert=tokenizer_bert,
241
+ )
gradio_helpers.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Gradio utilities.
2
+
3
+ Note that the optional `progress` parameter can be both a `tqdm` module or a
4
+ `gr.Progress` instance.
5
+ """
6
+
7
+ import concurrent.futures
8
+ import contextlib
9
+ import glob
10
+ import hashlib
11
+ import logging
12
+ import os
13
+ import tempfile
14
+ import time
15
+ import urllib.request
16
+
17
+ import jax
18
+ import numpy as np
19
+ from tensorflow.io import gfile
20
+
21
+
22
+ @contextlib.contextmanager
23
+ def timed(name):
24
+ t0 = time.monotonic()
25
+ timing = dict(dt=None)
26
+ try:
27
+ yield timing
28
+ finally:
29
+ timing['secs'] = time.monotonic() - t0
30
+ logging.info('Timed %s: %.1f secs', name, timing['secs'])
31
+
32
+
33
+
34
+ def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False):
35
+ """Copies a file with progress bar.
36
+
37
+ Args:
38
+ src: Source file (readable by `tf.io.gfile`) or URL.
39
+ dst: Destination file. Path must be readable by `tf.io.gfile`.
40
+ progress: An object with a `.tqdm` attribute, or `None`.
41
+ block_size: Size of individual blocks to be read/written.
42
+ """
43
+ if os.path.dirname(dst):
44
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
45
+ if os.path.exists(dst) and not overwrite:
46
+ return
47
+
48
+ if src.startswith('http://') or src.startswith('https://'):
49
+ opener = urllib.request.urlopen
50
+ request = urllib.request.Request(src, method='HEAD')
51
+ response = urllib.request.urlopen(request)
52
+ content_length = response.headers.get('Content-Length')
53
+ n = int(np.ceil(int(content_length) / block_size))
54
+ print('content_length', content_length)
55
+ else:
56
+ opener = lambda path: gfile.GFile(path, 'rb')
57
+ stats = gfile.stat(src)
58
+ n = int(np.ceil(stats.length / block_size))
59
+
60
+ if progress is None:
61
+ range_or_trange = range
62
+ else:
63
+ range_or_trange = lambda n: progress.tqdm(list(range(n)), desc='download')
64
+
65
+ with opener(src) as fin:
66
+ with gfile.GFile(f'{dst}-PARTIAL', 'wb') as fout:
67
+ for _ in range_or_trange(n):
68
+ fout.write(fin.read(block_size))
69
+ gfile.rename(f'{dst}-PARTIAL', dst)
70
+
71
+
72
+ _estimated_real = [(10, 10)]
73
+ _memory_cache = {}
74
+
75
+
76
+ def get_with_progress(getter, secs, progress, step=0.1):
77
+ """Returns result from `getter` while showing a progress bar."""
78
+ with concurrent.futures.ThreadPoolExecutor() as executor:
79
+ future = executor.submit(getter)
80
+ for _ in progress.tqdm(list(range(int(np.ceil(secs/step)))), desc='read'):
81
+ if not future.done():
82
+ time.sleep(step)
83
+ return future.result()
84
+
85
+
86
+ def _get_array_sizes(tree):
87
+ return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
88
+
89
+
90
+ def get_memory_cache(key, getter, max_cache_size_bytes, progress=None, estimated_secs=None):
91
+ """Keeps cache below specified size by removing elements not last accessed."""
92
+ if key in _memory_cache:
93
+ _memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
94
+ return _memory_cache[key]
95
+
96
+ est, real = zip(*_estimated_real)
97
+ if estimated_secs is None:
98
+ estimated_secs = sum(est) / len(est)
99
+ with timed(f'loading {key}') as timing:
100
+ estimated_secs *= sum(real) / sum(est)
101
+ _memory_cache[key] = get_with_progress(getter, estimated_secs, progress)
102
+ _estimated_real.append((estimated_secs, timing['secs']))
103
+
104
+ sz = sum(_get_array_sizes(list(_memory_cache.values())))
105
+ logging.info('New memory cache size=%.1f MB', sz/1e6)
106
+
107
+ while sz > max_cache_size_bytes:
108
+ k, v = next(iter(_memory_cache.items()))
109
+ if k == key:
110
+ break
111
+ s = sum(_get_array_sizes(v))
112
+ logging.info('Removing %s from memory cache (%.1f MB)', k, s/1e6)
113
+ _memory_cache.pop(k)
114
+ sz -= s
115
+
116
+ return _memory_cache[key]
117
+
118
+
119
+ def get_memory_cache_info():
120
+ """Returns number of items and total size in bytes."""
121
+ sizes = _get_array_sizes(_memory_cache)
122
+ return len(_memory_cache), sum(sizes)
123
+
124
+
125
+ CACHE_DIR = os.path.join(tempfile.gettempdir(), 'downloads_cache')
126
+
127
+
128
+ def get_disk_cache(path_or_url, max_cache_size_bytes, progress=None):
129
+ """Keeps cache below specified size by removing elements not last accessed."""
130
+ fname = os.path.basename(path_or_url)
131
+ path_hash = hashlib.md5(path_or_url.encode()).hexdigest() + '__' + fname
132
+ dst = os.path.join(CACHE_DIR, path_hash, fname)
133
+ if os.path.exists(dst):
134
+ return dst
135
+
136
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
137
+ with timed(f'copying {path_or_url}'):
138
+ copy_file(path_or_url, dst, progress=progress)
139
+
140
+ atimes_sizes_paths = sorted([
141
+ (os.path.getatime(p), os.path.getsize(p), p)
142
+ for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
143
+ if os.path.isfile(p)
144
+ ])
145
+ sz = sum(sz for _, sz, _ in atimes_sizes_paths)
146
+ logging.info('New disk cache size=%.1f MB', sz/1e6)
147
+
148
+ while sz > max_cache_size_bytes:
149
+ _, s, path = atimes_sizes_paths.pop(0)
150
+ if path == dst:
151
+ break
152
+ logging.info('Removing %s from memory cache (%.1f MB)', fname, sz/1e6)
153
+ os.unlink(fname)
154
+ sz -= s
155
+
156
+ return dst
157
+
158
+
159
+ def get_disk_cache_info():
160
+ """Returns number of items and total size in bytes."""
161
+ sizes = [
162
+ os.path.getsize(p)
163
+ for p in glob.glob(os.path.join(CACHE_DIR, '*', '*'))
164
+ ]
165
+ return len(sizes), sum(sizes)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aqtp # for flaxformer
2
+ einops
3
+ flax
4
+ gradio
5
+ jax
6
+ jaxlib
7
+ ml_collections
8
+ numpy
9
+ Pillow
10
+ sentencepiece
11
+ transformers # for transformers.BertTokenizer
12
+ tensorflow # for tf.io.gfile