{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### Original GPT-J perlexity" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n", "import transformers\n", "from tqdm.auto import tqdm\n", "\n", "\n", "\n", "model_name = \"EleutherAI/gpt-j-6B\"\n", "gpt = transformers.AutoModelForCausalLM.from_pretrained(model_name)\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "device = 'cuda' if torch.cuda.is_available else 'cpu'\n", "gpt.to(device).train(False);" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Reusing dataset wikitext (/home/jheuristic/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "47f0459174da4ee2bf064c9ae81fdecd", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/3 [00:00 \"FrozenBNBLinear\":\n", " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n", " return cls(weights_int8, *state, linear.bias)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n", " \n", " \n", "class FrozenBNBEmbedding(nn.Module):\n", " def __init__(self, weight, absmax, code):\n", " super().__init__()\n", " self.num_embeddings, self.embedding_dim = weight.shape\n", " self.register_buffer(\"weight\", weight.requires_grad_(False))\n", " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n", " self.register_buffer(\"code\", code.requires_grad_(False))\n", " \n", " def forward(self, x, **kwargs):\n", " with torch.no_grad():\n", " # note: both quantuized weights and input indices are *not* differentiable\n", " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n", " return F.embedding(x, weight_deq, **kwargs)\n", " \n", " @classmethod\n", " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n", " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n", " return cls(weights_int8, *state)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n", " \n", " \n", "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n", " assert chunk_size % 4096 == 0\n", " code = None\n", " chunks = []\n", " absmaxes = []\n", " flat_tensor = matrix.view(-1)\n", " for i in range((matrix.numel() - 1) // chunk_size + 1):\n", " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n", " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n", " chunks.append(quantized_chunk)\n", " absmaxes.append(absmax_chunk)\n", " \n", " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n", " absmax = torch.cat(absmaxes)\n", " return matrix_i8, (absmax, code)\n", "\n", "\n", "def dummify(model, adapter_dim: int = 0):\n", " for module in list(model.modules()):\n", " for name, child in module.named_children():\n", " if isinstance(child, nn.Linear):\n", " print(name, child)\n", " setattr(\n", " module,\n", " name,\n", " FrozenBNBLinear(\n", " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n", " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n", " code=torch.zeros(256),\n", " bias=child.bias,\n", " ),\n", " )\n", " elif isinstance(child, nn.Embedding):\n", " setattr(\n", " module,\n", " name,\n", " FrozenBNBEmbedding(\n", " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n", " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n", " code=torch.zeros(256),\n", " )\n", " ),\n", "\n", "\n", "def bnbfy_(model, adapter_dim: int = 0):\n", " for module in list(model.modules()):\n", " for name, child in module.named_children():\n", " if isinstance(child, nn.Linear):\n", " print(name, child)\n", " setattr(module, name, FrozenBNBLinear.from_linear(child))\n", " \n", " elif isinstance(child, nn.Embedding):\n", " print(name, child)\n", " setattr(module, name, FrozenBNBEmbedding.from_embedding(child))" ] }, { "cell_type": "code", "execution_count": 66, "metadata": {}, "outputs": [], "source": [ "class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):\n", " def __init__(self, config):\n", " print(\"MONKEYPATCH BLOCK\")\n", " super().__init__(config)\n", "\n", " dummify(self.attn)\n", " dummify(self.mlp)\n", "\n", "transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock\n", "\n", "\n", "class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " dummify(self)\n", "class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " dummify(self)\n" ] }, { "cell_type": "code", "execution_count": 67, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "1c98b9ebbf8d44d8b0bc422d4bfce21f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Downloading: 0%| | 0.00/0.98k [00:00