{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "inference", "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU", "gpuClass": "standard" }, "cells": [ { "cell_type": "code", "source": [ "!pip install transformers==4.14.1\n", "!pip install bitsandbytes-cuda111==0.26.0\n", "\n", "from IPython import display \n", "display.clear_output()" ], "metadata": { "id": "Q8cuAdVDGXR6" }, "execution_count": 1, "outputs": [] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "8mkqaWlNGLKn" }, "outputs": [], "source": [ "import transformers\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.cuda.amp import custom_fwd, custom_bwd\n", "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "source": [ "#@title convert to 8bit\n", "class FrozenBNBLinear(nn.Module):\n", " def __init__(self, weight, absmax, code, bias=None):\n", " assert isinstance(bias, nn.Parameter) or bias is None\n", " super().__init__()\n", " self.out_features, self.in_features = weight.shape\n", " self.register_buffer(\"weight\", weight.requires_grad_(False))\n", " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n", " self.register_buffer(\"code\", code.requires_grad_(False))\n", " self.adapter = None\n", " self.bias = bias\n", " \n", " def forward(self, input):\n", " output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n", " if self.adapter:\n", " output += self.adapter(input)\n", " return output\n", " \n", " @classmethod\n", " def from_linear(cls, linear: nn.Linear) -> \"FrozenBNBLinear\":\n", " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n", " return cls(weights_int8, *state, linear.bias)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n", " \n", " \n", "class DequantizeAndLinear(torch.autograd.Function): \n", " @staticmethod\n", " @custom_fwd\n", " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n", " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n", " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n", " ctx.save_for_backward(input, weights_quantized, absmax, code)\n", " ctx._has_bias = bias is not None\n", " return F.linear(input, weights_deq, bias)\n", " \n", " @staticmethod\n", " @custom_bwd\n", " def backward(ctx, grad_output: torch.Tensor):\n", " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n", " input, weights_quantized, absmax, code = ctx.saved_tensors\n", " # grad_output: [*batch, out_features]\n", " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n", " grad_input = grad_output @ weights_deq\n", " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n", " return grad_input, None, None, None, grad_bias\n", " \n", " \n", "class FrozenBNBEmbedding(nn.Module):\n", " def __init__(self, weight, absmax, code):\n", " super().__init__()\n", " self.num_embeddings, self.embedding_dim = weight.shape\n", " self.register_buffer(\"weight\", weight.requires_grad_(False))\n", " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n", " self.register_buffer(\"code\", code.requires_grad_(False))\n", " self.adapter = None\n", " \n", " def forward(self, input, **kwargs):\n", " with torch.no_grad():\n", " # note: both quantuized weights and input indices are *not* differentiable\n", " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n", " output = F.embedding(input, weight_deq, **kwargs)\n", " if self.adapter:\n", " output += self.adapter(input)\n", " return output \n", " \n", " @classmethod\n", " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n", " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n", " return cls(weights_int8, *state)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n", " \n", " \n", "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n", " assert chunk_size % 4096 == 0\n", " code = None\n", " chunks = []\n", " absmaxes = []\n", " flat_tensor = matrix.view(-1)\n", " for i in range((matrix.numel() - 1) // chunk_size + 1):\n", " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n", " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n", " chunks.append(quantized_chunk)\n", " absmaxes.append(absmax_chunk)\n", " \n", " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n", " absmax = torch.cat(absmaxes)\n", " return matrix_i8, (absmax, code)\n", " \n", " \n", "def convert_to_int8(model):\n", " \"\"\"Convert linear and embedding modules to 8-bit with optional adapters\"\"\"\n", " for module in list(model.modules()):\n", " for name, child in module.named_children():\n", " if isinstance(child, nn.Linear):\n", " setattr( \n", " module,\n", " name,\n", " FrozenBNBLinear(\n", " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n", " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n", " code=torch.zeros(256),\n", " bias=child.bias,\n", " ),\n", " )\n", " elif isinstance(child, nn.Embedding):\n", " setattr(\n", " module,\n", " name,\n", " FrozenBNBEmbedding(\n", " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n", " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n", " code=torch.zeros(256),\n", " )\n", " )\n", "class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):\n", " def __init__(self, config):\n", " super().__init__(config)\n", "\n", " convert_to_int8(self.attn)\n", " convert_to_int8(self.mlp)\n", "\n", "class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " convert_to_int8(self)\n", " \n", "class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " convert_to_int8(self)\n", "\n", "\n", "transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J" ], "metadata": { "cellView": "form", "id": "fmpdVvfVG7Pc" }, "execution_count": 3, "outputs": [] }, { "cell_type": "code", "source": [ "tokenizer = transformers.AutoTokenizer.from_pretrained(\"EleutherAI/gpt-j-6B\")\n", "gpt = GPTJForCausalLM.from_pretrained(\"crumb/gpt-j-6b-shakespeare\", low_cpu_mem_usage=True)\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "gpt = gpt.to(device)" ], "metadata": { "id": "ttKTRoUlG5YM" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "prompt = \"\"\"ROMEO: I would I were thy bird. \n", "JULIET: Sweet, so would I, Yet I should kill thee with much cherishing. Good night, good night! Parting is such sweet\"\"\"\n", "prompt = tokenizer(prompt, return_tensors='pt')\n", "prompt = {key: value.to(device) for key, value in prompt.items()}\n", "out = gpt.generate(**prompt, min_length=32, max_length=64, do_sample=True)\n", "out = tokenizer.decode(out[0])\n", "print(out)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "kSXSZz_kGcfm", "outputId": "d91dda66-88ab-4e52-bfb9-3df8092abe2f" }, "execution_count": 9, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "ROMEO: I would I were thy bird. \n", "JULIET: Sweet, so would I, Yet I should kill thee with much cherishing. Good night, good night! Parting is such sweet sorrow, As a lost angel's song, in answer to An evil dream.\n", "\n", "ROMEO\n" ] } ] } ] }