nicolas-dufour commited on
Commit
cf60dfb
Β·
1 Parent(s): 99dc3ef

initial commit

Browse files
Files changed (3) hide show
  1. app.py +45 -31
  2. models/samplers/riemannian_flow_sampler.py +3 -2
  3. pipe.py +22 -18
app.py CHANGED
@@ -52,11 +52,32 @@ def predict_location(image, model_name, cfg=0.0, num_samples=256):
52
 
53
  pipe = PIPES[model_name]
54
 
55
- # Get regular predictions
56
- predicted_gps = pipe(img, batch_size=num_samples, cfg=cfg, num_steps=16)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # Get single high-confidence prediction
 
59
  high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16)
 
 
 
 
 
60
  return {
61
  "lat": predicted_gps[:, 0].astype(float).tolist(),
62
  "lon": predicted_gps[:, 1].astype(float).tolist(),
@@ -234,14 +255,13 @@ def main():
234
  st.markdown("</div>", unsafe_allow_html=True)
235
 
236
  if st.button("πŸ” Predict Location", key="predict_upload"):
237
- with st.spinner("🌍 Analyzing image and predicting locations..."):
238
- predictions = predict_location(
239
- original_image,
240
- model_name=model_name,
241
- cfg=cfg_value,
242
- num_samples=num_samples,
243
- )
244
- st.session_state["predictions"] = predictions
245
 
246
  with tab2:
247
  url = st.text_input("Enter image URL:", key="image_url")
@@ -261,16 +281,13 @@ def main():
261
  st.markdown("</div>", unsafe_allow_html=True)
262
 
263
  if st.button("πŸ” Predict Location", key="predict_url"):
264
- with st.spinner(
265
- "🌍 Analyzing image and predicting locations..."
266
- ):
267
- predictions = predict_location(
268
- image,
269
- model_name=model_name,
270
- cfg=cfg_value,
271
- num_samples=num_samples,
272
- )
273
- st.session_state["predictions"] = predictions
274
 
275
  with tab3:
276
  examples = load_example_images()
@@ -290,17 +307,14 @@ def main():
290
  help=f"Click to predict location for {name}",
291
  use_container_width=True,
292
  ):
293
- with st.spinner(
294
- "🌍 Analyzing image and predicting locations..."
295
- ):
296
- predictions = predict_location(
297
- original_image,
298
- model_name=model_name,
299
- cfg=cfg_value,
300
- num_samples=num_samples,
301
- )
302
- st.session_state["predictions"] = predictions
303
- st.rerun()
304
 
305
  st.image(display_image, caption=name, use_container_width=True)
306
  st.markdown("</div>", unsafe_allow_html=True)
 
52
 
53
  pipe = PIPES[model_name]
54
 
55
+ # Create a progress bar
56
+ progress_bar = st.progress(0)
57
+ status_text = st.empty()
58
+
59
+ def update_progress(step, total_steps):
60
+ progress = float(step) / float(total_steps)
61
+ progress_bar.progress(progress)
62
+ status_text.text(f"Sampling step {step + 1}/{total_steps}")
63
+
64
+ # Get regular predictions with progress updates
65
+ predicted_gps = pipe(
66
+ img,
67
+ batch_size=num_samples,
68
+ cfg=cfg,
69
+ num_steps=16,
70
+ callback=update_progress
71
+ )
72
 
73
  # Get single high-confidence prediction
74
+ status_text.text("Generating high-confidence prediction...")
75
  high_conf_gps = pipe(img, batch_size=1, cfg=2.0, num_steps=16)
76
+
77
+ # Clear the status text and progress bar
78
+ status_text.empty()
79
+ progress_bar.empty()
80
+
81
  return {
82
  "lat": predicted_gps[:, 0].astype(float).tolist(),
83
  "lon": predicted_gps[:, 1].astype(float).tolist(),
 
255
  st.markdown("</div>", unsafe_allow_html=True)
256
 
257
  if st.button("πŸ” Predict Location", key="predict_upload"):
258
+ predictions = predict_location(
259
+ original_image,
260
+ model_name=model_name,
261
+ cfg=cfg_value,
262
+ num_samples=num_samples,
263
+ )
264
+ st.session_state["predictions"] = predictions
 
265
 
266
  with tab2:
267
  url = st.text_input("Enter image URL:", key="image_url")
 
281
  st.markdown("</div>", unsafe_allow_html=True)
282
 
283
  if st.button("πŸ” Predict Location", key="predict_url"):
284
+ predictions = predict_location(
285
+ image,
286
+ model_name=model_name,
287
+ cfg=cfg_value,
288
+ num_samples=num_samples,
289
+ )
290
+ st.session_state["predictions"] = predictions
 
 
 
291
 
292
  with tab3:
293
  examples = load_example_images()
 
307
  help=f"Click to predict location for {name}",
308
  use_container_width=True,
309
  ):
