lev1 commited on
Commit
687b293
1 Parent(s): 9adc565

Enabling Token Merging for fast inference

Browse files
app.py CHANGED
@@ -23,7 +23,7 @@ with gr.Blocks(css='style.css') as demo:
23
  """
24
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
25
  <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
26
- Text2Video-Zero
27
  </h1>
28
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
29
  Levon Khachatryan<sup>1*</sup>, Andranik Movsisyan<sup>1*</sup>, Vahram Tadevosyan<sup>1*</sup>, Roberto Henschel<sup>1*</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>
@@ -62,7 +62,8 @@ with gr.Blocks(css='style.css') as demo:
62
  create_demo_canny(model)
63
  with gr.Tab('Edge Conditional and Dreambooth Specialized'):
64
  create_demo_canny_db(model)
65
-
 
66
  gr.HTML(
67
  """
68
  <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
@@ -90,5 +91,5 @@ if on_huggingspace:
90
  demo.launch(debug=True)
91
  else:
92
  _, _, link = demo.queue(api_open=False).launch(
93
- file_directories=['temporal'], share=args.public_access or on_huggingspace)
94
  print(link)
 
23
  """
24
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
25
  <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
26
+ <a href="https://github.com/Picsart-AI-Research/Text2Video-Zero" style="color:blue;">Text2Video-Zero</a>
27
  </h1>
28
  <h2 style="font-weight: 450; font-size: 1rem; margin: 0rem">
29
  Levon Khachatryan<sup>1*</sup>, Andranik Movsisyan<sup>1*</sup>, Vahram Tadevosyan<sup>1*</sup>, Roberto Henschel<sup>1*</sup>, Zhangyang Wang<sup>1,2</sup>, Shant Navasardyan<sup>1</sup>
 
62
  create_demo_canny(model)
63
  with gr.Tab('Edge Conditional and Dreambooth Specialized'):
64
  create_demo_canny_db(model)
65
+ '''
66
+ '''
67
  gr.HTML(
68
  """
69
  <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
 
91
  demo.launch(debug=True)
92
  else:
93
  _, _, link = demo.queue(api_open=False).launch(
94
+ file_directories=['temporal'], share=args.public_access)
95
  print(link)
app_canny.py CHANGED
@@ -47,7 +47,11 @@ def create_demo(model: Model):
47
  watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
48
  "None"], label="Watermark", value='Picsart AI Research')
49
  chunk_size = gr.Slider(
50
- label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
 
 
 
 
51
  with gr.Column():
52
  result = gr.Video(label="Generated Video").style(height="auto")
53
 
@@ -56,6 +60,7 @@ def create_demo(model: Model):
56
  prompt,
57
  chunk_size,
58
  watermark,
 
59
  ]
60
 
61
  gr.Examples(examples=examples,
 
47
  watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
48
  "None"], label="Watermark", value='Picsart AI Research')
49
  chunk_size = gr.Slider(
50
+ label="Chunk size", minimum=2, maximum=16, value=8, step=1, visible=not on_huggingspace,
51
+ info="Number of frames processed at once. Reduce for lower memory usage.")
52
+ merging_ratio = gr.Slider(
53
+ label="Merging ratio", minimum=0.0, maximum=0.9, step=0.1, value=0.0, visible=not on_huggingspace,
54
+ info="Ratio of how many tokens are merged. The higher the more compression (less memory and faster inference).")
55
  with gr.Column():
56
  result = gr.Video(label="Generated Video").style(height="auto")
57
 
 
60
  prompt,
61
  chunk_size,
62
  watermark,
63
+ merging_ratio,
64
  ]
65
 
66
  gr.Examples(examples=examples,
app_canny_db.py CHANGED
@@ -51,7 +51,11 @@ def create_demo(model: Model):
51
  watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
52
  "None"], label="Watermark", value='Picsart AI Research')
53
  chunk_size = gr.Slider(
54
- label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
 
 
 
 
55
  with gr.Column():
56
  result = gr.Image(label="Generated Video").style(height=400)
57
 
@@ -79,6 +83,7 @@ def create_demo(model: Model):
79
  prompt,
80
  chunk_size,
81
  watermark,
 
82
  ]
83
 
84
  gr.Examples(examples=examples,
 
51
  watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
52
  "None"], label="Watermark", value='Picsart AI Research')
53
  chunk_size = gr.Slider(
54
+ label="Chunk size", minimum=2, maximum=16, value=8, step=1, visible=not on_huggingspace,
55
+ info="Number of frames processed at once. Reduce for lower memory usage.")
56
+ merging_ratio = gr.Slider(
57
+ label="Merging ratio", minimum=0.0, maximum=0.9, step=0.1, value=0.0, visible=not on_huggingspace,
58
+ info="Ratio of how many tokens are merged. The higher the more compression (less memory and faster inference).")
59
  with gr.Column():
60
  result = gr.Image(label="Generated Video").style(height=400)
61
 
 
83
  prompt,
84
  chunk_size,
85
  watermark,
86
+ merging_ratio,
87
  ]
88
 
