Linoy Tsaban commited on
Commit
4e5195b
1 Parent(s): ba508b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +202 -75
app.py CHANGED
@@ -1,18 +1,14 @@
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler
4
- from utils import *
5
-
6
-
7
-
8
-
9
-
10
 
11
  # load sd model
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
- model_id = "stabilityai/stable-diffusion-2-1-base"
14
- inv_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
15
- inv_pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
16
 
17
  def randomize_seed_fn():
18
  seed = random.randint(0, np.iinfo(np.int32).max)
@@ -21,69 +17,173 @@ def randomize_seed_fn():
21
  def reset_do_inversion():
22
  return True
23
 
24
- def get_example():
25
- case = [
26
- [
27
- 'examples/wolf.mp4',
28
 
29
- ],
30
- [
31
- 'examples/woman-running.mp4',
32
 
33
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- ]
36
- return case
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- def preprocess_and_invert(video,
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  frames,
40
  latents,
41
  inverted_latents,
42
  seed,
43
  randomize_seed,
44
  do_inversion,
45
- height:int = 512,
46
- weidth: int = 512,
47
  # save_dir: str = "latents",
48
  steps: int = 500,
49
  batch_size: int = 8,
50
  n_frames: int = 40,
51
  inversion_prompt:str = '',
52
- save_steps: int = 50,
53
  ):
 
 
 
 
54
 
55
  if do_inversion or randomize_seed:
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
- # save_video_frames(data_path, img_size=(height, weidth))
58
- frames = video_to_frames(video, img_size=(height, weidth))
59
- # data_path = os.path.join('data', Path(video_path).stem)
60
-
61
- toy_scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
62
- toy_scheduler.set_timesteps(save_steps)
63
- timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=save_steps,
64
- strength=1.0,
65
- device=device)
66
  if randomize_seed:
67
  seed = randomize_seed_fn()
68
  seed_everything(seed)
69
 
70
- frames, latents = get_data(inv_pipe, frames, n_frames)
71
-
72
- inverted_latents = extract_latents(inv_pipe, num_steps = steps,
73
- latent_frames = latents,
74
- batch_size = batch_size,
75
- timesteps_to_save = timesteps_to_save,
76
- inversion_prompt = inversion_prompt,)
77
  frames = gr.State(value=frames)
78
  latents = gr.State(value=latents)
79
- inverted_latents = gr.State(value=inverted_latents)
80
  do_inversion = False
81
 
82
- # temp to check something
83
- output_vid = frames
84
- return frames, latents, inverted_latents, do_inversion, output_vid
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
 
 
 
87
 
88
  ########
89
  # demo #
@@ -107,14 +207,14 @@ with gr.Blocks(css="style.css") as demo:
107
  do_inversion = gr.State(value=True)
108
 
109
  with gr.Row():
110
- input_vid = gr.Video(label="Input Video", interactive=True, elem_id="input_video")
111
- output_vid = gr.Video(label="Edited Video", interactive=False, elem_id="output_video")
112
- input_vid.style(height=365, width=365)
113
- output_vid.style(height=365, width=365)
114
 
115
 
116
  with gr.Row():
117
- tar_prompt = gr.Textbox(
118
  label="Describe your edited video",
119
  max_lines=1, value=""
120
  )
@@ -132,33 +232,41 @@ with gr.Blocks(css="style.css") as demo:
132
  run_button = gr.Button("Edit your video!", visible=True)
133
 
134
  with gr.Accordion("Advanced Options", open=False):
135
- with gr.Tabs() as tabs:
136
-
137
- with gr.TabItem('General options', id=2):
138
- with gr.Row():
139
- with gr.Column(min_width=100):
140
- seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
141
- randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
142
- steps = gr.Slider(label='Inversion steps', minimum=100, maximum=500,
143
- value=500, step=1, interactive=True)
144
- with gr.Column(min_width=100):
145
- inversion_prompt = gr.Textbox(lines=1, label="Inversion prompt", interactive=True, placeholder="")
146
- batch_size = gr.Slider(label='Batch size', minimum=1, maximum=10,
147
- value=8, step=1, interactive=True)
148
- n_frames = gr.Slider(label='Num frames', minimum=20, maximum=200,
149
- value=40, step=1, interactive=True)
 
 
 
 
 
 
 
 
150
 
