Jordan Legg commited on
Commit
86f0308
β€’
1 Parent(s): 448d742

remove projection layer and let x embedder handle it

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -3,7 +3,6 @@ import gradio as gr
3
  import numpy as np
4
  import random
5
  import torch
6
- import torch.nn as nn
7
  from PIL import Image
8
  from torchvision import transforms
9
  from diffusers import DiffusionPipeline
@@ -20,9 +19,6 @@ pipe.enable_model_cpu_offload()
20
  pipe.vae.enable_slicing()
21
  pipe.vae.enable_tiling()
22
 
23
- # Add a projection layer to match x_embedder input
24
- projection = nn.Linear(32 * 128 * 128, 64).to(device).to(dtype)
25
-
26
  def preprocess_image(image, image_size):
27
  preprocess = transforms.Compose([
28
  transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
@@ -33,19 +29,18 @@ def preprocess_image(image, image_size):
33
  return image
34
 
35
  def process_latents(latents, height, width):
36
- # Ensure latents are the correct shape (should be [1, 32, 128, 128])
37
- latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
38
- print(f"Latent shape after interpolation: {latents.shape}")
39
 
40
- # Flatten the latents
41
- latents_flat = latents.reshape(1, -1)
42
- print(f"Flattened latent shape: {latents_flat.shape}")
 
43
 
44
- # Project to 64 dimensions
45
- latents_projected = projection(latents_flat)
46
- print(f"Projected latent shape: {latents_projected.shape}")
47
 
48
- return latents_projected
49
 
50
  @spaces.GPU()
51
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
@@ -76,6 +71,9 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
76
  # Process latents to match x_embedder input
77
  latents = process_latents(latents, height, width)
78
 
 
 
 
79
  image = pipe(
80
  prompt=prompt,
81
  height=height,
 
3
  import numpy as np
4
  import random
5
  import torch
 
6
  from PIL import Image
7
  from torchvision import transforms
8
  from diffusers import DiffusionPipeline
 
19
  pipe.vae.enable_slicing()
20
  pipe.vae.enable_tiling()
21
 
 
 
 
22
  def preprocess_image(image, image_size):
23
  preprocess = transforms.Compose([
24
  transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
 
29
  return image
30
 
31
  def process_latents(latents, height, width):
32
+ print(f"Input latent shape: {latents.shape}")
 
 
33
 
34
+ # Ensure latents are the correct shape
35
+ if latents.shape[2:] != (height // 8, width // 8):
36
+ latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
37
+ print(f"Latent shape after potential interpolation: {latents.shape}")
38
 
39
+ # Reshape latents to [batch_size, seq_len, channels]
40
+ latents = latents.permute(0, 2, 3, 1).reshape(1, -1, latents.shape[1])
41
+ print(f"Reshaped latent shape: {latents.shape}")
42
 
43
+ return latents
44
 
45
  @spaces.GPU()
46
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
71
  # Process latents to match x_embedder input
72
  latents = process_latents(latents, height, width)
73
 
74
+ print(f"x_embedder weight shape: {pipe.transformer.x_embedder.weight.shape}")
75
+ print(f"First transformer block input shape: {pipe.transformer.transformer_blocks[0].attn.to_q.weight.shape}")
76
+
77
  image = pipe(
78
  prompt=prompt,
79
  height=height,