{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "2e4c3931", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda:0\n" ] } ], "source": [ "import os\n", "import yaml\n", "import torch\n", "from tqdm import tqdm\n", "\n", "os.chdir('..')\n", "from inference.utils import *\n", "from core.utils import load_or_fail\n", "from train import LoraCore, WurstCoreB\n", "\n", "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "markdown", "id": "b1920cce-3ce7-4b09-853b-3199a1accd46", "metadata": {}, "source": [ "### Load Config" ] }, { "cell_type": "code", "execution_count": 4, "id": "ed108877", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['model_version', 'effnet_checkpoint_path', 'previewer_checkpoint_path', 'module_filters', 'rank', 'train_tokens']\n", "['model_version', 'stage_a_checkpoint_path', 'effnet_checkpoint_path']\n" ] } ], "source": [ "# SETUP STAGE C\n", "config_file = 'configs/inference/lora_c_3b.yaml'\n", "with open(config_file, \"r\", encoding=\"utf-8\") as file:\n", " loaded_config = yaml.safe_load(file)\n", "\n", "core = LoraCore(config_dict=loaded_config, device=device, training=False)\n", "\n", "# SETUP STAGE B\n", "config_file_b = 'configs/inference/stage_b_3b.yaml'\n", "with open(config_file_b, \"r\", encoding=\"utf-8\") as file:\n", " config_file_b = yaml.safe_load(file)\n", " \n", "core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)" ] }, { "cell_type": "markdown", "id": "6d70294f-fdd3-4371-8aee-8b563d9b889b", "metadata": {}, "source": [ "### Load Extras & Models" ] }, { "cell_type": "markdown", "id": "fbd7c44f-d0af-4363-8ac2-efc46085ba52", "metadata": {}, "source": [ "Download an example LoRA for a dog called Fernando. For more information on training your own LoRA for Stable Cascade, check out the [training](../train/) section." ] }, { "cell_type": "code", "execution_count": 10, "id": "9d8b1675-2151-4786-8490-3be3b6be8010", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "lora_fernando_10k.s 100%[===================>] 12.03M 76.3MB/s in 0.2s \n" ] } ], "source": [ "!wget https://huggingface.co/dome272/stable-cascade/resolve/main/lora_fernando_10k.safetensors -P models -q --show-progress" ] }, { "cell_type": "code", "execution_count": 11, "id": "30b6f1f6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "['transforms', 'clip_preprocess', 'gdf', 'sampling_configs', 'effnet_preprocess']\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0704b382ad1548f0b91ca3f2a1bc9d15", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "batch_size = 4\n", "caption = \"cinematic photo of a dog [fernando] wearing a space suit\"\n", "height, width = 1024, 1024\n", "stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)\n", "\n", "# Stage C Parameters\n", "extras.sampling_configs['cfg'] = 4\n", "extras.sampling_configs['shift'] = 2\n", "extras.sampling_configs['timesteps'] = 20\n", "extras.sampling_configs['t_start'] = 1.0\n", "\n", "# Stage B Parameters\n", "extras_b.sampling_configs['cfg'] = 1.1\n", "extras_b.sampling_configs['shift'] = 1\n", "extras_b.sampling_configs['timesteps'] = 10\n", "extras_b.sampling_configs['t_start'] = 1.0\n", "\n", "# PREPARE CONDITIONS\n", "batch = {'captions': [caption] * batch_size}\n", "conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)\n", "unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) \n", "conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False, eval_image_embeds=False)\n", "unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True, eval_image_embeds=False)\n", "\n", "with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):\n", " # torch.manual_seed(42)\n", "\n", " sampling_c = extras.gdf.sample(\n", " models.generator, conditions, stage_c_latent_shape,\n", " unconditions, device=device, **extras.sampling_configs,\n", " )\n", " for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):\n", " sampled_c = sampled_c\n", " \n", " # preview_c = models.previewer(sampled_c).float()\n", " # show_images(preview_c)\n", "\n", " conditions_b['effnet'] = sampled_c\n", " unconditions_b['effnet'] = torch.zeros_like(sampled_c)\n", "\n", " sampling_b = extras_b.gdf.sample(\n", " models_b.generator, conditions_b, stage_b_latent_shape,\n", " unconditions_b, device=device, **extras_b.sampling_configs\n", " )\n", " for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):\n", " sampled_b = sampled_b\n", " sampled = models_b.stage_a.decode(sampled_b).float()\n", "\n", "show_images(sampled)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }