poipiii
commited on
Commit
•
e611cbb
1
Parent(s):
6dfb556
test interpolate latent
Browse files- pipeline.py +7 -1
pipeline.py
CHANGED
@@ -665,6 +665,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
665 |
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
666 |
height: int = 512,
|
667 |
width: int = 512,
|
|
|
668 |
num_inference_steps: int = 50,
|
669 |
guidance_scale: float = 7.5,
|
670 |
strength: float = 0.8,
|
@@ -841,10 +842,15 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
841 |
return None
|
842 |
print(latents)
|
843 |
print(latents.shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
844 |
# 9. Post-processing
|
845 |
image = self.decode_latents(latents)
|
846 |
|
847 |
-
#do latent upscale here
|
848 |
|
849 |
# 10. Run safety checker
|
850 |
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
|
|
665 |
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
666 |
height: int = 512,
|
667 |
width: int = 512,
|
668 |
+
resize_scale: float = 1.2,
|
669 |
num_inference_steps: int = 50,
|
670 |
guidance_scale: float = 7.5,
|
671 |
strength: float = 0.8,
|
|
|
842 |
return None
|
843 |
print(latents)
|
844 |
print(latents.shape)
|
845 |
+
resized_image = torch.nn.functional.interpolate(
|
846 |
+
latents, size=(int(latents.shape[2]*resize_scale)//8, int(latents.shape[3]*resize_scale)//8))
|
847 |
+
|
848 |
+
print(resized_image.shape)
|
849 |
+
#do latent upscale here
|
850 |
+
|
851 |
# 9. Post-processing
|
852 |
image = self.decode_latents(latents)
|
853 |
|
|
|
854 |
|
855 |
# 10. Run safety checker
|
856 |
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|