Jordan Legg commited on
Commit
448d742
β€’
1 Parent(s): 6b927be

mapped weights and tried transform projection

Browse files
Files changed (1) hide show
  1. app.py +21 -28
app.py CHANGED
@@ -3,6 +3,7 @@ import gradio as gr
3
  import numpy as np
4
  import random
5
  import torch
 
6
  from PIL import Image
7
  from torchvision import transforms
8
  from diffusers import DiffusionPipeline
@@ -19,6 +20,9 @@ pipe.enable_model_cpu_offload()
19
  pipe.vae.enable_slicing()
20
  pipe.vae.enable_tiling()
21
 
 
 
 
22
  def preprocess_image(image, image_size):
23
  preprocess = transforms.Compose([
24
  transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
@@ -28,14 +32,20 @@ def preprocess_image(image, image_size):
28
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
29
  return image
30
 
31
- def check_shapes(latents):
32
- print(f"Latent shape: {latents.shape}")
33
- if len(latents.shape) == 4:
34
- print(f"Expected transformer input shape: {(1, latents.shape[1] * latents.shape[2] * latents.shape[3])}")
35
- elif len(latents.shape) == 2:
36
- print(f"Reshaped latent shape: {latents.shape}")
37
- else:
38
- print(f"Unexpected latent shape: {latents.shape}")
 
 
 
 
 
 
39
 
40
  @spaces.GPU()
41
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
@@ -61,27 +71,10 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
61
 
62
  # Encode the image using FLUX VAE
63
  latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
 
64
 
65
- # Ensure latents are the correct shape
66
- latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
67
-
68
- # Check shapes before reshaping
69
- check_shapes(latents)
70
-
71
- # Reshape latents to match the expected input shape of the transformer
72
- latents = latents.reshape(1, -1)
73
-
74
- # Check shapes after reshaping
75
- check_shapes(latents)
76
-
77
- # Print the type and shape of each argument
78
- print(f"prompt type: {type(prompt)}, value: {prompt}")
79
- print(f"height type: {type(height)}, value: {height}")
80
- print(f"width type: {type(width)}, value: {width}")
81
- print(f"num_inference_steps type: {type(num_inference_steps)}, value: {num_inference_steps}")
82
- print(f"generator type: {type(generator)}")
83
- print(f"guidance_scale type: {type(0.0)}, value: 0.0")
84
- print(f"latents type: {type(latents)}, shape: {latents.shape}")
85
 
86
  image = pipe(
87
  prompt=prompt,
 
3
  import numpy as np
4
  import random
5
  import torch
6
+ import torch.nn as nn
7
  from PIL import Image
8
  from torchvision import transforms
9
  from diffusers import DiffusionPipeline
 
20
  pipe.vae.enable_slicing()
21
  pipe.vae.enable_tiling()
22
 
23
+ # Add a projection layer to match x_embedder input
24
+ projection = nn.Linear(32 * 128 * 128, 64).to(device).to(dtype)
25
+
26
  def preprocess_image(image, image_size):
27
  preprocess = transforms.Compose([
28
  transforms.Resize((image_size, image_size), interpolation=transforms.InterpolationMode.LANCZOS),
 
32
  image = preprocess(image).unsqueeze(0).to(device, dtype=dtype)
33
  return image
34
 
35
+ def process_latents(latents, height, width):
36
+ # Ensure latents are the correct shape (should be [1, 32, 128, 128])
37
+ latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8), mode='bilinear')
38
+ print(f"Latent shape after interpolation: {latents.shape}")
39
+
40
+ # Flatten the latents
41
+ latents_flat = latents.reshape(1, -1)
42
+ print(f"Flattened latent shape: {latents_flat.shape}")
43
+
44
+ # Project to 64 dimensions
45
+ latents_projected = projection(latents_flat)
46
+ print(f"Projected latent shape: {latents_projected.shape}")
47
+
48
+ return latents_projected
49
 
50
  @spaces.GPU()
51
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
71
 
72
  # Encode the image using FLUX VAE
73
  latents = pipe.vae.encode(init_image).latent_dist.sample() * 0.18215
74
+ print(f"Initial latent shape from VAE: {latents.shape}")
75
 
76
+ # Process latents to match x_embedder input
77
+ latents = process_latents(latents, height, width)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  image = pipe(
80
  prompt=prompt,