FrozenBurning commited on
Commit
670f57e
1 Parent(s): 9d573a0

update inference

Browse files
Files changed (1) hide show
  1. inference.py +10 -7
inference.py CHANGED
@@ -86,17 +86,18 @@ def extract_texmesh(args, model, output_path, device):
86
  # Prepare directory
87
  ins_dir = output_path
88
  # Noise Filter
89
- srt_param = model.srt_param.clone()
90
- prim_position = srt_param[:, 1:4]
91
- prim_scale = srt_param[:, 0:1]
 
92
  dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1))
93
- dist += torch.eye(prim_position.shape[0]).to(srt_param)
94
  min_dist, min_indices = dist.min(1)
95
  dst_prim_scale = prim_scale[min_indices, :]
96
- min_scale_converage = prim_scale * 1.414 + dst_prim_scale * 1.414
97
  prim_mask = min_dist < min_scale_converage[:, 0]
98
- filtered_srt_param = srt_param[prim_mask, :]
99
- filtered_feat_param = model.feat_param.clone()[prim_mask, ...]
100
  model.srt_param.data = filtered_srt_param
101
  model.feat_param.data = filtered_feat_param
102
  print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}')
@@ -210,6 +211,8 @@ def extract_texmesh(args, model, output_path, device):
210
 
211
  target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
212
  target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
 
 
213
 
214
  def main(config):
215
  logging.basicConfig(level=logging.INFO)
 
86
  # Prepare directory
87
  ins_dir = output_path
88
  # Noise Filter
89
+ raw_srt_param = model.srt_param.clone()
90
+ raw_feat_param = model.feat_param.clone()
91
+ prim_position = raw_srt_param[:, 1:4]
92
+ prim_scale = raw_srt_param[:, 0:1]
93
  dist = torch.sqrt(torch.sum((prim_position[:, None, :] - prim_position[None, :, :]) ** 2, dim=-1))
94
+ dist += torch.eye(prim_position.shape[0]).to(raw_srt_param)
95
  min_dist, min_indices = dist.min(1)
96
  dst_prim_scale = prim_scale[min_indices, :]
97
+ min_scale_converage = prim_scale * 1. + dst_prim_scale * 1.
98
  prim_mask = min_dist < min_scale_converage[:, 0]
99
+ filtered_srt_param = raw_srt_param[prim_mask, :]
100
+ filtered_feat_param = raw_feat_param[prim_mask, ...]
101
  model.srt_param.data = filtered_srt_param
102
  model.feat_param.data = filtered_feat_param
103
  print(f'[INFO] Mesh Extraction on PrimX: srt={model.srt_param.shape} feat={model.feat_param.shape}')
 
211
 
212
  target_mesh = Mesh(v=torch.from_numpy(v_np).contiguous(), f=torch.from_numpy(f_np).contiguous(), ft=ft.contiguous(), vt=torch.from_numpy(vt_np).contiguous(), albedo=torch.from_numpy(feats[..., :3]) / 255, metallicRoughness=torch.from_numpy(feats[..., 3:]) / 255)
213
  target_mesh.write(os.path.join(ins_dir, f'pbr_mesh.glb'))
214
+ model.srt_param.data = raw_srt_param
215
+ model.feat_param.data = raw_feat_param
216
 
217
  def main(config):
218
  logging.basicConfig(level=logging.INFO)