{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "e8aaf80e", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:38:35.143155Z", "start_time": "2022-11-23T07:38:35.122790Z" } }, "outputs": [], "source": [ "from fastdownload import FastDownload" ] }, { "cell_type": "code", "execution_count": 4, "id": "e8e94748", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:38:35.879304Z", "start_time": "2022-11-23T07:38:35.145311Z" } }, "outputs": [], "source": [ "from PIL import Image\n", "import torch\n", "from torchvision import transforms as tfms\n", "import numpy as np\n", "import cv2\n", "\n", "## Import the CLIP artifacts \n", "from transformers import CLIPTextModel, CLIPTokenizer\n", "from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler\n", "from diffusers import StableDiffusionInpaintPipeline" ] }, { "cell_type": "code", "execution_count": 5, "id": "1b81c483", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:38:45.864209Z", "start_time": "2022-11-23T07:38:35.880756Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.14.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.mlp.fc1.weight', 'vision_model.encoder.layers.14.mlp.fc1.weight', 'vision_model.encoder.layers.19.mlp.fc2.bias', 'vision_model.encoder.layers.7.self_attn.v_proj.weight', 'vision_model.encoder.layers.11.self_attn.k_proj.bias', 'vision_model.encoder.layers.23.self_attn.q_proj.weight', 'vision_model.pre_layrnorm.bias', 'vision_model.encoder.layers.2.mlp.fc1.bias', 'vision_model.encoder.layers.11.mlp.fc2.weight', 'vision_model.encoder.layers.19.self_attn.v_proj.bias', 'vision_model.encoder.layers.20.layer_norm2.bias', 'vision_model.encoder.layers.15.mlp.fc1.weight', 'vision_model.encoder.layers.15.self_attn.v_proj.bias', 'vision_model.encoder.layers.12.mlp.fc1.bias', 'vision_model.encoder.layers.15.layer_norm2.weight', 'vision_model.encoder.layers.4.layer_norm1.weight', 'vision_model.encoder.layers.19.self_attn.v_proj.weight', 'vision_model.encoder.layers.14.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.layer_norm2.weight', 'vision_model.encoder.layers.22.mlp.fc2.bias', 'vision_model.encoder.layers.5.layer_norm2.weight', 'vision_model.encoder.layers.15.self_attn.out_proj.weight', 'vision_model.encoder.layers.14.self_attn.v_proj.weight', 'vision_model.encoder.layers.14.layer_norm1.weight', 'vision_model.encoder.layers.13.self_attn.v_proj.weight', 'vision_model.encoder.layers.8.mlp.fc2.weight', 'vision_model.encoder.layers.14.self_attn.out_proj.bias', 'vision_model.encoder.layers.4.layer_norm2.bias', 'vision_model.encoder.layers.22.self_attn.v_proj.bias', 'vision_model.encoder.layers.7.layer_norm2.bias', 'vision_model.encoder.layers.9.layer_norm1.weight', 'vision_model.encoder.layers.6.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.mlp.fc2.weight', 'vision_model.encoder.layers.18.layer_norm2.weight', 'vision_model.encoder.layers.12.mlp.fc1.weight', 'vision_model.encoder.layers.11.self_attn.out_proj.bias', 'vision_model.encoder.layers.10.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.v_proj.weight', 'vision_model.encoder.layers.13.self_attn.q_proj.weight', 'vision_model.encoder.layers.4.mlp.fc1.weight', 'logit_scale', 'vision_model.encoder.layers.3.self_attn.out_proj.weight', 'vision_model.encoder.layers.18.mlp.fc1.weight', 'vision_model.encoder.layers.19.layer_norm1.weight', 'vision_model.encoder.layers.3.self_attn.v_proj.weight', 'vision_model.encoder.layers.13.mlp.fc1.bias', 'vision_model.encoder.layers.16.self_attn.out_proj.bias', 'vision_model.encoder.layers.18.mlp.fc2.bias', 'vision_model.encoder.layers.13.mlp.fc1.weight', 'vision_model.encoder.layers.16.mlp.fc2.bias', 'vision_model.encoder.layers.13.self_attn.out_proj.weight', 'vision_model.embeddings.class_embedding', 'vision_model.encoder.layers.2.layer_norm2.weight', 'vision_model.encoder.layers.20.layer_norm1.weight', 'vision_model.encoder.layers.1.layer_norm1.weight', 'vision_model.encoder.layers.21.self_attn.v_proj.bias', 'vision_model.encoder.layers.7.mlp.fc1.bias', 'vision_model.encoder.layers.7.self_attn.out_proj.bias', 'vision_model.encoder.layers.12.mlp.fc2.weight', 'vision_model.encoder.layers.18.self_attn.k_proj.weight', 'vision_model.encoder.layers.19.mlp.fc2.weight', 'vision_model.encoder.layers.10.self_attn.out_proj.weight', 'vision_model.encoder.layers.8.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.self_attn.v_proj.weight', 'vision_model.encoder.layers.10.layer_norm2.bias', 'vision_model.encoder.layers.2.mlp.fc1.weight', 'vision_model.encoder.layers.4.self_attn.out_proj.weight', 'vision_model.encoder.layers.18.self_attn.k_proj.bias', 'vision_model.encoder.layers.19.self_attn.q_proj.weight', 'vision_model.encoder.layers.6.layer_norm1.weight', 'vision_model.encoder.layers.23.layer_norm2.bias', 'vision_model.encoder.layers.7.layer_norm1.bias', 'vision_model.encoder.layers.23.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.self_attn.q_proj.bias', 'vision_model.encoder.layers.22.layer_norm2.bias', 'vision_model.encoder.layers.6.mlp.fc1.weight', 'vision_model.encoder.layers.5.mlp.fc2.bias', 'vision_model.encoder.layers.2.layer_norm1.weight', 'vision_model.encoder.layers.9.layer_norm2.bias', 'vision_model.encoder.layers.16.layer_norm2.bias', 'vision_model.encoder.layers.0.layer_norm1.bias', 'vision_model.encoder.layers.18.layer_norm1.bias', 'vision_model.encoder.layers.21.self_attn.v_proj.weight', 'vision_model.encoder.layers.18.layer_norm1.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.bias', 'vision_model.encoder.layers.13.self_attn.k_proj.weight', 'vision_model.encoder.layers.11.self_attn.k_proj.weight', 'vision_model.encoder.layers.11.mlp.fc1.weight', 'vision_model.encoder.layers.17.self_attn.out_proj.bias', 'vision_model.encoder.layers.21.self_attn.out_proj.weight', 'vision_model.encoder.layers.3.self_attn.k_proj.weight', 'vision_model.encoder.layers.8.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.layer_norm2.weight', 'vision_model.encoder.layers.2.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.self_attn.out_proj.bias', 'vision_model.encoder.layers.23.mlp.fc2.bias', 'vision_model.encoder.layers.4.mlp.fc1.bias', 'vision_model.encoder.layers.18.mlp.fc2.weight', 'vision_model.encoder.layers.9.mlp.fc1.weight', 'vision_model.encoder.layers.23.layer_norm1.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.weight', 'vision_model.encoder.layers.10.self_attn.v_proj.bias', 'vision_model.encoder.layers.22.self_attn.q_proj.weight', 'vision_model.encoder.layers.10.mlp.fc1.bias', 'vision_model.encoder.layers.1.self_attn.k_proj.bias', 'vision_model.encoder.layers.2.self_attn.out_proj.weight', 'vision_model.encoder.layers.22.self_attn.v_proj.weight', 'vision_model.encoder.layers.4.self_attn.k_proj.bias', 'vision_model.encoder.layers.18.self_attn.q_proj.bias', 'vision_model.encoder.layers.11.mlp.fc1.bias', 'vision_model.encoder.layers.15.layer_norm1.weight', 'vision_model.encoder.layers.1.layer_norm2.bias', 'vision_model.encoder.layers.16.mlp.fc1.bias', 'vision_model.encoder.layers.2.self_attn.k_proj.weight', 'vision_model.encoder.layers.5.self_attn.out_proj.bias', 'vision_model.encoder.layers.14.layer_norm2.bias', 'vision_model.encoder.layers.15.self_attn.out_proj.bias', 'vision_model.encoder.layers.17.self_attn.v_proj.bias', 'vision_model.encoder.layers.16.self_attn.v_proj.bias', 'vision_model.encoder.layers.15.mlp.fc1.bias', 'vision_model.encoder.layers.17.self_attn.k_proj.bias', 'vision_model.encoder.layers.18.layer_norm2.bias', 'vision_model.encoder.layers.8.self_attn.v_proj.bias', 'vision_model.encoder.layers.6.mlp.fc2.weight', 'vision_model.encoder.layers.5.layer_norm2.bias', 'vision_model.encoder.layers.14.layer_norm2.weight', 'vision_model.encoder.layers.1.self_attn.v_proj.weight', 'vision_model.encoder.layers.10.self_attn.k_proj.bias', 'vision_model.encoder.layers.2.self_attn.q_proj.bias', 'vision_model.encoder.layers.7.self_attn.q_proj.bias', 'vision_model.encoder.layers.15.self_attn.q_proj.bias', 'vision_model.encoder.layers.23.mlp.fc2.weight', 'vision_model.encoder.layers.2.mlp.fc2.bias', 'vision_model.encoder.layers.16.mlp.fc1.weight', 'vision_model.encoder.layers.5.layer_norm1.weight', 'vision_model.encoder.layers.17.layer_norm1.bias', 'vision_model.encoder.layers.8.mlp.fc2.bias', 'vision_model.embeddings.position_ids', 'vision_model.encoder.layers.8.layer_norm1.weight', 'vision_model.encoder.layers.15.self_attn.q_proj.weight', 'vision_model.encoder.layers.17.layer_norm2.weight', 'vision_model.encoder.layers.11.mlp.fc2.bias', 'vision_model.encoder.layers.2.self_attn.v_proj.bias', 'visual_projection.weight', 'vision_model.encoder.layers.23.self_attn.q_proj.bias', 'vision_model.encoder.layers.0.mlp.fc1.weight', 'vision_model.encoder.layers.13.mlp.fc2.bias', 'vision_model.encoder.layers.0.mlp.fc1.bias', 'vision_model.encoder.layers.1.self_attn.v_proj.bias', 'vision_model.encoder.layers.22.mlp.fc1.weight', 'vision_model.encoder.layers.19.mlp.fc1.weight', 'vision_model.encoder.layers.16.self_attn.q_proj.weight', 'vision_model.encoder.layers.0.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.layer_norm1.weight', 'vision_model.encoder.layers.5.self_attn.k_proj.weight', 'vision_model.encoder.layers.17.mlp.fc2.bias', 'vision_model.encoder.layers.23.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.self_attn.out_proj.bias', 'vision_model.encoder.layers.5.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.self_attn.v_proj.weight', 'vision_model.encoder.layers.23.mlp.fc1.weight', 'vision_model.encoder.layers.11.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.layer_norm2.bias', 'vision_model.encoder.layers.6.mlp.fc1.bias', 'vision_model.encoder.layers.13.self_attn.q_proj.bias', 'vision_model.encoder.layers.10.self_attn.q_proj.weight', 'vision_model.encoder.layers.9.self_attn.v_proj.bias', 'vision_model.encoder.layers.17.mlp.fc1.bias', 'vision_model.encoder.layers.4.self_attn.q_proj.weight', 'vision_model.encoder.layers.11.layer_norm1.weight', 'vision_model.encoder.layers.22.mlp.fc2.weight', 'vision_model.encoder.layers.10.self_attn.out_proj.bias', 'vision_model.encoder.layers.1.mlp.fc2.bias', 'vision_model.encoder.layers.6.layer_norm1.bias', 'vision_model.encoder.layers.11.self_attn.q_proj.bias', 'vision_model.encoder.layers.14.layer_norm1.bias', 'vision_model.encoder.layers.11.layer_norm1.bias', 'vision_model.encoder.layers.9.self_attn.out_proj.bias', 'vision_model.encoder.layers.15.self_attn.v_proj.weight', 'vision_model.encoder.layers.19.mlp.fc1.bias', 'vision_model.encoder.layers.16.layer_norm1.weight', 'vision_model.encoder.layers.7.self_attn.v_proj.bias', 'vision_model.encoder.layers.10.layer_norm1.weight', 'vision_model.encoder.layers.9.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.q_proj.bias', 'vision_model.encoder.layers.7.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.mlp.fc1.bias', 'vision_model.encoder.layers.1.layer_norm1.bias', 'vision_model.encoder.layers.21.layer_norm2.bias', 'vision_model.encoder.layers.20.mlp.fc2.bias', 'vision_model.encoder.layers.10.mlp.fc2.weight', 'vision_model.encoder.layers.14.mlp.fc2.weight', 'vision_model.encoder.layers.15.mlp.fc2.weight', 'vision_model.encoder.layers.18.self_attn.out_proj.bias', 'vision_model.encoder.layers.11.self_attn.q_proj.weight', 'vision_model.encoder.layers.5.layer_norm1.bias', 'vision_model.encoder.layers.9.mlp.fc1.bias', 'vision_model.encoder.layers.21.self_attn.k_proj.weight', 'vision_model.encoder.layers.9.mlp.fc2.bias', 'vision_model.encoder.layers.6.layer_norm2.bias', 'vision_model.encoder.layers.8.mlp.fc1.bias', 'vision_model.encoder.layers.16.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.k_proj.bias', 'vision_model.encoder.layers.19.layer_norm2.weight', 'vision_model.encoder.layers.23.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.self_attn.out_proj.bias', 'vision_model.encoder.layers.5.self_attn.out_proj.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.bias', 'vision_model.encoder.layers.14.mlp.fc2.bias', 'vision_model.encoder.layers.22.self_attn.out_proj.bias', 'vision_model.encoder.layers.4.self_attn.q_proj.bias', 'vision_model.encoder.layers.2.self_attn.k_proj.bias', 'vision_model.encoder.layers.18.mlp.fc1.bias', 'vision_model.encoder.layers.10.mlp.fc2.bias', 'vision_model.encoder.layers.10.mlp.fc1.weight', 'vision_model.encoder.layers.17.mlp.fc2.weight', 'vision_model.encoder.layers.1.self_attn.out_proj.bias', 'vision_model.encoder.layers.0.self_attn.v_proj.weight', 'vision_model.encoder.layers.3.layer_norm2.weight', 'vision_model.encoder.layers.18.self_attn.q_proj.weight', 'vision_model.encoder.layers.12.self_attn.q_proj.bias', 'vision_model.encoder.layers.8.layer_norm2.bias', 'vision_model.encoder.layers.20.mlp.fc2.weight', 'vision_model.encoder.layers.23.mlp.fc1.bias', 'vision_model.encoder.layers.20.self_attn.out_proj.weight', 'vision_model.encoder.layers.17.self_attn.out_proj.weight', 'vision_model.encoder.layers.16.self_attn.q_proj.bias', 'vision_model.encoder.layers.20.self_attn.v_proj.bias', 'vision_model.encoder.layers.11.layer_norm2.bias', 'vision_model.encoder.layers.12.layer_norm2.weight', 'vision_model.encoder.layers.14.self_attn.v_proj.bias', 'vision_model.encoder.layers.4.mlp.fc2.weight', 'vision_model.encoder.layers.20.layer_norm1.bias', 'vision_model.encoder.layers.6.self_attn.k_proj.weight', 'vision_model.encoder.layers.12.self_attn.k_proj.bias', 'vision_model.encoder.layers.3.self_attn.v_proj.bias', 'vision_model.encoder.layers.3.mlp.fc1.bias', 'vision_model.encoder.layers.2.layer_norm1.bias', 'vision_model.encoder.layers.18.self_attn.v_proj.weight', 'vision_model.encoder.layers.21.mlp.fc2.bias', 'vision_model.encoder.layers.8.self_attn.q_proj.weight', 'vision_model.encoder.layers.13.layer_norm2.bias', 'vision_model.encoder.layers.12.self_attn.q_proj.weight', 'vision_model.encoder.layers.23.layer_norm1.weight', 'vision_model.encoder.layers.7.mlp.fc2.bias', 'vision_model.encoder.layers.22.self_attn.k_proj.bias', 'vision_model.encoder.layers.22.self_attn.out_proj.weight', 'vision_model.encoder.layers.8.self_attn.q_proj.bias', 'vision_model.encoder.layers.19.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.k_proj.weight', 'vision_model.encoder.layers.10.self_attn.q_proj.bias', 'vision_model.encoder.layers.22.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.self_attn.q_proj.weight', 'vision_model.encoder.layers.21.layer_norm1.bias', 'vision_model.encoder.layers.4.self_attn.k_proj.weight', 'vision_model.encoder.layers.3.mlp.fc2.bias', 'vision_model.encoder.layers.23.layer_norm2.weight', 'vision_model.encoder.layers.6.self_attn.q_proj.weight', 'vision_model.encoder.layers.0.self_attn.q_proj.weight', 'vision_model.encoder.layers.1.self_attn.q_proj.weight', 'vision_model.encoder.layers.17.self_attn.q_proj.bias', 'vision_model.encoder.layers.22.layer_norm1.bias', 'vision_model.encoder.layers.16.self_attn.out_proj.weight', 'vision_model.encoder.layers.9.self_attn.q_proj.bias', 'vision_model.encoder.layers.15.mlp.fc2.bias', 'vision_model.encoder.layers.16.self_attn.k_proj.bias', 'vision_model.encoder.layers.22.layer_norm2.weight', 'vision_model.encoder.layers.23.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.self_attn.k_proj.bias', 'vision_model.encoder.layers.7.layer_norm2.weight', 'vision_model.encoder.layers.5.self_attn.q_proj.weight', 'vision_model.encoder.layers.17.self_attn.k_proj.weight', 'vision_model.encoder.layers.20.self_attn.q_proj.weight', 'vision_model.encoder.layers.15.layer_norm2.bias', 'vision_model.encoder.layers.6.mlp.fc2.bias', 'vision_model.encoder.layers.19.layer_norm1.bias', 'vision_model.post_layernorm.weight', 'vision_model.encoder.layers.17.layer_norm2.bias', 'text_projection.weight', 'vision_model.encoder.layers.4.layer_norm2.weight', 'vision_model.encoder.layers.12.layer_norm1.weight', 'vision_model.encoder.layers.17.mlp.fc1.weight', 'vision_model.encoder.layers.5.self_attn.v_proj.weight', 'vision_model.encoder.layers.16.self_attn.v_proj.weight', 'vision_model.encoder.layers.9.layer_norm1.bias', 'vision_model.encoder.layers.5.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.layer_norm2.weight', 'vision_model.encoder.layers.0.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.self_attn.q_proj.bias', 'vision_model.encoder.layers.14.mlp.fc1.bias', 'vision_model.encoder.layers.16.layer_norm1.bias', 'vision_model.encoder.layers.12.layer_norm1.bias', 'vision_model.encoder.layers.13.layer_norm2.weight', 'vision_model.encoder.layers.3.layer_norm1.bias', 'vision_model.encoder.layers.14.self_attn.q_proj.bias', 'vision_model.encoder.layers.2.mlp.fc2.weight', 'vision_model.encoder.layers.10.layer_norm2.weight', 'vision_model.encoder.layers.12.self_attn.v_proj.weight', 'vision_model.encoder.layers.22.self_attn.k_proj.weight', 'vision_model.encoder.layers.14.self_attn.k_proj.bias', 'vision_model.encoder.layers.9.self_attn.v_proj.weight', 'vision_model.encoder.layers.4.self_attn.out_proj.bias', 'vision_model.encoder.layers.19.self_attn.q_proj.bias', 'vision_model.encoder.layers.19.self_attn.out_proj.bias', 'vision_model.encoder.layers.12.self_attn.out_proj.bias', 'vision_model.encoder.layers.3.mlp.fc2.weight', 'vision_model.encoder.layers.13.layer_norm1.bias', 'vision_model.encoder.layers.14.self_attn.q_proj.weight', 'vision_model.encoder.layers.0.self_attn.k_proj.weight', 'vision_model.encoder.layers.7.mlp.fc2.weight', 'vision_model.encoder.layers.22.layer_norm1.weight', 'vision_model.encoder.layers.7.self_attn.k_proj.weight', 'vision_model.encoder.layers.8.mlp.fc1.weight', 'vision_model.encoder.layers.0.layer_norm1.weight', 'vision_model.encoder.layers.7.self_attn.q_proj.weight', 'vision_model.encoder.layers.3.layer_norm2.bias', 'vision_model.encoder.layers.16.mlp.fc2.weight', 'vision_model.encoder.layers.11.layer_norm2.weight', 'vision_model.encoder.layers.4.self_attn.v_proj.bias', 'vision_model.encoder.layers.7.mlp.fc1.weight', 'vision_model.encoder.layers.18.self_attn.v_proj.bias', 'vision_model.encoder.layers.4.layer_norm1.bias', 'vision_model.encoder.layers.23.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.k_proj.weight', 'vision_model.encoder.layers.1.self_attn.out_proj.weight', 'vision_model.post_layernorm.bias', 'vision_model.encoder.layers.3.mlp.fc1.weight', 'vision_model.encoder.layers.19.layer_norm2.bias', 'vision_model.encoder.layers.6.self_attn.out_proj.weight', 'vision_model.encoder.layers.5.mlp.fc1.bias', 'vision_model.encoder.layers.6.layer_norm2.weight', 'vision_model.encoder.layers.13.self_attn.v_proj.bias', 'vision_model.encoder.layers.15.layer_norm1.bias', 'vision_model.encoder.layers.21.layer_norm1.weight', 'vision_model.encoder.layers.1.layer_norm2.weight', 'vision_model.encoder.layers.5.mlp.fc1.weight', 'vision_model.encoder.layers.4.mlp.fc2.bias', 'vision_model.encoder.layers.19.self_attn.k_proj.bias', 'vision_model.encoder.layers.9.self_attn.q_proj.weight', 'vision_model.encoder.layers.12.self_attn.v_proj.bias', 'vision_model.encoder.layers.17.layer_norm1.weight', 'vision_model.encoder.layers.13.mlp.fc2.weight', 'vision_model.encoder.layers.11.self_attn.v_proj.weight', 'vision_model.encoder.layers.1.mlp.fc1.weight', 'vision_model.encoder.layers.20.mlp.fc1.bias', 'vision_model.encoder.layers.20.mlp.fc1.weight', 'vision_model.encoder.layers.2.layer_norm2.bias', 'vision_model.encoder.layers.1.self_attn.q_proj.bias', 'vision_model.encoder.layers.10.layer_norm1.bias', 'vision_model.encoder.layers.13.layer_norm1.weight', 'vision_model.encoder.layers.17.self_attn.v_proj.weight', 'vision_model.encoder.layers.15.self_attn.k_proj.bias', 'vision_model.encoder.layers.22.mlp.fc1.bias', 'vision_model.encoder.layers.11.self_attn.v_proj.bias', 'vision_model.encoder.layers.2.self_attn.q_proj.weight', 'vision_model.embeddings.position_embedding.weight', 'vision_model.encoder.layers.1.mlp.fc1.bias', 'vision_model.encoder.layers.8.layer_norm1.bias', 'vision_model.encoder.layers.7.layer_norm1.weight', 'vision_model.encoder.layers.9.self_attn.out_proj.weight', 'vision_model.encoder.layers.21.self_attn.k_proj.bias', 'vision_model.encoder.layers.0.layer_norm2.weight', 'vision_model.encoder.layers.6.self_attn.k_proj.bias', 'vision_model.encoder.layers.6.self_attn.v_proj.bias', 'vision_model.encoder.layers.3.self_attn.q_proj.bias', 'vision_model.encoder.layers.12.mlp.fc2.bias', 'vision_model.encoder.layers.18.self_attn.out_proj.weight', 'vision_model.encoder.layers.5.mlp.fc2.weight', 'vision_model.encoder.layers.12.self_attn.k_proj.weight', 'vision_model.encoder.layers.7.self_attn.k_proj.bias', 'vision_model.encoder.layers.20.self_attn.out_proj.bias', 'vision_model.embeddings.patch_embedding.weight', 'vision_model.encoder.layers.9.mlp.fc2.weight', 'vision_model.encoder.layers.15.self_attn.k_proj.weight', 'vision_model.pre_layrnorm.weight', 'vision_model.encoder.layers.17.self_attn.q_proj.weight', 'vision_model.encoder.layers.8.layer_norm2.weight', 'vision_model.encoder.layers.9.self_attn.k_proj.bias', 'vision_model.encoder.layers.12.self_attn.out_proj.weight', 'vision_model.encoder.layers.16.layer_norm2.weight', 'vision_model.encoder.layers.3.self_attn.k_proj.bias', 'vision_model.encoder.layers.13.self_attn.k_proj.bias', 'vision_model.encoder.layers.12.layer_norm2.bias', 'vision_model.encoder.layers.8.self_attn.v_proj.weight', 'vision_model.encoder.layers.21.self_attn.out_proj.bias', 'vision_model.encoder.layers.4.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.self_attn.v_proj.bias', 'vision_model.encoder.layers.13.self_attn.out_proj.bias', 'vision_model.encoder.layers.8.self_attn.out_proj.weight', 'vision_model.encoder.layers.0.mlp.fc2.bias', 'vision_model.encoder.layers.19.self_attn.out_proj.weight', 'vision_model.encoder.layers.2.self_attn.v_proj.weight', 'vision_model.encoder.layers.0.mlp.fc2.weight', 'vision_model.encoder.layers.1.mlp.fc2.weight']\n", "- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" ] } ], "source": [ "def load_artifacts():\n", " '''\n", " A function to load all diffusion artifacts\n", " '''\n", " vae = AutoencoderKL.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"vae\", torch_dtype=torch.float16).to(\"cuda\")\n", " unet = UNet2DConditionModel.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"unet\", torch_dtype=torch.float16).to(\"cuda\")\n", " tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", torch_dtype=torch.float16)\n", " text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\", torch_dtype=torch.float16).to(\"cuda\")\n", " scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False) \n", " return vae, unet, tokenizer, text_encoder, scheduler\n", "\n", "def load_image(p):\n", " '''\n", " Function to load images from a defined path\n", " '''\n", " return Image.open(p).convert('RGB').resize((512,512))\n", "\n", "def pil_to_latents(image):\n", " '''\n", " Function to convert image to latents\n", " '''\n", " init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0\n", " init_image = init_image.to(device=\"cuda\", dtype=torch.float16) \n", " init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215\n", " return init_latent_dist\n", "\n", "def latents_to_pil(latents):\n", " '''\n", " Function to convert latents to images\n", " '''\n", " latents = (1 / 0.18215) * latents\n", " with torch.no_grad():\n", " image = vae.decode(latents).sample\n", " image = (image / 2 + 0.5).clamp(0, 1)\n", " image = image.detach().cpu().permute(0, 2, 3, 1).numpy()\n", " images = (image * 255).round().astype(\"uint8\")\n", " pil_images = [Image.fromarray(image) for image in images]\n", " return pil_images\n", "\n", "def text_enc(prompts, maxlen=None):\n", " '''\n", " A function to take a texual promt and convert it into embeddings\n", " '''\n", " if maxlen is None: maxlen = tokenizer.model_max_length\n", " inp = tokenizer(prompts, padding=\"max_length\", max_length=maxlen, truncation=True, return_tensors=\"pt\") \n", " return text_encoder(inp.input_ids.to(\"cuda\"))[0].half()\n", "\n", "vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()" ] }, { "cell_type": "code", "execution_count": 6, "id": "c2d419ee", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:38:45.873536Z", "start_time": "2022-11-23T07:38:45.866722Z" } }, "outputs": [], "source": [ "def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):\n", " \"\"\"\n", " Diffusion process to convert prompt to image\n", " \"\"\"\n", " # Converting textual prompts to embedding\n", " text = text_enc(prompts) \n", " \n", " # Adding an unconditional prompt , helps in the generation process\n", " uncond = text_enc([\"\"], text.shape[1])\n", " emb = torch.cat([uncond, text])\n", " \n", " # Setting the seed\n", " if seed: torch.manual_seed(seed)\n", " \n", " # Setting number of steps in scheduler\n", " scheduler.set_timesteps(steps)\n", " \n", " # Convert the seed image to latent\n", " init_latents = pil_to_latents(init_img)\n", " \n", " # Figuring initial time step based on strength\n", " init_timestep = int(steps * strength) \n", " timesteps = scheduler.timesteps[-init_timestep]\n", " timesteps = torch.tensor([timesteps], device=\"cuda\")\n", " \n", " # Adding noise to the latents \n", " noise = torch.randn(init_latents.shape, generator=None, device=\"cuda\", dtype=init_latents.dtype)\n", " init_latents = scheduler.add_noise(init_latents, noise, timesteps)\n", " latents = init_latents\n", " \n", " # We need to scale the i/p latents to match the variance\n", " inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps)\n", " # Predicting noise residual using U-Net\n", " with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2)\n", " \n", " # Performing Guidance\n", " pred = u + g*(t-u)\n", "\n", " # Zero shot prediction\n", " latents = scheduler.step(pred, timesteps, latents).pred_original_sample\n", " \n", " # Returning the latent representation to output an array of 4x64x64\n", " return latents.detach().cpu()\n", "\n", "def create_mask_fast(init_img, rp, qp, n=20, s=0.5):\n", " ## Initialize a dictionary to save n iterations\n", " diff = {}\n", " \n", " ## Repeating the difference process n times\n", " for idx in range(n):\n", " ## Creating denoised sample using reference / original text\n", " orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0]\n", " ## Creating denoised sample using query / target text\n", " query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0]\n", " ## Taking the difference \n", " diff[idx] = (np.array(orig_noise)-np.array(query_noise))\n", " \n", " ## Creating a mask placeholder\n", " mask = np.zeros_like(diff[0])\n", " \n", " ## Taking an average of 10 iterations\n", " for idx in range(n):\n", " ## Note np.abs is a key step\n", " mask += np.abs(diff[idx]) \n", " \n", " ## Averaging multiple channels \n", " mask = mask.mean(0)\n", " \n", " ## Normalizing \n", " mask = (mask - mask.mean()) / np.std(mask)\n", " \n", " ## Binarizing and returning the mask object\n", " return (mask > 0).astype(\"uint8\")\n", "\n", "def improve_mask(mask):\n", " mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0\n", " return mask.astype('uint8')" ] }, { "cell_type": "code", "execution_count": 7, "id": "e3b0bc1e", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:38:45.886013Z", "start_time": "2022-11-23T07:38:45.874939Z" } }, "outputs": [], "source": [ "def fastDiffEdit(init_img, rp , qp, g=7.5, seed=100, strength =0.7, steps=20, dim=512):\n", " \n", " ## Step 1: Create mask\n", " mask = create_mask_fast(init_img=init_img, rp=rp, qp=qp, n=20)\n", " \n", " ## Improve masking using CV trick\n", " mask = improve_mask(mask)\n", " \n", " ## Step 2 and 3: Diffusion process using mask\n", " output = pipe(\n", " prompt=qp, \n", " image=init_img, \n", " mask_image=Image.fromarray(mask*255).resize((512,512)), \n", " generator=torch.Generator(\"cuda\").manual_seed(100),\n", " num_inference_steps = steps\n", " ).images\n", " return output[0]" ] }, { "cell_type": "code", "execution_count": 8, "id": "73a3d383", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:38:56.698098Z", "start_time": "2022-11-23T07:38:45.887255Z" } }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "652fef739c674f7fbc156e36b05db848", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 15 files: 0%| | 0/15 [00:00" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "output" ] }, { "cell_type": "code", "execution_count": 11, "id": "20f567e7", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:39:10.453493Z", "start_time": "2022-11-23T07:39:10.450855Z" } }, "outputs": [], "source": [ "def test_func(init_img, rp , qp):\n", " mask = create_mask_fast(init_img, rp, qp)\n", " mask = improve_mask(mask)\n", " return pipe(\n", " prompt=[qp], \n", " image=init_img, \n", " mask_image=Image.fromarray(mask*255),#.resize((512,512)), \n", " generator=torch.Generator(\"cuda\").manual_seed(100),\n", " num_inference_steps = 20\n", " ).images[0]" ] }, { "cell_type": "code", "execution_count": 12, "id": "0eb2e975", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:39:10.466013Z", "start_time": "2022-11-23T07:39:10.455799Z" } }, "outputs": [], "source": [ "import gradio as gr" ] }, { "cell_type": "code", "execution_count": 14, "id": "2907db3e", "metadata": { "ExecuteTime": { "end_time": "2022-11-23T07:39:35.972730Z", "start_time": "2022-11-23T07:39:31.484620Z" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/aayush/miniconda3/envs/fastai/lib/python3.9/site-packages/gradio/deprecation.py:40: UserWarning: `enable_queue` is deprecated in `Interface()`, please use it within `launch()` instead.\n", " warnings.warn(value)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7861/\n", "Running on public URL: https://a3a54c93e8e5d9f8.gradio.app\n", "\n", "This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://www.huggingface.co/spaces)\n" ] }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "(,\n", " 'http://127.0.0.1:7861/',\n", " 'https://a3a54c93e8e5d9f8.gradio.app')" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "demo = gr.Interface(\n", " fn=fastDiffEdit, \n", " inputs=[\n", " gr.inputs.Image(shape=(512, 512), type=\"pil\", label = \"Upload your image photo\"),\n", " gr.Textbox(label=\"Describe your image. Ex: a horse image\"),\n", " gr.Textbox(label=\"Retype the description with target output. Ex: a zebra image\")], \n", " outputs=\"image\",\n", " title = \"DiffEdit demo\",\n", " description = \"DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object\",\n", " examples = [\n", " [\"horse.jpg\", \"a horse image\", \"a zebra image\"],\n", " [\"fruitbowl.jpg\", \"a bowl of fruit\", \"a bowl of grapes\"]],\n", " enable_queue=True\n", " )\n", "\n", "demo.launch(share=True) " ] }, { "cell_type": "code", "execution_count": null, "id": "7abccb1f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "465.455px" }, "toc_section_display": true, "toc_window_display": true }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }