boris commited on
Commit
91d8a29
1 Parent(s): 1d51d0b

feat: cleanup

Browse files
Files changed (1) hide show
  1. dev/inference/wandb-backend.ipynb +32 -109
dev/inference/wandb-backend.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
@@ -26,51 +26,42 @@
26
  },
27
  {
28
  "cell_type": "code",
29
- "execution_count": 2,
30
- "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
- "run_ids = ['rjf3rycy']\n",
35
- "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
36
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
- "normalize_text = True\n",
38
- "latest_only = False # log only latest or all versions\n",
39
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
40
- "add_clip_32 = False"
41
  ]
42
  },
43
  {
44
  "cell_type": "code",
45
- "execution_count": 3,
46
- "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
47
  "metadata": {},
48
  "outputs": [],
49
  "source": [
50
- "run_ids = ['3kaut6e8']\n",
51
- "ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
52
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
53
- "normalize_text = False\n",
54
- "latest_only = True # log only latest or all versions\n",
55
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
56
- "add_clip_32 = True"
57
  ]
58
  },
59
  {
60
  "cell_type": "code",
61
- "execution_count": 4,
62
  "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
63
  "metadata": {},
64
- "outputs": [
65
- {
66
- "name": "stderr",
67
- "output_type": "stream",
68
- "text": [
69
- "INFO:absl:Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: \n",
70
- "INFO:absl:Unable to initialize backend 'gpu': NOT_FOUND: Could not find registered platform with name: \"cuda\". Available platform names are: TPU Interpreter Host\n"
71
- ]
72
- }
73
- ],
74
  "source": [
75
  "batch_size = 8\n",
76
  "num_images = 128\n",
@@ -84,18 +75,10 @@
84
  },
85
  {
86
  "cell_type": "code",
87
- "execution_count": 5,
88
  "id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
89
  "metadata": {},
90
- "outputs": [
91
- {
92
- "name": "stdout",
93
- "output_type": "stream",
94
- "text": [
95
- "Working with z of shape (1, 256, 16, 16) = 65536 dimensions.\n"
96
- ]
97
- }
98
- ],
99
  "source": [
100
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
101
  "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
@@ -111,7 +94,7 @@
111
  },
112
  {
113
  "cell_type": "code",
114
- "execution_count": 6,
115
  "id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
116
  "metadata": {},
117
  "outputs": [],
@@ -134,29 +117,17 @@
134
  },
135
  {
136
  "cell_type": "code",
137
- "execution_count": 7,
138
  "id": "ebf4f7bf-2efa-46cc-b3f4-2d7a54f7b2cb",
139
  "metadata": {},
140
- "outputs": [
141
- {
142
- "data": {
143
- "text/plain": [
144
- "ShardedDeviceArray([4.6051702, 4.6051702, 4.6051702, 4.6051702, 4.6051702,\n",
145
- " 4.6051702, 4.6051702, 4.6051702], dtype=float32)"
146
- ]
147
- },
148
- "execution_count": 7,
149
- "metadata": {},
150
- "output_type": "execute_result"
151
- }
152
- ],
153
  "source": [
154
  "clip_params['logit_scale']"
155
  ]
156
  },
157
  {
158
  "cell_type": "code",
159
- "execution_count": 8,
160
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
161
  "metadata": {},
162
  "outputs": [],
@@ -172,7 +143,7 @@
172
  },
173
  {
174
  "cell_type": "code",
175
- "execution_count": 9,
176
  "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
177
  "metadata": {},
178
  "outputs": [],
@@ -189,7 +160,7 @@
189
  },
190
  {
191
  "cell_type": "code",
192
- "execution_count": 10,
193
  "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
194
  "metadata": {},
195
  "outputs": [],
@@ -202,7 +173,7 @@
202
  },
203
  {
204
  "cell_type": "code",
205
- "execution_count": 11,
206
  "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
207
  "metadata": {},
208
  "outputs": [],
@@ -218,7 +189,7 @@
218
  },
219
  {
220
  "cell_type": "code",
221
- "execution_count": 12,
222
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
223
  "metadata": {},
224
  "outputs": [],
@@ -241,7 +212,7 @@
241
  },
242
  {
243
  "cell_type": "code",
244
- "execution_count": 13,
245
  "id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
246
  "metadata": {},
247
  "outputs": [],
@@ -252,19 +223,10 @@
252
  },
253
  {
254
  "cell_type": "code",
255
- "execution_count": 14,
256
  "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
257
  "metadata": {},
258
- "outputs": [
259
- {
260
- "ename": "SyntaxError",
261
- "evalue": "EOL while scanning string literal (1745443972.py, line 60)",
262
- "output_type": "error",
263
- "traceback": [
264
- "\u001b[0;36m File \u001b[0;32m\"/tmp/ipykernel_402605/1745443972.py\"\u001b[0;36m, line \u001b[0;32m60\u001b[0m\n\u001b[0;31m for i in tqdm(range(num_images // jax.device_count()), desc='Generating Images):\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m EOL while scanning string literal\n"
265
- ]
266
- }
267
- ],
268
  "source": [
269
  "artifact_versions = get_artifact_versions(run_id, latest_only)\n",
270
  "last_inference_version = get_last_inference_version(run_id)\n",
@@ -281,8 +243,7 @@
281
  " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
282
  " \n",
283
  " if latest_only:\n",
284
- " pass\n",
285
- " #assert last_inference_version is None or version > last_inference_version\n",
286
  " else:\n",
287
  " if last_inference_version is None:\n",
288
  " # we should start from v0\n",
@@ -325,7 +286,7 @@
325
  "\n",
326
  " # generate images\n",
327
  " images = []\n",
328
- " for i in tqdm(range(num_images // jax.device_count()), desc='Generating Images):\n",
329
  " key, subkey = jax.random.split(key)\n",
330
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
331
  " encoded_images = encoded_images.sequences[..., 1:]\n",
@@ -386,36 +347,6 @@
386
  " run = None # ensure we don't log on this run"
387
  ]
388
  },
389
- {
390
- "cell_type": "code",
391
- "execution_count": null,
392
- "id": "fdcd09d6-079c-461a-a81a-d9e650d3b099",
393
- "metadata": {},
394
- "outputs": [],
395
- "source": [
396
- "p_clip32"
397
- ]
398
- },
399
- {
400
- "cell_type": "code",
401
- "execution_count": null,
402
- "id": "7d86ceee-c9ac-4860-abad-410cadd16c3c",
403
- "metadata": {},
404
- "outputs": [],
405
- "source": [
406
- "clip_inputs['attention_mask'].shape, clip_inputs['pixel_values'].shape"
407
- ]
408
- },
409
- {
410
- "cell_type": "code",
411
- "execution_count": null,
412
- "id": "fbba4858-da2d-4dd5-97b7-ce3ab4746f96",
413
- "metadata": {},
414
- "outputs": [],
415
- "source": [
416
- "clip_inputs['input_ids'].shape"
417
- ]
418
- },
419
  {
420
  "cell_type": "code",
421
  "execution_count": null,
@@ -428,14 +359,6 @@
428
  " for run in tqdm(runs):\n",
429
  " log_run(run)"
430
  ]
431
- },
432
- {
433
- "cell_type": "code",
434
- "execution_count": null,
435
- "id": "a7a5fdf5-3c6e-421b-96a8-5115f730328c",
436
- "metadata": {},
437
- "outputs": [],
438
- "source": []
439
  }
440
  ],
