diff --git "a/Gradio Demo.ipynb" "b/Gradio Demo.ipynb" new file mode 100644--- /dev/null +++ "b/Gradio Demo.ipynb" @@ -0,0 +1,527 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 3, + "id": "e8aaf80e", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:38:35.143155Z", + "start_time": "2022-11-23T07:38:35.122790Z" + } + }, + "outputs": [], + "source": [ + "from fastdownload import FastDownload" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e8e94748", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:38:35.879304Z", + "start_time": "2022-11-23T07:38:35.145311Z" + } + }, + "outputs": [], + "source": [ + "from PIL import Image\n", + "import torch\n", + "from torchvision import transforms as tfms\n", + "import numpy as np\n", + "import cv2\n", + "\n", + "## Import the CLIP artifacts \n", + "from transformers import CLIPTextModel, CLIPTokenizer\n", + "from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler\n", + "from diffusers import StableDiffusionInpaintPipeline" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1b81c483", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:38:45.864209Z", + "start_time": "2022-11-23T07:38:35.880756Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.14.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.mlp.fc1.weight', 'vision_model.encoder.layers.14.mlp.fc1.weight', 'vision_model.encoder.layers.19.mlp.fc2.bias', 'vision_model.encoder.layers.7.self_attn.v_proj.weight', 'vision_model.encoder.layers.11.self_attn.k_proj.bias', 'vision_model.encoder.layers.23.self_attn.q_proj.weight', 'vision_model.pre_layrnorm.bias', 'vision_model.encoder.layers.2.mlp.fc1.bias', 'vision_model.encoder.layers.11.mlp.fc2.weight', 'vision_model.encoder.layers.19.self_attn.v_proj.bias', 'vision_model.encoder.layers.20.layer_norm2.bias', 'vision_model.encoder.layers.15.mlp.fc1.weight', 'vision_model.encoder.layers.15.self_attn.v_proj.bias', 'vision_model.encoder.layers.12.mlp.fc1.bias', 'vision_model.encoder.layers.15.layer_norm2.weight', 'vision_model.encoder.layers.4.layer_norm1.weight', 'vision_model.encoder.layers.19.self_attn.v_proj.weight', 'vision_model.encoder.layers.14.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.layer_norm2.weight', 'vision_model.encoder.layers.22.mlp.fc2.bias', 'vision_model.encoder.layers.5.layer_norm2.weight', 'vision_model.encoder.layers.15.self_attn.out_proj.weight', 'vision_model.encoder.layers.14.self_attn.v_proj.weight', 'vision_model.encoder.layers.14.layer_norm1.weight', 'vision_model.encoder.layers.13.self_attn.v_proj.weight', 'vision_model.encoder.layers.8.mlp.fc2.weight', 'vision_model.encoder.layers.14.self_attn.out_proj.bias', 'vision_model.encoder.layers.4.layer_norm2.bias', 'vision_model.encoder.layers.22.self_attn.v_proj.bias', 'vision_model.encoder.layers.7.layer_norm2.bias', 'vision_model.encoder.layers.9.layer_norm1.weight', 'vision_model.encoder.layers.6.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.mlp.fc2.weight', 'vision_model.encoder.layers.18.layer_norm2.weight', 'vision_model.encoder.layers.12.mlp.fc1.weight', 'vision_model.encoder.layers.11.self_attn.out_proj.bias', 'vision_model.encoder.layers.10.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.v_proj.weight', 'vision_model.encoder.layers.13.self_attn.q_proj.weight', 'vision_model.encoder.layers.4.mlp.fc1.weight', 'logit_scale', 'vision_model.encoder.layers.3.self_attn.out_proj.weight', 'vision_model.encoder.layers.18.mlp.fc1.weight', 'vision_model.encoder.layers.19.layer_norm1.weight', 'vision_model.encoder.layers.3.self_attn.v_proj.weight', 'vision_model.encoder.layers.13.mlp.fc1.bias', 'vision_model.encoder.layers.16.self_attn.out_proj.bias', 'vision_model.encoder.layers.18.mlp.fc2.bias', 'vision_model.encoder.layers.13.mlp.fc1.weight', 'vision_model.encoder.layers.16.mlp.fc2.bias', 'vision_model.encoder.layers.13.self_attn.out_proj.weight', 'vision_model.embeddings.class_embedding', 'vision_model.encoder.layers.2.layer_norm2.weight', 'vision_model.encoder.layers.20.layer_norm1.weight', 'vision_model.encoder.layers.1.layer_norm1.weight', 'vision_model.encoder.layers.21.self_attn.v_proj.bias', 'vision_model.encoder.layers.7.mlp.fc1.bias', 'vision_model.encoder.layers.7.self_attn.out_proj.bias', 'vision_model.encoder.layers.12.mlp.fc2.weight', 'vision_model.encoder.layers.18.self_attn.k_proj.weight', 'vision_model.encoder.layers.19.mlp.fc2.weight', 'vision_model.encoder.layers.10.self_attn.out_proj.weight', 'vision_model.encoder.layers.8.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.self_attn.v_proj.weight', 'vision_model.encoder.layers.10.layer_norm2.bias', 'vision_model.encoder.layers.2.mlp.fc1.weight', 'vision_model.encoder.layers.4.self_attn.out_proj.weight', 'vision_model.encoder.layers.18.self_attn.k_proj.bias', 'vision_model.encoder.layers.19.self_attn.q_proj.weight', 'vision_model.encoder.layers.6.layer_norm1.weight', 'vision_model.encoder.layers.23.layer_norm2.bias', 'vision_model.encoder.layers.7.layer_norm1.bias', 'vision_model.encoder.layers.23.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.self_attn.q_proj.bias', 'vision_model.encoder.layers.22.layer_norm2.bias', 'vision_model.encoder.layers.6.mlp.fc1.weight', 'vision_model.encoder.layers.5.mlp.fc2.bias', 'vision_model.encoder.layers.2.layer_norm1.weight', 'vision_model.encoder.layers.9.layer_norm2.bias', 'vision_model.encoder.layers.16.layer_norm2.bias', 'vision_model.encoder.layers.0.layer_norm1.bias', 'vision_model.encoder.layers.18.layer_norm1.bias', 'vision_model.encoder.layers.21.self_attn.v_proj.weight', 'vision_model.encoder.layers.18.layer_norm1.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.bias', 'vision_model.encoder.layers.13.self_attn.k_proj.weight', 'vision_model.encoder.layers.11.self_attn.k_proj.weight', 'vision_model.encoder.layers.11.mlp.fc1.weight', 'vision_model.encoder.layers.17.self_attn.out_proj.bias', 'vision_model.encoder.layers.21.self_attn.out_proj.weight', 'vision_model.encoder.layers.3.self_attn.k_proj.weight', 'vision_model.encoder.layers.8.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.layer_norm2.weight', 'vision_model.encoder.layers.2.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.self_attn.out_proj.bias', 'vision_model.encoder.layers.23.mlp.fc2.bias', 'vision_model.encoder.layers.4.mlp.fc1.bias', 'vision_model.encoder.layers.18.mlp.fc2.weight', 'vision_model.encoder.layers.9.mlp.fc1.weight', 'vision_model.encoder.layers.23.layer_norm1.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.weight', 'vision_model.encoder.layers.10.self_attn.v_proj.bias', 'vision_model.encoder.layers.22.self_attn.q_proj.weight', 'vision_model.encoder.layers.10.mlp.fc1.bias', 'vision_model.encoder.layers.1.self_attn.k_proj.bias', 'vision_model.encoder.layers.2.self_attn.out_proj.weight', 'vision_model.encoder.layers.22.self_attn.v_proj.weight', 'vision_model.encoder.layers.4.self_attn.k_proj.bias', 'vision_model.encoder.layers.18.self_attn.q_proj.bias', 'vision_model.encoder.layers.11.mlp.fc1.bias', 'vision_model.encoder.layers.15.layer_norm1.weight', 'vision_model.encoder.layers.1.layer_norm2.bias', 'vision_model.encoder.layers.16.mlp.fc1.bias', 'vision_model.encoder.layers.2.self_attn.k_proj.weight', 'vision_model.encoder.layers.5.self_attn.out_proj.bias', 'vision_model.encoder.layers.14.layer_norm2.bias', 'vision_model.encoder.layers.15.self_attn.out_proj.bias', 'vision_model.encoder.layers.17.self_attn.v_proj.bias', 'vision_model.encoder.layers.16.self_attn.v_proj.bias', 'vision_model.encoder.layers.15.mlp.fc1.bias', 'vision_model.encoder.layers.17.self_attn.k_proj.bias', 'vision_model.encoder.layers.18.layer_norm2.bias', 'vision_model.encoder.layers.8.self_attn.v_proj.bias', 'vision_model.encoder.layers.6.mlp.fc2.weight', 'vision_model.encoder.layers.5.layer_norm2.bias', 'vision_model.encoder.layers.14.layer_norm2.weight', 'vision_model.encoder.layers.1.self_attn.v_proj.weight', 'vision_model.encoder.layers.10.self_attn.k_proj.bias', 'vision_model.encoder.layers.2.self_attn.q_proj.bias', 'vision_model.encoder.layers.7.self_attn.q_proj.bias', 'vision_model.encoder.layers.15.self_attn.q_proj.bias', 'vision_model.encoder.layers.23.mlp.fc2.weight', 'vision_model.encoder.layers.2.mlp.fc2.bias', 'vision_model.encoder.layers.16.mlp.fc1.weight', 'vision_model.encoder.layers.5.layer_norm1.weight', 'vision_model.encoder.layers.17.layer_norm1.bias', 'vision_model.encoder.layers.8.mlp.fc2.bias', 'vision_model.embeddings.position_ids', 'vision_model.encoder.layers.8.layer_norm1.weight', 'vision_model.encoder.layers.15.self_attn.q_proj.weight', 'vision_model.encoder.layers.17.layer_norm2.weight', 'vision_model.encoder.layers.11.mlp.fc2.bias', 'vision_model.encoder.layers.2.self_attn.v_proj.bias', 'visual_projection.weight', 'vision_model.encoder.layers.23.self_attn.q_proj.bias', 'vision_model.encoder.layers.0.mlp.fc1.weight', 'vision_model.encoder.layers.13.mlp.fc2.bias', 'vision_model.encoder.layers.0.mlp.fc1.bias', 'vision_model.encoder.layers.1.self_attn.v_proj.bias', 'vision_model.encoder.layers.22.mlp.fc1.weight', 'vision_model.encoder.layers.19.mlp.fc1.weight', 'vision_model.encoder.layers.16.self_attn.q_proj.weight', 'vision_model.encoder.layers.0.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.layer_norm1.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.weight', 'vision_model.encoder.layers.17.mlp.fc2.bias', 'vision_model.encoder.layers.23.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.self_attn.out_proj.bias', 'vision_model.encoder.layers.5.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.self_attn.v_proj.weight', 'vision_model.encoder.layers.23.mlp.fc1.weight', 'vision_model.encoder.layers.11.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.layer_norm2.bias', 'vision_model.encoder.layers.6.mlp.fc1.bias', 'vision_model.encoder.layers.13.self_attn.q_proj.bias', 'vision_model.encoder.layers.10.self_attn.q_proj.weight', 'vision_model.encoder.layers.9.self_attn.v_proj.bias', 'vision_model.encoder.layers.17.mlp.fc1.bias', 'vision_model.encoder.layers.4.self_attn.q_proj.weight', 'vision_model.encoder.layers.11.layer_norm1.weight', 'vision_model.encoder.layers.22.mlp.fc2.weight', 'vision_model.encoder.layers.10.self_attn.out_proj.bias', 'vision_model.encoder.layers.1.mlp.fc2.bias', 'vision_model.encoder.layers.6.layer_norm1.bias', 'vision_model.encoder.layers.11.self_attn.q_proj.bias', 'vision_model.encoder.layers.14.layer_norm1.bias', 'vision_model.encoder.layers.11.layer_norm1.bias', 'vision_model.encoder.layers.9.self_attn.out_proj.bias', 'vision_model.encoder.layers.15.self_attn.v_proj.weight', 'vision_model.encoder.layers.19.mlp.fc1.bias', 'vision_model.encoder.layers.16.layer_norm1.weight', 'vision_model.encoder.layers.7.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.layer_norm1.weight', 'vision_model.encoder.layers.9.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.q_proj.bias', 'vision_model.encoder.layers.7.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.mlp.fc1.bias', 'vision_model.encoder.layers.1.layer_norm1.bias', 'vision_model.encoder.layers.21.layer_norm2.bias', 'vision_model.encoder.layers.20.mlp.fc2.bias', 'vision_model.encoder.layers.10.mlp.fc2.weight', 'vision_model.encoder.layers.14.mlp.fc2.weight', 'vision_model.encoder.layers.15.mlp.fc2.weight', 'vision_model.encoder.layers.18.self_attn.out_proj.bias', 'vision_model.encoder.layers.11.self_attn.q_proj.weight', 'vision_model.encoder.layers.5.layer_norm1.bias', 'vision_model.encoder.layers.9.mlp.fc1.bias', 'vision_model.encoder.layers.21.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.mlp.fc2.bias', 'vision_model.encoder.layers.6.layer_norm2.bias', 'vision_model.encoder.layers.8.mlp.fc1.bias', 'vision_model.encoder.layers.16.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.k_proj.bias', 'vision_model.encoder.layers.19.layer_norm2.weight', 'vision_model.encoder.layers.23.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.self_attn.out_proj.bias', 'vision_model.encoder.layers.5.self_attn.out_proj.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.bias', 'vision_model.encoder.layers.14.mlp.fc2.bias', 'vision_model.encoder.layers.22.self_attn.out_proj.bias', 'vision_model.encoder.layers.4.self_attn.q_proj.bias', 'vision_model.encoder.layers.2.self_attn.k_proj.bias', 'vision_model.encoder.layers.18.mlp.fc1.bias', 'vision_model.encoder.layers.10.mlp.fc2.bias', 'vision_model.encoder.layers.10.mlp.fc1.weight', 'vision_model.encoder.layers.17.mlp.fc2.weight', 'vision_model.encoder.layers.1.self_attn.out_proj.bias', 'vision_model.encoder.layers.0.self_attn.v_proj.weight', 'vision_model.encoder.layers.3.layer_norm2.weight', 'vision_model.encoder.layers.18.self_attn.q_proj.weight', 'vision_model.encoder.layers.12.self_attn.q_proj.bias', 'vision_model.encoder.layers.8.layer_norm2.bias', 'vision_model.encoder.layers.20.mlp.fc2.weight', 'vision_model.encoder.layers.23.mlp.fc1.bias', 'vision_model.encoder.layers.20.self_attn.out_proj.weight', 'vision_model.encoder.layers.17.self_attn.out_proj.weight', 'vision_model.encoder.layers.16.self_attn.q_proj.bias', 'vision_model.encoder.layers.20.self_attn.v_proj.bias', 'vision_model.encoder.layers.11.layer_norm2.bias', 'vision_model.encoder.layers.12.layer_norm2.weight', 'vision_model.encoder.layers.14.self_attn.v_proj.bias', 'vision_model.encoder.layers.4.mlp.fc2.weight', 'vision_model.encoder.layers.20.layer_norm1.bias', 'vision_model.encoder.layers.6.self_attn.k_proj.weight', 'vision_model.encoder.layers.12.self_attn.k_proj.bias', 'vision_model.encoder.layers.3.self_attn.v_proj.bias', 'vision_model.encoder.layers.3.mlp.fc1.bias', 'vision_model.encoder.layers.2.layer_norm1.bias', 'vision_model.encoder.layers.18.self_attn.v_proj.weight', 'vision_model.encoder.layers.21.mlp.fc2.bias', 'vision_model.encoder.layers.8.self_attn.q_proj.weight', 'vision_model.encoder.layers.13.layer_norm2.bias', 'vision_model.encoder.layers.12.self_attn.q_proj.weight', 'vision_model.encoder.layers.23.layer_norm1.weight', 'vision_model.encoder.layers.7.mlp.fc2.bias', 'vision_model.encoder.layers.22.self_attn.k_proj.bias', 'vision_model.encoder.layers.22.self_attn.out_proj.weight', 'vision_model.encoder.layers.8.self_attn.q_proj.bias', 'vision_model.encoder.layers.19.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.k_proj.weight', 'vision_model.encoder.layers.10.self_attn.q_proj.bias', 'vision_model.encoder.layers.22.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.self_attn.q_proj.weight', 'vision_model.encoder.layers.21.layer_norm1.bias', 'vision_model.encoder.layers.4.self_attn.k_proj.weight', 'vision_model.encoder.layers.3.mlp.fc2.bias', 'vision_model.encoder.layers.23.layer_norm2.weight', 'vision_model.encoder.layers.6.self_attn.q_proj.weight', 'vision_model.encoder.layers.0.self_attn.q_proj.weight', 'vision_model.encoder.layers.1.self_attn.q_proj.weight', 'vision_model.encoder.layers.17.self_attn.q_proj.bias', 'vision_model.encoder.layers.22.layer_norm1.bias', 'vision_model.encoder.layers.16.self_attn.out_proj.weight', 'vision_model.encoder.layers.9.self_attn.q_proj.bias', 'vision_model.encoder.layers.15.mlp.fc2.bias', 'vision_model.encoder.layers.16.self_attn.k_proj.bias', 'vision_model.encoder.layers.22.layer_norm2.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.self_attn.k_proj.bias', 'vision_model.encoder.layers.7.layer_norm2.weight', 'vision_model.encoder.layers.5.self_attn.q_proj.weight', 'vision_model.encoder.layers.17.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.q_proj.weight', 'vision_model.encoder.layers.15.layer_norm2.bias', 'vision_model.encoder.layers.6.mlp.fc2.bias', 'vision_model.encoder.layers.19.layer_norm1.bias', 'vision_model.post_layernorm.weight', 'vision_model.encoder.layers.17.layer_norm2.bias', 'text_projection.weight', 'vision_model.encoder.layers.4.layer_norm2.weight', 'vision_model.encoder.layers.12.layer_norm1.weight', 'vision_model.encoder.layers.17.mlp.fc1.weight', 'vision_model.encoder.layers.5.self_attn.v_proj.weight', 'vision_model.encoder.layers.16.self_attn.v_proj.weight', 'vision_model.encoder.layers.9.layer_norm1.bias', 'vision_model.encoder.layers.5.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.layer_norm2.weight', 'vision_model.encoder.layers.0.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.self_attn.q_proj.bias', 'vision_model.encoder.layers.14.mlp.fc1.bias', 'vision_model.encoder.layers.16.layer_norm1.bias', 'vision_model.encoder.layers.12.layer_norm1.bias', 'vision_model.encoder.layers.13.layer_norm2.weight', 'vision_model.encoder.layers.3.layer_norm1.bias', 'vision_model.encoder.layers.14.self_attn.q_proj.bias', 'vision_model.encoder.layers.2.mlp.fc2.weight', 'vision_model.encoder.layers.10.layer_norm2.weight', 'vision_model.encoder.layers.12.self_attn.v_proj.weight', 'vision_model.encoder.layers.22.self_attn.k_proj.weight', 'vision_model.encoder.layers.14.self_attn.k_proj.bias', 'vision_model.encoder.layers.9.self_attn.v_proj.weight', 'vision_model.encoder.layers.4.self_attn.out_proj.bias', 'vision_model.encoder.layers.19.self_attn.q_proj.bias', 'vision_model.encoder.layers.19.self_attn.out_proj.bias', 'vision_model.encoder.layers.12.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.mlp.fc2.weight', 'vision_model.encoder.layers.13.layer_norm1.bias', 'vision_model.encoder.layers.14.self_attn.q_proj.weight', 'vision_model.encoder.layers.0.self_attn.k_proj.weight', 'vision_model.encoder.layers.7.mlp.fc2.weight', 'vision_model.encoder.layers.22.layer_norm1.weight', 'vision_model.encoder.layers.7.self_attn.k_proj.weight', 'vision_model.encoder.layers.8.mlp.fc1.weight', 'vision_model.encoder.layers.0.layer_norm1.weight', 'vision_model.encoder.layers.7.self_attn.q_proj.weight', 'vision_model.encoder.layers.3.layer_norm2.bias', 'vision_model.encoder.layers.16.mlp.fc2.weight', 'vision_model.encoder.layers.11.layer_norm2.weight', 'vision_model.encoder.layers.4.self_attn.v_proj.bias', 'vision_model.encoder.layers.7.mlp.fc1.weight', 'vision_model.encoder.layers.18.self_attn.v_proj.bias', 'vision_model.encoder.layers.4.layer_norm1.bias', 'vision_model.encoder.layers.23.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.out_proj.weight', 'vision_model.post_layernorm.bias', 'vision_model.encoder.layers.3.mlp.fc1.weight', 'vision_model.encoder.layers.19.layer_norm2.bias', 'vision_model.encoder.layers.6.self_attn.out_proj.weight', 'vision_model.encoder.layers.5.mlp.fc1.bias', 'vision_model.encoder.layers.6.layer_norm2.weight', 'vision_model.encoder.layers.13.self_attn.v_proj.bias', 'vision_model.encoder.layers.15.layer_norm1.bias', 'vision_model.encoder.layers.21.layer_norm1.weight', 'vision_model.encoder.layers.1.layer_norm2.weight', 'vision_model.encoder.layers.5.mlp.fc1.weight', 'vision_model.encoder.layers.4.mlp.fc2.bias', 'vision_model.encoder.layers.19.self_attn.k_proj.bias', 'vision_model.encoder.layers.9.self_attn.q_proj.weight', 'vision_model.encoder.layers.12.self_attn.v_proj.bias', 'vision_model.encoder.layers.17.layer_norm1.weight', 'vision_model.encoder.layers.13.mlp.fc2.weight', 'vision_model.encoder.layers.11.self_attn.v_proj.weight', 'vision_model.encoder.layers.1.mlp.fc1.weight', 'vision_model.encoder.layers.20.mlp.fc1.bias', 'vision_model.encoder.layers.20.mlp.fc1.weight', 'vision_model.encoder.layers.2.layer_norm2.bias', 'vision_model.encoder.layers.1.self_attn.q_proj.bias', 'vision_model.encoder.layers.10.layer_norm1.bias', 'vision_model.encoder.layers.13.layer_norm1.weight', 'vision_model.encoder.layers.17.self_attn.v_proj.weight', 'vision_model.encoder.layers.15.self_attn.k_proj.bias', 'vision_model.encoder.layers.22.mlp.fc1.bias', 'vision_model.encoder.layers.11.self_attn.v_proj.bias', 'vision_model.encoder.layers.2.self_attn.q_proj.weight', 'vision_model.embeddings.position_embedding.weight', 'vision_model.encoder.layers.1.mlp.fc1.bias', 'vision_model.encoder.layers.8.layer_norm1.bias', 'vision_model.encoder.layers.7.layer_norm1.weight', 'vision_model.encoder.layers.9.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.self_attn.k_proj.bias', 'vision_model.encoder.layers.0.layer_norm2.weight', 'vision_model.encoder.layers.6.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.self_attn.v_proj.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.bias', 'vision_model.encoder.layers.12.mlp.fc2.bias', 'vision_model.encoder.layers.18.self_attn.out_proj.weight', 'vision_model.encoder.layers.5.mlp.fc2.weight', 'vision_model.encoder.layers.12.self_attn.k_proj.weight', 'vision_model.encoder.layers.7.self_attn.k_proj.bias', 'vision_model.encoder.layers.20.self_attn.out_proj.bias', 'vision_model.embeddings.patch_embedding.weight', 'vision_model.encoder.layers.9.mlp.fc2.weight', 'vision_model.encoder.layers.15.self_attn.k_proj.weight', 'vision_model.pre_layrnorm.weight', 'vision_model.encoder.layers.17.self_attn.q_proj.weight', 'vision_model.encoder.layers.8.layer_norm2.weight', 'vision_model.encoder.layers.9.self_attn.k_proj.bias', 'vision_model.encoder.layers.12.self_attn.out_proj.weight', 'vision_model.encoder.layers.16.layer_norm2.weight', 'vision_model.encoder.layers.3.self_attn.k_proj.bias', 'vision_model.encoder.layers.13.self_attn.k_proj.bias', 'vision_model.encoder.layers.12.layer_norm2.bias', 'vision_model.encoder.layers.8.self_attn.v_proj.weight', 'vision_model.encoder.layers.21.self_attn.out_proj.bias', 'vision_model.encoder.layers.4.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.self_attn.v_proj.bias', 'vision_model.encoder.layers.13.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.mlp.fc2.bias', 'vision_model.encoder.layers.19.self_attn.out_proj.weight', 'vision_model.encoder.layers.2.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.mlp.fc2.weight', 'vision_model.encoder.layers.1.mlp.fc2.weight']\n", + "- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" + ] + } + ], + "source": [ + "def load_artifacts():\n", + " '''\n", + " A function to load all diffusion artifacts\n", + " '''\n", + " vae = AutoencoderKL.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"vae\", torch_dtype=torch.float16).to(\"cuda\")\n", + " unet = UNet2DConditionModel.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"unet\", torch_dtype=torch.float16).to(\"cuda\")\n", + " tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", torch_dtype=torch.float16)\n", + " text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\", torch_dtype=torch.float16).to(\"cuda\")\n", + " scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False) \n", + " return vae, unet, tokenizer, text_encoder, scheduler\n", + "\n", + "def load_image(p):\n", + " '''\n", + " Function to load images from a defined path\n", + " '''\n", + " return Image.open(p).convert('RGB').resize((512,512))\n", + "\n", + "def pil_to_latents(image):\n", + " '''\n", + " Function to convert image to latents\n", + " '''\n", + " init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0\n", + " init_image = init_image.to(device=\"cuda\", dtype=torch.float16) \n", + " init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215\n", + " return init_latent_dist\n", + "\n", + "def latents_to_pil(latents):\n", + " '''\n", + " Function to convert latents to images\n", + " '''\n", + " latents = (1 / 0.18215) * latents\n", + " with torch.no_grad():\n", + " image = vae.decode(latents).sample\n", + " image = (image / 2 + 0.5).clamp(0, 1)\n", + " image = image.detach().cpu().permute(0, 2, 3, 1).numpy()\n", + " images = (image * 255).round().astype(\"uint8\")\n", + " pil_images = [Image.fromarray(image) for image in images]\n", + " return pil_images\n", + "\n", + "def text_enc(prompts, maxlen=None):\n", + " '''\n", + " A function to take a texual promt and convert it into embeddings\n", + " '''\n", + " if maxlen is None: maxlen = tokenizer.model_max_length\n", + " inp = tokenizer(prompts, padding=\"max_length\", max_length=maxlen, truncation=True, return_tensors=\"pt\") \n", + " return text_encoder(inp.input_ids.to(\"cuda\"))[0].half()\n", + "\n", + "vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c2d419ee", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:38:45.873536Z", + "start_time": "2022-11-23T07:38:45.866722Z" + } + }, + "outputs": [], + "source": [ + "def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):\n", + " \"\"\"\n", + " Diffusion process to convert prompt to image\n", + " \"\"\"\n", + " # Converting textual prompts to embedding\n", + " text = text_enc(prompts) \n", + " \n", + " # Adding an unconditional prompt , helps in the generation process\n", + " uncond = text_enc([\"\"], text.shape[1])\n", + " emb = torch.cat([uncond, text])\n", + " \n", + " # Setting the seed\n", + " if seed: torch.manual_seed(seed)\n", + " \n", + " # Setting number of steps in scheduler\n", + " scheduler.set_timesteps(steps)\n", + " \n", + " # Convert the seed image to latent\n", + " init_latents = pil_to_latents(init_img)\n", + " \n", + " # Figuring initial time step based on strength\n", + " init_timestep = int(steps * strength) \n", + " timesteps = scheduler.timesteps[-init_timestep]\n", + " timesteps = torch.tensor([timesteps], device=\"cuda\")\n", + " \n", + " # Adding noise to the latents \n", + " noise = torch.randn(init_latents.shape, generator=None, device=\"cuda\", dtype=init_latents.dtype)\n", + " init_latents = scheduler.add_noise(init_latents, noise, timesteps)\n", + " latents = init_latents\n", + " \n", + " # We need to scale the i/p latents to match the variance\n", + " inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps)\n", + " # Predicting noise residual using U-Net\n", + " with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2)\n", + " \n", + " # Performing Guidance\n", + " pred = u + g*(t-u)\n", + "\n", + " # Zero shot prediction\n", + " latents = scheduler.step(pred, timesteps, latents).pred_original_sample\n", + " \n", + " # Returning the latent representation to output an array of 4x64x64\n", + " return latents.detach().cpu()\n", + "\n", + "def create_mask_fast(init_img, rp, qp, n=20, s=0.5):\n", + " ## Initialize a dictionary to save n iterations\n", + " diff = {}\n", + " \n", + " ## Repeating the difference process n times\n", + " for idx in range(n):\n", + " ## Creating denoised sample using reference / original text\n", + " orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0]\n", + " ## Creating denoised sample using query / target text\n", + " query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0]\n", + " ## Taking the difference \n", + " diff[idx] = (np.array(orig_noise)-np.array(query_noise))\n", + " \n", + " ## Creating a mask placeholder\n", + " mask = np.zeros_like(diff[0])\n", + " \n", + " ## Taking an average of 10 iterations\n", + " for idx in range(n):\n", + " ## Note np.abs is a key step\n", + " mask += np.abs(diff[idx]) \n", + " \n", + " ## Averaging multiple channels \n", + " mask = mask.mean(0)\n", + " \n", + " ## Normalizing \n", + " mask = (mask - mask.mean()) / np.std(mask)\n", + " \n", + " ## Binarizing and returning the mask object\n", + " return (mask > 0).astype(\"uint8\")\n", + "\n", + "def improve_mask(mask):\n", + " mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0\n", + " return mask.astype('uint8')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e3b0bc1e", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:38:45.886013Z", + "start_time": "2022-11-23T07:38:45.874939Z" + } + }, + "outputs": [], + "source": [ + "def fastDiffEdit(init_img, rp , qp, g=7.5, seed=100, strength =0.7, steps=20, dim=512):\n", + " \n", + " ## Step 1: Create mask\n", + " mask = create_mask_fast(init_img=init_img, rp=rp, qp=qp, n=20)\n", + " \n", + " ## Improve masking using CV trick\n", + " mask = improve_mask(mask)\n", + " \n", + " ## Step 2 and 3: Diffusion process using mask\n", + " output = pipe(\n", + " prompt=qp, \n", + " image=init_img, \n", + " mask_image=Image.fromarray(mask*255).resize((512,512)), \n", + " generator=torch.Generator(\"cuda\").manual_seed(100),\n", + " num_inference_steps = steps\n", + " ).images\n", + " return output[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "73a3d383", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:38:56.698098Z", + "start_time": "2022-11-23T07:38:45.887255Z" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "652fef739c674f7fbc156e36b05db848", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 15 files: 0%| | 0/15 [00:00" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "20f567e7", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:39:10.453493Z", + "start_time": "2022-11-23T07:39:10.450855Z" + } + }, + "outputs": [], + "source": [ + "def test_func(init_img, rp , qp):\n", + " mask = create_mask_fast(init_img, rp, qp)\n", + " mask = improve_mask(mask)\n", + " return pipe(\n", + " prompt=[qp], \n", + " image=init_img, \n", + " mask_image=Image.fromarray(mask*255),#.resize((512,512)), \n", + " generator=torch.Generator(\"cuda\").manual_seed(100),\n", + " num_inference_steps = 20\n", + " ).images[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "0eb2e975", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:39:10.466013Z", + "start_time": "2022-11-23T07:39:10.455799Z" + } + }, + "outputs": [], + "source": [ + "import gradio as gr" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2907db3e", + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-23T07:39:35.972730Z", + "start_time": "2022-11-23T07:39:31.484620Z" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/aayush/miniconda3/envs/fastai/lib/python3.9/site-packages/gradio/deprecation.py:40: UserWarning: `enable_queue` is deprecated in `Interface()`, please use it within `launch()` instead.\n", + " warnings.warn(value)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on local URL: http://127.0.0.1:7861/\n", + "Running on public URL: https://a3a54c93e8e5d9f8.gradio.app\n", + "\n", + "This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://www.huggingface.co/spaces)\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "(,\n", + " 'http://127.0.0.1:7861/',\n", + " 'https://a3a54c93e8e5d9f8.gradio.app')" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "demo = gr.Interface(\n", + " fn=fastDiffEdit, \n", + " inputs=[\n", + " gr.inputs.Image(shape=(512, 512), type=\"pil\", label = \"Upload your image photo\"),\n", + " gr.Textbox(label=\"Describe your image. Ex: a horse image\"),\n", + " gr.Textbox(label=\"Retype the description with target output. Ex: a zebra image\")], \n", + " outputs=\"image\",\n", + " title = \"DiffEdit demo\",\n", + " description = \"DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object\",\n", + " examples = [\n", + " [\"horse.jpg\", \"a horse image\", \"a zebra image\"],\n", + " [\"fruitbowl.jpg\", \"a bowl of fruit\", \"a bowl of grapes\"]],\n", + " enable_queue=True\n", + " )\n", + "\n", + "demo.launch(share=True) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7abccb1f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": { + "height": "calc(100% - 180px)", + "left": "10px", + "top": "150px", + "width": "465.455px" + }, + "toc_section_display": true, + "toc_window_display": true + }, + "varInspector": { + "cols": { + "lenName": 16, + "lenType": 16, + "lenVar": 40 + }, + "kernels_config": { + "python": { + "delete_cmd_postfix": "", + "delete_cmd_prefix": "del ", + "library": "var_list.py", + "varRefreshCmd": "print(var_dic_list())" + }, + "r": { + "delete_cmd_postfix": ") ", + "delete_cmd_prefix": "rm(", + "library": "var_list.r", + "varRefreshCmd": "cat(var_dic_list()) " + } + }, + "types_to_exclude": [ + "module", + "function", + "builtin_function_or_method", + "instance", + "_Feature" + ], + "window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}