guangkaixu commited on
Commit
a123370
1 Parent(s): 12daec9
Files changed (2) hide show
  1. app.py +3 -3
  2. pipeline_genpercept.py +1 -5
app.py CHANGED
@@ -79,8 +79,8 @@ def process_image(
79
  show_progress_bar=False,
80
  )
81
 
82
- depth_pred = pipe_out.depth_np
83
- depth_colored = pipe_out.depth_colored
84
  depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
85
 
86
  np.save(path_out_fp32, depth_pred)
@@ -266,7 +266,7 @@ def main():
266
 
267
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
268
 
269
- dtype = torch.float16
270
 
271
  vae = AutoencoderKL.from_pretrained("guangkaixu/GenPercept", subfolder='vae').to(dtype)
272
  unet_depth_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_depth_v1").to(dtype)
 
79
  show_progress_bar=False,
80
  )
81
 
82
+ depth_pred = pipe_out.pred_np
83
+ depth_colored = pipe_out.pred_colored
84
  depth_16bit = (depth_pred * 65535.0).astype(np.uint16)
85
 
86
  np.save(path_out_fp32, depth_pred)
 
266
 
267
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
268
 
269
+ dtype = torch.float32
270
 
271
  vae = AutoencoderKL.from_pretrained("guangkaixu/GenPercept", subfolder='vae').to(dtype)
272
  unet_depth_v1 = UNet2DConditionModel.from_pretrained('guangkaixu/GenPercept', subfolder="unet_depth_v1").to(dtype)
pipeline_genpercept.py CHANGED
@@ -148,14 +148,10 @@ class GenPerceptPipeline(DiffusionPipeline):
148
  # Normalize rgb values
149
  rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
150
  rgb_norm = rgb / 255.0 * 2.0 - 1.0
151
- rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype)
152
  rgb_norm = rgb_norm[None].to(device)
153
  assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
154
  bs_imgs = 1
155
-
156
- print('rgb_norm :', rgb_norm.dtype)
157
- print('unet :', self.unet.dtype)
158
- print('vae :', self.vae.dtype)
159
 
160
  # ----------------- Predicting depth -----------------
161
 
 
148
  # Normalize rgb values
149
  rgb = np.transpose(image, (2, 0, 1)) # [H, W, rgb] -> [rgb, H, W]
150
  rgb_norm = rgb / 255.0 * 2.0 - 1.0
151
+ rgb_norm = torch.from_numpy(rgb_norm).to(self.unet.dtype)
152
  rgb_norm = rgb_norm[None].to(device)
153
  assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0
154
  bs_imgs = 1
 
 
 
 
155
 
156
  # ----------------- Predicting depth -----------------
157