| import re |
| from typing import Any, Dict, Iterable, List, Optional, Tuple |
|
|
| import numpy as np |
| import cv2 |
| import copy |
|
|
|
|
| def process_segmentation( |
| segmentation: np.ndarray, |
| segmentation_id_map: Optional[Dict[int, Any]], |
| color_map: Dict[int, List[int]], |
| current_segment: Any, |
| current_subgoal_segment: Optional[str], |
| previous_subgoal_segment: Optional[str], |
| current_task_name: str, |
| existing_points: Optional[List[List[int]]] = None, |
| existing_subgoal_filled: Optional[str] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Shared helper to compute segmentation filtering and grounded subgoal text. |
| |
| Returns a dict with: |
| - segmentation_result: segmentation mask filtered to visible ids |
| - segmentation_result_2d: squeezed version of segmentation_result |
| - segmentation_points: cached center points for current targets |
| - current_subgoal_segment_filled: subgoal string with centers filled in |
| - no_object_flag: whether the target ids are missing in the mask |
| - updated_previous_subgoal_segment: equals current_subgoal_segment for caller caching |
| - vis_obj_id_list: ids kept in segmentation_result |
| """ |
| segmentation_2d = segmentation.squeeze() if segmentation.ndim > 2 else segmentation |
|
|
| if isinstance(current_segment, (list, tuple)): |
| active_segments = list(current_segment) |
| elif current_segment is None: |
| active_segments = [] |
| else: |
| active_segments = [current_segment] |
|
|
| segment_ids_by_index = {idx: [] for idx in range(len(active_segments))} |
| vis_obj_id_list: List[int] = [] |
| if isinstance(segmentation_id_map, dict): |
| for obj_id, obj in sorted(segmentation_id_map.items()): |
| if active_segments: |
| for idx, target in enumerate(active_segments): |
| if obj is target: |
| vis_obj_id_list.append(obj_id) |
| segment_ids_by_index[idx].append(obj_id) |
| break |
| if getattr(obj, "name", None) == "table-workspace": |
| color_map[obj_id] = [0, 0, 0] |
|
|
| segmentation_result = np.where( |
| np.isin(segmentation_2d, vis_obj_id_list), segmentation_2d, 0 |
| ) |
| segmentation_result_2d = segmentation_result.squeeze() |
|
|
| segmentation_points = existing_points or [] |
| current_subgoal_segment_filled = existing_subgoal_filled |
| no_object_flag = False |
|
|
| if current_subgoal_segment != previous_subgoal_segment: |
|
|
| def compute_center_from_ids(segmentation_mask: np.ndarray, ids: Iterable[int]): |
| nonlocal no_object_flag |
| ids = list(ids) |
| if not ids: |
| return None |
| mask = np.isin(segmentation_mask, ids) |
| if not np.any(mask): |
| no_object_flag = True |
| return None |
| coords = np.argwhere(mask) |
| if coords.size == 0: |
| return None |
| center_y = int(coords[:, 0].mean()) |
| center_x = int(coords[:, 1].mean()) |
| return [center_y, center_x] |
|
|
| segment_centers: List[Optional[List[int]]] = [] |
| if active_segments: |
| for idx in range(len(active_segments)): |
| segment_centers.append( |
| compute_center_from_ids( |
| segmentation_2d, segment_ids_by_index.get(idx, []) |
| ) |
| ) |
| else: |
| segment_centers.append( |
| compute_center_from_ids(segmentation_2d, vis_obj_id_list) |
| ) |
|
|
| segmentation_points = [center for center in segment_centers if center is not None] |
|
|
| if current_subgoal_segment: |
| normalized_centers: List[Optional[str]] = [] |
| for center in segment_centers: |
| if center is None: |
| normalized_centers.append(None) |
| continue |
| center_y, center_x = center |
| normalized_centers.append(f"<{center_y}, {center_x}>") |
|
|
| placeholder_pattern = re.compile(r"<[^>]*>") |
| placeholders = list(placeholder_pattern.finditer(current_subgoal_segment)) |
| placeholder_count = len(placeholders) |
| if placeholder_count > 0 and normalized_centers: |
| replacements = normalized_centers.copy() |
| if len(replacements) == 1 and placeholder_count > 1: |
| replacements = replacements * placeholder_count |
| elif len(replacements) < placeholder_count: |
| replacements.extend([None] * (placeholder_count - len(replacements))) |
|
|
| missing_placeholder = False |
| new_text_parts: List[str] = [] |
| last_idx = 0 |
| for idx, match in enumerate(placeholders): |
| new_text_parts.append( |
| current_subgoal_segment[last_idx : match.start()] |
| ) |
| replacement_text = replacements[idx] |
| if replacement_text is None: |
| missing_placeholder = True |
| else: |
| new_text_parts.append(replacement_text) |
| last_idx = match.end() |
| new_text_parts.append(current_subgoal_segment[last_idx:]) |
| current_subgoal_segment_filled = ( |
| current_task_name if missing_placeholder else "".join(new_text_parts) |
| ) |
| else: |
| current_subgoal_segment_filled = current_subgoal_segment |
| else: |
| current_subgoal_segment_filled = current_subgoal_segment |
|
|
| return { |
| "segmentation_result": segmentation_result, |
| "segmentation_result_2d": segmentation_result_2d, |
| "segmentation_points": segmentation_points, |
| "current_subgoal_segment_filled": current_subgoal_segment_filled, |
| "no_object_flag": no_object_flag, |
| "updated_previous_subgoal_segment": current_subgoal_segment, |
| "vis_obj_id_list": vis_obj_id_list, |
| } |
|
|
|
|
| def create_segmentation_visuals( |
| segmentation: np.ndarray, |
| segmentation_result: np.ndarray, |
| base_frame: np.ndarray, |
| color_map: Dict[int, List[int]], |
| segmentation_points: List[List[int]], |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """ |
| Build colored segmentation visualizations and target overlay for video export. |
| |
| Returns (segmentation_vis, segmentation_result_vis, target_for_video). |
| """ |
| segmentation_for_video = copy.deepcopy(segmentation) |
| segmentation_result_for_video = copy.deepcopy(segmentation_result) |
|
|
| segmentation_vis = np.zeros( |
| (*segmentation_for_video.shape[:2], 3), dtype=np.uint8 |
| ) |
| segmentation_result_vis = np.zeros( |
| (*segmentation_result_for_video.shape[:2], 3), dtype=np.uint8 |
| ) |
|
|
| seg_2d = ( |
| segmentation_for_video.squeeze() |
| if segmentation_for_video.ndim > 2 |
| else segmentation_for_video |
| ) |
| seg_result_2d = ( |
| segmentation_result_for_video.squeeze() |
| if segmentation_result_for_video.ndim > 2 |
| else segmentation_result_for_video |
| ) |
|
|
| for seg_id in np.unique(seg_2d): |
| if seg_id > 0: |
| color = color_map.get(seg_id, [255, 255, 255]) |
| mask = seg_2d == seg_id |
| segmentation_vis[mask] = color |
|
|
| for seg_id in np.unique(seg_result_2d): |
| if seg_id > 0: |
| color = color_map.get(seg_id, [255, 255, 255]) |
| mask = seg_result_2d == seg_id |
| segmentation_result_vis[mask] = color |
|
|
| target_for_video = copy.deepcopy(base_frame) |
|
|
| if segmentation_vis.shape[:2] != base_frame.shape[:2]: |
| segmentation_vis = cv2.resize( |
| segmentation_vis, |
| (base_frame.shape[1], base_frame.shape[0]), |
| interpolation=cv2.INTER_NEAREST, |
| ) |
|
|
| if segmentation_result_vis.shape[:2] != base_frame.shape[:2]: |
| segmentation_result_vis = cv2.resize( |
| segmentation_result_vis, |
| (base_frame.shape[1], base_frame.shape[0]), |
| interpolation=cv2.INTER_NEAREST, |
| ) |
|
|
| if segmentation_points: |
| diameter = 5 |
| for center_y, center_x in segmentation_points: |
| cv2.circle(target_for_video, (center_x, center_y), diameter, (255, 0, 0), -1) |
|
|
| return segmentation_vis, segmentation_result_vis, target_for_video |
|
|