diff --git "a/demo.ipynb" "b/demo.ipynb" new file mode 100644--- /dev/null +++ "b/demo.ipynb" @@ -0,0 +1,241 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "651c16af-ab8b-4f42-a170-08e72e0f3533", + "metadata": {}, + "outputs": [], + "source": [ + "from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel\n", + "from rembg import remove\n", + "from PIL import Image\n", + "import torch\n", + "from ip_adapter import IPAdapterXL\n", + "from ip_adapter.utils import register_cross_attention_hook, get_net_attn_map, attnmaps2images\n", + "from PIL import Image, ImageChops\n", + "import numpy as np\n", + "\n", + "def image_grid(imgs, rows, cols):\n", + " assert len(imgs) == rows*cols\n", + "\n", + " w, h = imgs[0].size\n", + " grid = Image.new('RGB', size=(cols*w, rows*h))\n", + " grid_w, grid_h = grid.size\n", + " \n", + " for i, img in enumerate(imgs):\n", + " grid.paste(img, box=(i%cols*w, i//cols*h))\n", + " return grid" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b4810ab9-f6f3-4a27-aa01-7076ac3eefff", + "metadata": {}, + "outputs": [], + "source": [ + "base_model_path = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", + "image_encoder_path = \"models/image_encoder\"\n", + "ip_ckpt = \"sdxl_models/ip-adapter_sdxl_vit-h.bin\"\n", + "controlnet_path = \"diffusers/controlnet-depth-sdxl-1.0\"\n", + "device = \"cuda\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3fe3d8e3-a786-434d-8a45-14c8ebee0979", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6883e59476834310b2526be244e07bb3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading pipeline components...: 0%| | 0/7 [00:00" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\n", + "np_image = np.array(Image.open('demo_assets/depths/' + obj + '.png'))\n", + "np_image = (np_image / 256).astype('uint8')\n", + "\n", + "depth_map = Image.fromarray(np_image).resize((1024,1024))\n", + "\n", + "init_img = init_img.resize((1024,1024))\n", + "mask = target_mask.resize((1024, 1024))\n", + "grid = image_grid([target_mask.resize((256, 256)), ip_image.resize((256, 256)), init_img.resize((256, 256)), depth_map.resize((256, 256))], 1, 4)\n", + "\n", + "# Visualize each input individually\n", + "grid" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "ccfdc71f-3913-4772-a68f-2266c1f50af4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "best quality, high quality\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "955f34c3630a41a7a709cfbf73b1868c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/29 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "num_samples = 1\n", + "images = ip_model.generate(pil_image=ip_image, image=init_img, control_image=depth_map, mask_image=mask, controlnet_conditioning_scale=0.9, num_samples=num_samples, num_inference_steps=30, seed=42)\n", + "images[0].show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "651fc615-b21d-4a95-985d-9f1eeb53ef49", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "diffuser", + "language": "python", + "name": "diffuser" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}