Jordan Legg commited on
Commit
da39f41
Β·
1 Parent(s): f071803

shaping latents

Browse files
Files changed (1) hide show
  1. app.py +50 -102
app.py CHANGED
@@ -8,37 +8,18 @@ from torchvision import transforms
8
  from diffusers import DiffusionPipeline
9
 
10
  # Define constants
 
 
11
  MAX_SEED = np.iinfo(np.int32).max
12
  MAX_IMAGE_SIZE = 2048
13
- MIN_IMAGE_SIZE = 256
14
- DEFAULT_IMAGE_SIZE = 1024
15
- MAX_PROMPT_LENGTH = 256 # Changed to 256 as per FLUX.1-schnell requirements
16
-
17
- # Check for GPU availability
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- if device == "cpu":
20
- print("Warning: Running on CPU. This may be very slow.")
21
-
22
- dtype = torch.float16 if device == "cuda" else torch.float32
23
-
24
- def load_model():
25
- try:
26
- pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
27
- pipe.to(device)
28
- pipe.enable_model_cpu_offload()
29
- pipe.vae.enable_slicing()
30
- pipe.vae.enable_tiling()
31
- return pipe
32
- except Exception as e:
33
- raise RuntimeError(f"Failed to load the model: {str(e)}")
34
 
35
  # Load the diffusion pipeline
36
- pipe = load_model()
37
 
38
- def preprocess_image(image, target_size):
39
  # Preprocess the image for the VAE
40
  preprocess = transforms.Compose([
41
- transforms.Resize(target_size, interpolation=transforms.InterpolationMode.LANCZOS),
42
  transforms.ToTensor(),
43
  transforms.Normalize([0.5], [0.5])
44
  ])
@@ -51,76 +32,51 @@ def encode_image(image, vae):
51
  latents = vae.encode(image).latent_dist.sample() * 0.18215
52
  return latents
53
 
54
- def validate_inputs(prompt, width, height, num_inference_steps):
55
- if not prompt or len(prompt) > MAX_PROMPT_LENGTH:
56
- raise ValueError(f"Prompt must be between 1 and {MAX_PROMPT_LENGTH} characters.")
57
- if width % 8 != 0 or height % 8 != 0:
58
- raise ValueError("Width and height must be divisible by 8.")
59
- if width < MIN_IMAGE_SIZE or width > MAX_IMAGE_SIZE or height < MIN_IMAGE_SIZE or height > MAX_IMAGE_SIZE:
60
- raise ValueError(f"Image dimensions must be between {MIN_IMAGE_SIZE} and {MAX_IMAGE_SIZE}.")
61
- if num_inference_steps < 1 or num_inference_steps > 50:
62
- raise ValueError("Number of inference steps must be between 1 and 50.")
63
-
64
  @spaces.GPU()
65
- def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=DEFAULT_IMAGE_SIZE, height=DEFAULT_IMAGE_SIZE, num_inference_steps=4, strength=0.8, progress=gr.Progress(track_tqdm=True)):
66
- try:
67
- validate_inputs(prompt, width, height, num_inference_steps)
68
-
69
- if randomize_seed:
70
- seed = random.randint(0, MAX_SEED)
71
- generator = torch.Generator(device=device).manual_seed(seed)
72
-
73
- # Ensure max_sequence_length is not more than 256
74
- max_sequence_length = min(MAX_PROMPT_LENGTH, len(prompt))
 
 
 
75
 
