{ "cells": [ { "cell_type": "markdown", "id": "c423d2a1-475e-482e-b759-f16456fd6707", "metadata": {}, "source": [ "# Install" ] }, { "cell_type": "code", "execution_count": null, "id": "0440d6a7-78b9-49e9-98a2-9a5ed75e1a2f", "metadata": {}, "outputs": [], "source": [ "!git clone https://github.com/kopyl/PixArt-alpha.git" ] }, { "cell_type": "code", "execution_count": null, "id": "0abadf51-a7e3-4091-bb02-0bdd8d28fb73", "metadata": {}, "outputs": [], "source": [ "%cd PixArt-alpha" ] }, { "cell_type": "code", "execution_count": null, "id": "4df1af24-f439-485d-a946-966dbf16c49b", "metadata": { "scrolled": true }, "outputs": [], "source": [ "!pip install torch==2.0.0+cu117 torchvision==0.15.1+cu117 torchaudio==2.0.1 --index-url https://download.pytorch.org/whl/cu117\n", "!pip install -r requirements.txt\n", "!pip install wandb" ] }, { "cell_type": "markdown", "id": "d44474fd-0b92-48fc-b4cf-142b59d3917c", "metadata": {}, "source": [ "## Download model" ] }, { "cell_type": "code", "execution_count": null, "id": "06b1c1c9-f8b1-4719-8564-2383eac9ff28", "metadata": {}, "outputs": [], "source": [ "!python tools/download.py --model_names \"PixArt-XL-2-512x512.pth\"" ] }, { "cell_type": "markdown", "id": "f298a89c-d2a5-4da7-8304-c1390da0ba58", "metadata": {}, "source": [ "## Make dataset out of Hugginggface dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "e17b8883-0a5c-4fa3-a7d0-e8ee95e42027", "metadata": {}, "outputs": [], "source": [ "import os\n", "from tqdm.notebook import tqdm\n", "from datasets import load_dataset\n", "import json" ] }, { "cell_type": "code", "execution_count": null, "id": "92957b2c-6765-48ee-9296-d6739066d74d", "metadata": {}, "outputs": [], "source": [ "dataset = load_dataset(\"lambdalabs/pokemon-blip-captions\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0095cdda-c31a-48ee-a115-076a5fc393c3", "metadata": {}, "outputs": [], "source": [ "root_dir = \"/workspace/pixart-pokemon\"\n", "images_dir = \"images\"\n", "captions_dir = \"captions\"\n", "\n", "images_dir_absolute = os.path.join(root_dir, images_dir)\n", "captions_dir_absolute = os.path.join(root_dir, captions_dir)\n", "\n", "if not os.path.exists(root_dir):\n", " os.makedirs(os.path.join(root_dir, images_dir))\n", "\n", "if not os.path.exists(os.path.join(root_dir, images_dir)):\n", " os.makedirs(os.path.join(root_dir, images_dir))\n", "if not os.path.exists(os.path.join(root_dir, captions_dir)):\n", " os.makedirs(os.path.join(root_dir, captions_dir))\n", "\n", "image_format = \"png\"\n", "json_name = \"partition/data_info.json\"\n", "if not os.path.exists(os.path.join(root_dir, \"partition\")):\n", " os.makedirs(os.path.join(root_dir, \"partition\"))\n", "\n", "absolute_json_name = os.path.join(root_dir, json_name)\n", "data_info = []\n", "\n", "order = 0\n", "for item in tqdm(dataset[\"train\"]): \n", " image = item[\"image\"]\n", " image.save(f\"{images_dir_absolute}/{order}.{image_format}\")\n", " with open(f\"{captions_dir_absolute}/{order}.txt\", \"w\") as text_file:\n", " text_file.write(item[\"text\"])\n", " \n", " width, height = 512, 512\n", " ratio = 1\n", " data_info.append({\n", " \"height\": height,\n", " \"width\": width,\n", " \"ratio\": ratio,\n", " \"path\": f\"images/{order}.{image_format}\",\n", " \"prompt\": item[\"text\"],\n", " })\n", " \n", " order += 1\n", "\n", "with open(absolute_json_name, \"w\") as json_file:\n", " json.dump(data_info, json_file)" ] }, { "cell_type": "markdown", "id": "25be1c03", "metadata": {}, "source": [ "## Extract features" ] }, { "cell_type": "code", "execution_count": null, "id": "9f07a4f5-1873-48bf-86d0-9304942de5d3", "metadata": {}, "outputs": [], "source": [ "!python /workspace/PixArt-alpha/tools/extract_features.py \\\n", " --img_size 512 \\\n", " --json_path \"/workspace/pixart-pokemon/partition/data_info.json\" \\\n", " --t5_save_root \"/workspace/pixart-pokemon/caption_feature_wmask\" \\\n", " --vae_save_root \"/workspace/pixart-pokemon/img_vae_features\" \\\n", " --pretrained_models_dir \"/workspace/PixArt-alpha/output/pretrained_models\" \\\n", " --dataset_root \"/workspace/pixart-pokemon\"" ] }, { "cell_type": "code", "execution_count": null, "id": "9fc653d0", "metadata": {}, "outputs": [], "source": [ "!wandb login REPLACE_THIS_WITH_YOUR_AUTH_TOKEN_OF_WANDB" ] }, { "cell_type": "markdown", "id": "2cf1fd1a", "metadata": {}, "source": [ "## Train model" ] }, { "cell_type": "code", "execution_count": null, "id": "ea0e9dab-17bc-45ed-9c81-b670bbb8de47", "metadata": {}, "outputs": [], "source": [ "!python -m torch.distributed.launch \\\n", " train_scripts/train.py \\\n", " /workspace/PixArt-alpha/notebooks/PixArt_xl2_img512_internal_for_pokemon_sample_training.py \\\n", " --work-dir output/trained_model \\\n", " --report_to=\"wandb\" \\\n", " --loss_report_name=\"train_loss\"" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }