File size: 18,252 Bytes
bd28213 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 |
{
"cells": [
{
"cell_type": "markdown",
"id": "953903b9-5d23-40f1-8b77-d4aa692c7a75",
"metadata": {},
"source": [
"# Preamble"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "fea7b061-d527-4b01-a945-54124753640e",
"metadata": {},
"outputs": [],
"source": [
"from pathlib import Path\n",
"from tqdm.notebook import tqdm\n",
"\n",
"import math\n",
"import numpy as np\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax\n",
"import flax\n",
"from flax.training import train_state\n",
"from flax.training.common_utils import get_metrics, onehot, shard\n",
"from flax import jax_utils, traverse_util\n",
"\n",
"from datasets import load_dataset\n",
"from transformers import AutoTokenizer, AutoConfig, GPT2Tokenizer"
]
},
{
"cell_type": "markdown",
"id": "4fd5179e-372c-4222-a000-5ab1567c05b8",
"metadata": {},
"source": [
"# Set up model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "53883b92-02c6-4601-87f4-6b7ab227cb70",
"metadata": {},
"outputs": [],
"source": [
"model_config = 'gpt2-large'\n",
"model_dir = model_config + f\"-finetuned\"\n",
"Path(model_dir).mkdir(parents=True, exist_ok=True)\n",
"config = AutoConfig.from_pretrained('gpt2-large')\n",
"config.save_pretrained(f\"{model_dir}\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "f03ac4c5-77c4-46d4-af78-774077b60b8b",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:absl:Starting the local TPU driver.\n",
"INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://\n",
"INFO:absl:Unable to initialize backend 'gpu': Not found: Could not find registered platform with name: \"cuda\". Available platform names are: Interpreter TPU Host\n",
"tcmalloc: large alloc 3096141824 bytes == 0x8c128000 @ 0x7f216d775680 0x7f216d796824 0x5f7b11 0x648631 0x5c38e6 0x4f30e6 0x64ee88 0x505653 0x56acb6 0x568d9a 0x50b868 0x56fb87 0x568d9a 0x68cdc7 0x5ff5d4 0x5c3cb0 0x56aadf 0x501148 0x56c422 0x501148 0x56c422 0x501148 0x504d56 0x56acb6 0x5f5956 0x56aadf 0x5f5956 0x56acb6 0x568d9a 0x5f5b33 0x50b7f8\n"
]
}
],
"source": [
"from transformers import FlaxGPT2LMHeadModel\n",
"model = FlaxGPT2LMHeadModel.from_pretrained('gpt2-large')#, dtype=jnp.dtype(\"bfloat16\"))"
]
},
{
"cell_type": "markdown",
"id": "dec9146f-e942-49d6-9c83-7e832368eb07",
"metadata": {},
"source": [
"# Load preprocessed data"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "571809b9-e4e4-468c-96fd-05d82976fc75",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:datasets.builder:Using custom data configuration default-f00827eba4e2a675\n",
"WARNING:datasets.builder:Reusing dataset text (/home/user/.cache/huggingface/datasets/text/default-f00827eba4e2a675/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)\n"
]
}
],
"source": [
"dataset = load_dataset('text', \n",
" data_files={'train': \"project-data/raw_data/layout_prompts_train.txt\",\n",
" 'test': \"project-data/raw_data/layout_prompts_valid.txt\"})\n",
"\n",
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large', use_fast=True)\n",
"\n",
"lm_dataset = dataset.load_from_disk('project-data/gpt2_processed/grouped_256')"
]
},
{
"cell_type": "markdown",
"id": "614b535c-b850-42b3-a260-8913cd6bf974",
"metadata": {},
"source": [
"# Training options"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "858b317b-fbce-4afa-9257-4930b66a8337",
"metadata": {},
"outputs": [],
"source": [
"per_device_batch_size = 1\n",
"num_epochs = 3\n",
"training_seed=42\n",
"learning_rate=5e-5\n",
"total_batch_size = per_device_batch_size * jax.device_count()\n",
"num_train_steps = len(lm_dataset[\"train\"]) // total_batch_size * num_epochs\n",
"transition = int(num_train_steps * 0.1)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "225e1a08-c708-4030-bab1-eefa5cb53615",
"metadata": {},
"outputs": [],
"source": [
"def decay_mask_fn(params):\n",
" flat_params = traverse_util.flatten_dict(params)\n",
" flat_mask = {\n",
" path: (path[-1] != \"bias\" and path[-2:] not in [(\"ln_1\", \"scale\"), (\"ln_2\", \"scale\"), (\"ln_f\", \"scale\")])\n",
" for path in flat_params\n",
" }\n",
" return traverse_util.unflatten_dict(flat_mask)\n",
"\n",
"linear_decay_lr_schedule_fn = optax.linear_schedule(init_value=learning_rate, end_value=5e-06, transition_steps=num_train_steps-transition, transition_begin=transition)\n",
"adamw = optax.adamw(\n",
" learning_rate=linear_decay_lr_schedule_fn, \n",
" b1=0.9, \n",
" b2=0.98, \n",
" eps=1e-8, \n",
" weight_decay=0.1,\n",
" mask=decay_mask_fn)\n",
"\n",
"state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)"
]
},
{
"cell_type": "markdown",
"id": "35890bfe-c226-42f1-9b21-fbe2546ef769",
"metadata": {},
"source": [
"# Train"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cd07325b-b785-495c-bb51-8249902b96ce",
"metadata": {},
"outputs": [],
"source": [
"def data_loader(rng, dataset, batch_size, shuffle=False):\n",
" steps_per_epoch = len(dataset) // batch_size\n",
"\n",
" if shuffle:\n",
" batch_idx = jax.random.permutation(rng, len(dataset))\n",
" else:\n",
" batch_idx = jnp.arange(len(dataset))\n",
"\n",
" batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.\n",
" batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))\n",
"\n",
" for idx in batch_idx:\n",
" batch = dataset[idx]\n",
" batch = {k: jnp.array(v) for k, v in batch.items()}\n",
"\n",
" batch = shard(batch)\n",
"\n",
" yield batch\n",
" \n",
"def train_step(state, batch, dropout_rng):\n",
" dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)\n",
"\n",
" def loss_fn(params):\n",
" labels = batch.pop(\"labels\")\n",
" logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]\n",
" \n",
" loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()\n",
" return loss\n",
"\n",
" grad_fn = jax.value_and_grad(loss_fn)\n",
" loss, grad = grad_fn(state.params)\n",
" grad = jax.lax.pmean(grad, \"batch\")\n",
" new_state = state.apply_gradients(grads=grad)\n",
"\n",
" metrics = jax.lax.pmean(\n",
" {\"loss\": loss, \"learning_rate\": linear_decay_lr_schedule_fn(state.step)}, axis_name=\"batch\"\n",
" )\n",
"\n",
" return new_state, metrics, new_dropout_rng\n",
"\n",
"def eval_step(params, batch):\n",
" labels = batch.pop(\"labels\")\n",
"\n",
" logits = model(**batch, params=params, train=False)[0]\n",
"\n",
" loss = optax.softmax_cross_entropy(logits[..., :-1, :], onehot(labels[..., 1:], logits.shape[-1])).mean()\n",
"\n",
" # summarize metrics\n",
" metrics = {\"loss\": loss, \"perplexity\": jnp.exp(loss)}\n",
" metrics = jax.lax.pmean(metrics, axis_name=\"batch\")\n",
" return metrics"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "159e462e-0dc3-49e1-8e5f-265b41044fdf",
"metadata": {},
"outputs": [],
"source": [
"parallel_train_step = jax.pmap(train_step, \"batch\")\n",
"parallel_eval_step = jax.pmap(eval_step, \"batch\")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "f3888e1e-6bd6-415d-a6a9-892abdcbf34c",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/user/neo/lib/python3.8/site-packages/jax/lib/xla_bridge.py:382: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.\n",
" warnings.warn(\n",
"/home/user/neo/lib/python3.8/site-packages/jax/lib/xla_bridge.py:369: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.\n",
" warnings.warn(\n"
]
}
],
"source": [
"state = flax.jax_utils.replicate(state)\n",
"rng = jax.random.PRNGKey(training_seed)\n",
"dropout_rngs = jax.random.split(rng, jax.local_device_count())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "7b8c6e1a-ab32-481a-ad6b-885960119380",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b5170afdbf364e0a923961574abbb39b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Epoch ...: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "dd5a2ee9554742f49975ba19a820b10d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Training...: 0%| | 0/396670 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"2021-07-06 10:40:07.010317: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 0 failed: Resource exhausted: Failed to allocate request for 18.75MiB (19660800B) on device ordinal 0\n",
"2021-07-06 10:40:07.019948: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 4 failed: Resource exhausted: Failed to allocate request for 6.25MiB (6553600B) on device ordinal 6\n",
"2021-07-06 10:40:07.020278: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 1 failed: Resource exhausted: Failed to allocate request for 6.25MiB (6553600B) on device ordinal 1\n",
"2021-07-06 10:40:07.020375: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 3 failed: Resource exhausted: Failed to allocate request for 6.25MiB (6553600B) on device ordinal 3\n",
"2021-07-06 10:40:07.020546: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 7 failed: Resource exhausted: Failed to allocate request for 6.25MiB (6553600B) on device ordinal 5\n",
"2021-07-06 10:40:07.020854: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 6 failed: Resource exhausted: Failed to allocate request for 6.25MiB (6553600B) on device ordinal 4\n",
"2021-07-06 10:40:07.021840: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 2 failed: Resource exhausted: Failed to allocate request for 6.25MiB (6553600B) on device ordinal 2\n",
"2021-07-06 10:40:07.021971: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2036] Execution of replica 5 failed: Resource exhausted: Failed to allocate request for 6.25MiB (6553600B) on device ordinal 7\n"
]
},
{
"ename": "RuntimeError",
"evalue": "Resource exhausted: Failed to allocate request for 18.75MiB (19660800B) on device ordinal 0: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well).",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_839094/3626790602.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmodel_inputs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrain_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;31m# Model forward\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrain_metric\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mparallel_train_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdropout_rngs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 10\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mprogress_bar_train\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
" \u001b[0;31m[... skipping hidden 7 frame]\u001b[0m\n",
"\u001b[0;32m~/neo/lib/python3.8/site-packages/jax/interpreters/pxla.py\u001b[0m in \u001b[0;36mexecute_replicated\u001b[0;34m(compiled, backend, in_handler, out_handler, *args)\u001b[0m\n\u001b[1;32m 1150\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mexecute_replicated\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcompiled\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1151\u001b[0m \u001b[0minput_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0min_handler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1152\u001b[0;31m \u001b[0mout_bufs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompiled\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexecute_sharded_on_local_devices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_bufs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1153\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mxla\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mneeds_check_special\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1154\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mbufs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mout_bufs\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: Resource exhausted: Failed to allocate request for 18.75MiB (19660800B) on device ordinal 0: while running replica 0 and partition 0 of a replicated computation (other replicas may have failed as well)."
]
}
],
"source": [
"for epoch in tqdm(range(1, num_epochs + 1), desc=f\"Epoch ...\", position=0, leave=True):\n",
" rng, input_rng = jax.random.split(rng)\n",
"\n",
" # -- Train --\n",
" train_loader = data_loader(input_rng, lm_dataset[\"train\"], total_batch_size, shuffle=True)\n",
" with tqdm(total=len(lm_dataset[\"train\"]) // total_batch_size, desc=\"Training...\", leave=False) as progress_bar_train:\n",
" for model_inputs in train_loader:\n",
" # Model forward\n",
" state, train_metric, dropout_rngs = parallel_train_step(state, model_inputs, dropout_rngs)\n",
"\n",
" progress_bar_train.update(1)\n",
"\n",
" progress_bar_train.write(\n",
" f\"Train... ({epoch}/{num_epochs} | Loss: {round(train_metric['loss'].mean(), 3)}, Learning Rate: {round(train_metric['learning_rate'].mean(), 6)})\"\n",
" )\n",
"\n",
" # -- Eval --\n",
" eval_loader = data_loader(input_rng, lm_dataset[\"test\"], total_batch_size)\n",
" eval_metrics = []\n",
" \n",
" with tqdm(total=len(lm_dataset[\"test\"]) // total_batch_size, desc=\"Evaluation...\", leave=False) as progress_bar_eval:\n",
" for model_inputs in eval_loader:\n",
" # Model forward\n",
" eval_metric = parallel_eval_step(state.params, model_inputs)\n",
" eval_metrics.append(eval_metric)\n",
"\n",
" progress_bar_eval.update(1)\n",
" \n",
" eval_metrics = get_metrics(eval_metrics)\n",
" eval_metrics = jax.tree_map(jnp.mean, eval_metrics)\n",
" progress_bar_eval.write(\n",
" f\"Eval... ({epoch}/{num_epochs} | Loss: {eval_metrics['loss']} | Perplexity: {eval_metrics['perplexity']})\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f79fe81-821a-4004-9e58-bebc933d5942",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|