Jordan Legg commited on
Commit
6b927be
Β·
1 Parent(s): bc9da49

added more console logging

Browse files
Files changed (1) hide show
  1. app.py +18 -26
app.py CHANGED
@@ -29,32 +29,13 @@ def preprocess_image(image, image_size):
29
  return image
30
 
31
  def check_shapes(latents):
32
- # Get the shape of the latents
33
- latent_shape = latents.shape
34
- print(f"Latent shape: {latent_shape}")
35
-
36
- # Get the expected shape for the transformer input
37
- expected_shape = (1, latent_shape[1] * latent_shape[2] * latent_shape[3])
38
- print(f"Expected transformer input shape: {expected_shape}")
39
-
40
- # Try to get the shape of the transformer's weight matrix
41
- try:
42
- # Assuming the first layer of the transformer has a linear projection
43
- if hasattr(pipe.transformer, 'blocks'):
44
- weight_shape = pipe.transformer.blocks[0].attn.to_q.weight.shape
45
- else:
46
- print("Unable to determine transformer weight shape.")
47
- return
48
- print(f"Transformer weight shape: {weight_shape}")
49
-
50
- # Check if the shapes are compatible for matrix multiplication
51
- if expected_shape[1] == weight_shape[1]:
52
- print("Shapes are compatible for matrix multiplication.")
53
- else:
54
- print("Warning: Shapes are not compatible for matrix multiplication.")
55
- print(f"Expected: {expected_shape[1]}, Got: {weight_shape[1]}")
56
- except AttributeError as e:
57
- print(f"Unable to access transformer weights: {e}")
58
 
59
  @spaces.GPU()
60
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
@@ -93,6 +74,15 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
93
  # Check shapes after reshaping
94
  check_shapes(latents)
95
 
 
 
 
 
 
 
 
 
 
96
  image = pipe(
97
  prompt=prompt,
98
  height=height,
@@ -106,6 +96,8 @@ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, he
106
  return image, seed
107
  except Exception as e:
108
  print(f"Error during inference: {e}")
 
 
109
  return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
110
 
111
  # Gradio interface setup
 
29
  return image
30
 
31
  def check_shapes(latents):
32
+ print(f"Latent shape: {latents.shape}")
33
+ if len(latents.shape) == 4:
34
+ print(f"Expected transformer input shape: {(1, latents.shape[1] * latents.shape[2] * latents.shape[3])}")
35
+ elif len(latents.shape) == 2:
36
+ print(f"Reshaped latent shape: {latents.shape}")
37
+ else:
38
+ print(f"Unexpected latent shape: {latents.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @spaces.GPU()
41
  def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
 
74
  # Check shapes after reshaping
75
  check_shapes(latents)
76
 
77
+ # Print the type and shape of each argument
78
+ print(f"prompt type: {type(prompt)}, value: {prompt}")
79
+ print(f"height type: {type(height)}, value: {height}")
80
+ print(f"width type: {type(width)}, value: {width}")
81
+ print(f"num_inference_steps type: {type(num_inference_steps)}, value: {num_inference_steps}")
82
+ print(f"generator type: {type(generator)}")
83
+ print(f"guidance_scale type: {type(0.0)}, value: 0.0")
84
+ print(f"latents type: {type(latents)}, shape: {latents.shape}")
85
+
86
  image = pipe(
87
  prompt=prompt,
88
  height=height,
 
96
  return image, seed
97
  except Exception as e:
98
  print(f"Error during inference: {e}")
99
+ import traceback
100
+ traceback.print_exc()
101
  return Image.new("RGB", (width, height), (255, 0, 0)), seed # Red fallback image
102
 
103
  # Gradio interface setup