310
+ predictions = predict_location(
311
+ original_image,
312
+ model_name=model_name,
313
+ cfg=cfg_value,
314
+ num_samples=num_samples,
315
+ )
316
+ st.session_state["predictions"] = predictions
317
+ st.rerun()
 
 
 
318
 
319
  st.image(display_image, caption=name, use_container_width=True)
320
  st.markdown("</div>", unsafe_allow_html=True)
models/samplers/riemannian_flow_sampler.py CHANGED
@@ -13,6 +13,7 @@ def riemannian_flow_sampler(
13
  cfg_rate=0,
14
  generator=None,
15
  return_trajectories=False,
 
16
  ):
17
  if scheduler is None:
18
  raise ValueError("Scheduler must be provided")
@@ -35,13 +36,13 @@ def riemannian_flow_sampler(
35
  if cfg_rate > 0 and conditioning_keys is not None:
36
  stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
37
  stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
38
- denoised_all = net(stacked_batch)
39
  denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
40
  denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
41
  else:
42
  batch["y"] = x_cur
43
  batch["gamma"] = gamma_now.expand(x_cur.shape[0])
44
- denoised = net(batch)
45
 
46
  dt = gamma_next - gamma_now
47
  x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
 
13
  cfg_rate=0,
14
  generator=None,
15
  return_trajectories=False,
16
+ callback=None,
17
  ):
18
  if scheduler is None:
19
  raise ValueError("Scheduler must be provided")
 
36
  if cfg_rate > 0 and conditioning_keys is not None:
37
  stacked_batch["y"] = torch.cat([x_cur, x_cur], dim=0)
38
  stacked_batch["gamma"] = gamma_now.expand(x_cur.shape[0] * 2)
39
+ denoised_all = net(stacked_batch, current_step=step)
40
  denoised_cond, denoised_uncond = denoised_all.chunk(2, dim=0)
41
  denoised = denoised_cond * (1 + cfg_rate) - denoised_uncond * cfg_rate
42
  else:
43
  batch["y"] = x_cur
44
  batch["gamma"] = gamma_now.expand(x_cur.shape[0])
45
+ denoised = net(batch, current_step=step)
46
 
47
  dt = gamma_next - gamma_now
48
  x_next = x_cur + dt * denoised # manifold.expmap(x_cur, dt * denoised)
pipe.py CHANGED
@@ -216,6 +216,7 @@ class PlonkPipeline:
216
  scheduler=None,
217
  cfg=0,
218
  generator=None,
 
219
  ):
220
  """Sample from the model given conditioning.
221
 
@@ -228,6 +229,7 @@ class PlonkPipeline:
228
  scheduler: Custom scheduler function (uses default if not provided)
229
  cfg: Classifier-free guidance scale (default 15)
230
  generator: Random number generator
 
231
 
232
  Returns:
233
  Sampled GPS coordinates after postprocessing
@@ -264,26 +266,28 @@ class PlonkPipeline:
264
  sampler = self.sampler
265
  if scheduler is None:
266
  scheduler = self.scheduler
 
267
  # Sample from model
268
  if num_steps is None:
269
- output = sampler(
270
- self.model,
271
- batch,
272
- conditioning_keys="emb",
273
- scheduler=scheduler,
274
- cfg_rate=cfg,
275
- generator=generator,
276
- )
277
- else:
278
- output = sampler(
279
- self.model,
280
- batch,
281
- conditioning_keys="emb",
282
- scheduler=scheduler,
283
- num_steps=num_steps,
284
- cfg_rate=cfg,
285
- generator=generator,
286
- )
 
287
 
288
  # Apply postprocessing and return
289
  output = self.postprocessing(output)
 
216
  scheduler=None,
217
  cfg=0,
218
  generator=None,
219
+ callback=None,
220
  ):
221
  """Sample from the model given conditioning.
222
 
 
229
  scheduler: Custom scheduler function (uses default if not provided)
230
  cfg: Classifier-free guidance scale (default 15)
231
  generator: Random number generator
232
+ callback: Optional callback function to report progress (step, total_steps)
233
 
234
  Returns:
235
  Sampled GPS coordinates after postprocessing
 
266
  sampler = self.sampler
267
  if scheduler is None:
268
  scheduler = self.scheduler
269
+
270
  # Sample from model
271
  if num_steps is None:
272
+ num_steps = 16 # Default number of steps
273
+
274
+ # Create a wrapper for the model that updates progress
275
+ def model_with_progress(*args, **kwargs):
276
+ step = kwargs.pop('current_step', 0)
277
+ if callback:
278
+ callback(step, num_steps)
279
+ return self.model(*args, **kwargs)
280
+
281
+ output = sampler(
282
+ model_with_progress,
283
+ batch,
284
+ conditioning_keys="emb",
285
+ scheduler=scheduler,
286
+ num_steps=num_steps,
287
+ cfg_rate=cfg,
288
+ generator=generator,
289
+ callback=callback,
290
+ )
291
 
292
  # Apply postprocessing and return
293
  output = self.postprocessing(output)