boris commited on
Commit
25862e8
1 Parent(s): 89cf9ea

fix: style

Browse files
tools/inference/inference_pipeline.ipynb CHANGED
@@ -70,15 +70,15 @@
70
  "# Model references\n",
71
  "\n",
72
  "# dalle-mini\n",
73
- "DALLE_MODEL = 'dalle-mini/dalle-mini/model-3bqwu04f:latest' # can be wandb artifact or 🤗 Hub or local folder\n",
74
  "DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
75
  "\n",
76
  "# VQGAN model\n",
77
- "VQGAN_REPO = 'dalle-mini/vqgan_imagenet_f16_16384'\n",
78
- "VQGAN_COMMIT_ID = 'e93a26e7707683d349bf5d5c41c5b0ef69b677a9'\n",
79
  "\n",
80
  "# CLIP model\n",
81
- "CLIP_REPO = 'openai/clip-vit-base-patch16'\n",
82
  "CLIP_COMMIT_ID = None"
83
  ]
84
  },
@@ -121,18 +121,28 @@
121
  "import wandb\n",
122
  "\n",
123
  "# Load dalle-mini\n",
124
- "if ':' in DALLE_MODEL:\n",
125
  " # wandb artifact\n",
126
  " artifact = wandb.Api().artifact(DALLE_MODEL)\n",
127
  " # we only download required files (no need for opt_state which is large)\n",
128
- " model_files = ['config.json', 'flax_model.msgpack', 'merges.txt', 'special_tokens_map.json', 'tokenizer.json', 'tokenizer_config.json', 'vocab.json']\n",
 
 
 
 
 
 
 
 
129
  " for f in model_files:\n",
130
- " artifact.get_path(f).download('model')\n",
131
- " model = DalleBart.from_pretrained('model', dtype=dtype, abstract_init=True)\n",
132
- " tokenizer = AutoTokenizer.from_pretrained('model')\n",
133
  "else:\n",
134
  " # local folder or 🤗 Hub\n",
135
- " model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True)\n",
 
 
136
  " tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
137
  "\n",
138
  "# Load VQGAN\n",
@@ -191,7 +201,7 @@
191
  "from functools import partial\n",
192
  "\n",
193
  "# model inference\n",
194
- "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3,4))\n",
195
  "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
196
  " return model.generate(\n",
197
  " **tokenized_prompt,\n",
@@ -203,11 +213,13 @@
203
  " top_p=top_p\n",
204
  " )\n",
205
  "\n",
 
206
  "# decode images\n",
207
  "@partial(jax.pmap, axis_name=\"batch\")\n",
208
  "def p_decode(indices, params):\n",
209
  " return vqgan.decode_code(indices, params=params)\n",
210
  "\n",
 
211
  "# score images\n",
212
  "@partial(jax.pmap, axis_name=\"batch\")\n",
213
  "def p_clip(inputs, params):\n",
@@ -235,7 +247,7 @@
235
  "import random\n",
236
  "\n",
237
  "# create a random key\n",
238
- "seed = random.randint(0, 2**32-1)\n",
239
  "key = jax.random.PRNGKey(seed)"
240
  ]
241
  },
@@ -287,7 +299,7 @@
287
  },
288
  "outputs": [],
289
  "source": [
290
- "prompt = 'a red T-shirt'"
291
  ]
292
  },
293
  {
@@ -323,7 +335,13 @@
323
  "repeated_prompts = [processed_prompt] * jax.device_count()\n",
324
  "\n",
325
  "# tokenize\n",
326
- "tokenized_prompt = tokenizer(repeated_prompts, return_tensors='jax', padding='max_length', truncation=True, max_length=128).data\n",
 
 
 
 
 
 
327
  "tokenized_prompt"
328
  ]
329
  },
@@ -408,12 +426,14 @@
408
  " # get a new key\n",
409
  " key, subkey = jax.random.split(key)\n",
410
  " # generate images\n",
411
- " encoded_images = p_generate(tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p)\n",
 
 
412
  " # remove BOS\n",
413
  " encoded_images = encoded_images.sequences[..., 1:]\n",
414
  " # decode images\n",
415
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
416
- " decoded_images = decoded_images.clip(0., 1.).reshape((-1, 256, 256, 3))\n",
417
  " for img in decoded_images:\n",
