lopho commited on
Commit
b83ebfb
1 Parent(s): 5b09d17

nicer defaults, selecable scheduler, image cfg separate

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +127 -73
  3. example.webp +2 -2
  4. example_input.png +0 -0
  5. makeavid_sd/inference.py +91 -55
README.md CHANGED
@@ -12,7 +12,7 @@ library_name: diffusers
12
  pipeline_tag: text-to-video
13
  datasets:
14
  - TempoFunk/tempofunk-sdance
15
- - TempoFunk/tempofunk-m
16
  models:
17
  - TempoFunk/makeavid-sd-jax
18
  - runwayml/stable-diffusion-v1-5
 
12
  pipeline_tag: text-to-video
13
  datasets:
14
  - TempoFunk/tempofunk-sdance
15
+ - TempoFunk/small
16
  models:
17
  - TempoFunk/makeavid-sd-jax
18
  - runwayml/stable-diffusion-v1-5
app.py CHANGED
@@ -7,7 +7,11 @@ from functools import partial
7
  from PIL import Image, ImageOps
8
  import gradio as gr
9
 
10
- from makeavid_sd.inference import InferenceUNetPseudo3D, FlaxDPMSolverMultistepScheduler, jnp
 
 
 
 
11
 
12
  print(os.environ.get('XLA_PYTHON_CLIENT_PREALLOCATE', 'NotSet'))
13
  print(os.environ.get('XLA_PYTHON_CLIENT_ALLOCATOR', 'NotSet'))
@@ -17,8 +21,7 @@ _preheat: bool = False
17
  _seen_compilations = set()
18
 
19
  _model = InferenceUNetPseudo3D(
20
- model_path = 'TempoFunk/makeavid-sd-jax',
21
- scheduler_cls = FlaxDPMSolverMultistepScheduler,
22
  dtype = jnp.float16,
23
  hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
24
  )
@@ -30,69 +33,85 @@ if _model.failed != False:
30
 
31
  demo.launch()
32
 
 
 
 
 
33
  # gradio is illiterate. type hints make it go poopoo in pantsu.
34
  def generate(
35
  prompt = 'An elderly man having a great time in the park.',
36
  neg_prompt = '',
37
- image = None,
38
  inference_steps = 20,
39
- cfg = 12.0,
 
40
  seed = 0,
41
  fps = 24,
42
  num_frames = 24,
43
  height = 512,
44
- width = 512
 
 
45
  ) -> str:
 
 
46
  height = int(height)
47
  width = int(width)
48
- num_frames = int(num_frames)
49
- seed = int(seed)
50
  height = (height // 64) * 64
51
  width = (width // 64) * 64
 
 
 
52
  if seed < 0:
53
  seed = -seed
54
- inference_steps = int(inference_steps)
55
- hint_image = image
56
  if hint_image is not None:
57
  if hint_image.mode != 'RGB':
58
  hint_image = hint_image.convert('RGB')
59
  if hint_image.size != (width, height):
60
  hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS)
 
 
 
 
 
 
61
  images = _model.generate(
62
  prompt = [prompt] * _model.device_count,
63
  neg_prompt = neg_prompt,
64
  hint_image = hint_image,
65
- mask_image = None,
66
  inference_steps = inference_steps,
67
  cfg = cfg,
 
68
  height = height,
69
  width = width,
70
  num_frames = num_frames,
71
- seed = seed
 
72
  )
73
  _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
74
  buffer = BytesIO()
75
- images[0].save(
76
  buffer,
77
- format = 'webp',
78
  save_all = True,
79
- append_images = images[1:],
80
  loop = 0,
81
  duration = round(1000 / fps),
82
  allow_mixed = True
83
  )
84
  data = base64.b64encode(buffer.getvalue()).decode()
85
- data = 'data:image/webp;base64,' + data
86
  buffer.close()
 
87
  return data
88
 
89
- def check_if_compiled(image, inference_steps, height, width, num_frames, message):
90
  height = int(height)
91
  width = int(width)
 
92
  height = (height // 64) * 64
93
  width = (width // 64) * 64
94
- hint_image = image
95
- if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations:
96
  return ''
97
  else:
98
  return f"""{message}"""
@@ -126,19 +145,19 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
126
  # Make-A-Video Stable Diffusion JAX
127
 
128
  We have extended a pretrained LDM inpainting image generation model with temporal convolutions and attention.
129
- We take advantage of the extra 5 input channels of the inpaint model to guide the video generation with a hint image and mask.
130
- The hint image can be given by the user, otherwise it is generated by an generative image model.
131
 
132
- The temporal convolution and attention is a port of [Make-A-Video Pytorch](https://github.com/lucidrains/make-a-video-pytorch/blob/main/make_a_video_pytorch) to FLAX.
133
- It is a pseudo 3D convolution that seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D.
134
- Temporal attention is purely self attention and also separately attends to time and space.
135
 
136
  Only the new temporal layers have been fine tuned on a dataset of videos themed around dance.
137
- The model has been trained for 60 epochs on a dataset of 10,000 Videos with 120 frames each, randomly selecting a 24 frame range from each sample.
138
 
139
  See model and dataset links in the metadata.
140
 
141
- Model implementation and training code can be found at [https://github.com/lopho/makeavid-sd-tpu](https://github.com/lopho/makeavid-sd-tpu)
142
  """)
143
  with gr.Column():
144
  intro3 = gr.Markdown("""
@@ -151,40 +170,44 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
151
  Changes to the following parameters require the model to compile
152
  - Number of frames
153
  - Width & Height
154
- - Steps
155
  - Input image vs. no input image
 
 
 
156
  """)
157
 
158
  with gr.Row(variant = variant):
159
- with gr.Column(variant = variant):
160
  with gr.Row():
161
  #cancel_button = gr.Button(value = 'Cancel')
162
  submit_button = gr.Button(value = 'Make A Video', variant = 'primary')
163
  prompt_input = gr.Textbox(
164
  label = 'Prompt',
165
- value = 'They are dancing in the club while sweat drips from the ceiling.',
166
  interactive = True
167
  )
168
  neg_prompt_input = gr.Textbox(
169
  label = 'Negative prompt (optional)',
170
- value = '',
171
  interactive = True
172
  )
173
- inference_steps_input = gr.Slider(
174
- label = 'Steps',
175
- minimum = 2,
176
- maximum = 100,
177
- value = 20,
178
- step = 1
179
- )
180
  cfg_input = gr.Slider(
181
- label = 'Guidance scale',
182
  minimum = 1.0,
183
  maximum = 20.0,
184
  step = 0.1,
185
  value = 15.0,
186
  interactive = True
187
  )
 
 
 
 
 
 
 
 
188
  seed_input = gr.Number(
189
  label = 'Random seed',
190
  value = 0,
@@ -192,43 +215,68 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
192
  precision = 0
193
  )
194
  image_input = gr.Image(
195
- label = 'Input image (optional)',
196
  interactive = True,
197
  image_mode = 'RGB',
198
  type = 'pil',
199
  optional = True,
200
- source = 'upload'
 
 
 
 
 
 
 
 
 
201
  )
202
  num_frames_input = gr.Slider(
203
  label = 'Number of frames to generate',
204
  minimum = 1,
205
  maximum = 24,
206
  step = 1,
207
- value = 24
 
208
  )
209
  width_input = gr.Slider(
210
  label = 'Width',
211
  minimum = 64,
212
- maximum = 512,
213
  step = 64,
214
- value = 448
 
215
  )
216
  height_input = gr.Slider(
217
  label = 'Height',
218
  minimum = 64,
219
- maximum = 512,
220
  step = 64,
221
- value = 448
 
222
  )
223
- fps_input = gr.Slider(
224
- label = 'Output FPS',
225
- minimum = 1,
226
- maximum = 1000,
227
- step = 1,
228
- value = 12
229
  )
230
- with gr.Column(variant = variant):
231
- #no_gpu = gr.Markdown('**Until a GPU is assigned expect extremely long runtimes up to 1h+**')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  #will_trigger = gr.Markdown('')
233
  patience = gr.Markdown('**Please be patient. The model might have to compile with current parameters.**')
234
  image_output = gr.Image(
@@ -236,33 +284,39 @@ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled =
236
  value = 'example.webp',
237
  interactive = False
238
  )
239
- #trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ]
240
- #trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.')
241
  #height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
242
  #width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
243
  #num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
244
  #image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
245
  #inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
246
- #will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value)
247
- ev = submit_button.click(
248
- fn = generate,
249
- inputs = [
250
- prompt_input,
251
- neg_prompt_input,
252
- image_input,
253
- inference_steps_input,
254
- cfg_input,
255
- seed_input,
256
- fps_input,
257
- num_frames_input,
258
- height_input,
259
- width_input
260
- ],
261
- outputs = image_output,
262
- postprocess = False
 
 
 
263
  )
264
  #cancel_button.click(fn = lambda: None, cancels = ev)
265
 
266
- demo.queue(concurrency_count = 1, max_size = 32)
267
  demo.launch()
268
 
 
 
 
 
7
  from PIL import Image, ImageOps
8
  import gradio as gr
9
 
10
+ from makeavid_sd.inference import (
11
+ InferenceUNetPseudo3D,
12
+ jnp,
13
+ SCHEDULERS
14
+ )
15
 
16
  print(os.environ.get('XLA_PYTHON_CLIENT_PREALLOCATE', 'NotSet'))
17
  print(os.environ.get('XLA_PYTHON_CLIENT_ALLOCATOR', 'NotSet'))
 
21
  _seen_compilations = set()
22
 
23
  _model = InferenceUNetPseudo3D(
24
+ model_path = '/mnt/work1/make_a_vid/makeavid-space/model/model',
 
25
  dtype = jnp.float16,
26
  hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
27
  )
 
33
 
34
  demo.launch()
35
 
36
+ _output_formats = (
37
+ 'webp', 'gif'
38
+ )
39
+
40
  # gradio is illiterate. type hints make it go poopoo in pantsu.
41
  def generate(
42
  prompt = 'An elderly man having a great time in the park.',
43
  neg_prompt = '',
44
+ hint_image = None,
45
  inference_steps = 20,
46
+ cfg = 15.0,
47
+ cfg_image = 9.0,
48
  seed = 0,
49
  fps = 24,
50
  num_frames = 24,
51
  height = 512,
52
+ width = 512,
53
+ scheduler_type = 'DPM',
54
+ output_format = 'webp'
55
  ) -> str:
56
+ num_frames = int(num_frames)
57
+ inference_steps = int(inference_steps)
58
  height = int(height)
59
  width = int(width)
 
 
60
  height = (height // 64) * 64
61
  width = (width // 64) * 64
62
+ cfg = max(cfg, 1.0)
63
+ cfg_image = max(cfg_image, 1.0)
64
+ seed = int(seed)
65
  if seed < 0:
66
  seed = -seed
 
 
67
  if hint_image is not None:
68
  if hint_image.mode != 'RGB':
69
  hint_image = hint_image.convert('RGB')
70
  if hint_image.size != (width, height):
71
  hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS)
72
+ if scheduler_type not in SCHEDULERS:
73
+ scheduler_type = 'DPM'
74
+ output_format = output_format.lower()
75
+ if output_format not in _output_formats:
76
+ output_format = 'webp'
77
+ mask_image = None
78
  images = _model.generate(
79
  prompt = [prompt] * _model.device_count,
80
  neg_prompt = neg_prompt,
81
  hint_image = hint_image,
82
+ mask_image = mask_image,
83
  inference_steps = inference_steps,
84
  cfg = cfg,
85
+ cfg_image = cfg_image,
86
  height = height,
87
  width = width,
88
  num_frames = num_frames,
89
+ seed = seed,
90
+ scheduler_type = scheduler_type
91
  )
92
  _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
93
  buffer = BytesIO()
94
+ images[1].save(
95
  buffer,
96
+ format = output_format,
97
  save_all = True,
98
+ append_images = images[2:],
99
  loop = 0,
100
  duration = round(1000 / fps),
101
  allow_mixed = True
102
  )
103
  data = base64.b64encode(buffer.getvalue()).decode()
 
104
  buffer.close()
105
+ data = f'data:image/{output_format};base64,' + data
106
  return data
107
 
108
+ def check_if_compiled(hint_image, inference_steps, height, width, num_frames, scheduler_type, message):
109
  height = int(height)
110
  width = int(width)
111
+ inference_steps = int(inference_steps)
112
  height = (height // 64) * 64
113
  width = (width // 64) * 64
114
+ if (hint_image is None, inference_steps, height, width, num_frames, scheduler_type) in _seen_compilations:
 
115
  return ''
116
  else:
117
  return f"""{message}"""
 
145
  # Make-A-Video Stable Diffusion JAX
146
 
147
  We have extended a pretrained LDM inpainting image generation model with temporal convolutions and attention.
148
+ By taking advantage of the extra 5 input channels of the inpaint model, we guide the video generation with a hint image.
149
+ In this demo the hint image can be given by the user, otherwise it is generated by an generative image model.
150
 
151
+ The temporal layers are a port of [Make-A-Video PyTorch](https://github.com/lucidrains/make-a-video-pytorch) to FLAX.
152
+ The convolution is pseudo 3D and seperately convolves accross the spatial dimension in 2D and over the temporal dimension in 1D.
153
+ Temporal attention is purely self attention and also separately attends to time.
154
 
155
  Only the new temporal layers have been fine tuned on a dataset of videos themed around dance.
156
+ The model has been trained for 80 epochs on a dataset of 18,000 Videos with 120 frames each, randomly selecting a 24 frame range from each sample.
157
 
158
  See model and dataset links in the metadata.
159
 
160
+ Model implementation and training code can be found at <https://github.com/lopho/makeavid-sd-tpu>
161
  """)
162
  with gr.Column():
163
  intro3 = gr.Markdown("""
 
170
  Changes to the following parameters require the model to compile
171
  - Number of frames
172
  - Width & Height
173
+ - Inference steps
174
  - Input image vs. no input image
175
+ - Noise scheduler type
176
+
177
+ If you encounter any issues, please report them here: [Space discussions](https://huggingface.co/spaces/TempoFunk/makeavid-sd-jax/discussions)
178
  """)
179
 
180
  with gr.Row(variant = variant):
181
+ with gr.Column():
182
  with gr.Row():
183
  #cancel_button = gr.Button(value = 'Cancel')
184
  submit_button = gr.Button(value = 'Make A Video', variant = 'primary')
185
  prompt_input = gr.Textbox(
186
  label = 'Prompt',
187
+ value = 'They are dancing in the club but everybody is a 3d cg hairy monster wearing a hairy costume.',
188
  interactive = True
189
  )
190
  neg_prompt_input = gr.Textbox(
191
  label = 'Negative prompt (optional)',
192
+ value = 'monochrome, saturated',
193
  interactive = True
194
  )
 
 
 
 
 
 
 
195
  cfg_input = gr.Slider(
196
+ label = 'Guidance scale video',
197
  minimum = 1.0,
198
  maximum = 20.0,
199
  step = 0.1,
200
  value = 15.0,
201
  interactive = True
202
  )
203
+ cfg_image_input = gr.Slider(
204
+ label = 'Guidance scale hint (no effect with input image)',
205
+ minimum = 1.0,
206
+ maximum = 20.0,
207
+ step = 0.1,
208
+ value = 9.0,
209
+ interactive = True
210
+ )
211
  seed_input = gr.Number(
212
  label = 'Random seed',
213
  value = 0,
 
215
  precision = 0
216
  )
217
  image_input = gr.Image(
218
+ label = 'Hint image (optional)',
219
  interactive = True,
220
  image_mode = 'RGB',
221
  type = 'pil',
222
  optional = True,
223
+ source = 'upload',
224
+ value = 'example_input.png'
225
+ )
226
+ inference_steps_input = gr.Slider(
227
+ label = 'Steps',
228
+ minimum = 2,
229
+ maximum = 100,
230
+ value = 20,
231
+ step = 1,
232
+ interactive = True
233
  )
234
  num_frames_input = gr.Slider(
235
  label = 'Number of frames to generate',
236
  minimum = 1,
237
  maximum = 24,
238
  step = 1,
239
+ value = 24,
240
+ interactive = True
241
  )
242
  width_input = gr.Slider(
243
  label = 'Width',
244
  minimum = 64,
245
+ maximum = 576,
246
  step = 64,
247
+ value = 512,
248
+ interactive = True
249
  )
250
  height_input = gr.Slider(
251
  label = 'Height',
252
  minimum = 64,
253
+ maximum = 576,
254
  step = 64,
255
+ value = 512,
256
+ interactive = True
257
  )
258
+ scheduler_input = gr.Dropdown(
259
+ label = 'Noise scheduler',
260
+ choices = list(SCHEDULERS.keys()),
261
+ value = 'DPM',
262
+ interactive = True
 
263
  )
264
+ with gr.Row():
265
+ fps_input = gr.Slider(
266
+ label = 'Output FPS',
267
+ minimum = 1,
268
+ maximum = 1000,
269
+ step = 1,
270
+ value = 12,
271
+ interactive = True
272
+ )
273
+ output_format = gr.Dropdown(
274
+ label = 'Output format',
275
+ choices = _output_formats,
276
+ value = 'gif',
277
+ interactive = True
278
+ )
279
+ with gr.Column():
280
  #will_trigger = gr.Markdown('')
281
  patience = gr.Markdown('**Please be patient. The model might have to compile with current parameters.**')
282
  image_output = gr.Image(
 
284
  value = 'example.webp',
285
  interactive = False
286
  )
287
+ #trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input, scheduler_input ]
288
+ #trigger_check_fun = partial(check_if_compiled, message = 'Current parameters need compilation.')
289
  #height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
290
  #width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
291
  #num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
292
  #image_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
293
  #inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
294
+ #scheduler_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
295
+ submit_button.click(
296
+ fn = generate,
297
+ inputs = [
298
+ prompt_input,
299
+ neg_prompt_input,
300
+ image_input,
301
+ inference_steps_input,
302
+ cfg_input,
303
+ cfg_image_input,
304
+ seed_input,
305
+ fps_input,
306
+ num_frames_input,
307
+ height_input,
308
+ width_input,
309
+ scheduler_input,
310
+ output_format
311
+ ],
312
+ outputs = image_output,
313
+ postprocess = False
314
  )
315
  #cancel_button.click(fn = lambda: None, cancels = ev)
316
 
317
+ demo.queue(concurrency_count = 1, max_size = 12)
318
  demo.launch()
319
 
320
+ # Photorealistic fantasy oil painting of the angry minotaur in a threatening pose by Randy Vargas.
321
+ # A girl is dancing by a beautiful lake by sophie anderson and greg rutkowski and alphonse mucha.
322
+ # They are dancing in the club but everybody is a 3d cg hairy monster wearing a hairy costume.
example.webp CHANGED

Git LFS Details

  • SHA256: e04074345eb8c6157398eef5db65167ebaa29356c16a087555d4058cbe2cad6a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB

Git LFS Details

  • SHA256: ffd7cb93989a8e311395799f6d6e566e698ad7654f9f5a471196d8c781f46c1f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
example_input.png ADDED
makeavid_sd/inference.py CHANGED
@@ -1,5 +1,5 @@
1
 
2
- from typing import Any, Union, Tuple, List, Dict
3
  import os
4
  import gc
5
  from functools import partial
@@ -17,13 +17,14 @@ import einops
17
  from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel
18
  from diffusers import (
19
  FlaxDDIMScheduler,
20
- FlaxDDPMScheduler,
21
  FlaxPNDMScheduler,
22
  FlaxLMSDiscreteScheduler,
23
  FlaxDPMSolverMultistepScheduler,
24
- FlaxKarrasVeScheduler,
25
- FlaxScoreSdeVeScheduler
26
  )
 
 
 
 
27
 
28
  from transformers import FlaxCLIPTextModel, CLIPTokenizer
29
 
@@ -31,14 +32,31 @@ from .flax_impl.flax_unet_pseudo3d_condition import UNetPseudo3DConditionModel
31
 
32
  SchedulerType = Union[
33
  FlaxDDIMScheduler,
34
- FlaxDDPMScheduler,
35
  FlaxPNDMScheduler,
36
  FlaxLMSDiscreteScheduler,
37
  FlaxDPMSolverMultistepScheduler,
38
- FlaxKarrasVeScheduler,
39
- FlaxScoreSdeVeScheduler
40
  ]
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  def dtypestr(x: jnp.dtype):
43
  if x == jnp.float32: return 'float32'
44
  elif x == jnp.float16: return 'float16'
@@ -53,7 +71,6 @@ def castto(dtype, m, x):
53
  class InferenceUNetPseudo3D:
54
  def __init__(self,
55
  model_path: str,
56
- scheduler_cls: SchedulerType = FlaxDDIMScheduler,
57
  dtype: jnp.dtype = jnp.float16,
58
  hf_auth_token: Union[str, None] = None
59
  ) -> None:
@@ -129,28 +146,27 @@ class InferenceUNetPseudo3D:
129
  subfolder = 'tokenizer',
130
  use_auth_token = self.hf_auth_token
131
  )
132
- scheduler, scheduler_state = scheduler_cls.from_pretrained(
133
- self.model_path,
134
- subfolder = 'scheduler',
135
- dtype = jnp.float32,
136
- use_auth_token = self.hf_auth_token
137
- )
138
- self.scheduler: scheduler_cls = scheduler
139
- self.params['scheduler'] = scheduler_state
 
 
 
 
 
 
 
 
 
140
  self.vae_scale_factor: int = int(2 ** (len(self.vae.config.block_out_channels) - 1))
141
  self.device_count = jax.device_count()
142
  gc.collect()
143
 
144
- def set_scheduler(self, scheduler_cls: SchedulerType) -> None:
145
- scheduler, scheduler_state = scheduler_cls.from_pretrained(
146
- self.model_path,
147
- subfolder = 'scheduler',
148
- dtype = jnp.float32,
149
- use_auth_token = self.hf_auth_token
150
- )
151
- self.scheduler: scheduler_cls = scheduler
152
- self.params['scheduler'] = scheduler_state
153
-
154
  def prepare_inputs(self,
155
  prompt: List[str],
156
  neg_prompt: List[str],
@@ -213,11 +229,13 @@ class InferenceUNetPseudo3D:
213
  hint_image: Union[Image.Image, List[Image.Image], None] = None,
214
  mask_image: Union[Image.Image, List[Image.Image], None] = None,
215
  neg_prompt: Union[str, List[str]] = '',
216
- cfg: float = 10.0,
 
217
  num_frames: int = 24,
218
  width: int = 512,
219
  height: int = 512,
220
- seed: int = 0
 
221
  ) -> List[List[Image.Image]]:
222
  assert inference_steps > 0, f'number of inference steps must be > 0 but is {inference_steps}'
223
  assert num_frames > 0, f'number of frames must be > 0 but is {num_frames}'
@@ -243,6 +261,7 @@ class InferenceUNetPseudo3D:
243
  if isinstance(neg_prompt, str):
244
  neg_prompt = [ neg_prompt ] * batch_size
245
  assert len(neg_prompt) == batch_size, f'number of negative prompts must be equal to batch size {batch_size} but is {len(neg_prompt)}'
 
246
  tokens, neg_tokens, hint, mask = self.prepare_inputs(
247
  prompt = prompt,
248
  neg_prompt = neg_prompt,
@@ -251,11 +270,14 @@ class InferenceUNetPseudo3D:
251
  width = width,
252
  height = height
253
  )
 
 
 
254
  # NOTE splitting rngs is not deterministic,
255
  # running on different device counts gives different seeds
256
  #rng = jax.random.PRNGKey(seed)
257
  #rngs = jax.random.split(rng, self.device_count)
258
- # manually assign seeded RNGs to devices for reproducability
259
  rngs = jnp.array([ jax.random.PRNGKey(seed + i) for i in range(self.device_count) ])
260
  params = jax_utils.replicate(self.params)
261
  tokens = shard(tokens)
@@ -272,9 +294,11 @@ class InferenceUNetPseudo3D:
272
  height,
273
  width,
274
  cfg,
 
275
  rngs,
276
  params,
277
- use_imagegen
 
278
  )
279
  if images.ndim == 5:
280
  images = einops.rearrange(images, 'd f c h w -> (d f) h w c')
@@ -295,9 +319,11 @@ class InferenceUNetPseudo3D:
295
  height,
296
  width,
297
  cfg: float,
 
298
  rng: jax.random.KeyArray,
299
  params: Union[Dict[str, Any], FrozenDict[str, Any]],
300
- use_imagegen: bool
 
301
  ) -> List[Image.Image]:
302
  batch_size = tokens.shape[0]
303
  latent_h = height // self.vae_scale_factor
@@ -312,15 +338,18 @@ class InferenceUNetPseudo3D:
312
  encoded_prompt = self.text_encoder(tokens, params = params['text_encoder'])[0]
313
  encoded_neg_prompt = self.text_encoder(neg_tokens, params = params['text_encoder'])[0]
314
 
 
 
 
315
  if use_imagegen:
316
  image_latent_shape = (batch_size, self.vae.config.latent_channels, latent_h, latent_w)
317
  image_latents = jax.random.normal(
318
  rng,
319
  shape = image_latent_shape,
320
  dtype = jnp.float32
321
- ) * params['scheduler'].init_noise_sigma
322
- image_scheduler_state = self.scheduler.set_timesteps(
323
- params['scheduler'],
324
  num_inference_steps = inference_steps,
325
  shape = image_latents.shape
326
  )
@@ -328,21 +357,21 @@ class InferenceUNetPseudo3D:
328
  image_latents, image_scheduler_state = args
329
  t = image_scheduler_state.timesteps[step]
330
  tt = jnp.broadcast_to(t, image_latents.shape[0])
331
- latents_input = self.scheduler.scale_model_input(image_scheduler_state, image_latents, t)
332
  noise_pred = self.imunet.apply(
333
- {'params': params['imunet']},
334
  latents_input,
335
  tt,
336
  encoder_hidden_states = encoded_prompt
337
  ).sample
338
  noise_pred_uncond = self.imunet.apply(
339
- {'params': params['imunet']},
340
  latents_input,
341
  tt,
342
  encoder_hidden_states = encoded_neg_prompt
343
  ).sample
344
  noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
345
- image_latents, image_scheduler_state = self.scheduler.step(
346
  image_scheduler_state,
347
  noise_pred.astype(jnp.float32),
348
  t,
@@ -357,7 +386,7 @@ class InferenceUNetPseudo3D:
357
  hint = image_latents
358
  else:
359
  hint = self.vae.apply(
360
- {'params': params['vae']},
361
  hint,
362
  method = self.vae.encode
363
  ).latent_dist.mean * self.vae.config.scaling_factor
@@ -375,9 +404,9 @@ class InferenceUNetPseudo3D:
375
  rng,
376
  shape = latent_shape,
377
  dtype = jnp.float32
378
- ) * params['scheduler'].init_noise_sigma
379
- scheduler_state = self.scheduler.set_timesteps(
380
- params['scheduler'],
381
  num_inference_steps = inference_steps,
382
  shape = latents.shape
383
  )
@@ -386,7 +415,7 @@ class InferenceUNetPseudo3D:
386
  latents, scheduler_state = args
387
  t = scheduler_state.timesteps[step]#jnp.array(scheduler_state.timesteps, dtype = jnp.int32)[step]
388
  tt = jnp.broadcast_to(t, latents.shape[0])
389
- latents_input = self.scheduler.scale_model_input(scheduler_state, latents, t)
390
  latents_input = jnp.concatenate([latents_input, mask, hint], axis = 1)
391
  noise_pred = self.unet.apply(
392
  { 'params': params['unet'] },
@@ -401,7 +430,7 @@ class InferenceUNetPseudo3D:
401
  encoded_neg_prompt
402
  ).sample
403
  noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
404
- latents, scheduler_state = self.scheduler.step(
405
  scheduler_state,
406
  noise_pred.astype(jnp.float32),
407
  t,
@@ -453,9 +482,11 @@ class InferenceUNetPseudo3D:
453
  None, # 7 height
454
  None, # 8 width
455
  None, # 9 cfg
456
- 0, # 10 rng
457
- 0, # 11 params
458
- None, # 12 use_imagegen
 
 
459
  ),
460
  static_broadcasted_argnums = ( # trigger recompilation on change
461
  0, # inference_class
@@ -463,7 +494,8 @@ class InferenceUNetPseudo3D:
463
  6, # num_frames
464
  7, # height
465
  8, # width
466
- 12, # use_imagegen
 
467
  )
468
  )
469
  def _p_generate(
@@ -472,14 +504,16 @@ def _p_generate(
472
  neg_tokens,
473
  hint,
474
  mask,
475
- inference_steps,
476
- num_frames,
477
- height,
478
- width,
479
- cfg,
 
480
  rng,
481
  params,
482
- use_imagegen
 
483
  ):
484
  return inference_class._generate(
485
  tokens,
@@ -491,8 +525,10 @@ def _p_generate(
491
  height,
492
  width,
493
  cfg,
 
494
  rng,
495
  params,
496
- use_imagegen
 
497
  )
498
 
 
1
 
2
+ from typing import Any, Union, Optional, Tuple, List, Dict
3
  import os
4
  import gc
5
  from functools import partial
 
17
  from diffusers import FlaxAutoencoderKL, FlaxUNet2DConditionModel
18
  from diffusers import (
19
  FlaxDDIMScheduler,
 
20
  FlaxPNDMScheduler,
21
  FlaxLMSDiscreteScheduler,
22
  FlaxDPMSolverMultistepScheduler,
 
 
23
  )
24
+ from diffusers.schedulers.scheduling_ddim_flax import DDIMSchedulerState
25
+ from diffusers.schedulers.scheduling_pndm_flax import PNDMSchedulerState
26
+ from diffusers.schedulers.scheduling_lms_discrete_flax import LMSDiscreteSchedulerState
27
+ from diffusers.schedulers.scheduling_dpmsolver_multistep_flax import DPMSolverMultistepSchedulerState
28
 
29
  from transformers import FlaxCLIPTextModel, CLIPTokenizer
30
 
 
32
 
33
  SchedulerType = Union[
34
  FlaxDDIMScheduler,
 
35
  FlaxPNDMScheduler,
36
  FlaxLMSDiscreteScheduler,
37
  FlaxDPMSolverMultistepScheduler,
 
 
38
  ]
39
 
40
+ SchedulerStateType = Union[
41
+ DDIMSchedulerState,
42
+ PNDMSchedulerState,
43
+ LMSDiscreteSchedulerState,
44
+ DPMSolverMultistepSchedulerState,
45
+ ]
46
+
47
+ SCHEDULERS: Dict[str, SchedulerType] = {
48
+ 'DPM': FlaxDPMSolverMultistepScheduler, # husbando
49
+ 'DDIM': FlaxDDIMScheduler,
50
+ #'PLMS': FlaxPNDMScheduler, # its not correctly implemented in diffusers, output is bad, but at least it "works"
51
+ #'LMS': FlaxLMSDiscreteScheduler, # borked
52
+ # image_latents, image_scheduler_state = scheduler.step(
53
+ # File "/mnt/work1/make_a_vid/makeavid-space/.venv/lib/python3.10/site-packages/diffusers/schedulers/scheduling_lms_discrete_flax.py", line 255, in step
54
+ # order = min(timestep + 1, order)
55
+ # jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=1/1)>
56
+ # The problem arose with the `bool` function.
57
+ # The error occurred while tracing the function scanned_fun at /mnt/work1/make_a_vid/makeavid-space/.venv/lib/python3.10/site-packages/jax/_src/lax/control_flow/loops.py:1668 for scan. This concrete value was not available in Python because it depends on the values of the arguments loop_carry[0] and loop_carry[1][1].timesteps
58
+ }
59
+
60
  def dtypestr(x: jnp.dtype):
61
  if x == jnp.float32: return 'float32'
62
  elif x == jnp.float16: return 'float16'
 
71
  class InferenceUNetPseudo3D:
72
  def __init__(self,
73
  model_path: str,
 
74
  dtype: jnp.dtype = jnp.float16,
75
  hf_auth_token: Union[str, None] = None
76
  ) -> None:
 
146
  subfolder = 'tokenizer',
147
  use_auth_token = self.hf_auth_token
148
  )
149
+ self.schedulers: Dict[str, Dict[str, SchedulerType]] = {}
150
+ for scheduler_name in SCHEDULERS:
151
+ if scheduler_name not in ['KarrasVe', 'SDEVe']:
152
+ scheduler, scheduler_state = SCHEDULERS[scheduler_name].from_pretrained(
153
+ self.model_path,
154
+ subfolder = 'scheduler',
155
+ dtype = jnp.float32,
156
+ use_auth_token = self.hf_auth_token
157
+ )
158
+ else:
159
+ scheduler, scheduler_state = SCHEDULERS[scheduler_name].from_pretrained(
160
+ self.model_path,
161
+ subfolder = 'scheduler',
162
+ use_auth_token = self.hf_auth_token
163
+ )
164
+ self.schedulers[scheduler_name] = scheduler
165
+ self.params[scheduler_name] = scheduler_state
166
  self.vae_scale_factor: int = int(2 ** (len(self.vae.config.block_out_channels) - 1))
167
  self.device_count = jax.device_count()
168
  gc.collect()
169
 
 
 
 
 
 
 
 
 
 
 
170
  def prepare_inputs(self,
171
  prompt: List[str],
172
  neg_prompt: List[str],
 
229
  hint_image: Union[Image.Image, List[Image.Image], None] = None,
230
  mask_image: Union[Image.Image, List[Image.Image], None] = None,
231
  neg_prompt: Union[str, List[str]] = '',
232
+ cfg: float = 15.0,
233
+ cfg_image: Optional[float] = None,
234
  num_frames: int = 24,
235
  width: int = 512,
236
  height: int = 512,
237
+ seed: int = 0,
238
+ scheduler_type: str = 'DDIM'
239
  ) -> List[List[Image.Image]]:
240
  assert inference_steps > 0, f'number of inference steps must be > 0 but is {inference_steps}'
241
  assert num_frames > 0, f'number of frames must be > 0 but is {num_frames}'
 
261
  if isinstance(neg_prompt, str):
262
  neg_prompt = [ neg_prompt ] * batch_size
263
  assert len(neg_prompt) == batch_size, f'number of negative prompts must be equal to batch size {batch_size} but is {len(neg_prompt)}'
264
+ assert scheduler_type in SCHEDULERS, f'unknown type of noise scheduler: {scheduler_type}, must be one of {list(SCHEDULERS.keys())}'
265
  tokens, neg_tokens, hint, mask = self.prepare_inputs(
266
  prompt = prompt,
267
  neg_prompt = neg_prompt,
 
270
  width = width,
271
  height = height
272
  )
273
+ if cfg_image is None:
274
+ cfg_image = cfg
275
+ #params['scheduler'] = scheduler_state
276
  # NOTE splitting rngs is not deterministic,
277
  # running on different device counts gives different seeds
278
  #rng = jax.random.PRNGKey(seed)
279
  #rngs = jax.random.split(rng, self.device_count)
280
+ # manually assign seeded RNGs to devices for reproducability
281
  rngs = jnp.array([ jax.random.PRNGKey(seed + i) for i in range(self.device_count) ])
282
  params = jax_utils.replicate(self.params)
283
  tokens = shard(tokens)
 
294
  height,
295
  width,
296
  cfg,
297
+ cfg_image,
298
  rngs,
299
  params,
300
+ use_imagegen,
301
+ scheduler_type,
302
  )
