lemonaddie commited on
Commit
9ba96aa
·
verified ·
1 Parent(s): f12775a

Update models/depth_normal_pipeline_clip.py

Browse files
models/depth_normal_pipeline_clip.py CHANGED
@@ -79,6 +79,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
79
  match_input_res:bool =True,
80
  batch_size:int = 0,
81
  domain: str = "indoor",
 
82
  color_map: str="Spectral",
83
  show_progress_bar:bool = True,
84
  ensemble_kwargs: Dict = None,
@@ -147,6 +148,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
147
  input_rgb=batched_image,
148
  num_inference_steps=denoising_steps,
149
  domain=domain,
 
150
  show_pbar=show_progress_bar,
151
  )
152
  depth_pred_ls.append(depth_pred_raw.detach().clone())
@@ -230,6 +232,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
230
  def single_infer(self,input_rgb:torch.Tensor,
231
  num_inference_steps:int,
232
  domain:str,
 
233
  show_pbar:bool,):
234
 
235
  device = input_rgb.device
@@ -242,6 +245,8 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
242
  rgb_latent = self.encode_RGB(input_rgb)
243
 
244
  # Initial depth map (Guassian noise)
 
 
245
  geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
246
  rgb_latent = rgb_latent.repeat(2,1,1,1)
247
 
 
79
  match_input_res:bool =True,
80
  batch_size:int = 0,
81
  domain: str = "indoor",
82
+ seed: int = 0,
83
  color_map: str="Spectral",
84
  show_progress_bar:bool = True,
85
  ensemble_kwargs: Dict = None,
 
148
  input_rgb=batched_image,
149
  num_inference_steps=denoising_steps,
150
  domain=domain,
151
+ seed=seed,
152
  show_pbar=show_progress_bar,
153
  )
154
  depth_pred_ls.append(depth_pred_raw.detach().clone())
 
232
  def single_infer(self,input_rgb:torch.Tensor,
233
  num_inference_steps:int,
234
  domain:str,
235
+ seed: int,
236
  show_pbar:bool,):
237
 
238
  device = input_rgb.device
 
245
  rgb_latent = self.encode_RGB(input_rgb)
246
 
247
  # Initial depth map (Guassian noise)
248
+ if seed >= 0:
249
+ torch.manual_seed(0)
250
  geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
251
  rgb_latent = rgb_latent.repeat(2,1,1,1)
252