fffiloni commited on
Commit
0d04150
·
verified ·
1 Parent(s): c24728c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -137
app.py CHANGED
@@ -2,7 +2,6 @@ import sys
2
  import os
3
  from pathlib import Path
4
  import gc
5
- import traceback
6
 
7
  # Add the StableCascade and CSD directories to the Python path
8
  app_dir = Path(__file__).parent
@@ -28,7 +27,6 @@ from utils import WurstCoreCRBM
28
  from gdf.schedulers import CosineSchedule
29
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
30
  from gdf.targets import EpsilonTarget
31
- import PIL
32
 
33
  # Enable mixed precision
34
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -75,160 +73,171 @@ if low_vram:
75
 
76
  clear_gpu_cache()
77
 
78
- # Load configurations
79
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
80
  with open(config_file, "r", encoding="utf-8") as file:
81
  loaded_config = yaml.safe_load(file)
82
 
 
 
 
83
  config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
84
  with open(config_file_b, "r", encoding="utf-8") as file:
85
  config_file_b = yaml.safe_load(file)
86
-
87
- def initialize_models():
88
- global models_rbm, models_b, extras, extras_b, core, core_b
89
-
90
- # Clear any existing models from memory
91
- models_rbm = None
92
- models_b = None
93
- extras = None
94
- extras_b = None
95
-
96
- # Clear GPU cache
97
- clear_gpu_cache()
98
-
99
- # Initialize models
100
- core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
101
- core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
102
-
103
- extras = core.setup_extras_pre()
104
- models = core.setup_models(extras)
105
-
106
- extras_b = core_b.setup_extras_pre()
107
- models_b = core_b.setup_models(extras_b, skip_clip=True)
108
- models_b = WurstCoreB.Models(
109
- **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
110
- )
111
-
112
- # Initialize models_rbm
113
- generator_rbm = StageCRBM()
114
- for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
115
- set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
116
-
117
- generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
118
- generator_rbm = core.load_model(generator_rbm, 'generator')
119
-
120
- models_rbm = core.Models(
121
- effnet=models.effnet,
122
- previewer=models.previewer,
123
- generator=generator_rbm,
124
- generator_ema=models.generator_ema,
125
- tokenizer=models.tokenizer,
126
- text_model=models.text_model,
127
- image_model=models.image_model
128
- )
129
-
130
- # Move models to appropriate devices
131
- models_rbm.generator.to(device).eval().requires_grad_(False)
132
- models_b.generator.to(device).eval().requires_grad_(False)
133
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  clear_gpu_cache()
135
 
136
- def infer(style_description, ref_style_file, caption):
137
- try:
138
- # Clear GPU cache before inference
139
- clear_gpu_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
- # Ensure models are on the correct device
142
- models_rbm.to(device)
143
- models_b.to(device)
144
 
145
- height = 1024
146
- width = 1024
147
- batch_size = 1
148
- output_file = 'output.png'
149
-
150
- stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
151
-
152
- ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
153
-
154
- batch = {'captions': [caption] * batch_size}
155
- batch['style'] = ref_style
156
-
157
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
158
-
159
- conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
160
- unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
161
- conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
162
- unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
163
-
164
- if low_vram:
165
- # Offload non-essential models to CPU for memory savings
166
- models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
167
-
168
- # Stage C reverse process
169
- with torch.cuda.amp.autocast():
170
- sampling_c = extras.gdf.sample(
171
- models_rbm.generator, conditions, stage_c_latent_shape,
172
- unconditions, device=device,
173
- **extras.sampling_configs,
174
- x0_style_forward=x0_style_forward,
175
- apply_pushforward=False, tau_pushforward=8,
176
- num_iter=3, eta=0.1, tau=20, eval_csd=True,
177
- extras=extras, models=models_rbm,
178
- lam_style=1, lam_txt_alignment=1.0,
179
- use_ddim_sampler=True,
180
- )
181
- for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
182
- sampled_c = sampled_c
183
-
184
- clear_gpu_cache() # Clear cache between stages
185
-
186
- # Ensure models_b is on the correct device
187
- models_b.to(device)
188
-
189
- # Stage B reverse process
190
- with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
191
- conditions_b['effnet'] = sampled_c
192
- unconditions_b['effnet'] = torch.zeros_like(sampled_c)
193
-
194
- sampling_b = extras_b.gdf.sample(
195
- models_b.generator, conditions_b, stage_b_latent_shape,
196
- unconditions_b, device=device, **extras_b.sampling_configs,
197
- )
198
- for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
199
- sampled_b = sampled_b
200
- sampled = models_b.stage_a.decode(sampled_b).float()
201
-
202
- # Post-process and save the image
203
- sampled = sampled.cpu() # Move to CPU before processing
204
-
205
- # Ensure the tensor is in [C, H, W] format
206
- if sampled.dim() == 4 and sampled.size(0) == 1:
207
- sampled = sampled.squeeze(0)
208
 