303
  if images.ndim == 5:
304
  images = einops.rearrange(images, 'd f c h w -> (d f) h w c')
 
319
  height,
320
  width,
321
  cfg: float,
322
+ cfg_image: float,
323
  rng: jax.random.KeyArray,
324
  params: Union[Dict[str, Any], FrozenDict[str, Any]],
325
+ use_imagegen: bool,
326
+ scheduler_type: str
327
  ) -> List[Image.Image]:
328
  batch_size = tokens.shape[0]
329
  latent_h = height // self.vae_scale_factor
 
338
  encoded_prompt = self.text_encoder(tokens, params = params['text_encoder'])[0]
339
  encoded_neg_prompt = self.text_encoder(neg_tokens, params = params['text_encoder'])[0]
340
 
341
+ scheduler = self.schedulers[scheduler_type]
342
+ scheduler_state = params[scheduler_type]
343
+
344
  if use_imagegen:
345
  image_latent_shape = (batch_size, self.vae.config.latent_channels, latent_h, latent_w)
346
  image_latents = jax.random.normal(
347
  rng,
348
  shape = image_latent_shape,
349
  dtype = jnp.float32
350
+ ) * scheduler_state.init_noise_sigma
351
+ image_scheduler_state = scheduler.set_timesteps(
352
+ scheduler_state,
353
  num_inference_steps = inference_steps,
354
  shape = image_latents.shape
355
  )
 