89
  gr.Examples(examples=examples,
app_pix2pix_video.py CHANGED
@@ -48,9 +48,10 @@ def create_demo(model: Model):
48
  value=512,
49
  step=64)
50
  seed = gr.Slider(label='Seed',
51
- minimum=0,
52
  maximum=65536,
53
  value=0,
 
54
  step=1)
55
  image_guidance = gr.Slider(label='Image guidance scale',
56
  minimum=0.5,
@@ -73,7 +74,11 @@ def create_demo(model: Model):
73
  value=-1,
74
  step=1)
75
  chunk_size = gr.Slider(
76
- label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
 
 
 
 
77
  with gr.Column():
78
  result = gr.Video(label='Output', show_label=True)
79
  inputs = [
@@ -86,7 +91,8 @@ def create_demo(model: Model):
86
  end_t,
87
  out_fps,
88
  chunk_size,
89
- watermark
 
90
  ]
91
 
92
  gr.Examples(examples=examples,
 
48
  value=512,
49
  step=64)
50
  seed = gr.Slider(label='Seed',
51
+ minimum=-1,
52
  maximum=65536,
53
  value=0,
54
+ info="-1 for random seed on each run. Otherwise the seed will be fixed",
55
  step=1)
56
  image_guidance = gr.Slider(label='Image guidance scale',
57
  minimum=0.5,
 
74
  value=-1,
75
  step=1)
76
  chunk_size = gr.Slider(
77
+ label="Chunk size", minimum=2, maximum=16, value=8, step=1, visible=not on_huggingspace,
78
+ info="Number of frames processed at once. Reduce for lower memory usage.")
79
+ merging_ratio = gr.Slider(
80
+ label="Merging ratio", minimum=0.0, maximum=0.9, step=0.1, value=0.0, visible=not on_huggingspace,
81
+ info="Ratio of how many tokens are merged. The higher the more compression (less memory and faster inference).")
82
  with gr.Column():
83
  result = gr.Video(label='Output', show_label=True)
84
  inputs = [
 
91
  end_t,
92
  out_fps,
93
  chunk_size,
94
+ watermark,
95
+ merging_ratio
96
  ]
97
 
98
  gr.Examples(examples=examples,
app_pose.py CHANGED
@@ -35,7 +35,11 @@ def create_demo(model: Model):
35
  watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
36
  "None"], label="Watermark", value='Picsart AI Research')
37
  chunk_size = gr.Slider(
38
- label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
 
 
 
 
39
  with gr.Column():
40
  result = gr.Image(label="Generated Video")
41
 
@@ -48,6 +52,7 @@ def create_demo(model: Model):
48
  prompt,
49
  chunk_size,
50
  watermark,
 
51
  ]
52
 
53
  gr.Examples(examples=examples,
 
35
  watermark = gr.Radio(["Picsart AI Research", "Text2Video-Zero",
36
  "None"], label="Watermark", value='Picsart AI Research')
37
  chunk_size = gr.Slider(
38
+ label="Chunk size", minimum=2, maximum=16, value=8, step=1, visible=not on_huggingspace,
39
+ info="Number of frames processed at once. Reduce for lower memory usage.")
40
+ merging_ratio = gr.Slider(
41
+ label="Merging ratio", minimum=0.0, maximum=0.9, step=0.1, value=0.0, visible=not on_huggingspace,
42
+ info="Ratio of how many tokens are merged. The higher the more compression (less memory and faster inference).")
43
  with gr.Column():
44
  result = gr.Image(label="Generated Video")
45
 
 
52
  prompt,
53
  chunk_size,
54
  watermark,
55
+ merging_ratio,
56
  ]
57
 
58
  gr.Examples(examples=examples,
app_text_to_video.py CHANGED
@@ -39,6 +39,7 @@ def create_demo(model: Model):
39
  label="Model",
40
  choices=get_model_list(),
41
  value="dreamlike-art/dreamlike-photoreal-2.0",
 
42
  )
43
  prompt = gr.Textbox(label='Prompt')
44
  run_button = gr.Button(label='Run')
@@ -52,21 +53,41 @@ def create_demo(model: Model):
52
  else:
53
  video_length = gr.Number(
54
  label="Video length", value=8, precision=0)
55
- chunk_size = gr.Slider(
56
- label="Chunk size", minimum=2, maximum=16, value=12 if on_huggingspace else 8, step=1, visible=not on_huggingspace)
 
 
 
 
 
 
 
57
 
58
  motion_field_strength_x = gr.Slider(
59
- label='Global Translation $\delta_{x}$', minimum=-20, maximum=20, value=12, step=1)
 
 
60
  motion_field_strength_y = gr.Slider(
61
- label='Global Translation $\delta_{y}$', minimum=-20, maximum=20, value=12, step=1)
 
 
62
 
63
  t0 = gr.Slider(label="Timestep t0", minimum=0,
64
- maximum=49, value=44, step=1)
65
- t1 = gr.Slider(label="Timestep t1", minimum=0,
66
- maximum=49, value=47, step=1)
 
 
 
 
 
 
 
 
 
 
 
67
 
68
- n_prompt = gr.Textbox(
69
- label="Optional Negative Prompt", value='')
70
  with gr.Column():
