AP123 commited on
Commit
78c118e
1 Parent(s): 9379ea4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -23
app.py CHANGED
@@ -150,46 +150,37 @@ def call(
150
 
151
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
152
  extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
153
-
154
  # 7. Prepare added time ids & embeddings
155
  add_text_embeds = pooled_prompt_embeds
156
  add_text2_embeds = pooled_prompt2_embeds
157
-
 
 
 
 
 
158
  add_time_ids = pipe._get_add_time_ids(
159
- original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
160
  )
 
 
 
161
  add_time2_ids = pipe._get_add_time_ids(
162
- original_size, crops_coords_top_left, target_size, dtype=prompt2_embeds.dtype
163
  )
164
-
165
  if negative_original_size is not None and negative_target_size is not None:
166
  negative_add_time_ids = pipe._get_add_time_ids(
167
  negative_original_size,
168
  negative_crops_coords_top_left,
169
  negative_target_size,
170
- dtype=prompt_embeds.dtype,
171
  )
172
  else:
173
  negative_add_time_ids = add_time_ids
174
  negative_add_time2_ids = add_time2_ids
175
 
176
- if do_classifier_free_guidance:
177
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
178
- add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
179
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
180
-
181
- prompt2_embeds = torch.cat([negative_prompt2_embeds, prompt2_embeds], dim=0)
182
- add_text2_embeds = torch.cat([negative_pooled_prompt2_embeds, add_text2_embeds], dim=0)
183
- add_time2_ids = torch.cat([negative_add_time2_ids, add_time2_ids], dim=0)
184
-
185
- prompt_embeds = prompt_embeds.to(device)
186
- add_text_embeds = add_text_embeds.to(device)
187
- add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
188
-
189
- prompt2_embeds = prompt2_embeds.to(device)
190
- add_text2_embeds = add_text2_embeds.to(device)
191
- add_time2_ids = add_time2_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
192
-
193
  # 8. Denoising loop
194
  num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
195
 
 
150
 
151
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
152
  extra_step_kwargs = pipe.prepare_extra_step_kwargs(generator, eta)
153
+
154
  # 7. Prepare added time ids & embeddings
155
  add_text_embeds = pooled_prompt_embeds
156
  add_text2_embeds = pooled_prompt2_embeds
157
+
158
+ # Default dtype if prompt_embeds or prompt2_embeds are None
159
+ default_dtype = torch.float32
160
+
161
+ # Check and set dtype for add_time_ids
162
+ dtype_for_add_time_ids = prompt_embeds.dtype if prompt_embeds is not None else default_dtype
163
  add_time_ids = pipe._get_add_time_ids(
164
+ original_size, crops_coords_top_left, target_size, dtype=dtype_for_add_time_ids
165
  )
166
+
167
+ # Check and set dtype for add_time2_ids
168
+ dtype_for_add_time2_ids = prompt2_embeds.dtype if prompt2_embeds is not None else default_dtype
169
  add_time2_ids = pipe._get_add_time_ids(
170
+ original_size, crops_coords_top_left, target_size, dtype=dtype_for_add_time2_ids
171
  )
172
+
173
  if negative_original_size is not None and negative_target_size is not None:
174
  negative_add_time_ids = pipe._get_add_time_ids(
175
  negative_original_size,
176
  negative_crops_coords_top_left,
177
  negative_target_size,
178
+ dtype=dtype_for_add_time_ids # Use the same default dtype for negative prompts
179
  )
180
  else:
181
  negative_add_time_ids = add_time_ids
182
  negative_add_time2_ids = add_time2_ids
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # 8. Denoising loop
185
  num_warmup_steps = max(len(timesteps) - num_inference_steps * pipe.scheduler.order, 0)
186