{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch import Tensor\n", "import random\n", "from tqdm.auto import tqdm\n", "from mamba_ssm.modules.mamba_simple import Mamba\n", "\n", "def model_numel(m: nn.Module):\n", " return sum(p.numel() for p in m.parameters())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "raw_txt = Path(\"../shake.txt\").read_text()\n", "total_len = len(raw_txt)\n", "aux_len = int(total_len * 0.05)\n", "\n", "head_txt, test_txt = raw_txt[:-aux_len], raw_txt[-aux_len:]\n", "train_txt, valid_txt = head_txt[:-aux_len], head_txt[-aux_len:]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "len(train_txt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mambabit import string_to_bits, bits_to_string\n", "\n", "train_ds = string_to_bits(train_txt)\n", "valid_ds = string_to_bits(valid_txt)\n", "test_ds = string_to_bits(test_txt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def random_batches(split: Tensor, n_batch: int, bs: int):\n", " assert bs % 8 == 0, \"have mercy\"\n", " max_allowed_pos = len(split) // 8 - bs // 8\n", "\n", " values = []\n", " for i in range(n_batch):\n", " pos = random.randint(0, max_allowed_pos)\n", " values.append(split[pos*8: pos*8+bs])\n", " return torch.stack(values).cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from mambabit import dim_model, n_vocab, n_layers, MambaBit" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mamba_bit = MambaBit().cuda()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if True:\n", " mamba_bit.load_state_dict(torch.load(\"mamba_bit.bin\"))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "def train(m: nn.Module, \n", " n_epoch: int = 100, \n", " n_batch: int = 4, \n", " bs: int = 256):\n", " opt = torch.optim.AdamW(m.parameters(), lr=0.0001, fused=True)\n", "\n", " for e in (bar := tqdm(range(n_epoch))): \n", " b = random_batches(train_ds, n_batch, bs)\n", "\n", " y_pred = m(b)\n", " y_pred = y_pred[:, :-1].reshape(-1, n_vocab)\n", " y_true = b[:, 1:].ravel()\n", "\n", " loss = F.cross_entropy(y_pred,y_true)\n", " loss.backward()\n", " opt.step()\n", " opt.zero_grad()\n", " \n", " l = loss.item()\n", " bar.set_description(f\"L:{l:.10f}\")\n", "\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "if True:\n", " train(mamba_bit, 5000, 9, 8*128)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "torch.save(mamba_bit.state_dict(), \"mamba_bit.bin\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# TEST\n", "@torch.no_grad()\n", "def test(prompt: str, chars=10):\n", " x0 = decode_bits(prompt).cuda()[None]\n", " x = x0.clone()\n", " process = chars * 8\n", " for _ in tqdm(range(process)):\n", " y = mamba_bit(x)\n", " new = y[:, -1:].argmax(-1)\n", " x = torch.cat((x, new), 1) \n", " return encode_bits(x)\n", "\n", " \n", "print(test(\"FIRST CIT\", chars=10))" ] } ], "metadata": { "kernelspec": { "display_name": "sd", "language": "python", "name": "sd" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 2 }