71
  result = gr.Video(label="Generated Video")
72
 
@@ -81,6 +102,8 @@ def create_demo(model: Model):
81
  chunk_size,
82
  video_length,
83
  watermark,
 
 
84
  ]
85
 
86
  gr.Examples(examples=examples,
 
39
  label="Model",
40
  choices=get_model_list(),
41
  value="dreamlike-art/dreamlike-photoreal-2.0",
42
+
43
  )
44
  prompt = gr.Textbox(label='Prompt')
45
  run_button = gr.Button(label='Run')
 
53
  else:
54
  video_length = gr.Number(
55
  label="Video length", value=8, precision=0)
56
+
57
+ n_prompt = gr.Textbox(
58
+ label="Optional Negative Prompt", value='')
59
+ seed = gr.Slider(label='Seed',
60
+ info="-1 for random seed on each run. Otherwise, the seed will be fixed.",
61
+ minimum=-1,
62
+ maximum=65536,
63
+ value=0,
64
+ step=1)
65
 
66
  motion_field_strength_x = gr.Slider(
67
+ label='Global Translation $\\delta_{x}$', minimum=-20, maximum=20,
68
+ value=12,
69
+ step=1)
70
  motion_field_strength_y = gr.Slider(
71
+ label='Global Translation $\\delta_{y}$', minimum=-20, maximum=20,
72
+ value=12,
73
+ step=1)
74
 
75
  t0 = gr.Slider(label="Timestep t0", minimum=0,
76
+ maximum=47, value=44, step=1,
77
+ info="Perform DDPM steps from t0 to t1. The larger the gap between t0 and t1, the more variance between the frames. Ensure t0 < t1 ",
78
+ )
79
+ t1 = gr.Slider(label="Timestep t1", minimum=1,
80
+ info="Perform DDPM steps from t0 to t1. The larger the gap between t0 and t1, the more variance between the frames. Ensure t0 < t1",
81
+ maximum=48, value=47, step=1)
82
+ chunk_size = gr.Slider(
83
+ label="Chunk size", minimum=2, maximum=16, value=8, step=1, visible=not on_huggingspace,
84
+ info="Number of frames processed at once. Reduce for lower memory usage."
85
+ )
86
+ merging_ratio = gr.Slider(
87
+ label="Merging ratio", minimum=0.0, maximum=0.9, step=0.1, value=0.0, visible=not on_huggingspace,
88
+ info="Ratio of how many tokens are merged. The higher the more compression (less memory and faster inference)."
89
+ )
90
 
 
 
91
  with gr.Column():
92
  result = gr.Video(label="Generated Video")
93
 
 
102
  chunk_size,
103
  video_length,
104
  watermark,
105
+ merging_ratio,
106
+ seed,
107
  ]
108
 
