{ "cells": [ { "cell_type": "markdown", "id": "df85b023-cdb5-498e-8373-0fd5b7c31853", "metadata": {}, "source": [ "## Load Stuff" ] }, { "cell_type": "code", "execution_count": 1, "id": "895fc851-817b-4c23-baf4-72cf73238781", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from mario_gpt.dataset import MarioDataset\n", "from mario_gpt.prompter import Prompter\n", "from mario_gpt.lm import MarioLM\n", "from mario_gpt.utils import view_level, convert_level_to_png" ] }, { "cell_type": "markdown", "id": "28c11f07-b604-4603-8fe3-d53874ba02a8", "metadata": {}, "source": [ "### Load Model" ] }, { "cell_type": "code", "execution_count": 2, "id": "6f656e57-24a6-4624-b6ed-8aa871581007", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using shyamsn97/Mario-GPT2-700-context-length model\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:1177: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Using shyamsn97/Mario-GPT2-700-context-length tokenizer\n" ] } ], "source": [ "mario_lm = MarioLM()" ] }, { "cell_type": "code", "execution_count": 3, "id": "1a60f6ed-42be-4d17-af15-151fa24e0f91", "metadata": {}, "outputs": [], "source": [ "TILE_DIR = \"../data/tiles\"" ] }, { "cell_type": "markdown", "id": "a7d7bd55-14d4-45a3-9539-c7c385f63070", "metadata": {}, "source": [ "### Load Dataset (Optional)" ] }, { "cell_type": "code", "execution_count": 4, "id": "6c0840d0-ea5b-4111-9198-6b5a716083bd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "No level string specified, using default string FULL_LEVEL_STR_WITH_PATHS...\n", "\n", "\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Token indices sequence length is longer than the specified maximum sequence length for this model (102116 > 1024). Running this sequence through the model will result in indexing errors\n" ] } ], "source": [ "dataset = MarioDataset(mario_lm.tokenizer)" ] }, { "cell_type": "markdown", "id": "c80a131f-c68f-475d-ab24-acd3da814c39", "metadata": {}, "source": [ "#### View string representation of level" ] }, { "cell_type": "code", "execution_count": 5, "id": "2bdab45e-58cb-4bcb-8d6e-dee6c946d6fd", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['--------------------------------------------------',\n", " '--------------------------------------------------',\n", " '--------------------------------------------------',\n", " '--------------------------------------------------',\n", " '-------------------------------------------------o',\n", " '--------XSSSSS---------------------------------SSS',\n", " '--------X-----------------------------------------',\n", " '--------X-----------------------------------------',\n", " '-------EX--E-X---------------xxxx-?-----------xxxx',\n", " '--------XSS?SX---QQ?QQ------xx<>-x-----------xx--?',\n", " '---------------------------xx-[]--x---------xx----',\n", " '--------------------------xx--[]---x-------xx-----',\n", " 'xxxxxxxxxxxxxxxxxxxxxxxxxxx---[]----xxxxxxxx------',\n", " 'XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX---XXX']" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "view_level(dataset.input_ids[:700], mario_lm.tokenizer)" ] }, { "cell_type": "markdown", "id": "99be5b3a-c968-4fbd-a51a-f623003072c0", "metadata": {}, "source": [ "#### Image" ] }, { "cell_type": "markdown", "id": "d5614fc2-59bc-40ee-a92a-0cfd971e1ad3", "metadata": {}, "source": [ "##### Previewing the first 50 columns of the dataset" ] }, { "cell_type": "code", "execution_count": 6, "id": "0d6a3bf3-d050-4760-a48e-8b8655142c67", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/kokkgoblin/miniconda3/envs/py39/lib/python3.9/site-packages/Pillow-9.1.1-py3.9-linux-x86_64.egg/PIL/Image.py:992: UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images\n", " warnings.warn(\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAyAAAADgCAIAAAB0EpUWAAAYPUlEQVR4nO3dT2wcVZ7A8fe6q9OOsWNv4mwDQXHsvRhWyybZbAA7M4qGS9QoWhGFA1oJgmwNITkwBywuHo1G8S0ckBgFZXGCJZgDYsgcMmmNOIyi7MYBokSMUFgIKAECJuGPkzh2/Kf/1B6abZz+U1Wv+3X5VdX3o9EInG93qitdzo/X5Sq551BWeCalsG3vOT09PT09PT19FPuYQi7Unp2enp6enp6ePpq92oAFAAAAVwxYAAAAmjFgAQAAaMaABQAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgWUxKtQfQ09PT09PT09M7s5KWsG1RsEWu4H6zaCkFPT09PT09PT29cy/3HMq6VHc+wPVJ6enp6enp6ekj3qudg6X07PT09PT09PT00ew5yR0AAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANAsJqXaA+jp6enp6enp6Z1ZSUvYtijYIldwv1m0lIKenp6enp6ent65l3sOZV2qOx/g+qT09PT09PT09BHv1c7BUnp2enp6enp6evpo9pzkDgAAoBkDFgAAgGYMWAAAAJoxYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAaBaTUu0B9PT09PT09PT0zqykJWxbFGyRK7jfLFpKQU9PT09PT09P79zLPYeyLtWdD3B9Unp6enp6enr6iPdq52ApPTs9PT09PT09fTR7TnIHAADQjAELAABAMwYsAAAAzRiwAAAANGPAAgAA0IwBCwAAQDMGLAAAAM0YsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0i0mp9gB6enp6enp6enpnVtISti0KtsgV3G8WLaWgp6enp6enp6d37uWeQ1mX6s4HuD4pPT09PT09PX3I+gdmXlZ4gOo5WEpbQ09PT09PT08fwV4IYSk/AgAAINr+8x8Olf753qHLQoiDBw8uDfgpQgAAAAVLpyshxORYjxBieHh46RcZsAAAAJRljl3OHLtc61cZsAAAALwqLl9ljl1O7+pJ7+opfj5YuYjFgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmjFgAQAAePXH6/uEEMUrYN07dLl4BazKu+VYBt6wmp6enp6enp7eqL5ScbSqxUpawrZFwRa5gvvvJKWgp6enp6enp49av9Qfr+9zvdmz3HMo6/Ksd/4GRo2TIeiPPnT2yEj/4OiEe7y5Xwhhn3cvw+SZ9/7de2zgny89PT09vT/90YfOKjyg+dQGLGj3+sNnhRCuM1ZxuhJCjKWr/OpQJrRf5/0JAPBi218SlV9cxr+/OMndCIOjE0dG+mv9amm6AgAAgcAK1jIrrmAV1ZqxhjJ+bY15eH8CALwY31dlBWsZWcu9AVBTWoo08OO8Znz9f6okAACUM+3vRz4iBAAA0IwBCwAAQDPOwVpmnIPljPcnAMALzsFCQ0z4XNnPrwMA4IVpf3/xESEAAIBmDFgAAACacQ7WMiudg+VwNdGxdHQ/LOP9CQDwgnOwUIXztdqLH/EWZyzTrvPR7K9zHSwAgBem/f3IR4TLz8udcGr9cQIAAAPFpFR7AL3m3vN9BqM5Yxn350VPT09Pb2RvGrn3taxti4ItcgVh2261FElL0GvsX3/OrM+MTcP7k56enp4+iH+fqp3kLqX7i6RX6o8+dPbISP/g6IR7vLlfKJ70beDrpaenp6enj0Kvdg6W0rPTe+wHRydqXcO9xPsniY1vDz09PT09PX2DPSe5G8F5xqpvugIAAMuFyzSYwss6FgAACARWsAAAADRjwAIAANCMAQsAAEAzBiwAAADNGLAAAAA0Y8ACAADQjAELAABAM66DZQqHq4mOpcVQxs9tAQAADWEFywjO12ofyoixtG/bAgAAGsWAtfy83AmHGQsAgACJSan2AHrNvef7DBZnLOO2n56enp6enr6ClbSEbYuCLXIF95tFSyno9fZKhjJi7+NmbT89PT09PT19ZW/NZ12iMvR6e/v8xJGR/sHRCdeyuNZl2vbT09PT09PTV1I7B8t1ZKOvox8cnTgy4vJBofdPEhvfHnp6enp6evoGe05yN4LzjFXfdAUAAJYL18EyhZd1LAAAEAisYAEAAGjGgAUAAKAZAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZlwHyxQOVxMdS4uhjJ/bAgAAGsIKlhGcr9U+lBFjad+2BQAANIoBa/l5uRMOMxYAAAESk1LtAfSae8/3GSzOWMZtPz09PT09PX0FK2kJ2xYFW+QK7jeLllLQ6+2VDGXE3sfN2n56enp6enr6yt6az7pEZeid/dONl8u+surk70v/PL39d2W/ap+fODLSPzg64frMxbUu07b/wl2/UXr+V//trNLr3XNI7QWY9n6gp49yv+XPCe/x4OjpZ97b2tTtoaf3s1f7KULXkY2+zNLppPivlTPK4Kj7jOX9k0Tn7VHtvWy/6vOb/Hrp6en19o/t6vHYy80DSv9BZebrpacv9Zzk3kRl04nDF4szR63nqW/aaJz37Vdl5usFfBYv5NbMXl01P7X69rXUrStrZq8mc/MB6j26Z+hygwEgzHv/u/ZcB6tZHAYRh3WsJm+UAtXtV2Xa6wX8t/r2tf2ZJ7/p2tgxO5maunCz7b6TD+57vzedjScD0Xv07dhPi1gnjt0xSJUWt0oB4MC0979rzwpWU6w6+ftnT8w8e2Km8peKX9eyDtQ8Qd9+IBAS+YXuyTObLr6xkGj7oO/plsXp3aee77t6Nii9kuJ0NZT56X+iYt4CnJn2/nftGbD0Kw4fhx9rs9/7a9mM8uyJGfu9vx5+rE1o+qytGYK+/UCA5Kz4xfXpV3aMv7N1ePzRw4nc7Uc+fStAfR2+HespLVnx4SCUmPb+d+75iLDpqq4DBUjQtx8wWdZqubJ2oxCiIOPTK7tutPd0zE4GqAf8ZNr737lnBUuzynWd4pJPcUHINV52Qd9+IFgSufnu786tmp+6a3E6Nf1F561LU+3dAeq9e2xXT+meqnw4iPqY9v537lnBaq6yoaRyRjFc0LcfMFw8n98weeqJMwe+7+zd/Nnbt1rXne/dGaBeSfGTwdLJWIAq097/zj0DVhP914HdDv9qvqBvP2C+rLXiq7sHUtc/uf/LzExr6t0tL364fnuA+kbww4NQZdr737lnwNJsevvvln5w9uvf/qksWDqmNH6xA+2Cvv1AsOTjyc/XbTu+aX/H3I8zyc5FtwsimNbX4bFdPZzbjvqY9v537hmw9CuNHS/84l8qf/XXv/3TS//9UeXXHa6uOZb2dTm9vu1XZc7rBZZdXlpTrang9kpOHLs8lla4wjtQxrT3f62ek9yXwfDw8PDw8NKvOF+7fCgjxtJN3iYVlduvKlivF2gSW8YWEu25mNcb9pnWN4JFLKgy7f3v2jNgNcvw8LB8eEfZF+XDO0pfLM0oXu4M4//M4X37VZn5egH/zSXazvU99fWaBwLae1Q2Sy39ccKqAVCVae9/114+82pW6RaGUqrd8jBq/T/Pvlz8h9L8UfrJu8p55eDBg0pjylhanN7Z3D8v1e3/uO03Ss8/vk/hP459eL309PTN67f9JeH9o8B7hi7z9xF9mHoraQnbFgVb5Aruj5RS0Dv3JaXhST68o+rVDQ4ePOjydBWGMmLv42Ztf93P74UPr5eenr55/VBGiIzC6pRp209P30gv9xzKulR3PsCo8TAE/dGHzh4Z6R8cnXCPN/cLIezz7mUjlG7APDh6WgjFoUkIpddr2vtz4LjCCtzg6OkjIwOB7p95b6v33sDjK+j9688193Qo044venqTe9Xv/2o/Rai0NfQe+8HRCdeZo3TeUtUBqNYZS/V93fuSvtw8oPr89vkJpderxJ8/L9X9E+he6S9gM4+voPcvfaz2KO9eUDzVysz9Q0/vZ6/0/ZPLNBjBeeaob9poxD1Dl52vAegaODPt9apS3T9B7+FRvJDrnPshG19hFbKJ/GIulphJdi5YLY30o3+74yEjvyp/krLAT6qvF1iqGceLD7x//2TAMkVx5nDNal0gStfXi0rvnrJbhpWG91JQ3/MLz6/XTN73Tzh6eLT69rX9mSe/6drYMTuZmrpws+2+kw/ue783na1xuUIvfWmichikio3zpDXyKzH6t5//XwjxgtqLq3P7gVqacbz4wPv3TwasgCl99Kb3Y8Gqym4ZNpYWJ45dLlsgrXt7QsDL/glTD1eJ/EL35Jl/nProfzf8xwd9T//rpT/vPvX8zda1H63bpqVvRNl0pYWf24/wMfl4ceXl+yfXwYK7b8d6SiM5V6yppLp/gt7DQc6KX1yffmXH+Dtbh8cfPZzI3X7k07c09nVbuoKlkW/bj1Ay9njxzuH7JwMWAGiTtVqurN0ohCjI+PTKrhvtPR2zkxr7ujVjBUv4uP0IJWOPFy34iDBglp7b1IzzsZZ+nLf0nqxlHzbXvT2Do9WbIPKyf8LUw4tEbr77u3Or5qfyMSs1/UXnrUuX7v2Fxr5uledgaeHb9iOUjD1eXHn5/smAFTBNPe+qckIqrnyWfdjcyPaEjOv+CVkPV/F8fsPkqSfOHPi+s3fzZ2/fal13vnenxr5uTVrB8m37EUrGHi9euH7/ZMCCAn64zJnq/gl6j0pZa8VXdw+krn9y/5eZmdbUu1te/HD9do193Zq0guXb9iOUjD1e6lD5/ZMBC+6WroWikur+CXoPB/l48vN1245v2t8x9+NMsnPR7QfIVfu6NWkFy7ftRygZe7x45/D9kwHLFA5X1xxL/7z86M91sMqcOHa51hXA635+j683EBz2Tyh7uMpLa6o11by+Dk1awSryYfsRYgYeL945fP9kwDKC87XLi6c0FWcOP6+DVabq5Wvr2x7vrzdAVK9+HvQelWwZW0i052Jeb1im2jeiGdOVn9uP8DH5eFFV9fsnl2lYfl7uDOPz5TrLFjwf29VTNu408omSga9Xler+CXoPj+YSbef6nvp6jdeb/Kn2jWjGdbD83H6Ej8nHiwPv3z8tA29YHa3e8333/Jw5KifxWvdaUaX6ek8b9udVpLp/gtsbd7yY3d9Y2fXmwAGNvai4JXPV+9t4uenNCxX/X7k9rhp8varPTx/uXvvxYtr3f7n3taxti4ItcgX330lKkbQEvcb+9edMXO00h2nvz6j9eZm2/6PWH91rCSGklC6pOtu2hRDPjeWMer309Cb3qt//5Z5DWYXa7HGYnp6ePkx96Rv6Sx//9MXiB3xl51GVzqyq9fFf2dlXpVUxvv/T0zevVzsHS3UxjZ6enp7en96BlrOvTHu99PSG9/wUIQBUFy/kOud+yMZXWIVsIr+YiyVmkp0LVouu3jdNujqDKmP3D7QIzfGiCwMWAFS3+va1/Zknv+na2DE7mZq6cLPtvpMP7nu/N52tcXlD1d43zbsClhJj9w+0CM3xoguXaQCA6hL5he7JM5suvrGQaPug7+mWxendp57vu3pWV+8bE6YrYfD+gRahOV50YcACgJpyVvzi+vQrO8bf2To8/ujhRO72I5++pbH3h/YrYNXNzP0DXcJxvOjCgAUANWWtlitrNwohCjI+vbLrRntPx+ykxt4fhqxgCVP3D3QJx/GiCwMWANSUyM13f3du1fzUXYvTqekvOm9dmmrv1tj7w5wVLDP3D3QJx/GiCye5A0BN8Xx+w+SpJ84c+L6zd/Nnb99qXXe+d6fG3h/mrGCZuX+gSziOF11YwQKAmrLWiq/uHkhd/+SXf/+DEOLdLS9+uH67xt4f5qxgmbl/oEs4jhddWMECgJry8eTn67Yd37S/Y+7HmWTnotsPkKv2/jBnBcvM/QNdwnG86MIKFgC4yEtrqjXl/bu/at9s5qxgFZm2f6BX0I8XXRiwAKA6W8YWEu25mNc7vKr2vjFkBcvY/QMtQnO86MKABQDVzSXazvU99fWaB9zTunrfGLKCZez+gRahOV50sQy8ATU9PT29Cf2NlV1vDhxoXu+bWitYhu/PZm8Pvd5+2Y8X03oraQnbFgVb5Aruj5RS0NPT09P70+tS616Epr1eevow9dZ81iUqQ09PT0/vT69LrRUs014vPX2YerVzsFQXt+np6enp/ekdaDkHy7TXS09veM91sABERbyQ65z7IRtfYRWyifxiLpaYSXYuWC26emMZ8lOEqkKz/wMqsseLLgxYAKJi9e1r+zNPftO1sWN2MjV14WbbfScf3Pd+bzpb4wI8qr2xap2DZbjQ7P+AiuzxoguXaQAQFYn8QvfkmU0X31hItH3Q93TL4vTuU8/3XT2rqzdWEKcrEaL9H1CRPV50YcACECE5K35xffqVHePvbB0ef/RwInf7kU/f0tibyZDrYNUhHPs/uKJ5vOjCgAUgQrJWy5W1G4UQBRmfXtl1o72nY3ZSY2+mgK5gibDs/+CK5vGiCwMWgAhJ5Oa7vzu3an7qrsXp1PQXnbcuTbV3a+zNFNwVrHDs/+CK5vGiCye5A4iQeD6/YfLUE2cOfN/Zu/mzt2+1rjvfu1Njb6bgrmCFY/8HVzSPF11YwQIQIVlrxVd3D6Suf/LLv/9BCPHulhc/XL9dY2+m4K5ghWP/B1c0jxddWMECECH5ePLzdduOb9rfMffjTLJz0e0HyFV7MwV3BSsc+z+4onm86MIKFoDIyUtrqjXl/bu/am+a4K5gFQV9/wdd1I4XXRiwAESFLWMLifZcLNGk3lgBXcEKzf4PqMgeL7owYAGIirlE27m+p75e80CTemMFdAUrNPs/oCJ7vOhiSal2C0N6enr6gPY3Vna9OXCgeb3q9vim1gpWyP68mr09UesDd7yY1ltJS9i2KNgiV3B/pJSCnp6ent6fXpda9yI07fXS04ept+azLlEZenp6enp/el1qrWCZ9nrp6cPUq52Dpbq4TU9PT0/vT+9AyzlYpr1eenrD+/LrYMULuc65H7LxFVYhm8gv5mKJmWTngtVS6yno6enpw9qHRkB/ilCVae+fqPUoUz5grb59bX/myW+6NnbMTqamLtxsu+/kg/ve701na1zQgp6enj6sfWjUOgcrZEx7/0StR5nyjwgT+YXuyTObLr6xkGj7oO/plsXp3aee77t6ttbj6enp6cPah0YUpith3vsnaj3KVDkHK2fFL65Pv7Jj/J2tw+OPHk7kbj/y6VsOT0FPT08f1j4cAnodrDqY9v6JWo+lqgxYWavlytqNQoiCjE+v7LrR3tMxO+nwFPT09PRh7cMhIitYwrz3T9R6LFVlwErk5ru/O7dqfuquxenU9Bedty5NtXc7PAU9PT19WPtwiM4Klmnvn6j1WKr8JHchRDyf3zB56okzB77v7N382du3Wted793p8BT09PT0Ye3DITorWKa9f6LWY6kqA1bWWvHV3QOp65/c/2VmpjX17pYXP1y/3eEp6Onp6cPah0NEfopQmPf+iVqPpaoMWPl48vN1245v2t8x9+NMsnPR7Qcy6enp6cPah0NEpith3vsnaj2WqjJgFeWlNdWa8v5E9PT09GHtgy46K1hFpr1/otajqPwkd1vGFhLtuVjC4+Pp6enpw9qHRkSmK9PeP1HrUaZ8wJpLtJ3re+rrNQ94fDw9PT19WPvQiMhPEZr2/olajzLymVezSrcwlFLtlof09PT09PX14/t+Wjx46eOfvlIcksrWokqrU7VGqLIVrBf+/29Mvv/T0zevt5KWsG1RsEWu4P5IKQU9PT09vT+9LrXOwTLt9dLTh6mX9vmJIyP9g6MTLq0QcnO/EIKenp6enp6efs+hrGv580Oav7x09KGzRu2fmBBicHTiyEi/l7qInp6enp6ent47pWmp7t6o/RPz8pjKvUlPT09PT09Pbxpz9s/P18HyMpd5/z3o6enp6enpo9CbxpD9U+VmzwAAAGgEAxYAAIBmDFgAAACaMWABAABoxoAFAACgGQMWAACAZgxYAAAAmv18HSyHq4eNpcVQpvyL9PT09PT09PSmMWT/xFxrIcRQRoylvT47PT09PT09fUR605izf2KudeVj6Onp6enp6elNY9T+ka7pUqprg/T09PT09PSh7E/vzCrdkllKtVs4q/bj+xLeYx/2j9qABQAAIITY+1rWtkXBFrmC+yQkpUhaoqn9688pDFg++D9omqPgoC0DtAAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = convert_level_to_png(dataset.input_ids[:700], TILE_DIR, mario_lm.tokenizer)[0]\n", "img" ] }, { "cell_type": "markdown", "id": "28a7e683-a9a2-4321-b21a-807daf7aa744", "metadata": {}, "source": [ "#### Set device" ] }, { "cell_type": "code", "execution_count": 7, "id": "7a6f684a-63a9-4a34-9a57-fd6aa84375a0", "metadata": {}, "outputs": [], "source": [ "device = torch.device('cuda')\n", "mario_lm = mario_lm.to(device)" ] }, { "cell_type": "markdown", "id": "3869772f-e3a6-43d4-94ee-40364028bea8", "metadata": {}, "source": [ "## Generating Levels" ] }, { "cell_type": "code", "execution_count": 45, "id": "1e7589f2-2b48-4174-9fc7-7e7de7ff3615", "metadata": {}, "outputs": [], "source": [ "prompts = [\"many pipes, many enemies, some blocks, high elevation\"]" ] }, { "cell_type": "markdown", "id": "aa0a437f-4123-44b2-b08f-985f60165fb2", "metadata": {}, "source": [ "##### We generate 1399 predictions for an even 1400 output (including the input seed which is just a single block). Mario Levels have height of 14, so we generate 100 columns." ] }, { "cell_type": "code", "execution_count": 46, "id": "766362fb-8b90-43a4-b405-17fed2342d31", "metadata": { "scrolled": true, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "shape: torch.Size([1, 685]), torch.Size([1, 1400]) first: \n" ] } ], "source": [ "generated_level = mario_lm.sample(\n", " prompts=prompts,\n", " num_steps=1399,\n", " temperature=2.0,\n", " use_tqdm=True\n", ")" ] }, { "cell_type": "code", "execution_count": 47, "id": "777f94cf-a765-4f7a-a7b4-223c29680e17", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABkAAAADgCAIAAADUoj0kAAAwcElEQVR4nO3dXYxUVaLo8bW7qqjutpvuAJ5SMbT0U8vJmKYvx4+GuSHj5KRtDg9jEOOZHEXo+EHfhJmEHpMJxkzo3AfxwROPGGYa7MTxgVHHBw6t4WFCmMuHIgw3BqPoxQ+0BZUW+4P+qKre92FjWVRX79qraq9da+39/4UY6P7XZlXXrtqL5a5d1qbdaeGZZQnb9p7T09PT09PT09PT09PT09PT09NX2tdI5EJu6/T09PT09PT09PT09PT09PT09JX3cgtYAAAAAAAAQMBYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1ljAAgAAAAAAgNZYwAIAAAAAAIDWWMACAAAAAACA1mosS+4G9PT09PT09PT09PT09PT09PT0QfbxZFzYtpi1RWZW2HbprdPT09PT09PT09PT09PT09PT0wfZW5t2p0tU19+g5Ebp6enp6enp6enp6enp6enp6el97OWugSW1dXp6enp6enp6enp6enp6enp6+sp7LuIOAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACtsYAFAAAAAAAArbGABQAAAAAAAK2xgAUAAAAAAACt1ViW3A3o6enp6enp6enp6enp6enp6emD7OPJuLBtMWuLzKyw7dJbp6enp6enp6enp6enp6enp6enD7K3Nu1Ol6iuv0HJjdLT09PT09PT09PT09PT09PT0/vYy10DS2rr9PT09PT09PT09PT09PT09PT0lfdcxB0AAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFpjAQsAAAAAAABaYwELAAAAAAAAWmMBCwAAAAAAAFqrsSy5G9DT09PT09PT09PT09PT09PT0wfZx5NxYdti1haZWWHbpbdOT09PT09PT09PT09PT09PT08fZG9t2p0uUV1/g5Ibpaenp6enp6enp6enp6enp6en97GXuwaW1Nbp6enp6enp6enp6enp6enp6ekr7+NytwAAABEQm800T36Xji2Iz6YT2ZlMTWI82Twdr61WDwAAgIhjAQsAABRadPVS79BDXy1pb5oYTo2c/aHh1sN3bH2ntTsdS1alBwAAQMTJvYUQAABEQSI73TJ8fOW5V6YTDe+2PVI7M7rhyLa2iyer1QMAACDiWMACAABFZOKxc8u6X+gafOPOvsF79yQyV+/5aH8VewAAAEQZC1gAAKCIdLz2wo3tQohZKzZat+RK4/KmieEq9gAAAIgyFrAAAEARicxUyzenFk6N3DAzmhr9rHns/EhjSxV7AAAARBkXcQcAAEXEstnbho88cHznt82tHR+/Nla/9HTr+ir2AAAAiDIWsAAAQBHp+IIvblqd+v7D2z8fGq9PHVr11Jlla6vYAwAAIMpYwAIAAEVkY8lPlq45sLK3afLyeLJ5Jpasbg8AAIAoYwELAADMK2vFR+pT+vQAAACIJi7iDgAACtlWzXSiMVOT0KQHAABAxLGABQAACk0mGk61Pfzl4hWa9AAAAIg469GX0rYtcwNL0NPT09PT09PT09PT09PT09PTB9bHk3Fh22LWFpnZ0re0LEFPT09PT09PT09PT09PT09PTx9kb23anS5RXX+Dkhulp6enp6enp6enp6enp6enp6f3sZe7BpbU1unp6enp6enp6enp6enp6enp6Svv43K3AAAYIjabaZ78Lh1bEJ9NJ7IzmZrEeLJ5Ol5rSg93UXu8dBsPAAAAAsYCFgCE06Krl3qHHvpqSXvTxHBq5OwPDbcevmPrO63d6VjSiB7uovZ46TYeAAAABEzuLYQAAFMkstMtw8dXnntlOtHwbtsjtTOjG45sa7t40pQe7qL2eOk2HgAAAASMBSwACK1MPHZuWfcLXYNv3Nk3eO+eRObqPR/tN6iHu6g9XrqNBwAAAEFiAQsAQisdr71wY7sQYtaKjdYtudK4vGli2KAe7qL2eOk2HgAAAASJBSwACK1EZqrlm1MLp0ZumBlNjX7WPHZ+pLHFoB7uovZ46TYeAAAABImLuANAaMWy2duGjzxwfOe3za0dH782Vr/0dOt6g3q4i9rjpdt4AAAAECQWsAAgtNLxBV/ctDr1/Ye3fz40Xp86tOqpM8vWGtTDXdQeL93GAwAAgCCxgAUAoZWNJT9ZuubAyt6mycvjyeaZWNKsHu6i9njpNh4AAAAEiQUsAAi5rBUfqU+Z28Nd1B4v3cYDAACAYHARdwAIJ9uqmU40ZmoShvZwF7XHS7fxAAAAIGAsYAFAOE0mGk61Pfzl4hWG9nAXtcdLt/EAAAAgYNajL6VtW+YGlqCnp6enp6enp6enp6enp6enpw+sjyfjwrbFrC0ys6VvaVmCnp6enp6enp6enp6enp6enp4+yN7atDtdorr+BiU3Sq9Vv2L8+YIvLjz8h9zvR9c+U/Ddszf8JlLb1+3xoqenp/er33fXyb07Orf0Hysdd3QKIezTpct8j574F6nx6Pbzoaenp6enp5/blzF/kOpZf6CvpJf7FEKprdNr2Oev/jh/nLsGFOXt09PT04ep39Jfek7pzCaFEHt3dM79bs+QGOgucqueIbkJqJ4/H3p6enp6evq5vdT8oYxedjz09Lmei7hHSMHqj8sXo7l9GCc2m1k8cXHh1Miiq5dSYxcWT1xMZqYM6mXpNh7VTB+/LEWPrzOnnG8j5c0mjRC1/QcAAB/Jzh8iO99AwOTOwIK5XBZ6KjyPKRzbh4kWXb3UO/TQV0vamyaGUyNnf2i49fAdW99p7U7Hkkb0ut1f3Zg+flnqHl/3OWVOz5Dc1zUXtf0HAAB/eZw/lN0DZeAMrEhYePgPjx8cf/zg+NxvOV+v8Dwm07cPQyWy0y3Dx1eee2U60fBu2yO1M6Mbjmxru3jSlF6WbuNRzfTxy6r64zvQfe1X/u/zfxV8XXNR238AAABCjwWs8HMWd/asa7BPvF2wBvT4wXH7xNt71jWICt6LZ/r2YbRMPHZuWfcLXYNv3Nk3eO+eRObqPR/tN6iXpdt4VDN9/LKi9viqxs8HAAAgTHgLYeQUPY+J7cNQ6XjthRvbhRCzVmy0bsmVxuVNE8MG9bJ0G49qpo9fVtQeX9X4+QAAAIQJC1ghN/e8JPvE23/cuUEI8djTr1t3dxXG634Tqe3DdInMVMs3pxZOjWRr4qnRz5rHzp+/5ecG9bJ0G49qpo9fVnUf3/xrXXm5Htamfyv7rwpI1PYfAACAcGMBK1rsE2+7/JHtwzixbPa24SMPHN/5bXNrx8evjdUvPd263qBelm7jUc308cuq7uNb9LJWPUPzfl1/Udt/AAAAwo0FrAhxTlya749sHyZKxxd8cdPq1Pcf3v750Hh96tCqp84sW2tQL0u38ahm+vhlRe3xVY2fDwAAQJiwgBVyo2ufyX8X3mNPv14Q5C8Dja59Jmrbh+myseQnS9ccWNnbNHl5PNk8E0ua1cvSbTyqmT5+WVF7fFXj5wMAABAmLGCFX25ZZ/vPfzb3u489/fpzf38/yttHCGSt+Eh9ytxelm7jUc308cvy/fG1Ojrn+9ZA909vBvRy3SsTRW3/AQDAFx7nD2X3QBlYwILo6+sTQuzatYvtwyy2VTOdaMzUJAztZek2HtVMH78sRY+vy2xS/HiJK2dOmbvWlct1r/K//n88DrRKorb/AADgI+/zh/J6oDw11R4AAtLX11fwmX1CCOvurtwXnWWgyG4fJppMNJxqe/jLxSsM7WXpNh7VTB+/LBWPr/ts0jHfcpXporb/AADgF9n5Q5TnGwhY3LKEbUvcgN6svvDmd3flPrlv7npQBLdPb3R/pW7Jn1fvNLcvSfPxqO5NH79s7//j62E26ShvThmyn6fq8dDT09PT05vRS84fZPujut1ferP6J/6Utm0xa4vMbOlbWpZIxgW9Qf2K8edzf8ydo+SsARUsADlvwft/zb+J1PZ1e7zo6enp/epfflLtu+eYP9DT09PT04evZ/5Ar3NvbdqdLlFdf4OSG6VX2qt+QTEd+zN9mHrVz3eeL2b1+Qv6jvwPaZ37Mazb7129d0fnlv5jpTfe0SnYH+jp6enp6ektse+uk1LzB/v0MeYb9IH1chdxl9o6vaL+uQ/kbuXd9hVh2L53ej6+9PT5Pc8X+vn6/NUr549z17C29JeeU3o/8999PPT09PT09PTh6GXnD8w36APr+RRC8/T/7bo/7vhFiSBq24e5YrOZ5snv0rEF8dl0IjuTqUmMJ5un47V+9brxMn6eL97ptj8o3Z8LVq9yX5RdwypvNhlKuu0/AABUkez8gfkGgsEClnly/4J1+Yer07j/y3bHL0T/3376rxBieyi2D3Mtunqpd+ihr5a0N00Mp0bO/tBw6+E7tr7T2p2OJX3pdeNl/DxfvNNtf1C3Pxddvcp9a741rArvTujptv8AAFBdsvMH5hsIQE21B4CqKfjXLNtH1SWy0y3Dx1eee2U60fBu2yO1M6Mbjmxru3jSr143QY4/Cs8X3fYHRfvzwsN/ePzg+OMHx+duwfm6y/IWXOi2/wAAAKAAC1jRlX9GBtuHJjLx2Lll3S90Db5xZ9/gvXsSmav3fLTfx143gY0/Is8X3fYH3/dnZ3Fqz7oG+8TbBWtYjx8ct0+8vWddg3A9RQsudNt/AAAAkI+3EEaX6WdIReGMkghKx2sv3NguhJi1YqN1S640Lm+aGPax101g44/I80W3/UH1/lz0PCyUTbf9BwAAAPk4Ayu6TD9DKiJnlERNIjPV8s2phVMjN8yMpkY/ax47P9LY4mOvm8DGH5Hni277g7/789zzqpxTrpwTsgq+xUlYZdBt/wEAAEA+zsCKLtPPkIrIGSVRE8tmbxs+8sDxnd82t3Z8/NpY/dLTret97HUT2Pgj8nzRbX9Quj8XLFrNXcOCLN32HwAAAORjASu65n4qGdtH1aXjC764aXXq+w9v/3xovD51aNVTZ5at9bHXTWDjj8jzRbf9Qd3+/MedG1z+iPLotv8AAAAgHwtY0WX6GVJR+Nd4BGVjyU+Wrjmwsrdp8vJ4snmm1AfYy/a6CWz8EXm+6LY/+Ls/j659Jv+NgY89/XrBzfOXsUbXPlPx8CNHt/0HAAAA+VjAii7Tz5CKyBkl0ZS14iP1KXW9bgIYf6SeL7rtDz7uz7llqe0//9nc7z729OvP/f39uV+3Ojrn+7sGukXPkPehRYJu+w8AAFUhO39gvoEAcBH36DL9DKno/Gs8OmyrZjrRmKlJKOp1E+T4o/B80W1/qMr+3NfX19fXl/8Vl9mkEKJnSAx0V/IXhodu+w8AAFUkO39gvoFgsIAVXaZ/SmBEPlUtUiYTDafaHv5y8QpFvW6CHH8Uni+67Q/q9ue+vj7r7q6CL1p3d+W+mFvDcp9NOphTOnTbfwAAqBbZ+QPzDQTGevSltG3L3MAS9FXsB7fyP4fdsD/Th6lX/Xzn+WJW/88Tzzu/ya1P5T55cO561q5duwpOxXI30C2Ormd/oKenp6enj3ovNf+UfW8g8w36CvuaZFwk4yIRE5blaev01e1t27alHmHP7Dzmbl+3x4uevpKe5wt9fp+za9eua1+cs25VEHjXM6Td/aWnp6enp6cPvpcie2Ur5hv0FfbxqXTpLh99dXvrxwf2uQ+ufcV5Q1DBdW1yV7qZ7+1CBVfD2b7ipy0bvf1Nu+V+oLo9vvT0+VQ/33m+VLd/6X+c3Lujc0v/sZKlc2a+ffpYrt9++rpb2df/cW7vZfu6/Xzo6enp6enpg+9l5w/MN+iD7OWugSV7KgB9dXsXqq+GY8T2dXu86Okr6V3wfNG239J/bO+OEpeNyL+uhOreOz1/nvT09PT09PSV98w36LXtuYh7RKn+PDLTtw9fxGYziycuLpwaWXT1UmrswuKJi8nMlI89POL5UpTq/dNj7z7nmzvbU93DI16vAABeaDLfkMV8A3qKV3sAqI7cGRmK/k1r+vbhi0VXL/UOPfTVkvamieHUyNkfGm49fMfWd1q707GkLz084vlSlOr903vv5f9bBtnDC16vAABe6DPfkMV8AxriDKyIMv0MKf41boREdrpl+PjKc69MJxrebXukdmZ0w5FtbRdP+tXDI54vRaneP9mfw43HFwDgBfMNwEcsYEWUEdeoquL24ZdMPHZuWfcLXYNv3Nk3eO+eRObqPR/t97GHFzxf5qN6/2R/DjceXwCAF8w3AL+wgBVRpp8hxRklpkjHay/c2C6EmLVio3VLrjQub5oY9rGHFzxf5qN6/2R/DjceXwCAF8w3AL+wgBVRpp8hxRklpkhkplq+ObVwauSGmdHU6GfNY+dHGlt87OEFz5f5qN4/2Z/DjccXAOAF8w3AL1zEPaJMP0OKM0pMEctmbxs+8sDxnd82t3Z8/NpY/dLTret97OEFz5f5qN4/2Z/DjccXAOAF8w3AL5yBFVGmnyHFGSWmSMcXfHHT6tT3H/7P//tfQohDq546s2ytjz284PkyH9X7J/tzuPH4AgC8YL4B+IUzsCLK9DOkOKPEFNlY8pOlaw6s7G2avDyebJ4p9YG+sj284PkyH9X7J/tzuPH4AgC8YL4B+IUFrIjKnZGh6N+0pm8f/spa8ZH6lLoe7ni+uFO9f5bsrY7O+b410C16hoLuIYXXKwCAF1Wfb8hivgEN8RbCiDL9DCn+NW4E26qZTjRmahKKenjE86Uo1funx95ltieE6BkSA92B9vCI1ysAgBeazDdkMd+AnljAiijTr1HFNX2MMJloONX28JeLVyjq4RHPl6JU759eevfZniN/zqe6h3e8XgEAvNBhviGL+Qa0FbcsYdsSN6Cvbu8X08+Qmm/7uj1eEe+v1C358+qd6nrZ8RjX+4XnS9Fe9f5Zuvcw23M4cz7V/VG9Hy/del6v6Onp6em99NWfb8iOn/kGvcZ9TTIuknGRiAnL8rR1+ur2fjH9DKn5tq/b40VPX0nvF54vevZSZK8cUUav28+Hnp6enp6ePvheCvMN+oB7a9PudOkw7wZaLb+V0a8+IPH24C39R/fuWC3VP3riTqnx6DZ+03vVP396+iD7l5+89nx/7oNrX3QWoQrOpcqdXTXfknHBGVjbfzzHPGqv/6r7fXed3Lujc0v/sdJxR6cQwj59TLe+ZJnv0RP/4j3W8PGip6enp6fXoZedP6iev6mezzD/pK+kl/sUQqmta9uvu3+5x97qWD3QLddLPSE1HL/pfQA/f3p6PXsXvnwKoW73V89+S3/pOVz+mfZa9Xt3FHkLwHxXrOgZkpuA6vl40dPT09PT69DLHt91G49u46cPcR/Ri7jf3POpVCDbq6Z6/Kb3UCQ2m1k8cXHh1Miiq5dSYxcWT1xMZqZ87KEIn0LoC4/7szOHm28jc2dvuvVQhNdDAAgHRfPhso/Xio4vzDegJ7kzsELj64FrJ/Uc/Ot1Cx+5k31yQXm9aqrHb3oPRRZdvdQ79NBXS9qbJoZTI2d/aLj18B1b32ntTseSvvRQxJczsOB9f3afw82lST/fZSxkL2+Bong9BIBwUDcflj2+lzce71TPT4AyRPQMLIezGtIzdO2XmLM+UmGvmurxm97Dd4nsdMvw8ZXnXplONLzb9kjtzOiGI9vaLp70q4cirF75IvT780D3tV/5v8//VfB1SAn9/gMAEaHbfJjjCyIl0gtYOV8PLM+dwuPlzWiyvWqqx296Dx9l4rFzy7pf6Bp8486+wXv3JDJX7/lov489VFD9qaDRwf6MSrD/AEA46DYf5viC6GABC4CEdLz2wo3tQohZKzZat+RK4/KmiWEfe6jAGVh+YX9GJdh/ACAcdJsPc3xBdET0GliOdfcvz52/4+XNaLK9aqrHb3oPFRKZqZZvTi2cGsnWxFOjnzWPnT9/y8997KEC18DyS7j35/xrXXm5Htamf1M7nvAJ9/4DANGh23yY4wuiI9ILWOLHi3/nLq7ke6+a6vGb3sN3sWz2tuEjDxzf+W1za8fHr43VLz3dut7HHiqweuWXcO/PRS9r1TM079chK9z7DwBEh27zYY4viI6oL2AVkP0wO90+/E71+E3vUbl0fMEXN61Off/h7Z8PjdenDq166syytT72UIEzsPzC/oxKsP8AQDjoNh/m+ILoYAFLiOvfm6aiV031+E3v4aNsLPnJ0jUHVvY2TV4eTzbPlPqAXtkeKrB65Rf2Z1SC/QcAwkG3+TDHF0QHC1hCCHHwr58OdIt193s9nUe2V031+E3v4busFR+pT6nr4S/OwPJXyf3Z6uic71sD3UXefKdJ7+W6V6gcr4cAEA6+z4dlj+8Vjqck1fMToAx8CuFPZE/q0e0kINXjN71H5WyrZjrRmKlJKOqhCKtXvvC4P7vM3kSxS0rp0w90X/uV//v8XwVfhxReDwEgHBTNh2WP72WPxyPV8xOgPBFdwCpY+1h3//KCJeGCQLZXTfX4Te+hyGSi4VTbw18uXqGohyK5M7BQCS/7s/vszZE/h9Othzq8HgJAOKiYD1dyvFZxfGG+AW3FLUvYtsQNTO8dcy/+XfAV9z+6BKaP3/Ret/0tZP2VuiV/Xr1TXS87ntD3fpnvDCzd7q/mfen92cPszeHM4XTrZWn+eOnW83pIT09PH47e//mw5PH6qOrji+L5xlG9H1963fsn/pS2bTFri8xs6VtalkjGhdH9y0+qPXtf9c9T9fhNF7X9mT7cfe75/twH177onEVVsBSVW5xyOccqfw1r+4//i47ni7991F6f2X/o6enp6ekr72XnD6b/e5P5A30lvbVpd7pEdf0NSm6Unp6ent6Xft9dJ/fu6NzSf6x03NEphLBPH5Pqef1371X//HXr2R/o6dX1Azfvd37/qzO/dinfbH/V+c2W4QeVjoeenl5db/r8zfTx04e7l7sGltTW6enp6ekr7Lf0H9u7o8SJ2flnbsv2suOJWq/6569b752ejxc9vc593cbNdRs3v/2/k2+2vzrfr1wWwHjo6enV9aYfr00fP32I+4hexB2AECI2m1k8cXHh1Miiq5dSYxcWT1xMZqZ87OEL9znB3NmAbA93qn/+uvXQBK+34TP5l32Tf9knflyiKvorPwMQDEXzYdOP16aPH2EVr/YAAFTNoquXeoce+mpJe9PEcGrk7A8Ntx6+Y+s7rd3pWNKXHn7x8v+1KunhTvXPX7ceOuD1NsRYnwK0om4+bPrx2vTxI5Q4AwuIrkR2umX4+Mpzr0wnGt5te6R2ZnTDkW1tF0/61QMAysPrbfh4fG9gGW8hBFAJ5sOAQTgDC4i0TDx2bln3C12DNXb2H8u7fvvmL+/5aP/7S9f41QMAysPrbVi5rE9xchZQFcyHAVNwBhYQael47YUb24UQs1ZstG7JlcblTRPDPvYAgPLwegsAwWA+DJiCBSwg0hKZqZZvTi2cGrlhZjQ1+lnz2PmRxhYfewBAeXi9BYBgMB8GTMFbCIFIi2Wztw0feeD4zm+bWzs+fm2sfunp1vU+9gCA8vB6G2757xbkoldAdTEfBkzBAhYQaen4gi9uWp36/sPbPx8ar08dWvXUmWVrfewBAOXh9TbEJv+yL3/R6q1Vdfe9N1nF8QARx3wYMAULWECkZWPJT5auObCyt2ny8niyeabUB7TL9gCA8vB6G1a51SvnJKy6jZvve2+SNSygipgPA6ZgAQuAyFrxkfqUuh4Vsjo65/vWQLfoGaq0hzvVP3/demiF19uwyi1jsXQFaML3+bDpx2vTx49Q4iLuQHTZVs10ojFTk1DUwxcuswEhRM+QGOiuqIc71T9/3XpogtdbAAiGovmw6cdr08ePsGIBC4iuyUTDqbaHv1y8QlGPyrnPBhz5cwLZHu5U//x166EPXm9Dr27j5rdW1eWffpV/WXcAgVExHzb9eG36+BFi1qMvpW1b5gaWoKenp6cPoB/cKnHyhey52QPd4uh6Xv/dqP7569azP9DTq+v33rLf+U3dxs0FV3B3OCtZuTWsnq8f1Gr89PT03pk+fzN9/PQh75/4U9q2xawtMrOlb2lZIhkX9PT09PQB9C8/qfbdQ7z+u/eqf/66YX+gp1fXD9y831m06vr99JvtrxZdwPrPf72S+9bmFye0Gj89Pb0+8wfVx2vTx08f7t7atDtdorr+BiU3Sk9PT09Pr2G/766Te3d0buk/Vjru6BRC2KePRapnPkBPr66/fOna5yYVXb1yOGtYzu8X/VNG6XhWH5D4B+qW/qOPnrhTavsDN1874+xXZ37tUr7Z/uq1v2L4Qant6/b4qu5XjD9f8MWFh/+Q+/3o2mcKvnv2ht8oHU/U7q9u8wfVx2vV95f5Bn0lvdynEEptnZ6enp6eXqt+S3/pOVb+dRyi1nun5+NLTx+OXlZ541l3/3KPvdWxWuofnM72f1yn2+x+bS8n+/fnJRawdHu8gu/zV3OcP85d0wlyPKp7He5v1I7Xpo+fPsR94UXcY7OZxRMXF06NLLp6KTV2YfHExWRmSm6rAHwi+3xU3QMh4Myx5vvu3NlV1HoYitdzI7icfiWEuO+9yW2HmgMcjri559MKAxeTf9nnLF3Vbdw836/8DB4VrOa4fDEcyri/iubDoTlea3J/gfIUnoG16Oql3qGHvlrS3jQxnBo5+0PDrYfv2PpOa3c6lqzK+IAok30+qu6BcHCfY9HDRLyeowxfD1w7CevgX69bqMqdnJULKsH6lI9cFm4qPC9JT+XdX3Xz4XAcr/W5v0AZCs/ASmSnW4aPrzz3ynSi4d22R2pnRjcc2dZ28WRVBgdEnOzzUXUPANATr+f6y51+9daquqK/RDVOwhI/rl71DF37JeasZ5Und46VLxmEEAsP/+Hxg+OPHxyf+y3n6yE7D6vs+8t82F3U7i9Cpsg1sDLx2Lll3S90DdbY2X8s7/rtm7+856P97y9dE/zgAMg+H1X3AAA98XpuBGehSgix9ncv5r54+Nle51v3vTdZnWEJIX4838pZvbq551NfTr8SP10MqwhOzvLOWazZs67hsadft+7u2rOuIfetxw+O2yfe/uPODSJE52FVeH+ZD7uL2v1FmBRZwErHay/c2C6EmLVio3VLrjQub5oYDnpcAIQQ8s9H1T0AQE+8nuus4OpX+atXzh+dNSzAo6LnJYWY7P1lPuwuavcXYVL4FkIhRCIz1fLNqYVTIzfMjKZGP2seOz/S2BL8yAAI+eej6h4AoCdez82SW7EqWLoK/l2E6+5f7rxzUPj05kH4bu575ewTb+9Z17BnXYN94u2SsXEqv7/Mh91F7f4iTIqcgRXLZm8bPvLA8Z3fNrd2fPzaWP3S063rgx8ZACH/fFTdAwD0xOu5Keaeb3X42d7qnoSV/+bB3GKWv/LfLchFrypUsIgzd00nZMq4v8yH3UXt/iJMir6FcMEXN61Off/h7Z8PjdenDq166syytYEPDIAQ8s9H1T0AQE+8nsMvfl39KmfyL/vyF62qfrUvoznXfprvj+FT3v1lPuwuavcXYVJkASsbS36ydM2Blb1Nk5fHk80zfAAzUD2yz0fVPQBAT7yeo0Lr7l9+c4//byHMrV45J2HVbdx833uTrGFJGV37TP4b5R57+vWCIH9ZJwQXca/8/jIfdhe1+4swKbKA5cha8ZH6VJBDATAf2eej6h4wmtXROd+3BrqLvH0maj2Mxuu5EfIv4l5wQfcqOvjXTwe6xbr7fT79ypFbxmLpqjy5ZZrtP//Z3O8+9vTrz/39/WBHpJYv99f3+XDIjtdVv79AGQov4m5bNdOJxkxNoiqjAZBP9vmougdCwGV2JYToGRID3ZHuYShezw1y+Nle53JXc3+jAxUnYSEAfX19fX191R5FcFzur6L5cGiO15rcX6A8hQtYk4mGU20Pf7l4RVVGAyCf7PNRdQ+Yzn125cifY0Wth7l4PTfR2t+9WN3TrwrWqvI/jrBoUIm6jZvfWlWXf/pV/mXd4VFfX591d1fBF627u3JfDNkaVnn3V8V8OEzHax3uL1A269GX0rYtcwNL0NPT09PTG9cPbpU4OUX2XPcQ9EfXMx+gp1fV771lf+4q5m+tqiva5FZ23lpV9+f/+E7peNb8d8L7WwVv7vlU9t8Le2/Z7/ymbuPmgiu4O5yVrNwaVs/XD2r1eOnW//PE885vcus1uU/im7u+s2vXrg8afqPV+E2/v7rNH1Qfr1XfX+Yb9JX08WRc2LaYtUVmtvQtLUvQ09PT09Ob2EuRvVJDCPonfqXX40VPH6Y+n5crQKkeT8+QEEOflhxG2eMRQjiLVl2/nxbi1aKZ8y0n2/zig1o9Xrr1Obt27XLWdKy7u3JrOvl27dol1O8/Ubu/UkJwvJYdj2zPfIO+kt7atDtdorr+BiU3Sk9PH1j/8pNqL3fC6wO9zv2+u07u3dG5pf9Y6bijUwhhnz5G79LzfKenV9cP3HztjKRfnfm1S/lm+7W1ni3DDyodj+r+8qVrnxP1Zvurc0+/cry1qu4///WK8/tF/5RROh56+vze9PkDx2v6KPfzfgphUVJbp6enD6B/7gO5W3m3XfJSKnr+fOjD3W/pLz3ny78uA717752e+wM9vc79j+s4m92v/eRk//68xAKWnvdXHd3GT29ir9vxl+M1Pb3HXu4MLMRmM82T36VjC+Kz6UR2JlOTGE82T8drq9XDLL7vD4NbE83/dd1NdvyicCP9fytztFf+l9z/4QEC9vLdJ53fuMz58md79ulj9C49z/dIYb4RsPxrQrlkubUtqTOwNDTyTVy4nn7lyJ2EJXUGFqLG9/mz6fOHgI/XHC/MEvr1CrkzsLDo6qXeoYe+WtLeNDGcGjn7Q8Oth+/Y+k5rdzqWrEoPs6jYH3IrVi4LVU7jvpK14xei/28//VcIsV3uzgFV4/x/S3q/ekQB840q4tP3ACnq/j2l2/FXz+M1xwuzhH69oiaYvyY0EtnpluHjK8+9Mp1oeLftkdqZ0Q1HtrVdPFmtHmbReX8oWL0CAIQY843g1W3c7H46klSmv9zpV2+tqiv6Swhx33uT2w41V3mg0J7O8+co4OdpltCvV3AGlrRMPHZuWfcLXYM1dvYfy7t+++Yv7/lo//tL11Srh1m03R/mnoEFAAgx5hvV4rI+Fb6Ts5yFKiHE2t+9mPvi4Wd7nW95+UBGQGg8f44Ifp5mCfd6BWdgSUvHay/c2C6EmLVio3VLrjQub5oYrmIPs2i7P7B6BQCRwnwD6hRc/Sp/9WruH4GStJ0/RwQ/T7OEe72CBSxpicxUyzenFk6N3DAzmhr9rHns/EhjSxV7mEXb/SH/DCwAQOgx30CQnLOu8n/j4F2E8ELb+XNE8PM0S7jXK3gLobRYNnvb8JEHju/8trm14+PXxuqXnm5dX8UeZtF2f+AMLACIFOYb1ZX/bsFwXPRqPmt/92LBotXhZ3vnfhFwoe38OSL4eZol3OsVLGBJS8cXfHHT6tT3H97++dB4ferQqqfOLFtbxR5m0XZ/4BpYABApzDeqaPIv+/IXrbgaFOBO2/lzRPDzNEu41ytYwJKWjSU/WbrmwMrepsnL48nmmVIfGKm6h1m03R9YvQKASGG+US251SvnJKy6jZvve2+SNSzAhbbz54jg52mWcK9XsIBVpqwVH6lP6dPDLBruD5yBBXNZHZ3zfWugW/QM0cv1iBTmG9WSW8aKwtJV/lXbuYI7yub7/Fm346/mx2uOF2YJ63oFF3GXY1s104nGTE1Ckx5m0Xl/YPUKhnKZ7QkheobEQDe9RI+IYL6BwBx+tte53NXc3wBeKJo/63b81fZ4zfHCLKFfr2ABS85kouFU28NfLl6hSQ+z6Lw/8CmEMJH7bM+RP+ejd+8RHcw3qq5u4+a3VtXln36Vf1n3UFr7uxc5/QplUDF/1u34q/PxmuOFWUK/XmE9+lLatmVuYAl6enpN+sGtahe/eX2g17mX2v9lz72PYH90Pc93enpV/d5b9ju/qdu4ueAK7g5nJSu3htXz9YNajb+M+5u7j2+tqiva5Fbu3lpV9+f/+E6r8dOHuzd9/sDxmj7KfTwZF7YtZm2RmS19S8sS9PT0+vS2bQshLMsqkcpztvzkQEar+0tPn99Lkb1yRAT7J36l1+NLTx+mXgjhLOh0/X5aiFeLZs63nGzziw9qNf5KXp+9XOFLt/HTh7uXwvGanl6r3tq0O12iuv4GJTdKT08fWP/yk9f+D9JzH1z7ovMGwILrWOWubDXf2wMLrn61/cdzQnl9oA+y33fXyb07Orf0Hysdd3QKIezTx+h97Hm+m9XLPl94fOnp6cPaR23+wOs5fSW96fMHuWtgSW2dnp5eq96FL1e/0u3+0pvYb+k/tndHictA5F8ngt7f3js995+o9Ty+9PT09NGcP3in5+NFX93e6P0tXvCl2GymefK7dGxBfDadyM5kahLjyebpeO18m6B371XT7f7Sm7X/5PD5g9CHc0yd7/8LzT2a0vvbwyyBPb66HR/p6enpCzLdjqeRPV5rsj/Q+/t8KZvv4y9cwFp09VLv0ENfLWlvmhhOjZz9oeHWw3dsfae1Ox1LFv0L6N171XS7v/Rm7T85uTOwWMOCDrz8fyF6dT3MEszjq9vxkZ6enn5ur9vxNJrHa332B3p/ny/l8X38hW8hTGSnW4aPrzz3ynSi4d22R2pnRjcc2dZ28eR8A6J371XT7f7Sm7X/5LB6BQBwodvxkZ6enn6+HtWl2/5AH7L1iiLXwMrEY+eWdb/QNfjGnX2D9+5JZK7e89F+lzHRu/eq6XZ/6c3afxy+XAMLABBiuh0f6enp6aEn3fYH+jCtVxRZwErHay/c2C6EmLVio3VLrjQub5oYdvkL6N171XS7v/Rm7T8OzsACALjT7fhIT09PDz3ptj/Qh2m9osgCViIz1fLNqYVTIzfMjKZGP2seOz/S2OLyF9C796rpdn/pzdp/HJyBBQBwp9vxkZ6enh560m1/oA/TekXhRdyFELFs9rbhIw8c3/ltc2vHx6+N1S893bre5S+gd+9V0+3+0pu1/zg4AwsA4E634yM9PT099KTb/kAfpvWKIgtY6fiCL25anfr+w9s/HxqvTx1a9dSZZWtd/gJ691413e4vvVn7j4NPIQQAuNPt+EhPT08PPem2P9CHab2iyAJWNpb8ZOmaAyt7myYvjyebZ+b5gEN6j71qut1fen/7YLB6BQBwp9vxkZ6enh560m1/oHfvVfN3/EUWsK7dzIqP1KckhkVfVbrdX3p/e9U4AwtasTo65/vWQLfoGaJX28MsAT++uh0f6enp6XN0O55G/Hhd9f2B3t/nS4X8Gn/hRdxtq2Y60ZipSXjcLn116XZ/6f3tA8PqFfThcjQVQvQMiYFueoU9zBLY46vb8ZGenp6+gG7H08gerzXZH+jdM3PnD4ULWJOJhlNtD3+5eIXHv4C+unS7v/T+9oHhUwihCfejqSP/mErvbw+zBPn46nZ8pKenp8+n2/E0ysdrHfYHen+fL5XwffzWoy+lbVtiBJYl6OnpNekHt15bnH7ug2tfcRahCs6lyp1dNd8SVcEZWNt/fMXg9YE+yD63P3she24zfcn+6Hqe7yb1ss8XHl96evqw9lGbP/B6Tl9Jb/r8oSYZF8m4SMSEZXnaOj09vT69X+Y7A0u3+0sf7l6K7Dvz6Uv2uu0P9O69FB5fenr6EPdSNDz+yva6/fzpzeqlaLi/WfbpY3t3dG7pP1a67ugUQtC795t2p0uWP91Efnly310ntbq/9PQ696qfj/T+9ry+0Ue5l3294vlCT09PT09PH7W+Rgixpf/Y3h0l3gaZ/z5JevfeO6l/3eV63e4vPb3OvXflPR/p/e1123/o6XV+vdJt/PT09PT09PT0SvsaL7eZO7uid+9V0+3+0tPr3MMsuu0/9PRB9rJ0Gz89PT09PT09vbo+7vE2c9FXl273l55e5x5m0W3/oacPspel2/jp6enp6enp6RX1Nd63CAAAAAAAAASPBSwAAAAAAABojQUsAAAAAAAAaI0FLAAAAAAAAGiNBSwAAAAAAABojQUsAAAAAAAAaI0FLAAAAAAAAGgtnvud1dE5XzTQLXqGCr9I796rptv9pafXuYdZdNt/6OmD7GXpNn56enp6enp6ekV9TclaCNEzJAa6vW6dPgC63V96ep17mEW3/YeePshelm7jp6enp6enp6dX19eUrOfeht69V023+0tPr3MPs+i2/9DT6/x6pdv46enp6enp6emV9lbJNN+A5LnuEeyPrk/btsRNLEtI9YNbE1Lj0e3nQ08fZK/6+Ujvb8/rG32Ue9nXK54v9PT09PT09FHr5RawUNITf0rbtpi1RWa29L/cLEsk40Kqf/lJiQkrEHGqn4/0/va8viHKZF+veL4AAICo+f/elNjhpnnY9wAAAABJRU5ErkJggg==\n", "text/plain": [ "" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]\n", "img" ] }, { "cell_type": "markdown", "id": "7233c86a-eb02-48cb-8369-bb8a521bc330", "metadata": { "tags": [] }, "source": [ "#### Check if the model generated the correct level\n", "##### Because of the stochastic nature of the model and the small training dataset, the model may generate levels that do not completely match the given prompt" ] }, { "cell_type": "code", "execution_count": 50, "id": "d3489875-e648-4c75-97f0-7ae55dc51b81", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'some pipes, many enemies, some blocks, high elevation'" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mario_lm.prompter(generated_level)[0]" ] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:py39] *", "language": "python", "name": "conda-env-py39-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.0" } }, "nbformat": 4, "nbformat_minor": 5 }