441
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
 
26
  },
27
  {
28
  "cell_type": "code",
29
+ "execution_count": null,
30
+ "id": "23e00271-941c-4e1b-b6a9-107a1b77324d",
31
  "metadata": {},
32
  "outputs": [],
33
  "source": [
34
+ "run_ids = ['3kaut6e8']\n",
35
+ "ENTITY, PROJECT = 'wandb', 'hf-flax-dalle-mini'\n",
36
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
37
+ "normalize_text = False\n",
38
+ "latest_only = True # log only latest or all versions\n",
39
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
40
+ "add_clip_32 = True"
41
  ]
42
  },
43
  {
44
  "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "92f4557c-fd7f-4edc-81c2-de0b0a10c270",
47
  "metadata": {},
48
  "outputs": [],
49
  "source": [
50
+ "run_ids = ['k76r0v39']\n",
51
+ "ENTITY, PROJECT = 'dalle-mini', 'dalle-mini' # used only for training run\n",
52
  "VQGAN_REPO, VQGAN_COMMIT_ID = 'dalle-mini/vqgan_imagenet_f16_16384', None\n",
53
+ "normalize_text = True\n",
54
+ "latest_only = True # log only latest or all versions\n",
55
  "suffix = '' # mainly for duplicate inference runs with a deleted version\n",
56
+ "add_clip_32 = False"
57
  ]
58
  },