109
  gr.Examples(examples=examples,
gradio_utils.py CHANGED
@@ -8,19 +8,19 @@ def edge_path_to_video_path(edge_path):
8
 
9
  vid_name = edge_path.split("/")[-1]
10
  if vid_name == "butterfly.mp4":
11
- video_path = "__assets__/canny_videos_mp4_2fps/butterfly.mp4"
12
  elif vid_name == "deer.mp4":
13
- video_path = "__assets__/canny_videos_mp4_2fps/deer.mp4"
14
  elif vid_name == "fox.mp4":
15
- video_path = "__assets__/canny_videos_mp4_2fps/fox.mp4"
16
  elif vid_name == "girl_dancing.mp4":
17
- video_path = "__assets__/canny_videos_mp4_2fps/girl_dancing.mp4"
18
  elif vid_name == "girl_turning.mp4":
19
- video_path = "__assets__/canny_videos_mp4_2fps/girl_turning.mp4"
20
  elif vid_name == "halloween.mp4":
21
- video_path = "__assets__/canny_videos_mp4_2fps/halloween.mp4"
22
  elif vid_name == "santa.mp4":
23
- video_path = "__assets__/canny_videos_mp4_2fps/santa.mp4"
24
 
25
  assert os.path.isfile(video_path)
26
  return video_path
 
8
 
9
  vid_name = edge_path.split("/")[-1]
10
  if vid_name == "butterfly.mp4":
11
+ video_path = "__assets__/canny_videos_mp4/butterfly.mp4"
12
  elif vid_name == "deer.mp4":
13
+ video_path = "__assets__/canny_videos_mp4/deer.mp4"
14
  elif vid_name == "fox.mp4":
15
+ video_path = "__assets__/canny_videos_mp4/fox.mp4"
16
  elif vid_name == "girl_dancing.mp4":
17
+ video_path = "__assets__/canny_videos_mp4/girl_dancing.mp4"
18
  elif vid_name == "girl_turning.mp4":
19
+ video_path = "__assets__/canny_videos_mp4/girl_turning.mp4"
20
  elif vid_name == "halloween.mp4":
21
+ video_path = "__assets__/canny_videos_mp4/halloween.mp4"
22
  elif vid_name == "santa.mp4":
23
+ video_path = "__assets__/canny_videos_mp4/santa.mp4"
24
 
25
  assert os.path.isfile(video_path)
26
  return video_path
model.py CHANGED
@@ -1,7 +1,7 @@
1
  from enum import Enum
2
  import gc
3
  import numpy as np
4
-
5
  import torch
6
 
7
  from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
@@ -45,6 +45,7 @@ class Model:
45
  self.model_type = None
46
 
47
  self.states = {}
 
48
 
49
  def set_model(self, model_type: ModelType, model_id: str, **kwargs):
50
  if self.pipe is not None:
@@ -55,6 +56,7 @@ class Model:
55
  self.pipe = self.pipe_dict[model_type].from_pretrained(
56
  model_id, safety_checker=safety_checker, **kwargs).to(self.device).to(self.dtype)
57
  self.model_type = model_type
 
58
 
59
  def inference_chunk(self, frame_ids, **kwargs):
60
  if self.pipe is None:
@@ -80,6 +82,13 @@ class Model:
80
  def inference(self, split_to_chunks=False, chunk_size=8, **kwargs):
81
  if self.pipe is None:
82
  return
 
 
 
 
 
 
 
83
  seed = kwargs.pop('seed', 0)
84
  if seed < 0:
85
  seed = self.generator.seed()
@@ -116,6 +125,7 @@ class Model:
116
  result = np.concatenate(result)
117
  return result
118
  else:
 
119
  return self.pipe(prompt=prompt, negative_prompt=negative_prompt, generator=self.generator, **kwargs).images
120
 
121
  def process_controlnet_canny(self,
@@ -123,6 +133,7 @@ class Model:
123
  prompt,
124
  chunk_size=8,
125
  watermark='Picsart AI Research',
 
126
  num_inference_steps=20,
127
  controlnet_conditioning_scale=1.0,
128
  guidance_scale=9.0,
@@ -133,6 +144,7 @@ class Model:
133
  resolution=512,
134
  use_cf_attn=True,
135
  save_path=None):
 
136
  video_path = gradio_utils.edge_path_to_video_path(video_path)
137
  if self.model_type != ModelType.ControlNetCanny:
138
  controlnet = ControlNetModel.from_pretrained(
@@ -173,6 +185,7 @@ class Model:
173
  output_type='numpy',
174
  split_to_chunks=True,
175
  chunk_size=chunk_size,
 
176
  )
177
  return utils.create_video(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
178
 
@@ -181,6 +194,7 @@ class Model:
181
  prompt,
182
  chunk_size=8,
183
  watermark='Picsart AI Research',
 
184
  num_inference_steps=20,
185
  controlnet_conditioning_scale=1.0,
186
  guidance_scale=9.0,
@@ -189,6 +203,7 @@ class Model:
189
  resolution=512,
190
  use_cf_attn=True,
191
  save_path=None):
 
192
  video_path = gradio_utils.motion_to_video_path(video_path)
193
  if self.model_type != ModelType.ControlNetPose:
194
  controlnet = ControlNetModel.from_pretrained(
@@ -232,6 +247,7 @@ class Model:
232
  output_type='numpy',
233
  split_to_chunks=True,
234
  chunk_size=chunk_size,
 
235
  )
236
  return utils.create_gif(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
237
 
@@ -241,6 +257,7 @@ class Model:
241
  prompt,
242
  chunk_size=8,
243
  watermark='Picsart AI Research',
 
244
  num_inference_steps=20,
245
  controlnet_conditioning_scale=1.0,
246
  guidance_scale=9.0,
@@ -251,6 +268,7 @@ class Model:
251
  resolution=512,
252
  use_cf_attn=True,
253
  save_path=None):
 
254
  db_path = gradio_utils.get_model_from_db_selection(db_path)
255
  video_path = gradio_utils.get_video_from_canny_selection(video_path)
256
  # Load db and controlnet weights
@@ -295,6 +313,7 @@ class Model:
295
  output_type='numpy',
296
  split_to_chunks=True,
297
  chunk_size=chunk_size,
 
298
  )
299
  return utils.create_gif(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
300
 
@@ -309,8 +328,10 @@ class Model:
309
  out_fps=-1,
310
  chunk_size=8,
311
  watermark='Picsart AI Research',
 
312
  use_cf_attn=True,
313
  save_path=None,):
 
314
  if self.model_type != ModelType.Pix2Pix_Video:
315
  self.set_model(ModelType.Pix2Pix_Video,
316
  model_id="timbrooks/instruct-pix2pix")
@@ -330,6 +351,7 @@ class Model:
330
  image_guidance_scale=image_guidance_scale,
331
  split_to_chunks=True,
332
  chunk_size=chunk_size,
 
333
  )
334
  return utils.create_video(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
335
 
@@ -344,17 +366,18 @@ class Model:
344
  chunk_size=8,
345
  video_length=8,
346
  watermark='Picsart AI Research',
347
- inject_noise_to_warp=False,
 
348
  resolution=512,
349
- seed=-1,
350
  fps=2,
351
  use_cf_attn=True,
352
  use_motion_field=True,
353
  smooth_bg=False,
354
  smooth_bg_strength=0.4,
355
  path=None):
356
-
357
- if self.model_type != ModelType.Text2Video:
 
358
  unet = UNet2DConditionModel.from_pretrained(
359
  model_name, subfolder="unet")
360
  self.set_model(ModelType.Text2Video,
@@ -364,7 +387,7 @@ class Model:
364
  if use_cf_attn:
365
  self.pipe.unet.set_attn_processor(
366
  processor=self.text2video_attn_proc)
367
- self.generator.manual_seed(seed)
368
 
369
  added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
370
  negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
@@ -396,7 +419,7 @@ class Model:
396
  seed=seed,
397
  output_type='numpy',
398
  negative_prompt=negative_prompt,
399
- inject_noise_to_warp=inject_noise_to_warp,
400
  split_to_chunks=True,
401
  chunk_size=chunk_size,
402
  )
 
1
  from enum import Enum
2
  import gc
3
  import numpy as np
4
+ import tomesd
5
  import torch
6
 
7
  from diffusers import StableDiffusionInstructPix2PixPipeline, StableDiffusionControlNetPipeline, ControlNetModel, UNet2DConditionModel
 
45
  self.model_type = None
46
 
47
  self.states = {}
48
+ self.model_name = ""
49
 
50
  def set_model(self, model_type: ModelType, model_id: str, **kwargs):
51
  if self.pipe is not None:
 
56
  self.pipe = self.pipe_dict[model_type].from_pretrained(
57
  model_id, safety_checker=safety_checker, **kwargs).to(self.device).to(self.dtype)
58
  self.model_type = model_type
59
+ self.model_name = model_id
60
 
61
  def inference_chunk(self, frame_ids, **kwargs):
62
  if self.pipe is None:
 
82
  def inference(self, split_to_chunks=False, chunk_size=8, **kwargs):
83
  if self.pipe is None:
84
  return
85
+ tomesd.remove_patch(self.pipe)
86
+ if "merging_ratio" in kwargs:
87
+ merging_ratio = kwargs.pop("merging_ratio")
88
+
89
+ if merging_ratio > 0:
90
+
91
+ tomesd.apply_patch(self.pipe, ratio=merging_ratio)
92
  seed = kwargs.pop('seed', 0)
93
  if seed < 0:
94
  seed = self.generator.seed()
 
125
  result = np.concatenate(result)
126
  return result
127
  else:
128
+ self.generator.manual_seed(seed)
129
  return self.pipe(prompt=prompt, negative_prompt=negative_prompt, generator=self.generator, **kwargs).images
130
 
131
  def process_controlnet_canny(self,
 
133
  prompt,
134
  chunk_size=8,
135
  watermark='Picsart AI Research',
136
+ merging_ratio=0.0,
137
  num_inference_steps=20,
138
  controlnet_conditioning_scale=1.0,
139
  guidance_scale=9.0,
 
144
  resolution=512,
145
  use_cf_attn=True,
146
  save_path=None):
147
+ print("Processing Canny")
148
  video_path = gradio_utils.edge_path_to_video_path(video_path)
149
  if self.model_type != ModelType.ControlNetCanny:
150
  controlnet = ControlNetModel.from_pretrained(
 
185
  output_type='numpy',
186
  split_to_chunks=True,
187
  chunk_size=chunk_size,
188
+ merging_ratio=merging_ratio,
189
  )
190
  return utils.create_video(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
191
 
 
194
  prompt,
195
  chunk_size=8,
196
  watermark='Picsart AI Research',
197
+ merging_ratio=0.0,
198
  num_inference_steps=20,
199
  controlnet_conditioning_scale=1.0,
200
  guidance_scale=9.0,
 
203
  resolution=512,
204
  use_cf_attn=True,
205
  save_path=None):
206
+ print("Processing Pose")
207
  video_path = gradio_utils.motion_to_video_path(video_path)
208
  if self.model_type != ModelType.ControlNetPose:
209
  controlnet = ControlNetModel.from_pretrained(
 
247
  output_type='numpy',
248
  split_to_chunks=True,
249
  chunk_size=chunk_size,
250
+ merging_ratio=merging_ratio,
251
  )
252
  return utils.create_gif(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
253
 
 
257
  prompt,
258
  chunk_size=8,
259
  watermark='Picsart AI Research',
260
+ merging_ratio=0.0,
261
  num_inference_steps=20,
262
  controlnet_conditioning_scale=1.0,
263
  guidance_scale=9.0,
 
268
  resolution=512,
269
  use_cf_attn=True,
270
  save_path=None):
271
+ print("Processing Canny_DB")
272
  db_path = gradio_utils.get_model_from_db_selection(db_path)
273
  video_path = gradio_utils.get_video_from_canny_selection(video_path)
274
  # Load db and controlnet weights
 
313
  output_type='numpy',
314
  split_to_chunks=True,
315
  chunk_size=chunk_size,
316
+ merging_ratio=merging_ratio,
317
  )
318
  return utils.create_gif(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
319
 
 
328
  out_fps=-1,
329
  chunk_size=8,
330
  watermark='Picsart AI Research',
331
+ merging_ratio=0.0,
332
  use_cf_attn=True,
333
  save_path=None,):
334
+ print("Processing Pix2Pix")
335
  if self.model_type != ModelType.Pix2Pix_Video:
