lunarring commited on
Commit
c958d73
1 Parent(s): 1d2444d
Files changed (5) hide show
  1. gradio_ui.py +492 -0
  2. latent_blending.py +213 -579
  3. movie_util.py +46 -54
  4. stable_diffusion_holder.py +87 -355
  5. utils.py +260 -0
gradio_ui.py ADDED
@@ -0,0 +1,492 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Lunar Ring. All rights reserved.
2
+ # Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ import torch
18
+ torch.backends.cudnn.benchmark = False
19
+ torch.set_grad_enabled(False)
20
+ import numpy as np
21
+ import warnings
22
+ warnings.filterwarnings('ignore')
23
+ import warnings
24
+ from tqdm.auto import tqdm
25
+ from PIL import Image
26
+ from movie_util import MovieSaver, concatenate_movies
27
+ from latent_blending import LatentBlending
28
+ from stable_diffusion_holder import StableDiffusionHolder
29
+ import gradio as gr
30
+ from dotenv import find_dotenv, load_dotenv
31
+ import shutil
32
+ import random
33
+ from utils import get_time, add_frames_linear_interp
34
+ from huggingface_hub import hf_hub_download
35
+
36
+
37
+ class BlendingFrontend():
38
+ def __init__(
39
+ self,
40
+ sdh,
41
+ share=False):
42
+ r"""
43
+ Gradio Helper Class to collect UI data and start latent blending.
44
+ Args:
45
+ sdh:
46
+ StableDiffusionHolder
47
+ share: bool
48
+ Set true to get a shareable gradio link (e.g. for running a remote server)
49
+ """
50
+ self.share = share
51
+
52
+ # UI Defaults
53
+ self.num_inference_steps = 30
54
+ self.depth_strength = 0.25
55
+ self.seed1 = 420
56
+ self.seed2 = 420
57
+ self.prompt1 = ""
58
+ self.prompt2 = ""
59
+ self.negative_prompt = ""
60
+ self.fps = 30
61
+ self.duration_video = 8
62
+ self.t_compute_max_allowed = 10
63
+
64
+ self.lb = LatentBlending(sdh)
65
+ self.lb.sdh.num_inference_steps = self.num_inference_steps
66
+ self.init_parameters_from_lb()
67
+ self.init_save_dir()
68
+
69
+ # Vars
70
+ self.list_fp_imgs_current = []
71
+ self.recycle_img1 = False
72
+ self.recycle_img2 = False
73
+ self.list_all_segments = []
74
+ self.dp_session = ""
75
+ self.user_id = None
76
+
77
+ def init_parameters_from_lb(self):
78
+ r"""
79
+ Automatically init parameters from latentblending instance
80
+ """
81
+ self.height = self.lb.sdh.height
82
+ self.width = self.lb.sdh.width
83
+ self.guidance_scale = self.lb.guidance_scale
84
+ self.guidance_scale_mid_damper = self.lb.guidance_scale_mid_damper
85
+ self.mid_compression_scaler = self.lb.mid_compression_scaler
86
+ self.branch1_crossfeed_power = self.lb.branch1_crossfeed_power
87
+ self.branch1_crossfeed_range = self.lb.branch1_crossfeed_range
88
+ self.branch1_crossfeed_decay = self.lb.branch1_crossfeed_decay
89
+ self.parental_crossfeed_power = self.lb.parental_crossfeed_power
90
+ self.parental_crossfeed_range = self.lb.parental_crossfeed_range
91
+ self.parental_crossfeed_power_decay = self.lb.parental_crossfeed_power_decay
92
+
93
+ def init_save_dir(self):
94
+ r"""
95
+ Initializes the directory where stuff is being saved.
96
+ You can specify this directory in a ".env" file in your latentblending root, setting
97
+ DIR_OUT='/path/to/saving'
98
+ """
99
+ load_dotenv(find_dotenv(), verbose=False)
100
+ self.dp_out = os.getenv("DIR_OUT")
101
+ if self.dp_out is None:
102
+ self.dp_out = ""
103
+ self.dp_imgs = os.path.join(self.dp_out, "imgs")
104
+ os.makedirs(self.dp_imgs, exist_ok=True)
105
+ self.dp_movies = os.path.join(self.dp_out, "movies")
106
+ os.makedirs(self.dp_movies, exist_ok=True)
107
+ self.save_empty_image()
108
+
109
+ def save_empty_image(self):
110
+ r"""
111
+ Saves an empty/black dummy image.
112
+ """
113
+ self.fp_img_empty = os.path.join(self.dp_imgs, 'empty.jpg')
114
+ Image.fromarray(np.zeros((self.height, self.width, 3), dtype=np.uint8)).save(self.fp_img_empty, quality=5)
115
+
116
+ def randomize_seed1(self):
117
+ r"""
118
+ Randomizes the first seed
119
+ """
120
+ seed = np.random.randint(0, 10000000)
121
+ self.seed1 = int(seed)
122
+ print(f"randomize_seed1: new seed = {self.seed1}")
123
+ return seed
124
+
125
+ def randomize_seed2(self):
126
+ r"""
127
+ Randomizes the second seed
128
+ """
129
+ seed = np.random.randint(0, 10000000)
130
+ self.seed2 = int(seed)
131
+ print(f"randomize_seed2: new seed = {self.seed2}")
132
+ return seed
133
+
134
+ def setup_lb(self, list_ui_vals):
135
+ r"""
136
+ Sets all parameters from the UI. Since gradio does not support to pass dictionaries,
137
+ we have to instead pass keys (list_ui_keys, global) and values (list_ui_vals)
138
+ """
139
+ # Collect latent blending variables
140
+ self.lb.set_width(list_ui_vals[list_ui_keys.index('width')])
141
+ self.lb.set_height(list_ui_vals[list_ui_keys.index('height')])
142
+ self.lb.set_prompt1(list_ui_vals[list_ui_keys.index('prompt1')])
143
+ self.lb.set_prompt2(list_ui_vals[list_ui_keys.index('prompt2')])
144
+ self.lb.set_negative_prompt(list_ui_vals[list_ui_keys.index('negative_prompt')])
145
+ self.lb.guidance_scale = list_ui_vals[list_ui_keys.index('guidance_scale')]
146
+ self.lb.guidance_scale_mid_damper = list_ui_vals[list_ui_keys.index('guidance_scale_mid_damper')]
147
+ self.t_compute_max_allowed = list_ui_vals[list_ui_keys.index('duration_compute')]
148
+ self.lb.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
149
+ self.lb.sdh.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
150
+ self.duration_video = list_ui_vals[list_ui_keys.index('duration_video')]
151
+ self.lb.seed1 = list_ui_vals[list_ui_keys.index('seed1')]
152
+ self.lb.seed2 = list_ui_vals[list_ui_keys.index('seed2')]
153
+ self.lb.branch1_crossfeed_power = list_ui_vals[list_ui_keys.index('branch1_crossfeed_power')]
154
+ self.lb.branch1_crossfeed_range = list_ui_vals[list_ui_keys.index('branch1_crossfeed_range')]
155
+ self.lb.branch1_crossfeed_decay = list_ui_vals[list_ui_keys.index('branch1_crossfeed_decay')]
156
+ self.lb.parental_crossfeed_power = list_ui_vals[list_ui_keys.index('parental_crossfeed_power')]
157
+ self.lb.parental_crossfeed_range = list_ui_vals[list_ui_keys.index('parental_crossfeed_range')]
158
+ self.lb.parental_crossfeed_power_decay = list_ui_vals[list_ui_keys.index('parental_crossfeed_power_decay')]
159
+ self.num_inference_steps = list_ui_vals[list_ui_keys.index('num_inference_steps')]
160
+ self.depth_strength = list_ui_vals[list_ui_keys.index('depth_strength')]
161
+
162
+ if len(list_ui_vals[list_ui_keys.index('user_id')]) > 1:
163
+ self.user_id = list_ui_vals[list_ui_keys.index('user_id')]
164
+ else:
165
+ # generate new user id
166
+ self.user_id = ''.join((random.choice('ABCDEFGHIJKLMNOPQRSTUVWXYZ') for i in range(8)))
167
+ print(f"made new user_id: {self.user_id} at {get_time('second')}")
168
+
169
+ def save_latents(self, fp_latents, list_latents):
170
+ r"""
171
+ Saves a latent trajectory on disk, in npy format.
172
+ """
173
+ list_latents_cpu = [l.cpu().numpy() for l in list_latents]
174
+ np.save(fp_latents, list_latents_cpu)
175
+
176
+ def load_latents(self, fp_latents):
177
+ r"""
178
+ Loads a latent trajectory from disk, converts to torch tensor.
179
+ """
180
+ list_latents_cpu = np.load(fp_latents)
181
+ list_latents = [torch.from_numpy(l).to(self.lb.device) for l in list_latents_cpu]
182
+ return list_latents
183
+
184
+ def compute_img1(self, *args):
185
+ r"""
186
+ Computes the first transition image and returns it for display.
187
+ Sets all other transition images and last image to empty (as they are obsolete with this operation)
188
+ """
189
+ list_ui_vals = args
190
+ self.setup_lb(list_ui_vals)
191
+ fp_img1 = os.path.join(self.dp_imgs, f"img1_{self.user_id}")
192
+ img1 = Image.fromarray(self.lb.compute_latents1(return_image=True))
193
+ img1.save(fp_img1 + ".jpg")
194
+ self.save_latents(fp_img1 + ".npy", self.lb.tree_latents[0])
195
+ self.recycle_img1 = True
196
+ self.recycle_img2 = False
197
+ return [fp_img1 + ".jpg", self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
198
+
199
+ def compute_img2(self, *args):
200
+ r"""
201
+ Computes the last transition image and returns it for display.
202
+ Sets all other transition images to empty (as they are obsolete with this operation)
203
+ """
204
+ if not os.path.isfile(os.path.join(self.dp_imgs, f"img1_{self.user_id}.jpg")): # don't do anything
205
+ return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, self.user_id]
206
+ list_ui_vals = args
207
+ self.setup_lb(list_ui_vals)
208
+
209
+ self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
210
+ fp_img2 = os.path.join(self.dp_imgs, f"img2_{self.user_id}")
211
+ img2 = Image.fromarray(self.lb.compute_latents2(return_image=True))
212
+ img2.save(fp_img2 + '.jpg')
213
+ self.save_latents(fp_img2 + ".npy", self.lb.tree_latents[-1])
214
+ self.recycle_img2 = True
215
+ # fixme save seeds. change filenames?
216
+ return [self.fp_img_empty, self.fp_img_empty, self.fp_img_empty, fp_img2 + ".jpg", self.user_id]
217
+
218
+ def compute_transition(self, *args):
219
+ r"""
220
+ Computes transition images and movie.
221
+ """
222
+ list_ui_vals = args
223
+ self.setup_lb(list_ui_vals)
224
+ print("STARTING TRANSITION...")
225
+ fixed_seeds = [self.seed1, self.seed2]
226
+ # Inject loaded latents (other user interference)
227
+ self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
228
+ self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
229
+ imgs_transition = self.lb.run_transition(
230
+ recycle_img1=self.recycle_img1,
231
+ recycle_img2=self.recycle_img2,
232
+ num_inference_steps=self.num_inference_steps,
233
+ depth_strength=self.depth_strength,
234
+ t_compute_max_allowed=self.t_compute_max_allowed,
235
+ fixed_seeds=fixed_seeds)
236
+ print(f"Latent Blending pass finished ({get_time('second')}). Resulted in {len(imgs_transition)} images")
237
+
238
+ # Subselect three preview images
239
+ idx_img_prev = np.round(np.linspace(0, len(imgs_transition) - 1, 5)[1:-1]).astype(np.int32)
240
+
241
+ list_imgs_preview = []
242
+ for j in idx_img_prev:
243
+ list_imgs_preview.append(Image.fromarray(imgs_transition[j]))
244
+
245
+ # Save the preview imgs as jpgs on disk so we are not sending umcompressed data around
246
+ current_timestamp = get_time('second')
247
+ self.list_fp_imgs_current = []
248
+ for i in range(len(list_imgs_preview)):
249
+ fp_img = os.path.join(self.dp_imgs, f"img_preview_{i}_{current_timestamp}.jpg")
250
+ list_imgs_preview[i].save(fp_img)
251
+ self.list_fp_imgs_current.append(fp_img)
252
+ # Insert cheap frames for the movie
253
+ imgs_transition_ext = add_frames_linear_interp(imgs_transition, self.duration_video, self.fps)
254
+
255
+ # Save as movie
256
+ self.fp_movie = self.get_fp_video_last()
257
+ if os.path.isfile(self.fp_movie):
258
+ os.remove(self.fp_movie)
259
+ ms = MovieSaver(self.fp_movie, fps=self.fps)
260
+ for img in tqdm(imgs_transition_ext):
261
+ ms.write_frame(img)
262
+ ms.finalize()
263
+ print("DONE SAVING MOVIE! SENDING BACK...")
264
+
265
+ # Assemble Output, updating the preview images and le movie
266
+ list_return = self.list_fp_imgs_current + [self.fp_movie]
267
+ return list_return
268
+
269
+ def stack_forward(self, prompt2, seed2):
270
+ r"""
271
+ Allows to generate multi-segment movies. Sets last image -> first image with all
272
+ relevant parameters.
273
+ """
274
+ # Save preview images, prompts and seeds into dictionary for stacking
275
+ if len(self.list_all_segments) == 0:
276
+ timestamp_session = get_time('second')
277
+ self.dp_session = os.path.join(self.dp_out, f"session_{timestamp_session}")
278
+ os.makedirs(self.dp_session)
279
+
280
+ idx_segment = len(self.list_all_segments)
281
+ dp_segment = os.path.join(self.dp_session, f"segment_{str(idx_segment).zfill(3)}")
282
+
283
+ self.list_all_segments.append(dp_segment)
284
+ self.lb.write_imgs_transition(dp_segment)
285
+
286
+ fp_movie_last = self.get_fp_video_last()
287
+ fp_movie_next = self.get_fp_video_next()
288
+
289
+ shutil.copyfile(fp_movie_last, fp_movie_next)
290
+
291
+ self.lb.tree_latents[0] = self.load_latents(os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
292
+ self.lb.tree_latents[-1] = self.load_latents(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"))
293
+ self.lb.swap_forward()
294
+
295
+ shutil.copyfile(os.path.join(self.dp_imgs, f"img2_{self.user_id}.npy"), os.path.join(self.dp_imgs, f"img1_{self.user_id}.npy"))
296
+ fp_multi = self.multi_concat()
297
+ list_out = [fp_multi]
298
+
299
+ list_out.extend([os.path.join(self.dp_imgs, f"img2_{self.user_id}.jpg")])
300
+ list_out.extend([self.fp_img_empty] * 4)
301
+ list_out.append(gr.update(interactive=False, value=prompt2))
302
+ list_out.append(gr.update(interactive=False, value=seed2))
303
+ list_out.append("")
304
+ list_out.append(np.random.randint(0, 10000000))
305
+ print(f"stack_forward: fp_multi {fp_multi}")
306
+ return list_out
307
+
308
+ def multi_concat(self):
309
+ r"""
310
+ Concatentates all stacked segments into one long movie.
311
+ """
312
+ list_fp_movies = self.get_fp_video_all()
313
+ # Concatenate movies and save
314
+ fp_final = os.path.join(self.dp_session, f"concat_{self.user_id}.mp4")
315
+ concatenate_movies(fp_final, list_fp_movies)
316
+ return fp_final
317
+
318
+ def get_fp_video_all(self):
319
+ r"""
320
+ Collects all stacked movie segments.
321
+ """
322
+ list_all = os.listdir(self.dp_movies)
323
+ str_beg = f"movie_{self.user_id}_"
324
+ list_user = [l for l in list_all if str_beg in l]
325
+ list_user.sort()
326
+ list_user = [os.path.join(self.dp_movies, l) for l in list_user]
327
+ return list_user
328
+
329
+ def get_fp_video_next(self):
330
+ r"""
331
+ Gets the filepath of the next movie segment.
332
+ """
333
+ list_videos = self.get_fp_video_all()
334
+ if len(list_videos) == 0:
335
+ idx_next = 0
336
+ else:
337
+ idx_next = len(list_videos)
338
+ fp_video_next = os.path.join(self.dp_movies, f"movie_{self.user_id}_{str(idx_next).zfill(3)}.mp4")
339
+ return fp_video_next
340
+
341
+ def get_fp_video_last(self):
342
+ r"""
343
+ Gets the current video that was saved.
344
+ """
345
+ fp_video_last = os.path.join(self.dp_movies, f"last_{self.user_id}.mp4")
346
+ return fp_video_last
347
+
348
+
349
+ if __name__ == "__main__":
350
+ fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1-base", filename="v2-1_512-ema-pruned.ckpt")
351
+ # fp_ckpt = hf_hub_download(repo_id="stabilityai/stable-diffusion-2-1", filename="v2-1_768-ema-pruned.ckpt")
352
+ bf = BlendingFrontend(StableDiffusionHolder(fp_ckpt))
353
+ # self = BlendingFrontend(None)
354
+
355
+ with gr.Blocks() as demo:
356
+ with gr.Row():
357
+ prompt1 = gr.Textbox(label="prompt 1")
358
+ prompt2 = gr.Textbox(label="prompt 2")
359
+
360
+ with gr.Row():
361
+ duration_compute = gr.Slider(5, 200, bf.t_compute_max_allowed, step=1, label='compute budget', interactive=True)
362
+ duration_video = gr.Slider(1, 100, bf.duration_video, step=0.1, label='video duration', interactive=True)
363
+ height = gr.Slider(256, 2048, bf.height, step=128, label='height', interactive=True)
364
+ width = gr.Slider(256, 2048, bf.width, step=128, label='width', interactive=True)
365
+
366
+ with gr.Accordion("Advanced Settings (click to expand)", open=False):
367
+
368
+ with gr.Accordion("Diffusion settings", open=True):
369
+ with gr.Row():
370
+ num_inference_steps = gr.Slider(5, 100, bf.num_inference_steps, step=1, label='num_inference_steps', interactive=True)
371
+ guidance_scale = gr.Slider(1, 25, bf.guidance_scale, step=0.1, label='guidance_scale', interactive=True)
372
+ negative_prompt = gr.Textbox(label="negative prompt")
373
+
374
+ with gr.Accordion("Seed control: adjust seeds for first and last images", open=True):
375
+ with gr.Row():
376
+ b_newseed1 = gr.Button("randomize seed 1", variant='secondary')
377
+ seed1 = gr.Number(bf.seed1, label="seed 1", interactive=True)
378
+ seed2 = gr.Number(bf.seed2, label="seed 2", interactive=True)
379
+ b_newseed2 = gr.Button("randomize seed 2", variant='secondary')
380
+
381
+ with gr.Accordion("Last image crossfeeding.", open=True):
382
+ with gr.Row():
383
+ branch1_crossfeed_power = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_power, step=0.01, label='branch1 crossfeed power', interactive=True)
384
+ branch1_crossfeed_range = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_range, step=0.01, label='branch1 crossfeed range', interactive=True)
385
+ branch1_crossfeed_decay = gr.Slider(0.0, 1.0, bf.branch1_crossfeed_decay, step=0.01, label='branch1 crossfeed decay', interactive=True)
386
+
387
+ with gr.Accordion("Transition settings", open=True):
388
+ with gr.Row():
389
+ parental_crossfeed_power = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power, step=0.01, label='parental crossfeed power', interactive=True)
390
+ parental_crossfeed_range = gr.Slider(0.0, 1.0, bf.parental_crossfeed_range, step=0.01, label='parental crossfeed range', interactive=True)
391
+ parental_crossfeed_power_decay = gr.Slider(0.0, 1.0, bf.parental_crossfeed_power_decay, step=0.01, label='parental crossfeed decay', interactive=True)
392
+ with gr.Row():
393
+ depth_strength = gr.Slider(0.01, 0.99, bf.depth_strength, step=0.01, label='depth_strength', interactive=True)
394
+ guidance_scale_mid_damper = gr.Slider(0.01, 2.0, bf.guidance_scale_mid_damper, step=0.01, label='guidance_scale_mid_damper', interactive=True)
395
+
396
+ with gr.Row():
397
+ b_compute1 = gr.Button('compute first image', variant='primary')
398
+ b_compute_transition = gr.Button('compute transition', variant='primary')
399
+ b_compute2 = gr.Button('compute last image', variant='primary')
400
+
401
+ with gr.Row():
402
+ img1 = gr.Image(label="1/5")
403
+ img2 = gr.Image(label="2/5", show_progress=False)
404
+ img3 = gr.Image(label="3/5", show_progress=False)
405
+ img4 = gr.Image(label="4/5", show_progress=False)
406
+ img5 = gr.Image(label="5/5")
407
+
408
+ with gr.Row():
409
+ vid_single = gr.Video(label="current single trans")
410
+ vid_multi = gr.Video(label="concatented multi trans")
411
+
412
+ with gr.Row():
413
+ b_stackforward = gr.Button('append last movie segment (left) to multi movie (right)', variant='primary')
414
+
415
+ with gr.Row():
416
+ gr.Markdown(
417
+ """
418
+ # Parameters
419
+ ## Main
420
+ - compute budget: set your waiting time for the transition. high values = better quality
421
+ - video duration: seconds per segment
422
+ - height/width: in pixels
423
+
424
+ ## Diffusion settings
425
+ - num_inference_steps: number of diffusion steps
426
+ - guidance_scale: latent blending seems to prefer lower values here
427
+ - negative prompt: enter negative prompt here, applied for all images
428
+
429
+ ## Last image crossfeeding
430
+ - branch1_crossfeed_power: Controls the level of cross-feeding between the first and last image branch. For preserving structures.
431
+ - branch1_crossfeed_range: Sets the duration of active crossfeed during development. High values enforce strong structural similarity.
432
+ - branch1_crossfeed_decay: Sets decay for branch1_crossfeed_power. Lower values make the decay stronger across the range.
433
+
434
+ ## Transition settings
435
+ - parental_crossfeed_power: Similar to branch1_crossfeed_power, however applied for the images withinin the transition.
436
+ - parental_crossfeed_range: Similar to branch1_crossfeed_range, however applied for the images withinin the transition.
437
+ - parental_crossfeed_power_decay: Similar to branch1_crossfeed_decay, however applied for the images withinin the transition.
438
+ - depth_strength: Determines when the blending process will begin in terms of diffusion steps. Low values more inventive but can cause motion.
439
+ - guidance_scale_mid_damper: Decreases the guidance scale in the middle of a transition.
440
+ """)
441
+
442
+ with gr.Row():
443
+ user_id = gr.Textbox(label="user id", interactive=False)
444
+
445
+ # Collect all UI elemts in list to easily pass as inputs in gradio
446
+ dict_ui_elem = {}
447
+ dict_ui_elem["prompt1"] = prompt1
448
+ dict_ui_elem["negative_prompt"] = negative_prompt
449
+ dict_ui_elem["prompt2"] = prompt2
450
+
451
+ dict_ui_elem["duration_compute"] = duration_compute
452
+ dict_ui_elem["duration_video"] = duration_video
453
+ dict_ui_elem["height"] = height
454
+ dict_ui_elem["width"] = width
455
+
456
+ dict_ui_elem["depth_strength"] = depth_strength
457
+ dict_ui_elem["branch1_crossfeed_power"] = branch1_crossfeed_power
458
+ dict_ui_elem["branch1_crossfeed_range"] = branch1_crossfeed_range
459
+ dict_ui_elem["branch1_crossfeed_decay"] = branch1_crossfeed_decay
460
+
461
+ dict_ui_elem["num_inference_steps"] = num_inference_steps
462
+ dict_ui_elem["guidance_scale"] = guidance_scale
463
+ dict_ui_elem["guidance_scale_mid_damper"] = guidance_scale_mid_damper
464
+ dict_ui_elem["seed1"] = seed1
465
+ dict_ui_elem["seed2"] = seed2
466
+
467
+ dict_ui_elem["parental_crossfeed_range"] = parental_crossfeed_range
468
+ dict_ui_elem["parental_crossfeed_power"] = parental_crossfeed_power
469
+ dict_ui_elem["parental_crossfeed_power_decay"] = parental_crossfeed_power_decay
470
+ dict_ui_elem["user_id"] = user_id
471
+
472
+ # Convert to list, as gradio doesn't seem to accept dicts
473
+ list_ui_vals = []
474
+ list_ui_keys = []
475
+ for k in dict_ui_elem.keys():
476
+ list_ui_vals.append(dict_ui_elem[k])
477
+ list_ui_keys.append(k)
478
+ bf.list_ui_keys = list_ui_keys
479
+
480
+ b_newseed1.click(bf.randomize_seed1, outputs=seed1)
481
+ b_newseed2.click(bf.randomize_seed2, outputs=seed2)
482
+ b_compute1.click(bf.compute_img1, inputs=list_ui_vals, outputs=[img1, img2, img3, img4, img5, user_id])
483
+ b_compute2.click(bf.compute_img2, inputs=list_ui_vals, outputs=[img2, img3, img4, img5, user_id])
484
+ b_compute_transition.click(bf.compute_transition,
485
+ inputs=list_ui_vals,
486
+ outputs=[img2, img3, img4, vid_single])
487
+
488
+ b_stackforward.click(bf.stack_forward,
489
+ inputs=[prompt2, seed2],
490
+ outputs=[vid_multi, img1, img2, img3, img4, img5, prompt1, seed1, prompt2])
491
+
492
+ demo.launch(share=bf.share, inbrowser=True, inline=False)
latent_blending.py CHANGED
@@ -13,48 +13,31 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- import os, sys
17
- dp_git = "/home/lugo/git/"
18
- sys.path.append('util')
19
- # sys.path.append('../stablediffusion/ldm')
20
  import torch
