{ "cells": [ { "cell_type": "markdown", "id": "3b5fbb8f-5789-45db-a551-f0e6633b4f46", "metadata": {}, "source": [ "# Populate a HDF5 dataset with base64 Pokémon images keyed by energy type" ] }, { "cell_type": "markdown", "id": "a9234c78-1ac5-4c71-b11b-3a81be57f3f3", "metadata": {}, "source": [ "Used in [**This Pokémon Does Not Exist**](https://huggingface.co/spaces/ronvolutional/ai-pokemon-card)\n", "\n", "Model fine-tuned by [**Max Woolf**](https://huggingface.co/minimaxir/ai-generated-pokemon-rudalle)\n", "\n", "ruDALL-E by [**Sber**](https://rudalle.ru/en)" ] }, { "cell_type": "markdown", "id": "09850949-235d-4997-b8e3-c5b2aeffe109", "metadata": {}, "source": [ "## Initialise datasets" ] }, { "cell_type": "code", "execution_count": 1, "id": "cead6fa8-e9ef-4672-bbfb-beadcaf5f3a0", "metadata": {}, "outputs": [], "source": [ "import os\n", "import h5py\n", "\n", "datasets_dir = './datasets'\n", "datasets_file = 'pregenerated_pokemon.h5'\n", "h5_file = os.path.join(datasets_dir, datasets_file)\n", "\n", "energy_types = ['grass', 'fire', 'water', 'lightning', 'fighting', 'psychic', 'colorless', 'darkness', 'metal', 'dragon', 'fairy']" ] }, { "cell_type": "raw", "id": "2df90e94-15c0-4eb6-914e-875ec80b7c24", "metadata": {}, "source": [ "# Only run if the datasets file does not exist\n", "\n", "with h5py.File(h5_file, 'x') as datasets:\n", " for energy in energy_types:\n", " datasets.create_dataset(energy, (0,1), h5py.string_dtype(encoding='utf-8'), maxshape=(None,1))\n", "\n", " print(datasets.keys())" ] }, { "cell_type": "markdown", "id": "cdd3eb59-bbf5-4b85-b6bc-35f591317b47", "metadata": {}, "source": [ "### Dataset functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "fca1947f-8a66-4636-8049-99d6ff0ace93", "metadata": {}, "outputs": [], "source": [ "import math\n", "from time import gmtime, strftime, time\n", "from random import choices, randint\n", "from IPython import display\n", "\n", "def get_stats(h5_file=h5_file):\n", " with h5py.File(h5_file, 'r') as datasets:\n", " return {\n", " \"size_counts\": {key: datasets[key].size.item() for key in datasets.keys()},\n", " \"size_total\": sum(list(datasets[energy].size.item() for energy in datasets.keys())),\n", " \"size_mb\": round(os.path.getsize(h5_file) / 1024**2, 1)\n", " }\n", "\n", "\n", "def add_row(energy, image):\n", " with h5py.File(h5_file, 'r+') as datasets:\n", " dataset = datasets[energy]\n", " dataset.resize(dataset.size + 1, 0)\n", " dataset[-1] = image\n", "\n", "\n", "def get_image(energy=None, row=None):\n", " if not energy:\n", " energy = choices(energy_types)[0]\n", "\n", " with h5py.File(h5_file, 'r') as datasets:\n", " if not row:\n", " row = randint(0, datasets[energy].size - 1)\n", "\n", " return datasets[energy].asstr()[row][0]\n", "\n", "def pretty_time(seconds):\n", " m, s = divmod(seconds, 60)\n", " h, m = divmod(m, 60)\n", " return f\"{f'{math.floor(h)}h ' if h else ''}{f'{math.floor(m)}m ' if m else ''}{f'{math.floor(s)}s' if s else ''}\"\n", " \n", "def populate_dataset(batches=1, batch_size=1, image_cap=100_000, filesize_cap=4_000):\n", " initial_stats = get_stats()\n", "\n", " iterations = 0\n", " start_time = time()\n", "\n", " while iterations < batches and get_stats()['size_total'] < image_cap and get_stats()['size_mb'] < filesize_cap:\n", " for energy in energy_types:\n", " current = get_stats()\n", " new_images_count = (current['size_total'] - initial_stats['size_total'])\n", " new_mb_count = round(current['size_mb'] - initial_stats['size_mb'], 1)\n", " elapsed = time() - start_time\n", " eta_total = elapsed / (new_images_count or 1) * batches * batch_size * len(energy_types)\n", "\n", " display.clear_output(wait=True)\n", " if new_images_count:\n", " print(f\"ETA: {pretty_time(eta_total - elapsed)} left of {pretty_time(eta_total)}\")\n", " print(f\"Images in dataset: {current['size_total']}{f' (+{new_images_count})' if new_images_count else ''}\")\n", " print(f\"Size of dataset: {current['size_mb']}MB{f' (+{new_mb_count}MB)' if new_mb_count else ''}\")\n", " print(f\"Batch {iterations + 1} of {batches}:\")\n", " print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Generating {batch_size} {energy} Pokémon...\")\n", "\n", " generate_pokemon(energy, batch_size)\n", "\n", " iterations += 1\n", "\n", " new_stats = get_stats()\n", " elapsed = time() - start_time\n", "\n", " display.clear_output(wait=True)\n", " print(f\"{strftime('%Y-%m-%d %H:%M:%S', gmtime())} Finished populating dataset with {batches} {'batches' if batches > 1 else 'batch'} after {pretty_time(elapsed)}\")\n", " print(f\"Images in dataset: {new_stats['size_total']} (+{new_stats['size_total'] - initial_stats['size_total']})\")\n", " print(f\"Size of dataset: {new_stats['size_mb']}MB (+{round(new_stats['size_mb'] - initial_stats['size_mb'], 1)}MB)\")" ] }, { "cell_type": "markdown", "id": "b0da582a-2b29-4df6-8a3f-ddd56377af16", "metadata": {}, "source": [ "## Load Pokémon model" ] }, { "cell_type": "code", "execution_count": 3, "id": "ad435db5-bd35-4440-87b2-5a108f4ae385", "metadata": {}, "outputs": [], "source": [ "from rudalle import get_rudalle_model, get_tokenizer, get_vae\n", "from huggingface_hub import cached_download, hf_hub_url\n", "import torch" ] }, { "cell_type": "code", "execution_count": 4, "id": "5a2a4e6a-2086-4f98-b2b5-65a3631be61e", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GPUs available: 1\n" ] } ], "source": [ "print(f\"GPUs available: {torch.cuda.device_count()}\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "720df30d-f42c-406a-92ba-4a465f6ff1d3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Working with z of shape (1, 256, 32, 32) = 262144 dimensions.\n", "vae --> ready\n", "tokenizer --> ready\n", "GPU[0] memory: 11263Mib\n", "GPU[0] memory reserved: 5144Mib\n", "GPU[0] memory allocated: 2767Mib\n" ] } ], "source": [ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "fp16 = torch.cuda.is_available()\n", "map_location = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "file_dir = \"./models\"\n", "file_name = \"pytorch_model.bin\"\n", "config_file_url = hf_hub_url(repo_id=\"minimaxir/ai-generated-pokemon-rudalle\", filename=file_name)\n", "cached_download(config_file_url, cache_dir=file_dir, force_filename=file_name)\n", "\n", "model = get_rudalle_model('Malevich', pretrained=False, fp16=fp16, device=device)\n", "model.load_state_dict(torch.load(f\"{file_dir}/{file_name}\", map_location=map_location))\n", "\n", "vae = get_vae().to(device)\n", "tokenizer = get_tokenizer()\n", "\n", "print(f\"GPU[0] memory: {int(torch.cuda.get_device_properties(0).total_memory / 1024**2)}Mib\")\n", "print(f\"GPU[0] memory reserved: {int(torch.cuda.memory_reserved(0) / 1024**2)}Mib\")\n", "print(f\"GPU[0] memory allocated: {int(torch.cuda.memory_allocated(0) / 1024**2)}Mib\")" ] }, { "cell_type": "markdown", "id": "88d413a4-a8c9-401e-9cb5-32a1ae34c179", "metadata": {}, "source": [ "### Model functions" ] }, { "cell_type": "code", "execution_count": 6, "id": "0624c686-c75f-46c6-afae-bbe6c455caa1", "metadata": {}, "outputs": [], "source": [ "import base64\n", "from io import BytesIO\n", "from time import gmtime, strftime, time\n", "from rudalle.pipelines import generate_images\n", "\n", "def english_to_russian(english):\n", " word_map = {\n", " \"colorless\": \"Покемон нормального типа\",\n", " \"dragon\": \"Покемон типа дракона\",\n", " \"darkness\": \"Покемон темного типа\",\n", " \"fairy\": \"Покемон фея\",\n", " \"fighting\": \"Покемон боевого типа\",\n", " \"fire\": \"Покемон огня\",\n", " \"grass\": \"Покемон трава\",\n", " \"lightning\": \"Покемон электрического типа\",\n", " \"metal\": \"Покемон из стали типа\",\n", " \"psychic\": \"Покемон психического типа\",\n", " \"water\": \"Покемон в воду\"\n", " }\n", "\n", " return word_map[english.lower()]\n", "\n", "\n", "def generate_pokemon(energy, num=1):\n", " if energy in energy_types:\n", " russian_prompt = english_to_russian(energy)\n", " \n", " images, _ = generate_images(russian_prompt, tokenizer, model, vae, top_k=2048, images_num=num, top_p=0.995)\n", " \n", " for image in images:\n", " buffer = BytesIO()\n", " image.save(buffer, format=\"JPEG\", quality=100, optimize=True)\n", " base64_bytes = base64.b64encode(buffer.getvalue())\n", " base64_string = base64_bytes.decode(\"UTF-8\")\n", " base64_image = \"data:image/jpeg;base64,\" + base64_string\n", " add_row(energy, base64_image)" ] }, { "cell_type": "markdown", "id": "7b309b8e-0c34-4a80-a411-8f093a105494", "metadata": {}, "source": [ "## Populate dataset" ] }, { "cell_type": "markdown", "id": "a96d026f-e300-40c7-88a6-20be0012c584", "metadata": {}, "source": [ "Total number of images per population = `batches` × `len(energy_types)` (11) × `batch_size`" ] }, { "cell_type": "code", "execution_count": 7, "id": "74d2c50a-93ae-4040-89a0-87f818187bbb", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2022-03-16 05:07:48 Finished populating dataset with 1 batch after 10m 8s\n", "Images in dataset: 5082 (+66)\n", "Size of dataset: 199.8MB (+2.5MB)\n" ] } ], "source": [ "batches = 1\n", "batch_size = 6\n", "image_cap = 100_000\n", "filesize_cap = 4_000 # MB\n", "\n", "populate_dataset(batches, batch_size, image_cap, filesize_cap)" ] }, { "cell_type": "markdown", "id": "4eb6f750-aa2d-42b9-89dc-2cb103e8869e", "metadata": {}, "source": [ "## Getting images" ] }, { "cell_type": "markdown", "id": "9365ac7b-36e4-42b9-8380-a039d556356b", "metadata": {}, "source": [ "### Random image" ] }, { "cell_type": "code", "execution_count": 8, "id": "dca3bc8e-1f65-4566-8385-bf6b9f20eaaf", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "display.HTML(f'')" ] }, { "cell_type": "markdown", "id": "f3b9278e-b6dc-4bea-9ef6-3aa312739718", "metadata": {}, "source": [ "### Random image of specific energy type" ] }, { "cell_type": "code", "execution_count": 9, "id": "b6e98817-19f7-44f3-8970-0a11b01cf37b", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "display.HTML(f'')" ] }, { "cell_type": "markdown", "id": "34ca6f96-625b-459a-ba90-2c58a1d0ea47", "metadata": {}, "source": [ "### Specific image" ] }, { "cell_type": "code", "execution_count": 10, "id": "5a60fc60-4198-4318-a5b4-26095cb2c0bb", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "display.HTML(f'')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }