lev1 commited on
Commit
2d7762b
1 Parent(s): a681a6f

T2V Tab improvements

Browse files
app_text_to_video.py CHANGED
@@ -1,16 +1,17 @@
1
  import gradio as gr
2
  from model import Model
 
3
 
4
  examples = [
5
- "an astronaut waving the arm on the moon",
6
- "a sloth surfing on a wakeboard",
7
- "an astronaut walking on a street",
8
- "a cute cat walking on grass",
9
- "a horse is galloping on a street",
10
- "an astronaut is skiing down the hill",
11
- "a gorilla walking alone down the street"
12
- "a gorilla dancing on times square",
13
- "A panda dancing dancing like crazy on Times Square",
14
  ]
15
 
16
 
@@ -24,17 +25,35 @@ def create_demo(model: Model):
24
  with gr.Column():
25
  prompt = gr.Textbox(label='Prompt')
26
  run_button = gr.Button(label='Run')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  with gr.Column():
28
  result = gr.Video(label="Generated Video")
29
  inputs = [
30
- prompt,
 
 
 
31
  ]
32
 
33
  gr.Examples(examples=examples,
34
  inputs=inputs,
35
  outputs=result,
36
- cache_examples=False,
37
- #cache_examples=os.getenv('SYSTEM') == 'spaces')
38
  run_on_click=False,
39
  )
40
 
1
  import gradio as gr
2
  from model import Model
3
+ from functools import partial
4
 
5
  examples = [
6
+ ["an astronaut waving the arm on the moon"],
7
+ ["a sloth surfing on a wakeboard"],
8
+ ["an astronaut walking on a street"],
9
+ ["a cute cat walking on grass"],
10
+ ["a horse is galloping on a street"],
11
+ ["an astronaut is skiing down the hill"],
12
+ ["a gorilla walking alone down the street"],
13
+ ["a gorilla dancing on times square"],
14
+ ["A panda dancing dancing like crazy on Times Square"],
15
  ]
16
 
17
 
25
  with gr.Column():
26
  prompt = gr.Textbox(label='Prompt')
27
  run_button = gr.Button(label='Run')
28
+ with gr.Accordion('Advanced options', open=False):
29
+ motion_field_strength_x = gr.Slider(label='Global Translation $\delta_{x}$',
30
+ minimum=-20,
31
+ maximum=20,
32
+ value=12,
33
+ step=1)
34
+
35
+ motion_field_strength_y = gr.Slider(label='Global Translation $\delta_{y}$',
36
+ minimum=-20,
37
+ maximum=20,
38
+ value=12,
39
+ step=1)
40
+ # a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
41
+ n_prompt = gr.Textbox(label="Optional Negative Prompt",
42
+ value='')
43
  with gr.Column():
44
  result = gr.Video(label="Generated Video")
45
  inputs = [
46
+ prompt,
47
+ motion_field_strength_x,
48
+ motion_field_strength_y,
49
+ n_prompt
50
  ]
51
 
52
  gr.Examples(examples=examples,
53
  inputs=inputs,
54
  outputs=result,
55
+ # cache_examples=False,
56
+ cache_examples=os.getenv('SYSTEM') == 'spaces',
57
  run_on_click=False,
58
  )
59
 
model.py CHANGED
@@ -255,26 +255,71 @@ class Model:
255
  )
256
  return utils.create_video(result, fps)
257
 
258
- def process_text2video(self, prompt, resolution=512, seed=24, num_frames=8, fps=4, t0=881, t1=941,
259
- use_cf_attn=True, use_motion_field=True, use_foreground_motion_field=False,
260
- smooth_bg=False, smooth_bg_strength=0.4, motion_field_strength=12):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  if self.model_type != ModelType.Text2Video:
263
- unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
264
  self.set_model(ModelType.Text2Video, model_id="runwayml/stable-diffusion-v1-5", unet=unet)
265
  self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
266
- self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
 
267
  self.generator.manual_seed(seed)
268
-
269
 
270
  added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
271
- self.generator.manual_seed(seed)
272
 
273
  prompt = prompt.rstrip()
274
  if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
275
  prompt = prompt.rstrip()[:-1]
276
  prompt = prompt.rstrip()
277
  prompt = prompt + ", "+added_prompt
 
 
 
 
278
 
279
  result = self.inference(prompt=[prompt],
280
  video_length=num_frames,
@@ -285,12 +330,13 @@ class Model:
285
  guidance_stop_step=1.0,
286
  t0=t0,
287
  t1=t1,
288
- use_foreground_motion_field=use_foreground_motion_field,
289
- motion_field_strength=motion_field_strength,
290
  use_motion_field=use_motion_field,
291
  smooth_bg=smooth_bg,
292
  smooth_bg_strength=smooth_bg_strength,
293
  seed=seed,
294
  output_type='numpy',
 
295
  )
296
- return utils.create_video(result, fps)
255
  )
256
  return utils.create_video(result, fps)
257
 
258
+ # def process_text2video(self, prompt, resolution=512, seed=24, num_frames=8, fps=4, t0=881, t1=941,
259
+ # use_cf_attn=True, use_motion_field=True, use_foreground_motion_field=False,
260
+ # smooth_bg=False, smooth_bg_strength=0.4, motion_field_strength=12):
261
+
262
+ # if self.model_type != ModelType.Text2Video:
263
+ # unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
264
+ # self.set_model(ModelType.Text2Video, model_id="runwayml/stable-diffusion-v1-5", unet=unet)
265
+ # self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
266
+ # self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
267
+ # self.generator.manual_seed(seed)
268
+
269
+
270
+ # added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
271
+ # self.generator.manual_seed(seed)
272
+
273
+ # prompt = prompt.rstrip()
274
+ # if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
275
+ # prompt = prompt.rstrip()[:-1]
276
+ # prompt = prompt.rstrip()
277
+ # prompt = prompt + ", "+added_prompt
278
+
279
+ # result = self.inference(prompt=[prompt],
280
+ # video_length=num_frames,
281
+ # height=resolution,
282
+ # width=resolution,
283
+ # num_inference_steps=50,
284
+ # guidance_scale=7.5,
285
+ # guidance_stop_step=1.0,
286
+ # t0=t0,
287
+ # t1=t1,
288
+ # use_foreground_motion_field=use_foreground_motion_field,
289
+ # motion_field_strength=motion_field_strength,
290
+ # use_motion_field=use_motion_field,
291
+ # smooth_bg=smooth_bg,
292
+ # smooth_bg_strength=smooth_bg_strength,
293
+ # seed=seed,
294
+ # output_type='numpy',
295
+ # )
296
+ # return utils.create_video(result, fps)
297
+
298
+ def process_text2video(self, prompt, motion_field_strength_x=12,motion_field_strength_y=12, n_prompt="", resolution=512, seed=24, num_frames=8, fps=4, t0=881, t1=941,
299
+ use_cf_attn=True, use_motion_field=True,
300
+ smooth_bg=False, smooth_bg_strength=0.4 ):
301
 
302
  if self.model_type != ModelType.Text2Video:
303
+ unet = UNet2DConditionModel.from_pretrained('runwayml/stable-diffusion-v1-5', subfolder="unet")
304
  self.set_model(ModelType.Text2Video, model_id="runwayml/stable-diffusion-v1-5", unet=unet)
305
  self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
306
+ if use_cf_attn:
307
+ self.pipe.unet.set_attn_processor(processor=self.text2video_attn_proc)
308
  self.generator.manual_seed(seed)
309
+
310
 
311
  added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
312
+ negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
313
 
314
  prompt = prompt.rstrip()
315
  if len(prompt) > 0 and (prompt[-1] == "," or prompt[-1] == "."):
316
  prompt = prompt.rstrip()[:-1]
317
  prompt = prompt.rstrip()
318
  prompt = prompt + ", "+added_prompt
319
+ if len(n_prompt)>0:
320
+ negative_prompt = [n_prompt]
321
+ else:
322
+ negative_prompt = None
323
 
324
  result = self.inference(prompt=[prompt],
325
  video_length=num_frames,
330
  guidance_stop_step=1.0,
331
  t0=t0,
332
  t1=t1,
333
+ motion_field_strength_x=motion_field_strength_x,
334
+ motion_field_strength_y=motion_field_strength_y,
335
  use_motion_field=use_motion_field,
336
  smooth_bg=smooth_bg,
337
  smooth_bg_strength=smooth_bg_strength,
338
  seed=seed,
339
  output_type='numpy',
340
+ negative_prompt = negative_prompt,
341
  )