209
- if sampled.dim() == 3 and sampled.shape[0] == 3:
210
- sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
211
- sampled_image.save(output_file) # Save the image as a PNG
212
- else:
213
- raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
214
-
215
- except Exception as e:
216
- print(f"An error occurred during inference: {str(e)}")
217
- traceback.print_exc() # This will print the full traceback
218
- return None
219
-
220
- finally:
221
- clear_gpu_cache() # Always clear cache after inference
 
 
 
 
 
 
 
 
 
 
 
222
 
223
  return output_file # Return the path to the saved image
224
 
225
  import gradio as gr
226
 
227
- def gradio_interface(style_description, ref_style_file, caption):
228
- return infer(style_description, ref_style_file, caption)
229
-
230
  gr.Interface(
231
- fn=gradio_interface,
232
  inputs=[gr.Textbox(label="style description"), gr.Image(label="Ref Style File", type="filepath"), gr.Textbox(label="caption")],
233
  outputs=[gr.Image()]
234
  ).launch()
 
2
  import os
3
  from pathlib import Path
4
  import gc
 
5
 
6
  # Add the StableCascade and CSD directories to the Python path
7
  app_dir = Path(__file__).parent
 
27
  from gdf.schedulers import CosineSchedule
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
 
30
 
31
  # Enable mixed precision
32
  torch.backends.cuda.matmul.allow_tf32 = True
 
73
 
74
  clear_gpu_cache()
75
 
76
+ # Stage C model configuration
77
  config_file = 'third_party/StableCascade/configs/inference/stage_c_3b.yaml'
78
  with open(config_file, "r", encoding="utf-8") as file:
79
  loaded_config = yaml.safe_load(file)
80
 
81
+ core = WurstCoreCRBM(config_dict=loaded_config, device=device, training=False)
82
+
83
+ # Stage B model configuration
84
  config_file_b = 'third_party/StableCascade/configs/inference/stage_b_3b.yaml'
85
  with open(config_file_b, "r", encoding="utf-8") as file:
86
  config_file_b = yaml.safe_load(file)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)
89
+
90
+ # Setup extras and models for Stage C
91
+ extras = core.setup_extras_pre()
92
+
93
+ gdf_rbm = RBM(
94
+ schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
95
+ input_scaler=VPScaler(), target=EpsilonTarget(),
96
+ noise_cond=CosineTNoiseCond(),
97
+ loss_weight=AdaptiveLossWeight(),
98
+ )
99
+
100
+ sampling_configs = {
101
+ "cfg": 5,
102
+ "sampler": DDPMSampler(gdf_rbm),
103
+ "shift": 1,
104
+ "timesteps": 20
105
+ }
106
+
107
+ extras = core.Extras(
108
+ gdf=gdf_rbm,
109
+ sampling_configs=sampling_configs,
110
+ transforms=extras.transforms,
111
+ effnet_preprocess=extras.effnet_preprocess,
112
+ clip_preprocess=extras.clip_preprocess
113
+ )
114
+
115
+ models = core.setup_models(extras)
116
+ models.generator.eval().requires_grad_(False)
117
+
118
+ # Setup extras and models for Stage B
119
+ extras_b = core_b.setup_extras_pre()
120
+ models_b = core_b.setup_models(extras_b, skip_clip=True)
121
+ models_b = WurstCoreB.Models(
122
+ **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
123
+ )
124
+ models_b.generator.bfloat16().eval().requires_grad_(False)
125
+
126
+ # Off-load old generator (low VRAM mode)
127
+ if low_vram:
128
+ models.generator.to("cpu")
129
  clear_gpu_cache()
130
 
