Chao Xu commited on
Commit
d74847a
โ€ข
1 Parent(s): 3c3d4fa

model half

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -251,11 +251,12 @@ def stage1_run(models, device, cam_vis, tmp_dir,
251
  input_im, scale, ddim_steps, elev=None, rerun_all=[],
252
  *btn_retrys):
253
  is_rerun = True if cam_vis is None else False
 
254
 
255
  stage1_dir = os.path.join(tmp_dir, "stage1_8")
256
  if not is_rerun:
257
  os.makedirs(stage1_dir, exist_ok=True)
258
- output_ims = predict_stage1_gradio(models['turncam'], input_im, save_path=stage1_dir, adjust_set=list(range(4)), device=device, ddim_steps=ddim_steps, scale=scale)
259
  stage2_steps = 50 # ddim_steps
260
  zero123_infer(models['turncam'], tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
261
  elev_output = estimate_elev(tmp_dir)
@@ -266,9 +267,9 @@ def stage1_run(models, device, cam_vis, tmp_dir,
266
 
267
  flag_lower_cam = elev_output <= 75
268
  if flag_lower_cam:
269
- output_ims_2 = predict_stage1_gradio(models['turncam'], input_im, save_path=stage1_dir, adjust_set=list(range(4,8)), device=device, ddim_steps=ddim_steps, scale=scale)
270
  else:
271
- output_ims_2 = predict_stage1_gradio(models['turncam'], input_im, save_path=stage1_dir, adjust_set=list(range(8,12)), device=device, ddim_steps=ddim_steps, scale=scale)
272
  torch.cuda.empty_cache()
273
  return (90-elev_output, new_fig, *output_ims, *output_ims_2)
274
  else:
@@ -283,7 +284,7 @@ def stage1_run(models, device, cam_vis, tmp_dir,
283
  if idx not in rerun_all:
284
  rerun_all.append(idx)
285
  print("rerun_idx", rerun_all)
286
- output_ims = predict_stage1_gradio(models['turncam'], input_im, save_path=stage1_dir, adjust_set=rerun_idx_in, device=device, ddim_steps=ddim_steps, scale=scale)
287
  outputs = [gr.update(visible=True)] * 8
288
  for idx, view_idx in enumerate(rerun_idx):
289
  outputs[view_idx] = output_ims[idx]
@@ -296,11 +297,12 @@ def stage2_run(models, device, tmp_dir,
296
  # print("elev", elev)
297
  flag_lower_cam = 90-int(elev["label"]) <= 75
298
  is_rerun = True if rerun_all else False
 
299
  if not is_rerun:
300
  if flag_lower_cam:
301
- zero123_infer(models['turncam'], tmp_dir, indices=list(range(1,8)), device=device, ddim_steps=stage2_steps, scale=scale)
302
  else:
303
- zero123_infer(models['turncam'], tmp_dir, indices=list(range(1,4))+list(range(8,12)), device=device, ddim_steps=stage2_steps, scale=scale)
304
  else:
305
  print("rerun_idx", rerun_all)
306
  zero123_infer(models['turncam'], tmp_dir, indices=rerun_all, device=device, ddim_steps=stage2_steps, scale=scale)
 
251
  input_im, scale, ddim_steps, elev=None, rerun_all=[],
252
  *btn_retrys):
253
  is_rerun = True if cam_vis is None else False
254
+ model = models['turncam'].half()
255
 
256
  stage1_dir = os.path.join(tmp_dir, "stage1_8")
257
  if not is_rerun:
258
  os.makedirs(stage1_dir, exist_ok=True)
259
+ output_ims = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4)), device=device, ddim_steps=ddim_steps, scale=scale)
260
  stage2_steps = 50 # ddim_steps
261
  zero123_infer(models['turncam'], tmp_dir, indices=[0], device=device, ddim_steps=stage2_steps, scale=scale)
262
  elev_output = estimate_elev(tmp_dir)
 
267
 
268
  flag_lower_cam = elev_output <= 75
269
  if flag_lower_cam:
270
+ output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(4,8)), device=device, ddim_steps=ddim_steps, scale=scale)
271
  else:
272
+ output_ims_2 = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=list(range(8,12)), device=device, ddim_steps=ddim_steps, scale=scale)
273
  torch.cuda.empty_cache()
274
  return (90-elev_output, new_fig, *output_ims, *output_ims_2)
275
  else:
 
284
  if idx not in rerun_all:
285
  rerun_all.append(idx)
286
  print("rerun_idx", rerun_all)
287
+ output_ims = predict_stage1_gradio(model, input_im, save_path=stage1_dir, adjust_set=rerun_idx_in, device=device, ddim_steps=ddim_steps, scale=scale)
288
  outputs = [gr.update(visible=True)] * 8
289
  for idx, view_idx in enumerate(rerun_idx):
290
  outputs[view_idx] = output_ims[idx]
 
297
  # print("elev", elev)
298
  flag_lower_cam = 90-int(elev["label"]) <= 75
299
  is_rerun = True if rerun_all else False
300
+ model = models['turncam'].half()
301
  if not is_rerun:
302
  if flag_lower_cam:
303
+ zero123_infer(model, tmp_dir, indices=list(range(1,8)), device=device, ddim_steps=stage2_steps, scale=scale)
304
  else:
305
+ zero123_infer(model, tmp_dir, indices=list(range(1,4))+list(range(8,12)), device=device, ddim_steps=stage2_steps, scale=scale)
306
  else:
307
  print("rerun_idx", rerun_all)
308
  zero123_infer(models['turncam'], tmp_dir, indices=rerun_all, device=device, ddim_steps=stage2_steps, scale=scale)