{ "cells": [ { "cell_type": "markdown", "id": "f86ede39-8d9f-4da9-bc12-955f2fddd484", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## Copyright 2023 Google LLC" ] }, { "cell_type": "code", "execution_count": null, "id": "3f3cbf47-a52b-48b1-9bd3-3435f92f2174", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# Copyright 2023 Google LLC\n", "#\n", "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", "# You may obtain a copy of the License at\n", "#\n", "# http://www.apache.org/licenses/LICENSE-2.0\n", "#\n", "# Unless required by applicable law or agreed to in writing, software\n", "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." ] }, { "cell_type": "markdown", "id": "22de629b-581f-4335-9e7b-f73221d8dbcb", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# ControlNet depth with StyleAligned over SDXL" ] }, { "cell_type": "code", "execution_count": null, "id": "486b7ebb-c483-4bf0-ace8-f8092c2d1f23", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL\n", "from diffusers.utils import load_image\n", "from transformers import DPTImageProcessor, DPTForDepthEstimation\n", "import torch\n", "import mediapy\n", "import sa_handler\n", "import pipeline_calls" ] }, { "cell_type": "code", "execution_count": null, "id": "2a7e85e7-b5cf-45b2-946a-5ba1e4923586", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# init models\n", "\n", "depth_estimator = DPTForDepthEstimation.from_pretrained(\"Intel/dpt-hybrid-midas\").to(\"cuda\")\n", "feature_processor = DPTImageProcessor.from_pretrained(\"Intel/dpt-hybrid-midas\")\n", "\n", "controlnet = ControlNetModel.from_pretrained(\n", " \"diffusers/controlnet-depth-sdxl-1.0\",\n", " variant=\"fp16\",\n", " use_safetensors=True,\n", " torch_dtype=torch.float16,\n", ").to(\"cuda\")\n", "vae = AutoencoderKL.from_pretrained(\"madebyollin/sdxl-vae-fp16-fix\", torch_dtype=torch.float16).to(\"cuda\")\n", "pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(\n", " \"stabilityai/stable-diffusion-xl-base-1.0\",\n", " controlnet=controlnet,\n", " vae=vae,\n", " variant=\"fp16\",\n", " use_safetensors=True,\n", " torch_dtype=torch.float16,\n", ").to(\"cuda\")\n", "pipeline.enable_model_cpu_offload()\n", "\n", "sa_args = sa_handler.StyleAlignedArgs(share_group_norm=False,\n", " share_layer_norm=False,\n", " share_attention=True,\n", " adain_queries=True,\n", " adain_keys=True,\n", " adain_values=False,\n", " )\n", "handler = sa_handler.Handler(pipeline)\n", "handler.register(sa_args, )" ] }, { "cell_type": "code", "execution_count": null, "id": "94ca26b4-9061-4012-9400-8d97ef212d87", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# get depth maps\n", "\n", "image = load_image(\"./example_image/train.png\")\n", "depth_image1 = pipeline_calls.get_depth_map(image, feature_processor, depth_estimator)\n", "depth_image2 = load_image(\"./example_image/sun.png\").resize((1024, 1024))\n", "mediapy.show_images([depth_image1, depth_image2])" ] }, { "cell_type": "code", "execution_count": null, "id": "c8f56fe4-559f-49ff-a2d8-460dcfeb56a0", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# run ControlNet depth with StyleAligned\n", "\n", "reference_prompt = \"a poster in flat design style\"\n", "target_prompts = [\"a train in flat design style\", \"the sun in flat design style\"]\n", "controlnet_conditioning_scale = 0.8\n", "num_images_per_prompt = 3 # adjust according to VRAM size\n", "latents = torch.randn(1 + num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n", "for deph_map, target_prompt in zip((depth_image1, depth_image2), target_prompts):\n", " latents[1:] = torch.randn(num_images_per_prompt, 4, 128, 128).to(pipeline.unet.dtype)\n", " images = pipeline_calls.controlnet_call(pipeline, [reference_prompt, target_prompt],\n", " image=deph_map,\n", " num_inference_steps=50,\n", " controlnet_conditioning_scale=controlnet_conditioning_scale,\n", " num_images_per_prompt=num_images_per_prompt,\n", " latents=latents)\n", " \n", " mediapy.show_images([images[0], deph_map] + images[1:], titles=[\"reference\", \"depth\"] + [f'result {i}' for i in range(1, len(images))])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "437ba4bd-6243-486b-8ba5-3b7cd661d53a", "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 5 }