Jordan Legg commited on
Commit
9d86930
β€’
1 Parent(s): 86f0308

re added projection layer to suit the x embedder input

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -1,8 +1,8 @@
1
- import spaces
2
  import gradio as gr
3
  import numpy as np
4
  import random
5
  import torch
 
6
  from PIL import Image
7
  from torchvision import transforms
8
  from diffusers import DiffusionPipeline
@@ -19,6 +19,9 @@ pipe.enable_model_cpu_offload()
19
  pipe.vae.enable_slicing()
20
  pipe.vae.enable_tiling()
21
 
 
 
 
22
  def preprocess_image(image, image_size):
23
  preprocess = transforms.Compose([
24
  transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
@@ -40,6 +43,10 @@ def process_latents(latents, height, width):
40
  latents = latents.permute(0, 2, 3, 1).reshape(1, -1, latents.shape[1])
41
  print(f"Reshaped latent shape: {latents.shape}")
42
 
 
 
 
 
43
  return latents
44
 
45
  @spaces.GPU()
 
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
  import torch
5
+ import torch.nn as nn
6
  from PIL import Image
7
  from torchvision import transforms
8
  from diffusers import DiffusionPipeline
 
19
  pipe.vae.enable_slicing()
20
  pipe.vae.enable_tiling()
21
 
22
+ # Add a projection layer to match x_embedder input
23
+ projection = nn.Linear(16, 64).to(device).to(dtype)
24
+
25
  def preprocess_image(image, image_size):
26
  preprocess = transforms.Compose([
27
  transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
 
43
  latents = latents.permute(0, 2, 3, 1).reshape(1, -1, latents.shape[1])
44
  print(f"Reshaped latent shape: {latents.shape}")
45
 
46
+ # Project latents from 16 to 64 dimensions
47
+ latents = projection(latents)
48
+ print(f"Projected latent shape: {latents.shape}")
49
+
50
  return latents
51
 
52
  @spaces.GPU()