336
  self.set_model(ModelType.Pix2Pix_Video,
337
  model_id="timbrooks/instruct-pix2pix")
 
351
  image_guidance_scale=image_guidance_scale,
352
  split_to_chunks=True,
353
  chunk_size=chunk_size,
354
+ merging_ratio=merging_ratio
355
  )
356
  return utils.create_video(result, fps, path=save_path, watermark=gradio_utils.logo_name_to_path(watermark))
357
 
 
366
  chunk_size=8,
367
  video_length=8,
368
  watermark='Picsart AI Research',
369
+ merging_ratio=0.0,
370
+ seed=0,
371
  resolution=512,
 
372
  fps=2,
373
  use_cf_attn=True,
374
  use_motion_field=True,
375
  smooth_bg=False,
376
  smooth_bg_strength=0.4,
377
  path=None):
378
+ print("Processing Text2Video")
379
+ if self.model_type != ModelType.Text2Video or model_name != self.model_name:
380
+ print("Model update")
381
  unet = UNet2DConditionModel.from_pretrained(
382
  model_name, subfolder="unet")
383
  self.set_model(ModelType.Text2Video,
 
387
  if use_cf_attn:
388
  self.pipe.unet.set_attn_processor(
389
  processor=self.text2video_attn_proc)
390
+ self.generator.manual_seed(seed)
391
 
392
  added_prompt = "high quality, HD, 8K, trending on artstation, high focus, dramatic lighting"
393
  negative_prompts = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer difits, cropped, worst quality, low quality, deformed body, bloated, ugly, unrealistic'
 
419
  seed=seed,
420
  output_type='numpy',
421
  negative_prompt=negative_prompt,
422
+ merging_ratio=merging_ratio,
423
  split_to_chunks=True,
424
  chunk_size=chunk_size,
425
  )
requirements.txt CHANGED
@@ -34,3 +34,4 @@ yapf==0.32.0
34
  safetensors==0.2.7
35
  beautifulsoup4
36
  bs4
 
 
34
  safetensors==0.2.7
35
  beautifulsoup4
36
  bs4
37
+ tomesd
text_to_video_pipeline.py CHANGED
@@ -53,8 +53,10 @@ class TextToVideoPipeline(StableDiffusionPipeline):
53
  if x0 is None:
54
  return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device)
55
  else:
56
- eps = torch.randn_like(x0, dtype=text_embeddings.dtype).to(device)
 
57
  alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax])
 
58
  xt = torch.sqrt(alpha_vec) * x0 + \
59
  torch.sqrt(1-alpha_vec) * eps
60
  return xt
@@ -89,7 +91,7 @@ class TextToVideoPipeline(StableDiffusionPipeline):
89
  latents = latents * self.scheduler.init_noise_sigma
90
  return latents
91
 
92
- def warp_latents_independently(self, latents, reference_flow, inject_noise=False):
93
  _, _, H, W = reference_flow.size()
94
  b, _, f, h, w = latents.size()
95
  assert b == 1
@@ -109,15 +111,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
109
  warped = grid_sample(latents_0, coords_t0,
110
  mode='nearest', padding_mode='reflection')
111
 
112
- if inject_noise:
113
- idx = torch.logical_or(coords_t0 >= 1, coords_t0 < -1)
114
- reset_noise = torch.randn(idx.shape)
115
- idx = torch.logical_or(idx[:, :, :, 0], idx[:, :, :, 1])
116
- idx = repeat(idx, "f w h -> f c w h", c=warped.shape[1])
117
- reset_noise = torch.randn(
118
- size=warped.shape, dtype=warped.dtype, device=warped.device)
119
- warped[idx] = reset_noise[idx]
120
-
121
  warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f)
122
  return warped
123
 
@@ -212,20 +205,20 @@ class TextToVideoPipeline(StableDiffusionPipeline):
212
 
213
  reference_flow = torch.zeros(
214
  (video_length-1, 2, 512, 512), device=latents.device, dtype=latents.dtype)
215
- for fr_idx in range(video_length-1):
216
  reference_flow[fr_idx, 0, :,
217
- :] = motion_field_strength_x*(frame_ids[fr_idx]+1)
218
  reference_flow[fr_idx, 1, :,
219
- :] = motion_field_strength_y*(frame_ids[fr_idx]+1)
220
  return reference_flow
221
 
222
- def create_motion_field_and_warp_latents(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, inject_noise_to_warp, latents):
223
 
224
  motion_field = self.create_motion_field(motion_field_strength_x=motion_field_strength_x,
225
  motion_field_strength_y=motion_field_strength_y, latents=latents, video_length=video_length, frame_ids=frame_ids)
226
  for idx, latent in enumerate(latents):
227
  latents[idx] = self.warp_latents_independently(
228
- latent[None], motion_field, inject_noise=inject_noise_to_warp)
229
  return motion_field, latents
230
 
231
  @torch.no_grad()
@@ -255,13 +248,12 @@ class TextToVideoPipeline(StableDiffusionPipeline):
255
  use_motion_field: bool = True,
256
  smooth_bg: bool = False,
257
  smooth_bg_strength: float = 0.4,
258
- inject_noise_to_warp: bool = False,
259
  t0: int = 44,
