pcuenq HF staff commited on
Commit
7865d26
1 Parent(s): b8850be

Improve error reporting

Browse files
Files changed (1) hide show
  1. app.py +131 -125
app.py CHANGED
@@ -85,142 +85,148 @@ def sample(
85
  Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
86
  image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
87
  """
88
- if(randomize_seed):
89
- seed = random.randint(0, max_64_bit_int)
 
 
 
 
 
 
90
 
91
- torch.manual_seed(seed)
92
-
93
- path = Path(input_path)
94
- all_img_paths = []
95
- if path.is_file():
96
- if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
97
- all_img_paths = [input_path]
 
 
 
 
 
 
 
 
 
 
98
  else:
99
- raise ValueError("Path is not valid image file.")
100
- elif path.is_dir():
101
- all_img_paths = sorted(
102
- [
103
- f
104
- for f in path.iterdir()
105
- if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
106
- ]
107
- )
108
- if len(all_img_paths) == 0:
109
- raise ValueError("Folder does not contain any images.")
110
- else:
111
- raise ValueError
 
112
 
113
- for input_img_path in all_img_paths:
114
- with Image.open(input_img_path) as image:
115
- if image.mode == "RGBA":
116
- image = image.convert("RGB")
117
- w, h = image.size
118
 
119
- if h % 64 != 0 or w % 64 != 0:
120
- width, height = map(lambda x: x - x % 64, (w, h))
121
- image = image.resize((width, height))
 
 
 
 
122
  print(
123
- f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
124
  )
125
-
126
- image = ToTensor()(image)
127
- image = image * 2.0 - 1.0
128
-
129
- image = image.unsqueeze(0).to(device)
130
- H, W = image.shape[2:]
131
- assert image.shape[1] == 3
132
- F = 8
133
- C = 4
134
- shape = (num_frames, C, H // F, W // F)
135
- if (H, W) != (576, 1024):
136
- print(
137
- "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
138
- )
139
- if motion_bucket_id > 255:
140
- print(
141
- "WARNING: High motion bucket! This may lead to suboptimal performance."
142
- )
143
-
144
- if fps_id < 5:
145
- print("WARNING: Small fps value! This may lead to suboptimal performance.")
146
-
147
- if fps_id > 30:
148
- print("WARNING: Large fps value! This may lead to suboptimal performance.")
149
-
150
- value_dict = {}
151
- value_dict["motion_bucket_id"] = motion_bucket_id
152
- value_dict["fps_id"] = fps_id
153
- value_dict["cond_aug"] = cond_aug
154
- value_dict["cond_frames_without_noise"] = image
155
- value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
156
- value_dict["cond_aug"] = cond_aug
157
-
158
- with torch.no_grad():
159
- with torch.autocast(device):
160
- batch, batch_uc = get_batch(
161
- get_unique_embedder_keys_from_conditioner(model.conditioner),
162
- value_dict,
163
- [1, num_frames],
164
- T=num_frames,
165
- device=device,
166
- )
167
- c, uc = model.conditioner.get_unconditional_conditioning(
168
- batch,
169
- batch_uc=batch_uc,
170
- force_uc_zero_embeddings=[
171
- "cond_frames",
172
- "cond_frames_without_noise",
173
- ],
174
  )
175
 
176
- for k in ["crossattn", "concat"]:
177
- uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
178
- uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
179
- c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
180
- c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
181
-
182
- randn = torch.randn(shape, device=device)
183
-
184
- additional_model_inputs = {}
185
- additional_model_inputs["image_only_indicator"] = torch.zeros(
186
- 2, num_frames
187
- ).to(device)
188
- additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
189
-
190
- def denoiser(input, sigma, c):
191
- return model.denoiser(
192
- model.model, input, sigma, c, **additional_model_inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
194
 
195
- samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
196
- model.en_and_decode_n_samples_a_time = decoding_t
197
- samples_x = model.decode_first_stage(samples_z)
198
- samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
199
-
200
- os.makedirs(output_folder, exist_ok=True)
201
- base_count = len(glob(os.path.join(output_folder, "*.mp4")))
202
- video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
203
- writer = cv2.VideoWriter(
204
- video_path,
205
- cv2.VideoWriter_fourcc(*"mp4v"),
206
- fps_id + 1,
207
- (samples.shape[-1], samples.shape[-2]),
208
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
- samples = embed_watermark(samples)
211
- samples = filter(samples)
212
- vid = (
213
- (rearrange(samples, "t c h w -> t h w c") * 255)
214
- .cpu()
215
- .numpy()
216
- .astype(np.uint8)
217
- )
218
- for frame in vid:
219
- frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
220
- writer.write(frame)
221
- writer.release()
222
-
223
- return video_path, seed
 
 
224
 
225
  def get_unique_embedder_keys_from_conditioner(conditioner):
226
  return list(set([x.input_key for x in conditioner.embedders]))
 
85
  Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
86
  image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
87
  """
88
+ try:
89
+ if input_path is None:
90
+ raise ValueError("No image")
91
+
92
+ if(randomize_seed):
93
+ seed = random.randint(0, max_64_bit_int)
94
+
95
+ torch.manual_seed(seed)
96
 
97
+ path = Path(input_path)
98
+ all_img_paths = []
99
+ if path.is_file():
100
+ if any([input_path.endswith(x) for x in ["jpg", "jpeg", "png"]]):
101
+ all_img_paths = [input_path]
102
+ else:
103
+ raise ValueError("Unsupported image type.")
104
+ elif path.is_dir():
105
+ all_img_paths = sorted(
106
+ [
107
+ f
108
+ for f in path.iterdir()
109
+ if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
110
+ ]
111
+ )
112
+ if len(all_img_paths) == 0:
113
+ raise ValueError("Folder does not contain any images.")
114
  else:
115
+ raise ValueError("No image")
116
+
117
+ for input_img_path in all_img_paths:
118
+ with Image.open(input_img_path) as image:
119
+ if image.mode == "RGBA":
120
+ image = image.convert("RGB")
121
+ w, h = image.size
122
+
123
+ if h % 64 != 0 or w % 64 != 0:
124
+ width, height = map(lambda x: x - x % 64, (w, h))
125
+ image = image.resize((width, height))
126
+ print(
127
+ f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
128
+ )
129
 
130
+ image = ToTensor()(image)
131
+ image = image * 2.0 - 1.0
 
 
 
132
 
133
+ image = image.unsqueeze(0).to(device)
134
+ H, W = image.shape[2:]
135
+ assert image.shape[1] == 3
136
+ F = 8
137
+ C = 4
138
+ shape = (num_frames, C, H // F, W // F)
139
+ if (H, W) != (576, 1024):
140
  print(
141
+ "WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
142
  )
143
+ if motion_bucket_id > 255:
144
+ print(
145
+ "WARNING: High motion bucket! This may lead to suboptimal performance."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
 
148
+ if fps_id < 5:
149
+ print("WARNING: Small fps value! This may lead to suboptimal performance.")
150
+
151
+ if fps_id > 30:
152
+ print("WARNING: Large fps value! This may lead to suboptimal performance.")
153
+
154
+ value_dict = {}
155
+ value_dict["motion_bucket_id"] = motion_bucket_id
156
+ value_dict["fps_id"] = fps_id
157
+ value_dict["cond_aug"] = cond_aug
158
+ value_dict["cond_frames_without_noise"] = image
159
+ value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
160
+ value_dict["cond_aug"] = cond_aug
161
+
162
+ with torch.no_grad():
163
+ with torch.autocast(device):
164
+ batch, batch_uc = get_batch(
165
+ get_unique_embedder_keys_from_conditioner(model.conditioner),
166
+ value_dict,
167
+ [1, num_frames],
168
+ T=num_frames,
169
+ device=device,
170
+ )
171
+ c, uc = model.conditioner.get_unconditional_conditioning(
172
+ batch,
173
+ batch_uc=batch_uc,
174
+ force_uc_zero_embeddings=[
175
+ "cond_frames",
176
+ "cond_frames_without_noise",
177
+ ],
178
  )
179
 
180
+ for k in ["crossattn", "concat"]:
181
+ uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
182
+ uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
183
+ c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
184
+ c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
185
+
186
+ randn = torch.randn(shape, device=device)
187
+
188
+ additional_model_inputs = {}
189
+ additional_model_inputs["image_only_indicator"] = torch.zeros(
190
+ 2, num_frames
191
+ ).to(device)
192
+ additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
193
+
194
+ def denoiser(input, sigma, c):
195
+ return model.denoiser(
196
+ model.model, input, sigma, c, **additional_model_inputs
197
+ )
198
+
199
+ samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
200
+ model.en_and_decode_n_samples_a_time = decoding_t
201
+ samples_x = model.decode_first_stage(samples_z)
202
+ samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
203
+
204
+ os.makedirs(output_folder, exist_ok=True)
205
+ base_count = len(glob(os.path.join(output_folder, "*.mp4")))
206
+ video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
207
+ writer = cv2.VideoWriter(
208
+ video_path,
209
+ cv2.VideoWriter_fourcc(*"mp4v"),
210
+ fps_id + 1,
211
+ (samples.shape[-1], samples.shape[-2]),
212
+ )
213
 
214
+ samples = embed_watermark(samples)
215
+ samples = filter(samples)
216
+ vid = (
217
+ (rearrange(samples, "t c h w -> t h w c") * 255)
218
+ .cpu()
219
+ .numpy()
220
+ .astype(np.uint8)
221
+ )
222
+ for frame in vid:
223
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
224
+ writer.write(frame)
225
+ writer.release()
226
+
227
+ return video_path, seed
228
+ except Exception as e:
229
+ raise gr.Error(e.args[0] if len(e.args) > 0 else "Sampling error")
230
 
231
  def get_unique_embedder_keys_from_conditioner(conditioner):
232
  return list(set([x.input_key for x in conditioner.embedders]))