342
+ return utils.create_video(result, fps)
text_to_video/text_to_video_generator.py CHANGED
@@ -13,6 +13,8 @@ class TextToVideo():
13
  g.manual_seed(22)
14
  self.g = g
15
 
 
 
16
  print(f"Loading model SD-Net model file from {sd_path}")
17
 
18
  self.dtype = torch.float16
13
  g.manual_seed(22)
14
  self.g = g
15
 
16
+ assert sd_path is not None
17
+
18
  print(f"Loading model SD-Net model file from {sd_path}")
19
 
20
  self.dtype = torch.float16
text_to_video/text_to_video_pipeline.py CHANGED
@@ -142,7 +142,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
142
  with self.progress_bar(total=num_inference_steps) as progress_bar:
143
  for i, t in enumerate(timesteps):
144
  if t > skip_t:
145
- # print("Skipping frame!")
146
  continue
147
  else:
148
  if not entered:
@@ -235,19 +234,20 @@ class TextToVideoPipeline(StableDiffusionPipeline):
235
  List[torch.Generator]]] = None,
236
  xT: Optional[torch.FloatTensor] = None,
237
  null_embs: Optional[torch.FloatTensor] = None,
238
- motion_field_strength: float = 12,
 
 
239
  output_type: Optional[str] = "tensor",
240
  return_dict: bool = True,
241
  callback: Optional[Callable[[
242
  int, int, torch.FloatTensor], None]] = None,
243
  callback_steps: Optional[int] = 1,
244
- use_foreground_motion_field: bool = True,
245
  use_motion_field: bool = True,
246
  smooth_bg: bool = True,
247
  smooth_bg_strength: float = 0.4,
248
  **kwargs,
249
  ):
250
-
251
  print(f" Use: Motion field = {use_motion_field}")
252
  print(f" Use: Background smoothing = {smooth_bg}")
253
  # Default height and width to unet
@@ -349,7 +349,9 @@ class TextToVideoPipeline(StableDiffusionPipeline):
349
  reference_flow = torch.zeros(
350
  (video_length-1, 2, 512, 512), device=x_t0_1.device, dtype=x_t0_1.dtype)
351
  for fr_idx in range(video_length-1):
352
- reference_flow[fr_idx, :, :, :] = motion_field_strength*(fr_idx+1)
 
 
353
 
354
  for idx, latent in enumerate(x_t0_k):
355
  x_t0_k[idx] = self.warp_latents_independently(
@@ -379,63 +381,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
379
  x_t0_k = x_t0_1[:, :, 1:, :, :].clone()
380
  x_t0_1 = x_t0_1[:,:,:1,:,:].clone()
381
 
382
-
383
- move_object = use_foreground_motion_field
384
- if move_object:
385
- h, w = x0.shape[3], x0.shape[4]
386
- # Move object
387
- # reference_flow = torch.zeros(
388
- # (video_length-1, 2, 512, 512), device=x_t0_1.device, dtype=x_t0_1.dtype)
389
- reference_flow_obj = torch.zeros(
390
- (batch_size, video_length, 2, 512, 512), device=x_t0_1.device, dtype=x_t0_1.dtype)
391
-
392
- for batch_idx, x0_b in enumerate(x0):
393
- tmp = x0_b[None]
394
- z0_b = []
395
- for fr_split in range(tmp.shape[2]):
396
- z0_b.append(self.decode_latents(
397
- tmp[:, :, fr_split, None]).detach())
398
- z0_b = torch.cat(z0_b, dim=2)
399
- z0_b = rearrange(z0_b[0], "c f h w -> f h w c")
400
- shift = (-5 - 5) * torch.rand(2,
401
- device=x0.device, dtype=x0.dtype) + 5
402
- for frame_idx, z0_f in enumerate(z0_b):
403
- if frame_idx > 0:
404
-
405
- z0_f = torch.round(
406
- z0_f * 255).cpu().numpy().astype(np.uint8)
407
-
408
- # apply SOD detection to obtain mask of foreground object
409
- m_f = torch.tensor(self.sod_model.process_data(
410
- z0_f), device=x0.device).to(x0.dtype)
411
- kernel = torch.ones(
412
- 5, 5, device=x0.device, dtype=x0.dtype)
413
- mask = dilation(
414
- m_f[None, None].to(x0.device), kernel)[0]
415
- for coord_idx in range(2):
416
- reference_flow_obj[batch_idx, frame_idx,
417
- coord_idx, :, :] = (1+frame_idx) * shift[coord_idx] * mask
418
-
419
-
420
-
421
- for idx, x_t0_k_b in enumerate(x_t0_k):
422
- x_t0_k[idx] = self.warp_latents_independently(
423
- x_t0_k_b[None], reference_flow_obj[idx, 1:])
424
-
425
- x_t1_k = self.DDPM_forward(
426
- x0=x_t0_k, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator)
427
-
428
- if x_t1_1 is None:
429
- raise Exception
430
- x_t1 = torch.cat([x_t1_1, x_t1_k], dim=2)
431
-
432
- # del latent
433
- ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
434
- null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step, callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
435
- x0 = ddim_res["x0"].detach()
436
- del ddim_res
437
-
438
-
439
  # smooth background
440
  if smooth_bg:
441
  h, w = x0.shape[3], x0.shape[4]
@@ -474,9 +419,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
474
  x_t1_fg_masked_b, reference_flow)
