Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import sys
|
2 |
import os
|
3 |
from pathlib import Path
|
4 |
-
import gc
|
5 |
|
6 |
# Add the StableCascade and CSD directories to the Python path
|
7 |
app_dir = Path(__file__).parent
|
@@ -28,29 +27,12 @@ from gdf.schedulers import CosineSchedule
|
|
28 |
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
|
29 |
from gdf.targets import EpsilonTarget
|
30 |
|
31 |
-
# Enable mixed precision
|
32 |
-
torch.backends.cuda.matmul.allow_tf32 = True
|
33 |
-
torch.backends.cudnn.allow_tf32 = True
|
34 |
-
|
35 |
# Device configuration
|
36 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
37 |
print(device)
|
38 |
|
39 |
# Flag for low VRAM usage
|
40 |
-
low_vram =
|
41 |
-
|
42 |
-
# Function to clear GPU cache
|
43 |
-
def clear_gpu_cache():
|
44 |
-
torch.cuda.empty_cache()
|
45 |
-
gc.collect()
|
46 |
-
|
47 |
-
# Function to move model to CPU
|
48 |
-
def to_cpu(model):
|
49 |
-
return model.cpu()
|
50 |
-
|
51 |
-
# Function to move model to GPU
|
52 |
-
def to_gpu(model):
|
53 |
-
return model.cuda()
|
54 |
|
55 |
# Function definition for low VRAM usage
|
56 |
if low_vram:
|
@@ -71,7 +53,7 @@ if low_vram:
|
|
71 |
print(f"Change device of '{attr_name}' to {device}")
|
72 |
attr_value.to(device)
|
73 |
|
74 |
-
|
75 |
|
76 |
# Stage C model configuration
|
77 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
@@ -126,7 +108,7 @@ models_b.generator.bfloat16().eval().requires_grad_(False)
|
|
126 |
# Off-load old generator (low VRAM mode)
|
127 |
if low_vram:
|
128 |
models.generator.to("cpu")
|
129 |
-
|
130 |
|
131 |
# Load and configure new generator
|
132 |
generator_rbm = StageCRBM()
|
@@ -149,7 +131,6 @@ models_rbm = core.Models(
|
|
149 |
models_rbm.generator.eval().requires_grad_(False)
|
150 |
|
151 |
def infer(style_description, ref_style_file, caption):
|
152 |
-
clear_gpu_cache() # Clear cache before inference
|
153 |
|
154 |
height=1024
|
155 |
width=1024
|
@@ -185,22 +166,19 @@ def infer(style_description, ref_style_file, caption):
|
|
185 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
186 |
|
187 |
# Stage C reverse process.
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
sampled_c = sampled_c
|
202 |
-
|
203 |
-
clear_gpu_cache() # Clear cache between stages
|
204 |
|
205 |
# Stage B reverse process.
|
206 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
@@ -216,21 +194,14 @@ def infer(style_description, ref_style_file, caption):
|
|
216 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
217 |
|
218 |
sampled = torch.cat([
|
219 |
-
torch.nn.functional.interpolate(ref_style.cpu(), size=
|
220 |
sampled.cpu(),
|
221 |
-
|
222 |
-
|
223 |
-
# Remove the batch dimension and keep only the generated image
|
224 |
-
sampled = sampled[1] # This selects the generated image, discarding the reference style image
|
225 |
-
|
226 |
-
# Ensure the tensor is in [C, H, W] format
|
227 |
-
if sampled.dim() == 3 and sampled.shape[0] == 3:
|
228 |
-
sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
|
229 |
-
sampled_image.save(output_file) # Save the image as a PNG
|
230 |
-
else:
|
231 |
-
raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
|
232 |
|
233 |
-
|
|
|
|
|
234 |
|
235 |
return output_file # Return the path to the saved image
|
236 |
|
|
|
1 |
import sys
|
2 |
import os
|
3 |
from pathlib import Path
|
|
|
4 |
|
5 |
# Add the StableCascade and CSD directories to the Python path
|
6 |
app_dir = Path(__file__).parent
|
|
|
27 |
from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
|
28 |
from gdf.targets import EpsilonTarget
|
29 |
|
|
|
|
|
|
|
|
|
30 |
# Device configuration
|
31 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32 |
print(device)
|
33 |
|
34 |
# Flag for low VRAM usage
|
35 |
+
low_vram = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
# Function definition for low VRAM usage
|
38 |
if low_vram:
|
|
|
53 |
print(f"Change device of '{attr_name}' to {device}")
|
54 |
attr_value.to(device)
|
55 |
|
56 |
+
torch.cuda.empty_cache()
|
57 |
|
58 |
# Stage C model configuration
|
59 |
config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
|
|
|
108 |
# Off-load old generator (low VRAM mode)
|
109 |
if low_vram:
|
110 |
models.generator.to("cpu")
|
111 |
+
torch.cuda.empty_cache()
|
112 |
|
113 |
# Load and configure new generator
|
114 |
generator_rbm = StageCRBM()
|
|
|
131 |
models_rbm.generator.eval().requires_grad_(False)
|
132 |
|
133 |
def infer(style_description, ref_style_file, caption):
|
|
|
134 |
|
135 |
height=1024
|
136 |
width=1024
|
|
|
166 |
models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
|
167 |
|
168 |
# Stage C reverse process.
|
169 |
+
sampling_c = extras.gdf.sample(
|
170 |
+
models_rbm.generator, conditions, stage_c_latent_shape,
|
171 |
+
unconditions, device=device,
|
172 |
+
**extras.sampling_configs,
|
173 |
+
x0_style_forward=x0_style_forward,
|
174 |
+
apply_pushforward=False, tau_pushforward=8,
|
175 |
+
num_iter=3, eta=0.1, tau=20, eval_csd=True,
|
176 |
+
extras=extras, models=models_rbm,
|
177 |
+
lam_style=1, lam_txt_alignment=1.0,
|
178 |
+
use_ddim_sampler=True,
|
179 |
+
)
|
180 |
+
for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
|
181 |
+
sampled_c = sampled_c
|
|
|
|
|
|
|
182 |
|
183 |
# Stage B reverse process.
|
184 |
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
|
|
194 |
sampled = models_b.stage_a.decode(sampled_b).float()
|
195 |
|
196 |
sampled = torch.cat([
|
197 |
+
torch.nn.functional.interpolate(ref_style.cpu(), size=height),
|
198 |
sampled.cpu(),
|
199 |
+
],
|
200 |
+
dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
201 |
|
202 |
+
# Save the sampled image to a file
|
203 |
+
sampled_image = T.ToPILImage()(sampled.squeeze(0)) # Convert tensor to PIL image
|
204 |
+
sampled_image.save(output_file) # Save the image
|
205 |
|
206 |
return output_file # Return the path to the saved image
|
207 |
|