260
  t1: int = 47,
261
  **kwargs,
262
  ):
263
  frame_ids = kwargs.pop("frame_ids", list(range(video_length)))
264
-
265
  assert num_videos_per_prompt == 1
266
  assert isinstance(prompt, list) and len(prompt) > 0
267
  assert isinstance(negative_prompt, list) or negative_prompt is None
@@ -280,11 +272,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
280
  prompt = prompt_types[0]
281
  negative_prompt = prompt_types[1]
282
 
283
- print(
284
- f" Motion field strength x = {motion_field_strength_x}, y = {motion_field_strength_y}")
285
- print(f" Use: Motion field = {use_motion_field}")
286
- print(f" Use: Background smoothing = {smooth_bg}")
287
- print(f"Inject noise to warp = {inject_noise_to_warp}")
288
  # Default height and width to unet
289
  height = height or self.unet.config.sample_size * self.vae_scale_factor
290
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -355,6 +342,7 @@ class TextToVideoPipeline(StableDiffusionPipeline):
355
 
356
  t0 = timesteps_ddpm[t0]
357
  t1 = timesteps_ddpm[t1]
 
358
  print(f"t0 = {t0} t1 = {t1}")
359
  x_t1_1 = None
360
 
@@ -366,14 +354,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
366
 
367
  shape = (batch_size, num_channels_latents, 1, height //
368
  self.vae_scale_factor, width // self.vae_scale_factor)
369
- if inject_noise_to_warp and use_motion_field:
370
- # if we inject to noise to warp function, we do it for timesteps T = 1000
371
-
372
- x_t0_k = xT[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1)
373
-
374
- # reference_flow, x_t0_k = self.create_motion_field_and_warp_latents(motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y,
375
- # frame_ids=frame_ids,video_length=video_length,inject_noise_to_warp=inject_noise_to_warp,latents = x_t0_k)
376
- # xT =torch.cat([xT, x_t0_k], dim=2).clone().detach()
377
 
378
  ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance,
379
  null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step,
@@ -387,37 +367,13 @@ class TextToVideoPipeline(StableDiffusionPipeline):
387
  x_t1_1 = ddim_res["x_t1_1"].detach()
388
  del ddim_res
389
  del xT
390
-
391
- if inject_noise_to_warp and use_motion_field:
392
- # DDPM forward to allow for more motion
393
- if t1 > t0:
394
- x_t1_k = self.DDPM_forward(
395
- x0=x_t0_1, t0=t0, tMax=t1, device=device, shape=shape, text_embeddings=text_embeddings, generator=generator)
396
- else:
397
- x_t1_k = x_t0_k
398
-
399
- if x_t1_1 is None:
400
- raise Exception
401
-
402
- x_t1 = x_t1_k.clone().detach()
403
-
404
- ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=t1, t0=-1, t1=-1, do_classifier_free_guidance=do_classifier_free_guidance,
405
- null_embs=null_embs, text_embeddings=text_embeddings, latents_local=x_t1, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step,
406
- callback=callback, callback_steps=callback_steps, extra_step_kwargs=extra_step_kwargs, num_warmup_steps=num_warmup_steps)
407
-
408
- x0 = ddim_res["x0"].detach()
409
- del ddim_res
410
- del x_t1
411
- del x_t1_k
412
-
413
- if use_motion_field and not inject_noise_to_warp:
414
  del x0
415
 
416
  x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1)
417
 
418
  reference_flow, x_t0_k = self.create_motion_field_and_warp_latents(
419
- motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_t0_k, video_length=video_length,
420
- inject_noise_to_warp=inject_noise_to_warp, frame_ids=frame_ids)
421
 
422
  # assuming t0=t1=1000, if t0 = 1000
423
  if t1 > t0:
@@ -440,7 +396,6 @@ class TextToVideoPipeline(StableDiffusionPipeline):
440
  del x_t1
441
  del x_t1_1
442
  del x_t1_k
443
-
444
  else:
445
  x_t1 = x_t1_1.clone()
446
  x_t1_1 = x_t1_1[:, :, :1, :, :].clone()
@@ -481,7 +436,7 @@ class TextToVideoPipeline(StableDiffusionPipeline):
481
  if use_motion_field:
482
  x_t1_fg_masked_b = x_t1_fg_masked_b[None]
483
  x_t1_fg_masked_b = self.warp_latents_independently(
484
- x_t1_fg_masked_b, reference_flow, inject_noise=False)
485
  else:
486
  x_t1_fg_masked_b = x_t1_fg_masked_b[None]
487
 
@@ -499,7 +454,7 @@ class TextToVideoPipeline(StableDiffusionPipeline):
499
  m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1)
500
  if use_motion_field:
501
  m_fg_b = self.warp_latents_independently(
502
- m_fg_b.clone(), reference_flow, inject_noise=False)
503
  M_FG_warped.append(
504
  torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1))
505
 
 
53
  if x0 is None:
54
  return torch.randn(shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype).to(device)
55
  else:
56
+ eps = torch.randn(x0.shape, dtype=text_embeddings.dtype, generator=generator,
57
+ device=rand_device)
58
  alpha_vec = torch.prod(self.scheduler.alphas[t0:tMax])
