fffiloni commited on
Commit
fac5f0b
1 Parent(s): b04fd93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -88
app.py CHANGED
@@ -149,96 +149,99 @@ models_rbm = core.Models(
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
- # Ensure all models are moved back to the correct device
153
- core_b.generator.to(device)
154
- models_rbm.generator.to(device)
155
-
156
- clear_gpu_cache() # Clear cache before inference
 
157
 
158
- height=1024
159
- width=1024
160
- batch_size=1
161
- output_file='output.png'
162
-
163
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
164
-
165
- extras.sampling_configs['cfg'] = 4
166
- extras.sampling_configs['shift'] = 2
167
- extras.sampling_configs['timesteps'] = 20
168
- extras.sampling_configs['t_start'] = 1.0
169
-
170
- extras_b.sampling_configs['cfg'] = 1.1
171
- extras_b.sampling_configs['shift'] = 1
172
- extras_b.sampling_configs['timesteps'] = 10
173
- extras_b.sampling_configs['t_start'] = 1.0
174
-
175
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
176
-
177
- batch = {'captions': [caption] * batch_size}
178
- batch['style'] = ref_style
179
-
180
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
181
-
182
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
183
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
184
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
185
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
186
-
187
- if low_vram:
188
- # Offload non-essential models to CPU for memory savings
189
- models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
190
-
191
- # Stage C reverse process
192
- with torch.cuda.amp.autocast(): # Use mixed precision
193
- sampling_c = extras.gdf.sample(
194
- models_rbm.generator, conditions, stage_c_latent_shape,
195
- unconditions, device=device,
196
- **extras.sampling_configs,
197
- x0_style_forward=x0_style_forward,
198
- apply_pushforward=False, tau_pushforward=8,
199
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
200
- extras=extras, models=models_rbm,
201
- lam_style=1, lam_txt_alignment=1.0,
202
- use_ddim_sampler=True,
203
- )
204
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
205
- sampled_c = sampled_c
206
-
207
- clear_gpu_cache() # Clear cache between stages
208
-
209
- # Ensure all models are on the right device again
210
- models_b.generator.to(device)
211
-
212
- # Stage B reverse process
213
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
214
- conditions_b['effnet'] = sampled_c
215
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
- sampling_b = extras_b.gdf.sample(
218
- models_b.generator, conditions_b, stage_b_latent_shape,
219
- unconditions_b, device=device, **extras_b.sampling_configs,
220
- )
221
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
222
- sampled_b = sampled_b
223
- sampled = models_b.stage_a.decode(sampled_b).float()
224
-
225
- # Post-process and save the image
226
- sampled = torch.cat([
227
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
228
- sampled.cpu(),
229
- ], dim=0)
230
-
231
- # Remove the batch dimension and keep only the generated image
232
- sampled = sampled[1] # This selects the generated image, discarding the reference style image
233
-
234
- # Ensure the tensor is in [C, H, W] format
235
- if sampled.dim() == 3 and sampled.shape[0] == 3:
236
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
237
- sampled_image.save(output_file) # Save the image as a PNG
238
- else:
239
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
240
-
241
- clear_gpu_cache() # Clear cache after inference
242
 
243
  return output_file # Return the path to the saved image
244
 
 
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
+ try:
153
+ # Ensure all models are moved back to the correct device
154
+ models_rbm.generator.to(device)
155
+ models_b.generator.to(device)
156
+
157
+ clear_gpu_cache() # Clear cache before inference
158
 
159
+ height = 1024
160
+ width = 1024
161
+ batch_size = 1
162
+ output_file = 'output.png'
163
+
164
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
165
+
166
+ extras.sampling_configs['cfg'] = 4
167
+ extras.sampling_configs['shift'] = 2
168
+ extras.sampling_configs['timesteps'] = 20
169
+ extras.sampling_configs['t_start'] = 1.0
170
+
171
+ extras_b.sampling_configs['cfg'] = 1.1
172
+ extras_b.sampling_configs['shift'] = 1
173
+ extras_b.sampling_configs['timesteps'] = 10
174
+ extras_b.sampling_configs['t_start'] = 1.0
175
+
176
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
177
+
178
+ batch = {'captions': [caption] * batch_size}
179
+ batch['style'] = ref_style
180
+
181
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
182
+
183
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
184
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
185
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
186
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
187
+
188
+ if low_vram:
189
+ # Offload non-essential models to CPU for memory savings
190
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
191
+
192
+ # Stage C reverse process
193
+ with torch.cuda.amp.autocast(): # Use mixed precision
194
+ sampling_c = extras.gdf.sample(
195
+ models_rbm.generator, conditions, stage_c_latent_shape,
196
+ unconditions, device=device,
197
+ **extras.sampling_configs,
198
+ x0_style_forward=x0_style_forward,
199
+ apply_pushforward=False, tau_pushforward=8,
200
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
201
+ extras=extras, models=models_rbm,
202
+ lam_style=1, lam_txt_alignment=1.0,
203
+ use_ddim_sampler=True,
204
+ )
205
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
206
+ sampled_c = sampled_c
207
+
208
+ clear_gpu_cache() # Clear cache between stages
209
+
210
+ # Ensure all models are on the right device again
211
+ models_b.generator.to(device)
212
+
213
+ # Stage B reverse process
214
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
215
+ conditions_b['effnet'] = sampled_c
216
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
217
+
218
+ sampling_b = extras_b.gdf.sample(
219
+ models_b.generator, conditions_b, stage_b_latent_shape,
220
+ unconditions_b, device=device, **extras_b.sampling_configs,
221
+ )
222
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
223
+ sampled_b = sampled_b
224
+ sampled = models_b.stage_a.decode(sampled_b).float()
225
+
226
+ # Post-process and save the image
227
+ sampled = sampled.cpu() # Move to CPU before processing
228
+
229
+ # Ensure the tensor is in [C, H, W] format
230
+ if sampled.dim() == 4 and sampled.size(0) == 1:
231
+ sampled = sampled.squeeze(0)
232
 
233
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
234
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
235
+ sampled_image.save(output_file) # Save the image as a PNG
236
+ else:
237
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
238
+
239
+ except Exception as e:
240
+ print(f"An error occurred during inference: {str(e)}")
241
+ return None
242
+
243
+ finally:
244
+ clear_gpu_cache() # Always clear cache after inference
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  return output_file # Return the path to the saved image
247