131
+ # Load and configure new generator
132
+ generator_rbm = StageCRBM()
133
+ for param_name, param in load_or_fail(core.config.generator_checkpoint_path).items():
134
+ set_module_tensor_to_device(generator_rbm, param_name, "cpu", value=param)
135
+
136
+ generator_rbm = generator_rbm.to(getattr(torch, core.config.dtype)).to(device)
137
+ generator_rbm = core.load_model(generator_rbm, 'generator')
138
+
139
+ # Create models_rbm instance
140
+ models_rbm = core.Models(
141
+ effnet=models.effnet,
142
+ previewer=models.previewer,
143
+ generator=generator_rbm,
144
+ generator_ema=models.generator_ema,
145
+ tokenizer=models.tokenizer,
146
+ text_model=models.text_model,
147
+ image_model=models.image_model
148
+ )
149
+ models_rbm.generator.eval().requires_grad_(False)
150
 
151
+ def infer(style_description, ref_style_file, caption):
152
+ clear_gpu_cache() # Clear cache before inference
 
153
 
154
+ height=1024
155
+ width=1024
156
+ batch_size=1
157
+ output_file='output.png'
158
+
159
+ stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)
160
+
161
+ extras.sampling_configs['cfg'] = 4
162
+ extras.sampling_configs['shift'] = 2
163
+ extras.sampling_configs['timesteps'] = 20
164
+ extras.sampling_configs['t_start'] = 1.0
165
+
166
+ extras_b.sampling_configs['cfg'] = 1.1
167
+ extras_b.sampling_configs['shift'] = 1
168
+ extras_b.sampling_configs['timesteps'] = 10
169
+ extras_b.sampling_configs['t_start'] = 1.0
170
+
171
+ ref_style = resize_image(PIL.Image.open(ref_style_file).convert("RGB")).unsqueeze(0).expand(batch_size, -1, -1, -1).to(device)
172
+
173
+ batch = {'captions': [caption] * batch_size}
174
+ batch['style'] = ref_style
175
+
176
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
177
+
178
+ conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
179
+ unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
180
+ conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
181
+ unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
182
+
183
+ if low_vram:
184
+ # The sampling process uses more vram, so we offload everything except two modules to the cpu.
185
+ models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
186
+
187
+ # Stage C reverse process.
188
+ with torch.cuda.amp.autocast(): # Use mixed precision
189
+ sampling_c = extras.gdf.sample(
190
+ models_rbm.generator, conditions, stage_c_latent_shape,
191
+ unconditions, device=device,
192
+ **extras.sampling_configs,
193
+ x0_style_forward=x0_style_forward,
194
+ apply_pushforward=False, tau_pushforward=8,
195
+ num_iter=3, eta=0.1, tau=20, eval_csd=True,
196
+ extras=extras, models=models_rbm,
197
+ lam_style=1, lam_txt_alignment=1.0,
198
+ use_ddim_sampler=True,
199
+ )
200
+ for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
201
+ sampled_c = sampled_c
202
+
203
+ clear_gpu_cache() # Clear cache between stages
204
+
205
+ # Stage B reverse process.
206
+ with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
207
+ conditions_b['effnet'] = sampled_c
208
+ unconditions_b['effnet'] = torch.zeros_like(sampled_c)
 
 
 
 
 
 
 
 
209
 
210
+ sampling_b = extras_b.gdf.sample(
211
+ models_b.generator, conditions_b, stage_b_latent_shape,
212
+ unconditions_b, device=device, **extras_b.sampling_configs,
213
+ )
214
+ for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
215
+ sampled_b = sampled_b
216
+ sampled = models_b.stage_a.decode(sampled_b).float()
217
+
218
+ sampled = torch.cat([
219
+ torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
220
+ sampled.cpu(),
221
+ ], dim=0)
222
+
223
+ # Remove the batch dimension and keep only the generated image
224
+ sampled = sampled[1] # This selects the generated image, discarding the reference style image
225
+
226
+ # Ensure the tensor is in [C, H, W] format
227
+ if sampled.dim() == 3 and sampled.shape[0] == 3:
228
+ sampled_image = T.ToPILImage()(sampled) # Convert tensor to PIL image
229
+ sampled_image.save(output_file) # Save the image as a PNG
230
+ else:
231
+ raise ValueError(f"Expected tensor of shape [3, H, W] but got {sampled.shape}")
232
+
233
+ clear_gpu_cache() # Clear cache after inference
234
 
235
  return output_file # Return the path to the saved image
236
 
237
  import gradio as gr
238
 
 
 
 
239
  gr.Interface(
240
+ fn = infer,
241
  inputs=[gr.Textbox(label="style description"), gr.Image(label="Ref Style File", type="filepath"), gr.Textbox(label="caption")],
242
  outputs=[gr.Image()]
243
  ).launch()