21
  torch.backends.cudnn.benchmark = False
 
22
  import numpy as np
23
  import warnings
24
  warnings.filterwarnings('ignore')
25
  import time
26
- import subprocess
27
  import warnings
28
- import torch
29
  from tqdm.auto import tqdm
30
  from PIL import Image
31
- # import matplotlib.pyplot as plt
32
- import torch
33
  from movie_util import MovieSaver
34
- import datetime
35
- from typing import Callable, List, Optional, Union
36
- import inspect
37
- from threading import Thread
38
- torch.set_grad_enabled(False)
39
- from omegaconf import OmegaConf
40
- from torch import autocast
41
- from contextlib import nullcontext
42
-
43
- from ldm.models.diffusion.ddim import DDIMSampler
44
- from ldm.util import instantiate_from_config
45
  from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
46
- from stable_diffusion_holder import StableDiffusionHolder
47
- import yaml
48
  import lpips
49
- #%%
 
 
50
  class LatentBlending():
51
  def __init__(
52
- self,
53
  sdh: None,
54
  guidance_scale: float = 4,
55
  guidance_scale_mid_damper: float = 0.5,
56
- mid_compression_scaler: float = 1.2,
57
- ):
58
  r"""
59
  Initializes the latent blending class.
60
  Args:
@@ -71,9 +54,10 @@ class LatentBlending():
71
  Increases the sampling density in the middle (where most changes happen). Higher value
72
  imply more values in the middle. However the inflection point can occur outside the middle,
73
  thus high values can give rough transitions. Values around 2 should be fine.
74
-
75
  """
76
- assert guidance_scale_mid_damper>0 and guidance_scale_mid_damper<=1.0, f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
 
 
77
 
78
  self.sdh = sdh
79
  self.device = self.sdh.device
@@ -81,20 +65,20 @@ class LatentBlending():
81
  self.height = self.sdh.height
82
  self.guidance_scale_mid_damper = guidance_scale_mid_damper
83
  self.mid_compression_scaler = mid_compression_scaler
84
- self.seed1 = 0
85
  self.seed2 = 0
86
-
87
  # Initialize vars
88
  self.prompt1 = ""
89
  self.prompt2 = ""
90
  self.negative_prompt = ""
91
-
92
  self.tree_latents = [None, None]
93
  self.tree_fracts = None
94
  self.idx_injection = []
95
  self.tree_status = None
96
  self.tree_final_imgs = []
97
-
98
  self.list_nmb_branches_prev = []
99
  self.list_injection_idx_prev = []
100
  self.text_embedding1 = None
@@ -106,25 +90,23 @@ class LatentBlending():
106
  self.noise_level_upscaling = 20
107
  self.list_injection_idx = None
108
  self.list_nmb_branches = None
109
-
110
  # Mixing parameters
111
  self.branch1_crossfeed_power = 0.1
112
  self.branch1_crossfeed_range = 0.6
113
  self.branch1_crossfeed_decay = 0.8
114
-
115
  self.parental_crossfeed_power = 0.1
116
  self.parental_crossfeed_range = 0.8
117
- self.parental_crossfeed_power_decay = 0.8
118
-
119
  self.set_guidance_scale(guidance_scale)
120
  self.init_mode()
121
  self.multi_transition_img_first = None
122
  self.multi_transition_img_last = None
123
  self.dt_per_diff = 0
124
  self.spatial_mask = None
125
-
126
  self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
127
-
128
 
129
  def init_mode(self):
130
  r"""
@@ -138,7 +120,7 @@ class LatentBlending():
138
  self.mode = 'inpaint'
139
  else:
140
  self.mode = 'standard'
141
-
142
  def set_guidance_scale(self, guidance_scale):
143
  r"""
144
  sets the guidance scale.
@@ -146,25 +128,24 @@ class LatentBlending():
146
  self.guidance_scale_base = guidance_scale
147
  self.guidance_scale = guidance_scale
148
  self.sdh.guidance_scale = guidance_scale
149
-
150
  def set_negative_prompt(self, negative_prompt):
151
  r"""Set the negative prompt. Currenty only one negative prompt is supported
152
  """
153
  self.negative_prompt = negative_prompt
154
  self.sdh.set_negative_prompt(negative_prompt)
155
-
156
  def set_guidance_mid_dampening(self, fract_mixing):
157
  r"""
158
- Tunes the guidance scale down as a linear function of fract_mixing,
159
  towards 0.5 the minimum will be reached.
160
  """
161
- mid_factor = 1 - np.abs(fract_mixing - 0.5)/ 0.5
162
- max_guidance_reduction = self.guidance_scale_base * (1-self.guidance_scale_mid_damper) - 1
163
- guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction*mid_factor
164
  self.guidance_scale = guidance_scale_effective
165
  self.sdh.guidance_scale = guidance_scale_effective
166
 
167
-
168
  def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
169
  r"""
170
  Sets the crossfeed parameters for the first branch to the last branch.
@@ -179,14 +160,13 @@ class LatentBlending():
179
  self.branch1_crossfeed_power = np.clip(crossfeed_power, 0, 1)
180
  self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
181
  self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
182
-
183
-
184
  def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
185
  r"""
186
  Sets the crossfeed parameters for all transition images (within the first and last branch).
187
  Args:
188
  crossfeed_power: float [0,1]
189
- Controls the level of cross-feeding from the parental branches
190
  crossfeed_range: float [0,1]
191
  Sets the duration of active crossfeed during development.
192
  crossfeed_decay: float [0,1]
@@ -196,7 +176,6 @@ class LatentBlending():
196
  self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
197
  self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
198
 
199
-
200
  def set_prompt1(self, prompt: str):
201
  r"""
202
  Sets the first prompt (for the first keyframe) including text embeddings.
@@ -207,8 +186,7 @@ class LatentBlending():
207
  prompt = prompt.replace("_", " ")
208
  self.prompt1 = prompt
209
  self.text_embedding1 = self.get_text_embeddings(self.prompt1)
210
-
211
-
212
  def set_prompt2(self, prompt: str):
213
  r"""
214
  Sets the second prompt (for the second keyframe) including text embeddings.
@@ -219,7 +197,7 @@ class LatentBlending():
219
  prompt = prompt.replace("_", " ")
220
  self.prompt2 = prompt
221
  self.text_embedding2 = self.get_text_embeddings(self.prompt2)
222
-
223
  def set_image1(self, image: Image):
224
  r"""
225
  Sets the first image (keyframe), relevant for the upscaling model transitions.
@@ -227,7 +205,7 @@ class LatentBlending():
227
  image: Image
228
  """
229
  self.image1_lowres = image
230
-
231
  def set_image2(self, image: Image):
232
  r"""
233
  Sets the second image (keyframe), relevant for the upscaling model transitions.
@@ -235,17 +213,16 @@ class LatentBlending():
235
  image: Image
236
  """
237
  self.image2_lowres = image
238
-
239
  def run_transition(
240
  self,
241
- recycle_img1: Optional[bool] = False,
242
- recycle_img2: Optional[bool] = False,
243
  num_inference_steps: Optional[int] = 30,
244
  depth_strength: Optional[float] = 0.3,
245
  t_compute_max_allowed: Optional[float] = None,
246
  nmb_max_branches: Optional[int] = None,
247
- fixed_seeds: Optional[List[int]] = None,
248
- ):
249
  r"""
250
  Function for computing transitions.
251
  Returns a list of transition images using spherical latent blending.
@@ -257,79 +234,77 @@ class LatentBlending():
257
  num_inference_steps:
258
  Number of diffusion steps. Higher values will take more compute time.
259
  depth_strength:
260
- Determines how deep the first injection will happen.
261
  Deeper injections will cause (unwanted) formation of new structures,
262
  more shallow values will go into alpha-blendy land.
263
  t_compute_max_allowed:
264
- Either provide t_compute_max_allowed or nmb_max_branches.
265
- The maximum time allowed for computation. Higher values give better results but take longer.
266
  nmb_max_branches: int
267
  Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
268
- results. Use this if you want to have controllable results independent
269
  of your computer.
270
  fixed_seeds: Optional[List[int)]:
271
  You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
272
  Otherwise random seeds will be taken.
273
-
274
  """
275
-
276
  # Sanity checks first
277
  assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
278
  assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
279
-
280
  # Random seeds
281
  if fixed_seeds is not None:
282
  if fixed_seeds == 'randomize':
283
  fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
284
  else:
285
- assert len(fixed_seeds)==2, "Supply a list with len = 2"
286
-
287
  self.seed1 = fixed_seeds[0]
288
  self.seed2 = fixed_seeds[1]
289
-
290
  # Ensure correct num_inference_steps in holder
291
  self.num_inference_steps = num_inference_steps
292
  self.sdh.num_inference_steps = num_inference_steps
293
-
294
  # Compute / Recycle first image
295
  if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
296
  list_latents1 = self.compute_latents1()
297
  else:
298
  list_latents1 = self.tree_latents[0]
299
-
300
  # Compute / Recycle first image
301
  if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
302
  list_latents2 = self.compute_latents2()
303
  else:
304
  list_latents2 = self.tree_latents[-1]
305
-
306
  # Reset the tree, injecting the edge latents1/2 we just generated/recycled
307
- self.tree_latents = [list_latents1, list_latents2]
308
  self.tree_fracts = [0.0, 1.0]
309
  self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
310
  self.tree_idx_injection = [0, 0]
311
-
312
  # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
313
  self.spatial_mask = None
314
-
315
  # Set up branching scheme (dependent on provided compute time)
316
  list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
317
 
318
- # Run iteratively, starting with the longest trajectory.
319
  # Always inserting new branches where they are needed most according to image similarity
320
  for s_idx in tqdm(range(len(list_idx_injection))):
321
  nmb_stems = list_nmb_stems[s_idx]
322
  idx_injection = list_idx_injection[s_idx]
323
-
324
  for i in range(nmb_stems):
325
  fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
326
  self.set_guidance_mid_dampening(fract_mixing)
327
  list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
328
  self.insert_into_tree(fract_mixing, idx_injection, list_latents)
329
  # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
330
-
331
  return self.tree_final_imgs
332
-
333
 
334
  def compute_latents1(self, return_image=False):
335
  r"""
@@ -343,18 +318,17 @@ class LatentBlending():
343
  t0 = time.time()
344
  latents_start = self.get_noise(self.seed1)
345
  list_latents1 = self.run_diffusion(
346
- list_conditionings,
347
- latents_start = latents_start,
348
- idx_start = 0
349
- )
350
  t1 = time.time()
351
- self.dt_per_diff = (t1-t0) / self.num_inference_steps
352
  self.tree_latents[0] = list_latents1
353
  if return_image:
354
  return self.sdh.latent2image(list_latents1[-1])
355
  else:
356
  return list_latents1
357
-
358
  def compute_latents2(self, return_image=False):
359
  r"""
360
  Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
@@ -368,28 +342,26 @@ class LatentBlending():
368
  # Influence from branch1
369
  if self.branch1_crossfeed_power > 0.0:
370
  # Set up the mixing_coeffs
371
- idx_mixing_stop = int(round(self.num_inference_steps*self.branch1_crossfeed_range))
372
- mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power*self.branch1_crossfeed_decay, idx_mixing_stop))
373
- mixing_coeffs.extend((self.num_inference_steps-idx_mixing_stop)*[0])
374
  list_latents_mixing = self.tree_latents[0]
375
  list_latents2 = self.run_diffusion(
376
- list_conditionings,
377
- latents_start = latents_start,
378
- idx_start = 0,
379
- list_latents_mixing = list_latents_mixing,
380
- mixing_coeffs = mixing_coeffs
381
- )
382
  else:
383
  list_latents2 = self.run_diffusion(list_conditionings, latents_start)
384
  self.tree_latents[-1] = list_latents2
385
-
386
  if return_image:
387
  return self.sdh.latent2image(list_latents2[-1])
388
  else:
389
- return list_latents2
390
 
391
-
392
- def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
393
  r"""
394
  Runs a diffusion trajectory, using the latents from the respective parents
395
  Args:
@@ -403,9 +375,9 @@ class LatentBlending():
403
  the index in terms of diffusion steps, where the next insertion will start.
404
  """
405
  list_conditionings = self.get_mixed_conditioning(fract_mixing)
406
- fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
407
  # idx_reversed = self.num_inference_steps - idx_injection
408
-
409
  list_latents_parental_mix = []
410
  for i in range(self.num_inference_steps):
411
  latents_p1 = self.tree_latents[b_parent1][i]
@@ -416,22 +388,19 @@ class LatentBlending():
416
  latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
417
  list_latents_parental_mix.append(latents_parental)
418
 
419
- idx_mixing_stop = int(round(self.num_inference_steps*self.parental_crossfeed_range))
420
- mixing_coeffs = idx_injection*[self.parental_crossfeed_power]
421
  nmb_mixing = idx_mixing_stop - idx_injection
422
  if nmb_mixing > 0:
423
- mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power*self.parental_crossfeed_power_decay, nmb_mixing)))
424
- mixing_coeffs.extend((self.num_inference_steps-len(mixing_coeffs))*[0])
425
-
426
- latents_start = list_latents_parental_mix[idx_injection-1]
427
  list_latents = self.run_diffusion(
428
- list_conditionings,
429
- latents_start = latents_start,
430
- idx_start = idx_injection,
431
- list_latents_mixing = list_latents_parental_mix,
432
- mixing_coeffs = mixing_coeffs
433
- )
434
-
435
  return list_latents
436
 
437
  def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
