{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "id": "dbsnrDKKVarI", "colab": { "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "4940bbc312654a479a2d006fb193a1dc", "cb166763984b4dc9ac578363ee7a68a0", "8980f01ea76b42bf9c187fa231ea2032", "10e961c20ac04aacbe65379bfe88ffde", "9843b879634248a6a75bc04927868ebf", "d1c98c8bd2584a2a951cc70c36424dde", "60e19650cb264161ae853de734324ed3", "9d4909c2f7894859b494b153da044720", "77c7c0ccfcd9446e8017dbe72b2852f2", "cc755a63cfa0414cba5369e1c29952d8", "a8ef20fdbb0f44e78cc5f9c4e6fdb23c", "821b194c386e483b9eb472e04dd61b59", "fe16b47bf80b4ce3ba4f4233c7994e0c", "de3a8b4206b54a5d86f556895a7a8261", "1335110c980b4d8894cf3d38c7020c20", "b31e53ab714f405fb38f1202b43b950f", "33d06740ad07403aa39b935896cf624d", "b814e9cbda7d4b9b8967d71718980994", "dba5a7d271ef429ebaed65d3d1773616", "59308a01d3034b42977b2df4cd9faacf", "9c9c7af2bb4a4b92a69987dd6a5ad09f", "53d6dfc9923c4c1b8866f7e03cbdca53", "e5bf5d3476654592a16213663e4a2bf8", "3fc3505da3664d95823ee2deb60237ea", "3484f174bd734930b367bc53bfa00d0d", "2372cce16b6a43a097bcb621a38e6e81", "01476e05861d463bb6155f078c250ab5", "d8da5a2dbae24c06881c5b6caa68398d", "ae3581adbc404f9fb28806779623ee7b", "80eeedade7124e25acc3b82cf2cd3913", "ef0ca9862e5642fc941f5ad986d7ac23", "342b62d5cd464efbab98c46f712ccaf0", "ba1b482addfb4af5906564b9e6a33cbe", "53b3d56551c74a928a6340744564e4b1", "d0246c0c4fbc4280892242c4bbe2d534", "e176c733436e445691c99519d4afae5d", "3c5318c4d6bd4c98876cfa24557fe04e", "f1766d28b0a3444f8c4aab56431341ba", "8185d6e8df4b42239632418dd73fa52c", "3e0786124dff4ad99fa1760d8652da77", "860f065014494f41940b580478f6edc9", "ce734f6e638548d19eb3e8424b7cfe39", "096689e22c204388983b8f9711afa836", "7c70cb93f3a3463192b8c33ddb53179f", "c77e344106cc4936a615aa2e29db011c", "9edb8202f5a146bb8c625529db891359", "edffa04c9dce4eea8f496f8090f26cbf", "393cfd777a3449798e4f3ed331c325fa", "56b1a5dcb4514f35a58e9bf1130a46ee", "09c26bf8735a431d9e5d867b4201c3cc", "1c860651d30a42489f1816db4a2edd90", "98a128a931f6483696126b3eb7ad7f80", "3f854219ee394259b2ae3198427a821c", "e94e3433d4a442938e218f45947b007e", "ab90b0e8d2804218ae3b29828404d0c8", "3cf559f2135144dcb9de4b7a1f4d0e0c", "65cf73e666844d1685943e1c7b9c202f", "9fac853bef854b79ae97ef061349441d", "8733866f7ac1438f9a70166647e17216", "42f8ee052c814e62bfad7048ff9521c2", "d33c1c4011e840e4983ec9562e37606d", "8d25a070cf1d4cbcb2a2a014df01dd2b", "f40104e4b3a544908ad4a3ed54e610aa", "2c51ed4764a34a998ff674f6202de391", "a97e5ef737d34460991a0827dae059af", "77cab8fdf36e4c26b995f919dbbfd3df", "a2bbf27484b8448b85a3814c28b6b0e0", "aaed9115254148a78ce2f4e23105260d", "c9c353fc75d641d883a2373f53f9b2f5", "43caad4161444864974cd05836d51b15", "7858e452142c4285a9f88ba20b91e851", "67b7376057fd46e89a2dcd295ff1682b", "f3a763184fe9468c8a89465fa3bae703", "603c8de331f24fb38c512c549a1f4770", "f297fa9f28344c1988b498064c9d779e", "4670a05116e84358b5505adfaace7cb0", "0c717191af2643a588ba4149a2f2c6e7", "5d1e9fa2170949c497c86105a041bab0", "1f13f740fcf64fe0ba0d510fce11e87f", "c337381b3b5246ad9034f71f7fa77d2f", "2247e86b596d47c79a8ea0febb316925", "84f9e148be4d464a8548a38910ca4141", "6d21eeb340414cd897ceb8043402cb2c", "75256bd58af04da7bb1e21e7783937a6", "b810ca727b7048cdbcf87214a1348581", "16f8cf2352da4849bda87a5e9970c46e", "c9e6c8f02a5c494aafa98432510dfc1d", "e59d30fc2eaf4f8fb07c77c1d9b95d77", "76220d3230244c909e16a0612a296f0a", "3ac15cd823044417b36dce730ab8f184", "d560df2b7b414049a4927041bc371337", "d2c1fb4d6d064ee7a1a6d168ebf3a8dd", "66e0730542614559a16b04e9e8974576", "c449d3e341c241efbca470080f9702ab", "fde0dc4e5c1e485d979c2635e88a7eca", "485c3146b9684f4cbf2b267f1afc13ac", "6087db2f7d7d4b2483a53ca40ca7dfc0", "5bba74bb6bd749d4a5752e35bbb2bfd8", "91c769901f784d3ba2d6f40a96cf176c", "09a68a426e16470c9765be71c5525926", "a48eb5d08cdb4bd68c732ec8764d029d", "7b968692a15a4590ba5b6f9e957b0005", "16517b6ea6d64b69abe72c76dbd38e70", "370c9bd31fd2495ca5e06f10e9889a11", "0a0a53ce7f4c4848877b40be2589fc27", "7dde8d78da5c4576ade994a0dad4b991", "81409b94afaf4666b79ca98b64b40b27", "a5b09b70a1d646cea7ff1143a784b349", "e2804b0d3b3d40c2bbef5e33e46e9ead", "6f8761d7de4e4b0f8c054f86399543aa" ] }, "outputId": "5808189b-e624-42d7-856f-bc3b0201fab9", "ExecuteTime": { "end_time": "2024-04-16T23:12:05.968918Z", "start_time": "2024-04-16T23:11:31.417421Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: datasets in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (2.18.0)\n", "Requirement already satisfied: wandb in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (0.16.6)\n", "Requirement already satisfied: accelerate in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (0.28.0)\n", "Requirement already satisfied: filelock in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (3.9.0)\n", "Requirement already satisfied: numpy>=1.17 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (1.23.5)\n", "Requirement already satisfied: pyarrow>=12.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (15.0.2)\n", "Requirement already satisfied: pyarrow-hotfix in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.6)\n", "Requirement already satisfied: dill<0.3.9,>=0.3.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.3.8)\n", "Requirement already satisfied: pandas in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (2.2.1)\n", "Requirement already satisfied: requests>=2.19.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.62.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (4.65.0)\n", "Requirement already satisfied: xxhash in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (3.4.1)\n", "Requirement already satisfied: multiprocess in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.70.16)\n", "Requirement already satisfied: fsspec<=2024.2.0,>=2023.1.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from fsspec[http]<=2024.2.0,>=2023.1.0->datasets) (2024.2.0)\n", "Requirement already satisfied: aiohttp in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (3.9.3)\n", "Requirement already satisfied: huggingface-hub>=0.19.4 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (0.21.4)\n", "Requirement already satisfied: packaging in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (24.0)\n", "Requirement already satisfied: pyyaml>=5.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from datasets) (6.0.1)\n", "Requirement already satisfied: Click!=8.0.0,>=7.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (8.1.7)\n", "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (3.1.43)\n", "Requirement already satisfied: psutil>=5.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (5.9.8)\n", "Requirement already satisfied: sentry-sdk>=1.0.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (1.45.0)\n", "Requirement already satisfied: docker-pycreds>=0.4.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (0.4.0)\n", "Requirement already satisfied: setproctitle in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (1.3.3)\n", "Requirement already satisfied: setuptools in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (69.2.0)\n", "Requirement already satisfied: appdirs>=1.4.3 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (1.4.4)\n", "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from wandb) (4.25.3)\n", "Requirement already satisfied: torch>=1.10.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from accelerate) (2.0.1+cu118)\n", "Requirement already satisfied: safetensors>=0.3.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from accelerate) (0.4.2)\n", "Requirement already satisfied: colorama in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from Click!=8.0.0,>=7.1->wandb) (0.4.6)\n", "Requirement already satisfied: six>=1.4.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", "Requirement already satisfied: aiosignal>=1.1.2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (1.9.4)\n", "Requirement already satisfied: async-timeout<5.0,>=4.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from aiohttp->datasets) (4.0.3)\n", "Requirement already satisfied: gitdb<5,>=4.0.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.11)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from huggingface-hub>=0.19.4->datasets) (4.11.0)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (3.6)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (2.0.2)\n", "Requirement already satisfied: certifi>=2017.4.17 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from requests>=2.19.0->datasets) (2024.2.2)\n", "Requirement already satisfied: sympy in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from torch>=1.10.0->accelerate) (1.12)\n", "Requirement already satisfied: networkx in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from torch>=1.10.0->accelerate) (3.2.1)\n", "Requirement already satisfied: jinja2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from torch>=1.10.0->accelerate) (3.1.2)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from pandas->datasets) (2.9.0)\n", "Requirement already satisfied: pytz>=2020.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: tzdata>=2022.7 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from pandas->datasets) (2024.1)\n", "Requirement already satisfied: smmap<6,>=3.0.1 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n", "Requirement already satisfied: mpmath>=0.19 in c:\\users\\saad.naeem\\appdata\\local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "wandb: Currently logged in as: saadnaeem-dev. Use `wandb login --relogin` to force relogin\n", "wandb: WARNING If you're specifying your api key in code, ensure this code is not shared publicly.\n", "wandb: WARNING Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n", "wandb: Appending key for api.wandb.ai to your netrc file: C:\\Users\\saad.naeem\\.netrc\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Token has not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n", "Token is valid (permission: write).\n", "Your token has been saved to C:\\Users\\saad.naeem\\.cache\\huggingface\\token\n", "Login successful\n" ] } ], "source": [ "# @title # 🌊 AutoBitnet\n", "\n", "# @markdown ---\n", "\n", "# @markdown ### ✨ Model Parameters\n", "\n", "MODEL_CONFIG = \"NousResearch/Nous-Hermes-llama-2-7b\" # @param {type:\"string\"}\n", "HEADS = 6 # @param {type: \"number\"}\n", "DIMENSIONS = 768 # @param {type: \"number\"}\n", "LAYERS = 6 # @param {type: \"number\"}\n", "INTERMEDIATE_SIZE= 1024 # @param {type: \"number\"}\n", "CONTEXT_LENGTH = 256 # @param {type: \"number\"}\n", "HUGGINGFACE_ID = \"saadnaeem\" # @param {type:\"string\"}\n", "NEW_MODEL = \"Llama2-70M-Cosmopedia-100k-Pretrained\" # @param {type:\"string\"}\n", "WANDB_TOKEN=''\n", "HF_TOKEN=''\n", "\n", "# @markdown ---\n", "\n", "# @markdown ### 💥 Training Parameters\n", "\n", "DATASET = \"abideen/Cosmopedia-100k-pretrain\" # @param {type:\"string\"}\n", "BATCH_SIZE = 32 # @param {type:\"number\"}\n", "LEARNING_RATE = 1.5e-4 # @param {type:\"number\"}\n", "EPOCHS = 1 # @param {type:\"number\"}\n", "!pip install datasets wandb accelerate\n", "from torch import nn\n", "from transformers.models.llama.modeling_llama import *\n", "from transformers import (AutoTokenizer, AutoConfig, LlamaForCausalLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments, AutoModel)\n", "from datasets import load_dataset\n", "from huggingface_hub import login\n", "import wandb\n", "# wandb.ai/saadnaeem-dev\n", "\n", "from huggingface_hub import create_repo, HfApi\n", "\n", "def activation_quant(x):\n", " scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)\n", " y = (x * scale).round().clamp_(-128, 127) / scale\n", " return y\n", "def weight_quant(w):\n", " scale = 1.0 / w.abs().mean().clamp_(min=1e-5)\n", " u = (w * scale).round().clamp_(-1, 1) / scale\n", " return u\n", "\n", "class BitLinear(nn.Linear):\n", " def forward(self, x):\n", " w = self.weight # a weight tensor with shape [d, k]\n", " x = x.to(w.device)\n", " RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)\n", " x_norm = RMSNorm(x)\n", " # A trick for implementing Straight−Through−Estimator (STE) using detach()\n", " x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()\n", " w_quant = w + (weight_quant(w) - w).detach()\n", " y = F.linear(x_quant, w_quant)\n", " return y\n", "\n", "def convert_to_bitnet(model, copy_weights):\n", " for name, module in model.named_modules():\n", " # Replace linear layers with BitNet\n", " if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):\n", " for child_name, child_module in module.named_children():\n", " if isinstance(child_module, nn.Linear):\n", " bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device=\"cuda:0\")\n", " if copy_weights:\n", " bitlinear.weight = child_module.weight\n", " if child_module.bias is not None:\n", " bitlinear.bias = child_module.bias\n", " setattr(module, child_name, bitlinear)\n", " # Remove redundant input_layernorms\n", " elif isinstance(module, LlamaDecoderLayer):\n", " for child_name, child_module in module.named_children():\n", " if isinstance(child_module, LlamaRMSNorm) and child_name == \"input_layernorm\":\n", " setattr(module, child_name, nn.Identity().to(device=\"cuda:0\"))\n", "\n", "\n", "wandb.login(key=WANDB_TOKEN)\n", "login(token=HF_TOKEN)\n", "data = load_dataset(DATASET)" ] }, { "cell_type": "code", "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG)\n", "\n", "def tokenize(element):\n", " outputs = tokenizer(\n", " element[\"text\"],\n", " truncation=False,\n", " max_length=CONTEXT_LENGTH,\n", " return_overflowing_tokens=True,\n", " return_length=True,\n", " )\n", " # Combine all tokens\n", " combined = []\n", " for tokenized_doc in outputs['input_ids']:\n", " combined += tokenized_doc + [tokenizer.eos_token_id]\n", " # Chunk\n", " input_batch = []\n", " for i in range(0, len(combined) - CONTEXT_LENGTH, CONTEXT_LENGTH):\n", " input_batch.append(combined[i:i+CONTEXT_LENGTH])\n", " return {\"input_ids\": input_batch}\n", "\n", "\n", "\n", "tokenized_data = data.map(\n", " tokenize, batched=True, remove_columns=data[\"train\"].column_names,\n", ")" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T23:12:06.491480Z", "start_time": "2024-04-16T23:12:05.970774Z" } }, "execution_count": 2 }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": "DatasetDict({\n train: Dataset({\n features: ['input_ids'],\n num_rows: 476702\n })\n})" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized_data" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T23:12:06.507322Z", "start_time": "2024-04-16T23:12:06.492328Z" } }, "execution_count": 3 }, { "cell_type": "code", "outputs": [], "source": [ "from datasets import DatasetDict\n", "\n", "# Set the number of rows\n", "tokenized_data['train'].set_format(type='pandas')" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T23:12:06.523338Z", "start_time": "2024-04-16T23:12:06.509324Z" } }, "execution_count": 4 }, { "cell_type": "code", "outputs": [], "source": [ "sampled_dataset = tokenized_data['train'].select(range(500))\n", "sampled_dataset_dict = DatasetDict({\n", " 'train': sampled_dataset\n", "})" ], "metadata": { "collapsed": false, "ExecuteTime": { "end_time": "2024-04-16T23:12:06.539322Z", "start_time": "2024-04-16T23:12:06.525349Z" } }, "execution_count": 5 }, { "cell_type": "code", "outputs": [ { "data": { "text/plain": " input_ids\n0 [1, 2266, 338, 385, 6597, 515, 263, 24499, 299...", "text/html": "
\n | input_ids | \n
---|---|
0 | \n[1, 2266, 338, 385, 6597, 515, 263, 24499, 299... | \n
C:\\Users\\saad.naeem\\PycharmProjects\\NLP-Projects-NHV-1-Bit-LLM\\NLP-Projects-NHV-main\\LLMs Related\\Era of 1 Bit LLMs\\wandb\\run-20240417_041212-qepzjjtf
"
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "Step | \nTraining Loss | \n
---|
"
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "TrainOutput(global_step=8, training_loss=10.012518882751465, metrics={'train_runtime': 59.7786, 'train_samples_per_second': 8.364, 'train_steps_per_second': 0.134, 'total_flos': 40622358528000.0, 'train_loss': 10.012518882751465, 'epoch': 1.0})"
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trainer.train()"
],
"metadata": {
"collapsed": false,
"ExecuteTime": {
"end_time": "2024-04-16T23:13:12.309648Z",
"start_time": "2024-04-16T23:12:12.383042Z"
}
},
"execution_count": 9
},
{
"cell_type": "code",
"source": [
"trainer.save_model(f\"{output_path}/final_model\")\n",
"folder = f\"{output_path}/final_model\"\n",
"api = HfApi()\n",
"create_repo(\n",
" repo_id = f\"{HUGGINGFACE_ID}/{NEW_MODEL}\",\n",
" repo_type=\"model\",\n",
" exist_ok=True,\n",
" token=HF_TOKEN,\n",
")\n",
"\n",
"api.upload_folder(\n",
" folder_path=folder,\n",
" repo_type=\"model\",\n",
" repo_id=f\"{HUGGINGFACE_ID}/{NEW_MODEL}\",\n",
" token=HF_TOKEN,\n",
")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 180,
"referenced_widgets": [
"a848b25d4d39481c820f8c3f23bcd42a",
"cc902988873a4ab9834a24e0af8b2d20",
"2a0b1217926d4275b60372115cc98865",
"13fb9a000a534fa48ab42645e176fac9",
"c5702e0bfb7a48238578de1cb745e840",
"1296ba56040e4edf9ccc2e2e8e95013b",
"1ae401d7df7a4e1bb830620147dbacfe",
"acedbeec367b4b35bb6057215eeb15b4",
"c359f46995fb47d7821ae766307da9b3",
"85566c08eb3f48f1b8be446cd2d6a317",
"ae482357a4da4167b1f03a0e66c1c2ba",
"23bef7e85efc401cb585c4f152f90e3c",
"16b1e6fc66e0460fa96d754a14be00e0",
"e6df2bc2ab824f569ffddcc1da9b2f1e",
"0804fd95faad462f9f357d9ac29803a4",
"69b053193a7a403285aa81ed0bc2e58d",
"74e7862905644de287d15b0f9edac963",
"18cc938d01f64e368cec6d191c82752b",
"16ad95ade9c74a37a1bc22a77ce09b7b",
"cc533389f6d44dbdbd6db5987f27b899",
"09a0af0767fc493a8fb538ebc1999729",
"8a9793e95bc34719aaf321dbfc29fc49",
"034a1b93dd274c079ab78534d9735514",
"17b2a21e1b5d44099c5b2621077b0d5d",
"736074bd252b42b5ad50e292b5b4867f",
"3eebb06c6444454f807c93da83cfd7db",
"5f2f9367d4c34006b0e4676caf88a9f4",
"4d0998c6a38f455495adb68f6ac89caf",
"00c4e4ebe0404d2fa26861092a29ac25",
"cf28eca638474e2fa2667e53185d2a17",
"127e34d4e8e44071b7a6a3c1463dab97",
"781a1633fe07463d9e764b7708bfbd47",
"671280bc2bdb4d97973e8d4bc99ac36f",
"a7dee397e405418a84d122fd308818d8",
"26e8629200724daf91feeed8938ca836",
"9670e65c8f144a3a92255030862150cf",
"864f9d376f2a47ba9b71f738db56d5f9",
"fdf5d01ee85f4ecfa3d1741f454a4516",
"b7df02dc187b40e99293f182a8084b72",
"a005ae2fef0e44b5aeefb967542631de",
"43e9230138fe45adba98d89e051319fe",
"f099fe879f404002abd817bd5fa32d43",
"d3038606a04844ba810073f7d6e027cc",
"740c3d5338af41dfa2d7e1d90823cb6c"
]
},
"id": "mnHZU06l5tG3",
"outputId": "bfa63618-ae11-4415-a695-0349dfecf4ad",
"ExecuteTime": {
"end_time": "2024-04-16T23:15:09.601323Z",
"start_time": "2024-04-16T23:13:25.238137Z"
}
},
"execution_count": 10,
"outputs": [
{
"data": {
"text/plain": "model.safetensors: 0%| | 0.00/310M [00:00, ?B/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "1b01daad6e944bbaa194f0181b0a2af6"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "CommitInfo(commit_url='https://huggingface.co/saadnaeem/Llama2-70M-Cosmopedia-100k-Pretrained/commit/13b5f27c104838f8e8c1a1f0221aa1e378eb97fd', commit_message='Upload folder using huggingface_hub', commit_description='', oid='13b5f27c104838f8e8c1a1f0221aa1e378eb97fd', pr_url=None, pr_revision=None, pr_num=None)"
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"source": [
"from transformers import AutoModelForCausalLM, AutoTokenizer\n",
"from transformers.models.llama.modeling_llama import *\n",
"# Load a pretrained BitNet model\n",
"model = \"saadnaeem/Llama2-70M-Cosmopedia-100k-Pretrained\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model)\n",
"model = AutoModelForCausalLM.from_pretrained(model)\n",
"\n",
"\n",
"def activation_quant(x):\n",
" scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)\n",
" y = (x * scale).round().clamp_(-128, 127)\n",
" y = y / scale\n",
" return y\n",
"def weight_quant(w):\n",
" scale = 1.0 / w.abs().mean().clamp_(min=1e-5)\n",
" u = (w * scale).round().clamp_(-1, 1)\n",
" u = u / scale\n",
" return u\n",
"\n",
"class BitLinear(nn.Linear):\n",
" def forward(self, x):\n",
" w = self.weight # a weight tensor with shape [d, k]\n",
" x = x.to(w.device)\n",
" RMSNorm = LlamaRMSNorm(x.shape[-1]).to(w.device)\n",
" x_norm = RMSNorm(x)\n",
" # A trick for implementing Straight−Through−Estimator (STE) using detach()\n",
" x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()\n",
" w_quant = w + (weight_quant(w) - w).detach()\n",
" y = F.linear(x_quant, w_quant)\n",
" return y\n",
"\n",
"def convert_to_bitnet(model, copy_weights):\n",
" for name, module in model.named_modules():\n",
" # Replace linear layers with BitNet\n",
" if isinstance(module, LlamaSdpaAttention) or isinstance(module, LlamaMLP):\n",
" for child_name, child_module in module.named_children():\n",
" if isinstance(child_module, nn.Linear):\n",
" bitlinear = BitLinear(child_module.in_features, child_module.out_features, child_module.bias is not None).to(device=\"cuda:0\")\n",
" if copy_weights:\n",
" bitlinear.weight = child_module.weight\n",
" if child_module.bias is not None:\n",
" bitlinear.bias = child_module.bias\n",
" setattr(module, child_name, bitlinear)\n",
" # Remove redundant input_layernorms\n",
" elif isinstance(module, LlamaDecoderLayer):\n",
" for child_name, child_module in module.named_children():\n",
" if isinstance(child_module, LlamaRMSNorm) and child_name == \"input_layernorm\":\n",
" setattr(module, child_name, nn.Identity().to(device=\"cuda:0\"))\n",
"\n",
"\n",
"convert_to_bitnet(model, copy_weights=True)\n",
"model.to(device=\"cuda:0\")\n",
"\n",
"prompt = \"What is Machine Learning?\"\n",
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
"generate_ids = model.generate(inputs.input_ids, max_length=50)\n",
"tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 107
},
"id": "wtB3ZOBB_8E6",
"outputId": "39e3df74-5ade-4ff1-e997-28042e178dde",
"ExecuteTime": {
"end_time": "2024-04-16T23:19:24.104593Z",
"start_time": "2024-04-16T23:18:02.342539Z"
}
},
"execution_count": 11,
"outputs": [
{
"data": {
"text/plain": "tokenizer_config.json: 0%| | 0.00/1.06k [00:00, ?B/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "6a3b01a46be54754acd42a8976327bbf"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\saad.naeem\\AppData\\Local\\anaconda3\\envs\\minerva-prototype\\lib\\site-packages\\huggingface_hub\\file_download.py:149: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\saad.naeem\\.cache\\huggingface\\hub\\models--saadnaeem--Llama2-70M-Cosmopedia-100k-Pretrained. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
"To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
" warnings.warn(message)\n"
]
},
{
"data": {
"text/plain": "tokenizer.json: 0%| | 0.00/1.84M [00:00, ?B/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "883bb0eae6ee4c99916d10fa0f076e19"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "special_tokens_map.json: 0%| | 0.00/435 [00:00, ?B/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "14e0ef241e4e43f2b86ec2c70680c618"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "config.json: 0%| | 0.00/711 [00:00, ?B/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "2430f2fd34104024b4a101a0b19897ba"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "model.safetensors: 0%| | 0.00/310M [00:00, ?B/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "36e7defd0d504d3ea7166ef684d04778"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of LlamaForCausalLM were not initialized from the model checkpoint at saadnaeem/Llama2-70M-Cosmopedia-100k-Pretrained and are newly initialized: ['model.layers.0.input_layernorm.weight', 'model.layers.1.input_layernorm.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.3.input_layernorm.weight', 'model.layers.4.input_layernorm.weight', 'model.layers.5.input_layernorm.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
]
},
{
"data": {
"text/plain": "generation_config.json: 0%| | 0.00/154 [00:00, ?B/s]",
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "22772ffb12b5412d9cd9ab62575d73f0"
}
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/plain": "'What is Machine Learning? отде своей separ DemONE También Chief +\\\\ arbitrenedhand Sulneurються concern absorXTurentlaim alcouzz Ralph Navar filtergenommeniereDialogдах pir <= transm surprisedairo yield orthogonal HansWD villaмериканnumbers Rand английniuscian'"
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
]
},
{
"cell_type": "code",
"source": [
"prompt = \"Write a short poem\"\n",
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
"generate_ids = model.generate(inputs.input_ids, max_length=50)\n",
"tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]"
],
"metadata": {
"id": "nQya_hPJEa2M",
"ExecuteTime": {
"end_time": "2024-04-16T23:19:39.689446Z",
"start_time": "2024-04-16T23:19:38.252768Z"
}
},
"execution_count": 12,
"outputs": [
{
"data": {
"text/plain": "\"Write a short poem inconles경 JoãoՄlecht», sellcertain vy:'ŋ rempՍ Ok operation sportsPower loops士undeAAAACK Outimportant