boris commited on
Commit
b8bbe68
·
1 Parent(s): 378a628

feat: add functions

Browse files
Files changed (1) hide show
  1. dev/inference/wandb-backend.ipynb +85 -24
dev/inference/wandb-backend.ipynb CHANGED
@@ -2,13 +2,15 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
  "import csv\n",
11
  "import tempfile\n",
 
 
12
  "import wandb\n",
13
  "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
14
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
@@ -42,26 +44,82 @@
42
  },
43
  {
44
  "cell_type": "code",
45
- "execution_count": null,
46
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
47
  "metadata": {},
48
  "outputs": [],
49
  "source": [
50
  "with open('samples.csv', newline='', encoding='utf8') as f:\n",
51
- " reader = csv.reader(f)\n",
 
52
  " for row in reader:\n",
53
- " breakpoint()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  ]
55
  },
56
  {
57
  "cell_type": "code",
58
  "execution_count": null,
 
 
 
 
 
 
 
 
 
 
59
  "id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
60
  "metadata": {},
61
  "outputs": [],
62
  "source": [
 
63
  "wandb_run = wandb_runs[0]\n",
64
- "api = wandb.Api()"
65
  ]
66
  },
67
  {
@@ -280,27 +338,30 @@
280
  },
281
  {
282
  "cell_type": "code",
283
- "execution_count": null,
284
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
285
  "metadata": {},
286
  "outputs": [],
287
- "source": []
288
- },
289
- {
290
- "cell_type": "code",
291
- "execution_count": null,
292
- "id": "43d2a99b-3501-4b30-b041-0fdeead12380",
293
- "metadata": {},
294
- "outputs": [],
295
- "source": []
296
- },
297
- {
298
- "cell_type": "code",
299
- "execution_count": null,
300
- "id": "06472541-75f1-44e5-841f-a4a26a0493e3",
301
- "metadata": {},
302
- "outputs": [],
303
- "source": []
 
 
 
304
  },
305
  {
306
  "cell_type": "code",
@@ -323,7 +384,7 @@
323
  {
324
  "cell_type": "code",
325
  "execution_count": null,
326
- "id": "b37c1714-d54b-479e-a9e8-740affc0de2c",
327
  "metadata": {},
328
  "outputs": [],
329
  "source": []
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 197,
6
  "id": "4ff2a984-b8b2-4a69-89cf-0d16da2393c8",
7
  "metadata": {},
8
  "outputs": [],
9
  "source": [
10
  "import csv\n",
11
  "import tempfile\n",
12
+ "from functools import partial\n",
13
+ "import jax\n",
14
  "import wandb\n",
15
  "from dalle_mini.model import CustomFlaxBartForConditionalGeneration\n",
16
  "from vqgan_jax.modeling_flax_vqgan import VQModel\n",
 
44
  },
45
  {
46
  "cell_type": "code",
47
+ "execution_count": 245,
48
  "id": "e57797ab-0b3a-4490-be58-03d8d1c23fe9",
49
  "metadata": {},
50
  "outputs": [],
51
  "source": [
52
  "with open('samples.csv', newline='', encoding='utf8') as f:\n",
53
+ " reader = csv.DictReader(f)\n",
54
+ " samples = []\n",
55
  " for row in reader:\n",
56
+ " samples.append(row)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": 246,
62
+ "id": "f75b2869-fc25-4f56-b937-e97bbb712ede",
63
+ "metadata": {},
64
+ "outputs": [
65
+ {
66
+ "data": {
67
+ "text/plain": [
68
+ "101"
69
+ ]
70
+ },
71
+ "execution_count": 246,
72
+ "metadata": {},
73
+ "output_type": "execute_result"
74
+ }
75
+ ],
76
+ "source": [
77
+ "len(samples)"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 248,
83
+ "id": "2ea0b166-a20c-4d78-bffb-b792ca512d17",
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "data": {
88
+ "text/plain": [
89
+ "104"
90
+ ]
91
+ },
92
+ "execution_count": 248,
93
+ "metadata": {},
94
+ "output_type": "execute_result"
95
+ }
96
+ ],
97
+ "source": [
98
+ "samples_to_add = ['empty'] * (-len(samples) % 8)\n",
99
+ "samples.extend(samples_to_add)\n",
100
+ "len(samples)"
101
  ]
102
  },
103
  {
104
  "cell_type": "code",
105
  "execution_count": null,
106
+ "id": "a2c629e9-1a82-40c6-a260-ca1780c19a2e",
107
+ "metadata": {},
108
+ "outputs": [],
109
+ "source": [
110
+ "api = wandb.Api()"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 204,
116
  "id": "3ffb1d09-bd1c-4f57-9ae5-3eda6f7d3a08",
117
  "metadata": {},
118
  "outputs": [],
119
  "source": [
120
+ "# TODO: iterate on runs\n",
121
  "wandb_run = wandb_runs[0]\n",
122
+ "functions_pmapped = False"
123
  ]
124
  },
125
  {
 
338
  },
339
  {
340
  "cell_type": "code",
341
+ "execution_count": 207,
342
  "id": "d1cc9993-1bfc-4ec6-a004-c056189c42ac",
343
  "metadata": {},
344
  "outputs": [],
345
+ "source": [
346
+ "# function to generate encoded images\n",
347
+ "# we should generate this function only once per run\n",
348
+ "if not functions_pmapped:\n",
349
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
350
+ " def p_generate(tokenized_prompt, key, params):\n",
351
+ " return model.generate(\n",
352
+ " **tokenized_prompt,\n",
353
+ " do_sample=True,\n",
354
+ " num_beams=1,\n",
355
+ " prng_key=key,\n",
356
+ " params=params\n",
357
+ " )\n",
358
+ " \n",
359
+ " @partial(jax.pmap, axis_name=\"batch\")\n",
360
+ " def p_decode(indices, params):\n",
361
+ " return vqgan.decode_code(indices, params=params)\n",
362
+ " \n",
363
+ " functions_pmapped = False"
364
+ ]
365
  },
366
  {
367
  "cell_type": "code",
 
384
  {
385
  "cell_type": "code",
386
  "execution_count": null,
387
+ "id": "e79ac8f2-adc2-4a16-970c-dadcceadd566",
388
  "metadata": {},
389
  "outputs": [],
390
  "source": []