@@ -441,48 +410,46 @@ class LatentBlending():
441
  Either provide t_compute_max_allowed or nmb_max_branches
442
  Args:
443
  depth_strength:
444
- Determines how deep the first injection will happen.
445
  Deeper injections will cause (unwanted) formation of new structures,
446
  more shallow values will go into alpha-blendy land.
447
  t_compute_max_allowed: float
448
  The maximum time allowed for computation. Higher values give better results
449
- but take longer. Use this if you want to fix your waiting time for the results.
450
  nmb_max_branches: int
451
  The maximum number of branches to be computed. Higher values give better
452
- results. Use this if you want to have controllable results independent
453
  of your computer.
454
  """
455
- idx_injection_base = int(round(self.num_inference_steps*depth_strength))
456
- list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps-1, 3)
457
  list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
458
  t_compute = 0
459
-
460
  if nmb_max_branches is None:
461
  assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
462
  stop_criterion = "t_compute_max_allowed"
463
  elif t_compute_max_allowed is None:
464
  assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
465
  stop_criterion = "nmb_max_branches"
466
- nmb_max_branches -= 2 # discounting the outer frames
467
  else:
468
  raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
469
-
470
  stop_criterion_reached = False
471
  is_first_iteration = True
472
-
473
  while not stop_criterion_reached:
474
  list_compute_steps = self.num_inference_steps - list_idx_injection
475
  list_compute_steps *= list_nmb_stems
476
- t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15*np.sum(list_nmb_stems)
477
  increase_done = False
478
- for s_idx in range(len(list_nmb_stems)-1):
479
- if list_nmb_stems[s_idx+1] / list_nmb_stems[s_idx] >= 2:
480
  list_nmb_stems[s_idx] += 1
481
  increase_done = True
482
  break
483
  if not increase_done:
484
  list_nmb_stems[-1] += 1
485
-
486
  if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
487
  stop_criterion_reached = True
488
  elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
@@ -493,7 +460,7 @@ class LatentBlending():
493
  list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
494
  else:
495
  is_first_iteration = False
496
-
497
  # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
498
  return list_idx_injection, list_nmb_stems
499
 
@@ -508,13 +475,13 @@ class LatentBlending():
508
  """
509
  # get_lpips_similarity
510
  similarities = []
511
- for i in range(len(self.tree_final_imgs)-1):
512
- similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i+1]))
513
  b_closest1 = np.argmax(similarities)
514
- b_closest2 = b_closest1+1
515
  fract_closest1 = self.tree_fracts[b_closest1]
516
  fract_closest2 = self.tree_fracts[b_closest2]
517
-
518
  # Ensure that the parents are indeed older!
519
  b_parent1 = b_closest1
520
  while True:
@@ -522,23 +489,15 @@ class LatentBlending():
522
  break
523
  else:
524
  b_parent1 -= 1
525
-
526
  b_parent2 = b_closest2
527
  while True:
528
  if self.tree_idx_injection[b_parent2] < idx_injection:
529
  break
530
  else:
531
  b_parent2 += 1
532
-
533
- # print(f"\n\nb_closest: {b_closest1} {b_closest2} fract_closest1 {fract_closest1} fract_closest2 {fract_closest2}")
534
- # print(f"b_parent: {b_parent1} {b_parent2}")
535
- # print(f"similarities {similarities}")
536
- # print(f"idx_injection {idx_injection} tree_idx_injection {self.tree_idx_injection}")
537
-
538
- fract_mixing = (fract_closest1 + fract_closest2) /2
539
  return fract_mixing, b_parent1, b_parent2
540
-
541
-
542
  def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
543
  r"""
544
  Inserts all necessary parameters into the trajectory tree.
@@ -550,31 +509,28 @@ class LatentBlending():
550
  list_latents: list
551
  list of the latents to be inserted
552
  """
553
- b_parent1, b_parent2 = get_closest_idx(fract_mixing, self.tree_fracts)
554
- self.tree_latents.insert(b_parent1+1, list_latents)
555
- self.tree_final_imgs.insert(b_parent1+1, self.sdh.latent2image(list_latents[-1]))
556
- self.tree_fracts.insert(b_parent1+1, fract_mixing)
557
- self.tree_idx_injection.insert(b_parent1+1, idx_injection)
558
-
559
-
560
- def get_spatial_mask_template(self):
561
  r"""
562
- Experimental helper function to get a spatial mask template.
563
  """
564
  shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
565
  C, H, W = shape_latents
566
  return np.ones((H, W))
567
-
568
  def set_spatial_mask(self, img_mask):
569
  r"""
570
- Experimental helper function to set a spatial mask.
571
  The mask forces latents to be overwritten.
572
  Args:
573
- img_mask:
574
  mask image [0,1]. You can get a template using get_spatial_mask_template
575
-
576
  """
577
-
578
  shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
579
  C, H, W = shape_latents
580
  img_mask = np.asarray(img_mask)
@@ -584,18 +540,15 @@ class LatentBlending():
584
  assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
585
  spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
586
  spatial_mask = torch.unsqueeze(spatial_mask, 0)
587
- spatial_mask = spatial_mask.repeat((C,1,1))
588
  spatial_mask = torch.unsqueeze(spatial_mask, 0)
589
-
590
  self.spatial_mask = spatial_mask
591
-
592
-
593
  def get_noise(self, seed):
594
  r"""
595
  Helper function to get noise given seed.
596
  Args:
597
  seed: int
598
-
599
  """
600
  generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
601
  if self.mode == 'standard':
@@ -606,87 +559,81 @@ class LatentBlending():
606
  h = self.image1_lowres.size[1]
607
  shape_latents = [self.sdh.model.channels, h, w]
608
  C, H, W = shape_latents
609
-
610
  return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
611
 
612
-
613
  @torch.no_grad()
614
  def run_diffusion(
615
- self,
616
- list_conditionings,
617
- latents_start: torch.FloatTensor = None,
618
- idx_start: int = 0,
619
- list_latents_mixing = None,
620
- mixing_coeffs = 0.0,
621
- return_image: Optional[bool] = False
622
- ):
623
-
624
  r"""
625
  Wrapper function for diffusion runners.
626
  Depending on the mode, the correct one will be executed.
627
-
628
  Args:
629
  list_conditionings: list
630
  List of all conditionings for the diffusion model.
631
- latents_start: torch.FloatTensor
632
  Latents that are used for injection
633
  idx_start: int
634
  Index of the diffusion process start and where the latents_for_injection are injected
635
- list_latents_mixing: torch.FloatTensor
636
  List of latents (latent trajectories) that are used for mixing
637
  mixing_coeffs: float or list
638
  Coefficients, how strong each element of list_latents_mixing will be mixed in.
639
  return_image: Optional[bool]
640
  Optionally return image directly
641
  """
642
-
643
  # Ensure correct num_inference_steps in Holder
644
  self.sdh.num_inference_steps = self.num_inference_steps
645
  assert type(list_conditionings) is list, "list_conditionings need to be a list"
646
-
647
  if self.mode == 'standard':
648
  text_embeddings = list_conditionings[0]
649
  return self.sdh.run_diffusion_standard(
650
- text_embeddings = text_embeddings,
651
- latents_start = latents_start,
652
- idx_start = idx_start,
653
- list_latents_mixing = list_latents_mixing,
654
- mixing_coeffs = mixing_coeffs,
655
- spatial_mask = self.spatial_mask,
656
- return_image = return_image,
657
- )
658
-
659
  elif self.mode == 'upscale':
660
  cond = list_conditionings[0]
661
  uc_full = list_conditionings[1]
662
  return self.sdh.run_diffusion_upscaling(
663
- cond,
664
- uc_full,
665
- latents_start=latents_start,
666
- idx_start=idx_start,
667
- list_latents_mixing = list_latents_mixing,
668
- mixing_coeffs = mixing_coeffs,
669
  return_image=return_image)
670
 
671
-
672
  def run_upscaling(
673
- self,
674
  dp_img: str,
675
  depth_strength: float = 0.65,
676
  num_inference_steps: int = 100,
677
  nmb_max_branches_highres: int = 5,
678
  nmb_max_branches_lowres: int = 6,
679
- duration_single_segment = 3,
680
- fixed_seeds: Optional[List[int]] = None,
681
- ):
682
  r"""
683
  Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
684
-
685
  Args:
686
  dp_img: str
687
  Path to the low-res transition path (as saved in write_imgs_transition)
688
  depth_strength:
689
- Determines how deep the first injection will happen.
690
  Deeper injections will cause (unwanted) formation of new structures,
691
  more shallow values will go into alpha-blendy land.
692
  num_inference_steps:
@@ -699,68 +646,59 @@ class LatentBlending():
699
  Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
700
  duration_single_segment: float
701
  The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
 
 
702
  fixed_seeds: Optional[List[int)]:
703
  You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
704
  Otherwise random seeds will be taken.
705
  """
706
  fp_yml = os.path.join(dp_img, "lowres.yaml")
707
  fp_movie = os.path.join(dp_img, "movie_highres.mp4")
708
- fps = 24
709
  ms = MovieSaver(fp_movie, fps=fps)
710
  assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
711
  dict_stuff = yml_load(fp_yml)
712
-
713
  # load lowres images
714
  nmb_images_lowres = dict_stuff['nmb_images']
715
  prompt1 = dict_stuff['prompt1']
716
  prompt2 = dict_stuff['prompt2']
717
- idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres-1, nmb_max_branches_lowres)).astype(np.int32)
718
  imgs_lowres = []
719
  for i in idx_img_lowres:
720
  fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
721
  assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
722
  imgs_lowres.append(Image.open(fp_img_lowres))
723
-
724
 
725
  # set up upscaling
726
  text_embeddingA = self.sdh.get_text_embedding(prompt1)
727
  text_embeddingB = self.sdh.get_text_embedding(prompt2)
728
-
729
- list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres-1)
730
-
731
- for i in range(nmb_max_branches_lowres-1):
732
  print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
733
-
734
  self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
735
- self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1-list_fract_mixing[i])
736
-
737
- if i==0:
738
- recycle_img1 = False
739
  else:
740
  self.swap_forward()
741
- recycle_img1 = True
742
-
743
  self.set_image1(imgs_lowres[i])
744
- self.set_image2(imgs_lowres[i+1])
745
-
746
  list_imgs = self.run_transition(
747
- recycle_img1 = recycle_img1,
748
- recycle_img2 = False,
749
- num_inference_steps = num_inference_steps,
750
- depth_strength = depth_strength,
751
- nmb_max_branches = nmb_max_branches_highres,
752
- )
753
-
754
  list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
755
-
756
  # Save movie frame
757
  for img in list_imgs_interp:
758
  ms.write_frame(img)
759
-
760
  ms.finalize()
761
-
762
 
763
-
764
  @torch.no_grad()
765
  def get_mixed_conditioning(self, fract_mixing):
766
  if self.mode == 'standard':
@@ -782,9 +720,8 @@ class LatentBlending():
782
 
783
  @torch.no_grad()
784
  def get_text_embeddings(
785
- self,
786
- prompt: str
787
- ):
788
  r"""
789
  Computes the text embeddings provided a string with a prompts.
790
  Adapted from stable diffusion repo
@@ -792,9 +729,7 @@ class LatentBlending():
792
  prompt: str
793
  ABC trending on artstation painted by Old Greg.
794
  """
795
-
796
  return self.sdh.get_text_embedding(prompt)
797
-
798
 
799
  def write_imgs_transition(self, dp_img):
800
  r"""
@@ -809,10 +744,9 @@ class LatentBlending():
809
  for i, img in enumerate(imgs_transition):
810
  img_leaf = Image.fromarray(img)
811
  img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
812
-
813
- fp_yml = os.path.join(dp_img, "lowres.yaml")
814
  self.save_statedict(fp_yml)
815
-
816
  def write_movie_transition(self, fp_movie, duration_transition, fps=30):
817
  r"""
818
  Writes the transition movie to fp_movie, using the given duration and fps..
@@ -824,9 +758,8 @@ class LatentBlending():
824
  duration of the movie in seonds
825
  fps: int
826
  fps of the movie
827
-
828
  """
829
-
830
  # Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
831
  imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps)
832
 
@@ -838,15 +771,13 @@ class LatentBlending():
838
  ms.write_frame(img)
839
  ms.finalize()
840
 
841
-
842
-
843
  def save_statedict(self, fp_yml):
844
  # Dump everything relevant into yaml
845
  imgs_transition = self.tree_final_imgs
846
  state_dict = self.get_state_dict()
847
  state_dict['nmb_images'] = len(imgs_transition)
848
  yml_save(fp_yml, state_dict)
849
-
850
  def get_state_dict(self):
851
  state_dict = {}
852
  grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
@@ -860,391 +791,94 @@ class LatentBlending():
860
  state_dict[v] = int(getattr(self, v))
861
  elif v == 'guidance_scale':
862
  state_dict[v] = float(getattr(self, v))
863
-
864
  else:
865
  try:
866
  state_dict[v] = getattr(self, v)
867
- except Exception as e:
868
  pass
869
-
870
  return state_dict
871
-
872
  def randomize_seed(self):
873
  r"""
874
  Set a random seed for a fresh start.
875
- """
876
  seed = np.random.randint(999999999)
877
  self.set_seed(seed)
878
-
879
  def set_seed(self, seed: int):
880
  r"""
881
  Set a the seed for a fresh start.
882
- """
883
  self.seed = seed
884
  self.sdh.seed = seed
885
-
886
  def set_width(self, width):
887
  r"""
888
  Set the width of the resulting image.
889
- """
890
  assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
891
  self.width = width
892
  self.sdh.width = width
893
-
894
  def set_height(self, height):
895
  r"""
896
  Set the height of the resulting image.
897
- """
898
  assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
899
  self.height = height
900
  self.sdh.height = height
901
-
902
 
903
  def swap_forward(self):
904
  r"""
905
  Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
906
  as in run_multi_transition()
907
- """
908
  # Move over all latents
909
  self.tree_latents[0] = self.tree_latents[-1]
910
-
911
  # Move over prompts and text embeddings
912
  self.prompt1 = self.prompt2
913
  self.text_embedding1 = self.text_embedding2
914
-
915
  # Final cleanup for extra sanity
916
- self.tree_final_imgs = []
917
-
918
-
919
  def get_lpips_similarity(self, imgA, imgB):
920
  r"""
921
- Computes the image similarity between two images imgA and imgB.
922
  Used to determine the optimal point of insertion to create smooth transitions.
923
  High values indicate low similarity.
