{ "cells": [ { "cell_type": "markdown", "id": "e13eff4e-c134-4dac-9523-07b297164250", "metadata": {}, "source": [ "# Example of Fine-tuning 176 billion Bloom with 8-bit weights\n", "\n", "This notebook shows an example of how to fine tune Bloom with Low Rank Adapters. Heavily inspired by [Hivemind's work](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)" ] }, { "cell_type": "code", "execution_count": null, "id": "699e94eb-3ce1-4788-999b-fb6d593ba7e9", "metadata": {}, "outputs": [], "source": [ "!pip install transformers==4.20.1\n", "!pip install bitsandbytes-cuda110\n", "!pip install datasets" ] }, { "cell_type": "markdown", "id": "0afea72c-691d-4719-a84a-663f1891af6e", "metadata": {}, "source": [ "### Load and convert original Bloom structure to 8-bit LoRA\n", "\n", "You can load an already compressed 8-bit version of Bloom from [joaoalvarenga/bloom-8bit](https://huggingface.co/joaoalvarenga/bloom-8bit), but first we need to make some adaptations into original model structure. Some of the following code is an adaptation from [Hivemind's GPT-J 8-bit fine-tuning notebook](https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es)." ] }, { "cell_type": "code", "execution_count": null, "id": "aa5f4118-d4d9-474f-ac36-acaadb920c1f", "metadata": {}, "outputs": [], "source": [ "import transformers\n", "\n", "import torch\n", "import torch.nn.functional as F\n", "from torch import nn\n", "from torch.cuda.amp import custom_fwd, custom_bwd\n", "\n", "from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise\n", "\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": null, "id": "cc4f262e-70de-4a06-a5a6-52d1cd5223d3", "metadata": {}, "outputs": [], "source": [ "class FrozenBNBLinear(nn.Module):\n", " def __init__(self, weight, absmax, code, bias=None):\n", " assert isinstance(bias, nn.Parameter) or bias is None\n", " super().__init__()\n", " self.out_features, self.in_features = weight.shape\n", " self.register_buffer(\"weight\", weight.requires_grad_(False))\n", " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n", " self.register_buffer(\"code\", code.requires_grad_(False))\n", " self.adapter = None\n", " self.bias = bias\n", " \n", " def forward(self, input):\n", " output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)\n", " if self.adapter:\n", " output += self.adapter(input)\n", " return output\n", " \n", " @classmethod\n", " def from_linear(cls, linear: nn.Linear) -> \"FrozenBNBLinear\":\n", " weights_int8, state = quantize_blockise_lowmemory(linear.weight)\n", " return cls(weights_int8, *state, linear.bias)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.in_features}, {self.out_features})\"\n", " \n", " \n", "class DequantizeAndLinear(torch.autograd.Function): \n", " @staticmethod\n", " @custom_fwd\n", " def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,\n", " absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):\n", " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n", " ctx.save_for_backward(input, weights_quantized, absmax, code)\n", " ctx._has_bias = bias is not None\n", " return F.linear(input, weights_deq, bias)\n", " \n", " @staticmethod\n", " @custom_bwd\n", " def backward(ctx, grad_output: torch.Tensor):\n", " assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]\n", " input, weights_quantized, absmax, code = ctx.saved_tensors\n", " # grad_output: [*batch, out_features]\n", " weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)\n", " grad_input = grad_output @ weights_deq\n", " grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None\n", " return grad_input, None, None, None, grad_bias\n", " \n", " \n", "class FrozenBNBEmbedding(nn.Module):\n", " def __init__(self, weight, absmax, code):\n", " super().__init__()\n", " self.num_embeddings, self.embedding_dim = weight.shape\n", " self.register_buffer(\"weight\", weight.requires_grad_(False))\n", " self.register_buffer(\"absmax\", absmax.requires_grad_(False))\n", " self.register_buffer(\"code\", code.requires_grad_(False))\n", " self.adapter = None\n", " \n", " def forward(self, input, **kwargs):\n", " with torch.no_grad():\n", " # note: both quantuized weights and input indices are *not* differentiable\n", " weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)\n", " output = F.embedding(input, weight_deq, **kwargs)\n", " if self.adapter:\n", " output += self.adapter(input)\n", " return output \n", " \n", " @classmethod\n", " def from_embedding(cls, embedding: nn.Embedding) -> \"FrozenBNBEmbedding\":\n", " weights_int8, state = quantize_blockise_lowmemory(embedding.weight)\n", " return cls(weights_int8, *state)\n", " \n", " def __repr__(self):\n", " return f\"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})\"\n", " \n", " \n", "def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):\n", " assert chunk_size % 4096 == 0\n", " code = None\n", " chunks = []\n", " absmaxes = []\n", " flat_tensor = matrix.view(-1)\n", " for i in range((matrix.numel() - 1) // chunk_size + 1):\n", " input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()\n", " quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)\n", " chunks.append(quantized_chunk)\n", " absmaxes.append(absmax_chunk)\n", " \n", " matrix_i8 = torch.cat(chunks).reshape_as(matrix)\n", " absmax = torch.cat(absmaxes)\n", " return matrix_i8, (absmax, code)\n", "\n", "\n", "def convert_to_int8(model):\n", " \"\"\"Convert linear and embedding modules to 8-bit with optional adapters\"\"\"\n", " for module in list(model.modules()):\n", " for name, child in module.named_children():\n", " if isinstance(child, nn.Linear):\n", " print(name, child)\n", " setattr( \n", " module,\n", " name,\n", " FrozenBNBLinear(\n", " weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),\n", " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n", " code=torch.zeros(256),\n", " bias=child.bias,\n", " ),\n", " )\n", " elif isinstance(child, nn.Embedding):\n", " setattr(\n", " module,\n", " name,\n", " FrozenBNBEmbedding(\n", " weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),\n", " absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),\n", " code=torch.zeros(256),\n", " )\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "f4673d4c-0f4e-482e-ac04-b7389397af6e", "metadata": {}, "outputs": [], "source": [ "class BloomBlock(transformers.models.bloom.modeling_bloom.BloomBlock):\n", " def __init__(self, config, layer_number=None):\n", " super().__init__(config, layer_number)\n", "\n", " convert_to_int8(self.self_attention)\n", " convert_to_int8(self.mlp)\n", "\n", "\n", "class BloomModel(transformers.models.bloom.modeling_bloom.BloomModel):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " convert_to_int8(self)\n", " \n", "\n", "class BloomForCausalLM(transformers.models.bloom.modeling_bloom.BloomForCausalLM):\n", " def __init__(self, config):\n", " super().__init__(config)\n", " convert_to_int8(self)\n", " \n", "transformers.models.bloom.modeling_bloom.BloomBlock = BloomBlock" ] }, { "cell_type": "code", "execution_count": null, "id": "eca11b11-9b0b-4958-89f4-401f7a2cac0e", "metadata": {}, "outputs": [], "source": [ "from transformers import BloomForCausalLM \n", "tokenizer = transformers.AutoTokenizer.from_pretrained('joaoalvarenga/bloom-8bit')\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "model = BloomForCausalLM.from_pretrained('joaoalvarenga/bloom-8bit', low_cpu_mem_usage=True)\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "model.to(device)" ] }, { "cell_type": "markdown", "id": "82ea942b-7fcf-4bbc-adb9-be0bbd98b9f8", "metadata": {}, "source": [ "### Fine-tune and save model" ] }, { "cell_type": "code", "execution_count": null, "id": "26cacf36-56f7-4f9c-b975-33dd34b1ff9c", "metadata": {}, "outputs": [], "source": [ "def add_adapters(model, adapter_dim=16):\n", " assert adapter_dim > 0\n", "\n", " for module in model.modules():\n", " if isinstance(module, FrozenBNBLinear):\n", " module.adapter = nn.Sequential(\n", " nn.Linear(module.in_features, adapter_dim, bias=False),\n", " nn.Linear(adapter_dim, module.out_features, bias=False),\n", " )\n", " nn.init.zeros_(module.adapter[1].weight)\n", " elif isinstance(module, FrozenBNBEmbedding):\n", " module.adapter = nn.Sequential(\n", " nn.Embedding(module.num_embeddings, adapter_dim),\n", " nn.Linear(adapter_dim, module.embedding_dim, bias=False),\n", " )\n", " nn.init.zeros_(module.adapter[1].weight)\n", "\n", "add_adapters(model)\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e293eb3-979a-46d7-97b8-cde296f45da8", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "from bitsandbytes.optim import Adam8bit\n", "\n", "model.gradient_checkpointing_enable()\n", "\n", "wikisql = load_dataset(\"wikisql\", streaming=True)\n", "optimizer = Adam8bit(model.parameters(), lr=1e-5)\n", "\n", "with torch.cuda.amp.autocast():\n", " for row in tqdm(wikisql['train']):\n", "\n", " batch = tokenizer(row['question'] + row['sql']['human_readable'], truncation=True, max_length=128, return_tensors='pt')\n", " batch = {k: v.cuda() for k, v in batch.items()}\n", "\n", " out = gpt.forward(**batch,)\n", "\n", " loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),\n", " reduction='mean')\n", " print(loss)\n", " loss.backward()\n", "\n", " optimizer.step()\n", " optimizer.zero_grad()" ] }, { "cell_type": "code", "execution_count": null, "id": "4e2251f6-1a5c-4193-b971-0840d6d59c32", "metadata": {}, "outputs": [], "source": [ "model.save_pretrained('bloom-8bit-fine-tuned')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.12" } }, "nbformat": 4, "nbformat_minor": 5 }