Upload folder using huggingface_hub
Browse files- geocalib/extractor.py +2 -7
- gradio_app.py +8 -5
- siclib/models/extractor.py +2 -7
geocalib/extractor.py
CHANGED
@@ -22,14 +22,9 @@ class GeoCalib(nn.Module):
|
|
22 |
weights (str): trained variant, "pinhole" (default) or "distorted".
|
23 |
"""
|
24 |
super().__init__()
|
25 |
-
if weights
|
26 |
-
url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
|
27 |
-
elif weights == "distorted":
|
28 |
-
url = (
|
29 |
-
"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
|
30 |
-
)
|
31 |
-
else:
|
32 |
raise ValueError(f"Unknown weights: {weights}")
|
|
|
33 |
|
34 |
# load checkpoint
|
35 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|
|
|
22 |
weights (str): trained variant, "pinhole" (default) or "distorted".
|
23 |
"""
|
24 |
super().__init__()
|
25 |
+
if weights not in {"pinhole", "distorted"}:
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
raise ValueError(f"Unknown weights: {weights}")
|
27 |
+
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"
|
28 |
|
29 |
# load checkpoint
|
30 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|
gradio_app.py
CHANGED
@@ -8,7 +8,7 @@ import numpy as np
|
|
8 |
import spaces
|
9 |
import torch
|
10 |
|
11 |
-
from geocalib import viz2d
|
12 |
from geocalib.camera import camera_models
|
13 |
from geocalib.extractor import GeoCalib
|
14 |
from geocalib.perspective_fields import get_perspective_field
|
@@ -77,7 +77,9 @@ def format_output(results):
|
|
77 |
@spaces.GPU(duration=10)
|
78 |
def inference(img, camera_model):
|
79 |
out = model.calibrate(img.to(device), camera_model=camera_model)
|
80 |
-
save_keys = ["camera", "gravity"] + [
|
|
|
|
|
81 |
res = {k: v.cpu() for k, v in out.items() if k in save_keys}
|
82 |
# not converting to numpy results in gpu abort
|
83 |
res["up_confidence"] = out["up_confidence"].cpu().numpy()
|
@@ -100,10 +102,9 @@ def process_results(
|
|
100 |
raise gr.Error("Please upload an image first.")
|
101 |
|
102 |
img = model.load_image(image_path)
|
103 |
-
print("Running inference...")
|
104 |
start = time()
|
105 |
inference_result = inference(img, camera_model)
|
106 |
-
|
107 |
inference_result["image"] = img.cpu()
|
108 |
|
109 |
if inference_result is None:
|
@@ -158,7 +159,9 @@ def update_plot(
|
|
158 |
viz2d.plot_confidences([torch.tensor(inference_result["up_confidence"][0])], axes=[ax[0]])
|
159 |
|
160 |
if plot_latitude_confidence:
|
161 |
-
viz2d.plot_confidences(
|
|
|
|
|
162 |
|
163 |
fig.canvas.draw()
|
164 |
img = np.array(fig.canvas.renderer.buffer_rgba())
|
|
|
8 |
import spaces
|
9 |
import torch
|
10 |
|
11 |
+
from geocalib import logger, viz2d
|
12 |
from geocalib.camera import camera_models
|
13 |
from geocalib.extractor import GeoCalib
|
14 |
from geocalib.perspective_fields import get_perspective_field
|
|
|
77 |
@spaces.GPU(duration=10)
|
78 |
def inference(img, camera_model):
|
79 |
out = model.calibrate(img.to(device), camera_model=camera_model)
|
80 |
+
save_keys = ["camera", "gravity"] + [
|
81 |
+
f"{k}_uncertainty" for k in ["roll", "pitch", "vfov", "focal"]
|
82 |
+
]
|
83 |
res = {k: v.cpu() for k, v in out.items() if k in save_keys}
|
84 |
# not converting to numpy results in gpu abort
|
85 |
res["up_confidence"] = out["up_confidence"].cpu().numpy()
|
|
|
102 |
raise gr.Error("Please upload an image first.")
|
103 |
|
104 |
img = model.load_image(image_path)
|
|
|
105 |
start = time()
|
106 |
inference_result = inference(img, camera_model)
|
107 |
+
logger.info(f"Calibration took {time() - start:.2f} sec. ({camera_model})")
|
108 |
inference_result["image"] = img.cpu()
|
109 |
|
110 |
if inference_result is None:
|
|
|
159 |
viz2d.plot_confidences([torch.tensor(inference_result["up_confidence"][0])], axes=[ax[0]])
|
160 |
|
161 |
if plot_latitude_confidence:
|
162 |
+
viz2d.plot_confidences(
|
163 |
+
[torch.tensor(inference_result["latitude_confidence"][0])], axes=[ax[0]]
|
164 |
+
)
|
165 |
|
166 |
fig.canvas.draw()
|
167 |
img = np.array(fig.canvas.renderer.buffer_rgba())
|
siclib/models/extractor.py
CHANGED
@@ -22,14 +22,9 @@ class GeoCalib(nn.Module):
|
|
22 |
weights (str, optional): Weights to load. Defaults to "pinhole".
|
23 |
"""
|
24 |
super().__init__()
|
25 |
-
if weights
|
26 |
-
url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
|
27 |
-
elif weights == "distorted":
|
28 |
-
url = (
|
29 |
-
"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
|
30 |
-
)
|
31 |
-
else:
|
32 |
raise ValueError(f"Unknown weights: {weights}")
|
|
|
33 |
|
34 |
# load checkpoint
|
35 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|
|
|
22 |
weights (str, optional): Weights to load. Defaults to "pinhole".
|
23 |
"""
|
24 |
super().__init__()
|
25 |
+
if weights not in {"pinhole", "distorted"}:
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
raise ValueError(f"Unknown weights: {weights}")
|
27 |
+
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"
|
28 |
|
29 |
# load checkpoint
|
30 |
model_dir = f"{torch.hub.get_dir()}/geocalib"
|