357
  image_latents, image_scheduler_state = args
358
  t = image_scheduler_state.timesteps[step]
359
  tt = jnp.broadcast_to(t, image_latents.shape[0])
360
+ latents_input = scheduler.scale_model_input(image_scheduler_state, image_latents, t)
361
  noise_pred = self.imunet.apply(
362
+ { 'params': params['imunet']} ,
363
  latents_input,
364
  tt,
365
  encoder_hidden_states = encoded_prompt
366
  ).sample
367
  noise_pred_uncond = self.imunet.apply(
368
+ { 'params': params['imunet'] },
369
  latents_input,
370
  tt,
371
  encoder_hidden_states = encoded_neg_prompt
372
  ).sample
373
  noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
374
+ image_latents, image_scheduler_state = scheduler.step(
375
  image_scheduler_state,
376
  noise_pred.astype(jnp.float32),
377
  t,
 
386
  hint = image_latents
387
  else:
388
  hint = self.vae.apply(
389
+ { 'params': params['vae'] },
390
  hint,
391
  method = self.vae.encode
392
  ).latent_dist.mean * self.vae.config.scaling_factor
 
404
  rng,
405
  shape = latent_shape,
406
  dtype = jnp.float32
407
+ ) * scheduler_state.init_noise_sigma
408
+ scheduler_state = scheduler.set_timesteps(
409
+ scheduler_state,
410
  num_inference_steps = inference_steps,
411
  shape = latents.shape
412
  )
 