475
  else:
476
  x_t1_fg_masked_b = x_t1_fg_masked_b[None]
477
- if move_object:
478
- x_t1_fg_masked_b = self.warp_latents_independently(
479
- x_t1_fg_masked_b, reference_flow_obj[batch_idx, 1:])
480
 
481
  x_t1_fg_masked_b = torch.cat(
482
  [x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2)
@@ -493,9 +435,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
493
  if use_motion_field:
494
  m_fg_b = self.warp_latents_independently(
495
  m_fg_b.clone(), reference_flow)
496
- if move_object:
497
- m_fg_b = self.warp_latents_independently(
498
- m_fg_b, reference_flow_obj[batch_idx, 1:])
499
  M_FG_warped.append(
500
  torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1))
501
 
142
  with self.progress_bar(total=num_inference_steps) as progress_bar:
143
  for i, t in enumerate(timesteps):
144
  if t > skip_t:
 
145
  continue
146
  else:
147
  if not entered:
234
  List[torch.Generator]]] = None,
235
  xT: Optional[torch.FloatTensor] = None,
236
  null_embs: Optional[torch.FloatTensor] = None,
237
+ #motion_field_strength: float = 12,
238
+ motion_field_strength_x: float = 12,
239
+ motion_field_strength_y: float = 12,
240
  output_type: Optional[str] = "tensor",
241
  return_dict: bool = True,
242
  callback: Optional[Callable[[
243
  int, int, torch.FloatTensor], None]] = None,
244
  callback_steps: Optional[int] = 1,
 
245
  use_motion_field: bool = True,
246
  smooth_bg: bool = True,
247
  smooth_bg_strength: float = 0.4,
248
  **kwargs,
249
  ):
250
+ print(motion_field_strength_x,motion_field_strength_y)
251
  print(f" Use: Motion field = {use_motion_field}")
252
  print(f" Use: Background smoothing = {smooth_bg}")
253
  # Default height and width to unet
349
  reference_flow = torch.zeros(
350
  (video_length-1, 2, 512, 512), device=x_t0_1.device, dtype=x_t0_1.dtype)
351
  for fr_idx in range(video_length-1):
352
+ #reference_flow[fr_idx, :, :, :] = motion_field_strength*(fr_idx+1)
353
+ reference_flow[fr_idx, 0, :, :] = motion_field_strength_x*(fr_idx+1)
354
+ reference_flow[fr_idx, 1, :, :] = motion_field_strength_y*(fr_idx+1)
355
 
356
  for idx, latent in enumerate(x_t0_k):
357
  x_t0_k[idx] = self.warp_latents_independently(
381
  x_t0_k = x_t0_1[:, :, 1:, :, :].clone()
382
  x_t0_1 = x_t0_1[:,:,:1,:,:].clone()
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # smooth background
385
  if smooth_bg:
386
  h, w = x0.shape[3], x0.shape[4]
419
  x_t1_fg_masked_b, reference_flow)
420
  else:
421
  x_t1_fg_masked_b = x_t1_fg_masked_b[None]
 
 
 
422
 
423
  x_t1_fg_masked_b = torch.cat(
424
  [x_t1_1_fg_masked_b[None], x_t1_fg_masked_b], dim=2)
435
  if use_motion_field:
436
  m_fg_b = self.warp_latents_independently(
437
  m_fg_b.clone(), reference_flow)
 
 
 
438
  M_FG_warped.append(
439
  torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1))
440