osherr commited on
Commit
c6dd5dc
·
verified ·
1 Parent(s): 6029516

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +173 -36
src/streamlit_app.py CHANGED
@@ -29,7 +29,12 @@ STRIDE = 4
29
 
30
  # Keep LoRA normalization from your local code
31
  IMAGENET_MEAN = (0.430, 0.411, 0.296)
32
- IMAGENET_STD = (0.213, 0.156, 0.143)
 
 
 
 
 
33
 
34
  # ============================================================
35
  # HELPERS
@@ -58,6 +63,65 @@ def preview_rgb(rgb_raw):
58
  rgb = rgb / (np.percentile(rgb, 98) + 1e-6)
59
  return np.clip(rgb, 0, 1)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # ============================================================
62
  # MODELS
63
  # ============================================================
@@ -138,9 +202,9 @@ def get_lora_params(model):
138
  params.extend(list(module.B.parameters()))
139
  return params
140
 
 
141
  # ============================================================
142
  # MODEL LOADING
143
- # Uses old app paths / loading logic
144
  # ============================================================
145
  @st.cache_resource
146
  def load_models(repo_id, dav_file, dino_file):
@@ -196,15 +260,14 @@ def load_models(repo_id, dav_file, dino_file):
196
 
197
  return dav_model, dino_model
198
 
 
199
  # ============================================================
200
  # DEPTH ANYTHING INFERENCE
201
- # Use exactly the old app style: raw DA output, no disparity2depth
202
  # ============================================================
203
  @st.cache_data(show_spinner=False)
204
  def run_dav_inference(_dav, rgb_raw, h_f, w_f):
205
  img_448 = cv2.resize(rgb_raw.transpose(1, 2, 0), (448, 448))
206
 
207
- # Use torch.tensor for HF stability
208
  dav_in = torch.tensor(img_448, device=DEVICE).permute(2, 0, 1).unsqueeze(0).float() / 255.0
209
 
210
  with torch.no_grad():
@@ -217,17 +280,17 @@ def run_dav_inference(_dav, rgb_raw, h_f, w_f):
217
  size=(h_f, w_f),
218
  mode="bilinear",
219
  align_corners=False
220
- ).squeeze(1) # [1,H,W]
221
 
222
  raw_depth_map = raw_depth[0].detach().float().cpu().numpy()
223
 
224
- # optional normalized version for more stable LoRA fitting / plotting
225
  valid = np.isfinite(raw_depth_map)
226
  raw_depth_01 = normalize_01(raw_depth_map, valid)
227
  raw_depth_01[~valid] = np.nan
228
 
229
  return raw_depth_map, raw_depth_01
230
 
 
231
  # ============================================================
232
  # MAIN LORA PIPELINE
233
  # ============================================================
