zxhezexin commited on
Commit
206e6cf
1 Parent(s): 7af53da

Update lrm/inferrer.py

Browse files
Files changed (1) hide show
  1. lrm/inferrer.py +16 -11
lrm/inferrer.py CHANGED
@@ -45,14 +45,19 @@ class LRMInferrer:
45
 
46
  def _load_checkpoint(self, model_name: str, cache_dir = './.cache'):
47
  # download checkpoint if not exists
48
- if not os.path.exists(cache_dir):
49
- os.makedirs(cache_dir, exist_ok=True)
50
- if not os.path.exists(os.path.join(cache_dir, f'{model_name}.pth')):
51
- # TODO: on-the-fly download not supported yet, plz download manually
52
  # os.system(f'wget -O {os.path.join(cache_dir, f"{model_name}.pth")} https://zxhezexin.com/modelzoo/openlrm/{model_name}.pth')
53
- raise FileNotFoundError(f"Checkpoint {model_name} not found in {cache_dir}")
54
- local_path = os.path.join(cache_dir, f'{model_name}.pth')
55
- checkpoint = torch.load(local_path, map_location=self.device)
 
 
 
 
 
56
  return checkpoint
57
 
58
  def _build_model(self, model_kwargs, model_weights):
@@ -193,11 +198,11 @@ class LRMInferrer:
193
 
194
  image = torch.tensor(np.array(Image.open(source_image))).permute(2, 0, 1).unsqueeze(0) / 255.0
195
  # if RGBA, blend to RGB
196
- print(f"[DEBUG] check 1.")
197
  if image.shape[1] == 4:
198
  image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
199
  print(f"[DEBUG] image.shape={image.shape} and image[0,0,0,0]={image[0,0,0,0]}")
200
- print(f"[DEBUG] check 2.")
201
  image = torch.nn.functional.interpolate(image, size=(source_image_size, source_image_size), mode='bicubic', align_corners=True)
202
  image = torch.clamp(image, 0, 1)
203
  results = self.infer_single(
@@ -240,11 +245,11 @@ if __name__ == '__main__':
240
 
241
  """
242
  Example usage:
243
- python -m lrm.inferrer --model_name lrm-base-obj-v1 --source_image ./assets/sample_input/owl.png --export_video --export_mesh
244
  """
245
 
246
  parser = argparse.ArgumentParser()
247
- parser.add_argument('--model_name', type=str, default='lrm-base-obj-v1')
248
  parser.add_argument('--source_image', type=str, default='./assets/sample_input/owl.png')
249
  parser.add_argument('--dump_path', type=str, default='./dumps')
250
  parser.add_argument('--source_size', type=int, default=-1)
 
45
 
46
  def _load_checkpoint(self, model_name: str, cache_dir = './.cache'):
47
  # download checkpoint if not exists
48
+ local_dir = os.path.join(cache_dir, model_name)
49
+ if not os.path.exists(local_dir):
50
+ os.makedirs(local_dir, exist_ok=True)
51
+ if not os.path.exists(os.path.join(local_dir, f'model.pth')):
52
  # os.system(f'wget -O {os.path.join(cache_dir, f"{model_name}.pth")} https://zxhezexin.com/modelzoo/openlrm/{model_name}.pth')
53
+ # raise FileNotFoundError(f"Checkpoint {model_name} not found in {cache_dir}")
54
+ from huggingface_hub import hf_hub_download
55
+ repo_id = f'zxhezexin/{model_name}'
56
+ config_path = hf_hub_download(repo_id=repo_id, filename='config.json', local_dir=local_dir)
57
+ model_path = hf_hub_download(repo_id=repo_id, filename=f'model.pth', local_dir=local_dir)
58
+ else:
59
+ model_path = os.path.join(local_dir, f'model.pth')
60
+ checkpoint = torch.load(model_path, map_location=self.device)
61
  return checkpoint
62
 
63
  def _build_model(self, model_kwargs, model_weights):
 
198
 
199
  image = torch.tensor(np.array(Image.open(source_image))).permute(2, 0, 1).unsqueeze(0) / 255.0
200
  # if RGBA, blend to RGB
201
+ # print(f"[DEBUG] check 1.")
202
  if image.shape[1] == 4:
203
  image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
204
  print(f"[DEBUG] image.shape={image.shape} and image[0,0,0,0]={image[0,0,0,0]}")
205
+ # print(f"[DEBUG] check 2.")
206
  image = torch.nn.functional.interpolate(image, size=(source_image_size, source_image_size), mode='bicubic', align_corners=True)
207
  image = torch.clamp(image, 0, 1)
208
  results = self.infer_single(
 
245
 
246
  """
247
  Example usage:
248
+ python -m lrm.inferrer --model_name openlrm-base-obj-1.0 --source_image ./assets/sample_input/owl.png --export_video --export_mesh
249
  """
250
 
251
  parser = argparse.ArgumentParser()
252
+ parser.add_argument('--model_name', type=str, default='openlrm-base-obj-1.0')
253
  parser.add_argument('--source_image', type=str, default='./assets/sample_input/owl.png')
254
  parser.add_argument('--dump_path', type=str, default='./dumps')
255
  parser.add_argument('--source_size', type=int, default=-1)