Update src/streamlit_app.py
Browse files- src/streamlit_app.py +173 -36
src/streamlit_app.py
CHANGED
|
@@ -29,7 +29,12 @@ STRIDE = 4
|
|
| 29 |
|
| 30 |
# Keep LoRA normalization from your local code
|
| 31 |
IMAGENET_MEAN = (0.430, 0.411, 0.296)
|
| 32 |
-
IMAGENET_STD
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# ============================================================
|
| 35 |
# HELPERS
|
|
@@ -58,6 +63,65 @@ def preview_rgb(rgb_raw):
|
|
| 58 |
rgb = rgb / (np.percentile(rgb, 98) + 1e-6)
|
| 59 |
return np.clip(rgb, 0, 1)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
# ============================================================
|
| 62 |
# MODELS
|
| 63 |
# ============================================================
|
|
@@ -138,9 +202,9 @@ def get_lora_params(model):
|
|
| 138 |
params.extend(list(module.B.parameters()))
|
| 139 |
return params
|
| 140 |
|
|
|
|
| 141 |
# ============================================================
|
| 142 |
# MODEL LOADING
|
| 143 |
-
# Uses old app paths / loading logic
|
| 144 |
# ============================================================
|
| 145 |
@st.cache_resource
|
| 146 |
def load_models(repo_id, dav_file, dino_file):
|
|
@@ -196,15 +260,14 @@ def load_models(repo_id, dav_file, dino_file):
|
|
| 196 |
|
| 197 |
return dav_model, dino_model
|
| 198 |
|
|
|
|
| 199 |
# ============================================================
|
| 200 |
# DEPTH ANYTHING INFERENCE
|
| 201 |
-
# Use exactly the old app style: raw DA output, no disparity2depth
|
| 202 |
# ============================================================
|
| 203 |
@st.cache_data(show_spinner=False)
|
| 204 |
def run_dav_inference(_dav, rgb_raw, h_f, w_f):
|
| 205 |
img_448 = cv2.resize(rgb_raw.transpose(1, 2, 0), (448, 448))
|
| 206 |
|
| 207 |
-
# Use torch.tensor for HF stability
|
| 208 |
dav_in = torch.tensor(img_448, device=DEVICE).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 209 |
|
| 210 |
with torch.no_grad():
|
|
@@ -217,17 +280,17 @@ def run_dav_inference(_dav, rgb_raw, h_f, w_f):
|
|
| 217 |
size=(h_f, w_f),
|
| 218 |
mode="bilinear",
|
| 219 |
align_corners=False
|
| 220 |
-
).squeeze(1)
|
| 221 |
|
| 222 |
raw_depth_map = raw_depth[0].detach().float().cpu().numpy()
|
| 223 |
|
| 224 |
-
# optional normalized version for more stable LoRA fitting / plotting
|
| 225 |
valid = np.isfinite(raw_depth_map)
|
| 226 |
raw_depth_01 = normalize_01(raw_depth_map, valid)
|
| 227 |
raw_depth_01[~valid] = np.nan
|
| 228 |
|
| 229 |
return raw_depth_map, raw_depth_01
|
| 230 |
|
|
|
|
| 231 |
# ============================================================
|
| 232 |
# MAIN LORA PIPELINE
|
| 233 |
# ============================================================
|
|
@@ -253,7 +316,6 @@ def run_lora_pipeline(
|
|
| 253 |
anchor_mask = anchor_mask_cpu.to(DEVICE)
|
| 254 |
prior_gpu = prior_raw_t.to(DEVICE)
|
| 255 |
|
| 256 |
-
# make fresh copy each run
|
| 257 |
dino = copy.deepcopy(dino_base)
|
| 258 |
dino = inject_lora(dino, r=lora_r, alpha=lora_alpha).to(DEVICE).train()
|
| 259 |
|
|
@@ -273,8 +335,8 @@ def run_lora_pipeline(
|
|
| 273 |
|
| 274 |
Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
|
| 275 |
prior_p = F.interpolate(prior_gpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten()
|
| 276 |
-
rel_p
|
| 277 |
-
mask_p
|
| 278 |
|
| 279 |
loss_hist = []
|
| 280 |
prog = st.progress(0, text="Running LoRA TTO...")
|
|
@@ -297,7 +359,6 @@ def run_lora_pipeline(
|
|
| 297 |
|
| 298 |
prog.empty()
|
| 299 |
|
| 300 |
-
# dense stride-4 inference
|
| 301 |
dino.eval()
|
| 302 |
mlp_head.eval()
|
| 303 |
|
|
@@ -323,11 +384,11 @@ def run_lora_pipeline(
|
|
| 323 |
|
| 324 |
sb_local = mlp_head(t).t().reshape(2, hc // p, wc // p)
|
| 325 |
|
| 326 |
-
sb_acc[:, dy//stride:dy//stride + (hc//p)*(p//stride):p//stride,
|
| 327 |
-
dx//stride:dx//stride + (wc//p)*(p//stride):p//stride] += sb_local
|
| 328 |
|
| 329 |
-
cnt_acc[:, dy//stride:dy//stride + (hc//p)*(p//stride):p//stride,
|
| 330 |
-
dx//stride:dx//stride + (wc//p)*(p//stride):p//stride] += 1
|
| 331 |
|
| 332 |
sb_dense = sb_acc / (cnt_acc + 1e-8)
|
| 333 |
offset = (p - (p // 2)) // stride + 1
|
|
@@ -350,15 +411,54 @@ def run_lora_pipeline(
|
|
| 350 |
|
| 351 |
return final_dsm, loss_hist, anchor_mask_cpu.cpu().numpy()
|
| 352 |
|
|
|
|
| 353 |
# ============================================================
|
| 354 |
# UI
|
| 355 |
# ============================================================
|
| 356 |
st.title("Prior2DSM | LoRA")
|
| 357 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 358 |
with st.sidebar:
|
| 359 |
st.header("📂 Data")
|
| 360 |
-
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
st.divider()
|
| 364 |
st.write("#### LoRA / TTO")
|
|
@@ -367,20 +467,29 @@ with st.sidebar:
|
|
| 367 |
tto_steps = st.slider("TTO steps", 10, 300, 100, step=10)
|
| 368 |
tto_lr = st.select_slider("TTO LR", options=[1e-4, 3e-4, 1e-3, 3e-3], value=1e-3)
|
| 369 |
|
| 370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
dav_m, dino_base = load_models(
|
| 372 |
repo_id="osherr/Prior2DSM",
|
| 373 |
dav_file="depth_anything_v2_vitl.pth",
|
| 374 |
dino_file="dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
|
| 375 |
)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
rgb_raw =
|
| 379 |
-
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
-
|
| 382 |
-
prior_raw = src.read(1).astype(np.float32)
|
| 383 |
-
prior_meta = src.meta.copy()
|
| 384 |
|
| 385 |
with st.spinner("Generating relative depth with Depth Anything V2..."):
|
| 386 |
rel_depth_map, rel_depth_01 = run_dav_inference(dav_m, rgb_raw, h_f, w_f)
|
|
@@ -388,15 +497,45 @@ if rgb_file and prior_file:
|
|
| 388 |
st.subheader("1. ROI Selection")
|
| 389 |
|
| 390 |
viz_rgb = preview_rgb(rgb_raw)
|
| 391 |
-
|
| 392 |
col_img, col_ctrl = st.columns([1.2, 0.8])
|
| 393 |
|
| 394 |
with col_ctrl:
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
|
| 401 |
half_s = bbox_size // 2
|
| 402 |
x1, x2 = max(0, x_center - half_s), min(w_f, x_center + half_s)
|
|
@@ -406,12 +545,8 @@ if rgb_file and prior_file:
|
|
| 406 |
bbox_mask[y1:y2, x1:x2] = True
|
| 407 |
|
| 408 |
with col_img:
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
ax_roi.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="red", lw=2))
|
| 412 |
-
ax_roi.set_title("ROI Preview")
|
| 413 |
-
ax_roi.axis("off")
|
| 414 |
-
st.pyplot(fig_roi)
|
| 415 |
|
| 416 |
if run_btn:
|
| 417 |
rel_for_lora = rel_depth_01 if use_normalized_rel else rel_depth_map
|
|
@@ -512,4 +647,6 @@ if rgb_file and prior_file:
|
|
| 512 |
"Download Georeferenced DSM",
|
| 513 |
out_buf.getvalue(),
|
| 514 |
file_name="lora_refined_dsm_georef.tif"
|
| 515 |
-
)
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# Keep LoRA normalization from your local code
|
| 31 |
IMAGENET_MEAN = (0.430, 0.411, 0.296)
|
| 32 |
+
IMAGENET_STD = (0.213, 0.156, 0.143)
|
| 33 |
+
|
| 34 |
+
# Example files inside the HF Space repo
|
| 35 |
+
EXAMPLE_RGB_FILENAME = "examples/example_rgb.tif"
|
| 36 |
+
EXAMPLE_PRIOR_FILENAME = "examples/example_prior.tif"
|
| 37 |
+
|
| 38 |
|
| 39 |
# ============================================================
|
| 40 |
# HELPERS
|
|
|
|
| 63 |
rgb = rgb / (np.percentile(rgb, 98) + 1e-6)
|
| 64 |
return np.clip(rgb, 0, 1)
|
| 65 |
|
| 66 |
+
|
| 67 |
+
def draw_roi_preview(viz_rgb, x1, y1, x2, y2):
|
| 68 |
+
preview = (np.clip(viz_rgb, 0, 1) * 255).astype(np.uint8).copy()
|
| 69 |
+
cv2.rectangle(preview, (x1, y1), (x2, y2), (255, 0, 0), 2)
|
| 70 |
+
return preview
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@st.cache_data(show_spinner=False)
|
| 74 |
+
def load_tiff_from_hf(repo_id, filename, repo_type="space"):
|
| 75 |
+
return hf_hub_download(repo_id=repo_id, filename=filename, repo_type=repo_type)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def read_rgb_tiff(path_or_bytes):
|
| 79 |
+
if isinstance(path_or_bytes, (str, os.PathLike)):
|
| 80 |
+
with rasterio.open(path_or_bytes) as src:
|
| 81 |
+
rgb_raw = src.read([1, 2, 3])
|
| 82 |
+
h_f, w_f = src.height, src.width
|
| 83 |
+
meta = src.meta.copy()
|
| 84 |
+
else:
|
| 85 |
+
with rasterio.open(BytesIO(path_or_bytes)) as src:
|
| 86 |
+
rgb_raw = src.read([1, 2, 3])
|
| 87 |
+
h_f, w_f = src.height, src.width
|
| 88 |
+
meta = src.meta.copy()
|
| 89 |
+
return rgb_raw, h_f, w_f, meta
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def read_prior_tiff(path_or_bytes):
|
| 93 |
+
if isinstance(path_or_bytes, (str, os.PathLike)):
|
| 94 |
+
with rasterio.open(path_or_bytes) as src:
|
| 95 |
+
prior_raw = src.read(1).astype(np.float32)
|
| 96 |
+
meta = src.meta.copy()
|
| 97 |
+
else:
|
| 98 |
+
with rasterio.open(BytesIO(path_or_bytes)) as src:
|
| 99 |
+
prior_raw = src.read(1).astype(np.float32)
|
| 100 |
+
meta = src.meta.copy()
|
| 101 |
+
return prior_raw, meta
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def init_roi_state(h_f, w_f):
|
| 105 |
+
if "x_center" not in st.session_state:
|
| 106 |
+
st.session_state["x_center"] = w_f // 2
|
| 107 |
+
if "y_center" not in st.session_state:
|
| 108 |
+
st.session_state["y_center"] = h_f // 2
|
| 109 |
+
if "bbox_size" not in st.session_state:
|
| 110 |
+
st.session_state["bbox_size"] = min(200, min(h_f, w_f))
|
| 111 |
+
if "use_normalized_rel" not in st.session_state:
|
| 112 |
+
st.session_state["use_normalized_rel"] = True
|
| 113 |
+
if "loaded_shape" not in st.session_state:
|
| 114 |
+
st.session_state["loaded_shape"] = (h_f, w_f)
|
| 115 |
+
|
| 116 |
+
prev_shape = st.session_state.get("loaded_shape", None)
|
| 117 |
+
if prev_shape != (h_f, w_f):
|
| 118 |
+
st.session_state["x_center"] = w_f // 2
|
| 119 |
+
st.session_state["y_center"] = h_f // 2
|
| 120 |
+
st.session_state["bbox_size"] = min(200, min(h_f, w_f))
|
| 121 |
+
st.session_state["use_normalized_rel"] = True
|
| 122 |
+
st.session_state["loaded_shape"] = (h_f, w_f)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
# ============================================================
|
| 126 |
# MODELS
|
| 127 |
# ============================================================
|
|
|
|
| 202 |
params.extend(list(module.B.parameters()))
|
| 203 |
return params
|
| 204 |
|
| 205 |
+
|
| 206 |
# ============================================================
|
| 207 |
# MODEL LOADING
|
|
|
|
| 208 |
# ============================================================
|
| 209 |
@st.cache_resource
|
| 210 |
def load_models(repo_id, dav_file, dino_file):
|
|
|
|
| 260 |
|
| 261 |
return dav_model, dino_model
|
| 262 |
|
| 263 |
+
|
| 264 |
# ============================================================
|
| 265 |
# DEPTH ANYTHING INFERENCE
|
|
|
|
| 266 |
# ============================================================
|
| 267 |
@st.cache_data(show_spinner=False)
|
| 268 |
def run_dav_inference(_dav, rgb_raw, h_f, w_f):
|
| 269 |
img_448 = cv2.resize(rgb_raw.transpose(1, 2, 0), (448, 448))
|
| 270 |
|
|
|
|
| 271 |
dav_in = torch.tensor(img_448, device=DEVICE).permute(2, 0, 1).unsqueeze(0).float() / 255.0
|
| 272 |
|
| 273 |
with torch.no_grad():
|
|
|
|
| 280 |
size=(h_f, w_f),
|
| 281 |
mode="bilinear",
|
| 282 |
align_corners=False
|
| 283 |
+
).squeeze(1)
|
| 284 |
|
| 285 |
raw_depth_map = raw_depth[0].detach().float().cpu().numpy()
|
| 286 |
|
|
|
|
| 287 |
valid = np.isfinite(raw_depth_map)
|
| 288 |
raw_depth_01 = normalize_01(raw_depth_map, valid)
|
| 289 |
raw_depth_01[~valid] = np.nan
|
| 290 |
|
| 291 |
return raw_depth_map, raw_depth_01
|
| 292 |
|
| 293 |
+
|
| 294 |
# ============================================================
|
| 295 |
# MAIN LORA PIPELINE
|
| 296 |
# ============================================================
|
|
|
|
| 316 |
anchor_mask = anchor_mask_cpu.to(DEVICE)
|
| 317 |
prior_gpu = prior_raw_t.to(DEVICE)
|
| 318 |
|
|
|
|
| 319 |
dino = copy.deepcopy(dino_base)
|
| 320 |
dino = inject_lora(dino, r=lora_r, alpha=lora_alpha).to(DEVICE).train()
|
| 321 |
|
|
|
|
| 335 |
|
| 336 |
Hp, Wp = H // PATCH_SIZE, W // PATCH_SIZE
|
| 337 |
prior_p = F.interpolate(prior_gpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten()
|
| 338 |
+
rel_p = F.interpolate(rel_cpu.view(1, 1, H, W), size=(Hp, Wp), mode="bilinear").flatten()
|
| 339 |
+
mask_p = F.interpolate(anchor_mask.float().view(1, 1, H, W), size=(Hp, Wp), mode="area").flatten() > 0.5
|
| 340 |
|
| 341 |
loss_hist = []
|
| 342 |
prog = st.progress(0, text="Running LoRA TTO...")
|
|
|
|
| 359 |
|
| 360 |
prog.empty()
|
| 361 |
|
|
|
|
| 362 |
dino.eval()
|
| 363 |
mlp_head.eval()
|
| 364 |
|
|
|
|
| 384 |
|
| 385 |
sb_local = mlp_head(t).t().reshape(2, hc // p, wc // p)
|
| 386 |
|
| 387 |
+
sb_acc[:, dy // stride:dy // stride + (hc // p) * (p // stride):p // stride,
|
| 388 |
+
dx // stride:dx // stride + (wc // p) * (p // stride):p // stride] += sb_local
|
| 389 |
|
| 390 |
+
cnt_acc[:, dy // stride:dy // stride + (hc // p) * (p // stride):p // stride,
|
| 391 |
+
dx // stride:dx // stride + (wc // p) * (p // stride):p // stride] += 1
|
| 392 |
|
| 393 |
sb_dense = sb_acc / (cnt_acc + 1e-8)
|
| 394 |
offset = (p - (p // 2)) // stride + 1
|
|
|
|
| 411 |
|
| 412 |
return final_dsm, loss_hist, anchor_mask_cpu.cpu().numpy()
|
| 413 |
|
| 414 |
+
|
| 415 |
# ============================================================
|
| 416 |
# UI
|
| 417 |
# ============================================================
|
| 418 |
st.title("Prior2DSM | LoRA")
|
| 419 |
|
| 420 |
+
st.markdown(
|
| 421 |
+
f"""
|
| 422 |
+
**Example TIFFs**
|
| 423 |
+
- [Download example RGB TIFF](https://huggingface.co/spaces/osherr/Prior2DSM/resolve/main/{EXAMPLE_RGB_FILENAME})
|
| 424 |
+
- [Download example Prior TIFF](https://huggingface.co/spaces/osherr/Prior2DSM/resolve/main/{EXAMPLE_PRIOR_FILENAME})
|
| 425 |
+
"""
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
with st.sidebar:
|
| 429 |
st.header("📂 Data")
|
| 430 |
+
|
| 431 |
+
data_mode = st.radio(
|
| 432 |
+
"Data source",
|
| 433 |
+
["Upload TIFFs", "Use example TIFFs"],
|
| 434 |
+
index=0
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
rgb_file = None
|
| 438 |
+
prior_file = None
|
| 439 |
+
rgb_example_path = None
|
| 440 |
+
prior_example_path = None
|
| 441 |
+
|
| 442 |
+
if data_mode == "Upload TIFFs":
|
| 443 |
+
rgb_file = st.file_uploader("RGB Image", type=["tif", "tiff"])
|
| 444 |
+
prior_file = st.file_uploader("LiDAR Prior", type=["tif", "tiff"])
|
| 445 |
+
else:
|
| 446 |
+
st.caption("Load demo RGB/Prior TIFFs from the Hugging Face Space.")
|
| 447 |
+
if st.button("Load example TIFFs"):
|
| 448 |
+
st.session_state["use_examples"] = True
|
| 449 |
+
|
| 450 |
+
if st.session_state.get("use_examples", False):
|
| 451 |
+
rgb_example_path = load_tiff_from_hf(
|
| 452 |
+
repo_id="osherr/Prior2DSM",
|
| 453 |
+
filename=EXAMPLE_RGB_FILENAME,
|
| 454 |
+
repo_type="space"
|
| 455 |
+
)
|
| 456 |
+
prior_example_path = load_tiff_from_hf(
|
| 457 |
+
repo_id="osherr/Prior2DSM",
|
| 458 |
+
filename=EXAMPLE_PRIOR_FILENAME,
|
| 459 |
+
repo_type="space"
|
| 460 |
+
)
|
| 461 |
+
st.success("Example TIFFs loaded.")
|
| 462 |
|
| 463 |
st.divider()
|
| 464 |
st.write("#### LoRA / TTO")
|
|
|
|
| 467 |
tto_steps = st.slider("TTO steps", 10, 300, 100, step=10)
|
| 468 |
tto_lr = st.select_slider("TTO LR", options=[1e-4, 3e-4, 1e-3, 3e-3], value=1e-3)
|
| 469 |
|
| 470 |
+
has_uploaded = (rgb_file is not None and prior_file is not None)
|
| 471 |
+
has_examples = (
|
| 472 |
+
data_mode == "Use example TIFFs"
|
| 473 |
+
and st.session_state.get("use_examples", False)
|
| 474 |
+
and rgb_example_path is not None
|
| 475 |
+
and prior_example_path is not None
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
if has_uploaded or has_examples:
|
| 479 |
dav_m, dino_base = load_models(
|
| 480 |
repo_id="osherr/Prior2DSM",
|
| 481 |
dav_file="depth_anything_v2_vitl.pth",
|
| 482 |
dino_file="dinov3_vitl16_pretrain_sat493m-eadcf0ff.pth"
|
| 483 |
)
|
| 484 |
|
| 485 |
+
if has_uploaded:
|
| 486 |
+
rgb_raw, h_f, w_f, _ = read_rgb_tiff(rgb_file.read())
|
| 487 |
+
prior_raw, prior_meta = read_prior_tiff(prior_file.read())
|
| 488 |
+
else:
|
| 489 |
+
rgb_raw, h_f, w_f, _ = read_rgb_tiff(rgb_example_path)
|
| 490 |
+
prior_raw, prior_meta = read_prior_tiff(prior_example_path)
|
| 491 |
|
| 492 |
+
init_roi_state(h_f, w_f)
|
|
|
|
|
|
|
| 493 |
|
| 494 |
with st.spinner("Generating relative depth with Depth Anything V2..."):
|
| 495 |
rel_depth_map, rel_depth_01 = run_dav_inference(dav_m, rgb_raw, h_f, w_f)
|
|
|
|
| 497 |
st.subheader("1. ROI Selection")
|
| 498 |
|
| 499 |
viz_rgb = preview_rgb(rgb_raw)
|
|
|
|
| 500 |
col_img, col_ctrl = st.columns([1.2, 0.8])
|
| 501 |
|
| 502 |
with col_ctrl:
|
| 503 |
+
with st.form("roi_form", clear_on_submit=False):
|
| 504 |
+
x_center_form = st.slider(
|
| 505 |
+
"X center",
|
| 506 |
+
0, w_f - 1,
|
| 507 |
+
int(st.session_state["x_center"])
|
| 508 |
+
)
|
| 509 |
+
y_center_form = st.slider(
|
| 510 |
+
"Y center",
|
| 511 |
+
0, h_f - 1,
|
| 512 |
+
int(st.session_state["y_center"])
|
| 513 |
+
)
|
| 514 |
+
bbox_size_form = st.slider(
|
| 515 |
+
"BBox Size (px)",
|
| 516 |
+
50, min(400, min(h_f, w_f)),
|
| 517 |
+
int(st.session_state["bbox_size"])
|
| 518 |
+
)
|
| 519 |
+
use_normalized_rel_form = st.checkbox(
|
| 520 |
+
"Use normalized relative depth for LoRA",
|
| 521 |
+
value=bool(st.session_state["use_normalized_rel"])
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
c1, c2 = st.columns(2)
|
| 525 |
+
with c1:
|
| 526 |
+
update_roi = st.form_submit_button("Update ROI")
|
| 527 |
+
with c2:
|
| 528 |
+
run_btn = st.form_submit_button("🚀 Run LoRA Pipeline", type="primary")
|
| 529 |
+
|
| 530 |
+
if update_roi or run_btn:
|
| 531 |
+
st.session_state["x_center"] = x_center_form
|
| 532 |
+
st.session_state["y_center"] = y_center_form
|
| 533 |
+
st.session_state["bbox_size"] = bbox_size_form
|
| 534 |
+
st.session_state["use_normalized_rel"] = use_normalized_rel_form
|
| 535 |
+
x_center = int(st.session_state["x_center"])
|
| 536 |
+
y_center = int(st.session_state["y_center"])
|
| 537 |
+
bbox_size = int(st.session_state["bbox_size"])
|
| 538 |
+
use_normalized_rel = bool(st.session_state["use_normalized_rel"])
|
| 539 |
|
| 540 |
half_s = bbox_size // 2
|
| 541 |
x1, x2 = max(0, x_center - half_s), min(w_f, x_center + half_s)
|
|
|
|
| 545 |
bbox_mask[y1:y2, x1:x2] = True
|
| 546 |
|
| 547 |
with col_img:
|
| 548 |
+
roi_preview = draw_roi_preview(viz_rgb, x1, y1, x2, y2)
|
| 549 |
+
st.image(roi_preview, caption="ROI Preview", use_container_width=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 550 |
|
| 551 |
if run_btn:
|
| 552 |
rel_for_lora = rel_depth_01 if use_normalized_rel else rel_depth_map
|
|
|
|
| 647 |
"Download Georeferenced DSM",
|
| 648 |
out_buf.getvalue(),
|
| 649 |
file_name="lora_refined_dsm_georef.tif"
|
| 650 |
+
)
|
| 651 |
+
else:
|
| 652 |
+
st.info("Upload RGB and Prior TIFFs, or switch to example TIFFs in the sidebar.")
|