fffiloni commited on
Commit
7983b33
·
verified ·
1 Parent(s): 0d04150

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -51
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import sys
2
  import os
3
  from pathlib import Path
4
- import gc
5
 
6
  # Add the StableCascade and CSD directories to the Python path
7
  app_dir = Path(__file__).parent
@@ -28,29 +27,12 @@ from gdf.schedulers import CosineSchedule
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
30
 
31
- # Enable mixed precision
32
- torch.backends.cuda.matmul.allow_tf32 = True
33
- torch.backends.cudnn.allow_tf32 = True
34
-
35
  # Device configuration
36
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
37
  print(device)
38
 
39
  # Flag for low VRAM usage
40
- low_vram = True # Set to True to enable low VRAM optimizations
41
-
42
- # Function to clear GPU cache
43
- def clear_gpu_cache():
44
- torch.cuda.empty_cache()
45
- gc.collect()
46
-
47
- # Function to move model to CPU
48
- def to_cpu(model):
49
- return model.cpu()
50
-
51
- # Function to move model to GPU
52
- def to_gpu(model):
53
- return model.cuda()
54
 
55
  # Function definition for low VRAM usage
56
  if low_vram:
@@ -71,7 +53,7 @@ if low_vram:
71
  print(f"Change device of '{attr_name}' to {device}")
72
  attr_value.to(device)
73
 
74
- clear_gpu_cache()
75
 
76
  # Stage C model configuration
77
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
@@ -126,7 +108,7 @@ models_b.generator.bfloat16().eval().requires_grad_(False)
126
  # Off-load old generator (low VRAM mode)
127
  if low_vram:
128
  models.generator.to("cpu")
129
- clear_gpu_cache()
130
 
131
  # Load and configure new generator
132
  generator_rbm = StageCRBM()
@@ -149,7 +131,6 @@ models_rbm = core.Models(
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
- clear_gpu_cache() # Clear cache before inference
153
 
154
  height=1024
155
  width=1024
@@ -185,22 +166,19 @@ def infer(style_description, ref_style_file, caption):
185
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
186
 
187
  # Stage C reverse process.
188
- with torch.cuda.amp.autocast(): # Use mixed precision
189
- sampling_c = extras.gdf.sample(
190
- models_rbm.generator, conditions, stage_c_latent_shape,
191
- unconditions, device=device,
192
- **extras.sampling_configs,
193
- x0_style_forward=x0_style_forward,
194
- apply_pushforward=False, tau_pushforward=8,
195
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
196
- extras=extras, models=models_rbm,
197
- lam_style=1, lam_txt_alignment=1.0,
198
- use_ddim_sampler=True,
199
- )
200
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
201
- sampled_c = sampled_c
202
-
203
- clear_gpu_cache() # Clear cache between stages
204
 
205
  # Stage B reverse process.
206
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
@@ -216,21 +194,14 @@ def infer(style_description, ref_style_file, caption):
216
  sampled = models_b.stage_a.decode(sampled_b).float()
217
 
218
  sampled = torch.cat([
219
- torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
220
  sampled.cpu(),
221
- ], dim=0)
222
-
223
- # Remove the batch dimension and keep only the generated image
224
- sampled = sampled[1] # This selects the generated image, discarding the reference style image
225
-
226
- # Ensure the tensor is in [C, H, W] format
227
- if sampled.dim() == 3 and sampled.shape[0] == 3:
228
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
229
- sampled_image.save(output_file) # Save the image as a PNG
230
- else:
231
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
232
 
233
- clear_gpu_cache() # Clear cache after inference
 
 
234
 
235
  return output_file # Return the path to the saved image
236
 
 
1
  import sys
2
  import os
3
  from pathlib import Path
 
4
 
5
  # Add the StableCascade and CSD directories to the Python path
6
  app_dir = Path(__file__).parent
 
27
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
28
  from gdf.targets import EpsilonTarget
29
 
 
 
 
 
30
  # Device configuration
31
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
  print(device)
33
 
34
  # Flag for low VRAM usage
35
+ low_vram = False
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Function definition for low VRAM usage
38
  if low_vram:
 
53
  print(f"Change device of '{attr_name}' to {device}")
54
  attr_value.to(device)
55
 
56
+ torch.cuda.empty_cache()
57
 
58
  # Stage C model configuration
59
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
 
108
  # Off-load old generator (low VRAM mode)
109
  if low_vram:
110
  models.generator.to("cpu")
111
+ torch.cuda.empty_cache()
112
 
113
  # Load and configure new generator
114
  generator_rbm = StageCRBM()
 
131
  models_rbm.generator.eval().requires_grad_(False)
132
 
133
  def infer(style_description, ref_style_file, caption):
 
134
 
135
  height=1024
136
  width=1024
 
166
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
167
 
168
  # Stage C reverse process.
169
+ sampling_c = extras.gdf.sample(
170
+ models_rbm.generator, conditions, stage_c_latent_shape,
171
+ unconditions, device=device,
172
+ **extras.sampling_configs,
173
+ x0_style_forward=x0_style_forward,
174
+ apply_pushforward=False, tau_pushforward=8,
175
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
176
+ extras=extras, models=models_rbm,
177
+ lam_style=1, lam_txt_alignment=1.0,
178
+ use_ddim_sampler=True,
179
+ )
180
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
181
+ sampled_c = sampled_c
 
 
 
182
 
183
  # Stage B reverse process.
184
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
 
194
  sampled = models_b.stage_a.decode(sampled_b).float()
195
 
196
  sampled = torch.cat([
197
+ torch.nn.functional.interpolate(ref_style.cpu(), size=height),
198
  sampled.cpu(),
199
+ ],
200
+ dim=0)
 
 
 
 
 
 
 
 
 
201
 
202
+ # Save the sampled image to a file
203
+ sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
204
+ sampled_image.save(output_file) # Save the image
205
 
206
  return output_file # Return the path to the saved image
207