@@ -253,7 +316,6 @@ def run_lora_pipeline(
253
  anchor_mask = anchor_mask_cpu.to(DEVICE)
254
  prior_gpu = prior_raw_t.to(DEVICE)
255
 
256
- # make fresh copy each run
257
  dino = copy.deepcopy(dino_base)
258
  dino = inject_lora(dino, r=lora_r, alpha=lora_alpha).to(DEVICE).train()
259
 
@@ -273,8 +335,8 @@ def run_lora_pipeline(
273
 
274
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
275
  prior_p = F.interpolate(prior_gpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten()
276
- rel_p = F.interpolate(rel_cpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten()
277
- mask_p = F.interpolate(anchor_mask.float().view(1, 1, H, W), size=(Hp, Wp), mode="area").flatten() > 0.5
278
 
279
  loss_hist = []
280
  prog = st.progress(0, text="Running LoRA TTO...")
@@ -297,7 +359,6 @@ def run_lora_pipeline(
297
 
298
  prog.empty()
299
 
300
- # dense stride-4 inference
301
  dino.eval()
302
  mlp_head.eval()
303
 
@@ -323,11 +384,11 @@ def run_lora_pipeline(
323
 
324
  sb_local = mlp_head(t).t().reshape(2, hc // p, wc // p)
325
 
326
- sb_acc[:, dy//stride:dy//stride + (hc//p)*(p//stride):p//stride,
327
- dx//stride:dx//stride + (wc//p)*(p//stride):p//stride] += sb_local
328
 
329
- cnt_acc[:, dy//stride:dy//stride + (hc//p)*(p//stride):p//stride,
330
- dx//stride:dx//stride + (wc//p)*(p//stride):p//stride] += 1
331
 
332
  sb_dense = sb_acc / (cnt_acc + 1e-8)
333
  offset = (p - (p // 2)) // stride + 1
@@ -350,15 +411,54 @@ def run_lora_pipeline(
350
 
351
  return final_dsm, loss_hist, anchor_mask_cpu.cpu().numpy()
352
 
 
353
  # ============================================================
354
  # UI
355
  # ============================================================
356
  st.title("Prior2DSM | LoRA")
357
 
 
 
 
 
 
 
 
 
358
  with st.sidebar:
359
  st.header("📂 Data")
360
- rgb_file = st.file_uploader("RGB Image", type=["tif", "tiff"])
361
- prior_file = st.file_uploader("LiDAR Prior", type=["tif", "tiff"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
362
 
363
  st.divider()
364
  st.write("#### LoRA / TTO")
@@ -367,20 +467,29 @@ with st.sidebar:
367
  tto_steps = st.slider("TTO steps", 10, 300, 100, step=10)
368
  tto_lr = st.select_slider("TTO LR", options=[1e-4, 3e-4, 1e-3, 3e-3], value=1e-3)
369
 
370
- if rgb_file and prior_file:
 
 
 
 
 
 
 
 
371
  dav_m, dino_base = load_models(
372
  repo_id="osherr/Prior2DSM",
373
  dav_file="depth_anything_v2_vitl.pth",
374
  dino_file="dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
375
  )
376
 
377
- with rasterio.open(BytesIO(rgb_file.read())) as src:
378
- rgb_raw = src.read([1, 2, 3])
379
- h_f, w_f = src.height, src.width
 
 
 
380
 
381
- with rasterio.open(BytesIO(prior_file.read())) as src:
382
- prior_raw = src.read(1).astype(np.float32)
383
- prior_meta = src.meta.copy()
384
 
385
  with st.spinner("Generating relative depth with Depth Anything V2..."):
386
  rel_depth_map, rel_depth_01 = run_dav_inference(dav_m, rgb_raw, h_f, w_f)
@@ -388,15 +497,45 @@ if rgb_file and prior_file:
388
  st.subheader("1. ROI Selection")
389
 
390
  viz_rgb = preview_rgb(rgb_raw)
391
-
392
  col_img, col_ctrl = st.columns([1.2, 0.8])
393
 
394
  with col_ctrl:
395
- x_center = st.slider("X center", 0, w_f - 1, w_f // 2)
396
- y_center = st.slider("Y center", 0, h_f - 1, h_f // 2)
397
- bbox_size = st.slider("BBox Size (px)", 50, min(400, min(h_f, w_f)), 200)
398
- use_normalized_rel = st.checkbox("Use normalized relative depth for LoRA", value=True)
399
- run_btn = st.button("🚀 Run LoRA Pipeline", type="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
 
401
  half_s = bbox_size // 2
402
  x1, x2 = max(0, x_center - half_s), min(w_f, x_center + half_s)
@@ -406,12 +545,8 @@ if rgb_file and prior_file:
406
  bbox_mask[y1:y2, x1:x2] = True
407
 
408
  with col_img:
409
- fig_roi, ax_roi = plt.subplots(figsize=(6, 6))
410
- ax_roi.imshow(viz_rgb)
411
- ax_roi.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", lw=2))
412
- ax_roi.set_title("ROI Preview")
413
- ax_roi.axis("off")
414
- st.pyplot(fig_roi)
415
 
416
  if run_btn:
417
  rel_for_lora = rel_depth_01 if use_normalized_rel else rel_depth_map
@@ -512,4 +647,6 @@ if rgb_file and prior_file:
512
  "Download Georeferenced DSM",
513
  out_buf.getvalue(),
514
  file_name="lora_refined_dsm_georef.tif"
515
- )
 
 
 
29
 
30
  # Keep LoRA normalization from your local code
31
  IMAGENET_MEAN = (0.430, 0.411, 0.296)
32
+ IMAGENET_STD = (0.213, 0.156, 0.143)
33
+
34
+ # Example files inside the HF Space repo
35
+ EXAMPLE_RGB_FILENAME = "examples/example_rgb.tif"
36
+ EXAMPLE_PRIOR_FILENAME = "examples/example_prior.tif"
37
+
38
 
39
  # ============================================================
40
  # HELPERS
 
63
  rgb = rgb / (np.percentile(rgb, 98) + 1e-6)
64
  return np.clip(rgb, 0, 1)
65
 
66
+
67
+ def draw_roi_preview(viz_rgb, x1, y1, x2, y2):
68
+ preview = (np.clip(viz_rgb, 0, 1) * 255).astype(np.uint8).copy()
69
+ cv2.rectangle(preview, (x1, y1), (x2, y2), (255, 0, 0), 2)
70
+ return preview
71
+
72
+
73
+ @st.cache_data(show_spinner=False)
74
+ def load_tiff_from_hf(repo_id, filename, repo_type="space"):
75
+ return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type)
76
+
77
+
78
+ def read_rgb_tiff(path_or_bytes):
79
+ if isinstance(path_or_bytes, (str, os.PathLike)):
80
+ with rasterio.open(path_or_bytes) as src:
81
+ rgb_raw = src.read([1, 2, 3])
82
+ h_f, w_f = src.height, src.width
83
+ meta = src.meta.copy()
84
+ else:
85
+ with rasterio.open(BytesIO(path_or_bytes)) as src:
86
+ rgb_raw = src.read([1, 2, 3])
87
+ h_f, w_f = src.height, src.width
88
+ meta = src.meta.copy()
89
+ return rgb_raw, h_f, w_f, meta
90
+
91
+
92
+ def read_prior_tiff(path_or_bytes):
93
+ if isinstance(path_or_bytes, (str, os.PathLike)):
94
+ with rasterio.open(path_or_bytes) as src:
95
+ prior_raw = src.read(1).astype(np.float32)
96
+ meta = src.meta.copy()
97
+ else:
98
+ with rasterio.open(BytesIO(path_or_bytes)) as src:
99
+ prior_raw = src.read(1).astype(np.float32)
100
+ meta = src.meta.copy()
101
+ return prior_raw, meta
102
+
103
+
104
+ def init_roi_state(h_f, w_f):
105
+ if "x_center" not in st.session_state:
106
+ st.session_state["x_center"] = w_f // 2
107
+ if "y_center" not in st.session_state:
108
+ st.session_state["y_center"] = h_f // 2
109
+ if "bbox_size" not in st.session_state:
110
+ st.session_state["bbox_size"] = min(200, min(h_f, w_f))
111
+ if "use_normalized_rel" not in st.session_state:
112
+ st.session_state["use_normalized_rel"] = True
113
+ if "loaded_shape" not in st.session_state:
114
+ st.session_state["loaded_shape"] = (h_f, w_f)
115
+
116
+ prev_shape = st.session_state.get("loaded_shape", None)
117
+ if prev_shape != (h_f, w_f):
118
+ st.session_state["x_center"] = w_f // 2
119
+ st.session_state["y_center"] = h_f // 2
120
+ st.session_state["bbox_size"] = min(200, min(h_f, w_f))
121
+ st.session_state["use_normalized_rel"] = True
122
+ st.session_state["loaded_shape"] = (h_f, w_f)
123
+
124
+
125
  # ============================================================
126
  # MODELS
127
  # ============================================================
 
202
  params.extend(list(module.B.parameters()))
203
  return params
204
 
205
+
206
  # ============================================================
207
  # MODEL LOADING
 
208
  # ============================================================
209
  @st.cache_resource
210
  def load_models(repo_id, dav_file, dino_file):
 
260
 
261
  return dav_model, dino_model
262
 
263
+
264
  # ============================================================
265
  # DEPTH ANYTHING INFERENCE
 
266
  # ============================================================
267
  @st.cache_data(show_spinner=False)
268
  def run_dav_inference(_dav, rgb_raw, h_f, w_f):
269
  img_448 = cv2.resize(rgb_raw.transpose(1, 2, 0), (448, 448))
270
 
 
271
  dav_in = torch.tensor(img_448, device=DEVICE).permute(2, 0, 1).unsqueeze(0).float() / 255.0
272
 
273
  with torch.no_grad():
 
280
  size=(h_f, w_f),
281
  mode="bilinear",
282
  align_corners=False
283
+ ).squeeze(1)
284
 
285
  raw_depth_map = raw_depth[0].detach().float().cpu().numpy()
286
 
 
287
  valid = np.isfinite(raw_depth_map)
288
  raw_depth_01 = normalize_01(raw_depth_map, valid)
289
  raw_depth_01[~valid] = np.nan
290
 
291
  return raw_depth_map, raw_depth_01
292
 
293
+
294
  # ============================================================
295
  # MAIN LORA PIPELINE
296
  # ============================================================
 
316
  anchor_mask = anchor_mask_cpu.to(DEVICE)
317
  prior_gpu = prior_raw_t.to(DEVICE)
318
 
 
319
  dino = copy.deepcopy(dino_base)
320
  dino = inject_lora(dino, r=lora_r, alpha=lora_alpha).to(DEVICE).train()
321
 
 
335
 
336
  Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
337
  prior_p = F.interpolate(prior_gpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten()
338
+ rel_p = F.interpolate(rel_cpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten()
339
+ mask_p = F.interpolate(anchor_mask.float().view(1, 1, H, W), size=(Hp, Wp), mode="area").flatten() > 0.5
340
 
341
  loss_hist = []
342
  prog = st.progress(0, text="Running LoRA TTO...")
 
359
 
360
  prog.empty()
361
 
 
362
  dino.eval()
363
  mlp_head.eval()
364
 
 
384
 
385
  sb_local = mlp_head(t).t().reshape(2, hc // p, wc // p)
386
 
387
+ sb_acc[:, dy // stride:dy // stride + (hc // p) * (p // stride):p // stride,
388
+ dx // stride:dx // stride + (wc // p) * (p // stride):p // stride] += sb_local
389
 
390
+ cnt_acc[:, dy // stride:dy // stride + (hc // p) * (p // stride):p // stride,
391
+ dx // stride:dx // stride + (wc // p) * (p // stride):p // stride] += 1
392
 
393
  sb_dense = sb_acc / (cnt_acc + 1e-8)
394
  offset = (p - (p // 2)) // stride + 1
 
411
 
412
  return final_dsm, loss_hist, anchor_mask_cpu.cpu().numpy()
413
 
414
+
415
  # ============================================================
416
  # UI
417
  # ============================================================
418
  st.title("Prior2DSM | LoRA")
419
 
420
+ st.markdown(
421
+ f"""
422
+ **Example TIFFs**
423
+ - [Download example RGB TIFF](https://huggingface.co/spaces/osherr/Prior2DSM/resolve/main/{EXAMPLE_RGB_FILENAME})
424
+ - [Download example Prior TIFF](https://huggingface.co/spaces/osherr/Prior2DSM/resolve/main/{EXAMPLE_PRIOR_FILENAME})
425
+ """
426
+ )
427
+
428
  with st.sidebar:
429
  st.header("📂 Data")
430
+
431
+ data_mode = st.radio(
432
+ "Data source",
433
+ ["Upload TIFFs", "Use example TIFFs"],
434
+ index=0
435
+ )
436
+
437
+ rgb_file = None
438
+ prior_file = None
439
+ rgb_example_path = None
440
+ prior_example_path = None
441
+
442
+ if data_mode == "Upload TIFFs":
443
+ rgb_file = st.file_uploader("RGB Image", type=["tif", "tiff"])
444
+ prior_file = st.file_uploader("LiDAR Prior", type=["tif", "tiff"])
445
+ else:
446
+ st.caption("Load demo RGB/Prior TIFFs from the Hugging Face Space.")
447
+ if st.button("Load example TIFFs"):
448
+ st.session_state["use_examples"] = True
449
+
450
+ if st.session_state.get("use_examples", False):
451
+ rgb_example_path = load_tiff_from_hf(
452
+ repo_id="osherr/Prior2DSM",
453
+ filename=EXAMPLE_RGB_FILENAME,
454
+ repo_type="space"
455
+ )
456
+ prior_example_path = load_tiff_from_hf(
457
+ repo_id="osherr/Prior2DSM",
458
+ filename=EXAMPLE_PRIOR_FILENAME,
459
+ repo_type="space"
460
+ )
461
+ st.success("Example TIFFs loaded.")
462
 
463
  st.divider()
464
  st.write("#### LoRA / TTO")
 
467
  tto_steps = st.slider("TTO steps", 10, 300, 100, step=10)
468
  tto_lr = st.select_slider("TTO LR", options=[1e-4, 3e-4, 1e-3, 3e-3], value=1e-3)
469
 
470
+ has_uploaded = (rgb_file is not None and prior_file is not None)
471
+ has_examples = (
472
+ data_mode == "Use example TIFFs"
473
+ and st.session_state.get("use_examples", False)
474
+ and rgb_example_path is not None
475
+ and prior_example_path is not None
476
+ )
477
+
478
+ if has_uploaded or has_examples:
479
  dav_m, dino_base = load_models(
480
  repo_id="osherr/Prior2DSM",
481
  dav_file="depth_anything_v2_vitl.pth",
482
  dino_file="dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
483
  )
484
 
485
+ if has_uploaded:
486
+ rgb_raw, h_f, w_f, _ = read_rgb_tiff(rgb_file.read())
487
+ prior_raw, prior_meta = read_prior_tiff(prior_file.read())
488
+ else:
489
+ rgb_raw, h_f, w_f, _ = read_rgb_tiff(rgb_example_path)
490
+ prior_raw, prior_meta = read_prior_tiff(prior_example_path)
491
 
492
+ init_roi_state(h_f, w_f)
 
 
493
 
494
  with st.spinner("Generating relative depth with Depth Anything V2..."):
495
  rel_depth_map, rel_depth_01 = run_dav_inference(dav_m, rgb_raw, h_f, w_f)
 
497
  st.subheader("1. ROI Selection")
498
 
499
  viz_rgb = preview_rgb(rgb_raw)
 
500
  col_img, col_ctrl = st.columns([1.2, 0.8])
501
 
502
  with col_ctrl:
503
+ with st.form("roi_form", clear_on_submit=False):
504
+ x_center_form = st.slider(
505
+ "X center",
506
+ 0, w_f - 1,
507
+ int(st.session_state["x_center"])
508
+ )
509
+ y_center_form = st.slider(
510
+ "Y center",
511
+ 0, h_f - 1,
512
+ int(st.session_state["y_center"])
513
+ )
514
+ bbox_size_form = st.slider(
515
+ "BBox Size (px)",
516
+ 50, min(400, min(h_f, w_f)),
517
+ int(st.session_state["bbox_size"])
518
+ )
519
+ use_normalized_rel_form = st.checkbox(
520
+ "Use normalized relative depth for LoRA",
521
+ value=bool(st.session_state["use_normalized_rel"])
522
+ )
523
+
524
+ c1, c2 = st.columns(2)
525
+ with c1:
526
+ update_roi = st.form_submit_button("Update ROI")
527
+ with c2:
528
+ run_btn = st.form_submit_button("🚀 Run LoRA Pipeline", type="primary")
529
+
530
+ if update_roi or run_btn:
531
+ st.session_state["x_center"] = x_center_form
532
+ st.session_state["y_center"] = y_center_form
533
+ st.session_state["bbox_size"] = bbox_size_form
534
+ st.session_state["use_normalized_rel"] = use_normalized_rel_form
535
+ x_center = int(st.session_state["x_center"])
536
+ y_center = int(st.session_state["y_center"])
537
+ bbox_size = int(st.session_state["bbox_size"])
538
+ use_normalized_rel = bool(st.session_state["use_normalized_rel"])
539
 
540
  half_s = bbox_size // 2
541
  x1, x2 = max(0, x_center - half_s), min(w_f, x_center + half_s)
 
545
  bbox_mask[y1:y2, x1:x2] = True
546
 
547
  with col_img:
548
+ roi_preview = draw_roi_preview(viz_rgb, x1, y1, x2, y2)
549
+ st.image(roi_preview, caption="ROI Preview", use_container_width=True)
 
 
 
 
550
 
551
  if run_btn:
552
  rel_for_lora = rel_depth_01 if use_normalized_rel else rel_depth_map
 
647
  "Download Georeferenced DSM",
648
  out_buf.getvalue(),
649
  file_name="lora_refined_dsm_georef.tif"
650
+ )
651
+ else:
652
+ st.info("Upload RGB and Prior TIFFs, or switch to example TIFFs in the sidebar.")