Spaces:
Runtime error
Runtime error
Chao Xu
commited on
Commit
โข
d74847a
1
Parent(s):
3c3d4fa
model half
Browse files
app.py
CHANGED
@@ -251,11 +251,12 @@ def stage1_run(models, device, cam_vis, tmp_dir,
|
|
251 |
input_im, scale, ddim_steps, elev=None, rerun_all=[],
|
252 |
*btn_retrys):
|
253 |
is_rerun = True if cam_vis is None else False
|
|
|
254 |
|
255 |
stage1_dir = os.path.join(tmp_dir, "stage1_8")
|
256 |
if not is_rerun:
|
257 |
os.makedirs(stage1_dir, exist_ok=True)
|
258 |
-
output_ims = predict_stage1_gradio(
|
259 |
stage2_steps = 50 # ddim_steps
|
260 |
zero123_infer(models['turncam'], tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
|
261 |
elev_output = estimate_elev(tmp_dir)
|
@@ -266,9 +267,9 @@ def stage1_run(models, device, cam_vis, tmp_dir,
|
|
266 |
|
267 |
flag_lower_cam = elev_output <= 75
|
268 |
if flag_lower_cam:
|
269 |
-
output_ims_2 = predict_stage1_gradio(
|
270 |
else:
|
271 |
-
output_ims_2 = predict_stage1_gradio(
|
272 |
torch.cuda.empty_cache()
|
273 |
return (90-elev_output, new_fig, *output_ims, *output_ims_2)
|
274 |
else:
|
@@ -283,7 +284,7 @@ def stage1_run(models, device, cam_vis, tmp_dir,
|
|
283 |
if idx not in rerun_all:
|
284 |
rerun_all.append(idx)
|
285 |
print("rerun_idx", rerun_all)
|
286 |
-
output_ims = predict_stage1_gradio(
|
287 |
outputs = [gr.update(visible=True)] * 8
|
288 |
for idx, view_idx in enumerate(rerun_idx):
|
289 |
outputs[view_idx] = output_ims[idx]
|
@@ -296,11 +297,12 @@ def stage2_run(models, device, tmp_dir,
|
|
296 |
# print("elev", elev)
|
297 |
flag_lower_cam = 90-int(elev["label"]) <= 75
|
298 |
is_rerun = True if rerun_all else False
|
|
|
299 |
if not is_rerun:
|
300 |
if flag_lower_cam:
|
301 |
-
zero123_infer(
|
302 |
else:
|
303 |
-
zero123_infer(
|
304 |
else:
|
305 |
print("rerun_idx", rerun_all)
|
306 |
zero123_infer(models['turncam'], tmp_dir, indices=rerun_all, device=device, ddim_steps=stage2_steps, scale=scale)
|
|
|
251 |
input_im, scale, ddim_steps, elev=None, rerun_all=[],
|
252 |
*btn_retrys):
|
253 |
is_rerun = True if cam_vis is None else False
|
254 |
+
model = models['turncam'].half()
|
255 |
|
256 |
stage1_dir = os.path.join(tmp_dir, "stage1_8")
|
257 |
if not is_rerun:
|
258 |
os.makedirs(stage1_dir, exist_ok=True)
|
259 |
+
output_ims = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4)), device=device, ddim_steps=ddim_steps, scale=scale)
|
260 |
stage2_steps = 50 # ddim_steps
|
261 |
zero123_infer(models['turncam'], tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
|
262 |
elev_output = estimate_elev(tmp_dir)
|
|
|
267 |
|
268 |
flag_lower_cam = elev_output <= 75
|
269 |
if flag_lower_cam:
|
270 |
+
output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4,8)), device=device, ddim_steps=ddim_steps, scale=scale)
|
271 |
else:
|
272 |
+
output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(8,12)), device=device, ddim_steps=ddim_steps, scale=scale)
|
273 |
torch.cuda.empty_cache()
|
274 |
return (90-elev_output, new_fig, *output_ims, *output_ims_2)
|
275 |
else:
|
|
|
284 |
if idx not in rerun_all:
|
285 |
rerun_all.append(idx)
|
286 |
print("rerun_idx", rerun_all)
|
287 |
+
output_ims = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=rerun_idx_in, device=device, ddim_steps=ddim_steps, scale=scale)
|
288 |
outputs = [gr.update(visible=True)] * 8
|
289 |
for idx, view_idx in enumerate(rerun_idx):
|
290 |
outputs[view_idx] = output_ims[idx]
|
|
|
297 |
# print("elev", elev)
|
298 |
flag_lower_cam = 90-int(elev["label"]) <= 75
|
299 |
is_rerun = True if rerun_all else False
|
300 |
+
model = models['turncam'].half()
|
301 |
if not is_rerun:
|
302 |
if flag_lower_cam:
|
303 |
+
zero123_infer(model, tmp_dir, indices=list(range(1,8)), device=device, ddim_steps=stage2_steps, scale=scale)
|
304 |
else:
|
305 |
+
zero123_infer(model, tmp_dir, indices=list(range(1,4))+list(range(8,12)), device=device, ddim_steps=stage2_steps, scale=scale)
|
306 |
else:
|
307 |
print("rerun_idx", rerun_all)
|
308 |
zero123_infer(models['turncam'], tmp_dir, indices=rerun_all, device=device, ddim_steps=stage2_steps, scale=scale)
|