{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "_nSGBgG98qRC" }, "outputs": [], "source": [ "# Trainer\n", "import torch\n", "from tqdm import tqdm\n", "\n", "iterator = tqdm(dataloader, desc=\"Training\", postfix={\"train_loss\":0.0})\n", "\n", "for item in iterator:\n", " item = tokenizer.bos_token + \" \" + item[0] + \" \" + tokenizer.eos_token\n", " encoded_inp = tokenizer(item, return_tensors='pt').input_ids.to(\"cuda\")\n", " logits = mamba_model(encoded_inp)\n", "\n", " labels = encoded_inp.to(logits.device)\n", " shift_logits = logits[:, :-1, :].contiguous()\n", " labels = labels[:, 1:].contiguous()\n", " loss_fct = torch.nn.CrossEntropyLoss()\n", " loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1))\n", "\n", " optimizer.zero_grad(set_to_none=True)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " # moving data's from gpu to cpu\n", " loss = loss.detach().cpu().numpy()\n", " logits = logits.detach().cpu().numpy()\n", " labels = labels.detach().cpu().numpy()\n", " encoded_inp = encoded_inp.detach().cpu().numpy()\n", " shift_logits = shift_logits.detach().cpu().numpy()\n", "\n", " iterator.set_postfix({\"train_loss\": loss.item()}, refresh=False)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "id": "feaR0XKtOGug" }, "outputs": [], "source": [ "# Inference\n", "import torch\n", "import torch.nn.functional as F\n", "\n", "\n", "def generate(model,\n", " tokenizer,\n", " prompt: str,\n", " n_tokens_to_gen: int = 200,\n", " sample: bool = True,\n", " top_k: int = 40):\n", " model.eval()\n", "\n", " input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(\"cuda\")\n", "\n", " for token_n in range(n_tokens_to_gen):\n", " with torch.no_grad():\n", " indices_to_input = input_ids\n", " next_token_logits = mamba_model(indices_to_input)[:, -1]\n", "\n", " probs = F.softmax(next_token_logits, dim=-1)\n", " (batch, vocab_size) = probs.shape\n", "\n", " if top_k is not None:\n", " (values, indices) = torch.topk(probs, k=top_k)\n", " probs[probs < values[:, -1, None]] = 0\n", " probs = probs / probs.sum(axis=1, keepdims=True)\n", "\n", " if sample:\n", " next_indices = torch.multinomial(probs, num_samples=1)\n", " else:\n", " next_indices = torch.argmax(probs, dim=-1)[:, None]\n", "\n", " input_ids = torch.cat([input_ids, next_indices], dim=1)\n", "\n", " output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]\n", "\n", " return output_completions" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }