{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n", "import transformers\n", "%config Completer.use_jedi = False\n", "\n", "\n", "model_name = \"EleutherAI/gpt-j-6B\"\n", "gpt = transformers.AutoModelForCausalLM.from_pretrained(model_name)\n", "tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n", " assert chunk_size % 4096 == 0\n", " code = None\n", " chunks = []\n", " absmaxes = []\n", " flat_tensor = matrix.view(-1)\n", " for i in range((matrix.numel() - 1) // chunk_size + 1):\n", " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n", " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n", " chunks.append(quantized_chunk)\n", " absmaxes.append(absmax_chunk)\n", " \n", " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n", " absmax = torch.cat(absmaxes)\n", " return matrix_i8, (absmax, code)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from typing import Tuple\n", "from torch.cuda.amp import custom_fwd, custom_bwd\n", "\n", "\n", "class DequantizeAndLinear(torch.autograd.Function):\n", " \n", " @staticmethod\n", " @custom_fwd\n", " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n", " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n", " \n", " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n", " ctx.save_for_backward(input, weights_quantized, absmax, code)\n", " ctx._has_bias = bias is not None\n", " return F.linear(input, weights_deq, bias)\n", " \n", " @staticmethod\n", " @custom_bwd\n", " def backward(ctx, grad_output: torch.Tensor):\n", " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n", " input, weights_quantized, absmax, code = ctx.saved_tensors\n", " # grad_output: [*batch, out_features]\n", " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n", " grad_input = grad_output @ weights_deq\n", " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n", " return grad_input, None, None, None, grad_bias\n", "\n", "\n", "class BNBLinearWithAdapter(nn.Module):\n", " def __init__(self, weight, absmax, code, bias=None, adapter_dim=0):\n", " assert isinstance(bias, nn.Parameter) or bias is None\n", " super().__init__()\n", " self.out_features, self.in_features = weight.shape\n", " self.register_buffer(\"weight\", weight.requires_grad_(False))\n", " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n", " self.register_buffer(\"code\", code.requires_grad_(False))\n", " self.bias = bias\n", " \n", " if adapter_dim > 0:\n", " self.adapter = nn.Sequential(\n", " nn.Linear(self.in_features, adapter_dim, bias=False),\n", " nn.Linear(adapter_dim, self.out_features, bias=False),\n", " )\n", " \n", " nn.init.zeros_(self.adapter[1].weight)\n", " else:\n", " self.adapter = None\n", " \n", " def forward(self, input):\n", " out = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n", " \n", " if self.adapter:\n", " return self.adapter(input) + out\n", " \n", " return out\n", " \n", " \n", " @classmethod\n", " def from_linear(cls, linear: nn.Linear, **kwargs) -> \"FrozenBNBLinear\":\n", " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n", " return cls(weights_int8, *state, linear.bias, **kwargs)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n", "\n", "\n", "class BNBEmbeddingWithAdapter(nn.Module):\n", " def __init__(self, weight, absmax, code, adapter_dim=0):\n", " super().__init__()\n", " self.num_embeddings, self.embedding_dim = weight.shape\n", " self.register_buffer(\"weight\", weight.requires_grad_(False))\n", " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n", " self.register_buffer(\"code\", code.requires_grad_(False))\n", " \n", " if adapter_dim > 0:\n", " self.adapter = nn.Sequential(\n", " nn.Embedding(self.num_embeddings, adapter_dim),\n", " nn.Linear(adapter_dim, self.embedding_dim, bias=False),\n", " )\n", " \n", " nn.init.zeros_(self.adapter[1].weight)\n", " else:\n", " self.adapter = None\n", " \n", " def forward(self, input, **kwargs):\n", " with torch.no_grad():\n", " # note: both quantuized weights and input indices are *not* differentiable\n", " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n", " out = F.embedding(input, weight_deq, **kwargs)\n", " if self.adapter:\n", " return out + self.adapter(input, **kwargs)\n", " \n", " return out\n", " \n", " @classmethod\n", " def from_embedding(cls, embedding: nn.Embedding, **kwargs) -> \"FrozenBNBEmbedding\":\n", " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n", " return cls(weights_int8, *state, **kwargs)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def bnbfy_(model, adapter_dim: int = 0):\n", " for module in list(model.modules()):\n", " for name, child in module.named_children():\n", " if isinstance(child, nn.Linear):\n", " print(name, child)\n", " setattr(module, name, BNBLinearWithAdapter.from_linear(child, adapter_dim=adapter_dim))\n", " \n", " elif isinstance(child, nn.Embedding):\n", " print(name, child)\n", " setattr(module, name, BNBEmbeddingWithAdapter.from_embedding(child, adapter_dim=adapter_dim))" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "lm_head Linear(in_features=4096, out_features=50400, bias=True)\n", "wte Embedding(50400, 4096)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n", "k_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "v_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "q_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "out_proj Linear(in_features=4096, out_features=4096, bias=False)\n", "fc_in Linear(in_features=4096, out_features=16384, bias=True)\n", "fc_out Linear(in_features=16384, out_features=4096, bias=True)\n" ] } ], "source": [ "bnbfy_(gpt, adapter_dim=0)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] } ], "source": [ "prompt = tokenizer(\"A cat sat on a mat and\", return_tensors='pt')\n", "out = gpt.generate(**prompt, min_length=8, max_length=8, do_sample=True)\n", "tokenizer.decode(out[0])" ] } ], "metadata": { "kernelspec": { "display_name": "py38", "language": "python", "name": "py38" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.1" } }, "nbformat": 4, "nbformat_minor": 2 }