924
- """
925
  tensorA = torch.from_numpy(imgA).float().cuda(self.device)
926
- tensorA = 2*tensorA/255.0 - 1
927
- tensorA = tensorA.permute([2,0,1]).unsqueeze(0)
928
-
929
  tensorB = torch.from_numpy(imgB).float().cuda(self.device)
930
- tensorB = 2*tensorB/255.0 - 1
931
- tensorB = tensorB.permute([2,0,1]).unsqueeze(0)
932
  lploss = self.lpips(tensorA, tensorB)
933
  lploss = float(lploss[0][0][0][0])
934
-
935
  return lploss
936
-
937
-
938
- # Auxiliary functions
939
- def get_closest_idx(
940
- fract_mixing: float,
941
- list_fract_mixing_prev: List[float],
942
- ):
943
- r"""
944
- Helper function to retrieve the parents for any given mixing.
945
- Example: fract_mixing = 0.4 and list_fract_mixing_prev = [0, 0.3, 0.6, 1.0]
946
- Will return the two closest values from list_fract_mixing_prev, i.e. [1, 2]
947
- """
948
-
949
- pdist = fract_mixing - np.asarray(list_fract_mixing_prev)
950
- pdist_pos = pdist.copy()
951
- pdist_pos[pdist_pos<0] = np.inf
952
- b_parent1 = np.argmin(pdist_pos)
953
- pdist_neg = -pdist.copy()
954
- pdist_neg[pdist_neg<=0] = np.inf
955
- b_parent2= np.argmin(pdist_neg)
956
-
957
- if b_parent1 > b_parent2:
958
- tmp = b_parent2
959
- b_parent2 = b_parent1
960
- b_parent1 = tmp
961
-
962
- return b_parent1, b_parent2
963
-
964
- @torch.no_grad()
965
- def interpolate_spherical(p0, p1, fract_mixing: float):
966
- r"""
967
- Helper function to correctly mix two random variables using spherical interpolation.
968
- See https://en.wikipedia.org/wiki/Slerp
969
- The function will always cast up to float64 for sake of extra 4.
970
- Args:
971
- p0:
972
- First tensor for interpolation
973
- p1:
974
- Second tensor for interpolation
975
- fract_mixing: float
976
- Mixing coefficient of interval [0, 1].
977
- 0 will return in p0
978
- 1 will return in p1
979
- 0.x will return a mix between both preserving angular velocity.
980
- """
981
-
982
- if p0.dtype == torch.float16:
983
- recast_to = 'fp16'
984
- else:
985
- recast_to = 'fp32'
986
-
987
- p0 = p0.double()
988
- p1 = p1.double()
989
- norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
990
- epsilon = 1e-7
991
- dot = torch.sum(p0 * p1) / norm
992
- dot = dot.clamp(-1+epsilon, 1-epsilon)
993
-
994
- theta_0 = torch.arccos(dot)
995
- sin_theta_0 = torch.sin(theta_0)
996
- theta_t = theta_0 * fract_mixing
997
- s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
998
- s1 = torch.sin(theta_t) / sin_theta_0
999
- interp = p0*s0 + p1*s1
1000
-
1001
- if recast_to == 'fp16':
1002
- interp = interp.half()
1003
- elif recast_to == 'fp32':
1004
- interp = interp.float()
1005
-
1006
- return interp
1007
-
1008
-
1009
- def interpolate_linear(p0, p1, fract_mixing):
1010
- r"""
1011
- Helper function to mix two variables using standard linear interpolation.
1012
- Args:
1013
- p0:
1014
- First tensor / np.ndarray for interpolation
1015
- p1:
1016
- Second tensor / np.ndarray for interpolation
1017
- fract_mixing: float
1018
- Mixing coefficient of interval [0, 1].
1019
- 0 will return in p0
1020
- 1 will return in p1
1021
- 0.x will return a linear mix between both.
1022
- """
1023
- reconvert_uint8 = False
1024
- if type(p0) is np.ndarray and p0.dtype == 'uint8':
1025
- reconvert_uint8 = True
1026
- p0 = p0.astype(np.float64)
1027
-
1028
- if type(p1) is np.ndarray and p1.dtype == 'uint8':
1029
- reconvert_uint8 = True
1030
- p1 = p1.astype(np.float64)
1031
-
1032
- interp = (1-fract_mixing) * p0 + fract_mixing * p1
1033
-
1034
- if reconvert_uint8:
1035
- interp = np.clip(interp, 0, 255).astype(np.uint8)
1036
-
1037
- return interp
1038
-
1039
-
1040
- def add_frames_linear_interp(
1041
- list_imgs: List[np.ndarray],
1042
- fps_target: Union[float, int] = None,
1043
- duration_target: Union[float, int] = None,
1044
- nmb_frames_target: int=None,
1045
- ):
1046
- r"""
1047
- Helper function to cheaply increase the number of frames given a list of images,
1048
- by virtue of standard linear interpolation.
1049
- The number of inserted frames will be automatically adjusted so that the total of number
1050
- of frames can be fixed precisely, using a random shuffling technique.
1051
- The function allows 1:1 comparisons between transitions as videos.
1052
-
1053
- Args:
1054
- list_imgs: List[np.ndarray)
1055
- List of images, between each image new frames will be inserted via linear interpolation.
1056
- fps_target:
1057
- OptionA: specify here the desired frames per second.
1058
- duration_target:
1059
- OptionA: specify here the desired duration of the transition in seconds.
1060
- nmb_frames_target:
1061
- OptionB: directly fix the total number of frames of the output.
1062
- """
1063
-
1064
- # Sanity
1065
- if nmb_frames_target is not None and fps_target is not None:
1066
- raise ValueError("You cannot specify both fps_target and nmb_frames_target")
1067
- if fps_target is None:
1068
- assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
1069
- if nmb_frames_target is None:
1070
- assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
1071
- assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
1072
- nmb_frames_target = fps_target*duration_target
1073
-
1074
- # Get number of frames that are missing
1075
- nmb_frames_diff = len(list_imgs)-1
1076
- nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
1077
-
1078
- if nmb_frames_missing < 1:
1079
- return list_imgs
1080
-
1081
- list_imgs_float = [img.astype(np.float32) for img in list_imgs]
1082
- # Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
1083
- mean_nmb_frames_insert = nmb_frames_missing/nmb_frames_diff
1084
- constfact = np.floor(mean_nmb_frames_insert)
1085
- remainder_x = 1-(mean_nmb_frames_insert - constfact)
1086
-
1087
- nmb_iter = 0
1088
- while True:
1089
- nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
1090
- nmb_frames_to_insert[nmb_frames_to_insert<=remainder_x] = 0
1091
- nmb_frames_to_insert[nmb_frames_to_insert>remainder_x] = 1
1092
- nmb_frames_to_insert += constfact
1093
- if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
1094
- break
1095
- nmb_iter += 1
1096
- if nmb_iter > 100000:
1097
- print("add_frames_linear_interp: issue with inserting the right number of frames")
1098
- break
1099
-
1100
- nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
1101
- list_imgs_interp = []
1102
- for i in range(len(list_imgs_float)-1):#, desc="STAGE linear interp"):
1103
- img0 = list_imgs_float[i]
1104
- img1 = list_imgs_float[i+1]
1105
- list_imgs_interp.append(img0.astype(np.uint8))
1106
- list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i]+2)[1:-1]
1107
- for fract_linblend in list_fracts_linblend:
1108
- img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
1109
- list_imgs_interp.append(img_blend.astype(np.uint8))
1110
-
1111
- if i==len(list_imgs_float)-2:
1112
- list_imgs_interp.append(img1.astype(np.uint8))
1113
-
1114
- return list_imgs_interp
1115
-
1116
-
1117
- def get_spacing(nmb_points: int, scaling: float):
1118
- """
1119
- Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
1120
- Args:
1121
- nmb_points: int
1122
- Number of points between [0, 1]
1123
- scaling: float
1124
- Higher values will return higher sampling density around 0.5
1125
-
1126
- """
1127
- if scaling < 1.7:
1128
- return np.linspace(0, 1, nmb_points)
1129
- nmb_points_per_side = nmb_points//2 + 1
1130
- if np.mod(nmb_points, 2) != 0: # uneven case
1131
- left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
1132
- right_side = 1-left_side[::-1][1:]
1133
- else:
1134
- left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
1135
- right_side = 1-left_side[::-1]
1136
- all_fracts = np.hstack([left_side, right_side])
1137
- return all_fracts
1138
-
1139
-
1140
- def get_time(resolution=None):
1141
- """
1142
- Helper function returning an nicely formatted time string, e.g. 221117_1620
1143
- """
1144
- if resolution==None:
1145
- resolution="second"
1146
- if resolution == "day":
1147
- t = time.strftime('%y%m%d', time.localtime())
1148
- elif resolution == "minute":
1149
- t = time.strftime('%y%m%d_%H%M', time.localtime())
1150
- elif resolution == "second":
1151
- t = time.strftime('%y%m%d_%H%M%S', time.localtime())
1152
- elif resolution == "millisecond":
1153
- t = time.strftime('%y%m%d_%H%M%S', time.localtime())
1154
- t += "_"
1155
- t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f'))/1000)))
1156
- else:
1157
- raise ValueError("bad resolution provided: %s" %resolution)
1158
- return t
1159
-
1160
- def compare_dicts(a, b):
1161
- """
1162
- Compares two dictionaries a and b and returns a dictionary c, with all
1163
- keys,values that have shared keys in a and b but same values in a and b.
1164
- The values of a and b are stacked together in the output.
1165
- Example:
1166
- a = {}; a['bobo'] = 4
1167
- b = {}; b['bobo'] = 5
1168
- c = dict_compare(a,b)
1169
- c = {"bobo",[4,5]}
1170
- """
1171
- c = {}
1172
- for key in a.keys():
1173
- if key in b.keys():
1174
- val_a = a[key]
1175
- val_b = b[key]
1176
- if val_a != val_b:
1177
- c[key] = [val_a, val_b]
1178
- return c
1179
-
1180
- def yml_load(fp_yml, print_fields=False):
1181
- """
1182
- Helper function for loading yaml files
1183
- """
1184
- with open(fp_yml) as f:
1185
- data = yaml.load(f, Loader=yaml.loader.SafeLoader)
1186
- dict_data = dict(data)
1187
- print("load: loaded {}".format(fp_yml))
1188
- return dict_data
1189
-
1190
- def yml_save(fp_yml, dict_stuff):
1191
- """
1192
- Helper function for saving yaml files
1193
- """
1194
- with open(fp_yml, 'w') as f:
1195
- data = yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
1196
- print("yml_save: saved {}".format(fp_yml))
1197
-
1198
-
1199
- #%% le main
1200
- if __name__ == "__main__":
1201
- # xxxx
1202
-
1203
- #%% First let us spawn a stable diffusion holder
1204
- device = "cuda"
1205
- fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_512-ema-pruned.ckpt"
1206
-
1207
- sdh = StableDiffusionHolder(fp_ckpt)
1208
-
1209
- xxx
1210
-
1211
-
1212
- #%% Next let's set up all parameters
1213
- depth_strength = 0.3 # Specifies how deep (in terms of diffusion iterations the first branching happens)
1214
- fixed_seeds = [697164, 430214]
1215
-
1216
- prompt1 = "photo of a desert and a sky"
1217
- prompt2 = "photo of a tree with a lake"
1218
-
1219
- duration_transition = 12 # In seconds
1220
- fps = 30
1221
-
1222
- # Spawn latent blending
1223
- self = LatentBlending(sdh)
1224
-
1225
- self.set_prompt1(prompt1)
1226
- self.set_prompt2(prompt2)
1227
-
1228
- # Run latent blending
1229
- self.branch1_crossfeed_power = 0.3
1230
- self.branch1_crossfeed_range = 0.4
1231
- # self.run_transition(depth_strength=depth_strength, fixed_seeds=fixed_seeds)
1232
- self.seed1=21312
1233
- img1 =self.compute_latents1(True)
1234
- #%
1235
- self.seed2=1234121
1236
- self.branch1_crossfeed_power = 0.7
1237
- self.branch1_crossfeed_range = 0.3
1238
- self.branch1_crossfeed_decay = 0.3
1239
- img2 =self.compute_latents2(True)
1240
- # Image.fromarray(np.concatenate((img1, img2), axis=1))
1241
-
1242
- #%%
1243
- t0 = time.time()
1244
- self.t_compute_max_allowed = 30
1245
- self.parental_crossfeed_range = 1.0
1246
- self.parental_crossfeed_power = 0.0
1247
- self.parental_crossfeed_power_decay = 1.0
1248
- imgs_transition = self.run_transition(recycle_img1=True, recycle_img2=True)
1249
- t1 = time.time()
1250
- print(f"took: {t1-t0}s")
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
+ import os
 
 
 
17
  import torch
18
  torch.backends.cudnn.benchmark = False
19
+ torch.set_grad_enabled(False)
20
  import numpy as np
21
  import warnings
22
  warnings.filterwarnings('ignore')
23
  import time
 
24
  import warnings
 
25
  from tqdm.auto import tqdm
26
  from PIL import Image
 
 
27
  from movie_util import MovieSaver
28
+ from typing import List, Optional
 
 
 
 
 
 
 
 
 
 
29
  from ldm.models.diffusion.ddpm import LatentUpscaleDiffusion, LatentInpaintDiffusion
 
 
30
  import lpips
31
+ from utils import interpolate_spherical, interpolate_linear, add_frames_linear_interp, yml_load, yml_save
32
+
33
+
34
  class LatentBlending():
35
  def __init__(
36
+ self,
37
  sdh: None,
38
  guidance_scale: float = 4,
39
  guidance_scale_mid_damper: float = 0.5,
40
+ mid_compression_scaler: float = 1.2):
 
41
  r"""
42
  Initializes the latent blending class.
43
  Args:
54
  Increases the sampling density in the middle (where most changes happen). Higher value
55
  imply more values in the middle. However the inflection point can occur outside the middle,
56
  thus high values can give rough transitions. Values around 2 should be fine.
 
57
  """
58
+ assert guidance_scale_mid_damper > 0 \
59
+ and guidance_scale_mid_damper <= 1.0, \
60
+ f"guidance_scale_mid_damper neees to be in interval (0,1], you provided {guidance_scale_mid_damper}"
61
 
62
  self.sdh = sdh
63
  self.device = self.sdh.device
65
  self.height = self.sdh.height
66
  self.guidance_scale_mid_damper = guidance_scale_mid_damper
67
  self.mid_compression_scaler = mid_compression_scaler
68
+ self.seed1 = 0
69
  self.seed2 = 0
70
+
71
  # Initialize vars
72
  self.prompt1 = ""
73
  self.prompt2 = ""
74
  self.negative_prompt = ""
75
+
76
  self.tree_latents = [None, None]
77
  self.tree_fracts = None
78
  self.idx_injection = []
79
  self.tree_status = None
80
  self.tree_final_imgs = []
81
+
82
  self.list_nmb_branches_prev = []
83
  self.list_injection_idx_prev = []
84
  self.text_embedding1 = None
90
  self.noise_level_upscaling = 20
91
  self.list_injection_idx = None
92
  self.list_nmb_branches = None
93
+
94
  # Mixing parameters
95
  self.branch1_crossfeed_power = 0.1
96
  self.branch1_crossfeed_range = 0.6
97
  self.branch1_crossfeed_decay = 0.8
98
+
99
  self.parental_crossfeed_power = 0.1
100
  self.parental_crossfeed_range = 0.8
101
+ self.parental_crossfeed_power_decay = 0.8
102
+
103
  self.set_guidance_scale(guidance_scale)
104
  self.init_mode()
105
  self.multi_transition_img_first = None
106
  self.multi_transition_img_last = None
107
  self.dt_per_diff = 0
108
  self.spatial_mask = None
 
109
  self.lpips = lpips.LPIPS(net='alex').cuda(self.device)
 
110
 
111
  def init_mode(self):
112
  r"""
120
  self.mode = 'inpaint'
121
  else:
122
  self.mode = 'standard'
123
+
124
  def set_guidance_scale(self, guidance_scale):
125
  r"""
126
  sets the guidance scale.
128
  self.guidance_scale_base = guidance_scale
129
  self.guidance_scale = guidance_scale
130
  self.sdh.guidance_scale = guidance_scale
131
+
132
  def set_negative_prompt(self, negative_prompt):
133
  r"""Set the negative prompt. Currenty only one negative prompt is supported
134
  """
135
  self.negative_prompt = negative_prompt
136
  self.sdh.set_negative_prompt(negative_prompt)
137
+
138
  def set_guidance_mid_dampening(self, fract_mixing):
139
  r"""
140
+ Tunes the guidance scale down as a linear function of fract_mixing,
141
  towards 0.5 the minimum will be reached.
142
  """
143
+ mid_factor = 1 - np.abs(fract_mixing - 0.5) / 0.5
144
+ max_guidance_reduction = self.guidance_scale_base * (1 - self.guidance_scale_mid_damper) - 1
145
+ guidance_scale_effective = self.guidance_scale_base - max_guidance_reduction * mid_factor
146
  self.guidance_scale = guidance_scale_effective
147
  self.sdh.guidance_scale = guidance_scale_effective
148
 
 
149
  def set_branch1_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
150
  r"""
151
  Sets the crossfeed parameters for the first branch to the last branch.
160
  self.branch1_crossfeed_power = np.clip(crossfeed_power, 0, 1)
161
  self.branch1_crossfeed_range = np.clip(crossfeed_range, 0, 1)
162
  self.branch1_crossfeed_decay = np.clip(crossfeed_decay, 0, 1)
163
+
 
164
  def set_parental_crossfeed(self, crossfeed_power, crossfeed_range, crossfeed_decay):
165
  r"""
166
  Sets the crossfeed parameters for all transition images (within the first and last branch).
167
  Args:
168
  crossfeed_power: float [0,1]
169
+ Controls the level of cross-feeding from the parental branches
170
  crossfeed_range: float [0,1]
171
  Sets the duration of active crossfeed during development.
172
  crossfeed_decay: float [0,1]
176
  self.parental_crossfeed_range = np.clip(crossfeed_range, 0, 1)
177
  self.parental_crossfeed_power_decay = np.clip(crossfeed_decay, 0, 1)
178
 
 
179
  def set_prompt1(self, prompt: str):
180
  r"""
181
  Sets the first prompt (for the first keyframe) including text embeddings.
186
  prompt = prompt.replace("_", " ")
187
  self.prompt1 = prompt
188
  self.text_embedding1 = self.get_text_embeddings(self.prompt1)
189
+
 
190
  def set_prompt2(self, prompt: str):
191
  r"""
192
  Sets the second prompt (for the second keyframe) including text embeddings.
197
  prompt = prompt.replace("_", " ")
198
  self.prompt2 = prompt
199
  self.text_embedding2 = self.get_text_embeddings(self.prompt2)
200
+
201
  def set_image1(self, image: Image):
202
  r"""
203
  Sets the first image (keyframe), relevant for the upscaling model transitions.
205
  image: Image
206
  """
207
  self.image1_lowres = image
208
+
209
  def set_image2(self, image: Image):
210
  r"""
211
  Sets the second image (keyframe), relevant for the upscaling model transitions.
213
  image: Image
214
  """
215
  self.image2_lowres = image
216
+
217
  def run_transition(
218
  self,
219
+ recycle_img1: Optional[bool] = False,
220
+ recycle_img2: Optional[bool] = False,
221
  num_inference_steps: Optional[int] = 30,
222
  depth_strength: Optional[float] = 0.3,
223
  t_compute_max_allowed: Optional[float] = None,
224
  nmb_max_branches: Optional[int] = None,
225
+ fixed_seeds: Optional[List[int]] = None):
 
226
  r"""
227
  Function for computing transitions.
228
  Returns a list of transition images using spherical latent blending.
234
  num_inference_steps:
235
  Number of diffusion steps. Higher values will take more compute time.
236
  depth_strength:
237
+ Determines how deep the first injection will happen.
238
  Deeper injections will cause (unwanted) formation of new structures,
239
  more shallow values will go into alpha-blendy land.
240
  t_compute_max_allowed:
241
+ Either provide t_compute_max_allowed or nmb_max_branches.
242
+ The maximum time allowed for computation. Higher values give better results but take longer.
243
  nmb_max_branches: int
244
  Either provide t_compute_max_allowed or nmb_max_branches. The maximum number of branches to be computed. Higher values give better
245
+ results. Use this if you want to have controllable results independent
246
  of your computer.
247
  fixed_seeds: Optional[List[int)]:
248
  You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
249
  Otherwise random seeds will be taken.
 
250
  """
251
+
252
  # Sanity checks first
253
  assert self.text_embedding1 is not None, 'Set the first text embedding with .set_prompt1(...) before'
254
  assert self.text_embedding2 is not None, 'Set the second text embedding with .set_prompt2(...) before'
255
+
256
  # Random seeds
257
  if fixed_seeds is not None:
258
  if fixed_seeds == 'randomize':
259
  fixed_seeds = list(np.random.randint(0, 1000000, 2).astype(np.int32))
260
  else:
261
+ assert len(fixed_seeds) == 2, "Supply a list with len = 2"
262
+
263
  self.seed1 = fixed_seeds[0]
264
  self.seed2 = fixed_seeds[1]
265
+
266
  # Ensure correct num_inference_steps in holder
267
  self.num_inference_steps = num_inference_steps
268
  self.sdh.num_inference_steps = num_inference_steps
269
+
270
  # Compute / Recycle first image
271
  if not recycle_img1 or len(self.tree_latents[0]) != self.num_inference_steps:
272
  list_latents1 = self.compute_latents1()
273
  else:
274
  list_latents1 = self.tree_latents[0]
275
+
276
  # Compute / Recycle first image
277
  if not recycle_img2 or len(self.tree_latents[-1]) != self.num_inference_steps:
278
  list_latents2 = self.compute_latents2()
279
  else:
280
  list_latents2 = self.tree_latents[-1]
281
+
282
  # Reset the tree, injecting the edge latents1/2 we just generated/recycled
283
+ self.tree_latents = [list_latents1, list_latents2]
284
  self.tree_fracts = [0.0, 1.0]
285
  self.tree_final_imgs = [self.sdh.latent2image((self.tree_latents[0][-1])), self.sdh.latent2image((self.tree_latents[-1][-1]))]
286
  self.tree_idx_injection = [0, 0]
287
+
288
  # Hard-fix. Apply spatial mask only for list_latents2 but not for transition. WIP...
289
  self.spatial_mask = None
290
+
291
  # Set up branching scheme (dependent on provided compute time)
292
  list_idx_injection, list_nmb_stems = self.get_time_based_branching(depth_strength, t_compute_max_allowed, nmb_max_branches)
293
 
294
+ # Run iteratively, starting with the longest trajectory.
295
  # Always inserting new branches where they are needed most according to image similarity
296
  for s_idx in tqdm(range(len(list_idx_injection))):
297
  nmb_stems = list_nmb_stems[s_idx]
298
  idx_injection = list_idx_injection[s_idx]
299
+
300
  for i in range(nmb_stems):
301
  fract_mixing, b_parent1, b_parent2 = self.get_mixing_parameters(idx_injection)
302
  self.set_guidance_mid_dampening(fract_mixing)
303
  list_latents = self.compute_latents_mix(fract_mixing, b_parent1, b_parent2, idx_injection)
304
  self.insert_into_tree(fract_mixing, idx_injection, list_latents)
305
  # print(f"fract_mixing: {fract_mixing} idx_injection {idx_injection}")
306
+
307
  return self.tree_final_imgs
 
308
 
309
  def compute_latents1(self, return_image=False):
310
  r"""
318
  t0 = time.time()
319
  latents_start = self.get_noise(self.seed1)
320
  list_latents1 = self.run_diffusion(
321
+ list_conditionings,
322
+ latents_start=latents_start,
323
+ idx_start=0)
 
324
  t1 = time.time()
325
+ self.dt_per_diff = (t1 - t0) / self.num_inference_steps
326
  self.tree_latents[0] = list_latents1
327
  if return_image:
328
  return self.sdh.latent2image(list_latents1[-1])
329
  else:
330
  return list_latents1
331
+
332
  def compute_latents2(self, return_image=False):
