Spaces:
Running
on
A100
Running
on
A100
Update app.py
Browse files
app.py
CHANGED
@@ -125,7 +125,8 @@ models_rbm = core.Models(
|
|
125 |
text_model=models.text_model,
|
126 |
tokenizer=models.tokenizer,
|
127 |
generator=generator_rbm,
|
128 |
-
previewer=models.previewer
|
|
|
129 |
)
|
130 |
|
131 |
def reset_inference_state():
|
@@ -160,8 +161,10 @@ def reset_inference_state():
|
|
160 |
|
161 |
models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
|
162 |
|
163 |
-
# Ensure effnet
|
164 |
models_rbm.effnet.to(device)
|
|
|
|
|
165 |
|
166 |
# Reset model states
|
167 |
models_rbm.generator.eval().requires_grad_(False)
|
@@ -204,8 +207,11 @@ def infer(style_description, ref_style_file, caption):
|
|
204 |
|
205 |
models_b.generator.to(device)
|
206 |
|
207 |
-
# Ensure effnet
|
208 |
models_rbm.effnet.to(device)
|
|
|
|
|
|
|
209 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
210 |
|
211 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|
|
|
125 |
text_model=models.text_model,
|
126 |
tokenizer=models.tokenizer,
|
127 |
generator=generator_rbm,
|
128 |
+
previewer=models.previewer,
|
129 |
+
image_model=models.image_model # Add this line
|
130 |
)
|
131 |
|
132 |
def reset_inference_state():
|
|
|
161 |
|
162 |
models_b.generator.to("cpu") # Keep Stage B generator on CPU for now
|
163 |
|
164 |
+
# Ensure effnet and image_model are on the correct device
|
165 |
models_rbm.effnet.to(device)
|
166 |
+
if models_rbm.image_model is not None:
|
167 |
+
models_rbm.image_model.to(device)
|
168 |
|
169 |
# Reset model states
|
170 |
models_rbm.generator.eval().requires_grad_(False)
|
|
|
207 |
|
208 |
models_b.generator.to(device)
|
209 |
|
210 |
+
# Ensure effnet and image_model are on the correct device
|
211 |
models_rbm.effnet.to(device)
|
212 |
+
if models_rbm.image_model is not None:
|
213 |
+
models_rbm.image_model.to(device)
|
214 |
+
|
215 |
x0_style_forward = models_rbm.effnet(extras.effnet_preprocess(ref_style))
|
216 |
|
217 |
conditions = core.get_conditions(batch, models_rbm, extras, is_eval=True, is_unconditional=False, eval_image_embeds=True, eval_style=True, eval_csd=False)
|