76
- if init_image is not None:
77
- # Process img2img
78
- init_image = init_image.convert("RGB")
79
- init_image = preprocess_image(init_image, (height, width))
80
-
81
- # Encode the image using the VAE
82
- init_latents = encode_image(init_image, pipe.vae)
83
-
84
- # Ensure latents are correctly shaped
85
- init_latents = torch.nn.functional.interpolate(init_latents, size=(height // 8, width // 8), mode='bilinear', align_corners=False)
86
-
87
- # Add noise to latents
88
- noise = torch.randn_like(init_latents)
89
- latents = noise + strength * (init_latents - noise)
90
-
91
- image = pipe(
92
- prompt=prompt,
93
- height=height,
94
- width=width,
95
- num_inference_steps=num_inference_steps,
96
- generator=generator,
97
- guidance_scale=0.0,
98
- latents=latents,
99
- max_sequence_length=max_sequence_length
100
- ).images[0]
101
- else:
102
- # Process text2img
103
- image = pipe(
104
- prompt=prompt,
105
- height=height,
106
- width=width,
107
- num_inference_steps=num_inference_steps,
108
- generator=generator,
109
- guidance_scale=0.0,
110
- max_sequence_length=max_sequence_length
111
- ).images[0]
112
 
113
- return image, seed
114
- except Exception as e:
115
- raise gr.Error(str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  # Define example prompts
118
  examples = [
119
  "a tiny astronaut hatching from an egg on the moon",
120
  "a cat holding a sign that says hello world",
121
  "an anime illustration of a wiener schnitzel",
122
- "a surreal landscape with floating islands and waterfalls",
123
- "a steampunk-inspired cityscape at sunset"
124
  ]
125
 
126
  # CSS styling for the Japanese-inspired interface
@@ -173,7 +129,7 @@ with gr.Blocks(css=css) as demo:
173
  label="Prompt",
174
  show_label=False,
175
  max_lines=1,
176
- placeholder=f"Enter your prompt (max {MAX_PROMPT_LENGTH} characters)",
177
  container=False,
178
  )
179
  run_button = gr.Button("Run", scale=0)
@@ -195,17 +151,17 @@ with gr.Blocks(css=css) as demo:
195
  with gr.Row():
196
  width = gr.Slider(
197
  label="Width",
198
- minimum=MIN_IMAGE_SIZE,
199
  maximum=MAX_IMAGE_SIZE,
200
- step=8,
201
- value=DEFAULT_IMAGE_SIZE,
202
  )
203
  height = gr.Slider(
204
  label="Height",
205
- minimum=MIN_IMAGE_SIZE,
206
  maximum=MAX_IMAGE_SIZE,
207
- step=8,
208
- value=DEFAULT_IMAGE_SIZE,
209
  )
210
 
211
  with gr.Row():
@@ -216,13 +172,6 @@ with gr.Blocks(css=css) as demo:
216
  step=1,
217
  value=4,
218
  )
219
- strength = gr.Slider(
220
- label="Strength (for img2img)",
221
- minimum=0.0,
222
- maximum=1.0,
223
- step=0.01,
224
- value=0.8,
225
- )
226
 
227
  gr.Examples(
228
  examples=examples,
@@ -235,9 +184,8 @@ with gr.Blocks(css=css) as demo:
235
  gr.on(
236
  triggers=[run_button.click, prompt.submit],
237
  fn=infer,
238
- inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps, strength],
239
  outputs=[result, seed]
240
  )
241
 
242
- if __name__ == "__main__":
243
- demo.launch()
 
8
  from diffusers import DiffusionPipeline
9
 
10
  # Define constants
11
+ dtype = torch.bfloat16
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  # Load the diffusion pipeline
17
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
18
 
19
+ def preprocess_image(image):
20
  # Preprocess the image for the VAE
21
  preprocess = transforms.Compose([
22
+ transforms.Resize((512, 512)), # Adjust the size as needed
23
  transforms.ToTensor(),
24
  transforms.Normalize([0.5], [0.5])
25
  ])
 
32
  latents = vae.encode(image).latent_dist.sample() * 0.18215
33
  return latents
34
 
 
 
 
 
 
 
 
 
 
 
35
  @spaces.GPU()
36
+ def infer(prompt, init_image=None, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
37
+ if randomize_seed:
38
+ seed = random.randint(0, MAX_SEED)
39
+ generator = torch.Generator().manual_seed(seed)
40
+
41
+ if init_image is not None:
42
+ # Process img2img
43
+ init_image = init_image.convert("RGB")
44
+ init_image = preprocess_image(init_image)
45
+ latents = encode_image(init_image, pipe.vae)
46
+ # Ensure latents are correctly shaped and adjusted
47
+ latents = torch.nn.functional.interpolate(latents, size=(height // 8, width // 8))
48
+ latents = latents * 0.18215 # Adjust latent scaling factor if necessary
49
 
50
+ # Ensure latents are reshaped to match the expected input dimensions of the model
51
+ latents = latents.view(1, -1, height // 8, width // 8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ image = pipe(
54
+ prompt=prompt,
55
+ height=height,
56
+ width=width,
57
+ num_inference_steps=num_inference_steps,
58
+ generator=generator,
59
+ guidance_scale=0.0,
60
+ latents=latents
61
+ ).images[0]
62
+ else:
63
+ # Process text2img
64
+ image = pipe(
65
+ prompt=prompt,
66
+ height=height,
67
+ width=width,
68
+ num_inference_steps=num_inference_steps,
69
+ generator=generator,
70
+ guidance_scale=0.0
71
+ ).images[0]
72
+
73
+ return image, seed
74
 
75
  # Define example prompts
76
  examples = [
77
  "a tiny astronaut hatching from an egg on the moon",
78
  "a cat holding a sign that says hello world",
79
  "an anime illustration of a wiener schnitzel",
 
 
80
  ]
81
 
82
  # CSS styling for the Japanese-inspired interface
 
129
  label="Prompt",
130
  show_label=False,
131
  max_lines=1,
132
+ placeholder="Enter your prompt",
133
  container=False,
134
  )
135
  run_button = gr.Button("Run", scale=0)
 
151
  with gr.Row():
152
  width = gr.Slider(
153
  label="Width",
154
+ minimum=256,
155
  maximum=MAX_IMAGE_SIZE,
156
+ step=32,
157
+ value=1024,
158
  )
159
  height = gr.Slider(
160
  label="Height",
161
+ minimum=256,
162
  maximum=MAX_IMAGE_SIZE,
163
+ step=32,
164
+ value=1024,
165
  )
166
 
167
  with gr.Row():
 
172
  step=1,
173
  value=4,
174
  )
 
 
 
 
 
 
 
175
 
176
  gr.Examples(
177
  examples=examples,
 
184
  gr.on(
185
  triggers=[run_button.click, prompt.submit],
186
  fn=infer,
187
+ inputs=[prompt, init_image, seed, randomize_seed, width, height, num_inference_steps],
188
  outputs=[result, seed]
189
  )
190
 
191
+ demo.launch()