333
  r"""
334
  Runs a diffusion trajectory for the last image, which may be affected by the first image's trajectory.
342
  # Influence from branch1
343
  if self.branch1_crossfeed_power > 0.0:
344
  # Set up the mixing_coeffs
345
+ idx_mixing_stop = int(round(self.num_inference_steps * self.branch1_crossfeed_range))
346
+ mixing_coeffs = list(np.linspace(self.branch1_crossfeed_power, self.branch1_crossfeed_power * self.branch1_crossfeed_decay, idx_mixing_stop))
347
+ mixing_coeffs.extend((self.num_inference_steps - idx_mixing_stop) * [0])
348
  list_latents_mixing = self.tree_latents[0]
349
  list_latents2 = self.run_diffusion(
350
+ list_conditionings,
351
+ latents_start=latents_start,
352
+ idx_start=0,
353
+ list_latents_mixing=list_latents_mixing,
354
+ mixing_coeffs=mixing_coeffs)
 
355
  else:
356
  list_latents2 = self.run_diffusion(list_conditionings, latents_start)
357
  self.tree_latents[-1] = list_latents2
358
+
359
  if return_image:
360
  return self.sdh.latent2image(list_latents2[-1])
361
  else:
362
+ return list_latents2
363
 
364
+ def compute_latents_mix(self, fract_mixing, b_parent1, b_parent2, idx_injection):
 
365
  r"""
366
  Runs a diffusion trajectory, using the latents from the respective parents
367
  Args:
375
  the index in terms of diffusion steps, where the next insertion will start.
376
  """
377
  list_conditionings = self.get_mixed_conditioning(fract_mixing)
378
+ fract_mixing_parental = (fract_mixing - self.tree_fracts[b_parent1]) / (self.tree_fracts[b_parent2] - self.tree_fracts[b_parent1])
379
  # idx_reversed = self.num_inference_steps - idx_injection
380
+
381
  list_latents_parental_mix = []
382
  for i in range(self.num_inference_steps):
383
  latents_p1 = self.tree_latents[b_parent1][i]
388
  latents_parental = interpolate_spherical(latents_p1, latents_p2, fract_mixing_parental)
389
  list_latents_parental_mix.append(latents_parental)
390
 
391
+ idx_mixing_stop = int(round(self.num_inference_steps * self.parental_crossfeed_range))
392
+ mixing_coeffs = idx_injection * [self.parental_crossfeed_power]
393
  nmb_mixing = idx_mixing_stop - idx_injection
394
  if nmb_mixing > 0:
395
+ mixing_coeffs.extend(list(np.linspace(self.parental_crossfeed_power, self.parental_crossfeed_power * self.parental_crossfeed_power_decay, nmb_mixing)))
396
+ mixing_coeffs.extend((self.num_inference_steps - len(mixing_coeffs)) * [0])
397
+ latents_start = list_latents_parental_mix[idx_injection - 1]
 
398
  list_latents = self.run_diffusion(
399
+ list_conditionings,
400
+ latents_start=latents_start,
401
+ idx_start=idx_injection,
402
+ list_latents_mixing=list_latents_parental_mix,
403
+ mixing_coeffs=mixing_coeffs)
 
 
404
  return list_latents
405
 
406
  def get_time_based_branching(self, depth_strength, t_compute_max_allowed=None, nmb_max_branches=None):
410
  Either provide t_compute_max_allowed or nmb_max_branches
411
  Args:
412
  depth_strength:
413
+ Determines how deep the first injection will happen.
414
  Deeper injections will cause (unwanted) formation of new structures,
415
  more shallow values will go into alpha-blendy land.
416
  t_compute_max_allowed: float
417
  The maximum time allowed for computation. Higher values give better results
418
+ but take longer. Use this if you want to fix your waiting time for the results.
419
  nmb_max_branches: int
420
  The maximum number of branches to be computed. Higher values give better
421
+ results. Use this if you want to have controllable results independent
422
  of your computer.
423
  """
424
+ idx_injection_base = int(round(self.num_inference_steps * depth_strength))
425
+ list_idx_injection = np.arange(idx_injection_base, self.num_inference_steps - 1, 3)
426
  list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
427
  t_compute = 0
428
+
429
  if nmb_max_branches is None:
430
  assert t_compute_max_allowed is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
431
  stop_criterion = "t_compute_max_allowed"
432
  elif t_compute_max_allowed is None:
433
  assert nmb_max_branches is not None, "Either specify t_compute_max_allowed or nmb_max_branches"
434
  stop_criterion = "nmb_max_branches"
435
+ nmb_max_branches -= 2 # Discounting the outer frames
436
  else:
437
  raise ValueError("Either specify t_compute_max_allowed or nmb_max_branches")
 
438
  stop_criterion_reached = False
439
  is_first_iteration = True
 
440
  while not stop_criterion_reached:
441
  list_compute_steps = self.num_inference_steps - list_idx_injection
442
  list_compute_steps *= list_nmb_stems
443
+ t_compute = np.sum(list_compute_steps) * self.dt_per_diff + 0.15 * np.sum(list_nmb_stems)
444
  increase_done = False
445
+ for s_idx in range(len(list_nmb_stems) - 1):
446
+ if list_nmb_stems[s_idx + 1] / list_nmb_stems[s_idx] >= 2:
447
  list_nmb_stems[s_idx] += 1
448
  increase_done = True
449
  break
450
  if not increase_done:
451
  list_nmb_stems[-1] += 1
452
+
453
  if stop_criterion == "t_compute_max_allowed" and t_compute > t_compute_max_allowed:
454
  stop_criterion_reached = True
455
  elif stop_criterion == "nmb_max_branches" and np.sum(list_nmb_stems) >= nmb_max_branches:
460
  list_nmb_stems = np.ones(len(list_idx_injection), dtype=np.int32)
461
  else:
462
  is_first_iteration = False
463
+
464
  # print(f"t_compute {t_compute} list_nmb_stems {list_nmb_stems}")
465
  return list_idx_injection, list_nmb_stems
466
 
475
  """
476
  # get_lpips_similarity
477
  similarities = []
478
+ for i in range(len(self.tree_final_imgs) - 1):
479
+ similarities.append(self.get_lpips_similarity(self.tree_final_imgs[i], self.tree_final_imgs[i + 1]))
480
  b_closest1 = np.argmax(similarities)
481
+ b_closest2 = b_closest1 + 1
482
  fract_closest1 = self.tree_fracts[b_closest1]
483
  fract_closest2 = self.tree_fracts[b_closest2]
484
+
485
  # Ensure that the parents are indeed older!
486
  b_parent1 = b_closest1
487
  while True:
489
  break
490
  else:
491
  b_parent1 -= 1
 
492
  b_parent2 = b_closest2
493
  while True:
494
  if self.tree_idx_injection[b_parent2] < idx_injection:
495
  break
496
  else:
497
  b_parent2 += 1
498
+ fract_mixing = (fract_closest1 + fract_closest2) / 2
 
 
 
 
 
 
499
  return fract_mixing, b_parent1, b_parent2
500
+
 
501
  def insert_into_tree(self, fract_mixing, idx_injection, list_latents):
502
  r"""
503
  Inserts all necessary parameters into the trajectory tree.
509
  list_latents: list
510
  list of the latents to be inserted
511
  """
512
+ b_parent1, b_parent2 = self.get_closest_idx(fract_mixing)
513
+ self.tree_latents.insert(b_parent1 + 1, list_latents)
514
+ self.tree_final_imgs.insert(b_parent1 + 1, self.sdh.latent2image(list_latents[-1]))
515
+ self.tree_fracts.insert(b_parent1 + 1, fract_mixing)
516
+ self.tree_idx_injection.insert(b_parent1 + 1, idx_injection)
517
+
518
+ def get_spatial_mask_template(self):
 
519
  r"""
520
+ Experimental helper function to get a spatial mask template.
521
  """
522
  shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
523
  C, H, W = shape_latents
524
  return np.ones((H, W))
525
+
526
  def set_spatial_mask(self, img_mask):
527
  r"""
528
+ Experimental helper function to set a spatial mask.
529
  The mask forces latents to be overwritten.
530
  Args:
531
+ img_mask:
532
  mask image [0,1]. You can get a template using get_spatial_mask_template
 
533
  """
 
534
  shape_latents = [self.sdh.C, self.sdh.height // self.sdh.f, self.sdh.width // self.sdh.f]
535
  C, H, W = shape_latents
536
  img_mask = np.asarray(img_mask)
540
  assert img_mask.shape[1] == W, f"Your mask needs to be of dimension {H} x {W}"
541
  spatial_mask = torch.from_numpy(img_mask).to(device=self.device)
542
  spatial_mask = torch.unsqueeze(spatial_mask, 0)
543
+ spatial_mask = spatial_mask.repeat((C, 1, 1))
544
  spatial_mask = torch.unsqueeze(spatial_mask, 0)
 
545
  self.spatial_mask = spatial_mask
546
+
 
547
  def get_noise(self, seed):
548
  r"""
549
  Helper function to get noise given seed.
550
  Args:
551
  seed: int
 
552
  """
553
  generator = torch.Generator(device=self.sdh.device).manual_seed(int(seed))
554
  if self.mode == 'standard':
559
  h = self.image1_lowres.size[1]
560
  shape_latents = [self.sdh.model.channels, h, w]
561
  C, H, W = shape_latents
 
562
  return torch.randn((1, C, H, W), generator=generator, device=self.sdh.device)
563
 
 
564
  @torch.no_grad()
565
  def run_diffusion(
566
+ self,
567
+ list_conditionings,
568
+ latents_start: torch.FloatTensor = None,
569
+ idx_start: int = 0,
570
+ list_latents_mixing=None,
571
+ mixing_coeffs=0.0,
572
+ return_image: Optional[bool] = False):
 
 
573
  r"""
574
  Wrapper function for diffusion runners.
575
  Depending on the mode, the correct one will be executed.
576
+
577
  Args:
578
  list_conditionings: list
579
  List of all conditionings for the diffusion model.
580
+ latents_start: torch.FloatTensor
581
  Latents that are used for injection
582
  idx_start: int
583
  Index of the diffusion process start and where the latents_for_injection are injected
584
+ list_latents_mixing: torch.FloatTensor
585
  List of latents (latent trajectories) that are used for mixing
586
  mixing_coeffs: float or list
587
  Coefficients, how strong each element of list_latents_mixing will be mixed in.
588
  return_image: Optional[bool]
589
  Optionally return image directly
590
  """
591
+
592
  # Ensure correct num_inference_steps in Holder
593
  self.sdh.num_inference_steps = self.num_inference_steps
594
  assert type(list_conditionings) is list, "list_conditionings need to be a list"
595
+
596
  if self.mode == 'standard':
597
  text_embeddings = list_conditionings[0]
598
  return self.sdh.run_diffusion_standard(
599
+ text_embeddings=text_embeddings,
600
+ latents_start=latents_start,
601
+ idx_start=idx_start,
602
+ list_latents_mixing=list_latents_mixing,
603
+ mixing_coeffs=mixing_coeffs,
604
+ spatial_mask=self.spatial_mask,
605
+ return_image=return_image)
606
+
 
607
  elif self.mode == 'upscale':
608
  cond = list_conditionings[0]
609
  uc_full = list_conditionings[1]
610
  return self.sdh.run_diffusion_upscaling(
611
+ cond,
612
+ uc_full,
613
+ latents_start=latents_start,
614
+ idx_start=idx_start,
615
+ list_latents_mixing=list_latents_mixing,
616
+ mixing_coeffs=mixing_coeffs,
617
  return_image=return_image)
618
 
 
619
  def run_upscaling(
620
+ self,
621
  dp_img: str,
622
  depth_strength: float = 0.65,
623
  num_inference_steps: int = 100,
624
  nmb_max_branches_highres: int = 5,
625
  nmb_max_branches_lowres: int = 6,
626
+ duration_single_segment=3,
627
+ fps=24,
628
+ fixed_seeds: Optional[List[int]] = None):
629
  r"""
630
  Runs upscaling with the x4 model. Requires that you run a transition before with a low-res model and save the results using write_imgs_transition.
631
+
632
  Args:
633
  dp_img: str
634
  Path to the low-res transition path (as saved in write_imgs_transition)
635
  depth_strength:
636
+ Determines how deep the first injection will happen.
637
  Deeper injections will cause (unwanted) formation of new structures,
638
  more shallow values will go into alpha-blendy land.
639
  num_inference_steps:
646
  Setting this number lower (e.g. 6) will decrease the compute time but not affect the results too much.
647
  duration_single_segment: float
648
  The duration of each high-res movie segment. You will have nmb_max_branches_lowres-1 segments in total.
649
+ fps: float
650
+ frames per second of movie
651
  fixed_seeds: Optional[List[int)]:
652
  You can supply two seeds that are used for the first and second keyframe (prompt1 and prompt2).
653
  Otherwise random seeds will be taken.
654
  """
655
  fp_yml = os.path.join(dp_img, "lowres.yaml")
656
  fp_movie = os.path.join(dp_img, "movie_highres.mp4")
 
657
  ms = MovieSaver(fp_movie, fps=fps)
658
  assert os.path.isfile(fp_yml), "lowres.yaml does not exist. did you forget run_upscaling_step1?"
659
  dict_stuff = yml_load(fp_yml)
660
+
661
  # load lowres images
662
  nmb_images_lowres = dict_stuff['nmb_images']
663
  prompt1 = dict_stuff['prompt1']
664
  prompt2 = dict_stuff['prompt2']
665
+ idx_img_lowres = np.round(np.linspace(0, nmb_images_lowres - 1, nmb_max_branches_lowres)).astype(np.int32)
666
  imgs_lowres = []
667
  for i in idx_img_lowres:
668
  fp_img_lowres = os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg")
669
  assert os.path.isfile(fp_img_lowres), f"{fp_img_lowres} does not exist. did you forget run_upscaling_step1?"
670
  imgs_lowres.append(Image.open(fp_img_lowres))
 
671
 
672
  # set up upscaling
673
  text_embeddingA = self.sdh.get_text_embedding(prompt1)
674
  text_embeddingB = self.sdh.get_text_embedding(prompt2)
675
+ list_fract_mixing = np.linspace(0, 1, nmb_max_branches_lowres - 1)
676
+ for i in range(nmb_max_branches_lowres - 1):
 
 
677
  print(f"Starting movie segment {i+1}/{nmb_max_branches_lowres-1}")
 
678
  self.text_embedding1 = interpolate_linear(text_embeddingA, text_embeddingB, list_fract_mixing[i])
679
+ self.text_embedding2 = interpolate_linear(text_embeddingA, text_embeddingB, 1 - list_fract_mixing[i])
680
+ if i == 0:
681
+ recycle_img1 = False
 
682
  else:
683
  self.swap_forward()
684
+ recycle_img1 = True
685
+
686
  self.set_image1(imgs_lowres[i])
687
+ self.set_image2(imgs_lowres[i + 1])
688
+
689
  list_imgs = self.run_transition(
690
+ recycle_img1=recycle_img1,
691
+ recycle_img2=False,
692
+ num_inference_steps=num_inference_steps,
693
+ depth_strength=depth_strength,
694
+ nmb_max_branches=nmb_max_branches_highres)
 
 
695
  list_imgs_interp = add_frames_linear_interp(list_imgs, fps, duration_single_segment)
696
+
697
  # Save movie frame
698
  for img in list_imgs_interp:
699
  ms.write_frame(img)
 
700
  ms.finalize()
 
701
 
 
702
  @torch.no_grad()
703
  def get_mixed_conditioning(self, fract_mixing):
704
  if self.mode == 'standard':
720
 
721
  @torch.no_grad()
722
  def get_text_embeddings(
723
+ self,
724
+ prompt: str):
 
725
  r"""
726
  Computes the text embeddings provided a string with a prompts.
727
  Adapted from stable diffusion repo
729
  prompt: str
730
  ABC trending on artstation painted by Old Greg.
731
  """
 
732
  return self.sdh.get_text_embedding(prompt)
 
733
 
734
  def write_imgs_transition(self, dp_img):
735
  r"""
744
  for i, img in enumerate(imgs_transition):
745
  img_leaf = Image.fromarray(img)
746
  img_leaf.save(os.path.join(dp_img, f"lowres_img_{str(i).zfill(4)}.jpg"))
747
+ fp_yml = os.path.join(dp_img, "lowres.yaml")
 
748
  self.save_statedict(fp_yml)
749
+
750
  def write_movie_transition(self, fp_movie, duration_transition, fps=30):
751
  r"""
752
  Writes the transition movie to fp_movie, using the given duration and fps..
758
  duration of the movie in seonds
759
  fps: int
760
  fps of the movie
 
761
  """
762
+
763
  # Let's get more cheap frames via linear interpolation (duration_transition*fps frames)
764
  imgs_transition_ext = add_frames_linear_interp(self.tree_final_imgs, duration_transition, fps)
765
 
771
  ms.write_frame(img)
772
  ms.finalize()
773
 
 
 
774
  def save_statedict(self, fp_yml):
775
  # Dump everything relevant into yaml
776
  imgs_transition = self.tree_final_imgs
777
  state_dict = self.get_state_dict()
778
  state_dict['nmb_images'] = len(imgs_transition)
779
  yml_save(fp_yml, state_dict)
780
+
781
  def get_state_dict(self):
782
  state_dict = {}
