Spaces:
Paused
Paused
Switched back to slerp for latents interpolation
Browse files
app.py
CHANGED
@@ -45,6 +45,17 @@ def InitializeOutpainting():
|
|
45 |
pipeline = StableDiffusionInpaintPipeline.from_pretrained(modelNames[modelIndex])
|
46 |
#safety_checker=lambda images, **kwargs: (images, False))
|
47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps):
|
49 |
global lastImage, lastSeed, generator, oldLatentWalk, activeLatents
|
50 |
|
@@ -55,7 +66,7 @@ def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInf
|
|
55 |
GenerateNewLatentsForInference()
|
56 |
|
57 |
if oldLatentWalk != latentWalk:
|
58 |
-
activeLatents =
|
59 |
oldLatentWalk = latentWalk
|
60 |
|
61 |
if lastSeed != generatorSeed:
|
|
|
45 |
pipeline = StableDiffusionInpaintPipeline.from_pretrained(modelNames[modelIndex])
|
46 |
#safety_checker=lambda images, **kwargs: (images, False))
|
47 |
|
48 |
+
# Based on: https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
|
49 |
+
# Further optimized to trade a divide operation for a multiply
|
50 |
+
def slerp(start, end, alpha):
|
51 |
+
start_norm = torch.norm(start, dim=1, keepdim=True)
|
52 |
+
end_norm = torch.norm(end, dim=1, keepdim=True)
|
53 |
+
omega = torch.acos((start*end/(start_norm*end_norm)).sum(1))
|
54 |
+
sinOmega = torch.sin(omega)
|
55 |
+
first = torch.sin((1.0-alpha)*omega)/sinOmega
|
56 |
+
second = torch.sin(alpha*omega)/sinOmega
|
57 |
+
return first.unsqueeze(1)*start + second.unsqueeze(1)*end
|
58 |
+
|
59 |
def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps):
|
60 |
global lastImage, lastSeed, generator, oldLatentWalk, activeLatents
|
61 |
|
|
|
66 |
GenerateNewLatentsForInference()
|
67 |
|
68 |
if oldLatentWalk != latentWalk:
|
69 |
+
activeLatents = slerp(oldLatents, latents, latentWalk)
|
70 |
oldLatentWalk = latentWalk
|
71 |
|
72 |
if lastSeed != generatorSeed:
|