Spaces:
Sleeping
Sleeping
| """Generalisation Data Lab β Stage 1 of the Generalisation pipeline.""" | |
| import streamlit as st | |
| import cv2 | |
| import numpy as np | |
| import os | |
| from utils.middlebury_loader import ( | |
| DEFAULT_MIDDLEBURY_ROOT, get_scene_groups, load_single_view, | |
| read_pfm_bytes, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Helpers (shared with stereo data lab) | |
| # ------------------------------------------------------------------ | |
| def _augment(img, brightness, contrast, rotation, | |
| flip_h, flip_v, noise, blur, shift_x, shift_y): | |
| out = img.astype(np.float32) | |
| out = np.clip(contrast * out + brightness, 0, 255) | |
| if noise > 0: | |
| out = np.clip(out + np.random.normal(0, noise, out.shape), 0, 255) | |
| out = out.astype(np.uint8) | |
| k = blur * 2 + 1 | |
| if k > 1: | |
| out = cv2.GaussianBlur(out, (k, k), 0) | |
| if rotation != 0: | |
| h, w = out.shape[:2] | |
| M = cv2.getRotationMatrix2D((w / 2, h / 2), rotation, 1.0) | |
| out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT) | |
| if shift_x != 0 or shift_y != 0: | |
| h, w = out.shape[:2] | |
| M = np.float32([[1, 0, shift_x], [0, 1, shift_y]]) | |
| out = cv2.warpAffine(out, M, (w, h), borderMode=cv2.BORDER_REFLECT) | |
| if flip_h: | |
| out = cv2.flip(out, 1) | |
| if flip_v: | |
| out = cv2.flip(out, 0) | |
| return out | |
| ROI_COLORS = [(0,255,0),(255,0,0),(0,0,255),(255,255,0), | |
| (255,0,255),(0,255,255),(128,255,0),(255,128,0)] | |
| MAX_UPLOAD_BYTES = 50 * 1024 * 1024 | |
| def render(): | |
| st.header("π§ͺ Data Lab β Generalisation") | |
| st.info("**How this works:** Train on one image, test on a completely " | |
| "different image of the same object. No stereo geometry β " | |
| "pure recognition generalisation.") | |
| source = st.radio("Data source", | |
| ["π¦ Middlebury Multi-View", "π Upload your own files"], | |
| horizontal=True, key="gen_source") | |
| # =================================================================== | |
| # Middlebury multi-view | |
| # =================================================================== | |
| if source == "π¦ Middlebury Multi-View": | |
| groups = get_scene_groups() | |
| if not groups: | |
| st.error("No valid Middlebury scenes found in ./data/middlebury/") | |
| return | |
| group_name = st.selectbox("Scene group", list(groups.keys()), key="gen_group") | |
| variants = groups[group_name] | |
| gc1, gc2 = st.columns(2) | |
| train_scene = gc1.selectbox("Training scene", variants, key="gen_train_scene") | |
| available_test = [v for v in variants if v != train_scene] | |
| if not available_test: | |
| st.error("Need at least 2 variants in a group.") | |
| return | |
| test_scene = gc2.selectbox("Test scene", available_test, key="gen_test_scene") | |
| train_path = os.path.join(DEFAULT_MIDDLEBURY_ROOT, train_scene) | |
| test_path = os.path.join(DEFAULT_MIDDLEBURY_ROOT, test_scene) | |
| img_train = load_single_view(train_path) | |
| img_test = load_single_view(test_path) | |
| st.markdown("*Both images show the same scene type captured under different " | |
| "conditions. The model trains on one variant and must recognise " | |
| "the same object class in the other β testing genuine appearance " | |
| "generalisation.*") | |
| c1, c2 = st.columns(2) | |
| c1.image(cv2.cvtColor(img_train, cv2.COLOR_BGR2RGB), | |
| caption=f"π¦ TRAIN IMAGE ({train_scene})", use_container_width=True) | |
| c2.image(cv2.cvtColor(img_test, cv2.COLOR_BGR2RGB), | |
| caption=f"π₯ TEST IMAGE ({test_scene})", use_container_width=True) | |
| scene_group = group_name | |
| # =================================================================== | |
| # Custom upload | |
| # =================================================================== | |
| else: | |
| uc1, uc2 = st.columns(2) | |
| with uc1: | |
| up_train = st.file_uploader("Train Image", type=["png","jpg","jpeg"], | |
| key="gen_up_train") | |
| with uc2: | |
| up_test = st.file_uploader("Test Image", type=["png","jpg","jpeg"], | |
| key="gen_up_test") | |
| if not (up_train and up_test): | |
| st.info("Upload a train and test image to proceed.") | |
| return | |
| if up_train.size > MAX_UPLOAD_BYTES or up_test.size > MAX_UPLOAD_BYTES: | |
| st.error("Image too large (max 50 MB).") | |
| return | |
| img_train = cv2.imdecode(np.frombuffer(up_train.read(), np.uint8), cv2.IMREAD_COLOR); up_train.seek(0) | |
| img_test = cv2.imdecode(np.frombuffer(up_test.read(), np.uint8), cv2.IMREAD_COLOR); up_test.seek(0) | |
| c1, c2 = st.columns(2) | |
| c1.image(cv2.cvtColor(img_train, cv2.COLOR_BGR2RGB), | |
| caption="π¦ TRAIN IMAGE", use_container_width=True) | |
| c2.image(cv2.cvtColor(img_test, cv2.COLOR_BGR2RGB), | |
| caption="π₯ TEST IMAGE", use_container_width=True) | |
| train_scene = "custom_train" | |
| test_scene = "custom_test" | |
| scene_group = "custom" | |
| # =================================================================== | |
| # ROI Definition (on TRAIN image) | |
| # =================================================================== | |
| st.divider() | |
| st.subheader("Step 2: Crop Region(s) of Interest") | |
| st.write("Define bounding boxes on the **TRAIN image**.") | |
| H, W = img_train.shape[:2] | |
| st.caption(f"π Image size: **{W} Γ {H}** px (X: 0 β {W-1}, Y: 0 β {H-1})") | |
| if "gen_rois" not in st.session_state: | |
| st.session_state["gen_rois"] = [ | |
| {"label": "object", "x0": 0, "y0": 0, | |
| "x1": min(W, 100), "y1": min(H, 100)} | |
| ] | |
| def _add_roi(): | |
| if len(st.session_state["gen_rois"]) >= 20: | |
| return | |
| st.session_state["gen_rois"].append( | |
| {"label": f"object_{len(st.session_state['gen_rois'])+1}", | |
| "x0": 0, "y0": 0, | |
| "x1": min(W, 100), "y1": min(H, 100)}) | |
| def _remove_roi(idx): | |
| if len(st.session_state["gen_rois"]) > 1: | |
| st.session_state["gen_rois"].pop(idx) | |
| for i, roi in enumerate(st.session_state["gen_rois"]): | |
| color = ROI_COLORS[i % len(ROI_COLORS)] | |
| color_hex = "#{:02x}{:02x}{:02x}".format(*color) | |
| with st.container(border=True): | |
| hc1, hc2, hc3 = st.columns([3, 6, 1]) | |
| hc1.markdown(f"**ROI {i+1}** <span style='color:{color_hex}'>β </span>", | |
| unsafe_allow_html=True) | |
| roi["label"] = hc2.text_input("Class Label", roi["label"], | |
| key=f"gen_roi_lbl_{i}") | |
| if len(st.session_state["gen_rois"]) > 1: | |
| hc3.button("β", key=f"gen_roi_del_{i}", | |
| on_click=_remove_roi, args=(i,)) | |
| cr1, cr2, cr3, cr4 = st.columns(4) | |
| roi["x0"] = int(cr1.number_input("X start", 0, W-2, int(roi["x0"]), | |
| step=1, key=f"gen_roi_x0_{i}")) | |
| roi["y0"] = int(cr2.number_input("Y start", 0, H-2, int(roi["y0"]), | |
| step=1, key=f"gen_roi_y0_{i}")) | |
| roi["x1"] = int(cr3.number_input("X end", 0, W, | |
| min(W, int(roi["x1"])), | |
| step=1, key=f"gen_roi_x1_{i}")) | |
| roi["y1"] = int(cr4.number_input("Y end", 0, H, | |
| min(H, int(roi["y1"])), | |
| step=1, key=f"gen_roi_y1_{i}")) | |
| if roi["x1"] <= roi["x0"] or roi["y1"] <= roi["y0"]: | |
| st.error(f"ROI {i+1}: end must be greater than start " | |
| f"(X: {roi['x0']}β{roi['x1']}, Y: {roi['y0']}β{roi['y1']}). " | |
| f"Adjust the values above.") | |
| st.button("β Add Another ROI", on_click=_add_roi, | |
| disabled=len(st.session_state["gen_rois"]) >= 20, | |
| key="gen_add_roi") | |
| # Validate all ROIs before drawing | |
| roi_valid = all(r["x1"] > r["x0"] and r["y1"] > r["y0"] | |
| for r in st.session_state["gen_rois"]) | |
| if not roi_valid: | |
| st.warning("β οΈ Fix the invalid ROI coordinates above before proceeding.") | |
| st.stop() | |
| overlay = img_train.copy() | |
| # Draw pixel ruler (tick marks every 100 px) | |
| for px in range(0, W, 100): | |
| cv2.line(overlay, (px, 0), (px, 12), (200, 200, 200), 1) | |
| cv2.putText(overlay, str(px), (px + 2, 11), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.3, (200, 200, 200), 1) | |
| for py in range(0, H, 100): | |
| cv2.line(overlay, (0, py), (12, py), (200, 200, 200), 1) | |
| cv2.putText(overlay, str(py), (1, py + 12), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.3, (200, 200, 200), 1) | |
| crops = [] | |
| for i, roi in enumerate(st.session_state["gen_rois"]): | |
| color = ROI_COLORS[i % len(ROI_COLORS)] | |
| x0, y0, x1, y1 = roi["x0"], roi["y0"], roi["x1"], roi["y1"] | |
| cv2.rectangle(overlay, (x0, y0), (x1, y1), color, 2) | |
| cv2.putText(overlay, roi["label"], (x0, y0 - 6), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2) | |
| crops.append(img_train[y0:y1, x0:x1].copy()) | |
| ov1, ov2 = st.columns([3, 2]) | |
| ov1.image(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB), | |
| caption="TRAIN image β ROIs highlighted", use_container_width=True) | |
| with ov2: | |
| for i, (c, roi) in enumerate(zip(crops, st.session_state["gen_rois"])): | |
| st.image(cv2.cvtColor(c, cv2.COLOR_BGR2RGB), | |
| caption=f"{roi['label']} ({c.shape[1]}Γ{c.shape[0]})", width=160) | |
| crop_bgr = crops[0] | |
| x0 = st.session_state["gen_rois"][0]["x0"] | |
| y0 = st.session_state["gen_rois"][0]["y0"] | |
| x1 = st.session_state["gen_rois"][0]["x1"] | |
| y1 = st.session_state["gen_rois"][0]["y1"] | |
| # =================================================================== | |
| # Augmentation | |
| # =================================================================== | |
| st.divider() | |
| st.subheader("Step 3: Data Augmentation") | |
| ac1, ac2 = st.columns(2) | |
| with ac1: | |
| brightness = st.slider("Brightness offset", -100, 100, 0, key="gen_bright") | |
| contrast = st.slider("Contrast scale", 0.5, 3.0, 1.0, 0.05, key="gen_contrast") | |
| rotation = st.slider("Rotation (Β°)", -180, 180, 0, key="gen_rot") | |
| noise = st.slider("Gaussian noise Ο", 0, 50, 0, key="gen_noise") | |
| with ac2: | |
| blur = st.slider("Blur kernel (0=off)", 0, 10, 0, key="gen_blur") | |
| shift_x = st.slider("Shift X (px)", -100, 100, 0, key="gen_sx") | |
| shift_y = st.slider("Shift Y (px)", -100, 100, 0, key="gen_sy") | |
| flip_h = st.checkbox("Flip Horizontal", key="gen_fh") | |
| flip_v = st.checkbox("Flip Vertical", key="gen_fv") | |
| aug = _augment(crop_bgr, brightness, contrast, rotation, | |
| flip_h, flip_v, noise, blur, shift_x, shift_y) | |
| all_augs = [_augment(c, brightness, contrast, rotation, | |
| flip_h, flip_v, noise, blur, shift_x, shift_y) | |
| for c in crops] | |
| ag1, ag2 = st.columns(2) | |
| ag1.image(cv2.cvtColor(crop_bgr, cv2.COLOR_BGR2RGB), | |
| caption="Original Crop (ROI 1)", use_container_width=True) | |
| ag2.image(cv2.cvtColor(aug, cv2.COLOR_BGR2RGB), | |
| caption="Augmented Crop (ROI 1)", use_container_width=True) | |
| # =================================================================== | |
| # Lock & Store | |
| # =================================================================== | |
| st.divider() | |
| if st.button("π Lock Data & Proceed", key="gen_lock"): | |
| rois_data = [] | |
| for i, roi in enumerate(st.session_state["gen_rois"]): | |
| rois_data.append({ | |
| "label": roi["label"], | |
| "bbox": (roi["x0"], roi["y0"], roi["x1"], roi["y1"]), | |
| "crop": crops[i], | |
| "crop_aug": all_augs[i], | |
| }) | |
| st.session_state["gen_pipeline"] = { | |
| "train_image": img_train, | |
| "test_image": img_test, | |
| "roi": {"x": x0, "y": y0, "w": x1 - x0, "h": y1 - y0, | |
| "label": st.session_state["gen_rois"][0]["label"]}, | |
| "crop": crop_bgr, | |
| "crop_aug": aug, | |
| "crop_bbox": (x0, y0, x1, y1), | |
| "rois": rois_data, | |
| "source": "middlebury" if source == "π¦ Middlebury Multi-View" else "custom", | |
| "scene_group": scene_group, | |
| "train_scene": train_scene, | |
| "test_scene": test_scene, | |
| } | |
| st.success(f"β Data locked with **{len(rois_data)} ROI(s)**! " | |
| f"Proceed to Feature Lab.") | |