lopho commited on
Commit
8c4daf1
1 Parent(s): 84726ba

gradio app

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app.py +259 -0
  3. packages.txt +0 -0
  4. pre-requirements.txt +5 -0
  5. requirements.txt +11 -0
.gitattributes CHANGED
@@ -1,3 +1,4 @@
 
1
  *.7z filter=lfs diff=lfs merge=lfs -text
2
  *.arrow filter=lfs diff=lfs merge=lfs -text
3
  *.bin filter=lfs diff=lfs merge=lfs -text
 
1
+ *.webp filter=lfs diff=lfs merge=lfs -text
2
  *.7z filter=lfs diff=lfs merge=lfs -text
3
  *.arrow filter=lfs diff=lfs merge=lfs -text
4
  *.bin filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from io import BytesIO
4
+ import base64
5
+ from functools import partial
6
+
7
+ from PIL import Image, ImageOps
8
+ import gradio as gr
9
+
10
+ from makeavid_sd.inference import InferenceUNetPseudo3D, FlaxDPMSolverMultistepScheduler, jnp
11
+
12
+
13
+ _preheat: bool = False
14
+
15
+ _seen_compilations = set()
16
+
17
+ _model = InferenceUNetPseudo3D(
18
+ model_path = 'TempoFunk/makeavid-sd-jax',
19
+ scheduler_cls = FlaxDPMSolverMultistepScheduler,
20
+ dtype = jnp.float16,
21
+ hf_auth_token = os.environ.get('HUGGING_FACE_HUB_TOKEN', None)
22
+ )
23
+
24
+ # gradio is illiterate. type hints make it go poopoo in pantsu.
25
+ def generate(
26
+ prompt = 'An elderly man having a great time in the park.',
27
+ neg_prompt = '',
28
+ image = { 'image': None, 'mask': None },
29
+ inference_steps = 20,
30
+ cfg = 12.0,
31
+ seed = 0,
32
+ fps = 24,
33
+ num_frames = 24,
34
+ height = 512,
35
+ width = 512
36
+ ) -> str:
37
+ height = int(height)
38
+ width = int(width)
39
+ num_frames = int(num_frames)
40
+ seed = int(seed)
41
+ if seed < 0:
42
+ seed = -seed
43
+ inference_steps = int(inference_steps)
44
+ if image is not None:
45
+ hint_image = image['image']
46
+ mask_image = image['mask']
47
+ else:
48
+ hint_image = None
49
+ mask_image = None
50
+ if hint_image is not None:
51
+ if hint_image.mode != 'RGB':
52
+ hint_image = hint_image.convert('RGB')
53
+ if hint_image.size != (width, height):
54
+ hint_image = ImageOps.fit(hint_image, (width, height), method = Image.Resampling.LANCZOS)
55
+ if mask_image is not None:
56
+ if mask_image.mode != 'L':
57
+ mask_image = mask_image.convert('L')
58
+ if mask_image.size != (width, height):
59
+ mask_image = ImageOps.fit(mask_image, (width, height), method = Image.Resampling.LANCZOS)
60
+ images = _model.generate(
61
+ prompt = [prompt] * _model.device_count,
62
+ neg_prompt = neg_prompt,
63
+ hint_image = hint_image,
64
+ mask_image = mask_image,
65
+ inference_steps = inference_steps,
66
+ cfg = cfg,
67
+ height = height,
68
+ width = width,
69
+ num_frames = num_frames,
70
+ seed = seed
71
+ )
72
+ _seen_compilations.add((hint_image is None, inference_steps, height, width, num_frames))
73
+ buffer = BytesIO()
74
+ images[0].save(
75
+ buffer,
76
+ format = 'webp',
77
+ save_all = True,
78
+ append_images = images[1:],
79
+ loop = 0,
80
+ duration = round(1000 / fps),
81
+ allow_mixed = True
82
+ )
83
+ data = base64.b64encode(buffer.getvalue()).decode()
84
+ data = 'data:image/webp;base64,' + data
85
+ buffer.close()
86
+ return data
87
+
88
+ def check_if_compiled(image, inference_steps, height, width, num_frames, message):
89
+ height = int(height)
90
+ width = int(width)
91
+ hint_image = None if image is None else image['image']
92
+ if (hint_image is None, inference_steps, height, width, num_frames) in _seen_compilations:
93
+ return ''
94
+ else:
95
+ return f"""{message}"""
96
+
97
+ if _preheat:
98
+ print('\npreheating the oven')
99
+ generate(
100
+ prompt = 'preheating the oven',
101
+ neg_prompt = '',
102
+ image = { 'image': None, 'mask': None },
103
+ inference_steps = 20,
104
+ cfg = 12.0,
105
+ seed = 0
106
+ )
107
+ print('Entertaining the guests with sailor songs played on an old piano.')
108
+ dada = generate(
109
+ prompt = 'Entertaining the guests with sailor songs played on an old harmonium.',
110
+ neg_prompt = '',
111
+ image = { 'image': Image.new('RGB', size = (512, 512), color = (0, 0, 0)), 'mask': None },
112
+ inference_steps = 20,
113
+ cfg = 12.0,
114
+ seed = 0
115
+ )
116
+ print('dinner is ready\n')
117
+
118
+ with gr.Blocks(title = 'Make-A-Video Stable Diffusion JAX', analytics_enabled = False) as demo:
119
+ variant = 'panel'
120
+ with gr.Row():
121
+ with gr.Column():
122
+ intro1 = gr.Markdown("""
123
+ # Make-A-Video Stable Diffusion JAX
124
+ **Please be patient. The model might have to compile with current parameters.**
125
+
126
+ This can take up to 5 minutes on the first run, and 2-3 minutes on later runs.
127
+ The compilation will be cached and consecutive runs with the same parameters
128
+ will be much faster.
129
+ """)
130
+ with gr.Column():
131
+ intro2 = gr.Markdown("""
132
+ The following parameters require the model to compile
133
+ - Number of frames
134
+ - Width & Height
135
+ - Steps
136
+ - Input image vs. no input image
137
+ """)
138
+
139
+ with gr.Row(variant = variant):
140
+ with gr.Column(variant = variant):
141
+ with gr.Row():
142
+ cancel_button = gr.Button(value = 'Cancel')
143
+ submit_button = gr.Button(value = 'Make A Video', variant = 'primary')
144
+ prompt_input = gr.Textbox(
145
+ label = 'Prompt',
146
+ value = 'They are dancing in the club while sweat drips from the ceiling.',
147
+ interactive = True
148
+ )
149
+ neg_prompt_input = gr.Textbox(
150
+ label = 'Negative prompt (optional)',
151
+ value = '',
152
+ interactive = True
153
+ )
154
+ inference_steps_input = gr.Slider(
155
+ label = 'Steps',
156
+ minimum = 1,
157
+ maximum = 100,
158
+ value = 20,
159
+ step = 1
160
+ )
161
+ cfg_input = gr.Slider(
162
+ label = 'Guidance scale',
163
+ minimum = 1.0,
164
+ maximum = 20.0,
165
+ step = 0.1,
166
+ value = 15.0,
167
+ interactive = True
168
+ )
169
+ seed_input = gr.Number(
170
+ label = 'Random seed',
171
+ value = 0,
172
+ interactive = True,
173
+ precision = 0
174
+ )
175
+ image_input = gr.Image(
176
+ label = 'Input image (optional)',
177
+ interactive = True,
178
+ image_mode = 'RGB',
179
+ type = 'pil',
180
+ optional = True,
181
+ source = 'upload',
182
+ tool = 'sketch'
183
+ )
184
+ num_frames_input = gr.Slider(
185
+ label = 'Number of frames to generate',
186
+ minimum = 1,
187
+ maximum = 24,
188
+ step = 1,
189
+ value = 24
190
+ )
191
+ width_input = gr.Slider(
192
+ label = 'Width',
193
+ minimum = 64,
194
+ maximum = 512,
195
+ step = 1,
196
+ value = 448
197
+ )
198
+ height_input = gr.Slider(
199
+ label = 'Height',
200
+ minimum = 64,
201
+ maximum = 512,
202
+ step = 1,
203
+ value = 448
204
+ )
205
+ fps_input = gr.Slider(
206
+ label = 'Output FPS',
207
+ minimum = 1,
208
+ maximum = 1000,
209
+ step = 1,
210
+ value = 12
211
+ )
212
+ with gr.Column(variant = variant):
213
+ will_trigger = gr.Markdown('')
214
+ patience = gr.Markdown('')
215
+ image_output = gr.Image(
216
+ label = 'Output',
217
+ value = 'example.webp',
218
+ interactive = False
219
+ )
220
+ trigger_inputs = [ image_input, inference_steps_input, height_input, width_input, num_frames_input ]
221
+ trigger_check_fun = partial(check_if_compiled, message = 'Current parameters will trigger compilation.')
222
+ height_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
223
+ width_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
224
+ num_frames_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
225
+ inference_steps_input.change(fn = trigger_check_fun, inputs = trigger_inputs, outputs = will_trigger)
226
+ will_trigger.value = trigger_check_fun(image_input.value, inference_steps_input.value, height_input.value, width_input.value, num_frames_input.value)
227
+ ev = submit_button.click(
228
+ fn = partial(
229
+ check_if_compiled,
230
+ message = 'Please be patient. The model has to be compiled with current parameters.'
231
+ ),
232
+ inputs = trigger_inputs,
233
+ outputs = patience
234
+ ).then(
235
+ fn = generate,
236
+ inputs = [
237
+ prompt_input,
238
+ neg_prompt_input,
239
+ image_input,
240
+ inference_steps_input,
241
+ cfg_input,
242
+ seed_input,
243
+ fps_input,
244
+ num_frames_input,
245
+ height_input,
246
+ width_input
247
+ ],
248
+ outputs = image_output,
249
+ postprocess = False
250
+ ).then(
251
+ fn = trigger_check_fun,
252
+ inputs = trigger_inputs,
253
+ outputs = will_trigger
254
+ )
255
+ cancel_button(cancels = ev)
256
+
257
+ demo.queue(concurrency_count = 1, max_size = 16)
258
+ demo.launch()
259
+
packages.txt ADDED
File without changes
pre-requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pip
2
+ setuptools
3
+ wheel
4
+ ninja
5
+ cmake
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ pillow
3
+ transformers
4
+ diffusers
5
+ einops
6
+ git+https://github.com/lopho/makeavid-sd-tpu.git
7
+ -f https://download.pytorch.org/whl/cpu/torch
8
+ torch[cpu]
9
+ -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
10
+ jax[cuda11_cudnn805] #jax[cuda11_cudnn86] #jax[cuda11_cudnn805]
11
+ flax