151
 
152
- input_vid.change(
153
  fn = reset_do_inversion,
154
  outputs = [do_inversion],
155
  queue = False)
156
 
157
- input_vid.upload(
158
  fn = reset_do_inversion,
159
  outputs = [do_inversion],
160
  queue = False).then(fn = preprocess_and_invert,
161
- inputs = [input_vid,
162
  frames,
163
  latents,
164
  inverted_latents,
@@ -173,19 +281,38 @@ with gr.Blocks(css="style.css") as demo:
173
  outputs = [frames,
174
  latents,
175
  inverted_latents,
176
- do_inversion,
177
- output_vid
178
 
179
  ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- gr.Examples(
182
- examples=get_example(),
183
- label='Examples',
184
- inputs=[input_vid],
185
- outputs=[input_vid]
186
- )
187
 
188
 
189
 
190
  demo.queue()
191
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from diffusers import StableDiffusionPipeline, DDIMScheduler
4
+ # from utils import *
5
+ from diffusers.utils import export_to_video
 
 
 
 
6
 
7
  # load sd model
8
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ # model_id = "stabilityai/stable-diffusion-2-1-base"
10
+ # inv_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
11
+ # inv_pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
12
 
13
  def randomize_seed_fn():
14
  seed = random.randint(0, np.iinfo(np.int32).max)
 
17
  def reset_do_inversion():
18
  return True
19
 
20
+ # def get_example():
21
+ # case = [
22
+ # [
23
+ # 'examples/wolf.mp4',
24
 
25
+ # ],
26
+ # [
27
+ # 'examples/woman-running.mp4',
28
 
29
+ # ],
30
+
31
+ # ]
32
+ # return case
33
+
34
+
35
+ def prep(config):
36
+ # timesteps to save
37
+ if config["sd_version"] == '2.1':
38
+ model_key = "stabilityai/stable-diffusion-2-1-base"
39
+ elif config["sd_version"] == '2.0':
40
+ model_key = "stabilityai/stable-diffusion-2-base"
41
+ elif config["sd_version"] == '1.5' or config["sd_version"] == 'ControlNet':
42
+ model_key = "runwayml/stable-diffusion-v1-5"
43
+ elif config["sd_version"] == 'depth':
44
+ model_key = "stabilityai/stable-diffusion-2-depth"
45
+ toy_scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler")
46
+ toy_scheduler.set_timesteps(config["save_steps"])
47
+ timesteps_to_save, num_inference_steps = get_timesteps(toy_scheduler, num_inference_steps=config["save_steps"],
48
+ strength=1.0,
49
+ device=device)
50
 
51
+ # seed_everything(config["seed"])
52
+ if not config["frames"]: # original non demo setting
53
+ save_path = os.path.join(config["save_dir"],
54
+ f'sd_{config["sd_version"]}',
55
+ Path(config["data_path"]).stem,
56
+ f'steps_{config["steps"]}',
57
+ f'nframes_{config["n_frames"]}')
58
+ os.makedirs(os.path.join(save_path, f'latents'), exist_ok=True)
59
+ add_dict_to_yaml_file(os.path.join(config["save_dir"], 'inversion_prompts.yaml'), Path(config["data_path"]).stem, config["inversion_prompt"])
60
+ # save inversion prompt in a txt file
61
+ with open(os.path.join(save_path, 'inversion_prompt.txt'), 'w') as f:
62
+ f.write(config["inversion_prompt"])
63
+ else:
64
+ save_path = None
65
 
66
+ model = Preprocess(device, config)
67
+ print(type(model.config["batch_size"]))
68
+ frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
69
+ num_steps=model.config["steps"],
70
+ save_path=save_path,
71
+ batch_size=model.config["batch_size"],
72
+ timesteps_to_save=timesteps_to_save,
73
+ inversion_prompt=model.config["inversion_prompt"],
74
+ )
75
+
76
+
77
+ return frames, latents, total_inverted_latents, rgb_reconstruction
78
+
79
+ def preprocess_and_invert(input_video,
80
  frames,
81
  latents,
82
  inverted_latents,
83
  seed,
84
  randomize_seed,
85
  do_inversion,
 
 
86
  # save_dir: str = "latents",
87
  steps: int = 500,
88
  batch_size: int = 8,
89
  n_frames: int = 40,
90
  inversion_prompt:str = '',
91
+
92
  ):
93
+ sd_version = "2.1"
94
+ height = 512
95
+ weidth: int = 512
96
+ save_steps = 50
97
 
98
  if do_inversion or randomize_seed:
99
+ preprocess_config = {}
100
+ preprocess_config['H'] = height
101
+ preprocess_config['W'] = weidth
102
+ preprocess_config['save_dir'] = 'latents'
103
+ preprocess_config['sd_version'] = sd_version
104
+ preprocess_config['steps'] = steps
105
+ preprocess_config['batch_size'] = batch_size
106
+ preprocess_config['save_steps'] = save_steps
107
+ preprocess_config['n_frames'] = n_frames
108
+ preprocess_config['seed'] = seed
109
+ preprocess_config['inversion_prompt'] = inversion_prompt
110
+ preprocess_config['frames'] = video_to_frames(input_video)
111
+ preprocess_config['data_path'] = input_video.split(".")[0]
112
 
113
+
 
 
 
 
 
 
 
 
114
  if randomize_seed:
115
  seed = randomize_seed_fn()
116
  seed_everything(seed)
117
 
118
+ frames, latents, total_inverted_latents, rgb_reconstruction = prep(preprocess_config)
 
 
 
 
 
 
119
  frames = gr.State(value=frames)
120
  latents = gr.State(value=latents)
121
+ inverted_latents = gr.State(value=total_inverted_latents)
122
  do_inversion = False
123
 
124
+ return frames, latents, inverted_latents, do_inversion
125
+
 
126
 
127
+ def edit_with_pnp(input_video,
128
+ frames,
129
+ latents,
130
+ inverted_latents,
131
+ seed,
132
+ randomize_seed,
133
+ do_inversion,
134
+ steps,
135
+ prompt: str = "a marble sculpture of a woman running, Venus de Milo",
136
+ # negative_prompt: str = "ugly, blurry, low res, unrealistic, unaesthetic",
137
+ pnp_attn_t: float = 0.5,
138
+ pnp_f_t: float = 0.8,
139
+ batch_size: int = 8, #needs to be the same as for preprocess
140
+ n_frames: int = 40,#needs to be the same as for preprocess
141
+ n_timesteps: int = 50,
142
+ gudiance_scale: float = 7.5,
143
+ inversion_prompt: str = ""#needs to be the same as for preprocess
144
+ ):
145
+ config = {}
146
+
147
+ config["sd_version"] = "2.1"
148
+ config["device"] = device
149
+ config["n_timesteps"] = n_timesteps
150
+ config["n_frames"] = n_frames
151
+ config["batch_size"] = batch_size
152
+ config["guidance_scale"] = gudiance_scale
153
+ config["prompt"] = prompt
154
+ config["negative_prompt"] = "ugly, blurry, low res, unrealistic, unaesthetic",
155
+ config["pnp_attn_t"] = pnp_attn_t
156
+ config["pnp_f_t"] = pnp_f_t
157
+ config["pnp_inversion_prompt"] = inversion_prompt
158
+
159
+
160
+ if do_inversion:
161
+ frames, latents, inverted_latents, do_inversion = preprocess_and_invert(
162
+ input_video,
163
+ frames,
164
+ latents,
165
+ inverted_latents,
166
+ seed,
167
+ randomize_seed,
168
+ do_inversion,
169
+ steps,
170
+ batch_size,
171
+ n_frames,
172
+ inversion_prompt)
173
+ do_inversion = False
174
+
175
+
176
+ if randomize_seed:
177
+ seed = randomize_seed_fn()
178
+ seed_everything(seed)
179
+
180
+
181
+ editor = TokenFlow(config=config, frames=frames.value, inverted_latents=inverted_latents.value)
182
+ edited_frames = editor.edit_video()
183
 
184
+ save_video(edited_frames, 'tokenflow_PnP_fps_30.mp4', fps=30)
185
+ # path = export_to_video(edited_frames)
186
+ return 'tokenflow_PnP_fps_30.mp4', frames, latents, inverted_latents, do_inversion
187
 
188
  ########
189
  # demo #
 
207
  do_inversion = gr.State(value=True)
208
 
209
  with gr.Row():
210
+ input_video = gr.Video(label="Input Video", interactive=True, elem_id="input_video")
211
+ output_video = gr.Video(label="Edited Video", interactive=False, elem_id="output_video")
212
+ input_video.style(height=365, width=365)
213
+ output_video.style(height=365, width=365)
214
 
215
 
216
  with gr.Row():
217
+ prompt = gr.Textbox(
218
  label="Describe your edited video",
219
  max_lines=1, value=""
220
  )
 
232
  run_button = gr.Button("Edit your video!", visible=True)
233
 
234
  with gr.Accordion("Advanced Options", open=False):
235
+ with gr.Tabs() as tabs:
236
+ with gr.TabItem('General options', id=2):
237
+ with gr.Row():
238
+ with gr.Column(min_width=100):
239
+ seed = gr.Number(value=0, precision=0, label="Seed", interactive=True)
240
+ randomize_seed = gr.Checkbox(label='Randomize seed', value=False)
241
+ gudiance_scale = gr.Slider(label='Guidance Scale', minimum=1, maximum=30,
242
+ value=7.5, step=0.5, interactive=True)
243
+ steps = gr.Slider(label='Inversion steps', minimum=100, maximum=500,
244
+ value=500, step=1, interactive=True)
245
+ n_timesteps = gr.Slider(label='Diffusion steps', minimum=25, maximum=100,
246
+ value=50, step=1, interactive=True)
247
+
248
+ with gr.Column(min_width=100):
249
+ inversion_prompt = gr.Textbox(lines=1, label="Inversion prompt", interactive=True, placeholder="")
250
+ batch_size = gr.Slider(label='Batch size', minimum=1, maximum=10,
251
+ value=8, step=1, interactive=True)
252
+ n_frames = gr.Slider(label='Num frames', minimum=20, maximum=200,
253
+ value=40, step=1, interactive=True)
254
+ pnp_attn_t = gr.Slider(label='pnp attention threshold', minimum=0, maximum=1,
255
+ value=0.5, step=0.5, interactive=True)
256
+ pnp_f_t = gr.Slider(label='pnp feature threshold', minimum=0, maximum=1,
257
+ value=0.8, step=0.05, interactive=True)
258
 
259
 
260
+ input_video.change(
261
  fn = reset_do_inversion,
262
  outputs = [do_inversion],
263
  queue = False)
264
 
265
+ input_video.upload(
266
  fn = reset_do_inversion,
267
  outputs = [do_inversion],
268
  queue = False).then(fn = preprocess_and_invert,
269
+ inputs = [input_video,
270
  frames,
271
  latents,
272
  inverted_latents,
 
281
  outputs = [frames,
282
  latents,
283
  inverted_latents,
284
+ do_inversion
 
285
 
286
  ])
287
+
288
+ run_button.click(fn = edit_with_pnp,
289
+ inputs = [input_video,
290
+ frames,
291
+ latents,
292
+ inverted_latents,
293
+ seed,
294
+ randomize_seed,
295
+ do_inversion,
296
+ steps,
297
+ prompt,
298
+ pnp_attn_t,
299
+ pnp_f_t,
300
+ batch_size,
301
+ n_frames,
302
+ n_timesteps,
303
+ gudiance_scale,
304
+ inversion_prompt ],
305
+ outputs = [output_video, frames, latents, inverted_latents, do_inversion]
306
+ )
307
 
308
+ # gr.Examples(
309
+ # examples=get_example(),
310
+ # label='Examples',
311
+ # inputs=[input_vid],
312
+ # outputs=[input_vid]
313
+ # )
314
 
315
 
316
 
317
  demo.queue()
318
+ demo.launch(share=True)