nigeljw commited on
Commit
738e1c7
1 Parent(s): 00972fe

Switched back to slerp for latents interpolation

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -45,6 +45,17 @@ def InitializeOutpainting():
45
  pipeline = StableDiffusionInpaintPipeline.from_pretrained(modelNames[modelIndex])
46
  #safety_checker=lambda images, **kwargs: (images, False))
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps):
49
  global lastImage, lastSeed, generator, oldLatentWalk, activeLatents
50
 
@@ -55,7 +66,7 @@ def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInf
55
  GenerateNewLatentsForInference()
56
 
57
  if oldLatentWalk != latentWalk:
58
- activeLatents = torch.lerp(oldLatents, latents, latentWalk)
59
  oldLatentWalk = latentWalk
60
 
61
  if lastSeed != generatorSeed:
 
45
  pipeline = StableDiffusionInpaintPipeline.from_pretrained(modelNames[modelIndex])
46
  #safety_checker=lambda images, **kwargs: (images, False))
47
 
48
+ # Based on: https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
49
+ # Further optimized to trade a divide operation for a multiply
50
+ def slerp(start, end, alpha):
51
+ start_norm = torch.norm(start, dim=1, keepdim=True)
52
+ end_norm = torch.norm(end, dim=1, keepdim=True)
53
+ omega = torch.acos((start*end/(start_norm*end_norm)).sum(1))
54
+ sinOmega = torch.sin(omega)
55
+ first = torch.sin((1.0-alpha)*omega)/sinOmega
56
+ second = torch.sin(alpha*omega)/sinOmega
57
+ return first.unsqueeze(1)*start + second.unsqueeze(1)*end
58
+
59
  def diffuse(latentWalk, staticLatents, generatorSeed, inputImage, mask, pauseInference, prompt, negativePrompt, guidanceScale, numInferenceSteps):
60
  global lastImage, lastSeed, generator, oldLatentWalk, activeLatents
61
 
 
66
  GenerateNewLatentsForInference()
67
 
68
  if oldLatentWalk != latentWalk:
69
+ activeLatents = slerp(oldLatents, latents, latentWalk)
70
  oldLatentWalk = latentWalk
71
 
72
  if lastSeed != generatorSeed: