FrozenBurning commited on
Commit
9d573a0
1 Parent(s): ae08466

update inference

Browse files
Files changed (1) hide show
  1. inference.py +15 -3
inference.py CHANGED
@@ -85,6 +85,21 @@ def resize_foreground(
85
  def extract_texmesh(args, model, output_path, device):
86
  # Prepare directory
87
  ins_dir = output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # Get SDFs
90
  with torch.no_grad():
@@ -350,9 +365,6 @@ if __name__ == "__main__":
350
  # manually enable tf32 to get speedup on A100 GPUs
351
  torch.backends.cuda.matmul.allow_tf32 = True
352
  torch.backends.cudnn.allow_tf32 = True
353
- os.environ["CC"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/gcc"
354
- os.environ["CPP"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/g++"
355
- os.environ["CXX"] = "/mnt/lustre/share/gcc/gcc-8.5.0/bin/g++"
356
  # set config
357
  config = OmegaConf.load(str(sys.argv[1]))
358
  config_cli = OmegaConf.from_cli(args_list=sys.argv[2:])
 
85
  def extract_texmesh(args, model, output_path, device):
86
  # Prepare directory
87
  ins_dir = output_path
88
+ # Noise Filter
89
+ srt_param = model.srt_param.clone()
90
+ prim_position = srt_param[:, 1:4]
91
+ prim_scale = srt_param[:, 0:1]
92
+ dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1))
93
+ dist += torch.eye(prim_position.shape[0]).to(srt_param)
94
+ min_dist, min_indices = dist.min(1)
95
+ dst_prim_scale = prim_scale[min_indices, :]
96
+ min_scale_converage = prim_scale * 1.414 + dst_prim_scale * 1.414
97
+ prim_mask = min_dist < min_scale_converage[:, 0]
98
+ filtered_srt_param = srt_param[prim_mask, :]
99
+ filtered_feat_param = model.feat_param.clone()[prim_mask, ...]
100
+ model.srt_param.data = filtered_srt_param
101
+ model.feat_param.data = filtered_feat_param
102
+ print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}')
103
 
104
  # Get SDFs
105
  with torch.no_grad():
 
365
  # manually enable tf32 to get speedup on A100 GPUs
366
  torch.backends.cuda.matmul.allow_tf32 = True
367
  torch.backends.cudnn.allow_tf32 = True
 
 
 
368
  # set config
369
  config = OmegaConf.load(str(sys.argv[1]))
370
  config_cli = OmegaConf.from_cli(args_list=sys.argv[2:])