415
  latents, scheduler_state = args
416
  t = scheduler_state.timesteps[step]#jnp.array(scheduler_state.timesteps, dtype = jnp.int32)[step]
417
  tt = jnp.broadcast_to(t, latents.shape[0])
418
+ latents_input = scheduler.scale_model_input(scheduler_state, latents, t)
419
  latents_input = jnp.concatenate([latents_input, mask, hint], axis = 1)
420
  noise_pred = self.unet.apply(
421
  { 'params': params['unet'] },
 
430
  encoded_neg_prompt
431
  ).sample
432
  noise_pred = noise_pred_uncond + cfg * (noise_pred - noise_pred_uncond)
433
+ latents, scheduler_state = scheduler.step(
434
  scheduler_state,
435
  noise_pred.astype(jnp.float32),
436
  t,
 
482
  None, # 7 height
483
  None, # 8 width
484
  None, # 9 cfg
485
+ None, # 10 cfg_image
486
+ 0, # 11 rng
487
+ 0, # 12 params
488
+ None, # 13 use_imagegen
489
+ None, # 14 scheduler_type
490
  ),
491
  static_broadcasted_argnums = ( # trigger recompilation on change
492
  0, # inference_class
 
494
  6, # num_frames
495
  7, # height
496
  8, # width
497
+ 13, # use_imagegen
498
+ 14, # scheduler_type
499
  )
500
  )
501
  def _p_generate(
 
504
  neg_tokens,
505
  hint,
506
  mask,
507
+ inference_steps: int,
508
+ num_frames: int,
509
+ height: int,
510
+ width: int,
511
+ cfg: float,
512
+ cfg_image: float,
513
  rng,
514
  params,
515
+ use_imagegen: bool,
516
+ scheduler_type: str
517
  ):
518
  return inference_class._generate(
519
  tokens,
 
525
  height,
526
  width,
527
  cfg,
528
+ cfg_image,
529
  rng,
530
  params,
531
+ use_imagegen,
532
+ scheduler_type
533
  )
534