59
+
60
  xt = torch.sqrt(alpha_vec) * x0 + \
61
  torch.sqrt(1-alpha_vec) * eps
62
  return xt
 
91
  latents = latents * self.scheduler.init_noise_sigma
92
  return latents
93
 
94
+ def warp_latents_independently(self, latents, reference_flow):
95
  _, _, H, W = reference_flow.size()
96
  b, _, f, h, w = latents.size()
97
  assert b == 1
 
111
  warped = grid_sample(latents_0, coords_t0,
112
  mode='nearest', padding_mode='reflection')
113
 
 
 
 
 
 
 
 
 
 
114
  warped = rearrange(warped, '(b f) c h w -> b c f h w', f=f)
115
  return warped
116
 
 
205
 
206
  reference_flow = torch.zeros(
207
  (video_length-1, 2, 512, 512), device=latents.device, dtype=latents.dtype)
208
+ for fr_idx, frame_id in enumerate(frame_ids):
209
  reference_flow[fr_idx, 0, :,
210
+ :] = motion_field_strength_x*(frame_id)
211
  reference_flow[fr_idx, 1, :,
212
+ :] = motion_field_strength_y*(frame_id)
213
  return reference_flow
214
 
215
+ def create_motion_field_and_warp_latents(self, motion_field_strength_x, motion_field_strength_y, frame_ids, video_length, latents):
216
 
217
  motion_field = self.create_motion_field(motion_field_strength_x=motion_field_strength_x,
218
  motion_field_strength_y=motion_field_strength_y, latents=latents, video_length=video_length, frame_ids=frame_ids)
219
  for idx, latent in enumerate(latents):
220
  latents[idx] = self.warp_latents_independently(
221
+ latent[None], motion_field)
222
  return motion_field, latents
223
 
224
  @torch.no_grad()
 
248
  use_motion_field: bool = True,
249
  smooth_bg: bool = False,
250
  smooth_bg_strength: float = 0.4,
 
251
  t0: int = 44,
252
  t1: int = 47,
253
  **kwargs,
254
  ):
255
  frame_ids = kwargs.pop("frame_ids", list(range(video_length)))
256
+ assert t0 < t1
257
  assert num_videos_per_prompt == 1
258
  assert isinstance(prompt, list) and len(prompt) > 0
259
  assert isinstance(negative_prompt, list) or negative_prompt is None
 
272
  prompt = prompt_types[0]
273
  negative_prompt = prompt_types[1]
274
 
 
 
 
 
 
275
  # Default height and width to unet
276
  height = height or self.unet.config.sample_size * self.vae_scale_factor
277
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
342
 
343
  t0 = timesteps_ddpm[t0]
344
  t1 = timesteps_ddpm[t1]
345
+
346
  print(f"t0 = {t0} t1 = {t1}")
347
  x_t1_1 = None
348
 
 
354
 
355
  shape = (batch_size, num_channels_latents, 1, height //
356
  self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
 
 
 
 
357
 
358
  ddim_res = self.DDIM_backward(num_inference_steps=num_inference_steps, timesteps=timesteps, skip_t=1000, t0=t0, t1=t1, do_classifier_free_guidance=do_classifier_free_guidance,
359
  null_embs=null_embs, text_embeddings=text_embeddings, latents_local=xT, latents_dtype=dtype, guidance_scale=guidance_scale, guidance_stop_step=guidance_stop_step,
 
367
  x_t1_1 = ddim_res["x_t1_1"].detach()
368
  del ddim_res
369
  del xT
370
+ if use_motion_field:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
371
  del x0
372
 
373
  x_t0_k = x_t0_1[:, :, :1, :, :].repeat(1, 1, video_length-1, 1, 1)
374
 
375
  reference_flow, x_t0_k = self.create_motion_field_and_warp_latents(
376
+ motion_field_strength_x=motion_field_strength_x, motion_field_strength_y=motion_field_strength_y, latents=x_t0_k, video_length=video_length, frame_ids=frame_ids[1:])
 
377
 
378
  # assuming t0=t1=1000, if t0 = 1000
379
  if t1 > t0:
 
396
  del x_t1
397
  del x_t1_1
398
  del x_t1_k
 
399
  else:
400
  x_t1 = x_t1_1.clone()
401
  x_t1_1 = x_t1_1[:, :, :1, :, :].clone()
 
436
  if use_motion_field:
437
  x_t1_fg_masked_b = x_t1_fg_masked_b[None]
438
  x_t1_fg_masked_b = self.warp_latents_independently(
439
+ x_t1_fg_masked_b, reference_flow)
440
  else:
441
  x_t1_fg_masked_b = x_t1_fg_masked_b[None]
442
 
 
454
  m_fg_b = m_fg_1_b.repeat(1, 1, video_length-1, 1, 1)
455
  if use_motion_field:
456
  m_fg_b = self.warp_latents_independently(
457
+ m_fg_b.clone(), reference_flow)
458
  M_FG_warped.append(
459
  torch.cat([m_fg_1_b[:1, 0], m_fg_b[:1, 0]], dim=1))
460