| | """ |
| | Echo Tool Managers |
| | |
| | This module provides tool manager classes for various echo tools. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import shutil |
| | import zipfile |
| | import urllib.request |
| | from typing import Dict, List, Any, Optional, Type, Tuple |
| | from pathlib import Path |
| |
|
| | import torch |
| | import numpy as np |
| | import cv2 |
| |
|
| | from utils.video_utils import convert_video_to_h264 |
| |
|
| | |
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) |
| |
|
| | from pydantic import BaseModel, Field |
| | from langchain_core.tools import BaseTool |
| | from tools.general.base_tool_manager import BaseToolManager, ToolConfig, ToolStatus |
| |
|
| |
|
| | |
| | _model_cache = {} |
| |
|
| | _THIS_FILE = Path(__file__).resolve() |
| | _TOOL_REPO_BASES = [ |
| | _THIS_FILE.parents[2] / "tool_repos", |
| | _THIS_FILE.parents[3] / "tool_repos", |
| | ] |
| | workspace_root_env = os.getenv("ECHO_WORKSPACE_ROOT") |
| | if workspace_root_env: |
| | _TOOL_REPO_BASES.append(Path(workspace_root_env) / "tool_repos") |
| |
|
| | |
| | _unique_tool_repo_bases: List[Path] = [] |
| | for base_path in _TOOL_REPO_BASES: |
| | if base_path not in _unique_tool_repo_bases: |
| | _unique_tool_repo_bases.append(base_path) |
| | _TOOL_REPO_BASES = _unique_tool_repo_bases |
| |
|
| |
|
| | def _resolve_tool_repo(repo_names: List[str]) -> Path: |
| | """Return the first existing path for the given tool repo names.""" |
| | for repo_name in repo_names: |
| | for base in _TOOL_REPO_BASES: |
| | candidate = base / repo_name |
| | if candidate.exists(): |
| | return candidate |
| | primary_base = _TOOL_REPO_BASES[0] if _TOOL_REPO_BASES else Path.cwd() |
| | return primary_base / repo_names[0] |
| |
|
| |
|
| | MEDSAM_REPO_ROOT = _resolve_tool_repo(["MedSAM2-main", "MedSAM2"]) |
| | ECHOPRIME_REPO_ROOT = _resolve_tool_repo(["EchoPrime-main", "EchoPrime"]) |
| |
|
| | ECHO_PRIME_RELEASE_BASE = "https://github.com/echonet/EchoPrime/releases/download/v1.0.0" |
| | ECHO_PRIME_MODEL_ZIP_URL = f"{ECHO_PRIME_RELEASE_BASE}/model_data.zip" |
| | ECHO_PRIME_EMBEDDING_FILES = { |
| | "candidate_embeddings_p1.pt": f"{ECHO_PRIME_RELEASE_BASE}/candidate_embeddings_p1.pt", |
| | "candidate_embeddings_p2.pt": f"{ECHO_PRIME_RELEASE_BASE}/candidate_embeddings_p2.pt", |
| | } |
| |
|
| | DEFAULT_ECHO_SEGMENTATION_MASK = MEDSAM_REPO_ROOT / "0108.png" |
| | DEFAULT_ECHO_SEGMENTATION_MASK_DIR = MEDSAM_REPO_ROOT / "default_masks" |
| | DEFAULT_ECHO_SEGMENTATION_STRUCTURES = { |
| | "LV": "LV.png", |
| | "MYO": "MYO.png", |
| | "LA": "LA.png", |
| | "RV": "RV.png", |
| | "RA": "RA.png", |
| | } |
| | DEFAULT_ECHO_SEGMENTATION_CHECKPOINT = MEDSAM_REPO_ROOT / "checkpoints" / "MedSAM2_US_Heart.pt" |
| |
|
| |
|
| | def _download_file(url: str, destination: Path) -> bool: |
| | """Download a file from a URL to the destination path.""" |
| | try: |
| | destination.parent.mkdir(parents=True, exist_ok=True) |
| | print(f"⬇️ Downloading {url} -> {destination}") |
| | request = urllib.request.Request(url, headers={"User-Agent": "EchoAgent/1.0"}) |
| | with urllib.request.urlopen(request) as response, open(destination, "wb") as output_file: |
| | shutil.copyfileobj(response, output_file) |
| | print(f"✅ Downloaded {destination.name}") |
| | return True |
| | except Exception as download_error: |
| | print(f"❌ Failed to download {url}: {download_error}") |
| | if destination.exists(): |
| | destination.unlink() |
| | return False |
| |
|
| |
|
| | def ensure_echoprime_assets(echo_prime_path: Path) -> bool: |
| | """Ensure required EchoPrime assets are available, downloading when missing.""" |
| | model_data_dir = echo_prime_path / "model_data" |
| | weights_dir = model_data_dir / "weights" |
| | candidates_dir = model_data_dir / "candidates_data" |
| |
|
| | required_files = [ |
| | weights_dir / "echo_prime_encoder.pt", |
| | weights_dir / "view_classifier.pt", |
| | candidates_dir / "candidate_embeddings_p1.pt", |
| | candidates_dir / "candidate_embeddings_p2.pt", |
| | ] |
| |
|
| | if all(required_path.exists() for required_path in required_files): |
| | return True |
| |
|
| | print("⚠️ EchoPrime assets missing; attempting automatic download...") |
| | |
| | if not all((weights_dir / filename).exists() for filename in ("echo_prime_encoder.pt", "view_classifier.pt")): |
| | temp_zip_path = echo_prime_path / "model_data.zip" |
| | if _download_file(ECHO_PRIME_MODEL_ZIP_URL, temp_zip_path): |
| | try: |
| | with zipfile.ZipFile(temp_zip_path, "r") as zip_ref: |
| | zip_ref.extractall(echo_prime_path) |
| | print("✅ Extracted model_data.zip") |
| | except zipfile.BadZipFile as zip_error: |
| | print(f"❌ model_data.zip appears corrupted: {zip_error}") |
| | finally: |
| | temp_zip_path.unlink(missing_ok=True) |
| |
|
| | |
| | for filename, url in ECHO_PRIME_EMBEDDING_FILES.items(): |
| | destination = candidates_dir / filename |
| | if not destination.exists(): |
| | _download_file(url, destination) |
| |
|
| | |
| | all_present = all(required_path.exists() for required_path in required_files) |
| | if not all_present: |
| | print("❌ Required EchoPrime assets are still missing after download attempts.") |
| | return all_present |
| |
|
| |
|
| | def load_panecho_model(): |
| | """Load PanEcho model for real predictions with caching.""" |
| | if "panecho" in _model_cache: |
| | print("✅ Using cached PanEcho model") |
| | return _model_cache["panecho"] |
| | |
| | try: |
| | from models.model_factory import get_model |
| | print("🔄 Loading PanEcho model...") |
| | model = get_model("panecho") |
| | if model is None: |
| | raise RuntimeError("PanEcho model not available") |
| | |
| | |
| | _model_cache["panecho"] = model |
| | print("✅ PanEcho model loaded and cached") |
| | return model |
| | except Exception as e: |
| | print(f"PanEcho loading failed: {e}") |
| | raise RuntimeError(f"PanEcho model not available: {e}") |
| |
|
| | def load_medsam2_model(): |
| | """Load MedSAM2 model for segmentation with caching.""" |
| | if "medsam2" in _model_cache: |
| | print("✅ Using cached MedSAM2 model path") |
| | return _model_cache["medsam2"] |
| | |
| | try: |
| | from models.model_factory import get_model |
| | print("🔄 Loading MedSAM2 model...") |
| | model_path = get_model("medsam2") |
| | if model_path is None: |
| | raise RuntimeError("MedSAM2 model not available") |
| | |
| | |
| | _model_cache["medsam2"] = model_path |
| | print(f"✅ MedSAM2 model loaded and cached: {model_path}") |
| | return model_path |
| | except Exception as e: |
| | print(f"MedSAM2 loading failed: {e}") |
| | raise RuntimeError(f"MedSAM2 model not available: {e}") |
| |
|
| | def load_echoflow_model(): |
| | """Load EchoFlow model for generation with caching.""" |
| | if "echoflow" in _model_cache: |
| | print("✅ Using cached EchoFlow model") |
| | return _model_cache["echoflow"] |
| | |
| | try: |
| | from models.model_factory import get_model |
| | print("🔄 Loading EchoFlow model...") |
| | model = get_model("echoflow") |
| | if model is None: |
| | raise RuntimeError("EchoFlow model not available") |
| | |
| | |
| | _model_cache["echoflow"] = model |
| | print("✅ EchoFlow model loaded and cached") |
| | return model |
| | except Exception as e: |
| | print(f"EchoFlow loading failed: {e}") |
| | raise RuntimeError(f"EchoFlow model not available: {e}") |
| |
|
| | def clear_model_cache(): |
| | """Clear the model cache to free memory.""" |
| | global _model_cache |
| | _model_cache.clear() |
| | print("🧹 Model cache cleared") |
| |
|
| | def load_echo_prime_model(): |
| | """Load EchoPrime model for comprehensive analysis with caching.""" |
| | if "echo_prime" in _model_cache: |
| | print("✅ Using cached EchoPrime model") |
| | return _model_cache["echo_prime"] |
| | |
| | try: |
| | |
| | echo_prime_path = ECHOPRIME_REPO_ROOT |
| | if not echo_prime_path.exists(): |
| | print(f"❌ EchoPrime directory not found: {echo_prime_path}") |
| | return None |
| | |
| | if not ensure_echoprime_assets(echo_prime_path): |
| | print("❌ EchoPrime assets unavailable; cannot initialize model.") |
| | return None |
| |
|
| | |
| | if str(echo_prime_path) not in sys.path: |
| | sys.path.insert(0, str(echo_prime_path)) |
| | |
| | |
| | from echo_prime.model import EchoPrime |
| | |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | echo_prime_model = EchoPrime(device=device) |
| | |
| | |
| | _model_cache["echo_prime"] = echo_prime_model |
| | |
| | print("✅ EchoPrime model loaded successfully") |
| | return echo_prime_model |
| | |
| | except Exception as e: |
| | print(f"❌ Failed to load EchoPrime model: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return None |
| |
|
| |
|
| | class EchoDiseasePredictionInput(BaseModel): |
| | """Input schema for echo disease prediction.""" |
| | input_dir: str = Field(..., description="Directory containing echo videos") |
| | max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
| | save_csv: bool = Field(True, description="Save results to CSV file") |
| | include_confidence: bool = Field(True, description="Include confidence scores in output") |
| |
|
| |
|
| | class EchoImageVideoGenerationInput(BaseModel): |
| | """Input schema for echo image/video generation.""" |
| | views: List[str] = Field(..., description="List of echo views to generate") |
| | efs: List[float] = Field(..., description="List of ejection fractions") |
| | outdir: Optional[str] = Field(None, description="Output directory") |
| | num_samples: int = Field(10, description="Number of samples to generate") |
| |
|
| |
|
| | class EchoMeasurementPredictionInput(BaseModel): |
| | """Input schema for echo measurement prediction.""" |
| | input_dir: str = Field(..., description="Directory containing echo videos") |
| | max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
| | include_report: bool = Field(True, description="Include detailed report") |
| | save_csv: bool = Field(True, description="Save measurements to CSV") |
| |
|
| |
|
| | class EchoReportGenerationInput(BaseModel): |
| | """Input schema for echo report generation.""" |
| | input_dir: str = Field(..., description="Directory containing echo videos") |
| | visualize_views: bool = Field(False, description="Generate view visualizations") |
| | max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
| | include_sections: bool = Field(True, description="Include all report sections") |
| |
|
| |
|
| | class EchoSegmentationInput(BaseModel): |
| | """Input schema for echo segmentation.""" |
| |
|
| | video_path: str = Field(..., description="Path to echo video file") |
| | prompt_mode: str = Field("auto", description="Prompt mode for segmentation (auto, points, box, mask)") |
| | target_name: str = Field( |
| | "all", |
| | description="Target structure name (all, LV, RV, LA, RA, MV, TV, AV, PV, IVS, LVPW, AORoot, PA)", |
| | ) |
| | save_mask_video: bool = Field(True, description="Save mask video") |
| | save_overlay_video: bool = Field(True, description="Save overlay video") |
| | points: Optional[List[List[float]]] = Field( |
| | None, |
| | description="List of [x, y, label] triples in normalized coordinates for the first frame (label: 1 foreground, 0 background)", |
| | ) |
| | box: Optional[List[float]] = Field( |
| | None, |
| | description="Normalized box as [x1, y1, x2, y2] for the first frame", |
| | ) |
| | mask_path: Optional[str] = Field( |
| | None, |
| | description="Path to an initial segmentation mask for the first frame (for 'mask' mode)", |
| | ) |
| | sample_rate: int = Field(1, description="Process every Nth frame for speed (1 = every frame)") |
| | output_fps: Optional[int] = Field(None, description="FPS for output video. Defaults to source FPS") |
| | progress_callback: Optional[Any] = Field(None, description="Progress callback function for UI updates") |
| | |
| | initial_masks_dir: Optional[str] = Field( |
| | None, |
| | description="Directory containing first-frame masks per structure (e.g., LV.png, RV.png)", |
| | ) |
| | initial_mask_paths: Optional[Dict[str, str]] = Field( |
| | None, |
| | description="Mapping of structure code (e.g., 'LV') to first-frame mask file path", |
| | ) |
| | initial_mask_frame_idx: int = Field(0, description="Frame index the provided masks correspond to (default 0)") |
| | use_auto_masks_if_missing: bool = Field( |
| | True, |
| | description="If a structure mask is missing, fall back to auto coarse prompt", |
| | ) |
| |
|
| |
|
| | class EchoViewClassificationInput(BaseModel): |
| | """Input schema for echo view classification.""" |
| | input_dir: str = Field(..., description="Directory containing echo videos") |
| | visualize: bool = Field(False, description="Generate visualizations") |
| | max_videos: Optional[int] = Field(None, description="Maximum number of videos to process") |
| |
|
| |
|
| | class EchoDiseasePredictionTool(BaseTool): |
| | """Echo disease prediction tool.""" |
| | |
| | name: str = "echo_disease_prediction" |
| | description: str = "Predict cardiac diseases from echo videos using PanEcho." |
| | args_schema: Type[BaseModel] = EchoDiseasePredictionInput |
| | |
| | def _get_task_units(self, task_name: str) -> str: |
| | """Get units for a specific task.""" |
| | units_map = { |
| | 'EF': '%', |
| | 'GLS': '%', |
| | 'LVEDV': 'cm³', |
| | 'LVESV': 'cm³', |
| | 'LVSV': 'cm³', |
| | 'IVSd': 'cm', |
| | 'LVPWd': 'cm', |
| | 'LVIDs': 'cm', |
| | 'LVIDd': 'cm', |
| | 'LVOTDiam': 'cm', |
| | 'E|EAvg': 'ratio', |
| | 'RVSP': 'mmHg', |
| | 'RVIDd': 'cm', |
| | 'TAPSE': 'cm', |
| | 'RVSVel': 'cm/s', |
| | 'LAIDs2D': 'cm', |
| | 'LAVol': 'cm³', |
| | 'RADimensionM-L(cm)': 'cm', |
| | 'AVPkVel(m/s)': 'm/s', |
| | 'TVPkGrad': 'mmHg', |
| | 'AORoot': 'cm' |
| | } |
| | return units_map.get(task_name, 'N/A') |
| | |
| | def _get_class_names(self, task_name: str) -> list: |
| | """Get class names for classification tasks.""" |
| | class_names_map = { |
| | 'LVSize': ['Mildly Increased', 'Moderately|Severely Increased', 'Normal'], |
| | 'LVSystolicFunction': ['Mildly Decreased', 'Moderately|Severely Decreased', 'Normal|Hyperdynamic'], |
| | 'LVDiastolicFunction': ['Mild|Indeterminate', 'Moderate|Severe', 'Normal'], |
| | 'RVSize': ['Mildly Increased', 'Moderately|Severely Increased', 'Normal'], |
| | 'LASize': ['Mildly Dilated', 'Moderately|Severely Dilated', 'Normal'], |
| | 'AVStenosis': ['Mild|Moderate', 'None', 'Severe'], |
| | 'AVRegurg': ['Mild', 'Moderate|Severe', 'None|Trace'], |
| | 'MVRegurgitation': ['Mild', 'Moderate|Severe', 'None|Trace'], |
| | 'TVRegurgitation': ['Mild', 'Moderate|Severe', 'None|Trace'] |
| | } |
| | return class_names_map.get(task_name, []) |
| | |
| | def _run( |
| | self, |
| | input_dir: str, |
| | max_videos: Optional[int] = None, |
| | save_csv: bool = True, |
| | include_confidence: bool = True, |
| | run_manager: Optional[Any] = None, |
| | ) -> Dict[str, Any]: |
| | """Run echo disease prediction using real PanEcho model.""" |
| | try: |
| | |
| | panecho_model = load_panecho_model() |
| | |
| | |
| | import os |
| | import glob |
| | video_files = glob.glob(os.path.join(input_dir, "*.mp4")) |
| | if max_videos: |
| | video_files = video_files[:max_videos] |
| | |
| | if not video_files: |
| | raise RuntimeError(f"No MP4 videos found in {input_dir}") |
| | |
| | all_predictions = [] |
| | |
| | for video_path in video_files: |
| | try: |
| | |
| | import torchvision.transforms as transforms |
| | |
| | cap = cv2.VideoCapture(video_path) |
| | frames = [] |
| | frame_count = 0 |
| | max_frames = 16 |
| | |
| | |
| | normalize = transforms.Normalize( |
| | mean=[0.485, 0.456, 0.406], |
| | std=[0.229, 0.224, 0.225] |
| | ) |
| | |
| | while len(frames) < max_frames and cap.isOpened(): |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | |
| | |
| | frame = cv2.resize(frame, (224, 224)) |
| | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| | frames.append(frame) |
| | frame_count += 1 |
| | |
| | cap.release() |
| | |
| | if len(frames) < 16: |
| | |
| | while len(frames) < 16: |
| | frames.append(frames[-1]) |
| | |
| | |
| | frames_array = np.array(frames, dtype=np.float32) / 255.0 |
| | frames_tensor = torch.tensor(frames_array).permute(0, 3, 1, 2) |
| | frames_tensor = frames_tensor.unsqueeze(0) |
| | |
| | |
| | frames_tensor = normalize(frames_tensor.view(-1, 3, 224, 224)).view(1, 16, 3, 224, 224) |
| | |
| | |
| | frames_tensor = frames_tensor.permute(0, 2, 1, 3, 4) |
| | |
| | |
| | device = next(panecho_model.parameters()).device |
| | frames_tensor = frames_tensor.to(device) |
| | |
| | |
| | with torch.no_grad(): |
| | predictions = panecho_model(frames_tensor) |
| | |
| | |
| | disease_predictions = {} |
| | |
| | |
| | task_descriptions = { |
| | 'pericardial-effusion': 'Pericardial Effusion', |
| | 'EF': 'Ejection Fraction (%)', |
| | 'GLS': 'Global Longitudinal Strain (%)', |
| | 'LVEDV': 'LV End-Diastolic Volume (cm³)', |
| | 'LVESV': 'LV End-Systolic Volume (cm³)', |
| | 'LVSV': 'LV Stroke Volume (cm³)', |
| | 'LVSize': 'LV Size', |
| | 'LVWallThickness-increased-any': 'LV Wall Thickness - Any Increase', |
| | 'LVWallThickness-increased-modsev': 'LV Wall Thickness - Moderate/Severe Increase', |
| | 'LVSystolicFunction': 'LV Systolic Function', |
| | 'LVWallMotionAbnormalities': 'LV Wall Motion Abnormalities', |
| | 'IVSd': 'Interventricular Septum Diastole (cm)', |
| | 'LVPWd': 'LV Posterior Wall Diastole (cm)', |
| | 'LVIDs': 'LV Internal Diameter Systole (cm)', |
| | 'LVIDd': 'LV Internal Diameter Diastole (cm)', |
| | 'LVOTDiam': 'LV Outflow Tract Diameter (cm)', |
| | 'LVDiastolicFunction': 'LV Diastolic Function', |
| | 'E|EAvg': 'E/e\' Ratio', |
| | 'RVSP': 'RV Systolic Pressure (mmHg)', |
| | 'RVSize': 'RV Size', |
| | 'RVSystolicFunction': 'RV Systolic Function', |
| | 'RVIDd': 'RV Internal Diameter Diastole (cm)', |
| | 'TAPSE': 'Tricuspid Annular Plane Systolic Excursion (cm)', |
| | 'RVSVel': 'RV Systolic Excursion Velocity (cm/s)', |
| | 'LASize': 'Left Atrial Size', |
| | 'LAIDs2D': 'LA Internal Diameter Systole 2D (cm)', |
| | 'LAVol': 'LA Volume (cm³)', |
| | 'RASize': 'Right Atrial Size', |
| | 'RADimensionM-L(cm)': 'RA Major Dimension (cm)', |
| | 'AVStructure': 'Aortic Valve Structure', |
| | 'AVStenosis': 'Aortic Valve Stenosis', |
| | 'AVPkVel(m/s)': 'Aortic Valve Peak Velocity (m/s)', |
| | 'AVRegurg': 'Aortic Valve Regurgitation', |
| | 'LVOT20mmHg': 'Elevated LV Outflow Tract Pressure', |
| | 'MVStenosis': 'Mitral Valve Stenosis', |
| | 'MVRegurgitation': 'Mitral Valve Regurgitation', |
| | 'TVRegurgitation': 'Tricuspid Valve Regurgitation', |
| | 'TVPkGrad': 'Tricuspid Valve Peak Gradient (mmHg)', |
| | 'RAP-8-or-higher': 'Elevated RA Pressure', |
| | 'AORoot': 'Aortic Root Diameter (cm)' |
| | } |
| | |
| | |
| | for task_name, pred_value in predictions.items(): |
| | task_description = task_descriptions.get(task_name, f"{task_name} (Unknown Task)") |
| | |
| | try: |
| | |
| | if torch.is_tensor(pred_value): |
| | if pred_value.shape == (1, 1): |
| | raw_value = float(pred_value[0, 0].item()) |
| | |
| | |
| | if task_name in ['EF', 'GLS', 'LVEDV', 'LVESV', 'LVSV', 'IVSd', 'LVPWd', 'LVIDs', 'LVIDd', |
| | 'LVOTDiam', 'E|EAvg', 'RVSP', 'RVIDd', 'TAPSE', 'RVSVel', 'LAIDs2D', |
| | 'LAVol', 'RADimensionM-L(cm)', 'AVPkVel(m|s)', 'TVPkGrad', 'AORoot']: |
| | |
| | value = raw_value |
| | task_type = 'regression' |
| | confidence = 0.85 |
| | else: |
| | |
| | value = raw_value |
| | task_type = 'binary_classification' |
| | confidence = max(value, 1.0 - value) |
| | |
| | elif pred_value.shape[1] > 1: |
| | |
| | probs = pred_value[0] |
| | predicted_class = int(probs.argmax().item()) |
| | confidence = float(probs.max().item()) |
| | |
| | |
| | class_names = self._get_class_names(task_name) |
| | if class_names and predicted_class < len(class_names): |
| | value = class_names[predicted_class] |
| | else: |
| | value = predicted_class |
| | |
| | task_type = 'multi-class_classification' |
| | |
| | else: |
| | value = float(pred_value.flatten().mean().item()) |
| | task_type = 'regression' |
| | confidence = 0.85 |
| | else: |
| | value = float(pred_value) if isinstance(pred_value, (int, float)) else 0.0 |
| | task_type = 'unknown' |
| | confidence = 0.0 |
| | |
| | disease_predictions[task_name] = { |
| | 'value': value, |
| | 'description': task_description, |
| | 'confidence': confidence, |
| | 'task_type': task_type, |
| | 'units': self._get_task_units(task_name), |
| | 'raw_prediction': float(pred_value[0, 0].item()) if torch.is_tensor(pred_value) and pred_value.shape == (1, 1) else None |
| | } |
| | |
| | except Exception as e: |
| | print(f"Error processing {task_name}: {e}") |
| | disease_predictions[task_name] = { |
| | 'value': 0.0, |
| | 'description': task_description, |
| | 'confidence': 0.0, |
| | 'task_type': 'unknown', |
| | 'units': 'unknown', |
| | 'error': str(e) |
| | } |
| | |
| | all_predictions.append({ |
| | "video": os.path.basename(video_path), |
| | "predictions": disease_predictions |
| | }) |
| | |
| | except Exception as e: |
| | print(f"Error processing {video_path}: {e}") |
| | continue |
| | |
| | if not all_predictions: |
| | raise RuntimeError("No videos processed successfully") |
| | |
| | return { |
| | "status": "success", |
| | "model": "PanEcho", |
| | "input_dir": input_dir, |
| | "max_videos": max_videos, |
| | "processed_videos": len(all_predictions), |
| | "predictions": all_predictions, |
| | "message": f"Disease prediction completed for {len(all_predictions)} videos using real PanEcho model" |
| | } |
| | |
| | except Exception as e: |
| | print(f"PanEcho prediction failed: {e}") |
| | raise RuntimeError(f"Disease prediction failed: {e}") |
| |
|
| |
|
| | class EchoImageVideoGenerationTool(BaseTool): |
| | """Echo image/video generation tool.""" |
| | |
| | name: str = "echo_image_video_generation" |
| | description: str = "Generate synthetic echo images and videos using EchoFlow." |
| | args_schema: Type[BaseModel] = EchoImageVideoGenerationInput |
| | |
| | def _run( |
| | self, |
| | views: List[str], |
| | efs: List[float], |
| | outdir: Optional[str] = None, |
| | num_samples: int = 10, |
| | run_manager: Optional[Any] = None, |
| | ) -> Dict[str, Any]: |
| | """Run echo image/video generation using real EchoFlow model.""" |
| | try: |
| | |
| | echoflow_model = load_echoflow_model() |
| | |
| | |
| | output_dir = Path(outdir or "temp/echo_generated") |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | |
| | generated_files = echoflow_model.generate_synthetic_video( |
| | views=views, |
| | efs=efs, |
| | num_samples=num_samples, |
| | output_dir=str(output_dir) |
| | ) |
| | |
| | successful_generations = len(generated_files) |
| | |
| | return { |
| | "status": "success", |
| | "model": "EchoFlow", |
| | "views": views, |
| | "efs": efs, |
| | "num_samples": num_samples, |
| | "successful_generations": successful_generations, |
| | "generated_files": generated_files, |
| | "output_dir": str(output_dir), |
| | "message": f"Generated {successful_generations} synthetic echo videos using real EchoFlow model" |
| | } |
| | |
| | except Exception as e: |
| | print(f"EchoFlow generation failed: {e}") |
| | raise RuntimeError(f"Echo generation failed: {e}") |
| |
|
| |
|
| | class EchoMeasurementPredictionTool(BaseTool): |
| | """Echo measurement prediction tool.""" |
| | |
| | name: str = "echo_measurement_prediction" |
| | description: str = "Extract echocardiography measurements using EchoPrime." |
| | args_schema: Type[BaseModel] = EchoMeasurementPredictionInput |
| | |
| | def _run( |
| | self, |
| | input_dir: str, |
| | max_videos: Optional[int] = None, |
| | include_report: bool = True, |
| | save_csv: bool = True, |
| | run_manager: Optional[Any] = None, |
| | ) -> Dict[str, Any]: |
| | """Run echo measurement prediction using real EchoPrime model.""" |
| | try: |
| | |
| | echo_prime_model = load_echo_prime_model() |
| | |
| | |
| | print(f"🔄 Processing videos from {input_dir}...") |
| | stack_of_videos = echo_prime_model.process_mp4s(input_dir) |
| | |
| | if len(stack_of_videos) == 0: |
| | raise RuntimeError("No videos processed successfully") |
| | |
| | print(f"✅ Processed {len(stack_of_videos)} videos") |
| | |
| | |
| | print("🔄 Predicting measurements...") |
| | |
| | |
| | video_features = echo_prime_model.embed_videos(stack_of_videos) |
| | view_encodings = echo_prime_model.get_views(stack_of_videos) |
| | |
| | |
| | if view_encodings.dim() == 1: |
| | view_encodings = view_encodings.unsqueeze(0) |
| | |
| | |
| | study_embedding = torch.cat((video_features, view_encodings), dim=1) |
| | |
| | measurements = echo_prime_model.predict_metrics(study_embedding) |
| | |
| | |
| | formatted_measurements = {} |
| | for key, value in measurements.items(): |
| | if isinstance(value, (int, float)) and not np.isnan(value): |
| | |
| | unit = "%" if key == "EF" else "cm" if "d" in key else "mL" |
| | formatted_measurements[key] = { |
| | "value": float(value), |
| | "unit": unit, |
| | "confidence": 0.85 |
| | } |
| | |
| | all_measurements = [{ |
| | "video": "study_measurements", |
| | "measurements": formatted_measurements |
| | }] |
| | |
| | return { |
| | "status": "success", |
| | "model": "EchoPrime", |
| | "input_dir": input_dir, |
| | "max_videos": max_videos, |
| | "processed_videos": len(stack_of_videos), |
| | "measurements": all_measurements, |
| | "message": f"Measurement prediction completed for {len(stack_of_videos)} videos using real EchoPrime model" |
| | } |
| | |
| | except Exception as e: |
| | print(f"EchoPrime measurement prediction failed: {e}") |
| | raise RuntimeError(f"Measurement prediction failed: {e}") |
| |
|
| |
|
| | class EchoReportGenerationTool(BaseTool): |
| | """Echo report generation tool.""" |
| | |
| | name: str = "echo_report_generation" |
| | description: str = "Generate comprehensive echo report using EchoPrime." |
| | args_schema: Type[BaseModel] = EchoReportGenerationInput |
| | |
| | def _run( |
| | self, |
| | input_dir: str, |
| | visualize_views: bool = False, |
| | max_videos: Optional[int] = None, |
| | include_sections: bool = True, |
| | run_manager: Optional[Any] = None, |
| | ) -> Dict[str, Any]: |
| | """Run echo report generation using real EchoPrime model.""" |
| | try: |
| | |
| | echo_prime_model = load_echo_prime_model() |
| | |
| | |
| | print(f"🔄 Processing videos from {input_dir}...") |
| | stack_of_videos = echo_prime_model.process_mp4s(input_dir) |
| | |
| | if len(stack_of_videos) == 0: |
| | raise RuntimeError("No videos processed successfully") |
| | |
| | print(f"✅ Processed {len(stack_of_videos)} videos") |
| | |
| | |
| | print("🔄 Generating comprehensive report...") |
| | |
| | |
| | video_features = echo_prime_model.embed_videos(stack_of_videos) |
| | view_encodings = echo_prime_model.get_views(stack_of_videos, visualize=visualize_views) |
| | |
| | |
| | if view_encodings.dim() == 1: |
| | view_encodings = view_encodings.unsqueeze(0) |
| | |
| | |
| | study_embedding = torch.cat((video_features, view_encodings), dim=1) |
| | |
| | report = echo_prime_model.generate_report(study_embedding) |
| | |
| | |
| | measurements = echo_prime_model.predict_metrics(study_embedding) |
| | |
| | |
| | views = echo_prime_model.get_views(stack_of_videos, return_view_list=True) |
| | |
| | |
| | analysis = { |
| | "video": "study_analysis", |
| | "view_classification": { |
| | "predicted_views": views, |
| | "view_distribution": {view: views.count(view) for view in set(views)} |
| | }, |
| | "measurements": measurements, |
| | "disease_predictions": {}, |
| | "quality_assessment": { |
| | "confidence": 0.85, |
| | "model_used": "EchoPrime" |
| | }, |
| | "confidence": 0.85 |
| | } |
| | |
| | return { |
| | "status": "success", |
| | "model": "EchoPrime", |
| | "input_dir": input_dir, |
| | "max_videos": max_videos, |
| | "processed_videos": len(stack_of_videos), |
| | "report": report, |
| | "analysis": analysis, |
| | "message": f"Report generation completed for {len(stack_of_videos)} videos using real EchoPrime model" |
| | } |
| | |
| | except Exception as e: |
| | print(f"EchoPrime report generation failed: {e}") |
| | raise RuntimeError(f"Report generation failed: {e}") |
| | |
| | def _generate_comprehensive_report(self, analyses, include_sections): |
| | """Generate comprehensive report from analyses.""" |
| | |
| | all_measurements = [] |
| | all_disease_predictions = [] |
| | view_distribution = {} |
| | |
| | for analysis in analyses: |
| | all_measurements.append(analysis.get("measurements", {})) |
| | all_disease_predictions.append(analysis.get("disease_predictions", {})) |
| | |
| | view = analysis.get("view_classification", {}).get("predicted_view", "unknown") |
| | view_distribution[view] = view_distribution.get(view, 0) + 1 |
| | |
| | |
| | avg_measurements = {} |
| | measurement_keys = ['EF', 'LVEDV', 'LVESV', 'GLS', 'IVSd', 'LVPWd', 'LVIDs', 'LVIDd'] |
| | |
| | for key in measurement_keys: |
| | values = [m.get(key, {}).get("value", 0) if isinstance(m.get(key), dict) else m.get(key, 0) for m in all_measurements if m] |
| | if values: |
| | avg_measurements[key] = np.mean(values) |
| | |
| | |
| | ef = avg_measurements.get("EF", 0) |
| | if ef > 55: |
| | ef_status = "Normal" |
| | elif ef > 45: |
| | ef_status = "Mildly reduced" |
| | else: |
| | ef_status = "Moderately to severely reduced" |
| | |
| | summary = f"Left ventricular ejection fraction is {ef_status} ({ef:.1f}%). " |
| | if "LVEDV" in avg_measurements: |
| | summary += f"Left ventricular end-diastolic volume is {avg_measurements['LVEDV']:.1f} mL. " |
| | if "GLS" in avg_measurements: |
| | summary += f"Global longitudinal strain is {avg_measurements['GLS']:.1f}%." |
| | |
| | |
| | recommendations = [] |
| | if ef < 50: |
| | recommendations.append("Consider cardiology consultation") |
| | if avg_measurements.get("GLS", 0) < -18: |
| | recommendations.append("Monitor for heart failure") |
| | if not recommendations: |
| | recommendations.append("Routine follow-up in 1 year") |
| | |
| | |
| | sections = [] |
| | if include_sections: |
| | sections = ["findings", "measurements", "view_analysis", "recommendations"] |
| | |
| | report = { |
| | "summary": summary, |
| | "recommendations": recommendations, |
| | "sections": sections, |
| | "measurements": {k: f"{v:.1f}" for k, v in avg_measurements.items()}, |
| | "view_distribution": view_distribution, |
| | "processed_videos": len(analyses), |
| | "overall_confidence": np.mean([a.get("confidence", 0) for a in analyses]) |
| | } |
| | |
| | return report |
| | |
| | def _create_view_visualization(self, analyses, input_dir): |
| | """Create view visualization.""" |
| | try: |
| | import matplotlib.pyplot as plt |
| | |
| | |
| | view_counts = {} |
| | for analysis in analyses: |
| | view = analysis.get("view_classification", {}).get("predicted_view", "unknown") |
| | view_counts[view] = view_counts.get(view, 0) + 1 |
| | |
| | |
| | plt.figure(figsize=(8, 6)) |
| | plt.pie(view_counts.values(), labels=view_counts.keys(), autopct='%1.1f%%') |
| | plt.title("Echo View Distribution") |
| | |
| | |
| | output_path = Path(input_dir) / "view_distribution.png" |
| | plt.savefig(output_path, dpi=300, bbox_inches='tight') |
| | plt.close() |
| | |
| | return str(output_path) |
| | |
| | except Exception as e: |
| | print(f"Visualization creation failed: {e}") |
| | return None |
| |
|
| |
|
| | class EchoSegmentationTool(BaseTool): |
| | """Echo segmentation tool.""" |
| | |
| | name: str = "echo_segmentation" |
| | description: str = "Segment cardiac chambers in echo videos using MedSAM2." |
| | args_schema: Type[BaseModel] = EchoSegmentationInput |
| | |
| | def _run( |
| | self, |
| | video_path: str, |
| | prompt_mode: str = "auto", |
| | target_name: str = "all", |
| | save_mask_video: bool = True, |
| | save_overlay_video: bool = True, |
| | points: Optional[List[List[float]]] = None, |
| | box: Optional[List[float]] = None, |
| | mask_path: Optional[str] = None, |
| | sample_rate: int = 1, |
| | output_fps: Optional[int] = None, |
| | progress_callback: Optional[callable] = None, |
| | |
| | initial_masks_dir: Optional[str] = None, |
| | initial_mask_paths: Optional[Dict[str, str]] = None, |
| | initial_mask_frame_idx: int = 0, |
| | use_auto_masks_if_missing: bool = True, |
| | run_manager: Optional[Any] = None, |
| | query: Optional[str] = None, |
| | ) -> Dict[str, Any]: |
| | """Run echo segmentation using real MedSAM2 model.""" |
| | try: |
| | normalized_points: Optional[List[Tuple[float, float, int]]] = None |
| | if points: |
| | normalized_points = [] |
| | for entry in points: |
| | if not isinstance(entry, (list, tuple)) or len(entry) < 3: |
| | continue |
| | try: |
| | x_val = float(entry[0]) |
| | y_val = float(entry[1]) |
| | label_val = int(entry[2]) |
| | normalized_points.append((x_val, y_val, label_val)) |
| | except (TypeError, ValueError): |
| | continue |
| | if not normalized_points: |
| | normalized_points = None |
| |
|
| | normalized_box: Optional[Tuple[float, float, float, float]] = None |
| | if box and isinstance(box, (list, tuple)) and len(box) >= 4: |
| | try: |
| | normalized_box = ( |
| | float(box[0]), |
| | float(box[1]), |
| | float(box[2]), |
| | float(box[3]), |
| | ) |
| | except (TypeError, ValueError): |
| | normalized_box = None |
| |
|
| | |
| | medsam2_model_path = load_medsam2_model() |
| | |
| | |
| | cap = cv2.VideoCapture(video_path) |
| | masks = [] |
| | frames = [] |
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| | if not fps or fps <= 1e-3: |
| | fps = 30.0 |
| | |
| | |
| | while cap.isOpened(): |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | frames.append(frame) |
| | |
| | cap.release() |
| | |
| | if not frames: |
| | raise RuntimeError(f"No frames found in video: {video_path}") |
| | |
| | |
| | try: |
| | frames_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames] |
| |
|
| | |
| | height, width = frames[0].shape[:2] |
| | provided_masks = None |
| | if (initial_masks_dir or initial_mask_paths): |
| | provided_masks = self._load_initial_masks( |
| | height, |
| | width, |
| | initial_masks_dir=initial_masks_dir, |
| | initial_mask_paths=initial_mask_paths, |
| | ) |
| |
|
| | if provided_masks: |
| | print(f"✅ Using {len(provided_masks)} annotation masks: {list(provided_masks.keys())}") |
| | else: |
| | |
| | annotation_prompts = self._load_annotation_prompts_from_config( |
| | height, |
| | width, |
| | video_path, |
| | ) |
| | if annotation_prompts: |
| | provided_masks = annotation_prompts |
| | print(f"✅ Loaded config-based annotation masks: {list(provided_masks.keys())}") |
| | else: |
| | print("⚠️ No annotation masks found for video; falling back to auto prompts") |
| |
|
| | if not provided_masks: |
| | |
| | try: |
| | from config import Config |
| | default_path = getattr(Config, 'DEFAULT_INITIAL_MASK_PATH', '') |
| | default_structure = getattr(Config, 'DEFAULT_INITIAL_MASK_STRUCTURE', 'LV') |
| | if isinstance(default_path, str) and default_path: |
| | import os as _os |
| | if _os.path.exists(default_path): |
| | provided_masks = self._load_initial_masks( |
| | height, |
| | width, |
| | initial_mask_paths={str(default_structure).upper(): default_path}, |
| | ) |
| | print(f"⚠️ Using default single-structure mask for {default_structure}") |
| | except Exception: |
| | pass |
| |
|
| | segmentation_result = self._segment_with_medsam2( |
| | frames_rgb, |
| | medsam2_model_path, |
| | progress_callback, |
| | initial_masks=provided_masks, |
| | ) |
| |
|
| | masks = [] |
| | for frame_idx in range(len(frames)): |
| | combined = np.zeros((height, width), dtype=np.uint8) |
| | frame_masks = segmentation_result['masks'].get(frame_idx, {}) |
| | for obj_mask in frame_masks.values(): |
| | mask_array = obj_mask |
| | if mask_array.shape != (height, width): |
| | mask_array = cv2.resize(mask_array, (width, height), interpolation=cv2.INTER_NEAREST) |
| | combined = np.maximum(combined, mask_array) |
| | masks.append(combined) |
| |
|
| | self._enhanced_segmentation_result = segmentation_result |
| | |
| | except Exception as e: |
| | print(f"Error in enhanced MedSAM2 video segmentation: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | |
| | try: |
| | from tools.echo.medsam2_integration import MedSAM2VideoSegmenter |
| | segmenter = MedSAM2VideoSegmenter(medsam2_model_path) |
| | masks = segmenter.segment_video(frames_rgb, target_name, progress_callback) |
| | self._enhanced_segmentation_result = None |
| | except Exception as e2: |
| | |
| | raise |
| | |
| | except Exception as e: |
| | print(f"Error loading MedSAM2 model or processing video: {e}") |
| | return { |
| | "status": "error", |
| | "error": str(e), |
| | "video_path": video_path, |
| | "target_name": target_name |
| | } |
| | |
| | |
| | outputs = {} |
| | if save_mask_video or save_overlay_video: |
| | output_dir = Path("temp") / "segmentation_outputs" |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | |
| | if save_mask_video: |
| | mask_video_path = output_dir / f"mask_{target_name}_{Path(video_path).stem}.mp4" |
| | mask_video_path = self._save_mask_video(frames, masks, str(mask_video_path)) |
| | outputs["mask_video"] = mask_video_path |
| |
|
| | if save_overlay_video: |
| | overlay_video_path = output_dir / f"overlay_{target_name}_{Path(video_path).stem}.mp4" |
| | overlay_video_path = self._save_overlay_video(frames, masks, str(overlay_video_path)) |
| | outputs["overlay_video"] = overlay_video_path |
| | |
| | |
| | enhanced_outputs = {} |
| | if hasattr(self, '_enhanced_segmentation_result') and self._enhanced_segmentation_result: |
| | try: |
| | enhanced_outputs = self._create_enhanced_videos(frames, self._enhanced_segmentation_result, output_dir, fps=fps) |
| | except Exception as e: |
| | print(f"Error creating enhanced videos: {e}") |
| | enhanced_outputs = {} |
| | |
| | |
| | final_outputs = { |
| | **outputs, |
| | **enhanced_outputs, |
| | "masks": len(masks), |
| | "frames_processed": len(frames) |
| | } |
| | |
| | |
| | if "mask_video" in outputs: |
| | final_outputs["segmented_video"] = outputs["mask_video"] |
| | if "overlay_video" in outputs: |
| | final_outputs["overlay_video"] = outputs["overlay_video"] |
| |
|
| | |
| | if target_name == "all" and "combined_segmentation_video" in enhanced_outputs: |
| | final_outputs["segmented_video"] = enhanced_outputs["combined_segmentation_video"] |
| | final_outputs["overlay_video"] = enhanced_outputs["combined_segmentation_video"] |
| |
|
| | if hasattr(self, '_enhanced_segmentation_result') and self._enhanced_segmentation_result: |
| | final_outputs.setdefault( |
| | "structures", |
| | self._enhanced_segmentation_result.get("structures", []), |
| | ) |
| | final_outputs.setdefault( |
| | "structure_info", |
| | self._enhanced_segmentation_result.get("structure_info", {}), |
| | ) |
| |
|
| | return { |
| | "status": "success", |
| | "model": "MedSAM2", |
| | "video_path": video_path, |
| | "target_name": target_name, |
| | "prompt_mode": prompt_mode, |
| | "outputs": final_outputs, |
| | "message": f"Enhanced segmentation completed with {len(masks)} frames processed using MedSAM2 model" |
| | } |
| |
|
| | def _segment_with_medsam2(self, frames, model_path, progress_callback=None, initial_masks: Optional[Dict[str, np.ndarray]] = None): |
| | """Segment video using enhanced MedSAM2 model with multi-structure support.""" |
| | try: |
| | |
| | import sys |
| | import os |
| | current_dir = os.path.dirname(os.path.abspath(__file__)) |
| | project_root = os.path.dirname(os.path.dirname(os.path.dirname(current_dir))) |
| | sys.path.insert(0, project_root) |
| | |
| | from tools.echo.enhanced_medsam2_integration import EnhancedMedSAM2VideoSegmenter |
| | |
| | print("✅ Successfully imported EnhancedMedSAM2VideoSegmenter") |
| | if progress_callback: |
| | progress_callback(10, "Initializing enhanced MedSAM2 model...") |
| | |
| | |
| | segmenter = EnhancedMedSAM2VideoSegmenter(model_path) |
| | |
| | if progress_callback: |
| | progress_callback(20, "Starting multi-structure video segmentation...") |
| |
|
| | |
| | result = segmenter.segment_video_multi_structure(frames, progress_callback, initial_masks=initial_masks) |
| |
|
| | print(f"✅ Generated multi-structure masks for {result['total_frames']} frames") |
| | print(f"🎯 Segmented structures: {result['structures']}") |
| | if progress_callback: |
| | progress_callback(100, "Multi-structure segmentation completed!") |
| | |
| | return result |
| | |
| | except Exception as e: |
| | print(f"❌ Enhanced MedSAM2 segmentation error: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | if progress_callback: |
| | progress_callback(0, f"Segmentation failed: {e}") |
| | raise RuntimeError(f"Enhanced MedSAM2 segmentation failed: {e}") |
| |
|
| | def _load_initial_masks( |
| | self, |
| | height: int, |
| | width: int, |
| | initial_masks_dir: Optional[str] = None, |
| | initial_mask_paths: Optional[Dict[str, str]] = None, |
| | ) -> Dict[str, np.ndarray]: |
| | """Load first-frame masks from a directory or explicit mapping. |
| | |
| | Supported structure keys: 'LV','MYO','LA','RV','RA' (others ignored for now). |
| | Images are read as grayscale; any non-zero treated as foreground; resized to (width,height). |
| | """ |
| | import os |
| | import glob |
| | valid_structures = {"LV", "MYO", "LA", "RV", "RA"} |
| | paths: Dict[str, str] = {} |
| |
|
| | def read_mask(path: str) -> Optional[np.ndarray]: |
| | try: |
| | if path.lower().endswith(".npy"): |
| | arr = np.load(path) |
| | if arr.ndim == 3: |
| | arr = arr.squeeze() |
| | arr = (arr > 0).astype(np.uint8) * 255 |
| | else: |
| | img = cv2.imread(path, cv2.IMREAD_GRAYSCALE) |
| | if img is None: |
| | return None |
| | arr = (img > 0).astype(np.uint8) * 255 |
| | if arr.shape != (height, width): |
| | arr = cv2.resize(arr, (width, height), interpolation=cv2.INTER_NEAREST) |
| | return arr |
| | except Exception: |
| | return None |
| |
|
| | |
| | if initial_mask_paths: |
| | for k, v in initial_mask_paths.items(): |
| | key = str(k).upper() |
| | if key in valid_structures and isinstance(v, str) and os.path.exists(v): |
| | paths[key] = v |
| |
|
| | |
| | if initial_masks_dir and os.path.isdir(initial_masks_dir): |
| | candidates = {} |
| | for ext in ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.tif", "*.tiff", "*.npy"): |
| | for p in glob.glob(os.path.join(initial_masks_dir, ext)): |
| | name = os.path.splitext(os.path.basename(p))[0].lower() |
| | |
| | for s in valid_structures: |
| | s_lower = s.lower() |
| | if name == s_lower or name.startswith(s_lower) or name.endswith(s_lower): |
| | candidates.setdefault(s, p) |
| | for s, p in candidates.items(): |
| | paths.setdefault(s, p) |
| |
|
| | loaded: Dict[str, np.ndarray] = {} |
| | for s, p in paths.items(): |
| | mask = read_mask(p) |
| | if mask is not None and mask.any(): |
| | loaded[s] = mask |
| | return loaded |
| |
|
| | def _load_annotation_prompts_from_config(self, height: int, width: int, video_path: str) -> Dict[str, np.ndarray]: |
| | """Load annotation-derived first-frame masks using Config.ANNOTATION_PROMPTS.""" |
| | try: |
| | from config import Config |
| | except Exception: |
| | return {} |
| |
|
| | mapping = getattr(Config, "ANNOTATION_PROMPTS", {}) or {} |
| | if not mapping: |
| | return {} |
| |
|
| | candidates = [] |
| | from pathlib import Path |
| | stem = Path(video_path).stem |
| | name = Path(video_path).name |
| | candidates.extend([stem, name, os.path.abspath(video_path)]) |
| |
|
| | original_name = os.getenv("ECHO_ORIGINAL_VIDEO_NAME") |
| | if original_name: |
| | candidates.append(original_name) |
| | candidates.append(Path(original_name).stem) |
| |
|
| | print(f"🔍 Annotation lookup candidates: {candidates}") |
| | print(f"🔍 Available annotation entries: {list(mapping.keys())}") |
| |
|
| | entry = None |
| | for key in candidates: |
| | if key in mapping: |
| | entry = mapping[key] |
| | break |
| | if entry is None: |
| | return {} |
| |
|
| | frames_dir = entry.get("frames_dir") |
| | frame_index = int(entry.get("frame_index", 0)) |
| | label_map = entry.get("label_map", {}) |
| |
|
| | if not frames_dir: |
| | return {} |
| |
|
| | if os.path.isdir(frames_dir): |
| | frame_path = os.path.join(frames_dir, f"{frame_index:04d}.png") |
| | else: |
| | frame_path = frames_dir |
| |
|
| | if not os.path.exists(frame_path): |
| | print(f"⚠️ Annotation prompt not found: {frame_path}") |
| | return {} |
| |
|
| | mask_img = cv2.imread(frame_path, cv2.IMREAD_GRAYSCALE) |
| | if mask_img is None: |
| | print(f"⚠️ Failed to read annotation prompt: {frame_path}") |
| | return {} |
| |
|
| | if mask_img.shape != (height, width): |
| | mask_img = cv2.resize(mask_img, (width, height), interpolation=cv2.INTER_NEAREST) |
| |
|
| | loaded: Dict[str, np.ndarray] = {} |
| | if label_map: |
| | for raw_value, structure in label_map.items(): |
| | try: |
| | value = int(raw_value) |
| | except Exception: |
| | continue |
| | structure_key = str(structure).upper() |
| | mask = (mask_img == value).astype(np.uint8) * 255 |
| | if mask.any(): |
| | loaded[structure_key] = mask |
| | else: |
| | |
| | mask = (mask_img > 0).astype(np.uint8) * 255 |
| | if mask.any(): |
| | loaded["LV"] = mask |
| |
|
| | if loaded: |
| | print(f"✅ Loaded annotation prompts for {video_path} from {frame_path}") |
| |
|
| | return loaded |
| | |
| | def _create_enhanced_videos(self, frames, segmentation_result, output_dir, fps: float = 30.0): |
| | """Create overlay video showing all segmented structures.""" |
| | try: |
| | structures = segmentation_result['structures'] |
| | structure_info = segmentation_result['structure_info'] |
| | all_masks = segmentation_result['masks'] |
| |
|
| | combined_video_path = output_dir / "combined_segmentation_video.avi" |
| | combined_final_path = self._save_combined_overlay_video( |
| | frames, |
| | all_masks, |
| | structures, |
| | structure_info, |
| | str(combined_video_path), |
| | fps=fps, |
| | ) |
| |
|
| | combined_final_path = convert_video_to_h264(str(combined_final_path)) |
| |
|
| | return {"combined_segmentation_video": str(combined_final_path)} |
| |
|
| | except Exception as e: |
| | print(f"❌ Error creating enhanced videos: {e}") |
| | return {} |
| |
|
| | def _save_combined_overlay_video(self, frames, all_masks, structures, structure_info, output_path, fps: float = 30.0) -> str: |
| | """Save a single AVI where all structures are overlaid in unique colors.""" |
| | if not frames or not all_masks: |
| | return output_path |
| |
|
| | height, width = frames[0].shape[:2] |
| | final_path = output_path if output_path.lower().endswith(".avi") else os.path.splitext(output_path)[0] + ".avi" |
| |
|
| | fourcc = cv2.VideoWriter_fourcc(*'XVID') |
| | writer = cv2.VideoWriter(final_path, fourcc, fps, (width, height)) |
| |
|
| | for frame_index, frame in enumerate(frames): |
| | if frame_index in all_masks: |
| | base_frame = frame.copy() |
| | color_layer = np.zeros_like(frame) |
| | contour_masks = [] |
| | for obj_id, mask in all_masks[frame_index].items(): |
| | if obj_id <= len(structures): |
| | structure_id = structures[obj_id - 1] |
| | color = structure_info[structure_id]['color'] |
| |
|
| | if mask.shape != (height, width): |
| | mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST) |
| |
|
| | mask_bool = mask > 0 |
| | if not np.any(mask_bool): |
| | continue |
| | color_layer[mask_bool] = color |
| | contour_masks.append(mask_bool.astype(np.uint8)) |
| |
|
| | overlay = cv2.addWeighted(base_frame, 0.6, color_layer, 0.4, 0) |
| | for mask_bool in contour_masks: |
| | contours, _ = cv2.findContours(mask_bool, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) |
| |
|
| | writer.write(overlay) |
| | else: |
| | writer.write(frame) |
| |
|
| | writer.release() |
| | print(f"✅ Saved combined overlay video: {final_path}") |
| | return final_path |
| |
|
| | def _overlay_mask(self, frame: np.ndarray, mask: np.ndarray, color=(0, 255, 0), alpha=0.35): |
| | """Create overlay with proper alpha blending and contour visualization.""" |
| | overlay = frame.copy() |
| | |
| | |
| | while len(mask.shape) > 2: |
| | mask = mask.squeeze() |
| | |
| | |
| | if mask.shape != frame.shape[:2]: |
| | mask = cv2.resize(mask, (frame.shape[1], frame.shape[0]), interpolation=cv2.INTER_NEAREST) |
| | |
| | |
| | binary_mask = (mask > 0).astype(np.uint8) |
| | |
| | |
| | colored_overlay = np.zeros_like(frame) |
| | colored_overlay[binary_mask > 0] = color |
| | |
| | |
| | overlay = cv2.addWeighted(overlay, 1 - alpha, colored_overlay, alpha, 0) |
| | |
| | |
| | contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | cv2.drawContours(overlay, contours, -1, (255, 255, 255), 2) |
| | |
| | return overlay |
| | |
| | def _save_mask_video(self, frames, masks, output_path): |
| | """Save mask video and return a browser-friendly H.264 path.""" |
| | if not frames or not masks: |
| | return output_path |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(output_path, fourcc, 30.0, (frames[0].shape[1], frames[0].shape[0])) |
| | |
| | for mask in masks: |
| | |
| | mask_3ch = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR) |
| | |
| | |
| | if mask.max() > 0: |
| | contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| | cv2.drawContours(mask_3ch, contours, -1, (0, 255, 0), 2) |
| | |
| | out.write(mask_3ch) |
| | |
| | out.release() |
| | converted_path = convert_video_to_h264(output_path) |
| | print(f"✅ Saved mask video: {converted_path}") |
| | return converted_path |
| |
|
| | def _save_overlay_video(self, frames, masks, output_path): |
| | """Save overlay video with corrected overlay logic and return H.264 path.""" |
| | if not frames or not masks: |
| | return output_path |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(output_path, fourcc, 30.0, (frames[0].shape[1], frames[0].shape[0])) |
| | |
| | for frame, mask in zip(frames, masks): |
| | overlay = self._overlay_mask(frame, mask, color=(0, 255, 0), alpha=0.35) |
| | out.write(overlay) |
| | |
| | out.release() |
| | converted_path = convert_video_to_h264(output_path) |
| | print(f"✅ Saved overlay video: {converted_path}") |
| | return converted_path |
| | |
| | def segment_all_structures(self, video_path: str, output_dir: str, progress_callback=None) -> Dict[str, Any]: |
| | """Segment all cardiac structures in the video.""" |
| | cardiac_structures = [ |
| | "LV", |
| | "RV", |
| | "LA", |
| | "RA", |
| | "MV", |
| | "TV", |
| | "AV", |
| | "PV", |
| | "IVS", |
| | "LVPW", |
| | "AORoot", |
| | "PA", |
| | ] |
| | |
| | results = { |
| | "status": "success", |
| | "segmented_structures": {}, |
| | "combined_video": None, |
| | "individual_videos": {}, |
| | "overlay_videos": {} |
| | } |
| | |
| | try: |
| | |
| | import os |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | total_structures = len(cardiac_structures) |
| | |
| | for i, structure in enumerate(cardiac_structures): |
| | if progress_callback: |
| | progress_callback( |
| | int((i / total_structures) * 100), |
| | f"Segmenting {structure}..." |
| | ) |
| | |
| | try: |
| | |
| | structure_result = self._run( |
| | video_path=video_path, |
| | target_name=structure, |
| | save_mask_video=True, |
| | save_overlay_video=True, |
| | sample_rate=2, |
| | progress_callback=lambda p, msg: None |
| | ) |
| | |
| | |
| | if structure_result.get("status") == "success": |
| | if "mask_video" in structure_result: |
| | old_mask_path = structure_result["mask_video"] |
| | new_mask_path = os.path.join(output_dir, f"mask_{structure}.mp4") |
| | if os.path.exists(old_mask_path): |
| | import shutil |
| | shutil.move(old_mask_path, new_mask_path) |
| | structure_result["mask_video"] = convert_video_to_h264(new_mask_path) |
| | |
| | if "overlay_video" in structure_result: |
| | old_overlay_path = structure_result["overlay_video"] |
| | new_overlay_path = os.path.join(output_dir, f"overlay_{structure}.mp4") |
| | if os.path.exists(old_overlay_path): |
| | import shutil |
| | shutil.move(old_overlay_path, new_overlay_path) |
| | structure_result["overlay_video"] = convert_video_to_h264(new_overlay_path) |
| | |
| | if structure_result.get("status") == "success": |
| | results["segmented_structures"][structure] = structure_result |
| | |
| | |
| | if "mask_video" in structure_result: |
| | results["individual_videos"][structure] = structure_result["mask_video"] |
| | if "overlay_video" in structure_result: |
| | results["overlay_videos"][structure] = structure_result["overlay_video"] |
| | |
| | except Exception as e: |
| | print(f"❌ Failed to segment {structure}: {e}") |
| | results["segmented_structures"][structure] = { |
| | "status": "failed", |
| | "error": str(e) |
| | } |
| | |
| | |
| | if progress_callback: |
| | progress_callback(90, "Creating combined segmentation...") |
| | |
| | results["combined_video"] = self._create_combined_segmentation( |
| | results["individual_videos"], |
| | output_dir |
| | ) |
| | |
| | if progress_callback: |
| | progress_callback(100, "Segmentation completed!") |
| | |
| | return results |
| | |
| | except Exception as e: |
| | results["status"] = "failed" |
| | results["error"] = str(e) |
| | return results |
| | |
| | def _create_combined_segmentation(self, individual_videos: Dict[str, str], output_dir: str) -> Optional[str]: |
| | """Create a combined video showing all segmented structures.""" |
| | try: |
| | if not individual_videos: |
| | return None |
| | |
| | import numpy as np |
| | |
| | |
| | first_video = list(individual_videos.values())[0] |
| | cap = cv2.VideoCapture(first_video) |
| | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| | cap.release() |
| | |
| | |
| | output_path = os.path.join(output_dir, "combined_segmentation.mp4") |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) |
| | |
| | |
| | colors = { |
| | "LV": (0, 255, 0), |
| | "RV": (255, 0, 0), |
| | "LA": (0, 255, 255), |
| | "RA": (255, 0, 255), |
| | "MV": (255, 255, 0), |
| | "TV": (128, 0, 128), |
| | "AV": (255, 165, 0), |
| | "PV": (0, 128, 128), |
| | "IVS": (128, 128, 0), |
| | "LVPW": (128, 0, 0), |
| | "AORoot": (0, 128, 0), |
| | "PA": (0, 0, 128), |
| | } |
| | |
| | |
| | video_caps = {} |
| | for structure, video_path in individual_videos.items(): |
| | if os.path.exists(video_path): |
| | video_caps[structure] = cv2.VideoCapture(video_path) |
| | |
| | if not video_caps: |
| | return None |
| | |
| | |
| | frame_count = 0 |
| | while True: |
| | frames = {} |
| | all_done = True |
| | |
| | |
| | for structure, cap in video_caps.items(): |
| | ret, frame = cap.read() |
| | if ret: |
| | frames[structure] = frame |
| | all_done = False |
| | else: |
| | frames[structure] = None |
| | |
| | if all_done: |
| | break |
| | |
| | |
| | combined_frame = np.zeros((height, width, 3), dtype=np.uint8) |
| | |
| | for structure, frame in frames.items(): |
| | if frame is not None: |
| | |
| | gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) |
| | color = colors.get(structure, (255, 255, 255)) |
| | |
| | |
| | colored_mask = np.zeros_like(combined_frame) |
| | colored_mask[gray > 0] = color |
| | |
| | |
| | combined_frame = cv2.addWeighted(combined_frame, 1.0, colored_mask, 0.7, 0) |
| | |
| | |
| | y_offset = 30 |
| | for structure in ["LV", "RV", "LA", "RA", "MV", "TV", "AV", "PV", "IVS", "LVPW", "AORoot", "PA"]: |
| | if structure in frames and frames[structure] is not None: |
| | color = colors.get(structure, (255, 255, 255)) |
| | cv2.putText(combined_frame, structure, (10, y_offset), |
| | cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2) |
| | y_offset += 25 |
| | |
| | out.write(combined_frame) |
| | frame_count += 1 |
| | |
| | |
| | for cap in video_caps.values(): |
| | cap.release() |
| | out.release() |
| |
|
| | if frame_count > 0: |
| | return convert_video_to_h264(output_path) |
| | return None |
| | |
| | except Exception as e: |
| | print(f"❌ Failed to create combined segmentation: {e}") |
| | return None |
| |
|
| |
|
| | class EchoViewClassificationTool(BaseTool): |
| | """Echo view classification tool.""" |
| | |
| | name: str = "echo_view_classification" |
| | description: str = "Classify echocardiography video views using EchoPrime." |
| | args_schema: Type[BaseModel] = EchoViewClassificationInput |
| | |
| | def _run( |
| | self, |
| | input_dir: str, |
| | visualize: bool = False, |
| | max_videos: Optional[int] = None, |
| | run_manager: Optional[Any] = None, |
| | ) -> Dict[str, Any]: |
| | """Run echo view classification using real EchoPrime model.""" |
| | try: |
| | |
| | echo_prime_model = load_echo_prime_model() |
| | |
| | |
| | print(f"🔄 Processing videos from {input_dir}...") |
| | stack_of_videos = echo_prime_model.process_mp4s(input_dir) |
| | |
| | if stack_of_videos.shape[0] == 0: |
| | raise RuntimeError(f"No valid MP4 videos found in {input_dir}") |
| | |
| | |
| | if max_videos and stack_of_videos.shape[0] > max_videos: |
| | stack_of_videos = stack_of_videos[:max_videos] |
| | |
| | print(f"✅ Processed {stack_of_videos.shape[0]} videos") |
| | |
| | |
| | print("🔄 Classifying views...") |
| | view_encodings = echo_prime_model.get_views(stack_of_videos, visualize=visualize, return_view_list=True) |
| | |
| | |
| | all_classifications = [] |
| | for i, view in enumerate(view_encodings): |
| | classification = { |
| | "video": f"video_{i+1}.mp4", |
| | "predicted_view": view, |
| | "confidence": 0.85, |
| | "view_probabilities": { |
| | view: 0.85, |
| | "other": 0.15 |
| | } |
| | } |
| | all_classifications.append(classification) |
| | |
| | if not all_classifications: |
| | raise RuntimeError("No videos processed successfully") |
| | |
| | |
| | view_counts = {} |
| | for classification in all_classifications: |
| | view = classification["predicted_view"] |
| | if view not in view_counts: |
| | view_counts[view] = {"count": 0, "confidence": 0.0} |
| | view_counts[view]["count"] += 1 |
| | view_counts[view]["confidence"] = max( |
| | view_counts[view]["confidence"], |
| | classification["confidence"] |
| | ) |
| | |
| | return { |
| | "status": "success", |
| | "model": "EchoPrime", |
| | "input_dir": input_dir, |
| | "max_videos": max_videos, |
| | "processed_videos": len(all_classifications), |
| | "classifications": view_counts, |
| | "detailed_results": all_classifications, |
| | "message": f"View classification completed for {len(all_classifications)} videos using real EchoPrime model" |
| | } |
| | |
| | except Exception as e: |
| | print(f"EchoPrime view classification failed: {e}") |
| | raise RuntimeError(f"View classification failed: {e}") |
| |
|
| |
|
| | class EchoDiseasePredictionManager(BaseToolManager): |
| | """Manager for echo disease prediction tool.""" |
| | |
| | def __init__(self, model_manager=None): |
| | self.model_manager = model_manager |
| | config = ToolConfig( |
| | name="echo_disease_prediction", |
| | tool_type="disease_prediction", |
| | description="Echo disease prediction tool" |
| | ) |
| | super().__init__(config) |
| | self._initialize_tool() |
| | |
| | def _initialize_tool(self): |
| | """Initialize the disease prediction tool.""" |
| | try: |
| | self.tool = self._create_tool() |
| | self._set_status(ToolStatus.AVAILABLE) |
| | except Exception as e: |
| | print(f"Error initializing {self.config.name}: {e}") |
| | self._set_status(ToolStatus.NOT_AVAILABLE) |
| | |
| | def _create_tool(self) -> BaseTool: |
| | """Create the disease prediction tool.""" |
| | return EchoDiseasePredictionTool() |
| | |
| | def _create_fallback_tool(self) -> BaseTool: |
| | """Create fallback tool.""" |
| | return EchoDiseasePredictionTool() |
| | |
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | """Run the disease prediction tool.""" |
| | if not self.tool: |
| | return {"error": "Tool not available"} |
| | |
| | try: |
| | return self.tool._run(**input_data) |
| | except Exception as e: |
| | return {"error": f"Tool execution failed: {str(e)}"} |
| |
|
| |
|
| | class EchoImageVideoGenerationManager(BaseToolManager): |
| | """Manager for echo image/video generation tool.""" |
| | |
| | def __init__(self, model_manager=None): |
| | self.model_manager = model_manager |
| | config = ToolConfig( |
| | name="echo_image_video_generation", |
| | tool_type="generation", |
| | description="Echo image/video generation tool" |
| | ) |
| | super().__init__(config) |
| | self._initialize_tool() |
| | |
| | def _initialize_tool(self): |
| | """Initialize the image/video generation tool.""" |
| | try: |
| | self.tool = self._create_tool() |
| | self._set_status(ToolStatus.AVAILABLE) |
| | except Exception as e: |
| | print(f"Error initializing {self.config.name}: {e}") |
| | self._set_status(ToolStatus.NOT_AVAILABLE) |
| | |
| | def _create_tool(self) -> BaseTool: |
| | """Create the image/video generation tool.""" |
| | return EchoImageVideoGenerationTool() |
| | |
| | def _create_fallback_tool(self) -> BaseTool: |
| | """Create fallback tool.""" |
| | return EchoImageVideoGenerationTool() |
| | |
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | """Run the image/video generation tool.""" |
| | if not self.tool: |
| | return {"error": "Tool not available"} |
| | |
| | try: |
| | return self.tool._run(**input_data) |
| | except Exception as e: |
| | return {"error": f"Tool execution failed: {str(e)}"} |
| |
|
| |
|
| | class EchoMeasurementPredictionManager(BaseToolManager): |
| | """Manager for echo measurement prediction tool.""" |
| | |
| | def __init__(self, model_manager=None): |
| | self.model_manager = model_manager |
| | config = ToolConfig( |
| | name="echo_measurement_prediction", |
| | tool_type="measurement", |
| | description="Echo measurement prediction tool" |
| | ) |
| | super().__init__(config) |
| | self._initialize_tool() |
| | |
| | def _initialize_tool(self): |
| | """Initialize the measurement prediction tool.""" |
| | try: |
| | self.tool = self._create_tool() |
| | self._set_status(ToolStatus.AVAILABLE) |
| | except Exception as e: |
| | print(f"Error initializing {self.config.name}: {e}") |
| | self._set_status(ToolStatus.NOT_AVAILABLE) |
| | |
| | def _create_tool(self) -> BaseTool: |
| | """Create the measurement prediction tool.""" |
| | return EchoMeasurementPredictionTool() |
| | |
| | def _create_fallback_tool(self) -> BaseTool: |
| | """Create fallback tool.""" |
| | return EchoMeasurementPredictionTool() |
| | |
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | """Run the measurement prediction tool.""" |
| | if not self.tool: |
| | return {"error": "Tool not available"} |
| | |
| | try: |
| | return self.tool._run(**input_data) |
| | except Exception as e: |
| | return {"error": f"Tool execution failed: {str(e)}"} |
| |
|
| |
|
| | class EchoReportGenerationManager(BaseToolManager): |
| | """Manager for echo report generation tool.""" |
| | |
| | def __init__(self, model_manager=None): |
| | self.model_manager = model_manager |
| | config = ToolConfig( |
| | name="echo_report_generation", |
| | tool_type="report", |
| | description="Echo report generation tool" |
| | ) |
| | super().__init__(config) |
| | self._initialize_tool() |
| | |
| | def _initialize_tool(self): |
| | """Initialize the report generation tool.""" |
| | try: |
| | self.tool = self._create_tool() |
| | self._set_status(ToolStatus.AVAILABLE) |
| | except Exception as e: |
| | print(f"Error initializing {self.config.name}: {e}") |
| | self._set_status(ToolStatus.NOT_AVAILABLE) |
| | |
| | def _create_tool(self) -> BaseTool: |
| | """Create the report generation tool.""" |
| | return EchoReportGenerationTool() |
| | |
| | def _create_fallback_tool(self) -> BaseTool: |
| | """Create fallback tool.""" |
| | return EchoReportGenerationTool() |
| | |
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | """Run the report generation tool.""" |
| | if not self.tool: |
| | return {"error": "Tool not available"} |
| | |
| | try: |
| | return self.tool._run(**input_data) |
| | except Exception as e: |
| | return {"error": f"Tool execution failed: {str(e)}"} |
| |
|
| |
|
| | class EchoSegmentationManager(BaseToolManager): |
| | """Manager for echo segmentation tool.""" |
| | |
| | def __init__(self, model_manager=None): |
| | self.model_manager = model_manager |
| | config = ToolConfig( |
| | name="echo_segmentation", |
| | tool_type="segmentation", |
| | description="Echo segmentation tool" |
| | ) |
| | super().__init__(config) |
| | self._initialize_tool() |
| | |
| | def _initialize_tool(self): |
| | """Initialize the segmentation tool.""" |
| | try: |
| | self.tool = self._create_tool() |
| | self._set_status(ToolStatus.AVAILABLE) |
| | except Exception as e: |
| | print(f"Error initializing {self.config.name}: {e}") |
| | self._set_status(ToolStatus.NOT_AVAILABLE) |
| | |
| | def _create_tool(self) -> BaseTool: |
| | """Create the segmentation tool.""" |
| | return EchoSegmentationTool() |
| | |
| | def _create_fallback_tool(self) -> BaseTool: |
| | """Create fallback tool.""" |
| | return EchoSegmentationTool() |
| | |
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | """Run the segmentation tool.""" |
| | if not self.tool: |
| | return {"error": "Tool not available"} |
| | |
| | try: |
| | prepared_inputs = dict(input_data) if input_data else {} |
| |
|
| | |
| | has_user_prompts = any( |
| | key in prepared_inputs and prepared_inputs[key] |
| | for key in ("initial_masks_dir", "initial_mask_paths", "mask_path", "points", "box") |
| | ) |
| | if not has_user_prompts: |
| | if DEFAULT_ECHO_SEGMENTATION_MASK.exists(): |
| | if DEFAULT_ECHO_SEGMENTATION_MASK_DIR.exists(): |
| | mask_dir = DEFAULT_ECHO_SEGMENTATION_MASK_DIR |
| | available_paths = { |
| | structure: str(mask_dir / filename) |
| | for structure, filename in DEFAULT_ECHO_SEGMENTATION_STRUCTURES.items() |
| | if (mask_dir / filename).exists() |
| | } |
| | if available_paths: |
| | prepared_inputs.setdefault("initial_mask_paths", available_paths) |
| | else: |
| | print( |
| | f"⚠️ Default mask directory not found at {DEFAULT_ECHO_SEGMENTATION_MASK_DIR}; " |
| | "consider generating per-structure masks." |
| | ) |
| | else: |
| | print( |
| | f"⚠️ Default segmentation mask not found at {DEFAULT_ECHO_SEGMENTATION_MASK}; " |
| | "falling back to configured prompt." |
| | ) |
| |
|
| | return self.tool._run(**prepared_inputs) |
| | except Exception as e: |
| | return {"error": f"Tool execution failed: {str(e)}"} |
| |
|
| |
|
| | class EchoViewClassificationManager(BaseToolManager): |
| | """Manager for echo view classification tool.""" |
| | |
| | def __init__(self, model_manager=None): |
| | self.model_manager = model_manager |
| | config = ToolConfig( |
| | name="echo_view_classification", |
| | tool_type="classification", |
| | description="Echo view classification tool" |
| | ) |
| | super().__init__(config) |
| | self._initialize_tool() |
| | |
| | def _initialize_tool(self): |
| | """Initialize the view classification tool.""" |
| | try: |
| | self.tool = self._create_tool() |
| | self._set_status(ToolStatus.AVAILABLE) |
| | except Exception as e: |
| | print(f"Error initializing {self.config.name}: {e}") |
| | self._set_status(ToolStatus.NOT_AVAILABLE) |
| | |
| | def _create_tool(self) -> BaseTool: |
| | """Create the view classification tool.""" |
| | return EchoViewClassificationTool() |
| | |
| | def _create_fallback_tool(self) -> BaseTool: |
| | """Create fallback tool.""" |
| | return EchoViewClassificationTool() |
| | |
| | def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
| | """Run the view classification tool.""" |
| | if not self.tool: |
| | return {"error": "Tool not available"} |
| | |
| | try: |
| | return self.tool._run(**input_data) |
| | except Exception as e: |
| | return {"error": f"Tool execution failed: {str(e)}"} |
| |
|