lemonaddie commited on
Commit
f3991d8
1 Parent(s): 825b5cb

Update models/depth_normal_pipeline_clip.py

Browse files
models/depth_normal_pipeline_clip.py CHANGED
@@ -79,7 +79,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
79
  match_input_res:bool =True,
80
  batch_size:int = 0,
81
  domain: str = "indoor",
82
- seed: int = 0,
83
  color_map: str="Spectral",
84
  show_progress_bar:bool = True,
85
  ensemble_kwargs: Dict = None,
@@ -148,7 +148,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
148
  input_rgb=batched_image,
149
  num_inference_steps=denoising_steps,
150
  domain=domain,
151
- seed=seed,
152
  show_pbar=show_progress_bar,
153
  )
154
  depth_pred_ls.append(depth_pred_raw.detach().clone())
@@ -232,7 +232,7 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
232
  def single_infer(self,input_rgb:torch.Tensor,
233
  num_inference_steps:int,
234
  domain:str,
235
- seed: int,
236
  show_pbar:bool,):
237
 
238
  device = input_rgb.device
@@ -245,8 +245,8 @@ class DepthNormalEstimationPipeline(DiffusionPipeline):
245
  rgb_latent = self.encode_RGB(input_rgb)
246
 
247
  # Initial depth map (Guassian noise)
248
- if seed >= 0:
249
- torch.manual_seed(0)
250
  geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
251
  rgb_latent = rgb_latent.repeat(2,1,1,1)
252
 
 
79
  match_input_res:bool =True,
80
  batch_size:int = 0,
81
  domain: str = "indoor",
82
+ #seed: int = 0,
83
  color_map: str="Spectral",
84
  show_progress_bar:bool = True,
85
  ensemble_kwargs: Dict = None,
 
148
  input_rgb=batched_image,
149
  num_inference_steps=denoising_steps,
150
  domain=domain,
151
+ #seed=seed,
152
  show_pbar=show_progress_bar,
153
  )
154
  depth_pred_ls.append(depth_pred_raw.detach().clone())
 
232
  def single_infer(self,input_rgb:torch.Tensor,
233
  num_inference_steps:int,
234
  domain:str,
235
+ #seed: int,
236
  show_pbar:bool,):
237
 
238
  device = input_rgb.device
 
245
  rgb_latent = self.encode_RGB(input_rgb)
246
 
247
  # Initial depth map (Guassian noise)
248
+ #if seed >= 0:
249
+ #torch.manual_seed(0)
250
  geo_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype).repeat(2,1,1,1)
251
  rgb_latent = rgb_latent.repeat(2,1,1,1)
252