fffiloni commited on
Commit
bd18e87
1 Parent(s): 8020398

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -27,6 +27,7 @@ from utils import WurstCoreCRBM
27
  from gdf.schedulers import CosineSchedule
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
 
30
 
31
  # Device configuration
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -132,7 +133,7 @@ models_rbm = core.Models(
132
  models_rbm.generator.eval().requires_grad_(False)
133
 
134
  def reset_inference_state():
135
- global models_rbm, models_b, extras, extras_b
136
 
137
  # Reset sampling configurations
138
  extras.sampling_configs['cfg'] = 5
@@ -145,13 +146,16 @@ def reset_inference_state():
145
  extras_b.sampling_configs['timesteps'] = 10
146
  extras_b.sampling_configs['t_start'] = 1.0
147
 
148
- # Move models back to initial state
149
  if low_vram:
150
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
151
  models_b.generator.to("cpu")
152
  else:
153
- models_to(models_rbm, device="cuda")
154
- models_b.generator.to("cuda")
 
 
 
155
 
156
  # Clear CUDA cache
157
  torch.cuda.empty_cache()
@@ -181,7 +185,9 @@ def infer(style_description, ref_style_file, caption):
181
  batch = {'captions': [caption] * batch_size}
182
  batch['style'] = ref_style
183
 
184
- x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style.to(device)))
 
 
185
 
186
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
187
  unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
 
27
  from gdf.schedulers import CosineSchedule
28
  from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
29
  from gdf.targets import EpsilonTarget
30
+ import PIL
31
 
32
  # Device configuration
33
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
133
  models_rbm.generator.eval().requires_grad_(False)
134
 
135
  def reset_inference_state():
136
+ global models_rbm, models_b, extras, extras_b, device
137
 
138
  # Reset sampling configurations
139
  extras.sampling_configs['cfg'] = 5
 
146
  extras_b.sampling_configs['timesteps'] = 10
147
  extras_b.sampling_configs['t_start'] = 1.0
148
 
149
+ # Move models to the correct device
150
  if low_vram:
151
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
152
  models_b.generator.to("cpu")
153
  else:
154
+ models_to(models_rbm, device=device)
155
+ models_b.generator.to(device)
156
+
157
+ # Ensure effnet is on the correct device
158
+ models_rbm.effnet.to(device)
159
 
160
  # Clear CUDA cache
161
  torch.cuda.empty_cache()
 
185
  batch = {'captions': [caption] * batch_size}
186
  batch['style'] = ref_style
187
 
188
+ # Ensure effnet is on the correct device
189
+ models_rbm.effnet.to(device)
190
+ x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
191
 
192
  conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
193
  unconditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)