59
  {
60
  "cell_type": "code",
61
+ "execution_count": null,
62
  "id": "93b2e24b-f0e5-4abe-a3ec-0aa834cc3bf3",
63
  "metadata": {},
64
+ "outputs": [],
 
 
 
 
 
 
 
 
 
65
  "source": [
66
  "batch_size = 8\n",
67
  "num_images = 128\n",
 
75
  },
76
  {
77
  "cell_type": "code",
78
+ "execution_count": null,
79
  "id": "c6a878fa-4bf5-4978-abb5-e235841d765b",
80
  "metadata": {},
81
+ "outputs": [],
 
 
 
 
 
 
 
 
82
  "source": [
83
  "vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)\n",
84
  "clip = FlaxCLIPModel.from_pretrained(\"openai/clip-vit-base-patch16\")\n",
 
94
  },
95
  {
96
  "cell_type": "code",
97
+ "execution_count": null,
98
  "id": "a500dd07-dbc3-477d-80d4-2b73a3b83ef3",
99
  "metadata": {},
100
  "outputs": [],
 
117
  },
118
  {
119
  "cell_type": "code",
120
+ "execution_count": null,
121
  "id": "ebf4f7bf-2efa-46cc-b3f4-2d7a54f7b2cb",
122
  "metadata": {},
123
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
124
  "source": [
125
  "clip_params['logit_scale']"
126
  ]
127
  },
128
  {
129
  "cell_type": "code",
130
+ "execution_count": null,
131
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
132
  "metadata": {},
133
  "outputs": [],
 
143
  },
144
  {
145
  "cell_type": "code",
146
+ "execution_count": null,
147
  "id": "f3e02d9d-4ee1-49e7-a7bc-4d8b139e9614",
148
  "metadata": {},
149
  "outputs": [],
 
160
  },
161
  {
162
  "cell_type": "code",
163
+ "execution_count": null,
164
  "id": "f0d7ed17-7abb-4a31-ab3c-a12b9039a570",
165
  "metadata": {},
166
  "outputs": [],
 
173
  },
174
  {
175
  "cell_type": "code",
176
+ "execution_count": null,
177
  "id": "7e784a43-626d-4e8d-9e47-a23775b2f35f",
178
  "metadata": {},
179
  "outputs": [],
 
189
  },
190
  {
191
  "cell_type": "code",
192
+ "execution_count": null,
193
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
194
  "metadata": {},
195
  "outputs": [],
 
212
  },
213
  {
214
  "cell_type": "code",
215
+ "execution_count": null,
216
  "id": "23b2444c-67a9-44d7-abd1-187ed83a9431",
217
  "metadata": {},
218
  "outputs": [],
 
223
  },
224
  {
225
  "cell_type": "code",
226
+ "execution_count": null,
227
  "id": "bba70f33-af8b-4eb3-9973-7be672301a0b",
228
  "metadata": {},
229
+ "outputs": [],
 
 
 
 
 
 
 
 
 
230
  "source": [
231
  "artifact_versions = get_artifact_versions(run_id, latest_only)\n",
232
  "last_inference_version = get_last_inference_version(run_id)\n",
 
243
  " columns = ['Caption'] + [f'Image {i+1}' for i in range(top_k)] + [f'Score {i+1}' for i in range(top_k)]\n",
244
  " \n",
245
  " if latest_only:\n",
246
+ " assert last_inference_version is None or version > last_inference_version\n",
 
247
  " else:\n",
248
  " if last_inference_version is None:\n",
249
  " # we should start from v0\n",
 
286
  "\n",
287
  " # generate images\n",
288
  " images = []\n",
289
+ " for i in tqdm(range(num_images // jax.device_count()), desc='Generating Images'):\n",
290
  " key, subkey = jax.random.split(key)\n",
291
  " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params)\n",
292
  " encoded_images = encoded_images.sequences[..., 1:]\n",
 
347
  " run = None # ensure we don't log on this run"
348
  ]
349
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
  {
351
  "cell_type": "code",
352
  "execution_count": null,
 
359
  " for run in tqdm(runs):\n",
360
  " log_run(run)"
361
  ]
 
 
 
 
 
 
 
 
362
  }
363
  ],
364
  "metadata": {