{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", "from mingpt.model import GPT\n", "from mingpt.utils import set_seed\n", "from mingpt.bpe import BPETokenizer\n", "set_seed(3407)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "use_mingpt = True # use minGPT or huggingface/transformers model?\n", "model_type = 'gpt2-xl'\n", "device = 'cuda'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of parameters: 1557.61M\n" ] } ], "source": [ "if use_mingpt:\n", " model = GPT.from_pretrained(model_type)\n", "else:\n", " model = GPT2LMHeadModel.from_pretrained(model_type)\n", " model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n", "\n", "# ship model to device and set to eval mode\n", "model.to(device)\n", "model.eval();" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "\n", "def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n", " \n", " # tokenize the input prompt into integer input sequence\n", " if use_mingpt:\n", " tokenizer = BPETokenizer()\n", " if prompt == '':\n", " # to create unconditional samples...\n", " # manually create a tensor with only the special <|endoftext|> token\n", " # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py\n", " x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)\n", " else:\n", " x = tokenizer(prompt).to(device)\n", " else:\n", " tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n", " if prompt == '': \n", " # to create unconditional samples...\n", " # huggingface/transformers tokenizer special cases these strings\n", " prompt = '<|endoftext|>'\n", " encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n", " x = encoded_input['input_ids']\n", " \n", " # we'll process all desired num_samples in a batch, so expand out the batch dim\n", " x = x.expand(num_samples, -1)\n", "\n", " # forward the model `steps` times to get samples, in a batch\n", " y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n", " \n", " for i in range(num_samples):\n", " out = tokenizer.decode(y[i].cpu().squeeze())\n", " print('-'*80)\n", " print(out)\n", " " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n", "--------------------------------------------------------------------------------\n", "Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n", "\n" ] } ], "source": [ "generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.10.4 64-bit", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.4" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" } } }, "nbformat": 4, "nbformat_minor": 2 }