Spaces:
Running
Running
nicolas-dufour
commited on
Commit
Β·
cf60dfb
1
Parent(s):
99dc3ef
initial commit
Browse files- app.py +45 -31
- models/samplers/riemannian_flow_sampler.py +3 -2
- pipe.py +22 -18
app.py
CHANGED
@@ -52,11 +52,32 @@ def predict_location(image, model_name, cfg=0.0, num_samples=256):
|
|
52 |
|
53 |
pipe = PIPES[model_name]
|
54 |
|
55 |
-
#
|
56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
# Get single high-confidence prediction
|
|
|
59 |
high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16)
|
|
|
|
|
|
|
|
|
|
|
60 |
return {
|
61 |
"lat": predicted_gps[:, 0].astype(float).tolist(),
|
62 |
"lon": predicted_gps[:, 1].astype(float).tolist(),
|
@@ -234,14 +255,13 @@ def main():
|
|
234 |
st.markdown("</div>", unsafe_allow_html=True)
|
235 |
|
236 |
if st.button("π Predict Location", key="predict_upload"):
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
st.session_state["predictions"] = predictions
|
245 |
|
246 |
with tab2:
|
247 |
url = st.text_input("Enter image URL:", key="image_url")
|
@@ -261,16 +281,13 @@ def main():
|
|
261 |
st.markdown("</div>", unsafe_allow_html=True)
|
262 |
|
263 |
if st.button("π Predict Location", key="predict_url"):
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
num_samples=num_samples,
|
272 |
-
)
|
273 |
-
st.session_state["predictions"] = predictions
|
274 |
|
275 |
with tab3:
|
276 |
examples = load_example_images()
|
@@ -290,17 +307,14 @@ def main():
|
|
290 |
help=f"Click to predict location for {name}",
|
291 |
use_container_width=True,
|
292 |
):
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
)
|
302 |
-
st.session_state["predictions"] = predictions
|
303 |
-
st.rerun()
|
304 |
|
305 |
st.image(display_image, caption=name, use_container_width=True)
|
306 |
st.markdown("</div>", unsafe_allow_html=True)
|
|
|
52 |
|
53 |
pipe = PIPES[model_name]
|
54 |
|
55 |
+
# Create a progress bar
|
56 |
+
progress_bar = st.progress(0)
|
57 |
+
status_text = st.empty()
|
58 |
+
|
59 |
+
def update_progress(step, total_steps):
|
60 |
+
progress = float(step) / float(total_steps)
|
61 |
+
progress_bar.progress(progress)
|
62 |
+
status_text.text(f"Sampling step {step + 1}/{total_steps}")
|
63 |
+
|
64 |
+
# Get regular predictions with progress updates
|
65 |
+
predicted_gps = pipe(
|
66 |
+
img,
|
67 |
+
batch_size=num_samples,
|
68 |
+
cfg=cfg,
|
69 |
+
num_steps=16,
|
70 |
+
callback=update_progress
|
71 |
+
)
|
72 |
|
73 |
# Get single high-confidence prediction
|
74 |
+
status_text.text("Generating high-confidence prediction...")
|
75 |
high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16)
|
76 |
+
|
77 |
+
# Clear the status text and progress bar
|
78 |
+
status_text.empty()
|
79 |
+
progress_bar.empty()
|
80 |
+
|
81 |
return {
|
82 |
"lat": predicted_gps[:, 0].astype(float).tolist(),
|
83 |
"lon": predicted_gps[:, 1].astype(float).tolist(),
|
|
|
255 |
st.markdown("</div>", unsafe_allow_html=True)
|
256 |
|
257 |
if st.button("π Predict Location", key="predict_upload"):
|
258 |
+
predictions = predict_location(
|
259 |
+
original_image,
|
260 |
+
model_name=model_name,
|
261 |
+
cfg=cfg_value,
|
262 |
+
num_samples=num_samples,
|
263 |
+
)
|
264 |
+
st.session_state["predictions"] = predictions
|
|
|
265 |
|
266 |
with tab2:
|
267 |
url = st.text_input("Enter image URL:", key="image_url")
|
|
|
281 |
st.markdown("</div>", unsafe_allow_html=True)
|
282 |
|
283 |
if st.button("π Predict Location", key="predict_url"):
|
284 |
+
predictions = predict_location(
|
285 |
+
image,
|
286 |
+
model_name=model_name,
|
287 |
+
cfg=cfg_value,
|
288 |
+
num_samples=num_samples,
|
289 |
+
)
|
290 |
+
st.session_state["predictions"] = predictions
|
|
|
|
|
|
|
291 |
|
292 |
with tab3:
|
293 |
examples = load_example_images()
|
|
|
307 |
help=f"Click to predict location for {name}",
|
308 |
use_container_width=True,
|
309 |
):
|
310 |
+
predictions = predict_location(
|
311 |
+
original_image,
|
312 |
+
model_name=model_name,
|
313 |
+
cfg=cfg_value,
|
314 |
+
num_samples=num_samples,
|
315 |
+
)
|
316 |
+
st.session_state["predictions"] = predictions
|
317 |
+
st.rerun()
|
|
|
|
|
|
|
318 |
|
319 |
st.image(display_image, caption=name, use_container_width=True)
|
320 |
st.markdown("</div>", unsafe_allow_html=True)
|
models/samplers/riemannian_flow_sampler.py
CHANGED
@@ -13,6 +13,7 @@ def riemannian_flow_sampler(
|
|
13 |
cfg_rate=0,
|
14 |
generator=None,
|
15 |
return_trajectories=False,
|
|
|
16 |
):
|
17 |
if scheduler is None:
|
18 |
raise ValueError("Scheduler must be provided")
|
@@ -35,13 +36,13 @@ def riemannian_flow_sampler(
|
|
35 |
if cfg_rate > 0 and conditioning_keys is not None:
|
36 |
stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
|
37 |
stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
|
38 |
-
denoised_all = net(stacked_batch)
|
39 |
denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
|
40 |
denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
|
41 |
else:
|
42 |
batch["y"] = x_cur
|
43 |
batch["gamma"] = gamma_now.expand(x_cur.shape[0])
|
44 |
-
denoised = net(batch)
|
45 |
|
46 |
dt = gamma_next - gamma_now
|
47 |
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
|
|
|
13 |
cfg_rate=0,
|
14 |
generator=None,
|
15 |
return_trajectories=False,
|
16 |
+
callback=None,
|
17 |
):
|
18 |
if scheduler is None:
|
19 |
raise ValueError("Scheduler must be provided")
|
|
|
36 |
if cfg_rate > 0 and conditioning_keys is not None:
|
37 |
stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
|
38 |
stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
|
39 |
+
denoised_all = net(stacked_batch, current_step=step)
|
40 |
denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
|
41 |
denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
|
42 |
else:
|
43 |
batch["y"] = x_cur
|
44 |
batch["gamma"] = gamma_now.expand(x_cur.shape[0])
|
45 |
+
denoised = net(batch, current_step=step)
|
46 |
|
47 |
dt = gamma_next - gamma_now
|
48 |
x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
|
pipe.py
CHANGED
@@ -216,6 +216,7 @@ class PlonkPipeline:
|
|
216 |
scheduler=None,
|
217 |
cfg=0,
|
218 |
generator=None,
|
|
|
219 |
):
|
220 |
"""Sample from the model given conditioning.
|
221 |
|
@@ -228,6 +229,7 @@ class PlonkPipeline:
|
|
228 |
scheduler: Custom scheduler function (uses default if not provided)
|
229 |
cfg: Classifier-free guidance scale (default 15)
|
230 |
generator: Random number generator
|
|
|
231 |
|
232 |
Returns:
|
233 |
Sampled GPS coordinates after postprocessing
|
@@ -264,26 +266,28 @@ class PlonkPipeline:
|
|
264 |
sampler = self.sampler
|
265 |
if scheduler is None:
|
266 |
scheduler = self.scheduler
|
|
|
267 |
# Sample from model
|
268 |
if num_steps is None:
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
)
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
|
|
287 |
|
288 |
# Apply postprocessing and return
|
289 |
output = self.postprocessing(output)
|
|
|
216 |
scheduler=None,
|
217 |
cfg=0,
|
218 |
generator=None,
|
219 |
+
callback=None,
|
220 |
):
|
221 |
"""Sample from the model given conditioning.
|
222 |
|
|
|
229 |
scheduler: Custom scheduler function (uses default if not provided)
|
230 |
cfg: Classifier-free guidance scale (default 15)
|
231 |
generator: Random number generator
|
232 |
+
callback: Optional callback function to report progress (step, total_steps)
|
233 |
|
234 |
Returns:
|
235 |
Sampled GPS coordinates after postprocessing
|
|
|
266 |
sampler = self.sampler
|
267 |
if scheduler is None:
|
268 |
scheduler = self.scheduler
|
269 |
+
|
270 |
# Sample from model
|
271 |
if num_steps is None:
|
272 |
+
num_steps = 16 # Default number of steps
|
273 |
+
|
274 |
+
# Create a wrapper for the model that updates progress
|
275 |
+
def model_with_progress(*args, **kwargs):
|
276 |
+
step = kwargs.pop('current_step', 0)
|
277 |
+
if callback:
|
278 |
+
callback(step, num_steps)
|
279 |
+
return self.model(*args, **kwargs)
|
280 |
+
|
281 |
+
output = sampler(
|
282 |
+
model_with_progress,
|
283 |
+
batch,
|
284 |
+
conditioning_keys="emb",
|
285 |
+
scheduler=scheduler,
|
286 |
+
num_steps=num_steps,
|
287 |
+
cfg_rate=cfg,
|
288 |
+
generator=generator,
|
289 |
+
callback=callback,
|
290 |
+
)
|
291 |
|
292 |
# Apply postprocessing and return
|
293 |
output = self.postprocessing(output)
|