Spaces:
Sleeping
Sleeping
Update lrm/inferrer.py
Browse files- lrm/inferrer.py +16 -11
lrm/inferrer.py
CHANGED
@@ -45,14 +45,19 @@ class LRMInferrer:
|
|
45 |
|
46 |
def _load_checkpoint(self, model_name: str, cache_dir = './.cache'):
|
47 |
# download checkpoint if not exists
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
# os.system(f'wget -O {os.path.join(cache_dir, f"{model_name}.pth")} https://zxhezexin.com/modelzoo/openlrm/{model_name}.pth')
|
53 |
-
raise FileNotFoundError(f"Checkpoint {model_name} not found in {cache_dir}")
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
56 |
return checkpoint
|
57 |
|
58 |
def _build_model(self, model_kwargs, model_weights):
|
@@ -193,11 +198,11 @@ class LRMInferrer:
|
|
193 |
|
194 |
image = torch.tensor(np.array(Image.open(source_image))).permute(2, 0, 1).unsqueeze(0) / 255.0
|
195 |
# if RGBA, blend to RGB
|
196 |
-
print(f"[DEBUG] check 1.")
|
197 |
if image.shape[1] == 4:
|
198 |
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
|
199 |
print(f"[DEBUG] image.shape={image.shape} and image[0,0,0,0]={image[0,0,0,0]}")
|
200 |
-
print(f"[DEBUG] check 2.")
|
201 |
image = torch.nn.functional.interpolate(image, size=(source_image_size, source_image_size), mode='bicubic', align_corners=True)
|
202 |
image = torch.clamp(image, 0, 1)
|
203 |
results = self.infer_single(
|
@@ -240,11 +245,11 @@ if __name__ == '__main__':
|
|
240 |
|
241 |
"""
|
242 |
Example usage:
|
243 |
-
python -m lrm.inferrer --model_name
|
244 |
"""
|
245 |
|
246 |
parser = argparse.ArgumentParser()
|
247 |
-
parser.add_argument('--model_name', type=str, default='
|
248 |
parser.add_argument('--source_image', type=str, default='./assets/sample_input/owl.png')
|
249 |
parser.add_argument('--dump_path', type=str, default='./dumps')
|
250 |
parser.add_argument('--source_size', type=int, default=-1)
|
|
|
45 |
|
46 |
def _load_checkpoint(self, model_name: str, cache_dir = './.cache'):
|
47 |
# download checkpoint if not exists
|
48 |
+
local_dir = os.path.join(cache_dir, model_name)
|
49 |
+
if not os.path.exists(local_dir):
|
50 |
+
os.makedirs(local_dir, exist_ok=True)
|
51 |
+
if not os.path.exists(os.path.join(local_dir, f'model.pth')):
|
52 |
# os.system(f'wget -O {os.path.join(cache_dir, f"{model_name}.pth")} https://zxhezexin.com/modelzoo/openlrm/{model_name}.pth')
|
53 |
+
# raise FileNotFoundError(f"Checkpoint {model_name} not found in {cache_dir}")
|
54 |
+
from huggingface_hub import hf_hub_download
|
55 |
+
repo_id = f'zxhezexin/{model_name}'
|
56 |
+
config_path = hf_hub_download(repo_id=repo_id, filename='config.json', local_dir=local_dir)
|
57 |
+
model_path = hf_hub_download(repo_id=repo_id, filename=f'model.pth', local_dir=local_dir)
|
58 |
+
else:
|
59 |
+
model_path = os.path.join(local_dir, f'model.pth')
|
60 |
+
checkpoint = torch.load(model_path, map_location=self.device)
|
61 |
return checkpoint
|
62 |
|
63 |
def _build_model(self, model_kwargs, model_weights):
|
|
|
198 |
|
199 |
image = torch.tensor(np.array(Image.open(source_image))).permute(2, 0, 1).unsqueeze(0) / 255.0
|
200 |
# if RGBA, blend to RGB
|
201 |
+
# print(f"[DEBUG] check 1.")
|
202 |
if image.shape[1] == 4:
|
203 |
image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])
|
204 |
print(f"[DEBUG] image.shape={image.shape} and image[0,0,0,0]={image[0,0,0,0]}")
|
205 |
+
# print(f"[DEBUG] check 2.")
|
206 |
image = torch.nn.functional.interpolate(image, size=(source_image_size, source_image_size), mode='bicubic', align_corners=True)
|
207 |
image = torch.clamp(image, 0, 1)
|
208 |
results = self.infer_single(
|
|
|
245 |
|
246 |
"""
|
247 |
Example usage:
|
248 |
+
python -m lrm.inferrer --model_name openlrm-base-obj-1.0 --source_image ./assets/sample_input/owl.png --export_video --export_mesh
|
249 |
"""
|
250 |
|
251 |
parser = argparse.ArgumentParser()
|
252 |
+
parser.add_argument('--model_name', type=str, default='openlrm-base-obj-1.0')
|
253 |
parser.add_argument('--source_image', type=str, default='./assets/sample_input/owl.png')
|
254 |
parser.add_argument('--dump_path', type=str, default='./dumps')
|
255 |
parser.add_argument('--source_size', type=int, default=-1)
|