783
  grab_vars = ['prompt1', 'prompt2', 'seed1', 'seed2', 'height', 'width',
791
  state_dict[v] = int(getattr(self, v))
792
  elif v == 'guidance_scale':
793
  state_dict[v] = float(getattr(self, v))
794
+
795
  else:
796
  try:
797
  state_dict[v] = getattr(self, v)
798
+ except Exception:
799
  pass
 
800
  return state_dict
801
+
802
  def randomize_seed(self):
803
  r"""
804
  Set a random seed for a fresh start.
805
+ """
806
  seed = np.random.randint(999999999)
807
  self.set_seed(seed)
808
+
809
  def set_seed(self, seed: int):
810
  r"""
811
  Set a the seed for a fresh start.
812
+ """
813
  self.seed = seed
814
  self.sdh.seed = seed
815
+
816
  def set_width(self, width):
817
  r"""
818
  Set the width of the resulting image.
819
+ """
820
  assert np.mod(width, 64) == 0, "set_width: value needs to be divisible by 64"
821
  self.width = width
822
  self.sdh.width = width
823
+
824
  def set_height(self, height):
825
  r"""
826
  Set the height of the resulting image.
827
+ """
828
  assert np.mod(height, 64) == 0, "set_height: value needs to be divisible by 64"
829
  self.height = height
830
  self.sdh.height = height
 
831
 
832
  def swap_forward(self):
833
  r"""
834
  Moves over keyframe two -> keyframe one. Useful for making a sequence of transitions
835
  as in run_multi_transition()
836
+ """
837
  # Move over all latents
838
  self.tree_latents[0] = self.tree_latents[-1]
 
839
  # Move over prompts and text embeddings
840
  self.prompt1 = self.prompt2
841
  self.text_embedding1 = self.text_embedding2
 
842
  # Final cleanup for extra sanity
843
+ self.tree_final_imgs = []
844
+
 
845
  def get_lpips_similarity(self, imgA, imgB):
846
  r"""
847
+ Computes the image similarity between two images imgA and imgB.
848
  Used to determine the optimal point of insertion to create smooth transitions.
849
  High values indicate low similarity.
850
+ """
851
  tensorA = torch.from_numpy(imgA).float().cuda(self.device)
852
+ tensorA = 2 * tensorA / 255.0 - 1
853
+ tensorA = tensorA.permute([2, 0, 1]).unsqueeze(0)
 
854
  tensorB = torch.from_numpy(imgB).float().cuda(self.device)
855
+ tensorB = 2 * tensorB / 255.0 - 1
856
+ tensorB = tensorB.permute([2, 0, 1]).unsqueeze(0)
857
  lploss = self.lpips(tensorA, tensorB)
858
  lploss = float(lploss[0][0][0][0])
 
859
  return lploss
860
+
861
+ # Auxiliary functions
862
+ def get_closest_idx(
863
+ self,
864
+ fract_mixing: float):
865
+ r"""
866
+ Helper function to retrieve the parents for any given mixing.
867
+ Example: fract_mixing = 0.4 and self.tree_fracts = [0, 0.3, 0.6, 1.0]
868
+ Will return the two closest values here, i.e. [1, 2]
869
+ """
870
+
871
+ pdist = fract_mixing - np.asarray(self.tree_fracts)
872
+ pdist_pos = pdist.copy()
873
+ pdist_pos[pdist_pos < 0] = np.inf
874
+ b_parent1 = np.argmin(pdist_pos)
875
+ pdist_neg = -pdist.copy()
876
+ pdist_neg[pdist_neg <= 0] = np.inf
877
+ b_parent2 = np.argmin(pdist_neg)
878
+
879
+ if b_parent1 > b_parent2:
880
+ tmp = b_parent2
881
+ b_parent2 = b_parent1
882
+ b_parent1 = tmp
883
+
884
+ return b_parent1, b_parent2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
movie_util.py CHANGED
@@ -1,5 +1,6 @@
1
  # Copyright 2022 Lunar Ring. All rights reserved.
2
- #
 
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
5
  # You may obtain a copy of the License at
@@ -17,26 +18,24 @@ import os
17
  import numpy as np
18
  from tqdm import tqdm
19
  import cv2
20
- from typing import Callable, List, Optional, Union
21
- import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
 
22
 
23
- #%%
24
-
25
  class MovieSaver():
26
  def __init__(
27
- self,
28
- fp_out: str,
29
- fps: int = 24,
30
  shape_hw: List[int] = None,
31
  crf: int = 24,
32
  codec: str = 'libx264',
33
- preset: str ='fast',
34
- pix_fmt: str = 'yuv420p',
35
- silent_ffmpeg: bool = True
36
- ):
37
  r"""
38
  Initializes movie saver class - a human friendly ffmpeg wrapper.
39
- After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
40
  Don't forget toi finalize movie file with moviesaver.finalize().
41
  Args:
42
  fp_out: str
@@ -47,22 +46,22 @@ class MovieSaver():
47
  Output shape, optional argument. Can be initialized automatically when first frame is written.
48
  crf: int
49
  ffmpeg doc: the range of the CRF scale is 0–51, where 0 is lossless
50
- (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
51
- A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
52
- Consider 17 or 18 to be visually lossless or nearly so;
53
- it should look the same or nearly the same as the input but it isn't technically lossless.
54
- The range is exponential, so increasing the CRF value +6 results in
55
- roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
56
  codec: int
57
  Number of diffusion steps. Larger values will take more compute time.
58
  preset: str
59
  Choose between ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow.
60
- ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
61
- to compression ratio. A slower preset will provide better compression
62
- (compression is quality per filesize).
63
- This means that, for example, if you target a certain file size or constant bit rate,
64
  you will achieve better quality with a slower preset. Similarly, for constant quality encoding,
65
- you will simply save bitrate by choosing a slower preset.
66
  pix_fmt: str
67
  Pixel format. Run 'ffmpeg -pix_fmts' in your shell to see all options.
68
  silent_ffmpeg: bool
@@ -70,7 +69,7 @@ class MovieSaver():
70
  """
71
  if len(os.path.split(fp_out)[0]) > 0:
72
  assert os.path.isdir(os.path.split(fp_out)[0]), "Directory does not exist!"
73
-
74
  self.fp_out = fp_out
75
  self.fps = fps
76
  self.crf = crf
@@ -78,10 +77,10 @@ class MovieSaver():
78
  self.codec = codec
79
  self.preset = preset
80
  self.silent_ffmpeg = silent_ffmpeg
81
-
82
  if os.path.isfile(fp_out):
83
  os.remove(fp_out)
84
-
85
  self.init_done = False
86
  self.nmb_frames = 0
87
  if shape_hw is None:
@@ -91,11 +90,9 @@ class MovieSaver():
91
  shape_hw.append(3)
92
  self.shape_hw = shape_hw
93
  self.initialize()
94
-
95
-
96
  print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
97
-
98
-
99
  def initialize(self):
100
  args = (
101
  ffmpeg
@@ -111,8 +108,7 @@ class MovieSaver():
111
  self.init_done = True
112
  self.shape_hw = tuple(self.shape_hw)
113
  print(f"Initialization done. Movie shape: {self.shape_hw}")
114
-
115
-
116
  def write_frame(self, out_frame: np.ndarray):
117
  r"""
118
  Function to dump a numpy array as frame of a movie.
@@ -123,18 +119,17 @@ class MovieSaver():
123
  Dim 1: x
124
  Dim 2: RGB
125
  """
126
-
127
  assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
128
  assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
129
  assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
130
-
131
  if not self.init_done:
132
  self.shape_hw = out_frame.shape
133
  self.initialize()
134
-
135
  assert self.shape_hw == out_frame.shape, f"You cannot change the image size after init. Initialized with {self.shape_hw}, out_frame {out_frame.shape}"
136
 
137
- # write frame
138
  self.ffmpg_process.stdin.write(
139
  out_frame
140
  .astype(np.uint8)
@@ -142,8 +137,7 @@ class MovieSaver():
142
  )
143
 
144
  self.nmb_frames += 1
145
-
146
-
147
  def finalize(self):
148
  r"""
149
  Call this function to finalize the movie. If you forget to call it your movie will be garbage.
@@ -157,7 +151,6 @@ class MovieSaver():
157
  print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
158
 
159
 
160
-
161
  def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
162
  r"""
163
  Concatenate multiple movie segments into one long movie, using ffmpeg.
@@ -167,13 +160,13 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
167
  fp_final : str
168
  Full path of the final movie file. Should end with .mp4
169
  list_fp_movies : list[str]
170
- List of full paths of movie segments.
171
  """
172
  assert fp_final[-4] == ".", "fp_final seems to miss file extension: {fp_final}"
173
  for fp in list_fp_movies:
174
  assert os.path.isfile(fp), f"Input movie does not exist: {fp}"
175
  assert os.path.getsize(fp) > 100, f"Input movie seems empty: {fp}"
176
-
177
  if os.path.isfile(fp_final):
178
  os.remove(fp_final)
179
 
@@ -181,32 +174,32 @@ def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
181
  list_concat = []
182
  for fp_part in list_fp_movies:
183
  list_concat.append(f"""file '{fp_part}'""")
184
-
185
  # save this list
186
  fp_list = "tmp_move.txt"
187
  with open(fp_list, "w") as fa:
188
  for item in list_concat:
189
  fa.write("%s\n" % item)
190
-
191
  cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
192
- dp_movie = os.path.split(fp_final)[0]
193
  subprocess.call(cmd, shell=True)
194
  os.remove(fp_list)
195
  if os.path.isfile(fp_final):
196
  print(f"concatenate_movies: success! Watch here: {fp_final}")
197
 
198
-
199
  class MovieReader():
200
  r"""
201
  Class to read in a movie.
202
  """
 
203
  def __init__(self, fp_movie):
204
  self.video_player_object = cv2.VideoCapture(fp_movie)
205
  self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
206
  self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
207
- self.shape = [100,100,3]
208
  self.shape_is_set = False
209
-
210
  def get_next_frame(self):
211
  success, image = self.video_player_object.read()
212
  if success:
@@ -217,19 +210,18 @@ class MovieReader():
217
  else:
218
  return np.zeros(self.shape)
219
 
220
- #%%
221
- if __name__ == "__main__":
222
- fps=2
223
  list_fp_movies = []
224
  for k in range(4):
225
  fp_movie = f"/tmp/my_random_movie_{k}.mp4"
226
  list_fp_movies.append(fp_movie)
227
  ms = MovieSaver(fp_movie, fps=fps)
228
  for fn in tqdm(range(30)):
229
- img = (np.random.rand(512, 1024, 3)*255).astype(np.uint8)
230
  ms.write_frame(img)
231
  ms.finalize()
232
-
233
  fp_final = "/tmp/my_concatenated_movie.mp4"
234
  concatenate_movies(fp_final, list_fp_movies)
235
-
1
  # Copyright 2022 Lunar Ring. All rights reserved.
2
+ # Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
3
+
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
6
  # You may obtain a copy of the License at
18
  import numpy as np
19
  from tqdm import tqdm
20
  import cv2
21
+ from typing import List
22
+ import ffmpeg # pip install ffmpeg-python. if error with broken pipe: conda update ffmpeg
23
+
24
 
 
 
25
  class MovieSaver():
26
  def __init__(
27
+ self,
28
+ fp_out: str,
29
+ fps: int = 24,
30
  shape_hw: List[int] = None,
31
  crf: int = 24,
32
  codec: str = 'libx264',
33
+ preset: str = 'fast',
34
+ pix_fmt: str = 'yuv420p',
35
+ silent_ffmpeg: bool = True):
 
36
  r"""
37
  Initializes movie saver class - a human friendly ffmpeg wrapper.
38
+ After you init the class, you can dump numpy arrays x into moviesaver.write_frame(x).
39
  Don't forget toi finalize movie file with moviesaver.finalize().
40
  Args:
41
  fp_out: str
46
  Output shape, optional argument. Can be initialized automatically when first frame is written.
47
  crf: int
48
  ffmpeg doc: the range of the CRF scale is 0–51, where 0 is lossless
49
+ (for 8 bit only, for 10 bit use -qp 0), 23 is the default, and 51 is worst quality possible.
50
+ A lower value generally leads to higher quality, and a subjectively sane range is 17–28.
51
+ Consider 17 or 18 to be visually lossless or nearly so;
52
+ it should look the same or nearly the same as the input but it isn't technically lossless.
53
+ The range is exponential, so increasing the CRF value +6 results in
54
+ roughly half the bitrate / file size, while -6 leads to roughly twice the bitrate.
55
  codec: int
56
  Number of diffusion steps. Larger values will take more compute time.
57
  preset: str
58
  Choose between ultrafast, superfast, veryfast, faster, fast, medium, slow, slower, veryslow.
59
+ ffmpeg doc: A preset is a collection of options that will provide a certain encoding speed
60
+ to compression ratio. A slower preset will provide better compression
61
+ (compression is quality per filesize).
62
+ This means that, for example, if you target a certain file size or constant bit rate,
63
  you will achieve better quality with a slower preset. Similarly, for constant quality encoding,
64
+ you will simply save bitrate by choosing a slower preset.
65
  pix_fmt: str
66
  Pixel format. Run 'ffmpeg -pix_fmts' in your shell to see all options.
67
  silent_ffmpeg: bool
69
  """
70
  if len(os.path.split(fp_out)[0]) > 0:
71
  assert os.path.isdir(os.path.split(fp_out)[0]), "Directory does not exist!"
72
+
73
  self.fp_out = fp_out
74
  self.fps = fps
75
  self.crf = crf
77
  self.codec = codec
78
  self.preset = preset
79
  self.silent_ffmpeg = silent_ffmpeg
80
+
81
  if os.path.isfile(fp_out):
82
  os.remove(fp_out)
83
+
84
  self.init_done = False
85
  self.nmb_frames = 0
86
  if shape_hw is None:
90
  shape_hw.append(3)
91
  self.shape_hw = shape_hw
92
  self.initialize()
93
+
 
94
  print(f"MovieSaver initialized. fps={fps} crf={crf} pix_fmt={pix_fmt} codec={codec} preset={preset}")
95
+
 
96
  def initialize(self):
97
  args = (
98
  ffmpeg
108
  self.init_done = True
109
  self.shape_hw = tuple(self.shape_hw)
110
  print(f"Initialization done. Movie shape: {self.shape_hw}")
111
+
 
112
  def write_frame(self, out_frame: np.ndarray):
113
  r"""
114
  Function to dump a numpy array as frame of a movie.
119
  Dim 1: x
120
  Dim 2: RGB
121
  """
 
122
  assert out_frame.dtype == np.uint8, "Convert to np.uint8 before"
123
  assert len(out_frame.shape) == 3, "out_frame needs to be three dimensional, Y X C"
124
  assert out_frame.shape[2] == 3, f"need three color channels, but you provided {out_frame.shape[2]}."
125
+
126
  if not self.init_done:
127
  self.shape_hw = out_frame.shape
128
  self.initialize()
129
+
130
  assert self.shape_hw == out_frame.shape, f"You cannot change the image size after init. Initialized with {self.shape_hw}, out_frame {out_frame.shape}"
131
 
132
+ # write frame
133
  self.ffmpg_process.stdin.write(
134
  out_frame
135
  .astype(np.uint8)
137
  )
138
 
139
  self.nmb_frames += 1
140
+
 
141
  def finalize(self):
142
  r"""
143
  Call this function to finalize the movie. If you forget to call it your movie will be garbage.
151
  print(f"Movie saved, {duration}s playtime, watch here: \n{self.fp_out}")
152
 
153
 
 
154
  def concatenate_movies(fp_final: str, list_fp_movies: List[str]):
155
  r"""
156
  Concatenate multiple movie segments into one long movie, using ffmpeg.
160
  fp_final : str
161
  Full path of the final movie file. Should end with .mp4
162
  list_fp_movies : list[str]
163
+ List of full paths of movie segments.
164
  """
165
  assert fp_final[-4] == ".", "fp_final seems to miss file extension: {fp_final}"
166
  for fp in list_fp_movies:
167
  assert os.path.isfile(fp), f"Input movie does not exist: {fp}"
168
  assert os.path.getsize(fp) > 100, f"Input movie seems empty: {fp}"
169
+
170
  if os.path.isfile(fp_final):
171
  os.remove(fp_final)
172
 
174
  list_concat = []
175
  for fp_part in list_fp_movies:
176
  list_concat.append(f"""file '{fp_part}'""")
177
+
178
  # save this list
179
  fp_list = "tmp_move.txt"
180
  with open(fp_list, "w") as fa:
181
  for item in list_concat:
182
  fa.write("%s\n" % item)
183
+
184
  cmd = f'ffmpeg -f concat -safe 0 -i {fp_list} -c copy {fp_final}'
 
185
  subprocess.call(cmd, shell=True)
186
  os.remove(fp_list)
187
  if os.path.isfile(fp_final):
188
  print(f"concatenate_movies: success! Watch here: {fp_final}")
189
 
190
+
191
  class MovieReader():
192
  r"""
193
  Class to read in a movie.
194
  """
195
+
196
  def __init__(self, fp_movie):
197
  self.video_player_object = cv2.VideoCapture(fp_movie)
198
  self.nmb_frames = int(self.video_player_object.get(cv2.CAP_PROP_FRAME_COUNT))
199
  self.fps_movie = int(self.video_player_object.get(cv2.CAP_PROP_FPS))
200
+ self.shape = [100, 100, 3]
201
  self.shape_is_set = False
202
+
203
  def get_next_frame(self):
204
  success, image = self.video_player_object.read()
205
  if success:
210
  else:
211
  return np.zeros(self.shape)
212
 
213
+
214
+ if __name__ == "__main__":
215
+ fps = 2
216
  list_fp_movies = []
217
  for k in range(4):
218
  fp_movie = f"/tmp/my_random_movie_{k}.mp4"
219
  list_fp_movies.append(fp_movie)
220
  ms = MovieSaver(fp_movie, fps=fps)
221
  for fn in tqdm(range(30)):
222
+ img = (np.random.rand(512, 1024, 3) * 255).astype(np.uint8)
223
  ms.write_frame(img)
224
  ms.finalize()
225
+
226
  fp_final = "/tmp/my_concatenated_movie.mp4"
227
  concatenate_movies(fp_final, list_fp_movies)
 
stable_diffusion_holder.py CHANGED
@@ -13,36 +13,25 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
- import os, sys
17
- dp_git = "/home/lugo/git/"
18
- sys.path.append(os.path.join(dp_git,'garden4'))
19
- sys.path.append('util')
20
  import torch
21
  torch.backends.cudnn.benchmark = False
 
22
  import numpy as np
23
  import warnings
24
  warnings.filterwarnings('ignore')
25
- import time
26
- import subprocess
27
  import warnings
28
  import torch
29
- from tqdm.auto import tqdm
30
  from PIL import Image
31
- # import matplotlib.pyplot as plt
32
  import torch
33
- from movie_util import MovieSaver
34
- import datetime
35
- from typing import Callable, List, Optional, Union
36
- import inspect
37
- from threading import Thread
38
- torch.set_grad_enabled(False)
39
  from omegaconf import OmegaConf
40
  from torch import autocast
41
  from contextlib import nullcontext
42
  from ldm.util import instantiate_from_config
43
  from ldm.models.diffusion.ddim import DDIMSampler
44
  from einops import repeat, rearrange
45
- #%%
46
 
47
 
48
  def pad_image(input_image):
@@ -53,41 +42,11 @@ def pad_image(input_image):
53
  return im_padded
54
 
55
 
56
-
57
- def make_batch_inpaint(
58
- image,
59
- mask,
60
- txt,
61
- device,
62
- num_samples=1):
63
- image = np.array(image.convert("RGB"))
64
- image = image[None].transpose(0, 3, 1, 2)
65
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
66
-
67
- mask = np.array(mask.convert("L"))
68
- mask = mask.astype(np.float32) / 255.0
69
- mask = mask[None, None]
70
- mask[mask < 0.5] = 0
71
- mask[mask >= 0.5] = 1
72
- mask = torch.from_numpy(mask)
73
-
74
- masked_image = image * (mask < 0.5)
75
-
76
- batch = {
77
- "image": repeat(image.to(device=device), "1 ... -> n ...", n=num_samples),
78
- "txt": num_samples * [txt],
79
- "mask": repeat(mask.to(device=device), "1 ... -> n ...", n=num_samples),
80
- "masked_image": repeat(masked_image.to(device=device), "1 ... -> n ...", n=num_samples),
81
- }
82
- return batch
83
-
84
-
85
  def make_batch_superres(
86
  image,
87
  txt,
88
  device,
89
- num_samples=1,
90
- ):
91
  image = np.array(image.convert("RGB"))
92
  image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
93
  batch = {
@@ -107,14 +66,14 @@ def make_noise_augmentation(model, batch, noise_level=None):
107
 
108
 
109
  class StableDiffusionHolder:
110
- def __init__(self,
111
- fp_ckpt: str = None,
112
  fp_config: str = None,
113
- num_inference_steps: int = 30,
114
  height: Optional[int] = None,
115
  width: Optional[int] = None,
116
  device: str = None,
117
- precision: str='autocast',
118
  ):
119
  r"""
120
  Initializes the stable diffusion holder, which contains the models and sampler.
@@ -122,26 +81,26 @@ class StableDiffusionHolder:
122
  fp_ckpt: File pointer to the .ckpt model file
123
  fp_config: File pointer to the .yaml config file
124
  num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending.
125
- height: Height of the resulting image.
126
- width: Width of the resulting image.
127
  device: Device to run the model on.
128
  precision: Precision to run the model on.
129
  """
130
  self.seed = 42
131
  self.guidance_scale = 5.0
132
-
133
  if device is None:
134
  self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
135
  else:
136
  self.device = device
137
  self.precision = precision
138
  self.init_model(fp_ckpt, fp_config)
139
-
140
- self.f = 8 #downsampling factor, most often 8 or 16",
141
  self.C = 4
142
  self.ddim_eta = 0
143
  self.num_inference_steps = num_inference_steps
144
-
145
  if height is None and width is None:
146
  self.init_auto_res()
147
  else:
@@ -149,53 +108,44 @@ class StableDiffusionHolder:
149
  assert width is not None, "specify both width and height"
150
  self.height = height
151
  self.width = width
152
-
153
- # Inpainting inits
154
- self.mask_empty = Image.fromarray(255*np.ones([self.width, self.height], dtype=np.uint8))
155
- self.image_empty = Image.fromarray(np.zeros([self.width, self.height, 3], dtype=np.uint8))
156
-
157
  self.negative_prompt = [""]
158
-
159
-
160
  def init_model(self, fp_ckpt, fp_config):
161
  r"""Loads the models and sampler.
162
  """
163
 
164
  assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}"
165
  self.fp_ckpt = fp_ckpt
166
-
167
  # Auto init the config?
168
  if fp_config is None:
169
  fn_ckpt = os.path.basename(fp_ckpt)
170
  if 'depth' in fn_ckpt:
171
  fp_config = 'configs/v2-midas-inference.yaml'
172
- elif 'inpain' in fn_ckpt:
173
- fp_config = 'configs/v2-inpainting-inference.yaml'
174
  elif 'upscaler' in fn_ckpt:
175
- fp_config = 'configs/x4-upscaling.yaml'
176
  elif '512' in fn_ckpt:
177
- fp_config = 'configs/v2-inference.yaml'
178
- elif '768'in fn_ckpt:
179
- fp_config = 'configs/v2-inference-v.yaml'
180
  elif 'v1-5' in fn_ckpt:
181
- fp_config = 'configs/v1-inference.yaml'
182
  else:
183
  raise ValueError("auto detect of config failed. please specify fp_config manually!")
184
-
185
  assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually."
186
-
187
  assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
188
-
189
 
190
  config = OmegaConf.load(fp_config)
191
-
192
  self.model = instantiate_from_config(config.model)
193
  self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False)
194
 
195
  self.model = self.model.to(self.device)
196
  self.sampler = DDIMSampler(self.model)
197
-
198
-
199
  def init_auto_res(self):
200
  r"""Automatically set the resolution to the one used in training.
201
  """
@@ -205,7 +155,7 @@ class StableDiffusionHolder:
205
  else:
206
  self.height = 512
207
  self.width = 512
208
-
209
  def set_negative_prompt(self, negative_prompt):
210
  r"""Set the negative prompt. Currenty only one negative prompt is supported
211
  """
@@ -214,51 +164,46 @@ class StableDiffusionHolder:
214
  self.negative_prompt = [negative_prompt]
215
  else:
216
  self.negative_prompt = negative_prompt
217
-
218
  if len(self.negative_prompt) > 1:
219
  self.negative_prompt = [self.negative_prompt[0]]
220
 
221
-
222
  def get_text_embedding(self, prompt):
223
  c = self.model.get_learned_conditioning(prompt)
224
  return c
225
-
226
  @torch.no_grad()
227
  def get_cond_upscaling(self, image, text_embedding, noise_level):
228
  r"""
229
  Initializes the conditioning for the x4 upscaling model.
230
  """
231
-
232
  image = pad_image(image) # resize to integer multiple of 32
233
  w, h = image.size
234
  noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
235
  batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1)
