JiayiGuo821 commited on
Commit
b6e0092
1 Parent(s): e43437a
app.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ################################################################################
2
+ # Copyright (C) 2023 Jiayi Guo, Xingqian Xu, Manushree Vasu - All Rights Reserved #
3
+ ################################################################################
4
+
5
+ import gradio as gr
6
+ import os
7
+ import os.path as osp
8
+ import PIL
9
+ from PIL import Image
10
+ import numpy as np
11
+ from collections import OrderedDict
12
+ from easydict import EasyDict as edict
13
+ from functools import partial
14
+
15
+ import torch
16
+ import torchvision.transforms as tvtrans
17
+ import time
18
+ import argparse
19
+ import json
20
+ import hashlib
21
+ import copy
22
+ from tqdm import tqdm
23
+
24
+ from diffusers import StableDiffusionPipeline
25
+ from diffusers import DDIMScheduler
26
+ from app_utils import auto_dropdown
27
+
28
+ from huggingface_hub import hf_hub_download
29
+
30
+ version = "Smooth Diffusion Demo v1.0"
31
+ refresh_symbol = "\U0001f504" # 🔄
32
+ recycle_symbol = '\U0000267b' #
33
+
34
+ ##############
35
+ # model_book #
36
+ ##############
37
+
38
+ choices = edict()
39
+ choices.diffuser = OrderedDict([
40
+ ['SD-v1-5' , "runwayml/stable-diffusion-v1-5"],
41
+ ['OJ-v4' , "prompthero/openjourney-v4"],
42
+ ['RR-v2', "SG161222/Realistic_Vision_V2.0"],
43
+ ])
44
+
45
+ choices.lora = OrderedDict([
46
+ ['empty', ""],
47
+ ['Smooth-LoRA-v1', hf_hub_download('shi-labs/smooth-diffusion-lora', 'pytorch_model.bin')],
48
+ ])
49
+
50
+ choices.scheduler = OrderedDict([
51
+ ['DDIM', DDIMScheduler],
52
+ ])
53
+
54
+ choices.inversion = OrderedDict([
55
+ ['NTI', 'NTI'],
56
+ ['DDIM w/o text', 'DDIM w/o text'],
57
+ ['DDIM', 'DDIM'],
58
+ ])
59
+
60
+ default = edict()
61
+ default.diffuser = 'SD-v1-5'
62
+ default.scheduler = 'DDIM'
63
+ default.lora = 'Smooth-LoRA-v1'
64
+ default.inversion = 'NTI'
65
+ default.step = 50
66
+ default.cfg_scale = 7.5
67
+ default.framen = 24
68
+ default.fps = 16
69
+ default.nullinv_inner_step = 10
70
+ default.threshold = 0.8
71
+ default.variation = 0.8
72
+
73
+ ##########
74
+ # helper #
75
+ ##########
76
+
77
+ def lerp(t, v0, v1):
78
+ if isinstance(t, float):
79
+ return v0*(1-t) + v1*t
80
+ elif isinstance(t, (list, np.ndarray)):
81
+ return [v0*(1-ti) + v1*ti for ti in t]
82
+
83
+ def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
84
+ # mostly copied from
85
+ # https://gist.github.com/dvschultz/3af50c40df002da3b751efab1daddf2c
86
+ v0_unit = v0 / np.linalg.norm(v0)
87
+ v1_unit = v1 / np.linalg.norm(v1)
88
+ dot = np.sum(v0_unit * v1_unit)
89
+ if np.abs(dot) > DOT_THRESHOLD:
90
+ return lerp(t, v0, v1)
91
+ # Calculate initial angle between v0 and v1
92
+ theta_0 = np.arccos(dot)
93
+ sin_theta_0 = np.sin(theta_0)
94
+ # Angle at timestep t
95
+
96
+ if isinstance(t, float):
97
+ tlist = [t]
98
+ elif isinstance(t, (list, np.ndarray)):
99
+ tlist = t
100
+
101
+ v2_list = []
102
+
103
+ for ti in tlist:
104
+ theta_t = theta_0 * ti
105
+ sin_theta_t = np.sin(theta_t)
106
+ # Finish the slerp algorithm
107
+ s0 = np.sin(theta_0 - theta_t) / sin_theta_0
108
+ s1 = sin_theta_t / sin_theta_0
109
+ v2 = s0 * v0 + s1 * v1
110
+ v2_list.append(v2)
111
+
112
+ if isinstance(t, float):
113
+ return v2_list[0]
114
+ else:
115
+ return v2_list
116
+
117
+ def offset_resize(image, width=512, height=512, left=0, right=0, top=0, bottom=0):
118
+
119
+ image = np.array(image)[:, :, :3]
120
+ h, w, c = image.shape
121
+ left = min(left, w-1)
122
+ right = min(right, w - left - 1)
123
+ top = min(top, h - left - 1)
124
+ bottom = min(bottom, h - top - 1)
125
+ image = image[top:h-bottom, left:w-right]
126
+ h, w, c = image.shape
127
+ if h < w:
128
+ offset = (w - h) // 2
129
+ image = image[:, offset:offset + h]
130
+ elif w < h:
131
+ offset = (h - w) // 2
132
+ image = image[offset:offset + w]
133
+ image = Image.fromarray(image).resize((width, height))
134
+ return image
135
+
136
+ def auto_dtype_device_shape(tlist, v0, v1, func,):
137
+ vshape = v0.shape
138
+ assert v0.shape == v1.shape
139
+ assert isinstance(tlist, (list, np.ndarray))
140
+
141
+ if isinstance(v0, torch.Tensor):
142
+ is_torch = True
143
+ dtype, device = v0.dtype, v0.device
144
+ v0 = v0.to('cpu').numpy().astype(float).flatten()
145
+ v1 = v1.to('cpu').numpy().astype(float).flatten()
146
+ else:
147
+ is_torch = False
148
+ dtype = v0.dtype
149
+ assert isinstance(v0, np.ndarray)
150
+ assert isinstance(v1, np.ndarray)
151
+ v0 = v0.astype(float).flatten()
152
+ v1 = v1.astype(float).flatten()
153
+
154
+ r = func(tlist, v0, v1)
155
+
156
+ if is_torch:
157
+ r = [torch.Tensor(ri).view(*vshape).to(dtype).to(device) for ri in r]
158
+ else:
159
+ r = [ri.astype(dtype) for ri in r]
160
+ return r
161
+
162
+ auto_lerp = partial(auto_dtype_device_shape, func=lerp)
163
+ auto_slerp = partial(auto_dtype_device_shape, func=slerp)
164
+
165
+ def frames2mp4(vpath, frames, fps):
166
+ import moviepy.editor as mpy
167
+ frames = [np.array(framei) for framei in frames]
168
+ clip = mpy.ImageSequenceClip(frames, fps=fps)
169
+ clip.write_videofile(vpath, fps=fps)
170
+
171
+ def negseed_to_rndseed(seed):
172
+ if seed < 0:
173
+ seed = np.random.randint(0, np.iinfo(np.uint32).max-100)
174
+ return seed
175
+
176
+ def regulate_image(pilim):
177
+ w, h = pilim.size
178
+ w = int(round(w/64)) * 64
179
+ h = int(round(h/64)) * 64
180
+ return pilim.resize([w, h], resample=PIL.Image.BILINEAR)
181
+
182
+ def txt_to_emb(model, prompt):
183
+ text_input = model.tokenizer(
184
+ prompt,
185
+ padding="max_length",
186
+ max_length=model.tokenizer.model_max_length,
187
+ truncation=True,
188
+ return_tensors="pt",)
189
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
190
+ return text_embeddings
191
+
192
+ def hash_pilim(pilim):
193
+ hasha = hashlib.md5(pilim.tobytes()).hexdigest()
194
+ return hasha
195
+
196
+ def hash_cfgdict(cfgdict):
197
+ hashb = hashlib.md5(json.dumps(cfgdict, sort_keys=True).encode('utf-8')).hexdigest()
198
+ return hashb
199
+
200
+ def remove_earliest_file(path, max_allowance=500, remove_ratio=0.1, ext=None):
201
+ if len(os.listdir(path)) <= max_allowance:
202
+ return
203
+ def get_mtime(fname):
204
+ return osp.getmtime(osp.join(path, fname))
205
+ if ext is None:
206
+ flist = sorted(os.listdir(path), key=get_mtime)
207
+ else:
208
+ flist = [fi for fi in os.listdir(path) if fi.endswith(ext)]
209
+ flist = sorted(flist, key=get_mtime)
210
+ exceedn = max(len(flist)-max_allowance, 0)
211
+ removen = int(max_allowance*remove_ratio)
212
+ removen = max(1, removen) + exceedn
213
+ for fi in flist[0:removen]:
214
+ os.remove(osp.join(path, fi))
215
+
216
+ def remove_decoupled_file(path, exta='.mp4', extb='.json'):
217
+ tag_a = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(exta)]
218
+ tag_b = [osp.splitext(fi)[0] for fi in os.listdir(path) if fi.endswith(extb)]
219
+ tag_a_extra = set(tag_a) - set(tag_b)
220
+ tag_b_extra = set(tag_b) - set(tag_a)
221
+ [os.remove(osp.join(path, tagi+exta)) for tagi in tag_a_extra]
222
+ [os.remove(osp.join(path, tagi+extb)) for tagi in tag_b_extra]
223
+
224
+ @torch.no_grad()
225
+ def t2i_core(model, xt, emb, nemb, step=30, cfg_scale=7.5, return_list=False):
226
+ from nulltxtinv_wrapper import diffusion_step, latent2image
227
+ model.scheduler.set_timesteps(step)
228
+ xi = xt
229
+ emb = txt_to_emb(model, "") if emb is None else emb
230
+ nemb = txt_to_emb(model, "") if nemb is None else nemb
231
+ if return_list:
232
+ xi_list = [xi.clone()]
233
+ for i, t in enumerate(tqdm(model.scheduler.timesteps)):
234
+ embi = emb[i] if isinstance(emb, list) else emb
235
+ nembi = nemb[i] if isinstance(nemb, list) else nemb
236
+ context = torch.cat([nembi, embi])
237
+ xi = diffusion_step(model, xi, context, t, cfg_scale, low_resource=False)
238
+ if return_list:
239
+ xi_list.append(xi.clone())
240
+ x0 = xi
241
+ im = latent2image(model.vae, x0, return_type='pil')
242
+
243
+ if return_list:
244
+ return im, xi_list
245
+ else:
246
+ return im
247
+
248
+ ########
249
+ # main #
250
+ ########
251
+
252
+ class wrapper(object):
253
+ def __init__(self,
254
+ fp16=False,
255
+ tag_diffuser=None,
256
+ tag_lora=None,
257
+ tag_scheduler=None,):
258
+
259
+ self.device = "cuda"
260
+ if fp16:
261
+ self.torch_dtype = torch.float16
262
+ else:
263
+ self.torch_dtype = torch.float32
264
+ self.load_all(tag_diffuser, tag_lora, tag_scheduler)
265
+
266
+ self.image_latent_dim = 4
267
+ self.batchsize = 8
268
+ self.seed = {}
269
+
270
+ self.cache_video_folder = "temp/video"
271
+ self.cache_video_maxn = 500
272
+ self.cache_image_folder = "temp/image"
273
+ self.cache_image_maxn = 500
274
+ self.cache_inverse_folder = "temp/inverse"
275
+ self.cache_inverse_maxn = 500
276
+
277
+ def load_all(self, tag_diffuser, tag_lora, tag_scheduler):
278
+ self.load_diffuser_lora(tag_diffuser, tag_lora)
279
+ self.load_scheduler(tag_scheduler)
280
+ return tag_diffuser, tag_lora, tag_scheduler
281
+
282
+ def load_diffuser_lora(self, tag_diffuser, tag_lora):
283
+ self.net = StableDiffusionPipeline.from_pretrained(
284
+ choices.diffuser[tag_diffuser], torch_dtype=self.torch_dtype).to(self.device)
285
+ self.net.safety_checker = None
286
+ if tag_lora != 'empty':
287
+ self.net.unet.load_attn_procs(
288
+ choices.lora[tag_lora], use_safetensors=False,)
289
+ self.tag_diffuser = tag_diffuser
290
+ self.tag_lora = tag_lora
291
+ return tag_diffuser, tag_lora
292
+
293
+ def load_scheduler(self, tag_scheduler):
294
+ self.net.scheduler = choices.scheduler[tag_scheduler].from_config(self.net.scheduler.config)
295
+ self.tag_scheduler = tag_scheduler
296
+ return tag_scheduler
297
+
298
+ def reset_seed(self, which='ltintp'):
299
+ return -1
300
+
301
+ def recycle_seed(self, which='ltintp'):
302
+ if which not in self.seed:
303
+ return self.reset_seed(which=which)
304
+ else:
305
+ return self.seed[which]
306
+
307
+ ##########
308
+ # helper #
309
+ ##########
310
+
311
+ def precheck_model(self, tag_diffuser, tag_lora, tag_scheduler):
312
+ if (tag_diffuser != self.tag_diffuser) or (tag_lora != self.tag_lora):
313
+ self.load_all(tag_diffuser, tag_lora, tag_scheduler)
314
+ if tag_scheduler != self.tag_scheduler:
315
+ self.load_scheduler(tag_scheduler)
316
+
317
+ ########
318
+ # main #
319
+ ########
320
+
321
+ def ddiminv(self, img, cfgdict):
322
+ txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale']
323
+ from nulltxtinv_wrapper import NullInversion
324
+ null_inversion_model = NullInversion(self.net, step, cfg_scale)
325
+ with torch.no_grad():
326
+ emb = txt_to_emb(self.net, txt)
327
+ nemb = txt_to_emb(self.net, "")
328
+ xt = null_inversion_model.ddim_invert(img, txt)
329
+ data = {
330
+ 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
331
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
332
+ 'xt': xt, 'emb': emb, 'nemb': nemb,}
333
+ return data
334
+
335
+ def nullinv_or_loadcache(self, img, cfgdict, force_reinvert=False):
336
+ hash = hash_pilim(img) + "--" + hash_cfgdict(cfgdict)
337
+ cdir = self.cache_inverse_folder
338
+ cfname = osp.join(cdir, hash+'.pth')
339
+
340
+ if osp.isfile(cfname) and (not force_reinvert):
341
+ cache_data = torch.load(cfname)
342
+ dtype = next(self.net.unet.parameters()).dtype
343
+ device = next(self.net.unet.parameters()).device
344
+ cache_data['xt'] = cache_data['xt'].to(device=device, dtype=dtype)
345
+ cache_data['emb'] = cache_data['emb'].to(device=device, dtype=dtype)
346
+ cache_data['nemb'] = [
347
+ nembi.to(device=device, dtype=dtype)
348
+ for nembi in cache_data['nemb']]
349
+ return cache_data
350
+ else:
351
+ txt, step, cfg_scale = cfgdict['txt'], cfgdict['step'], cfgdict['cfg_scale']
352
+ inner_step = cfgdict['inner_step']
353
+ from nulltxtinv_wrapper import NullInversion
354
+ null_inversion_model = NullInversion(self.net, step, cfg_scale)
355
+ with torch.no_grad():
356
+ emb = txt_to_emb(self.net, txt)
357
+ xt, nemb = null_inversion_model.null_invert(img, txt, num_inner_steps=inner_step)
358
+ cache_data = {
359
+ 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
360
+ 'inner_step' : inner_step,
361
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
362
+ 'xt' : xt.to('cpu'),
363
+ 'emb' : emb.to('cpu'),
364
+ 'nemb' : [nembi.to('cpu') for nembi in nemb],}
365
+ os.makedirs(cdir, exist_ok=True)
366
+ remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn)
367
+ torch.save(cache_data, cfname)
368
+ data = {
369
+ 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt,
370
+ 'inner_step' : inner_step,
371
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
372
+ 'xt' : xt, 'emb' : emb, 'nemb' : nemb,}
373
+ return data
374
+
375
+ def nullinvdual_or_loadcachedual(self, img0, img1, cfgdict, force_reinvert=False):
376
+ hash = hash_pilim(img0) + "--" + hash_pilim(img1) + "--" + hash_cfgdict(cfgdict)
377
+ cdir = self.cache_inverse_folder
378
+ cfname = osp.join(cdir, hash+'.pth')
379
+
380
+ if osp.isfile(cfname) and (not force_reinvert):
381
+ cache_data = torch.load(cfname)
382
+ dtype = next(self.net.unet.parameters()).dtype
383
+ device = next(self.net.unet.parameters()).device
384
+ cache_data['xt0'] = cache_data['xt0'].to(device=device, dtype=dtype)
385
+ cache_data['xt1'] = cache_data['xt1'].to(device=device, dtype=dtype)
386
+ cache_data['emb0'] = cache_data['emb0'].to(device=device, dtype=dtype)
387
+ cache_data['emb1'] = cache_data['emb1'].to(device=device, dtype=dtype)
388
+ cache_data['nemb'] = [
389
+ nembi.to(device=device, dtype=dtype)
390
+ for nembi in cache_data['nemb']]
391
+
392
+ cache_data_a = copy.deepcopy(cache_data)
393
+ cache_data_a['xt'] = cache_data_a.pop('xt0')
394
+ cache_data_a['emb'] = cache_data_a.pop('emb0')
395
+ cache_data_a.pop('xt1'); cache_data_a.pop('emb1')
396
+
397
+ cache_data_b = cache_data
398
+ cache_data_b['xt'] = cache_data_b.pop('xt1')
399
+ cache_data_b['emb'] = cache_data_b.pop('emb1')
400
+ cache_data_b.pop('xt0'); cache_data_b.pop('emb0')
401
+
402
+ return cache_data_a, cache_data_b
403
+ else:
404
+ txt0, txt1, step, cfg_scale, inner_step = \
405
+ cfgdict['txt0'], cfgdict['txt1'], cfgdict['step'], \
406
+ cfgdict['cfg_scale'], cfgdict['inner_step']
407
+
408
+ from nulltxtinv_wrapper import NullInversion
409
+ null_inversion_model = NullInversion(self.net, step, cfg_scale)
410
+ with torch.no_grad():
411
+ emb0 = txt_to_emb(self.net, txt0)
412
+ emb1 = txt_to_emb(self.net, txt1)
413
+
414
+ xt0, xt1, nemb = null_inversion_model.null_invert_dual(
415
+ img0, img1, txt0, txt1, num_inner_steps=inner_step)
416
+ cache_data = {
417
+ 'step' : step, 'cfg_scale' : cfg_scale,
418
+ 'txt0' : txt0, 'txt1' : txt1,
419
+ 'inner_step' : inner_step,
420
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
421
+ 'xt0' : xt0.to('cpu'), 'xt1' : xt1.to('cpu'),
422
+ 'emb0' : emb0.to('cpu'), 'emb1' : emb1.to('cpu'),
423
+ 'nemb' : [nembi.to('cpu') for nembi in nemb],}
424
+ os.makedirs(cdir, exist_ok=True)
425
+ remove_earliest_file(cdir, max_allowance=self.cache_inverse_maxn)
426
+ torch.save(cache_data, cfname)
427
+ data0 = {
428
+ 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt0,
429
+ 'inner_step' : inner_step,
430
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
431
+ 'xt' : xt0, 'emb' : emb0, 'nemb' : nemb,}
432
+ data1 = {
433
+ 'step' : step, 'cfg_scale' : cfg_scale, 'txt' : txt1,
434
+ 'inner_step' : inner_step,
435
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,
436
+ 'xt' : xt1, 'emb' : emb1, 'nemb' : nemb,}
437
+ return data0, data1
438
+
439
+ def image_inversion(
440
+ self, img, txt,
441
+ cfg_scale, step,
442
+ inversion, inner_step, force_reinvert):
443
+ from nulltxtinv_wrapper import text2image_ldm
444
+ if inversion == 'DDIM w/o text':
445
+ txt = ''
446
+ if not inversion == 'NTI':
447
+ data = self.ddiminv(img, {'txt':txt, 'step':step, 'cfg_scale':cfg_scale,})
448
+ else:
449
+ data = self.nullinv_or_loadcache(
450
+ img, {'txt':txt, 'step':step,
451
+ 'cfg_scale':cfg_scale, 'inner_step':inner_step,
452
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
453
+
454
+ if inversion == 'NTI':
455
+ img_inv, _ = text2image_ldm(
456
+ self.net, [txt], step, cfg_scale,
457
+ latent=data['xt'], uncond_embeddings=data['nemb'])
458
+ else:
459
+ img_inv, _ = text2image_ldm(
460
+ self.net, [txt], step, cfg_scale,
461
+ latent=data['xt'], uncond_embeddings=None)
462
+
463
+ return img_inv
464
+
465
+ def image_editing(
466
+ self, img, txt_0, txt_1,
467
+ cfg_scale, step, thresh,
468
+ inversion, inner_step, force_reinvert):
469
+ from nulltxtinv_wrapper import text2image_ldm_imedit
470
+ if inversion == 'DDIM w/o text':
471
+ txt_0 = ''
472
+ if not inversion == 'NTI':
473
+ data = self.ddiminv(img, {'txt':txt_0, 'step':step, 'cfg_scale':cfg_scale,})
474
+ img_edited, _ = text2image_ldm_imedit(
475
+ self.net, thresh, [txt_0], [txt_1], step, cfg_scale,
476
+ latent=data['xt'], uncond_embeddings=None)
477
+ else:
478
+ data = self.nullinv_or_loadcache(
479
+ img, {'txt':txt_0, 'step':step,
480
+ 'cfg_scale':cfg_scale, 'inner_step':inner_step,
481
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
482
+ img_edited, _ = text2image_ldm_imedit(
483
+ self.net, thresh, [txt_0], [txt_1], step, cfg_scale,
484
+ latent=data['xt'], uncond_embeddings=data['nemb'])
485
+
486
+ return img_edited
487
+
488
+ def general_interpolation(
489
+ self, xset0, xset1,
490
+ cfg_scale, step, tlist,):
491
+
492
+ xt0, emb0, nemb0 = xset0['xt'], xset0['emb'], xset0['nemb']
493
+ xt1, emb1, nemb1 = xset1['xt'], xset1['emb'], xset1['nemb']
494
+ framen = len(tlist)
495
+
496
+ xt_list = auto_slerp(tlist, xt0, xt1)
497
+ emb_list = auto_lerp(tlist, emb0, emb1)
498
+
499
+ if isinstance(nemb0, list) and isinstance(nemb1, list):
500
+ assert len(nemb0) == len(nemb1)
501
+ nemb_list = [auto_lerp(tlist, e0, e1) for e0, e1 in zip(nemb0, nemb1)]
502
+ nemb_islist = True
503
+ else:
504
+ nemb_list = auto_lerp(tlist, nemb0, nemb1)
505
+ nemb_islist = False
506
+
507
+ im_list = []
508
+ for frameidx in range(0, len(xt_list), self.batchsize):
509
+ xt_batch = [xt_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
510
+ xt_batch = torch.cat(xt_batch, dim=0)
511
+ emb_batch = [emb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
512
+ emb_batch = torch.cat(emb_batch, dim=0)
513
+ if nemb_islist:
514
+ nemb_batch = []
515
+ for nembi in nemb_list:
516
+ nembi_batch = [nembi[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
517
+ nembi_batch = torch.cat(nembi_batch, dim=0)
518
+ nemb_batch.append(nembi_batch)
519
+ else:
520
+ nemb_batch = [nemb_list[idx] for idx in range(frameidx, min(frameidx+self.batchsize, framen))]
521
+ nemb_batch = torch.cat(nemb_batch, dim=0)
522
+
523
+ im = t2i_core(
524
+ self.net, xt_batch, emb_batch, nemb_batch, step, cfg_scale)
525
+ im_list += im if isinstance(im, list) else [im]
526
+
527
+ return im_list
528
+
529
+ def run_iminvs(
530
+ self, img, text,
531
+ cfg_scale, step,
532
+ force_resize, width, height,
533
+ inversion, inner_step, force_reinvert,
534
+ tag_diffuser, tag_lora, tag_scheduler, ):
535
+
536
+ self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
537
+
538
+ if force_resize:
539
+ img = offset_resize(img, width, height)
540
+ else:
541
+ img = regulate_image(img)
542
+
543
+ recon_output = self.image_inversion(
544
+ img, text, cfg_scale, step,
545
+ inversion, inner_step, force_reinvert)
546
+
547
+ idir = self.cache_image_folder
548
+ os.makedirs(idir, exist_ok=True)
549
+ remove_earliest_file(idir, max_allowance=self.cache_image_maxn)
550
+ sname = "time{}_iminvs_{}_{}".format(
551
+ int(time.time()), self.tag_diffuser, self.tag_lora,)
552
+ ipath = osp.join(idir, sname+'.png')
553
+ recon_output.save(ipath)
554
+
555
+ return [recon_output]
556
+
557
+ def run_imedit(
558
+ self, img, txt_0,txt_1,
559
+ threshold, cfg_scale, step,
560
+ force_resize, width, height,
561
+ inversion, inner_step, force_reinvert,
562
+ tag_diffuser, tag_lora, tag_scheduler, ):
563
+
564
+ self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
565
+ if force_resize:
566
+ img = offset_resize(img, width, height)
567
+ else:
568
+ img = regulate_image(img)
569
+
570
+ edited_img= self.image_editing(
571
+ img, txt_0,txt_1, cfg_scale, step, threshold,
572
+ inversion, inner_step, force_reinvert)
573
+
574
+ idir = self.cache_image_folder
575
+ os.makedirs(idir, exist_ok=True)
576
+ remove_earliest_file(idir, max_allowance=self.cache_image_maxn)
577
+ sname = "time{}_imedit_{}_{}".format(
578
+ int(time.time()), self.tag_diffuser, self.tag_lora,)
579
+ ipath = osp.join(idir, sname+'.png')
580
+ edited_img.save(ipath)
581
+
582
+ return [edited_img]
583
+
584
+
585
+ def run_imintp(
586
+ self,
587
+ img0, img1, txt0, txt1,
588
+ cfg_scale, step,
589
+ framen, fps,
590
+ force_resize, width, height,
591
+ inversion, inner_step, force_reinvert,
592
+ tag_diffuser, tag_lora, tag_scheduler,):
593
+
594
+ self.precheck_model(tag_diffuser, tag_lora, tag_scheduler)
595
+ if txt1 == '':
596
+ txt1 = txt0
597
+ if force_resize:
598
+ img0 = offset_resize(img0, width, height)
599
+ img1 = offset_resize(img1, width, height)
600
+ else:
601
+ img0 = regulate_image(img0)
602
+ img1 = regulate_image(img1)
603
+
604
+ if inversion == 'DDIM':
605
+ data0 = self.ddiminv(img0, {'txt':txt0, 'step':step, 'cfg_scale':cfg_scale,})
606
+ data1 = self.ddiminv(img1, {'txt':txt1, 'step':step, 'cfg_scale':cfg_scale,})
607
+ elif inversion == 'DDIM w/o text':
608
+ data0 = self.ddiminv(img0, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,})
609
+ data1 = self.ddiminv(img1, {'txt':"", 'step':step, 'cfg_scale':cfg_scale,})
610
+ else:
611
+ data0, data1 = self.nullinvdual_or_loadcachedual(
612
+ img0, img1, {'txt0':txt0, 'txt1':txt1, 'step':step,
613
+ 'cfg_scale':cfg_scale, 'inner_step':inner_step,
614
+ 'diffuser' : self.tag_diffuser, 'lora' : self.tag_lora,}, force_reinvert)
615
+
616
+ tlist = np.linspace(0.0, 1.0, framen)
617
+
618
+ iminv0 = t2i_core(self.net, data0['xt'], data0['emb'], data0['nemb'], step, cfg_scale)
619
+ iminv1 = t2i_core(self.net, data1['xt'], data1['emb'], data1['nemb'], step, cfg_scale)
620
+ frames = self.general_interpolation(data0, data1, cfg_scale, step, tlist)
621
+
622
+ vdir = self.cache_video_folder
623
+ os.makedirs(vdir, exist_ok=True)
624
+ remove_earliest_file(vdir, max_allowance=self.cache_video_maxn)
625
+ sname = "time{}_imintp_{}_{}_framen{}_fps{}".format(
626
+ int(time.time()), self.tag_diffuser, self.tag_lora, framen, fps)
627
+ vpath = osp.join(vdir, sname+'.mp4')
628
+ frames2mp4(vpath, frames, fps)
629
+ jpath = osp.join(vdir, sname+'.json')
630
+ cfgdict = {
631
+ "method" : "image_interpolation",
632
+ "txt0" : txt0, "txt1" : txt1,
633
+ "cfg_scale" : cfg_scale, "step" : step,
634
+ "framen" : framen, "fps" : fps,
635
+ "force_resize" : force_resize, "width" : width, "height" : height,
636
+ "inversion" : inversion, "inner_step" : inner_step,
637
+ "force_reinvert" : force_reinvert,
638
+ "tag_diffuser" : tag_diffuser, "tag_lora" : tag_lora, "tag_scheduler" : tag_scheduler,}
639
+ with open(jpath, 'w') as f:
640
+ json.dump(cfgdict, f, indent=4)
641
+
642
+ return frames, vpath, [iminv0, iminv1]
643
+
644
+ #################
645
+ # get examples #
646
+ #################
647
+ cache_examples = False
648
+ def get_imintp_example():
649
+ case = [
650
+ [
651
+ 'assets/images/interpolation/cityview1.png',
652
+ 'assets/images/interpolation/cityview2.png',
653
+ 'A city view',],
654
+ [
655
+ 'assets/images/interpolation/woman1.png',
656
+ 'assets/images/interpolation/woman2.png',
657
+ 'A woman face',],
658
+ [
659
+ 'assets/images/interpolation/land1.png',
660
+ 'assets/images/interpolation/land2.png',
661
+ 'A beautiful landscape',],
662
+ [
663
+ 'assets/images/interpolation/dog1.png',
664
+ 'assets/images/interpolation/dog2.png',
665
+ 'A realistic dog',],
666
+ [
667
+ 'assets/images/interpolation/church1.png',
668
+ 'assets/images/interpolation/church2.png',
669
+ 'A church',],
670
+ [
671
+ 'assets/images/interpolation/rabbit1.png',
672
+ 'assets/images/interpolation/rabbit2.png',
673
+ 'A cute rabbit',],
674
+ [
675
+ 'assets/images/interpolation/horse1.png',
676
+ 'assets/images/interpolation/horse2.png',
677
+ 'A robot horse',],
678
+ ]
679
+ return case
680
+
681
+ def get_iminvs_example():
682
+ case = [
683
+ [
684
+ 'assets/images/inversion/000000560011.jpg',
685
+ 'A mouse is next to a keyboard on a desk',],
686
+ [
687
+ 'assets/images/inversion/000000029596.jpg',
688
+ 'A room with a couch, table set with dinnerware and a television.',],
689
+ ]
690
+ return case
691
+
692
+
693
+ def get_imedit_example():
694
+ case = [
695
+ [
696
+ 'assets/images/editing/rabbit.png',
697
+ 'A rabbit is eating a watermelon on the table',
698
+ 'A cat is eating a watermelon on the table',
699
+ 0.7,],
700
+ [
701
+ 'assets/images/editing/cake.png',
702
+ 'A chocolate cake with cream on it',
703
+ 'A chocolate cake with strawberries on it',
704
+ 0.9,],
705
+ [
706
+ 'assets/images/editing/banana.png',
707
+ 'A banana on the table',
708
+ 'A banana and an apple on the table',
709
+ 0.8,],
710
+
711
+ ]
712
+ return case
713
+
714
+
715
+ #################
716
+ # sub interface #
717
+ #################
718
+
719
+
720
+ def interface_imintp(wrapper_obj):
721
+ with gr.Row():
722
+ with gr.Column():
723
+ img0 = gr.Image(label="Image Input 0", type='pil', elem_id='customized_imbox')
724
+ with gr.Column():
725
+ img1 = gr.Image(label="Image Input 1", type='pil', elem_id='customized_imbox')
726
+ with gr.Column():
727
+ video_output = gr.Video(label="Video Result", format='mp4', elem_id='customized_imbox')
728
+ with gr.Row():
729
+ with gr.Column():
730
+ txt0 = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", )
731
+ with gr.Column():
732
+ with gr.Row():
733
+ inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
734
+ inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
735
+ force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
736
+
737
+
738
+ with gr.Row():
739
+ with gr.Column():
740
+ with gr.Row():
741
+ framen = gr.Slider(label="Frame Number", minimum=8, maximum=default.framen, value=default.framen, step=1)
742
+ fps = gr.Slider(label="Video FPS", minimum=4, maximum=default.fps, value=default.fps, step=4)
743
+ with gr.Row():
744
+ button_run = gr.Button("Run")
745
+
746
+
747
+ with gr.Column():
748
+ with gr.Accordion('Frame Results', open=False):
749
+ frame_output = gr.Gallery(label="Frames", elem_id='customized_imbox')
750
+ with gr.Accordion("Inversion Results", open=False):
751
+ inv_output = gr.Gallery(label="Inversion Results", elem_id='customized_imbox')
752
+ with gr.Accordion('Advanced Settings', open=False):
753
+ with gr.Row():
754
+ tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
755
+ tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
756
+ tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
757
+ with gr.Row():
758
+ cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
759
+ step = gr.Number(default.step, label="Step", precision=0)
760
+ with gr.Row():
761
+ force_resize = gr.Checkbox(label="Force Resize", value=True)
762
+ inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
763
+ inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
764
+ with gr.Row():
765
+ txt1 = gr.Textbox(label='Optional Different Text Input for Image Input 1', lines=1, placeholder="Input prompt...", )
766
+
767
+
768
+ tag_diffuser.change(
769
+ wrapper_obj.load_all,
770
+ inputs = [tag_diffuser, tag_lora, tag_scheduler],
771
+ outputs = [tag_diffuser, tag_lora, tag_scheduler],)
772
+
773
+ tag_lora.change(
774
+ wrapper_obj.load_all,
775
+ inputs = [tag_diffuser, tag_lora, tag_scheduler],
776
+ outputs = [tag_diffuser, tag_lora, tag_scheduler],)
777
+
778
+ tag_scheduler.change(
779
+ wrapper_obj.load_scheduler,
780
+ inputs = [tag_scheduler],
781
+ outputs = [tag_scheduler],)
782
+
783
+ button_run.click(
784
+ wrapper_obj.run_imintp,
785
+ inputs=[img0, img1, txt0, txt1,
786
+ cfg_scale, step,
787
+ framen, fps,
788
+ force_resize, inp_width, inp_height,
789
+ inversion, inner_step, force_reinvert,
790
+ tag_diffuser, tag_lora, tag_scheduler,],
791
+ outputs=[frame_output, video_output, inv_output])
792
+
793
+ gr.Examples(
794
+ label='Examples',
795
+ examples=get_imintp_example(),
796
+ fn=wrapper_obj.run_imintp,
797
+ inputs=[img0, img1, txt0,],
798
+ outputs=[frame_output, video_output, inv_output],
799
+ cache_examples=cache_examples,)
800
+
801
+ def interface_iminvs(wrapper_obj):
802
+ with gr.Row():
803
+ image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox')
804
+ recon_output = gr.Gallery(label="Reconstruction output", elem_id='customized_imbox')
805
+ with gr.Row():
806
+ with gr.Column():
807
+ prompt = gr.Textbox(label='Text Input', lines=1, placeholder="Input prompt...", )
808
+ with gr.Row():
809
+ button_run = gr.Button("Run")
810
+
811
+
812
+ with gr.Column():
813
+ with gr.Row():
814
+ inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
815
+ inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
816
+ force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
817
+ with gr.Accordion('Advanced Settings', open=False):
818
+ with gr.Row():
819
+ tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
820
+ tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
821
+ tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
822
+ with gr.Row():
823
+ cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
824
+ step = gr.Number(default.step, label="Step", precision=0)
825
+ with gr.Row():
826
+ force_resize = gr.Checkbox(label="Force Resize", value=True)
827
+ inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
828
+ inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
829
+
830
+
831
+ tag_diffuser.change(
832
+ wrapper_obj.load_all,
833
+ inputs = [tag_diffuser, tag_lora, tag_scheduler],
834
+ outputs = [tag_diffuser, tag_lora, tag_scheduler],)
835
+
836
+ tag_lora.change(
837
+ wrapper_obj.load_all,
838
+ inputs = [tag_diffuser, tag_lora, tag_scheduler],
839
+ outputs = [tag_diffuser, tag_lora, tag_scheduler],)
840
+
841
+ tag_scheduler.change(
842
+ wrapper_obj.load_scheduler,
843
+ inputs = [tag_scheduler],
844
+ outputs = [tag_scheduler],)
845
+
846
+ button_run.click(
847
+ wrapper_obj.run_iminvs,
848
+ inputs=[image_input, prompt,
849
+ cfg_scale, step,
850
+ force_resize, inp_width, inp_height,
851
+ inversion, inner_step, force_reinvert,
852
+ tag_diffuser, tag_lora, tag_scheduler,],
853
+ outputs=[recon_output])
854
+
855
+ gr.Examples(
856
+ label='Examples',
857
+ examples=get_iminvs_example(),
858
+ fn=wrapper_obj.run_iminvs,
859
+ inputs=[image_input, prompt,],
860
+ outputs=[recon_output],
861
+ cache_examples=cache_examples,)
862
+
863
+
864
+ def interface_imedit(wrapper_obj):
865
+ with gr.Row():
866
+ image_input = gr.Image(label="Image input", type='pil', elem_id='customized_imbox')
867
+ edited_output = gr.Gallery(label="Edited output", elem_id='customized_imbox')
868
+ with gr.Row():
869
+ with gr.Column():
870
+ prompt_0 = gr.Textbox(label='Source Text', lines=1, placeholder="Source prompt...", )
871
+ prompt_1 = gr.Textbox(label='Target Text', lines=1, placeholder="Target prompt...", )
872
+ with gr.Row():
873
+ button_run = gr.Button("Run")
874
+
875
+ with gr.Column():
876
+ with gr.Row():
877
+ inversion = auto_dropdown('Inversion', choices.inversion, default.inversion)
878
+ inner_step = gr.Slider(label="Inner Step (NTI)", value=default.nullinv_inner_step, minimum=1, maximum=10, step=1)
879
+ force_reinvert = gr.Checkbox(label="Force ReInvert (NTI)", value=False)
880
+ threshold = gr.Slider(label="Threshold", minimum=0, maximum=1, value=default.threshold, step=0.1)
881
+ with gr.Accordion('Advanced Settings', open=False):
882
+ with gr.Row():
883
+ tag_diffuser = auto_dropdown('Diffuser', choices.diffuser, default.diffuser)
884
+ tag_lora = auto_dropdown('Use LoRA', choices.lora, default.lora)
885
+ tag_scheduler = auto_dropdown('Scheduler', choices.scheduler, default.scheduler)
886
+ with gr.Row():
887
+ cfg_scale = gr.Number(label="Scale", minimum=1, maximum=10, value=default.cfg_scale, step=0.5)
888
+ step = gr.Number(default.step, label="Step", precision=0)
889
+ with gr.Row():
890
+ force_resize = gr.Checkbox(label="Force Resize", value=True)
891
+ inp_width = gr.Slider(label="Width", minimum=256, maximum=1024, value=512, step=64)
892
+ inp_height = gr.Slider(label="Height", minimum=256, maximum=1024, value=512, step=64)
893
+
894
+
895
+ tag_diffuser.change(
896
+ wrapper_obj.load_all,
897
+ inputs = [tag_diffuser, tag_lora, tag_scheduler],
898
+ outputs = [tag_diffuser, tag_lora, tag_scheduler],)
899
+
900
+ tag_lora.change(
901
+ wrapper_obj.load_all,
902
+ inputs = [tag_diffuser, tag_lora, tag_scheduler],
903
+ outputs = [tag_diffuser, tag_lora, tag_scheduler],)
904
+
905
+ tag_scheduler.change(
906
+ wrapper_obj.load_scheduler,
907
+ inputs = [tag_scheduler],
908
+ outputs = [tag_scheduler],)
909
+
910
+ button_run.click(
911
+ wrapper_obj.run_imedit,
912
+ inputs=[image_input, prompt_0, prompt_1,
913
+ threshold, cfg_scale, step,
914
+ force_resize, inp_width, inp_height,
915
+ inversion, inner_step, force_reinvert,
916
+ tag_diffuser, tag_lora, tag_scheduler,],
917
+ outputs=[edited_output])
918
+
919
+ gr.Examples(
920
+ label='Examples',
921
+ examples=get_imedit_example(),
922
+ fn=wrapper_obj.run_imedit,
923
+ inputs=[image_input, prompt_0, prompt_1, threshold,],
924
+ outputs=[edited_output],
925
+ cache_examples=cache_examples,)
926
+
927
+
928
+ #############
929
+ # Interface #
930
+ #############
931
+
932
+ if __name__ == '__main__':
933
+ parser = argparse.ArgumentParser()
934
+ parser.add_argument('-p', '--port', type=int, default=None)
935
+ args = parser.parse_args()
936
+ from app_utils import css_empty, css_version_4_11_0
937
+ # css = css_empty
938
+ css = css_version_4_11_0
939
+
940
+ wrapper_obj = wrapper(
941
+ fp16=False,
942
+ tag_diffuser=default.diffuser,
943
+ tag_lora=default.lora,
944
+ tag_scheduler=default.scheduler)
945
+
946
+ if True:
947
+ with gr.Blocks(css=css) as demo:
948
+ gr.HTML(
949
+ """
950
+ <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
951
+ <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
952
+ {}
953
+ </h1>
954
+ </div>
955
+ """.format(version))
956
+
957
+ with gr.Tab('Image Interpolation'):
958
+ interface_imintp(wrapper_obj)
959
+ with gr.Tab('Image Inversion'):
960
+ interface_iminvs(wrapper_obj)
961
+ with gr.Tab('Image Editing'):
962
+ interface_imedit(wrapper_obj)
963
+
964
+ demo.launch()
app_utils.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import cv2
4
+ import numpy as np
5
+ import numpy.random as npr
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchvision.transforms as tvtrans
9
+ import PIL.Image
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import copy
13
+ import json
14
+ from collections import OrderedDict
15
+
16
+ #######
17
+ # css #
18
+ #######
19
+
20
+ css_empty = ""
21
+
22
+ css_version_4_11_0 = """
23
+ #customized_imbox {
24
+ min-height: 450px;
25
+ max-height: 450px;
26
+ }
27
+ #customized_imbox>div[data-testid="image"] {
28
+ min-height: 450px;
29
+ }
30
+ #customized_imbox>div[data-testid="image"]>span[data-testid="source-select"] {
31
+ max-height: 0px;
32
+ }
33
+ #customized_imbox>div[data-testid="image"]>span[data-testid="source-select"]>button {
34
+ max-height: 0px;
35
+ }
36
+ #customized_imbox>div[data-testid="image"]>div.upload-container>div.image-frame>img {
37
+ position: absolute;
38
+ top: 50%;
39
+ left: 50%;
40
+ transform: translateX(-50%) translateY(-50%);
41
+ width: unset;
42
+ height: unset;
43
+ max-height: 450px;
44
+ }
45
+ #customized_imbox>div.unpadded_box {
46
+ min-height: 450px;
47
+ }
48
+ #myinst {
49
+ font-size: 0.8rem;
50
+ margin: 0rem;
51
+ color: #6B7280;
52
+ }
53
+ #maskinst {
54
+ text-align: justify;
55
+ min-width: 1200px;
56
+ }
57
+ #maskinst>img {
58
+ min-width:399px;
59
+ max-width:450px;
60
+ vertical-align: top;
61
+ display: inline-block;
62
+ }
63
+ #maskinst:after {
64
+ content: "";
65
+ width: 100%;
66
+ display: inline-block;
67
+ }
68
+ """
69
+
70
+ ##########
71
+ # helper #
72
+ ##########
73
+
74
+ def highlight_print(info):
75
+ print('')
76
+ print(''.join(['#']*(len(info)+4)))
77
+ print('# '+info+' #')
78
+ print(''.join(['#']*(len(info)+4)))
79
+ print('')
80
+
81
+ def auto_dropdown(name, choices_od, value):
82
+ import gradio as gr
83
+ option_list = [pi for pi in choices_od.keys()]
84
+ return gr.Dropdown(label=name, choices=option_list, value=value)
85
+
86
+ def load_sd_from_file(target):
87
+ if osp.splitext(target)[-1] == '.ckpt':
88
+ sd = torch.load(target, map_location='cpu')['state_dict']
89
+ elif osp.splitext(target)[-1] == '.pth':
90
+ sd = torch.load(target, map_location='cpu')
91
+ elif osp.splitext(target)[-1] == '.safetensors':
92
+ from safetensors.torch import load_file as stload
93
+ sd = OrderedDict(stload(target, device='cpu'))
94
+ else:
95
+ assert False, "File type must be .ckpt or .pth or .safetensors"
96
+ return sd
97
+
98
+ def torch_to_numpy(x):
99
+ return x.detach().to('cpu').numpy()
100
+
101
+ if __name__ == '__main__':
102
+ pass
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/images/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/images/editing/banana.png ADDED
assets/images/editing/cake.png ADDED
assets/images/editing/rabbit.png ADDED
assets/images/interpolation/church1.png ADDED
assets/images/interpolation/church2.png ADDED
assets/images/interpolation/dog1.png ADDED
assets/images/interpolation/dog2.png ADDED
assets/images/interpolation/horse1.png ADDED
assets/images/interpolation/horse2.png ADDED
assets/images/interpolation/land1.png ADDED
assets/images/interpolation/land2.png ADDED
assets/images/interpolation/rabbit1.png ADDED
assets/images/interpolation/rabbit2.png ADDED
assets/images/interpolation/woman1.png ADDED
assets/images/interpolation/woman2.png ADDED
assets/images/inversion/000000029596.jpg ADDED
assets/images/inversion/000000560011.jpg ADDED
nulltxtinv_wrapper.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import PIL.Image
4
+ from tqdm import tqdm
5
+ from typing import Optional, Union, List
6
+ import warnings
7
+ warnings.filterwarnings('ignore')
8
+
9
+ from torch.optim.adam import Adam
10
+ import torch.nn.functional as nnf
11
+
12
+ from diffusers import DDIMScheduler
13
+
14
+ ##########
15
+ # helper #
16
+ ##########
17
+
18
+ def diffusion_step(model, latents, context, t, guidance_scale, low_resource=False):
19
+ if low_resource:
20
+ noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
21
+ noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
22
+ else:
23
+ latents_input = torch.cat([latents] * 2)
24
+ noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
25
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
26
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
27
+ latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
28
+ return latents
29
+
30
+ def image2latent(vae, image):
31
+ with torch.no_grad():
32
+ if isinstance(image, PIL.Image.Image):
33
+ image = np.array(image)
34
+ if isinstance(image, np.ndarray):
35
+ dtype = next(vae.parameters()).dtype
36
+ device = next(vae.parameters()).device
37
+ image = torch.from_numpy(image).float() / 127.5 - 1
38
+ image = image.permute(2, 0, 1).unsqueeze(0).to(device=device, dtype=dtype)
39
+ latents = vae.encode(image)['latent_dist'].mean
40
+ latents = latents * 0.18215
41
+ return latents
42
+
43
+ def latent2image(vae, latents, return_type='np'):
44
+ assert isinstance(latents, torch.Tensor)
45
+ latents = 1 / 0.18215 * latents.detach()
46
+ image = vae.decode(latents)['sample']
47
+ if return_type in ['np', 'pil']:
48
+ image = (image / 2 + 0.5).clamp(0, 1)
49
+ image = image.cpu().permute(0, 2, 3, 1).numpy()
50
+ image = (image * 255).astype(np.uint8)
51
+ if return_type == 'pil':
52
+ pilim = [PIL.Image.fromarray(imi) for imi in image]
53
+ pilim = pilim[0] if len(pilim)==1 else pilim
54
+ return pilim
55
+ else:
56
+ return image
57
+
58
+ def init_latent(latent, model, height, width, generator, batch_size):
59
+ if latent is None:
60
+ latent = torch.randn(
61
+ (1, model.unet.in_channels, height // 8, width // 8),
62
+ generator=generator,
63
+ )
64
+ latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
65
+ return latent, latents
66
+
67
+ def txt_to_emb(model, prompt):
68
+ text_input = model.tokenizer(
69
+ prompt,
70
+ padding="max_length",
71
+ max_length=model.tokenizer.model_max_length,
72
+ truncation=True,
73
+ return_tensors="pt",)
74
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
75
+ return text_embeddings
76
+
77
+ @torch.no_grad()
78
+ def text2image_ldm(
79
+ model,
80
+ prompt: List[str],
81
+ num_inference_steps: int = 50,
82
+ guidance_scale: Optional[float] = 7.5,
83
+ generator: Optional[torch.Generator] = None,
84
+ latent: Optional[torch.FloatTensor] = None,
85
+ uncond_embeddings=None,
86
+ start_time=50,
87
+ return_type='pil', ):
88
+
89
+ batch_size = len(prompt)
90
+ height = width = 512
91
+ if latent is not None:
92
+ height = latent.shape[-2] * 8
93
+ width = latent.shape[-1] * 8
94
+
95
+ text_input = model.tokenizer(
96
+ prompt,
97
+ padding="max_length",
98
+ max_length=model.tokenizer.model_max_length,
99
+ truncation=True,
100
+ return_tensors="pt",)
101
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
102
+ max_length = text_input.input_ids.shape[-1]
103
+ if uncond_embeddings is None:
104
+ uncond_input = model.tokenizer(
105
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt",)
106
+ uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
107
+ else:
108
+ uncond_embeddings_ = None
109
+
110
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
111
+ model.scheduler.set_timesteps(num_inference_steps)
112
+ for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
113
+ if uncond_embeddings_ is None:
114
+ context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
115
+ else:
116
+ context = torch.cat([uncond_embeddings_, text_embeddings])
117
+ latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False)
118
+
119
+ if return_type in ['pil', 'np']:
120
+ image = latent2image(model.vae, latents, return_type=return_type)
121
+ else:
122
+ image = latents
123
+ return image, latent
124
+
125
+ @torch.no_grad()
126
+ def text2image_ldm_imedit(
127
+ model,
128
+ thresh,
129
+ prompt: List[str],
130
+ target_prompt: List[str],
131
+ num_inference_steps: int = 50,
132
+ guidance_scale: Optional[float] = 7.5,
133
+ generator: Optional[torch.Generator] = None,
134
+ latent: Optional[torch.FloatTensor] = None,
135
+ uncond_embeddings=None,
136
+ start_time=50,
137
+ return_type='pil'
138
+ ):
139
+ batch_size = len(prompt)
140
+ height = width = 512
141
+
142
+ text_input = model.tokenizer(
143
+ prompt,
144
+ padding="max_length",
145
+ max_length=model.tokenizer.model_max_length,
146
+ truncation=True,
147
+ return_tensors="pt",
148
+ )
149
+ target_text_input = model.tokenizer(
150
+ target_prompt,
151
+ padding="max_length",
152
+ max_length=model.tokenizer.model_max_length,
153
+ truncation=True,
154
+ return_tensors="pt",
155
+ )
156
+ text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
157
+ target_text_embeddings = model.text_encoder(target_text_input.input_ids.to(model.device))[0]
158
+
159
+ max_length = text_input.input_ids.shape[-1]
160
+ if uncond_embeddings is None:
161
+ uncond_input = model.tokenizer(
162
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
163
+ )
164
+ uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
165
+ else:
166
+ uncond_embeddings_ = None
167
+
168
+ latent, latents = init_latent(latent, model, height, width, generator, batch_size)
169
+ model.scheduler.set_timesteps(num_inference_steps)
170
+ for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
171
+ if i < (1 - thresh) * num_inference_steps:
172
+ if uncond_embeddings_ is None:
173
+ context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
174
+ else:
175
+ context = torch.cat([uncond_embeddings_, text_embeddings])
176
+ latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False)
177
+ else:
178
+ if uncond_embeddings_ is None:
179
+ context = torch.cat([uncond_embeddings[i].expand(*target_text_embeddings.shape), target_text_embeddings])
180
+ else:
181
+ context = torch.cat([uncond_embeddings_, target_text_embeddings])
182
+ latents = diffusion_step(model, latents, context, t, guidance_scale, low_resource=False)
183
+
184
+ if return_type in ['pil', 'np']:
185
+ image = latent2image(model.vae, latents, return_type=return_type)
186
+ else:
187
+ image = latents
188
+ return image, latent
189
+
190
+
191
+ ###########
192
+ # wrapper #
193
+ ###########
194
+
195
+ class NullInversion(object):
196
+ def __init__(self, model, num_ddim_steps, guidance_scale, device='cuda'):
197
+ self.model = model
198
+ self.device = device
199
+ self.num_ddim_steps=num_ddim_steps
200
+ self.guidance_scale = guidance_scale
201
+ self.tokenizer = self.model.tokenizer
202
+ self.prompt = None
203
+ self.context = None
204
+
205
+ def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray]):
206
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
207
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
208
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
209
+ beta_prod_t = 1 - alpha_prod_t
210
+ pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
211
+ pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
212
+ prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
213
+ return prev_sample
214
+
215
+ def next_step(self, noise_pred, timestep, sample):
216
+ timestep, next_timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
217
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
218
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
219
+ beta_prod_t = 1 - alpha_prod_t
220
+ next_original_sample = (sample - beta_prod_t ** 0.5 * noise_pred) / alpha_prod_t ** 0.5
221
+ next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * noise_pred
222
+ next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
223
+ return next_sample
224
+
225
+ def get_noise_pred_single(self, latents, t, context):
226
+ noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
227
+ return noise_pred
228
+
229
+ def get_noise_pred(self, latents, t, is_forward=True, context=None):
230
+ latents_input = torch.cat([latents] * 2)
231
+ if context is None:
232
+ context = self.context
233
+ guidance_scale = 1 if is_forward else self.guidance_scale
234
+ noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
235
+ noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
236
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
237
+ if is_forward:
238
+ latents = self.next_step(noise_pred, t, latents)
239
+ else:
240
+ latents = self.prev_step(noise_pred, t, latents)
241
+ return latents
242
+
243
+ @torch.no_grad()
244
+ def init_prompt(self, prompt: str):
245
+ uncond_input = self.model.tokenizer(
246
+ [""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
247
+ return_tensors="pt"
248
+ )
249
+ uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
250
+ text_input = self.model.tokenizer(
251
+ [prompt],
252
+ padding="max_length",
253
+ max_length=self.model.tokenizer.model_max_length,
254
+ truncation=True,
255
+ return_tensors="pt",
256
+ )
257
+ text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
258
+ self.context = torch.cat([uncond_embeddings, text_embeddings])
259
+ self.prompt = prompt
260
+
261
+ @torch.no_grad()
262
+ def ddim_loop(self, latent, emb):
263
+ # uncond_embeddings, cond_embeddings = self.context.chunk(2)
264
+ all_latent = [latent]
265
+ latent = latent.clone().detach()
266
+ for i in range(self.num_ddim_steps):
267
+ t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
268
+ noise_pred = self.get_noise_pred_single(latent, t, emb)
269
+ latent = self.next_step(noise_pred, t, latent)
270
+ all_latent.append(latent)
271
+ return all_latent
272
+
273
+ @property
274
+ def scheduler(self):
275
+ return self.model.scheduler
276
+
277
+ @torch.no_grad()
278
+ def ddim_invert(self, image, prompt):
279
+ assert isinstance(image, PIL.Image.Image)
280
+
281
+ scheduler_save = self.model.scheduler
282
+ self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config)
283
+ self.model.scheduler.set_timesteps(self.num_ddim_steps)
284
+
285
+ with torch.no_grad():
286
+ emb = txt_to_emb(self.model, prompt)
287
+ latent = image2latent(self.model.vae, image)
288
+ ddim_latents = self.ddim_loop(latent, emb)
289
+
290
+ self.model.scheduler = scheduler_save
291
+ return ddim_latents[-1]
292
+
293
+ def null_optimization(self, latents, emb, nemb=None, num_inner_steps=10, epsilon=1e-5):
294
+ # force fp32
295
+ dtype = latents[0].dtype
296
+ uncond_embeddings = nemb.float() if nemb is not None else txt_to_emb(self.model, "").float()
297
+ cond_embeddings = emb.float()
298
+ latents = [li.float() for li in latents]
299
+ self.model.unet.to(torch.float32)
300
+
301
+ uncond_embeddings_list = []
302
+ latent_cur = latents[-1]
303
+ bar = tqdm(total=num_inner_steps * self.num_ddim_steps)
304
+ for i in range(self.num_ddim_steps):
305
+ uncond_embeddings = uncond_embeddings.clone().detach()
306
+ uncond_embeddings.requires_grad = True
307
+ optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
308
+ latent_prev = latents[len(latents) - i - 2]
309
+ t = self.model.scheduler.timesteps[i]
310
+ with torch.no_grad():
311
+ noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
312
+ for j in range(num_inner_steps):
313
+ noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
314
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
315
+ latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
316
+ loss = nnf.mse_loss(latents_prev_rec, latent_prev)
317
+ optimizer.zero_grad()
318
+ loss.backward()
319
+ optimizer.step()
320
+ loss_item = loss.item()
321
+ bar.update()
322
+ if loss_item < epsilon + i * 2e-5:
323
+ break
324
+ for j in range(j + 1, num_inner_steps):
325
+ bar.update()
326
+ uncond_embeddings_list.append(uncond_embeddings[:1].detach())
327
+ with torch.no_grad():
328
+ context = torch.cat([uncond_embeddings, cond_embeddings])
329
+ latent_cur = self.get_noise_pred(latent_cur, t, False, context)
330
+ bar.close()
331
+
332
+ uncond_embeddings_list = [ui.to(dtype) for ui in uncond_embeddings_list]
333
+ self.model.unet.to(dtype)
334
+ return uncond_embeddings_list
335
+
336
+ def null_invert(self, im, txt, ntxt=None, num_inner_steps=10, early_stop_epsilon=1e-5):
337
+ assert isinstance(im, PIL.Image.Image)
338
+
339
+ scheduler_save = self.model.scheduler
340
+ self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config)
341
+ self.model.scheduler.set_timesteps(self.num_ddim_steps)
342
+
343
+ with torch.no_grad():
344
+ nemb = txt_to_emb(self.model, ntxt) \
345
+ if ntxt is not None else txt_to_emb(self.model, "")
346
+ emb = txt_to_emb(self.model, txt)
347
+ latent = image2latent(self.model.vae, im)
348
+
349
+ # ddim inversion
350
+ ddim_latents = self.ddim_loop(latent, emb)
351
+ # nulltext inversion
352
+ uncond_embeddings = self.null_optimization(
353
+ ddim_latents, emb, nemb, num_inner_steps, early_stop_epsilon)
354
+
355
+ self.model.scheduler = scheduler_save
356
+ return ddim_latents[-1], uncond_embeddings
357
+
358
+ def null_optimization_dual(
359
+ self, latents0, latents1, emb0, emb1, nemb=None,
360
+ num_inner_steps=10, epsilon=1e-5):
361
+
362
+ # force fp32
363
+ dtype = latents0[0].dtype
364
+ uncond_embeddings = nemb.float() if nemb is not None else txt_to_emb(self.model, "").float()
365
+ cond_embeddings0, cond_embeddings1 = emb0.float(), emb1.float()
366
+ latents0 = [li.float() for li in latents0]
367
+ latents1 = [li.float() for li in latents1]
368
+ self.model.unet.to(torch.float32)
369
+
370
+ uncond_embeddings_list = []
371
+ latent_cur0 = latents0[-1]
372
+ latent_cur1 = latents1[-1]
373
+
374
+ bar = tqdm(total=num_inner_steps * self.num_ddim_steps)
375
+ for i in range(self.num_ddim_steps):
376
+ uncond_embeddings = uncond_embeddings.clone().detach()
377
+ uncond_embeddings.requires_grad = True
378
+ optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
379
+
380
+ latent_prev0 = latents0[len(latents0) - i - 2]
381
+ latent_prev1 = latents1[len(latents1) - i - 2]
382
+
383
+ t = self.model.scheduler.timesteps[i]
384
+ with torch.no_grad():
385
+ noise_pred_cond0 = self.get_noise_pred_single(latent_cur0, t, cond_embeddings0)
386
+ noise_pred_cond1 = self.get_noise_pred_single(latent_cur1, t, cond_embeddings1)
387
+ for j in range(num_inner_steps):
388
+ noise_pred_uncond0 = self.get_noise_pred_single(latent_cur0, t, uncond_embeddings)
389
+ noise_pred_uncond1 = self.get_noise_pred_single(latent_cur1, t, uncond_embeddings)
390
+
391
+ noise_pred0 = noise_pred_uncond0 + self.guidance_scale*(noise_pred_cond0-noise_pred_uncond0)
392
+ noise_pred1 = noise_pred_uncond1 + self.guidance_scale*(noise_pred_cond1-noise_pred_uncond1)
393
+
394
+ latents_prev_rec0 = self.prev_step(noise_pred0, t, latent_cur0)
395
+ latents_prev_rec1 = self.prev_step(noise_pred1, t, latent_cur1)
396
+
397
+ loss = nnf.mse_loss(latents_prev_rec0, latent_prev0) + \
398
+ nnf.mse_loss(latents_prev_rec1, latent_prev1)
399
+
400
+ optimizer.zero_grad()
401
+ loss.backward()
402
+ optimizer.step()
403
+ loss_item = loss.item()
404
+ bar.update()
405
+ if loss_item < epsilon + i * 2e-5:
406
+ break
407
+ for j in range(j + 1, num_inner_steps):
408
+ bar.update()
409
+ uncond_embeddings_list.append(uncond_embeddings[:1].detach())
410
+
411
+ with torch.no_grad():
412
+ context0 = torch.cat([uncond_embeddings, cond_embeddings0])
413
+ context1 = torch.cat([uncond_embeddings, cond_embeddings1])
414
+ latent_cur0 = self.get_noise_pred(latent_cur0, t, False, context0)
415
+ latent_cur1 = self.get_noise_pred(latent_cur1, t, False, context1)
416
+
417
+ bar.close()
418
+
419
+ uncond_embeddings_list = [ui.to(dtype) for ui in uncond_embeddings_list]
420
+ self.model.unet.to(dtype)
421
+ return uncond_embeddings_list
422
+
423
+ def null_invert_dual(
424
+ self, im0, im1, txt0, txt1, ntxt=None,
425
+ num_inner_steps=10, early_stop_epsilon=1e-5, ):
426
+ assert isinstance(im0, PIL.Image.Image)
427
+ assert isinstance(im1, PIL.Image.Image)
428
+
429
+ scheduler_save = self.model.scheduler
430
+ self.model.scheduler = DDIMScheduler.from_config(self.model.scheduler.config)
431
+ self.model.scheduler.set_timesteps(self.num_ddim_steps)
432
+
433
+ with torch.no_grad():
434
+ nemb = txt_to_emb(self.model, ntxt) \
435
+ if ntxt is not None else txt_to_emb(self.model, "")
436
+ latent0 = image2latent(self.model.vae, im0)
437
+ latent1 = image2latent(self.model.vae, im1)
438
+ emb0 = txt_to_emb(self.model, txt0)
439
+ emb1 = txt_to_emb(self.model, txt1)
440
+
441
+ # ddim inversion
442
+ ddim_latents_0 = self.ddim_loop(latent0, emb0)
443
+ ddim_latents_1 = self.ddim_loop(latent1, emb1)
444
+
445
+ # nulltext inversion
446
+ nembs = self.null_optimization_dual(
447
+ ddim_latents_0, ddim_latents_1, emb0, emb1, nemb, num_inner_steps, early_stop_epsilon)
448
+
449
+ self.model.scheduler = scheduler_save
450
+ return ddim_latents_0[-1], ddim_latents_1[-1], nembs
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.20.3
2
+ bitsandbytes==0.42.0
3
+ datasets==2.14.4
4
+ diffusers==0.20.1
5
+ easydict==1.11
6
+ gradio==4.19.2
7
+ huggingface_hub==0.19.3
8
+ moviepy==1.0.3
9
+ opencv_python==4.7.0.72
10
+ packaging==23.2
11
+ pypatchify==0.1.4
12
+ safetensors==0.3.1
13
+ tqdm==4.65.0
14
+ transformers==4.30.1
15
+ wandb==0.16.3
16
+ xformers==0.0.17