{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "f41486ad", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NVIDIA A100-PCIE-40GB\n" ] } ], "source": [ "# step 0. Preliminary\n", "import torch\n", "# check that cuda doesn't crash on us\n", "print(torch.cuda.get_device_name())\n", "# check that transformers installed\n", "import transformers" ] }, { "cell_type": "code", "execution_count": 2, "id": "ffd19cfb", "metadata": {}, "outputs": [], "source": [ "EPOCHS=3" ] }, { "cell_type": "code", "execution_count": 3, "id": "3a91ef1f", "metadata": {}, "outputs": [], "source": [ "# Step 1. Preparing the training\n", "# First ensure that required files are here\n", "from pathlib import Path\n", "assert Path(\"TinyStoriesV2-GPT4-train.txt\").exists()\n", "assert Path(\"TinyStoriesV2-GPT4-valid.txt\").exists()" ] }, { "cell_type": "code", "execution_count": 4, "id": "56b046d5", "metadata": {}, "outputs": [], "source": [ "# Then prepare directories\n", "Path(\"chunks.txt/train\").mkdir(parents=True, exist_ok=True)\n", "Path(\"chunks.tensors/train\").mkdir(parents=True, exist_ok=True)\n", "Path(\"chunks.txt/valid\").mkdir(parents=True, exist_ok=True)\n", "Path(\"chunks.tensors/valid\").mkdir(parents=True, exist_ok=True)" ] }, { "cell_type": "code", "execution_count": 5, "id": "1bddb2ee", "metadata": {}, "outputs": [], "source": [ "# Then prepare method to split one text to several\n", "from multiprocessing.pool import Pool\n", "from tqdm.contrib.concurrent import process_map\n", "import os\n", "_chunk_me = None\n", "def extract_chunk(chunk):\n", " split, i, chunk_from, chunk_to = chunk\n", " chunk = _chunk_me[chunk_from:chunk_to].strip() \n", " name = f\"chunks.txt/{split}/chunk-{i+1}.txt\"\n", " with open(name, \"w\") as f:\n", " f.write(chunk)\n", " return name\n", "\n", "def split_to_text_chunks(split:str, chunk_size = 16*1024*1024, max_workers=None):\n", " global _chunk_me #text is too chunky to pass as argument. storing as global so fork() can take care of it\n", " print(f\"reading {split}\")\n", " text = _chunk_me = Path(f\"./TinyStoriesV2-GPT4-{split}.txt\").read_text()\n", " offsets = [] \n", " delimiter = \"<|endoftext|>\"\n", " i=0\n", " while i < len(text): \n", " offsets.append(i)\n", " i += chunk_size\n", " i = text.find(delimiter, i)\n", " if i < 0:\n", " break\n", " i += len(delimiter)\n", " offsets.append(len(text))\n", " chunks = [(split, i, start,end) for (i, (start, end)) in enumerate(zip(offsets[:-1], offsets[1:]))]\n", " \n", " print(\"writing\")\n", " process_map(extract_chunk, chunks, max_workers=max_workers)\n", " " ] }, { "cell_type": "code", "execution_count": 7, "id": "e60017ee", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Assuming split has finished already\n" ] } ], "source": [ "# Prepare text of train split\n", "if not Path(\"chunks.txt/train/chunk-133.txt\").exists():\n", " split_to_text_chunks(\"train\")\n", "else:\n", " print(\"Assuming split has finished already\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "e9b7effe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Assuming split has finished already\n" ] } ], "source": [ "# Prepare text of valid split\n", "if not Path(\"chunks.txt/valid/chunk-2.txt\").exists():\n", " split_to_text_chunks(\"valid\") \n", "else:\n", " print(\"Assuming split has finished already\")" ] }, { "cell_type": "code", "execution_count": 10, "id": "b4706f24", "metadata": {}, "outputs": [], "source": [ "# Step 2. Prepare OpenLLAMA tokenizer. \n", "#Needed to be done once(TODO: add code to load tokenizer?)\n", "from transformers import AutoTokenizer\n", "import os\n", "if not Path('tokenizer.json').exists(): \n", " try:\n", " tokenizer = AutoTokenizer.from_pretrained(\"openlm-research/open_llama_3b\")\n", " tokenizer.save_pretrained(\".\")\n", " except Exception as e:\n", " print(e)\n", " os.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"]=\"python\" \n", " tokenizer = AutoTokenizer.from_pretrained(\"openlm-research/open_llama_3b\")\n", " tokenizer.save_pretrained(\".\")\n", " del os.environ[\"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION\"]\n", "tokenizer = AutoTokenizer.from_pretrained(\".\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "f9c935b0", "metadata": {}, "outputs": [], "source": [ "# Step 3. Preparing to tokenize each text chunk\n", "from tqdm.contrib.concurrent import process_map\n", "def tokenize_file(filename:Path):\n", " text = Path.read_text(filename)\n", " stories = text.split(\"<|endoftext|>\")\n", " result = []\n", " while stories:\n", " story = stories.pop(0).strip()\n", " tokenized = tokenizer(story, max_length=None).input_ids\n", " tokenized.append(tokenizer.eos_token_id)\n", " result.append(torch.tensor(tokenized))\n", " output_name = str(filename).replace(\".txt\", \".tensors\")\n", " torch.save(result, output_name)\n", "\n", "def tokenize_split(split, max_workers=None):\n", " to_process = list(Path(f\"chunks.txt/{split}\").glob(\"*\")) \n", " process_map(tokenize_file, to_process, max_workers=max_workers)\n", " " ] }, { "cell_type": "code", "execution_count": 12, "id": "95257f12", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Assuming train was tokenized already\n" ] } ], "source": [ "# processing train(this can take several minutes)\n", "if not Path(\"chunks.tensors/train/chunk-133.tensors\").exists():\n", " tokenize_split(\"train\")\n", "else:\n", " print(\"Assuming train was tokenized already\")" ] }, { "cell_type": "code", "execution_count": 13, "id": "bbbe4599", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Assuming valid was tokenized already\n" ] } ], "source": [ "# processing valid(this can take one minutes)\n", "if not Path(\"chunks.tensors/valid/chunk-2.tensors\").exists():\n", " tokenize_split(\"valid\")\n", "else:\n", " print(\"Assuming valid was tokenized already\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "a31a4aa7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Resetting [PAD] to [EOS]\n" ] } ], "source": [ "# Step 4. Training. \n", "# Step 4.1 Preparing tokenizer and setting pad token if it is not set(it is not set)\n", "tokenizer = AutoTokenizer.from_pretrained(\".\")\n", "if not tokenizer.pad_token_id:\n", " tokenizer.pad_token_id = tokenizer.eos_token_id\n", " print(\"Resetting [PAD] to [EOS]\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "f677c9c0", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Step 4.2. Preparing model\n", "from transformers.models.llama.modeling_llama import LlamaConfig, LlamaForCausalLM\n", "\n", "tiny_llama = LlamaConfig(\n", " hidden_size=64, \n", " vocab_size=tokenizer.vocab_size,\n", " intermediate_size=256, \n", " num_attention_heads=16, \n", " num_hidden_layers=8)\n", "\n", "torch.manual_seed(11010)\n", "torch.cuda.manual_seed(11010)\n", "model = LlamaForCausalLM(tiny_llama).cuda().bfloat16()" ] }, { "cell_type": "code", "execution_count": 16, "id": "aad9620b", "metadata": {}, "outputs": [], "source": [ "import functools\n", "import torch.nn.functional as F\n", "from tqdm.contrib.concurrent import process_map\n", "from tqdm.auto import tqdm\n", "\n", "# Step 4.3 Preparing dataset class\n", "def get_file_data_len(filename):\n", " data = torch.load(filename)\n", " return (filename, len(data))\n", "from datasets import Dataset\n", "\n", "CACHE_SIZE = 2000 # There are ~150 train splits. We can fit them in memory, so let's do it\n", "\n", "class TinyDataset(Dataset):\n", " def __init__(self, split: str, populate_cache=True):\n", " print(f\"Reading dataset {split} data\")\n", " self.file_lens = process_map(\n", " get_file_data_len,\n", " list(Path(f\"chunks.tensors/{split}\").glob(\"*\")))\n", " self.file_lens.sort()\n", " if populate_cache:\n", " print(\"Populating a cache\")\n", " for filename, _ in tqdm(self.file_lens):\n", " self.load_tensor_file(filename)\n", "\n", " @functools.lru_cache(maxsize=CACHE_SIZE)\n", " def load_tensor_file(self, filename):\n", " return torch.load(filename)\n", "\n", " def __len__(self):\n", " return sum(x[1] for x in self.file_lens)\n", "\n", " def global_index_to_local(self, i):\n", " for (file, length) in self.file_lens:\n", " if i < length:\n", " return (file, i)\n", " i -= length\n", " raise IndexError(f\"{i} is out-of-bonds, have {len(self)} sample\")\n", "\n", " def __getitem__(self, index):\n", " if torch.is_tensor(index):\n", " index = index.tolist()\n", " if isinstance(index, int):\n", " filename, local_index = self.global_index_to_local(index)\n", " tensors = self.load_tensor_file(filename)\n", " return {\n", " 'input_ids': tensors[local_index]\n", " }\n", " if isinstance(index, list):\n", " data = []\n", " indices = index\n", " for index in indices:\n", " filename, local_index = self.global_index_to_local(index)\n", " tensors = self.load_tensor_file(filename)\n", " data.append(tensors[local_index])\n", "\n", " return {'input_ids': data}\n", "\n", " raise TypeError(f'Invaldi index type {type(index)}')\n", " \n", "def batch_collate(data: list[torch.Tensor]):\n", " max_len = max(len(datum[\"input_ids\"]) for datum in data)\n", " inputs = []\n", " attentions = []\n", " for row in data:\n", " input_ids = row[\"input_ids\"]\n", " attention_mask = torch.ones_like(input_ids)\n", " attention_mask[-1] = 0 # don't care about EOS\n", " # Manual padding\n", " to_pad = max_len - len(input_ids)\n", " is_left_pad = tokenizer.padding_side == \"left\"\n", " padding = (is_left_pad * to_pad, (1 - is_left_pad) * to_pad)\n", " input_ids = F.pad(input_ids, padding, value=tokenizer.pad_token_id)\n", " attention_mask = F.pad(attention_mask, padding, value=0)\n", " inputs.append(input_ids)\n", " attentions.append(attention_mask)\n", "\n", " attention_masks = torch.stack(attentions)\n", " input_ids = torch.stack(inputs)\n", " labels = input_ids.clone()\n", "\n", " # disable prediction of the padding\n", " labels[attention_masks == 0] = -100\n", " # enable prediction of an actual EOS\n", " labels[:, -1] = tokenizer.eos_token_id\n", "\n", " return {\n", " 'input_ids': input_ids,\n", " 'attention_mask': attention_masks,\n", " 'labels': labels\n", " }\n", "\n", "def get_max_story_length(ds): \n", " return max(file_len[1] for file_len in ds.file_lens)\n" ] }, { "cell_type": "code", "execution_count": 17, "id": "2e828afe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Reading dataset train data\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8ca542afc1694073af6dcf9ce5f7e13a", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/133 [00:00 3\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m get_max_story_length(train_ds) \u001b[38;5;241m<\u001b[39m\u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mmodel_max_length, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mWARNIING: split long stories\u001b[39m\u001b[38;5;124m\"\u001b[39m\n", "\u001b[0;31mAssertionError\u001b[0m: WARNIING: split long stories" ] } ], "source": [ "assert tokenizer.padding_side in [\"left\", \"right\"]\n", "train_ds = TinyDataset(\"train\")\n", "assert get_max_story_length(train_ds) <= tokenizer.model_max_length, \"WARNIING: split long stories\"" ] }, { "cell_type": "code", "execution_count": 19, "id": "6412e7c5", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "torch.manual_seed(11010)\n", "torch.cuda.manual_seed(11010)\n", "train_dl = DataLoader(train_ds, 16, True, collate_fn=batch_collate)" ] }, { "cell_type": "code", "execution_count": 20, "id": "f3ff5a66", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mggg4\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.15.5" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /home/mayk/tiny-llama/wandb/run-20230707_181234-rilt4m6f" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run grateful-jazz-4 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/ggg4/training-tiny-llama-preview" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/ggg4/training-tiny-llama-preview/runs/rilt4m6f" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# prepare wandb\n", "import wandb\n", "wandb.init(\n", " project=\"training-tiny-llama-preview\",\n", " config={\n", " \"architecture\": \"llama\",\n", " \"dataset\": \"tiny-stories\",\n", " \"epochs\": EPOCHS,\n", " } \n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "aed7b7a4", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 21, "id": "166a4a27", "metadata": {}, "outputs": [], "source": [ "from tqdm.auto import tqdm\n", "def save_imm(epoch, step, saved=[]):\n", " fname = f\"step-{epoch}-{step}.bin\"\n", " torch.save(model.state_dict(), f\"step-{epoch}-{step}.bin\")\n", " saved.append(fname)\n", " if len(saved) > 5:\n", " delete_me = saved.pop(0)\n", " Path(delete_me).unlink(missing_ok=True)\n", "\n", "def epoch_step(epoch, opt):\n", " for i, batch in enumerate(bar := tqdm(train_dl)):\n", " for k in batch:\n", " batch[k] = batch[k].to(device=model.lm_head.weight.device)\n", " \n", " n_batch, n_seq = batch[\"input_ids\"].shape\n", " if n_seq > tokenizer.model_max_length:\n", " assert tokenizer.padding_side == \"right\", \"Left-pad truncation only supported[as model should not see >2k token anyway]\"\n", " batch[\"input_ids\"] = batch[\"input_ids\"][:, -tokenizer.model_max_length]\n", " batch[\"labels\"] = batch[\"labels\"][:, -tokenizer.model_max_length]\n", " batch[\"attention_mask\"] = batch[\"attention_mask\"][:, -tokenizer.model_max_length]\n", " \n", " \n", " loss = model(**batch).loss\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " bar.set_description(f'L:{loss.item():.4f}')\n", " wandb.log({\"loss\": loss.item()})\n", " if (i+1) % 100 == 0:\n", " save_imm(epoch, i+1)\n", " \n", " torch.save(model.state_dict(), f\"epoch-{epoch}.bin\")\n" ] }, { "cell_type": "code", "execution_count": 22, "id": "ec4943c7", "metadata": {}, "outputs": [], "source": [ "opt = torch.optim.AdamW(model.parameters(), fused=True)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "daae9020", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f7ab6fe3b99546f49acb0d43888b7ceb", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/169865 [00:00