multimodalart HF staff commited on
Commit
88a2ed3
1 Parent(s): a767e49

revert to working state

Browse files
Files changed (2) hide show
  1. app.py +16 -81
  2. utils.py +3 -13
app.py CHANGED
@@ -6,7 +6,6 @@ from utils import video_to_frames, add_dict_to_yaml_file, save_video, seed_every
6
  from tokenflow_pnp import TokenFlow
7
  from preprocess_utils import *
8
  from tokenflow_utils import *
9
- import math
10
  # load sd model
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  model_id = "stabilityai/stable-diffusion-2-1-base"
@@ -52,11 +51,6 @@ def get_example():
52
  ]
53
  return case
54
 
55
- def largest_divisor(n):
56
- for i in range(2, int(math.sqrt(n)) + 1):
57
- if n % i == 0:
58
- return n // i
59
- return n
60
 
61
  def prep(config):
62
  # timesteps to save
@@ -108,26 +102,7 @@ def prep(config):
108
 
109
 
110
  return frames, latents, total_inverted_latents, rgb_reconstruction
111
-
112
-
113
- def calculate_fps(input_video, batch_size):
114
- frames, frames_per_second = video_to_frames(input_video)
115
- #total_vid_frames = len(frames)
116
- #total_vid_duration = total_vid_frames/frames_per_second
117
-
118
- #if(total_vid_duration < 1):
119
- # frames_to_process = total_vid_frames
120
- #else:
121
- # frames_to_process = int(frames_per_second/n_seconds)
122
- #
123
- #if frames_to_process % batch_size != 0:
124
- # batch_size = largest_divisor(batch_size)
125
- #print("total vid duration", total_vid_duration)
126
- #print("frames to process", frames_to_process)
127
- #print("batch size", batch_size)
128
- print("fps", frames_per_second)
129
- return frames, frames_per_second
130
-
131
  def preprocess_and_invert(input_video,
132
  frames,
133
  latents,
@@ -140,8 +115,6 @@ def preprocess_and_invert(input_video,
140
  n_timesteps = 50,
141
  batch_size: int = 8,
142
  n_frames: int = 40,
143
- n_seconds: int = 1,
144
- n_fps_input: int = 40,
145
  inversion_prompt:str = '',
146
 
147
  ):
@@ -161,31 +134,10 @@ def preprocess_and_invert(input_video,
161
  preprocess_config['n_frames'] = n_frames
162
  preprocess_config['seed'] = seed
163
  preprocess_config['inversion_prompt'] = inversion_prompt
164
- not_processed = False
165
- if(not frames):
166
- preprocess_config['frames'],n_fps_input = video_to_frames(input_video)
167
- not_processed = True
168
- else:
169
- preprocess_config['frames'] = frames
170
-
171
- print("pre-process fps ", n_fps_input)
172
  preprocess_config['data_path'] = input_video.split(".")[0]
173
 
174
- total_vid_frames = len(preprocess_config['frames'])
175
- print("total frames", total_vid_frames)
176
- total_vid_duration = total_vid_frames/n_fps_input
177
-
178
- if(total_vid_duration < 1):
179
- preprocess_config['n_frames'] = total_vid_frames
180
- else:
181
- preprocess_config['n_frames'] = int(n_fps_input/n_seconds)
182
-
183
- if preprocess_config['n_frames'] % batch_size != 0:
184
- preprocess_config['batch_size'] = largest_divisor(batch_size)
185
 
186
- print("Running with batch size of ", preprocess_config['batch_size'])
187
- print("Total vid frames", preprocess_config['n_frames'])
188
-
189
  if randomize_seed:
190
  seed = randomize_seed_fn()
191
  seed_everything(seed)
@@ -198,7 +150,7 @@ def preprocess_and_invert(input_video,
198
  inverted_latents = gr.State(value=total_inverted_latents)
199
  do_inversion = False
200
 
201
- return frames, latents, inverted_latents, do_inversion, preprocess_config['batch_size'], preprocess_config['n_frames']
202
 
203
 
204
  def edit_with_pnp(input_video,
@@ -215,8 +167,6 @@ def edit_with_pnp(input_video,
215
  pnp_f_t: float = 0.8,
216
  batch_size: int = 8, #needs to be the same as for preprocess
217
  n_frames: int = 40,#needs to be the same as for preprocess
218
- n_seconds: int = 1,
219
- n_fps_input: int = 40,
220
  n_timesteps: int = 50,
221
  gudiance_scale: float = 7.5,
222
  inversion_prompt: str = "", #needs to be the same as for preprocess
@@ -236,12 +186,10 @@ def edit_with_pnp(input_video,
236
  config["pnp_attn_t"] = pnp_attn_t
237
  config["pnp_f_t"] = pnp_f_t
238
  config["pnp_inversion_prompt"] = inversion_prompt
239
-
240
- print("Running with batch size of ", config['batch_size'])
241
- print("Total vid frames", config['n_frames'])
242
 
243
  if do_inversion:
244
- frames, latents, inverted_latents, do_inversion, batch_size, n_frames = preprocess_and_invert(
245
  input_video,
246
  frames,
247
  latents,
@@ -253,11 +201,7 @@ def edit_with_pnp(input_video,
253
  n_timesteps,
254
  batch_size,
255
  n_frames,
256
- n_seconds,
257
- n_fps_input,
258
  inversion_prompt)
259
- config["batch_size"] = batch_size
260
- config["n_frames"] = n_frames
261
  do_inversion = False
262
 
263
 
@@ -277,6 +221,7 @@ def edit_with_pnp(input_video,
277
  # demo #
278
  ########
279
 
 
280
  intro = """
281
  <div style="text-align:center">
282
  <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
@@ -288,6 +233,8 @@ intro = """
288
  </div>
289
  """
290
 
 
 
291
  with gr.Blocks(css="style.css") as demo:
292
 
293
  gr.HTML(intro)
@@ -295,8 +242,7 @@ with gr.Blocks(css="style.css") as demo:
295
  inverted_latents = gr.State()
296
  latents = gr.State()
297
  do_inversion = gr.State(value=True)
298
- n_fps_input = gr.State()
299
-
300
  with gr.Row():
301
  input_video = gr.Video(label="Input Video", interactive=True, elem_id="input_video")
302
  output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video")
@@ -336,19 +282,15 @@ with gr.Blocks(css="style.css") as demo:
336
 
337
  with gr.Column(min_width=100):
338
  inversion_prompt = gr.Textbox(lines=1, label="Inversion prompt", interactive=True, placeholder="")
339
- batch_size = gr.Slider(label='Batch size', minimum=1, maximum=100,
340
- value=8, step=1, interactive=True, visible=False)
341
  n_frames = gr.Slider(label='Num frames', minimum=2, maximum=200,
342
- value=24, step=1, interactive=True, visible=False)
343
- n_seconds = gr.Slider(label='Num seconds', info="How many seconds of your video to process",
344
- minimum=1, maximum=2, step=1)
345
  n_timesteps = gr.Slider(label='Diffusion steps', minimum=25, maximum=100,
346
  value=50, step=25, interactive=True)
347
- #n_fps_input = gr.Slider(label="Input frames per second", value=40, minimum=1, maximum=120)
348
- n_fps = gr.Slider(label='Output frames per second', minimum=1, maximum=60,
349
  value=10, step=1, interactive=True)
350
 
351
-
352
  with gr.TabItem('Plug-and-Play Parameters'):
353
  with gr.Column(min_width=100):
354
  pnp_attn_t = gr.Slider(label='pnp attention threshold', minimum=0, maximum=1,
@@ -382,7 +324,7 @@ with gr.Blocks(css="style.css") as demo:
382
  input_video.upload(
383
  fn = reset_do_inversion,
384
  outputs = [do_inversion],
385
- queue = False).then(fn = calculate_fps, inputs=[input_video], outputs=[frames, n_fps_input], queue=False).then(fn = preprocess_and_invert,
386
  inputs = [input_video,
387
  frames,
388
  latents,
@@ -394,20 +336,15 @@ with gr.Blocks(css="style.css") as demo:
394
  n_timesteps,
395
  batch_size,
396
  n_frames,
397
- n_seconds,
398
- n_fps_input,
399
  inversion_prompt
400
  ],
401
  outputs = [frames,
402
  latents,
403
  inverted_latents,
404
- do_inversion,
405
- batch_size,
406
- n_frames
407
  ])
408
 
409
- input_video.change(fn = calculate_fps, inputs=[input_video], outputs=[frames, n_fps_input], queue=False)
410
-
411
  run_button.click(fn = edit_with_pnp,
412
  inputs = [input_video,
413
  frames,
@@ -422,8 +359,6 @@ with gr.Blocks(css="style.css") as demo:
422
  pnp_f_t,
423
  batch_size,
424
  n_frames,
425
- n_seconds,
426
- n_fps_input,
427
  n_timesteps,
428
  gudiance_scale,
429
  inversion_prompt,
 
6
  from tokenflow_pnp import TokenFlow
7
  from preprocess_utils import *
8
  from tokenflow_utils import *
 
9
  # load sd model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
  model_id = "stabilityai/stable-diffusion-2-1-base"
 
51
  ]
52
  return case
53
 
 
 
 
 
 
54
 
55
  def prep(config):
56
  # timesteps to save
 
102
 
103
 
104
  return frames, latents, total_inverted_latents, rgb_reconstruction
105
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def preprocess_and_invert(input_video,
107
  frames,
108
  latents,
 
115
  n_timesteps = 50,
116
  batch_size: int = 8,
117
  n_frames: int = 40,
 
 
118
  inversion_prompt:str = '',
119
 
120
  ):
 
134
  preprocess_config['n_frames'] = n_frames
135
  preprocess_config['seed'] = seed
136
  preprocess_config['inversion_prompt'] = inversion_prompt
137
+ preprocess_config['frames'] = video_to_frames(input_video)
 
 
 
 
 
 
 
138
  preprocess_config['data_path'] = input_video.split(".")[0]
139
 
 
 
 
 
 
 
 
 
 
 
 
140
 
 
 
 
141
  if randomize_seed:
142
  seed = randomize_seed_fn()
143
  seed_everything(seed)
 
150
  inverted_latents = gr.State(value=total_inverted_latents)
151
  do_inversion = False
152
 
153
+ return frames, latents, inverted_latents, do_inversion
154
 
155
 
156
  def edit_with_pnp(input_video,
 
167
  pnp_f_t: float = 0.8,
168
  batch_size: int = 8, #needs to be the same as for preprocess
169
  n_frames: int = 40,#needs to be the same as for preprocess
 
 
170
  n_timesteps: int = 50,
171
  gudiance_scale: float = 7.5,
172
  inversion_prompt: str = "", #needs to be the same as for preprocess
 
186
  config["pnp_attn_t"] = pnp_attn_t
187
  config["pnp_f_t"] = pnp_f_t
188
  config["pnp_inversion_prompt"] = inversion_prompt
189
+
 
 
190
 
191
  if do_inversion:
192
+ frames, latents, inverted_latents, do_inversion = preprocess_and_invert(
193
  input_video,
194
  frames,
195
  latents,
 
201
  n_timesteps,
202
  batch_size,
203
  n_frames,
 
 
204
  inversion_prompt)
 
 
205
  do_inversion = False
206
 
207
 
 
221
  # demo #
222
  ########
223
 
224
+
225
  intro = """
226
  <div style="text-align:center">
227
  <h1 style="font-weight: 1400; text-align: center; margin-bottom: 7px;">
 
233
  </div>
234
  """
235
 
236
+
237
+
238
  with gr.Blocks(css="style.css") as demo:
239
 
240
  gr.HTML(intro)
 
242
  inverted_latents = gr.State()
243
  latents = gr.State()
244
  do_inversion = gr.State(value=True)
245
+
 
246
  with gr.Row():
247
  input_video = gr.Video(label="Input Video", interactive=True, elem_id="input_video")
248
  output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video")
 
282
 
283
  with gr.Column(min_width=100):
284
  inversion_prompt = gr.Textbox(lines=1, label="Inversion prompt", interactive=True, placeholder="")
285
+ batch_size = gr.Slider(label='Batch size', minimum=1, maximum=10,
286
+ value=8, step=1, interactive=True)
287
  n_frames = gr.Slider(label='Num frames', minimum=2, maximum=200,
288
+ value=24, step=1, interactive=True)
 
 
289
  n_timesteps = gr.Slider(label='Diffusion steps', minimum=25, maximum=100,
290
  value=50, step=25, interactive=True)
291
+ n_fps = gr.Slider(label='Frames per second', minimum=1, maximum=60,
 
292
  value=10, step=1, interactive=True)
293
 
 
294
  with gr.TabItem('Plug-and-Play Parameters'):
295
  with gr.Column(min_width=100):
296
  pnp_attn_t = gr.Slider(label='pnp attention threshold', minimum=0, maximum=1,
 
324
  input_video.upload(
325
  fn = reset_do_inversion,
326
  outputs = [do_inversion],
327
+ queue = False).then(fn = preprocess_and_invert,
328
  inputs = [input_video,
329
  frames,
330
  latents,
 
336
  n_timesteps,
337
  batch_size,
338
  n_frames,
 
 
339
  inversion_prompt
340
  ],
341
  outputs = [frames,
342
  latents,
343
  inverted_latents,
344
+ do_inversion
345
+
 
346
  ])
347
 
 
 
348
  run_button.click(fn = edit_with_pnp,
349
  inputs = [input_video,
350
  frames,
 
359
  pnp_f_t,
360
  batch_size,
361
  n_frames,
 
 
362
  n_timesteps,
363
  gudiance_scale,
364
  inversion_prompt,
utils.py CHANGED
@@ -16,7 +16,7 @@ from kornia.utils.grid import create_meshgrid
16
  import cv2
17
 
18
  def save_video_frames(video_path, img_size=(512,512)):
19
- video, _, = read_video(video_path, output_format="TCHW")
20
  # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
21
  if video_path.endswith('.mov'):
22
  video = T.functional.rotate(video, -90)
@@ -29,8 +29,7 @@ def save_video_frames(video_path, img_size=(512,512)):
29
  image_resized.save(f'data/{video_name}/{ind}.png')
30
 
31
  def video_to_frames(video_path, img_size=(512,512)):
32
- video, _, video_info = read_video(video_path, output_format="TCHW")
33
-
34
  # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
35
  if video_path.endswith('.mov'):
36
  video = T.functional.rotate(video, -90)
@@ -40,19 +39,10 @@ def video_to_frames(video_path, img_size=(512,512)):
40
  for i in range(len(video)):
41
  ind = str(i).zfill(5)
42
  image = T.ToPILImage()(video[i])
43
-
44
- # get new height and width to maintain aspect ratio
45
- height, width = image.size
46
- new_height = math.floor(img_size[0] * height / width)
47
- new_width = math.floor(img_size[1] * width / height)
48
-
49
- # pad
50
- image = Image.new(image.mode, (new_width, new_height), (0, 0, 0))
51
-
52
  image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
53
  # image_resized.save(f'data/{video_name}/{ind}.png')
54
  frames.append(image_resized)
55
- return frames, video_info["video_fps"]
56
 
57
  def add_dict_to_yaml_file(file_path, key, value):
58
  data = {}
 
16
  import cv2
17
 
18
  def save_video_frames(video_path, img_size=(512,512)):
19
+ video, _, _ = read_video(video_path, output_format="TCHW")
20
  # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
21
  if video_path.endswith('.mov'):
22
  video = T.functional.rotate(video, -90)
 
29
  image_resized.save(f'data/{video_name}/{ind}.png')
30
 
31
  def video_to_frames(video_path, img_size=(512,512)):
32
+ video, _, _ = read_video(video_path, output_format="TCHW")
 
33
  # rotate video -90 degree if video is .mov format. this is a weird bug in torchvision
34
  if video_path.endswith('.mov'):
35
  video = T.functional.rotate(video, -90)
 
39
  for i in range(len(video)):
40
  ind = str(i).zfill(5)
41
  image = T.ToPILImage()(video[i])
 
 
 
 
 
 
 
 
 
42
  image_resized = image.resize((img_size), resample=Image.Resampling.LANCZOS)
43
  # image_resized.save(f'data/{video_name}/{ind}.png')
44
  frames.append(image_resized)
45
+ return frames
46
 
47
  def add_dict_to_yaml_file(file_path, key, value):
48
  data = {}