DongKyung commited on
Commit
617b10f
β€’
1 Parent(s): 0ef260c

Upload DDPM_Inversion.ipynb

Browse files
Files changed (1) hide show
  1. DDPM_Inversion.ipynb +352 -0
DDPM_Inversion.ipynb ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "gpuType": "L4"
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "markdown",
22
+ "source": [
23
+ "# https://github.com/inbarhub/DDPM_inversion"
24
+ ],
25
+ "metadata": {
26
+ "id": "2pmc1ZdmtAQJ"
27
+ }
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {
33
+ "id": "GsGhwPzb_RBH"
34
+ },
35
+ "outputs": [],
36
+ "source": [
37
+ "%pip install numpy\n",
38
+ "%pip install matplotlib\n",
39
+ "%pip install fastai\n",
40
+ "%pip install accelerate\n",
41
+ "%pip install -U transformers diffusers ftfy\n",
42
+ "%pip install torch\n",
43
+ "%pip install torchvision\n",
44
+ "%pip install opencv-python\n",
45
+ "%pip install ipywidgets"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "source": [
51
+ "import inspect\n",
52
+ "\n",
53
+ "from pathlib import Path\n",
54
+ "\n",
55
+ "import numpy as np\n",
56
+ "import torch\n",
57
+ "from accelerate import Accelerator\n",
58
+ "from diffusers import (\n",
59
+ " AutoencoderKL,\n",
60
+ " UNet2DConditionModel,\n",
61
+ " DDIMScheduler,\n",
62
+ " DPMSolverMultistepScheduler,\n",
63
+ ")\n",
64
+ "from huggingface_hub import notebook_login\n",
65
+ "from PIL import Image\n",
66
+ "from torchvision import transforms as tfms\n",
67
+ "from tqdm.auto import tqdm\n",
68
+ "from transformers import CLIPTextModel, CLIPTokenizer\n",
69
+ "from typing import Optional\n",
70
+ "import requests\n",
71
+ "\n",
72
+ "notebook_login()"
73
+ ],
74
+ "metadata": {
75
+ "id": "sYCb0YhF_YqC"
76
+ },
77
+ "execution_count": null,
78
+ "outputs": []
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "source": [
83
+ "from google.colab import drive\n",
84
+ "drive.mount('/content/drive')"
85
+ ],
86
+ "metadata": {
87
+ "id": "W3Ik_48j_Y1q"
88
+ },
89
+ "execution_count": null,
90
+ "outputs": []
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "source": [
95
+ "#init_image 즉, μΈν’‹μš© 이미지 λ§Œλ“œλŠ” μ…€\n",
96
+ "\n",
97
+ "init_image = load_image(path=\"/content/DDPM_inversion/Input_Images/cherry blossom branch petal.png\") #fill your own directory\n",
98
+ "\n",
99
+ "init_path = \"/content/DDPM_inversion/Input_Images/cherry blossom branch petal.png\" #fill your own directory"
100
+ ],
101
+ "metadata": {
102
+ "id": "tuhPV23T_Y4k"
103
+ },
104
+ "execution_count": null,
105
+ "outputs": []
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "source": [
110
+ "from transformers import Blip2Processor, Blip2ForConditionalGeneration\n",
111
+ "\n",
112
+ "processor = Blip2Processor.from_pretrained(\"Salesforce/blip2-opt-2.7b\")\n",
113
+ "imagecaptioningmodel = Blip2ForConditionalGeneration.from_pretrained(\"Salesforce/blip2-opt-2.7b\").to(device)\n",
114
+ "inputs = processor(init_image, return_tensors=\"pt\").to(device) #λ§€κ°œλ³€μˆ˜\n",
115
+ "outputs = imagecaptioningmodel.generate(**inputs)\n",
116
+ "print(processor.decode(outputs[0], skip_special_tokens=True))"
117
+ ],
118
+ "metadata": {
119
+ "id": "WRyROFhX_Y7c"
120
+ },
121
+ "execution_count": null,
122
+ "outputs": []
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "source": [
127
+ "prompt = str(processor.decode(outputs[0], skip_special_tokens=True))"
128
+ ],
129
+ "metadata": {
130
+ "id": "rh01KUQh_vW1"
131
+ },
132
+ "execution_count": null,
133
+ "outputs": []
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "source": [
138
+ "import yaml\n",
139
+ "data = [\n",
140
+ " {\n",
141
+ " \"init_img\": \"/content/DDPM_inversion/Input_Images/Cherry Blossoms.png\", #init_path μ‚¬μš©\n",
142
+ " \"source_prompt\": \"\",\n",
143
+ " \"target_prompts\": [\n",
144
+ " \"\",\n",
145
+ " ]\n",
146
+ " },\n",
147
+ "]\n",
148
+ "\n",
149
+ "file_path = '/content/DDPM_inversion/test.yaml' # λ³€κ²½ κ°€λŠ₯ν•œ 파일 경둜\n",
150
+ "\n",
151
+ "with open(file_path, 'w') as file:\n",
152
+ " yaml.dump(data, file)\n",
153
+ "with open(file_path, 'r') as file:\n",
154
+ " print(file.read())"
155
+ ],
156
+ "metadata": {
157
+ "id": "wZighP5oNL1X"
158
+ },
159
+ "execution_count": null,
160
+ "outputs": []
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "source": [
165
+ "!git clone https://github.com/Kangdongkyung/DDPM_inversion.git #do not use this. change to original git repository"
166
+ ],
167
+ "metadata": {
168
+ "id": "fuW0T7AzRPEz"
169
+ },
170
+ "execution_count": null,
171
+ "outputs": []
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "source": [
176
+ "%cd /content/DDPM_inversion #fill your own directory"
177
+ ],
178
+ "metadata": {
179
+ "id": "mM7wwPjycqSK"
180
+ },
181
+ "execution_count": null,
182
+ "outputs": []
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "source": [
187
+ "from easydict import EasyDict\n",
188
+ "from diffusers import StableDiffusionPipeline\n",
189
+ "from diffusers import DDIMScheduler\n",
190
+ "import os\n",
191
+ "from prompt_to_prompt.ptp_classes import AttentionStore, AttentionReplace, AttentionRefine, EmptyControl,load_512\n",
192
+ "from prompt_to_prompt.ptp_utils import register_attention_control, text2image_ldm_stable, view_images\n",
193
+ "from ddm_inversion.inversion_utils import inversion_forward_process, inversion_reverse_process\n",
194
+ "from ddm_inversion.utils import image_grid,dataset_from_yaml\n",
195
+ "\n",
196
+ "from torch import autocast, inference_mode\n",
197
+ "from ddm_inversion.ddim_inversion import ddim_inversion\n",
198
+ "\n",
199
+ "import calendar\n",
200
+ "import time\n",
201
+ "\n",
202
+ "if __name__ == \"__main__\":\n",
203
+ " # parser = argparse.ArgumentParser()\n",
204
+ " # parser.add_argument(\"--device_num\", type=int, default=0)\n",
205
+ " # parser.add_argument(\"--cfg_src\", type=float, default=3.5)\n",
206
+ " # parser.add_argument(\"--cfg_tar\", type=float, default=15)\n",
207
+ " # parser.add_argument(\"--num_diffusion_steps\", type=int, default=100)\n",
208
+ " # parser.add_argument(\"--dataset_yaml\", default=\"test.yaml\")\n",
209
+ " # parser.add_argument(\"--eta\", type=float, default=1)\n",
210
+ " # parser.add_argument(\"--mode\", default=\"our_inv\", help=\"modes: our_inv,p2pinv,p2pddim,ddim\")\n",
211
+ " # parser.add_argument(\"--skip\", type=int, default=36)\n",
212
+ " # parser.add_argument(\"--xa\", type=float, default=0.6)\n",
213
+ " # parser.add_argument(\"--sa\", type=float, default=0.2)\n",
214
+ "\n",
215
+ " # args = parser.parse_args()\n",
216
+ " args = EasyDict()\n",
217
+ " args.dataset_yaml = file_path\n",
218
+ " args.cfg_src = 3.5\n",
219
+ " args.cfg_tar = 15\n",
220
+ " args.num_diffusion_steps = 100\n",
221
+ " args.eta = 1\n",
222
+ " args.mode = \"our_inv\"\n",
223
+ " args.skip = 36\n",
224
+ " args.xa = 0.6\n",
225
+ " args.sa = 0.2\n",
226
+ "\n",
227
+ " full_data = dataset_from_yaml(args.dataset_yaml)\n",
228
+ "\n",
229
+ " # create scheduler\n",
230
+ " # load diffusion model\n",
231
+ " model_id = \"CompVis/stable-diffusion-v1-4\"\n",
232
+ " # model_id = \"stable_diff_local\" # load local save of model (for internet problems)\n",
233
+ "\n",
234
+ "\n",
235
+ " cfg_scale_src = args.cfg_src\n",
236
+ " cfg_scale_tar_list = [args.cfg_tar]\n",
237
+ " eta = args.eta # = 1\n",
238
+ " skip_zs = [args.skip]\n",
239
+ " xa_sa_string = f'_xa_{args.xa}_sa{args.sa}_' if args.mode=='p2pinv' else '_'\n",
240
+ "\n",
241
+ " current_GMT = time.gmtime()\n",
242
+ " time_stamp = calendar.timegm(current_GMT)\n",
243
+ "\n",
244
+ " # load/reload model:\n",
245
+ " ldm_stable = StableDiffusionPipeline.from_pretrained(model_id).to(device)\n",
246
+ "\n",
247
+ " for i in range(len(full_data)):\n",
248
+ " current_image_data = full_data[i]\n",
249
+ " image_path = current_image_data['init_img']\n",
250
+ " image_path = image_path #μ§€κΈˆμ˜ κ²½λ‘œκ°€ μ•„λ‹˜μ„ λœ»ν•˜κΈ° μœ„ν•΄ '.'을 μ œκ±°ν•œ 것. λ”°λΌμ„œ μˆ˜μ •ν•„μš”.\n",
251
+ " image_folder = image_path.split('/')[1] # after '.'\n",
252
+ " prompt_src = current_image_data.get('source_prompt', \"\") # default empty string\n",
253
+ " prompt_tar_list = current_image_data['target_prompts']\n",
254
+ "\n",
255
+ " if args.mode==\"p2pddim\" or args.mode==\"ddim\":\n",
256
+ " scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\", clip_sample=False, set_alpha_to_one=False)\n",
257
+ " ldm_stable.scheduler = scheduler\n",
258
+ " else:\n",
259
+ " ldm_stable.scheduler = DDIMScheduler.from_config(model_id, subfolder = \"scheduler\")\n",
260
+ "\n",
261
+ " ldm_stable.scheduler.set_timesteps(args.num_diffusion_steps)\n",
262
+ "\n",
263
+ " # load image\n",
264
+ " offsets=(0,0,0,0)\n",
265
+ " x0 = load_512(image_path, *offsets, device)\n",
266
+ "\n",
267
+ " # vae encode image\n",
268
+ " with autocast(\"cuda\"), inference_mode():\n",
269
+ " w0 = (ldm_stable.vae.encode(x0).latent_dist.mode() * 0.18215).float()\n",
270
+ "\n",
271
+ " # find Zs and wts - forward process\n",
272
+ " if args.mode==\"p2pddim\" or args.mode==\"ddim\":\n",
273
+ " wT = ddim_inversion(ldm_stable, w0, prompt_src, cfg_scale_src)\n",
274
+ " else:\n",
275
+ " wt, zs, wts = inversion_forward_process(ldm_stable, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=args.num_diffusion_steps)\n",
276
+ "\n",
277
+ " # iterate over decoder prompts\n",
278
+ " for k in range(len(prompt_tar_list)):\n",
279
+ " prompt_tar = prompt_tar_list[k]\n",
280
+ " save_path = os.path.join(f'./results/', args.mode+xa_sa_string+str(time_stamp), image_path.split(sep='.')[0], 'src_' + prompt_src.replace(\" \", \"_\"), 'dec_' + prompt_tar.replace(\" \", \"_\"))\n",
281
+ " os.makedirs(save_path, exist_ok=True)\n",
282
+ "\n",
283
+ " # Check if number of words in encoder and decoder text are equal\n",
284
+ " src_tar_len_eq = (len(prompt_src.split(\" \")) == len(prompt_tar.split(\" \")))\n",
285
+ "\n",
286
+ " for cfg_scale_tar in cfg_scale_tar_list:\n",
287
+ " for skip in skip_zs:\n",
288
+ " if args.mode==\"our_inv\":\n",
289
+ " # reverse process (via Zs and wT)\n",
290
+ " controller = AttentionStore()\n",
291
+ " register_attention_control(ldm_stable, controller)\n",
292
+ " w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)\n",
293
+ "\n",
294
+ " elif args.mode==\"p2pinv\":\n",
295
+ " # inversion with attention replace\n",
296
+ " cfg_scale_list = [cfg_scale_src, cfg_scale_tar]\n",
297
+ " prompts = [prompt_src, prompt_tar]\n",
298
+ " if src_tar_len_eq:\n",
299
+ " controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)\n",
300
+ " else:\n",
301
+ " # Should use Refine for target prompts with different number of tokens\n",
302
+ " controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=args.xa, self_replace_steps=args.sa, model=ldm_stable)\n",
303
+ "\n",
304
+ " register_attention_control(ldm_stable, controller)\n",
305
+ " w0, _ = inversion_reverse_process(ldm_stable, xT=wts[args.num_diffusion_steps-skip], etas=eta, prompts=prompts, cfg_scales=cfg_scale_list, prog_bar=True, zs=zs[:(args.num_diffusion_steps-skip)], controller=controller)\n",
306
+ " w0 = w0[1].unsqueeze(0)\n",
307
+ "\n",
308
+ " elif args.mode==\"p2pddim\" or args.mode==\"ddim\":\n",
309
+ " # only z=0\n",
310
+ " if skip != 0:\n",
311
+ " continue\n",
312
+ " prompts = [prompt_src, prompt_tar]\n",
313
+ " if args.mode==\"p2pddim\":\n",
314
+ " if src_tar_len_eq:\n",
315
+ " controller = AttentionReplace(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)\n",
316
+ " # Should use Refine for target prompts with different number of tokens\n",
317
+ " else:\n",
318
+ " controller = AttentionRefine(prompts, args.num_diffusion_steps, cross_replace_steps=.8, self_replace_steps=0.4, model=ldm_stable)\n",
319
+ " else:\n",
320
+ " controller = EmptyControl()\n",
321
+ "\n",
322
+ " register_attention_control(ldm_stable, controller)\n",
323
+ " # perform ddim inversion\n",
324
+ " cfg_scale_list = [cfg_scale_src, cfg_scale_tar]\n",
325
+ " w0, latent = text2image_ldm_stable(ldm_stable, prompts, controller, args.num_diffusion_steps, cfg_scale_list, None, wT)\n",
326
+ " w0 = w0[1:2]\n",
327
+ " else:\n",
328
+ " raise NotImplementedError\n",
329
+ "\n",
330
+ " # vae decode image\n",
331
+ " with autocast(\"cuda\"), inference_mode():\n",
332
+ " x0_dec = ldm_stable.vae.decode(1 / 0.18215 * w0).sample\n",
333
+ " if x0_dec.dim()<4:\n",
334
+ " x0_dec = x0_dec[None,:,:,:]\n",
335
+ " img = image_grid(x0_dec)\n",
336
+ "\n",
337
+ " # same output\n",
338
+ " current_GMT = time.gmtime()\n",
339
+ " time_stamp_name = calendar.timegm(current_GMT)\n",
340
+ " image_name_png = f'cfg_d_{cfg_scale_tar}_' + f'skip_{skip}_{time_stamp_name}' + \".png\"\n",
341
+ "\n",
342
+ " save_full_path = os.path.join(save_path, image_name_png)\n",
343
+ " img.save(save_full_path)"
344
+ ],
345
+ "metadata": {
346
+ "id": "dcVYikEa_wQ1"
347
+ },
348
+ "execution_count": null,
349
+ "outputs": []
350
+ }
351
+ ]
352
+ }