pcuenq HF staff commited on
Commit
4ebb8b5
1 Parent(s): 7865d26

Use PIL instead of path for image component.

Browse files

This creates two temporary files in /tmp instead of 3, according to my
tests.

Files changed (1) hide show
  1. app.py +111 -147
app.py CHANGED
@@ -69,7 +69,7 @@ model, filter = load_model(
69
  )
70
 
71
  def sample(
72
- input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
73
  seed: Optional[int] = None,
74
  randomize_seed: bool = True,
75
  motion_bucket_id: int = 127,
@@ -81,152 +81,118 @@ def sample(
81
  output_folder: str = "outputs",
82
  progress=gr.Progress(track_tqdm=True)
83
  ):
84
- """
85
- Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
86
- image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
87
- """
88
- try:
89
- if input_path is None:
90
- raise ValueError("No image")
91
-
92
- if(randomize_seed):
93
- seed = random.randint(0, max_64_bit_int)
94
-
95
- torch.manual_seed(seed)
96
 
97
- path = Path(input_path)
98
- all_img_paths = []
99
- if path.is_file():
100
- if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
101
- all_img_paths = [input_path]
102
- else:
103
- raise ValueError("Unsupported image type.")
104
- elif path.is_dir():
105
- all_img_paths = sorted(
106
- [
107
- f
108
- for f in path.iterdir()
109
- if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
110
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
- if len(all_img_paths) == 0:
113
- raise ValueError("Folder does not contain any images.")
114
- else:
115
- raise ValueError("No image")
116
-
117
- for input_img_path in all_img_paths:
118
- with Image.open(input_img_path) as image:
119
- if image.mode == "RGBA":
120
- image = image.convert("RGB")
121
- w, h = image.size
122
-
123
- if h % 64 != 0 or w % 64 != 0:
124
- width, height = map(lambda x: x - x % 64, (w, h))
125
- image = image.resize((width, height))
126
- print(
127
- f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
128
- )
129
-
130
- image = ToTensor()(image)
131
- image = image * 2.0 - 1.0
132
-
133
- image = image.unsqueeze(0).to(device)
134
- H, W = image.shape[2:]
135
- assert image.shape[1] == 3
136
- F = 8
137
- C = 4
138
- shape = (num_frames, C, H // F, W // F)
139
- if (H, W) != (576, 1024):
140
- print(
141
- "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
142
- )
143
- if motion_bucket_id > 255:
144
- print(
145
- "WARNING: High motion bucket! This may lead to suboptimal performance."
146
  )
147
 
148
- if fps_id < 5:
149
- print("WARNING: Small fps value! This may lead to suboptimal performance.")
150
-
151
- if fps_id > 30:
152
- print("WARNING: Large fps value! This may lead to suboptimal performance.")
153
-
154
- value_dict = {}
155
- value_dict["motion_bucket_id"] = motion_bucket_id
156
- value_dict["fps_id"] = fps_id
157
- value_dict["cond_aug"] = cond_aug
158
- value_dict["cond_frames_without_noise"] = image
159
- value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
160
- value_dict["cond_aug"] = cond_aug
161
-
162
- with torch.no_grad():
163
- with torch.autocast(device):
164
- batch, batch_uc = get_batch(
165
- get_unique_embedder_keys_from_conditioner(model.conditioner),
166
- value_dict,
167
- [1, num_frames],
168
- T=num_frames,
169
- device=device,
170
- )
171
- c, uc = model.conditioner.get_unconditional_conditioning(
172
- batch,
173
- batch_uc=batch_uc,
174
- force_uc_zero_embeddings=[
175
- "cond_frames",
176
- "cond_frames_without_noise",
177
- ],
178
- )
179
-
180
- for k in ["crossattn", "concat"]:
181
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
182
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
183
- c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
184
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
185
-
186
- randn = torch.randn(shape, device=device)
187
-
188
- additional_model_inputs = {}
189
- additional_model_inputs["image_only_indicator"] = torch.zeros(
190
- 2, num_frames
191
- ).to(device)
192
- additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
193
-
194
- def denoiser(input, sigma, c):
195
- return model.denoiser(
196
- model.model, input, sigma, c, **additional_model_inputs
197
- )
198
-
199
- samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
200
- model.en_and_decode_n_samples_a_time = decoding_t
201
- samples_x = model.decode_first_stage(samples_z)
202
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
203
-
204
- os.makedirs(output_folder, exist_ok=True)
205
- base_count = len(glob(os.path.join(output_folder, "*.mp4")))
206
- video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
207
- writer = cv2.VideoWriter(
208
- video_path,
209
- cv2.VideoWriter_fourcc(*"mp4v"),
210
- fps_id + 1,
211
- (samples.shape[-1], samples.shape[-2]),
212
- )
213
-
214
- samples = embed_watermark(samples)
215
- samples = filter(samples)
216
- vid = (
217
- (rearrange(samples, "t c h w -> t h w c") * 255)
218
- .cpu()
219
- .numpy()
220
- .astype(np.uint8)
221
- )
222
- for frame in vid:
223
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
224
- writer.write(frame)
225
- writer.release()
226
-
227
- return video_path, seed
228
- except Exception as e:
229
- raise gr.Error(e.args[0] if len(e.args) > 0 else "Sampling error")
230
 
231
  def get_unique_embedder_keys_from_conditioner(conditioner):
232
  return list(set([x.input_key for x in conditioner.embedders]))
@@ -272,8 +238,7 @@ def get_batch(keys, value_dict, N, T, device):
272
  batch_uc[key] = torch.clone(batch[key])
273
  return batch, batch_uc
274
 
275
- def resize_image(image_path, output_size=(1024, 576)):
276
- image = Image.open(image_path)
277
  # Calculate aspect ratios
278
  target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
279
  image_aspect = image.width / image.height # Aspect ratio of the original image
@@ -302,7 +267,6 @@ def resize_image(image_path, output_size=(1024, 576)):
302
 
303
  # Crop the image
304
  cropped_image = resized_image.crop((left, top, right, bottom))
305
-
306
  return cropped_image
307
 
308
  with gr.Blocks() as demo:
@@ -311,7 +275,7 @@ with gr.Blocks() as demo:
311
  ''')
312
  with gr.Row():
313
  with gr.Column():
314
- image = gr.Image(label="Upload your image", type="filepath")
315
  generate_btn = gr.Button("Generate")
316
  video = gr.Video()
317
  with gr.Accordion("Advanced options", open=False):
 
69
  )
70
 
71
  def sample(
72
+ image: Image,
73
  seed: Optional[int] = None,
74
  randomize_seed: bool = True,
75
  motion_bucket_id: int = 127,
 
81
  output_folder: str = "outputs",
82
  progress=gr.Progress(track_tqdm=True)
83
  ):
84
+ if(randomize_seed):
85
+ seed = random.randint(0, max_64_bit_int)
 
 
 
 
 
 
 
 
 
 
86
 
87
+ torch.manual_seed(seed)
88
+
89
+ if image.mode == "RGBA":
90
+ image = image.convert("RGB")
91
+ w, h = image.size
92
+
93
+ if h % 64 != 0 or w % 64 != 0:
94
+ width, height = map(lambda x: x - x % 64, (w, h))
95
+ image = image.resize((width, height))
96
+ print(
97
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
98
+ )
99
+
100
+ image = ToTensor()(image)
101
+ image = image * 2.0 - 1.0
102
+ image = image.unsqueeze(0).to(device)
103
+ H, W = image.shape[2:]
104
+ assert image.shape[1] == 3
105
+ F = 8
106
+ C = 4
107
+ shape = (num_frames, C, H // F, W // F)
108
+ if (H, W) != (576, 1024):
109
+ print(
110
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
111
+ )
112
+ if motion_bucket_id > 255:
113
+ print(
114
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
115
+ )
116
+
117
+ if fps_id < 5:
118
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
119
+
120
+ if fps_id > 30:
121
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
122
+
123
+ value_dict = {}
124
+ value_dict["motion_bucket_id"] = motion_bucket_id
125
+ value_dict["fps_id"] = fps_id
126
+ value_dict["cond_aug"] = cond_aug
127
+ value_dict["cond_frames_without_noise"] = image
128
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
129
+ value_dict["cond_aug"] = cond_aug
130
+
131
+ with torch.no_grad():
132
+ with torch.autocast(device):
133
+ batch, batch_uc = get_batch(
134
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
135
+ value_dict,
136
+ [1, num_frames],
137
+ T=num_frames,
138
+ device=device,
139
  )
140
+ c, uc = model.conditioner.get_unconditional_conditioning(
141
+ batch,
142
+ batch_uc=batch_uc,
143
+ force_uc_zero_embeddings=[
144
+ "cond_frames",
145
+ "cond_frames_without_noise",
146
+ ],
147
+ )
148
+
149
+ for k in ["crossattn", "concat"]:
150
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
151
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
152
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
153
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
154
+
155
+ randn = torch.randn(shape, device=device)
156
+
157
+ additional_model_inputs = {}
158
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
159
+ 2, num_frames
160
+ ).to(device)
161
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
162
+
163
+ def denoiser(input, sigma, c):
164
+ return model.denoiser(
165
+ model.model, input, sigma, c, **additional_model_inputs
 
 
 
 
 
 
 
 
166
  )
167
 
168
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
169
+ model.en_and_decode_n_samples_a_time = decoding_t
170
+ samples_x = model.decode_first_stage(samples_z)
171
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
172
+
173
+ os.makedirs(output_folder, exist_ok=True)
174
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
175
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
176
+ writer = cv2.VideoWriter(
177
+ video_path,
178
+ cv2.VideoWriter_fourcc(*"mp4v"),
179
+ fps_id + 1,
180
+ (samples.shape[-1], samples.shape[-2]),
181
+ )
182
+
183
+ samples = embed_watermark(samples)
184
+ samples = filter(samples)
185
+ vid = (
186
+ (rearrange(samples, "t c h w -> t h w c") * 255)
187
+ .cpu()
188
+ .numpy()
189
+ .astype(np.uint8)
190
+ )
191
+ for frame in vid:
192
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
193
+ writer.write(frame)
194
+ writer.release()
195
+ return video_path, seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  def get_unique_embedder_keys_from_conditioner(conditioner):
198
  return list(set([x.input_key for x in conditioner.embedders]))
 
238
  batch_uc[key] = torch.clone(batch[key])
239
  return batch, batch_uc
240
 
241
+ def resize_image(image, output_size=(1024, 576)):
 
242
  # Calculate aspect ratios
243
  target_aspect = output_size[0] / output_size[1] # Aspect ratio of the desired size
244
  image_aspect = image.width / image.height # Aspect ratio of the original image
 
267
 
268
  # Crop the image
269
  cropped_image = resized_image.crop((left, top, right, bottom))
 
270
  return cropped_image
271
 
272
  with gr.Blocks() as demo:
 
275
  ''')
276
  with gr.Row():
277
  with gr.Column():
278
+ image = gr.Image(label="Upload your image", type="pil")
279
  generate_btn = gr.Button("Generate")
280
  video = gr.Video()
281
  with gr.Accordion("Advanced options", open=False):