File size: 6,366 Bytes
4673b21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
"from mingpt.model import GPT\n",
"from mingpt.utils import set_seed\n",
"from mingpt.bpe import BPETokenizer\n",
"set_seed(3407)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"use_mingpt = True # use minGPT or huggingface/transformers model?\n",
"model_type = 'gpt2-xl'\n",
"device = 'cuda'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"number of parameters: 1557.61M\n"
]
}
],
"source": [
"if use_mingpt:\n",
" model = GPT.from_pretrained(model_type)\n",
"else:\n",
" model = GPT2LMHeadModel.from_pretrained(model_type)\n",
" model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n",
"\n",
"# ship model to device and set to eval mode\n",
"model.to(device)\n",
"model.eval();"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"\n",
"def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n",
" \n",
" # tokenize the input prompt into integer input sequence\n",
" if use_mingpt:\n",
" tokenizer = BPETokenizer()\n",
" if prompt == '':\n",
" # to create unconditional samples...\n",
" # manually create a tensor with only the special <|endoftext|> token\n",
" # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py\n",
" x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)\n",
" else:\n",
" x = tokenizer(prompt).to(device)\n",
" else:\n",
" tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n",
" if prompt == '': \n",
" # to create unconditional samples...\n",
" # huggingface/transformers tokenizer special cases these strings\n",
" prompt = '<|endoftext|>'\n",
" encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n",
" x = encoded_input['input_ids']\n",
" \n",
" # we'll process all desired num_samples in a batch, so expand out the batch dim\n",
" x = x.expand(num_samples, -1)\n",
"\n",
" # forward the model `steps` times to get samples, in a batch\n",
" y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n",
" \n",
" for i in range(num_samples):\n",
" out = tokenizer.decode(y[i].cpu().squeeze())\n",
" print('-'*80)\n",
" print(out)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n",
"--------------------------------------------------------------------------------\n",
"Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n",
"\n"
]
}
],
"source": [
"generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.10.4 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
|