236
 
237
  x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level)
238
-
239
  cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level}
240
  # uncond cond
241
  uc_cross = self.model.get_unconditional_conditioning(1, "")
242
  uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
243
-
244
  return cond, uc_full
245
 
246
  @torch.no_grad()
247
  def run_diffusion_standard(
248
- self,
249
- text_embeddings: torch.FloatTensor,
250
  latents_start: torch.FloatTensor,
251
- idx_start: int = 0,
252
- list_latents_mixing = None,
253
- mixing_coeffs = 0.0,
254
- spatial_mask = None,
255
- return_image: Optional[bool] = False,
256
- ):
257
  r"""
258
- Diffusion standard version.
259
-
260
  Args:
261
- text_embeddings: torch.FloatTensor
262
  Text embeddings used for diffusion
263
  latents_for_injection: torch.FloatTensor or list
264
  Latents that are used for injection
@@ -270,41 +215,32 @@ class StableDiffusionHolder:
270
  experimental feature for enforcing pixels from list_latents_mixing
271
  return_image: Optional[bool]
272
  Optionally return image directly
273
-
274
  """
275
-
276
  # Asserts
277
  if type(mixing_coeffs) == float:
278
- list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
279
  elif type(mixing_coeffs) == list:
280
  assert len(mixing_coeffs) == self.num_inference_steps
281
  list_mixing_coeffs = mixing_coeffs
282
  else:
283
  raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
284
-
285
  if np.sum(list_mixing_coeffs) > 0:
286
  assert len(list_latents_mixing) == self.num_inference_steps
287
-
288
-
289
  precision_scope = autocast if self.precision == "autocast" else nullcontext
290
-
291
  with precision_scope("cuda"):
292
  with self.model.ema_scope():
293
  if self.guidance_scale != 1.0:
294
  uc = self.model.get_learned_conditioning(self.negative_prompt)
295
  else:
296
  uc = None
297
-
298
- self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
299
-
300
  latents = latents_start.clone()
301
-
302
  timesteps = self.sampler.ddim_timesteps
303
-
304
  time_range = np.flip(timesteps)
305
  total_steps = timesteps.shape[0]
306
-
307
- # collect latents
308
  list_latents_out = []
309
  for i, step in enumerate(time_range):
310
  # Set the right starting latents
@@ -313,83 +249,71 @@ class StableDiffusionHolder:
313
  continue
314
  elif i == idx_start:
315
  latents = latents_start.clone()
316
-
317
- # Mix the latents.
318
- if i > 0 and list_mixing_coeffs[i]>0:
319
- latents_mixtarget = list_latents_mixing[i-1].clone()
320
  latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
321
-
322
  if spatial_mask is not None and list_latents_mixing is not None:
323
- latents = interpolate_spherical(latents, list_latents_mixing[i-1], 1-spatial_mask)
324
- # latents[:,:,-15:,:] = latents_mixtarget[:,:,-15:,:]
325
-
326
  index = total_steps - i - 1
327
  ts = torch.full((1,), step, device=self.device, dtype=torch.long)
328
  outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
329
- quantize_denoised=False, temperature=1.0,
330
- noise_dropout=0.0, score_corrector=None,
331
- corrector_kwargs=None,
332
- unconditional_guidance_scale=self.guidance_scale,
333
- unconditional_conditioning=uc,
334
- dynamic_threshold=None)
335
  latents, pred_x0 = outs
336
  list_latents_out.append(latents.clone())
337
-
338
- if return_image:
339
  return self.latent2image(latents)
340
  else:
341
  return list_latents_out
342
-
343
-
344
  @torch.no_grad()
345
  def run_diffusion_upscaling(
346
- self,
347
  cond,
348
  uc_full,
349
- latents_start: torch.FloatTensor,
350
- idx_start: int = -1,
351
- list_latents_mixing = None,
352
- mixing_coeffs = 0.0,
353
- return_image: Optional[bool] = False
354
- ):
355
  r"""
356
- Diffusion upscaling version.
357
  """
358
-
359
  # Asserts
360
  if type(mixing_coeffs) == float:
361
- list_mixing_coeffs = self.num_inference_steps*[mixing_coeffs]
362
  elif type(mixing_coeffs) == list:
363
  assert len(mixing_coeffs) == self.num_inference_steps
364
  list_mixing_coeffs = mixing_coeffs
365
  else:
366
  raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
367
-
368
  if np.sum(list_mixing_coeffs) > 0:
369
  assert len(list_latents_mixing) == self.num_inference_steps
370
-
371
  precision_scope = autocast if self.precision == "autocast" else nullcontext
372
-
373
- h = uc_full['c_concat'][0].shape[2]
374
- w = uc_full['c_concat'][0].shape[3]
375
-
376
  with precision_scope("cuda"):
377
  with self.model.ema_scope():
378
 
379
  shape_latents = [self.model.channels, h, w]
380
-
381
- self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=self.ddim_eta, verbose=False)
382
  C, H, W = shape_latents
383
  size = (1, C, H, W)
384
  b = size[0]
385
-
386
  latents = latents_start.clone()
387
-
388
  timesteps = self.sampler.ddim_timesteps
389
-
390
  time_range = np.flip(timesteps)
391
  total_steps = timesteps.shape[0]
392
-
393
  # collect latents
394
  list_latents_out = []
395
  for i, step in enumerate(time_range):
@@ -399,232 +323,40 @@ class StableDiffusionHolder:
399
  continue
400
  elif i == idx_start:
401
  latents = latents_start.clone()
402
-
403
- # Mix the latents.
404
- if i > 0 and list_mixing_coeffs[i]>0:
405
- latents_mixtarget = list_latents_mixing[i-1].clone()
406
  latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
407
-
408
  # print(f"diffusion iter {i}")
409
  index = total_steps - i - 1
410
  ts = torch.full((b,), step, device=self.device, dtype=torch.long)
411
  outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
412
- quantize_denoised=False, temperature=1.0,
413
- noise_dropout=0.0, score_corrector=None,
414
- corrector_kwargs=None,
415
- unconditional_guidance_scale=self.guidance_scale,
416
- unconditional_conditioning=uc_full,
417
- dynamic_threshold=None)
418
  latents, pred_x0 = outs
419
  list_latents_out.append(latents.clone())
420
-
421
- if return_image:
422
- return self.latent2image(latents)
423
- else:
424
- return list_latents_out
425
-
426
- @torch.no_grad()
427
- def run_diffusion_inpaint(
428
- self,
429
- text_embeddings: torch.FloatTensor,
430
- latents_for_injection: torch.FloatTensor = None,
431
- idx_start: int = -1,
432
- idx_stop: int = -1,
433
- return_image: Optional[bool] = False
434
- ):
435
- r"""
436
- Runs inpaint-based diffusion. Returns a list of latents that were computed.
437
- Adaptations allow to supply
438
- a) starting index for diffusion
439
- b) stopping index for diffusion
440
- c) latent representations that are injected at the starting index
441
- Furthermore the intermittent latents are collected and returned.
442
-
443
- Adapted from diffusers (https://github.com/huggingface/diffusers)
444
- Args:
445
- text_embeddings: torch.FloatTensor
446
- Text embeddings used for diffusion
447
- latents_for_injection: torch.FloatTensor
448
- Latents that are used for injection
449
- idx_start: int
450
- Index of the diffusion process start and where the latents_for_injection are injected
451
- idx_stop: int
452
- Index of the diffusion process end.
453
- return_image: Optional[bool]
454
- Optionally return image directly
455
-
456
- """
457
-
458
- if latents_for_injection is None:
459
- do_inject_latents = False
460
- else:
461
- do_inject_latents = True
462
-
463
- precision_scope = autocast if self.precision == "autocast" else nullcontext
464
- generator = torch.Generator(device=self.device).manual_seed(int(self.seed))
465
 
466
- with precision_scope("cuda"):
467
- with self.model.ema_scope():
468
-
469
- batch = make_batch_inpaint(self.image_source, self.mask_image, txt="willbereplaced", device=self.device, num_samples=1)
470
- c = text_embeddings
471
- c_cat = list()
472
- for ck in self.model.concat_keys:
473
- cc = batch[ck].float()
474
- if ck != self.model.masked_image_key:
475
- bchw = [1, 4, self.height // 8, self.width // 8]
476
- cc = torch.nn.functional.interpolate(cc, size=bchw[-2:])
477
- else:
478
- cc = self.model.get_first_stage_encoding(self.model.encode_first_stage(cc))
479
- c_cat.append(cc)
480
- c_cat = torch.cat(c_cat, dim=1)
481
-
482
- # cond
483
- cond = {"c_concat": [c_cat], "c_crossattn": [c]}
484
-
485
- # uncond cond
486
- uc_cross = self.model.get_unconditional_conditioning(1, "")
487
- uc_full = {"c_concat": [c_cat], "c_crossattn": [uc_cross]}
488
-
489
- shape_latents = [self.model.channels, self.height // 8, self.width // 8]
490
-
491
- self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps-1, ddim_eta=0., verbose=False)
492
- # sampling
493
- C, H, W = shape_latents
494
- size = (1, C, H, W)
495
-
496
- device = self.model.betas.device
497
- b = size[0]
498
- latents = torch.randn(size, generator=generator, device=device)
499
-
500
- timesteps = self.sampler.ddim_timesteps
501
-
502
- time_range = np.flip(timesteps)
503
- total_steps = timesteps.shape[0]
504
-
505
- # collect latents
506
- list_latents_out = []
507
- for i, step in enumerate(time_range):
508
- if do_inject_latents:
509
- # Inject latent at right place
510
- if i < idx_start:
511
- continue
512
- elif i == idx_start:
513
- latents = latents_for_injection.clone()
514
-
515
- if i == idx_stop:
516
- return list_latents_out
517
-
518
- index = total_steps - i - 1
519
- ts = torch.full((b,), step, device=device, dtype=torch.long)
520
-
521
- outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
522
- quantize_denoised=False, temperature=1.0,
523
- noise_dropout=0.0, score_corrector=None,
524
- corrector_kwargs=None,
525
- unconditional_guidance_scale=self.guidance_scale,
526
- unconditional_conditioning=uc_full,
527
- dynamic_threshold=None)
528
- latents, pred_x0 = outs
529
- list_latents_out.append(latents.clone())
530
-
531
- if return_image:
532
  return self.latent2image(latents)
533
  else:
534
  return list_latents_out
535
 
536
  @torch.no_grad()
537
  def latent2image(
538
- self,
539
- latents: torch.FloatTensor
540
- ):
541
  r"""
542
  Returns an image provided a latent representation from diffusion.
543
  Args:
544
  latents: torch.FloatTensor
545
- Result of the diffusion process.
546
  """
547
  x_sample = self.model.decode_first_stage(latents)
548
  x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
549
- x_sample = 255 * x_sample[0,:,:].permute([1,2,0]).cpu().numpy()
550
  image = x_sample.astype(np.uint8)
551
  return image
552
-
553
- @torch.no_grad()
554
- def interpolate_spherical(p0, p1, fract_mixing: float):
555
- r"""
556
- Helper function to correctly mix two random variables using spherical interpolation.
557
- See https://en.wikipedia.org/wiki/Slerp
558
- The function will always cast up to float64 for sake of extra 4.
559
- Args:
560
- p0:
561
- First tensor for interpolation
562
- p1:
563
- Second tensor for interpolation
564
- fract_mixing: float
565
- Mixing coefficient of interval [0, 1].
566
- 0 will return in p0
567
- 1 will return in p1
568
- 0.x will return a mix between both preserving angular velocity.
569
- """
570
-
571
- if p0.dtype == torch.float16:
572
- recast_to = 'fp16'
573
- else:
574
- recast_to = 'fp32'
575
-
576
- p0 = p0.double()
577
- p1 = p1.double()
578
- norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
579
- epsilon = 1e-7
580
- dot = torch.sum(p0 * p1) / norm
581
- dot = dot.clamp(-1+epsilon, 1-epsilon)
582
-
583
- theta_0 = torch.arccos(dot)
584
- sin_theta_0 = torch.sin(theta_0)
585
- theta_t = theta_0 * fract_mixing
586
- s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
587
- s1 = torch.sin(theta_t) / sin_theta_0
588
- interp = p0*s0 + p1*s1
589
-
590
- if recast_to == 'fp16':
591
- interp = interp.half()
592
- elif recast_to == 'fp32':
593
- interp = interp.float()
594
-
595
- return interp
596
-
597
-
598
- if __name__ == "__main__":
599
-
600
-
601
-
602
-
603
-
604
-
605
- num_inference_steps = 20 # Number of diffusion interations
606
-
607
- # fp_ckpt = "../stable_diffusion_models/ckpt/768-v-ema.ckpt"
608
- # fp_config = '../stablediffusion/configs/stable-diffusion/v2-inference-v.yaml'
609
-
610
- # fp_ckpt= "../stable_diffusion_models/ckpt/512-inpainting-ema.ckpt"
611
- # fp_config = '../stablediffusion/configs//stable-diffusion/v2-inpainting-inference.yaml'
612
-
613
- fp_ckpt = "../stable_diffusion_models/ckpt/v2-1_768-ema-pruned.ckpt"
614
- # fp_config = 'configs/v2-inference-v.yaml'
615
-
616
-
617
- self = StableDiffusionHolder(fp_ckpt, num_inference_steps=num_inference_steps)
618
-
619
- xxx
620
-
621
- #%%
622
- self.width = 1536
623
- self.height = 768
624
- prompt = "360 degree equirectangular, a huge rocky hill full of pianos and keyboards, musical instruments, cinematic, masterpiece 8 k, artstation"
625
- self.set_negative_prompt("out of frame, faces, rendering, blurry")
626
- te = self.get_text_embedding(prompt)
627
-
628
- img = self.run_diffusion_standard(te, return_image=True)
629
- Image.fromarray(img).show()
630
-
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
 
16
+ import os
 
 
 
17
  import torch
18
  torch.backends.cudnn.benchmark = False
19
+ torch.set_grad_enabled(False)
20
  import numpy as np
21
  import warnings
22
  warnings.filterwarnings('ignore')
 
 
23
  import warnings
24
  import torch
 
25
  from PIL import Image
 
26
  import torch
27
+ from typing import Optional
 
 
 
 
 
28
  from omegaconf import OmegaConf
29
  from torch import autocast
30
  from contextlib import nullcontext
31
  from ldm.util import instantiate_from_config
32
  from ldm.models.diffusion.ddim import DDIMSampler
33
  from einops import repeat, rearrange
34
+ from utils import interpolate_spherical
35
 
36
 
37
  def pad_image(input_image):
42
  return im_padded
43
 
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def make_batch_superres(
46
  image,
47
  txt,
48
  device,
49
+ num_samples=1):
 
50
  image = np.array(image.convert("RGB"))
51
  image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