418
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
419
  ]
@@ -436,7 +456,14 @@
436
  "outputs": [],
437
  "source": [
438
  "# get clip scores\n",
439
- "clip_inputs = processor(text=[prompt] * jax.device_count(), images=images, return_tensors='np', padding='max_length', max_length=77, truncation=True).data\n",
 
 
 
 
 
 
 
440
  "logits = p_clip(shard(clip_inputs), clip_params)\n",
441
  "logits = logits.squeeze().flatten()"
442
  ]
@@ -458,10 +485,10 @@
458
  },
459
  "outputs": [],
460
  "source": [
461
- "print(f'Prompt: {prompt}\\n')\n",
462
  "for idx in logits.argsort()[::-1]:\n",
463
  " display(images[idx])\n",
464
- " print(f'Score: {logits[idx]:.2f}\\n')"
465
  ]
466
  }
467
  ],
 
70
  "# Model references\n",
71
  "\n",
72
  "# dalle-mini\n",
73
+ "DALLE_MODEL = \"dalle-mini/dalle-mini/model-3bqwu04f:latest\" # can be wandb artifact or 🤗 Hub or local folder\n",
74
  "DALLE_COMMIT_ID = None # used only with 🤗 hub\n",
75
  "\n",
76
  "# VQGAN model\n",
77
+ "VQGAN_REPO = \"dalle-mini/vqgan_imagenet_f16_16384\"\n",
78
+ "VQGAN_COMMIT_ID = \"e93a26e7707683d349bf5d5c41c5b0ef69b677a9\"\n",
79
  "\n",
80
  "# CLIP model\n",
81
+ "CLIP_REPO = \"openai/clip-vit-base-patch16\"\n",
82
  "CLIP_COMMIT_ID = None"
83
  ]
84
  },
 
121
  "import wandb\n",
122
  "\n",
123
  "# Load dalle-mini\n",
124
+ "if \":\" in DALLE_MODEL:\n",
125
  " # wandb artifact\n",
126
  " artifact = wandb.Api().artifact(DALLE_MODEL)\n",
127
  " # we only download required files (no need for opt_state which is large)\n",
128
+ " model_files = [\n",
129
+ " \"config.json\",\n",
130
+ " \"flax_model.msgpack\",\n",
131
+ " \"merges.txt\",\n",
132
+ " \"special_tokens_map.json\",\n",
133
+ " \"tokenizer.json\",\n",
134
+ " \"tokenizer_config.json\",\n",
135
+ " \"vocab.json\",\n",
136
+ " ]\n",
137
  " for f in model_files:\n",
138
+ " artifact.get_path(f).download(\"model\")\n",
139
+ " model = DalleBart.from_pretrained(\"model\", dtype=dtype, abstract_init=True)\n",
140
+ " tokenizer = AutoTokenizer.from_pretrained(\"model\")\n",
141
  "else:\n",
142
  " # local folder or 🤗 Hub\n",
143
+ " model = DalleBart.from_pretrained(\n",
144
+ " DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=dtype, abstract_init=True\n",
145
+ " )\n",
146
  " tokenizer = AutoTokenizer.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)\n",
147
  "\n",
148
  "# Load VQGAN\n",
 
201
  "from functools import partial\n",
202
  "\n",
203
  "# model inference\n",
204
+ "@partial(jax.pmap, axis_name=\"batch\", static_broadcasted_argnums=(3, 4))\n",
205
  "def p_generate(tokenized_prompt, key, params, top_k, top_p):\n",
206
  " return model.generate(\n",
207
  " **tokenized_prompt,\n",
 
213
  " top_p=top_p\n",
214
  " )\n",
215
  "\n",
216
+ "\n",
217
  "# decode images\n",
218
  "@partial(jax.pmap, axis_name=\"batch\")\n",
219
  "def p_decode(indices, params):\n",
220
  " return vqgan.decode_code(indices, params=params)\n",
221
  "\n",
222
+ "\n",
223
  "# score images\n",
224
  "@partial(jax.pmap, axis_name=\"batch\")\n",
225
  "def p_clip(inputs, params):\n",
 
247
  "import random\n",
248
  "\n",
249
  "# create a random key\n",
250
+ "seed = random.randint(0, 2 ** 32 - 1)\n",
251
  "key = jax.random.PRNGKey(seed)"
