Spaces:
Sleeping
Sleeping
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +5 -4
clip_slider_pipeline.py
CHANGED
@@ -18,7 +18,7 @@ class CLIPSlider:
|
|
18 |
):
|
19 |
|
20 |
self.device = device
|
21 |
-
self.pipe = sd_pipe.to(self.device)
|
22 |
self.iterations = iterations
|
23 |
if target_word != "" or opposite != "":
|
24 |
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
@@ -280,13 +280,14 @@ class CLIPSliderXL(CLIPSlider):
|
|
280 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
281 |
prompt_embeds_list.append(prompt_embeds)
|
282 |
|
283 |
-
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
284 |
-
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
285 |
end_time = time.time()
|
|
|
286 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
287 |
torch.manual_seed(seed)
|
288 |
start_time = time.time()
|
289 |
-
image = self.pipe(prompt_embeds=prompt_embeds
|
290 |
**pipeline_kwargs).images[0]
|
291 |
end_time = time.time()
|
292 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
|
|
18 |
):
|
19 |
|
20 |
self.device = device
|
21 |
+
self.pipe = sd_pipe.to(self.device, torch.float16)
|
22 |
self.iterations = iterations
|
23 |
if target_word != "" or opposite != "":
|
24 |
self.avg_diff = self.find_latent_direction(target_word, opposite)
|
|
|
280 |
prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
|
281 |
prompt_embeds_list.append(prompt_embeds)
|
282 |
|
283 |
+
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(torch.float16)
|
284 |
+
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1).to(torch.float16)
|
285 |
end_time = time.time()
|
286 |
+
print("prompt_embeds", prompt_embeds.dtype)
|
287 |
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
288 |
torch.manual_seed(seed)
|
289 |
start_time = time.time()
|
290 |
+
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
291 |
**pipeline_kwargs).images[0]
|
292 |
end_time = time.time()
|
293 |
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|