52
  batch = {
66
 
67
 
68
  class StableDiffusionHolder:
69
+ def __init__(self,
70
+ fp_ckpt: str = None,
71
  fp_config: str = None,
72
+ num_inference_steps: int = 30,
73
  height: Optional[int] = None,
74
  width: Optional[int] = None,
75
  device: str = None,
76
+ precision: str = 'autocast',
77
  ):
78
  r"""
79
  Initializes the stable diffusion holder, which contains the models and sampler.
81
  fp_ckpt: File pointer to the .ckpt model file
82
  fp_config: File pointer to the .yaml config file
83
  num_inference_steps: Number of diffusion iterations. Will be overwritten by latent blending.
84
+ height: Height of the resulting image.
85
+ width: Width of the resulting image.
86
  device: Device to run the model on.
87
  precision: Precision to run the model on.
88
  """
89
  self.seed = 42
90
  self.guidance_scale = 5.0
91
+
92
  if device is None:
93
  self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
94
  else:
95
  self.device = device
96
  self.precision = precision
97
  self.init_model(fp_ckpt, fp_config)
98
+
99
+ self.f = 8 # downsampling factor, most often 8 or 16"
100
  self.C = 4
101
  self.ddim_eta = 0
102
  self.num_inference_steps = num_inference_steps
103
+
104
  if height is None and width is None:
105
  self.init_auto_res()
106
  else:
108
  assert width is not None, "specify both width and height"
109
  self.height = height
110
  self.width = width
111
+
 
 
 
 
112
  self.negative_prompt = [""]
113
+
 
114
  def init_model(self, fp_ckpt, fp_config):
115
  r"""Loads the models and sampler.
116
  """
117
 
118
  assert os.path.isfile(fp_ckpt), f"Your model checkpoint file does not exist: {fp_ckpt}"
119
  self.fp_ckpt = fp_ckpt
120
+
121
  # Auto init the config?
122
  if fp_config is None:
123
  fn_ckpt = os.path.basename(fp_ckpt)
124
  if 'depth' in fn_ckpt:
125
  fp_config = 'configs/v2-midas-inference.yaml'
 
 
126
  elif 'upscaler' in fn_ckpt:
127
+ fp_config = 'configs/x4-upscaling.yaml'
128
  elif '512' in fn_ckpt:
129
+ fp_config = 'configs/v2-inference.yaml'
130
+ elif '768' in fn_ckpt:
131
+ fp_config = 'configs/v2-inference-v.yaml'
132
  elif 'v1-5' in fn_ckpt:
133
+ fp_config = 'configs/v1-inference.yaml'
134
  else:
135
  raise ValueError("auto detect of config failed. please specify fp_config manually!")
136
+
137
  assert os.path.isfile(fp_config), "Auto-init of the config file failed. Please specify manually."
138
+
139
  assert os.path.isfile(fp_config), f"Your config file does not exist: {fp_config}"
 
140
 
141
  config = OmegaConf.load(fp_config)
142
+
143
  self.model = instantiate_from_config(config.model)
144
  self.model.load_state_dict(torch.load(fp_ckpt)["state_dict"], strict=False)
145
 
146
  self.model = self.model.to(self.device)
147
  self.sampler = DDIMSampler(self.model)
148
+
 
149
  def init_auto_res(self):
150
  r"""Automatically set the resolution to the one used in training.
151
  """
155
  else:
156
  self.height = 512
157
  self.width = 512
158
+
159
  def set_negative_prompt(self, negative_prompt):
160
  r"""Set the negative prompt. Currenty only one negative prompt is supported
161
  """
164
  self.negative_prompt = [negative_prompt]
165
  else:
166
  self.negative_prompt = negative_prompt
167
+
168
  if len(self.negative_prompt) > 1:
169
  self.negative_prompt = [self.negative_prompt[0]]
170
 
 
171
  def get_text_embedding(self, prompt):
172
  c = self.model.get_learned_conditioning(prompt)
173
  return c
174
+
175
  @torch.no_grad()
176
  def get_cond_upscaling(self, image, text_embedding, noise_level):
177
  r"""
178
  Initializes the conditioning for the x4 upscaling model.
179
  """
 
180
  image = pad_image(image) # resize to integer multiple of 32
181
  w, h = image.size
182
  noise_level = torch.Tensor(1 * [noise_level]).to(self.sampler.model.device).long()
183
  batch = make_batch_superres(image, txt="placeholder", device=self.device, num_samples=1)
184
 
185
  x_augment, noise_level = make_noise_augmentation(self.model, batch, noise_level)
186
+
187
  cond = {"c_concat": [x_augment], "c_crossattn": [text_embedding], "c_adm": noise_level}
188
  # uncond cond
189
  uc_cross = self.model.get_unconditional_conditioning(1, "")
190
  uc_full = {"c_concat": [x_augment], "c_crossattn": [uc_cross], "c_adm": noise_level}
 
191
  return cond, uc_full
192
 
193
  @torch.no_grad()
194
  def run_diffusion_standard(
195
+ self,
196
+ text_embeddings: torch.FloatTensor,
197
  latents_start: torch.FloatTensor,
198
+ idx_start: int = 0,
199
+ list_latents_mixing=None,
200
+ mixing_coeffs=0.0,
201
+ spatial_mask=None,
202
+ return_image: Optional[bool] = False):
 
203
  r"""
204
+ Diffusion standard version.
 
205
  Args:
206
+ text_embeddings: torch.FloatTensor
207
  Text embeddings used for diffusion
208
  latents_for_injection: torch.FloatTensor or list
209
  Latents that are used for injection
215
  experimental feature for enforcing pixels from list_latents_mixing
216
  return_image: Optional[bool]
217
  Optionally return image directly
 
218
  """
 
219
  # Asserts
220
  if type(mixing_coeffs) == float:
221
+ list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
222
  elif type(mixing_coeffs) == list:
223
  assert len(mixing_coeffs) == self.num_inference_steps
224
  list_mixing_coeffs = mixing_coeffs
225
  else:
226
  raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
227
+
228
  if np.sum(list_mixing_coeffs) > 0:
229
  assert len(list_latents_mixing) == self.num_inference_steps
230
+
 
231
  precision_scope = autocast if self.precision == "autocast" else nullcontext
 
232
  with precision_scope("cuda"):
233
  with self.model.ema_scope():
234
  if self.guidance_scale != 1.0:
235
  uc = self.model.get_learned_conditioning(self.negative_prompt)
236
  else:
237
  uc = None
238
+ self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
 
 
239
  latents = latents_start.clone()
 
240
  timesteps = self.sampler.ddim_timesteps
 
241
  time_range = np.flip(timesteps)
242
  total_steps = timesteps.shape[0]
243
+ # Collect latents
 
244
  list_latents_out = []
245
  for i, step in enumerate(time_range):
246
  # Set the right starting latents
249
  continue
250
  elif i == idx_start:
251
  latents = latents_start.clone()
252
+ # Mix latents
253
+ if i > 0 and list_mixing_coeffs[i] > 0:
254
+ latents_mixtarget = list_latents_mixing[i - 1].clone()
 
255
  latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
256
+
257
  if spatial_mask is not None and list_latents_mixing is not None:
258
+ latents = interpolate_spherical(latents, list_latents_mixing[i - 1], 1 - spatial_mask)
259
+
 
260
  index = total_steps - i - 1
261
  ts = torch.full((1,), step, device=self.device, dtype=torch.long)
262
  outs = self.sampler.p_sample_ddim(latents, text_embeddings, ts, index=index, use_original_steps=False,
263
+ quantize_denoised=False, temperature=1.0,
264
+ noise_dropout=0.0, score_corrector=None,
265
+ corrector_kwargs=None,
266
+ unconditional_guidance_scale=self.guidance_scale,
267
+ unconditional_conditioning=uc,
268
+ dynamic_threshold=None)
269
  latents, pred_x0 = outs
270
  list_latents_out.append(latents.clone())
271
+ if return_image:
 
272
  return self.latent2image(latents)
273
  else:
274
  return list_latents_out
275
+
 
276
  @torch.no_grad()
277
  def run_diffusion_upscaling(
278
+ self,
279
  cond,
280
  uc_full,
281
+ latents_start: torch.FloatTensor,
282
+ idx_start: int = -1,
283
+ list_latents_mixing: list = None,
284
+ mixing_coeffs: float = 0.0,
285
+ return_image: Optional[bool] = False):
 
286
  r"""
287
+ Diffusion upscaling version.
288
  """
289
+
290
  # Asserts
291
  if type(mixing_coeffs) == float:
292
+ list_mixing_coeffs = self.num_inference_steps * [mixing_coeffs]
293
  elif type(mixing_coeffs) == list:
294
  assert len(mixing_coeffs) == self.num_inference_steps
295
  list_mixing_coeffs = mixing_coeffs
296
  else:
297
  raise ValueError("mixing_coeffs should be float or list with len=num_inference_steps")
298
+
299
  if np.sum(list_mixing_coeffs) > 0:
300
  assert len(list_latents_mixing) == self.num_inference_steps
301
+
302
  precision_scope = autocast if self.precision == "autocast" else nullcontext
303
+ h = uc_full['c_concat'][0].shape[2]
304
+ w = uc_full['c_concat'][0].shape[3]
 
 
305
  with precision_scope("cuda"):
306
  with self.model.ema_scope():
307
 
308
  shape_latents = [self.model.channels, h, w]
309
+ self.sampler.make_schedule(ddim_num_steps=self.num_inference_steps - 1, ddim_eta=self.ddim_eta, verbose=False)
 
310
  C, H, W = shape_latents
311
  size = (1, C, H, W)
312
  b = size[0]
 
313
  latents = latents_start.clone()
 
314
  timesteps = self.sampler.ddim_timesteps
 
315
  time_range = np.flip(timesteps)
316
  total_steps = timesteps.shape[0]
 
317
  # collect latents
318
  list_latents_out = []
319
  for i, step in enumerate(time_range):
323
  continue
324
  elif i == idx_start:
325
  latents = latents_start.clone()
326
+ # Mix the latents.
327
+ if i > 0 and list_mixing_coeffs[i] > 0:
328
+ latents_mixtarget = list_latents_mixing[i - 1].clone()
 
329
  latents = interpolate_spherical(latents, latents_mixtarget, list_mixing_coeffs[i])
 
330
  # print(f"diffusion iter {i}")
331
  index = total_steps - i - 1
332
  ts = torch.full((b,), step, device=self.device, dtype=torch.long)
333
  outs = self.sampler.p_sample_ddim(latents, cond, ts, index=index, use_original_steps=False,
334
+ quantize_denoised=False, temperature=1.0,
335
+ noise_dropout=0.0, score_corrector=None,
336
+ corrector_kwargs=None,
337
+ unconditional_guidance_scale=self.guidance_scale,
338
+ unconditional_conditioning=uc_full,
339
+ dynamic_threshold=None)
340
  latents, pred_x0 = outs
341
  list_latents_out.append(latents.clone())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
 
343
+ if return_image:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
  return self.latent2image(latents)
345
  else:
346
  return list_latents_out
347
 
348
  @torch.no_grad()
349
  def latent2image(
350
+ self,
351
+ latents: torch.FloatTensor):
 
352
  r"""
353
  Returns an image provided a latent representation from diffusion.
354
  Args:
355
  latents: torch.FloatTensor
356
+ Result of the diffusion process.
357
  """
358
  x_sample = self.model.decode_first_stage(latents)
359
  x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
360
+ x_sample = 255 * x_sample[0, :, :].permute([1, 2, 0]).cpu().numpy()
361
  image = x_sample.astype(np.uint8)
362
  return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 Lunar Ring. All rights reserved.
2
+ # Written by Johannes Stelzer, email stelzer@lunar-ring.ai twitter @j_stelzer
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ torch.backends.cudnn.benchmark = False
18
+ import numpy as np
19
+ import warnings
20
+ warnings.filterwarnings('ignore')
21
+ import time
22
+ import warnings
23
+ import datetime
24
+ from typing import List, Union
25
+ torch.set_grad_enabled(False)
26
+ import yaml
27
+
28
+
29
+ @torch.no_grad()
30
+ def interpolate_spherical(p0, p1, fract_mixing: float):
31
+ r"""
32
+ Helper function to correctly mix two random variables using spherical interpolation.
33
+ See https://en.wikipedia.org/wiki/Slerp
34
+ The function will always cast up to float64 for sake of extra 4.
35
+ Args:
36
+ p0:
37
+ First tensor for interpolation
38
+ p1:
39
+ Second tensor for interpolation
40
+ fract_mixing: float
41
+ Mixing coefficient of interval [0, 1].
42
+ 0 will return in p0
43
+ 1 will return in p1
44
+ 0.x will return a mix between both preserving angular velocity.
45
+ """
46
+
47
+ if p0.dtype == torch.float16:
48
+ recast_to = 'fp16'
49
+ else:
50
+ recast_to = 'fp32'
51
+
52
+ p0 = p0.double()
53
+ p1 = p1.double()
54
+ norm = torch.linalg.norm(p0) * torch.linalg.norm(p1)
55
+ epsilon = 1e-7
56
+ dot = torch.sum(p0 * p1) / norm
57
+ dot = dot.clamp(-1 + epsilon, 1 - epsilon)
58
+
59
+ theta_0 = torch.arccos(dot)
60
+ sin_theta_0 = torch.sin(theta_0)
61
+ theta_t = theta_0 * fract_mixing
62
+ s0 = torch.sin(theta_0 - theta_t) / sin_theta_0
63
+ s1 = torch.sin(theta_t) / sin_theta_0
64
+ interp = p0 * s0 + p1 * s1
65
+
66
+ if recast_to == 'fp16':
67
+ interp = interp.half()
68
+ elif recast_to == 'fp32':
69
+ interp = interp.float()
70
+
71
+ return interp
72
+
73
+
74
+ def interpolate_linear(p0, p1, fract_mixing):
75
+ r"""
76
+ Helper function to mix two variables using standard linear interpolation.
77
+ Args:
78
+ p0:
79
+ First tensor / np.ndarray for interpolation
80
+ p1:
81
+ Second tensor / np.ndarray for interpolation
82
+ fract_mixing: float
83
+ Mixing coefficient of interval [0, 1].
84
+ 0 will return in p0
85
+ 1 will return in p1
86
+ 0.x will return a linear mix between both.
87
+ """
88
+ reconvert_uint8 = False
89
+ if type(p0) is np.ndarray and p0.dtype == 'uint8':
90
+ reconvert_uint8 = True
91
+ p0 = p0.astype(np.float64)
92
+
93
+ if type(p1) is np.ndarray and p1.dtype == 'uint8':
94
+ reconvert_uint8 = True
95
+ p1 = p1.astype(np.float64)
96
+
97
+ interp = (1 - fract_mixing) * p0 + fract_mixing * p1
98
+
99
+ if reconvert_uint8:
100
+ interp = np.clip(interp, 0, 255).astype(np.uint8)
101
+
102
+ return interp
103
+
104
+
105
+ def add_frames_linear_interp(
106
+ list_imgs: List[np.ndarray],
107
+ fps_target: Union[float, int] = None,
108
+ duration_target: Union[float, int] = None,
109
+ nmb_frames_target: int = None):
110
+ r"""
111
+ Helper function to cheaply increase the number of frames given a list of images,
112
+ by virtue of standard linear interpolation.
113
+ The number of inserted frames will be automatically adjusted so that the total of number
114
+ of frames can be fixed precisely, using a random shuffling technique.
115
+ The function allows 1:1 comparisons between transitions as videos.
116
+
117
+ Args:
118
+ list_imgs: List[np.ndarray)
119
+ List of images, between each image new frames will be inserted via linear interpolation.
120
+ fps_target:
121
+ OptionA: specify here the desired frames per second.
122
+ duration_target:
123
+ OptionA: specify here the desired duration of the transition in seconds.
124
+ nmb_frames_target:
125
+ OptionB: directly fix the total number of frames of the output.
126
+ """
127
+
128
+ # Sanity
129
+ if nmb_frames_target is not None and fps_target is not None:
130
+ raise ValueError("You cannot specify both fps_target and nmb_frames_target")
131
+ if fps_target is None:
132
+ assert nmb_frames_target is not None, "Either specify nmb_frames_target or nmb_frames_target"
133
+ if nmb_frames_target is None:
134
+ assert fps_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
135
+ assert duration_target is not None, "Either specify duration_target and fps_target OR nmb_frames_target"
136
+ nmb_frames_target = fps_target * duration_target
137
+
138
+ # Get number of frames that are missing
139
+ nmb_frames_diff = len(list_imgs) - 1
140
+ nmb_frames_missing = nmb_frames_target - nmb_frames_diff - 1
141
+
142
+ if nmb_frames_missing < 1:
143
+ return list_imgs
144
+
145
+ list_imgs_float = [img.astype(np.float32) for img in list_imgs]
146
+ # Distribute missing frames, append nmb_frames_to_insert(i) frames for each frame
147
+ mean_nmb_frames_insert = nmb_frames_missing / nmb_frames_diff
148
+ constfact = np.floor(mean_nmb_frames_insert)
149
+ remainder_x = 1 - (mean_nmb_frames_insert - constfact)
150
+ nmb_iter = 0
151
+ while True:
152
+ nmb_frames_to_insert = np.random.rand(nmb_frames_diff)
153
+ nmb_frames_to_insert[nmb_frames_to_insert <= remainder_x] = 0
154
+ nmb_frames_to_insert[nmb_frames_to_insert > remainder_x] = 1
155
+ nmb_frames_to_insert += constfact
156
+ if np.sum(nmb_frames_to_insert) == nmb_frames_missing:
157
+ break
158
+ nmb_iter += 1
159
+ if nmb_iter > 100000:
160
+ print("add_frames_linear_interp: issue with inserting the right number of frames")
161
+ break
162
+
163
+ nmb_frames_to_insert = nmb_frames_to_insert.astype(np.int32)
164
+ list_imgs_interp = []
165
+ for i in range(len(list_imgs_float) - 1):
166
+ img0 = list_imgs_float[i]
167
+ img1 = list_imgs_float[i + 1]
168
+ list_imgs_interp.append(img0.astype(np.uint8))
169
+ list_fracts_linblend = np.linspace(0, 1, nmb_frames_to_insert[i] + 2)[1:-1]
170
+ for fract_linblend in list_fracts_linblend:
171
+ img_blend = interpolate_linear(img0, img1, fract_linblend).astype(np.uint8)
172
+ list_imgs_interp.append(img_blend.astype(np.uint8))
173
+ if i == len(list_imgs_float) - 2:
174
+ list_imgs_interp.append(img1.astype(np.uint8))
175
+
176
+ return list_imgs_interp
177
+
178
+
179
+ def get_spacing(nmb_points: int, scaling: float):
180
+ """
181
+ Helper function for getting nonlinear spacing between 0 and 1, symmetric around 0.5
182
+ Args:
183
+ nmb_points: int
184
+ Number of points between [0, 1]
185
+ scaling: float
186
+ Higher values will return higher sampling density around 0.5
187
+ """
188
+ if scaling < 1.7:
189
+ return np.linspace(0, 1, nmb_points)
190
+ nmb_points_per_side = nmb_points // 2 + 1
191
+ if np.mod(nmb_points, 2) != 0: # Uneven case
192
+ left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)
193
+ right_side = 1 - left_side[::-1][1:]
194
+ else:
195
+ left_side = np.abs(np.linspace(1, 0, nmb_points_per_side)**scaling / 2 - 0.5)[0:-1]
196
+ right_side = 1 - left_side[::-1]
197
+ all_fracts = np.hstack([left_side, right_side])
198
+ return all_fracts
199
+
200
+
201
+ def get_time(resolution=None):
202
+ """
203
+ Helper function returning an nicely formatted time string, e.g. 221117_1620
204
+ """
205
+ if resolution is None:
206
+ resolution = "second"
207
+ if resolution == "day":
208
+ t = time.strftime('%y%m%d', time.localtime())
209
+ elif resolution == "minute":
210
+ t = time.strftime('%y%m%d_%H%M', time.localtime())
211
+ elif resolution == "second":
212
+ t = time.strftime('%y%m%d_%H%M%S', time.localtime())
213
+ elif resolution == "millisecond":
214
+ t = time.strftime('%y%m%d_%H%M%S', time.localtime())
215
+ t += "_"
216
+ t += str("{:03d}".format(int(int(datetime.utcnow().strftime('%f')) / 1000)))
217
+ else:
218
+ raise ValueError("bad resolution provided: %s" % resolution)
219
+ return t
220
+
221
+
222
+ def compare_dicts(a, b):
223
+ """
224
+ Compares two dictionaries a and b and returns a dictionary c, with all
225
+ keys,values that have shared keys in a and b but same values in a and b.
226
+ The values of a and b are stacked together in the output.
227
+ Example:
228
+ a = {}; a['bobo'] = 4
229
+ b = {}; b['bobo'] = 5
230
+ c = dict_compare(a,b)
231
+ c = {"bobo",[4,5]}
232
+ """
233
+ c = {}
234
+ for key in a.keys():
235
+ if key in b.keys():
236
+ val_a = a[key]
237
+ val_b = b[key]
238
+ if val_a != val_b:
239
+ c[key] = [val_a, val_b]
240
+ return c
241
+
242
+
243
+ def yml_load(fp_yml, print_fields=False):
244
+ """
245
+ Helper function for loading yaml files
246
+ """
247
+ with open(fp_yml) as f:
248
+ data = yaml.load(f, Loader=yaml.loader.SafeLoader)
249
+ dict_data = dict(data)
250
+ print("load: loaded {}".format(fp_yml))
251
+ return dict_data
252
+
253
+
254
+ def yml_save(fp_yml, dict_stuff):
255
+ """
256
+ Helper function for saving yaml files
257
+ """
258
+ with open(fp_yml, 'w') as f:
259
+ yaml.dump(dict_stuff, f, sort_keys=False, default_flow_style=False)
260
+ print("yml_save: saved {}".format(fp_yml))