252
  ]
253
  },
 
299
  },
300
  "outputs": [],
301
  "source": [
302
+ "prompt = \"a red T-shirt\""
303
  ]
304
  },
305
  {
 
335
  "repeated_prompts = [processed_prompt] * jax.device_count()\n",
336
  "\n",
337
  "# tokenize\n",
338
+ "tokenized_prompt = tokenizer(\n",
339
+ " repeated_prompts,\n",
340
+ " return_tensors=\"jax\",\n",
341
+ " padding=\"max_length\",\n",
342
+ " truncation=True,\n",
343
+ " max_length=128,\n",
344
+ ").data\n",
345
  "tokenized_prompt"
346
  ]
347
  },
 
426
  " # get a new key\n",
427
  " key, subkey = jax.random.split(key)\n",
428
  " # generate images\n",
429
+ " encoded_images = p_generate(\n",
430
+ " tokenized_prompt, shard_prng_key(subkey), model_params, gen_top_k, gen_top_p\n",
431
+ " )\n",
432
  " # remove BOS\n",
433
  " encoded_images = encoded_images.sequences[..., 1:]\n",
434
  " # decode images\n",
435
  " decoded_images = p_decode(encoded_images, vqgan_params)\n",
436
+ " decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))\n",
437
  " for img in decoded_images:\n",
438
  " images.append(Image.fromarray(np.asarray(img * 255, dtype=np.uint8)))"
439
  ]
 
456
  "outputs": [],
457
  "source": [
458
  "# get clip scores\n",
459
+ "clip_inputs = processor(\n",
460
+ " text=[prompt] * jax.device_count(),\n",
461
+ " images=images,\n",
462
+ " return_tensors=\"np\",\n",
463
+ " padding=\"max_length\",\n",
464
+ " max_length=77,\n",
465
+ " truncation=True,\n",
466
+ ").data\n",
467
  "logits = p_clip(shard(clip_inputs), clip_params)\n",
468
  "logits = logits.squeeze().flatten()"
469
  ]
 
485
  },
486
  "outputs": [],
487
  "source": [
488
+ "print(f\"Prompt: {prompt}\\n\")\n",
489
  "for idx in logits.argsort()[::-1]:\n",
490
  " display(images[idx])\n",
491
+ " print(f\"Score: {logits[idx]:.2f}\\n\")"
492
  ]
493
  }
494
  ],
tools/train/train.py CHANGED
@@ -219,9 +219,7 @@ class TrainingArguments:
219
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
220
  },
221
  )
222
- weight_decay: float = field(
223
- default=None, metadata={"help": "Weight decay."}
224
- )
225
  beta1: float = field(
226
  default=0.9,
227
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
@@ -237,13 +235,15 @@ class TrainingArguments:
237
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
238
  )
239
  block_size: int = field(
240
- default=1024, metadata={"help": "Chunked size for large layers with Distributed Shampoo."}
 
241
  )
242
  preconditioning_compute_steps: int = field(
243
  default=10, metadata={"help": "Number of steps to update preconditioner."}
244
  )
245
  skip_preconditioning_dim_size_gt: int = field(
246
- default=4096, metadata={"help": "Max size for preconditioning with Distributed Shampoo."}
 
247
  )
248
  optim_quantized: bool = field(
249
  default=False,
 
219
  "help": 'The optimizer to use. Can be "distributed_shampoo" (default), "adam" or "adafactor"'
220
  },
221
  )
222
+ weight_decay: float = field(default=None, metadata={"help": "Weight decay."})
 
 
223
  beta1: float = field(
224
  default=0.9,
225
  metadata={"help": "Beta1 for Adam & Distributed Shampoo."},
 
235
  default=1.0, metadata={"help": "Max gradient norm for Adafactor."}
236
  )
237
  block_size: int = field(
238
+ default=1024,
239
+ metadata={"help": "Chunked size for large layers with Distributed Shampoo."},
240
  )
241
  preconditioning_compute_steps: int = field(
242
  default=10, metadata={"help": "Number of steps to update preconditioner."}
243
  )
244
  skip_preconditioning_dim_size_gt: int = field(
245
+ default=4096,
246
+ metadata={"help": "Max size for preconditioning with Distributed Shampoo."},
247
  )
248
  optim_quantized: bool = field(
249
  default=False,