| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | PitVQA Multi-Agent Orchestration System |
| | |
| | Specialized agents for methodologically rigorous VLM pipeline management: |
| | 1. JobMonitorAgent - Track HuggingFace Jobs status |
| | 2. CurationAgent - Quality-filter showcase examples |
| | 3. DatasetAgent - Validate image-embedded dataset |
| | 4. ModelVerifierAgent - Test merged model outputs |
| | 5. DemoSyncAgent - Update Gradio Space with results |
| | |
| | Run with: python pitvqa_agent_orchestrator.py |
| | """ |
| |
|
| | import os |
| | import json |
| | import time |
| | from dataclasses import dataclass |
| | from typing import Dict, List, Optional, Any |
| | from datetime import datetime |
| | from enum import Enum |
| |
|
| | |
| | |
| | |
| |
|
| | class AgentStatus(Enum): |
| | IDLE = "idle" |
| | RUNNING = "running" |
| | SUCCESS = "success" |
| | FAILED = "failed" |
| | WAITING = "waiting" |
| |
|
| | @dataclass |
| | class AgentResult: |
| | agent_name: str |
| | status: AgentStatus |
| | message: str |
| | data: Optional[Dict] = None |
| | timestamp: str = "" |
| |
|
| | def __post_init__(self): |
| | if not self.timestamp: |
| | self.timestamp = datetime.now().isoformat() |
| |
|
| | |
| | |
| | |
| |
|
| | class BaseAgent: |
| | """Base class for all PitVQA agents.""" |
| |
|
| | def __init__(self, name: str): |
| | self.name = name |
| | self.status = AgentStatus.IDLE |
| | self.results: List[AgentResult] = [] |
| |
|
| | def log(self, message: str, level: str = "INFO"): |
| | icon = {"INFO": "βΉοΈ", "SUCCESS": "β
", "ERROR": "β", "WARN": "β οΈ"}.get(level, "π") |
| | print(f"[{self.name}] {icon} {message}") |
| |
|
| | def run(self) -> AgentResult: |
| | raise NotImplementedError |
| |
|
| | def report(self) -> Dict: |
| | return { |
| | "agent": self.name, |
| | "status": self.status.value, |
| | "results": [r.__dict__ for r in self.results] |
| | } |
| |
|
| | |
| | |
| | |
| |
|
| | class JobMonitorAgent(BaseAgent): |
| | """Monitors HuggingFace Jobs and reports status.""" |
| |
|
| | def __init__(self, job_ids: List[str]): |
| | super().__init__("JobMonitor") |
| | self.job_ids = job_ids |
| | self.job_status = {} |
| |
|
| | def check_job(self, job_id: str) -> Dict: |
| | """Check single job status using HF API.""" |
| | try: |
| | from huggingface_hub import HfApi |
| | api = HfApi() |
| |
|
| | |
| | job = api.get_job(job_id) |
| | return { |
| | "id": job_id, |
| | "status": job.status.stage if hasattr(job.status, 'stage') else str(job.status), |
| | "message": job.status.message if hasattr(job.status, 'message') else None |
| | } |
| | except Exception as e: |
| | return {"id": job_id, "status": "UNKNOWN", "error": str(e)} |
| |
|
| | def run(self) -> AgentResult: |
| | self.status = AgentStatus.RUNNING |
| | self.log(f"Checking {len(self.job_ids)} jobs...") |
| |
|
| | all_complete = True |
| | any_failed = False |
| |
|
| | for job_id in self.job_ids: |
| | status = self.check_job(job_id) |
| | self.job_status[job_id] = status |
| |
|
| | stage = status.get("status", "UNKNOWN") |
| | self.log(f"Job {job_id[:8]}: {stage}") |
| |
|
| | if stage not in ["COMPLETED", "SUCCESS"]: |
| | all_complete = False |
| | if stage in ["FAILED", "ERROR"]: |
| | any_failed = True |
| |
|
| | if any_failed: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult(self.name, AgentStatus.FAILED, "Some jobs failed", self.job_status) |
| | elif all_complete: |
| | self.status = AgentStatus.SUCCESS |
| | return AgentResult(self.name, AgentStatus.SUCCESS, "All jobs complete", self.job_status) |
| | else: |
| | self.status = AgentStatus.WAITING |
| | return AgentResult(self.name, AgentStatus.WAITING, "Jobs still running", self.job_status) |
| |
|
| | |
| | |
| | |
| |
|
| | class CurationAgent(BaseAgent): |
| | """Curates showcase examples based on quality criteria.""" |
| |
|
| | QUALITY_CRITERIA = { |
| | "coordinate_validity": lambda x, y: 0 <= x <= 100 and 0 <= y <= 100, |
| | "coordinate_diversity": lambda coords: len(set(coords)) > len(coords) * 0.5, |
| | "video_diversity": lambda vids: len(set(vids)) >= min(5, len(vids)), |
| | "frame_diversity": lambda frames: len(set(frames)) >= min(8, len(frames)), |
| | } |
| |
|
| | def __init__(self, results_path: str = "./curation_review/all_results.json"): |
| | super().__init__("Curation") |
| | self.results_path = results_path |
| | self.curated_examples = [] |
| |
|
| | def load_results(self) -> List[Dict]: |
| | """Load raw curation results.""" |
| | try: |
| | with open(self.results_path) as f: |
| | return json.load(f) |
| | except FileNotFoundError: |
| | self.log("Results file not found - job may still be running", "WARN") |
| | return [] |
| |
|
| | def score_example(self, example: Dict) -> float: |
| | """Score a single example (0-1).""" |
| | score = 0.0 |
| |
|
| | |
| | if example.get("success"): |
| | score += 0.3 |
| |
|
| | |
| | if example.get("task") == "point": |
| | x, y = example.get("x"), example.get("y") |
| | if x and y: |
| | |
| | if 10 < x < 90 and 10 < y < 90: |
| | score += 0.3 |
| | else: |
| | score += 0.1 |
| | elif example.get("task") == "bbox": |
| | bbox = example.get("bbox") |
| | if bbox and len(bbox) == 4: |
| | |
| | area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
| | if 100 < area < 5000: |
| | score += 0.3 |
| | else: |
| | score += 0.1 |
| |
|
| | |
| | response = example.get("response", "") |
| | if "<point" in response or "<box" in response: |
| | score += 0.2 |
| |
|
| | |
| | target = example.get("target", "") |
| | if target in response.lower(): |
| | score += 0.2 |
| |
|
| | return min(score, 1.0) |
| |
|
| | def curate(self, results: List[Dict], top_k: int = 12) -> List[Dict]: |
| | """Select best diverse examples.""" |
| | if not results: |
| | return [] |
| |
|
| | |
| | scored = [(self.score_example(ex), ex) for ex in results if ex.get("success")] |
| | scored.sort(key=lambda x: x[0], reverse=True) |
| |
|
| | |
| | curated = [] |
| | used_videos = set() |
| | used_frames = set() |
| | used_tasks = {"point": 0, "bbox": 0} |
| |
|
| | for score, ex in scored: |
| | if len(curated) >= top_k: |
| | break |
| |
|
| | video = ex.get("video_id") |
| | frame = ex.get("frame_idx") |
| | task = ex.get("task") |
| |
|
| | |
| | if used_videos.count(video) >= 2: |
| | continue |
| | if (video, frame) in used_frames: |
| | continue |
| | if used_tasks.get(task, 0) >= top_k // 2: |
| | continue |
| |
|
| | curated.append({**ex, "quality_score": score}) |
| | used_videos.add(video) |
| | used_frames.add((video, frame)) |
| | used_tasks[task] = used_tasks.get(task, 0) + 1 |
| |
|
| | return curated |
| |
|
| | def run(self) -> AgentResult: |
| | self.status = AgentStatus.RUNNING |
| | self.log("Loading curation results...") |
| |
|
| | results = self.load_results() |
| | if not results: |
| | self.status = AgentStatus.WAITING |
| | return AgentResult(self.name, AgentStatus.WAITING, "No results available yet") |
| |
|
| | self.log(f"Scoring {len(results)} examples...") |
| | self.curated_examples = self.curate(results) |
| |
|
| | if len(self.curated_examples) >= 8: |
| | self.status = AgentStatus.SUCCESS |
| |
|
| | |
| | videos = set(ex["video_id"] for ex in self.curated_examples) |
| | frames = set(ex["frame_idx"] for ex in self.curated_examples) |
| |
|
| | self.log(f"Curated {len(self.curated_examples)} examples", "SUCCESS") |
| | self.log(f" Videos: {len(videos)} unique") |
| | self.log(f" Frames: {len(frames)} unique") |
| |
|
| | return AgentResult( |
| | self.name, |
| | AgentStatus.SUCCESS, |
| | f"Curated {len(self.curated_examples)} high-quality diverse examples", |
| | {"examples": self.curated_examples} |
| | ) |
| | else: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.FAILED, |
| | f"Only {len(self.curated_examples)} examples passed quality checks" |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | class DatasetValidatorAgent(BaseAgent): |
| | """Validates image-embedded dataset quality.""" |
| |
|
| | def __init__(self, dataset_id: str = "mmrech/pitvqa-spatial-with-images"): |
| | super().__init__("DatasetValidator") |
| | self.dataset_id = dataset_id |
| |
|
| | def run(self) -> AgentResult: |
| | self.status = AgentStatus.RUNNING |
| | self.log(f"Validating dataset: {self.dataset_id}") |
| |
|
| | try: |
| | from datasets import load_dataset |
| |
|
| | |
| | ds = load_dataset(self.dataset_id, split="train[:10]") |
| |
|
| | |
| | required_fields = ["image", "messages"] |
| | missing = [f for f in required_fields if f not in ds.features] |
| |
|
| | if missing: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.FAILED, |
| | f"Missing fields: {missing}" |
| | ) |
| |
|
| | |
| | valid_images = 0 |
| | for ex in ds: |
| | img = ex.get("image") |
| | if img and hasattr(img, "size") and img.size[0] > 0: |
| | valid_images += 1 |
| |
|
| | if valid_images == len(ds): |
| | self.status = AgentStatus.SUCCESS |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.SUCCESS, |
| | f"Dataset valid: {valid_images}/{len(ds)} images OK", |
| | {"sample_count": len(ds), "valid_images": valid_images} |
| | ) |
| | else: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.FAILED, |
| | f"Invalid images: {len(ds) - valid_images}/{len(ds)}" |
| | ) |
| |
|
| | except Exception as e: |
| | self.status = AgentStatus.WAITING |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.WAITING, |
| | f"Dataset not yet available: {e}" |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | class ModelVerifierAgent(BaseAgent): |
| | """Verifies merged model outputs are correct.""" |
| |
|
| | TEST_PROMPTS = [ |
| | ("Point to the suction device", "point"), |
| | ("Draw a bounding box around the surgical instrument", "bbox"), |
| | ("What surgical phase is this?", "classification"), |
| | ] |
| |
|
| | def __init__(self, model_id: str = "mmrech/pitvqa-qwen2vl-merged"): |
| | super().__init__("ModelVerifier") |
| | self.model_id = model_id |
| |
|
| | def run(self) -> AgentResult: |
| | self.status = AgentStatus.RUNNING |
| | self.log(f"Verifying model: {self.model_id}") |
| |
|
| | try: |
| | from huggingface_hub import HfApi |
| | api = HfApi() |
| |
|
| | |
| | try: |
| | info = api.model_info(self.model_id) |
| | self.log(f"Model found: {info.modelId}") |
| |
|
| | |
| | files = [f.rfilename for f in info.siblings] |
| | required = ["config.json", "model.safetensors"] |
| |
|
| | |
| | has_model = any("safetensors" in f or "pytorch" in f for f in files) |
| | has_config = "config.json" in files |
| |
|
| | if has_model and has_config: |
| | self.status = AgentStatus.SUCCESS |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.SUCCESS, |
| | f"Model verified: {len(files)} files present", |
| | {"files": files[:10]} |
| | ) |
| | else: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.FAILED, |
| | f"Missing model files (has_model={has_model}, has_config={has_config})" |
| | ) |
| |
|
| | except Exception as e: |
| | self.status = AgentStatus.WAITING |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.WAITING, |
| | f"Model not yet available: {e}" |
| | ) |
| |
|
| | except Exception as e: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") |
| |
|
| | |
| | |
| | |
| |
|
| | class TrainingSpecialistAgent(BaseAgent): |
| | """ |
| | Specialist in HuggingFace LLM Training (TRL/SFT/LoRA/DPO). |
| | |
| | Responsibilities: |
| | - Validate training configurations |
| | - Check adapter quality |
| | - Recommend training improvements |
| | - Verify LoRA/PEFT setup |
| | """ |
| |
|
| | TRAINING_METHODS = { |
| | "SFT": "Supervised Fine-Tuning - learning from (input, output) pairs", |
| | "LoRA": "Low-Rank Adaptation - parameter-efficient adapters", |
| | "DPO": "Direct Preference Optimization - learning from preferences", |
| | "RLHF": "Reinforcement Learning from Human Feedback", |
| | } |
| |
|
| | OPTIMAL_CONFIG = { |
| | "lora_r": 16, |
| | "lora_alpha": 32, |
| | "learning_rate": 1e-4, |
| | "batch_size": 1, |
| | "gradient_accumulation_steps": 16, |
| | "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"], |
| | } |
| |
|
| | def __init__(self, adapter_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): |
| | super().__init__("TrainingSpecialist") |
| | self.adapter_repo = adapter_repo |
| |
|
| | def validate_adapter_config(self) -> Dict: |
| | """Validate adapter configuration.""" |
| | try: |
| | from huggingface_hub import hf_hub_download |
| | import json |
| |
|
| | |
| | config_path = hf_hub_download( |
| | repo_id=self.adapter_repo, |
| | filename="stage4/adapter_config.json" |
| | ) |
| |
|
| | with open(config_path) as f: |
| | config = json.load(f) |
| |
|
| | |
| | issues = [] |
| | recommendations = [] |
| |
|
| | |
| | if config.get("r", 0) < 8: |
| | issues.append("LoRA rank too low (r < 8)") |
| | elif config.get("r", 0) > 64: |
| | recommendations.append("Consider reducing LoRA rank for efficiency") |
| |
|
| | |
| | target_modules = config.get("target_modules", []) |
| | if not any("proj" in m for m in target_modules): |
| | issues.append("No projection layers targeted") |
| |
|
| | return { |
| | "config": config, |
| | "issues": issues, |
| | "recommendations": recommendations, |
| | "valid": len(issues) == 0 |
| | } |
| |
|
| | except Exception as e: |
| | return {"error": str(e), "valid": False} |
| |
|
| | def recommend_next_training(self, current_metrics: Dict = None) -> Dict: |
| | """Recommend next training steps based on current metrics.""" |
| | recommendations = [] |
| |
|
| | if not current_metrics: |
| | recommendations.append({ |
| | "priority": "HIGH", |
| | "action": "Run evaluation to get baseline metrics", |
| | "method": "scripts/evaluate_unified_vlm.py" |
| | }) |
| | else: |
| | accuracy = current_metrics.get("accuracy", 0) |
| |
|
| | if accuracy < 0.7: |
| | recommendations.append({ |
| | "priority": "HIGH", |
| | "action": "Increase training epochs or data", |
| | "method": "SFT with more epochs" |
| | }) |
| |
|
| | if accuracy >= 0.7 and accuracy < 0.85: |
| | recommendations.append({ |
| | "priority": "MEDIUM", |
| | "action": "Consider DPO for preference learning", |
| | "method": "Create chosen/rejected pairs from predictions" |
| | }) |
| |
|
| | if accuracy >= 0.85: |
| | recommendations.append({ |
| | "priority": "LOW", |
| | "action": "Model performing well - focus on inference optimization", |
| | "method": "Merge adapters, quantize for deployment" |
| | }) |
| |
|
| | return {"recommendations": recommendations} |
| |
|
| | def run(self) -> AgentResult: |
| | self.status = AgentStatus.RUNNING |
| | self.log(f"Validating training setup: {self.adapter_repo}") |
| |
|
| | |
| | validation = self.validate_adapter_config() |
| |
|
| | if validation.get("valid"): |
| | self.status = AgentStatus.SUCCESS |
| | recommendations = self.recommend_next_training() |
| |
|
| | return AgentResult( |
| | self.name, |
| | AgentStatus.SUCCESS, |
| | f"Training config valid. LoRA r={validation['config'].get('r')}", |
| | { |
| | "config": validation["config"], |
| | "recommendations": recommendations["recommendations"] |
| | } |
| | ) |
| | elif validation.get("error"): |
| | self.status = AgentStatus.WAITING |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.WAITING, |
| | f"Could not load adapter: {validation['error']}" |
| | ) |
| | else: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.FAILED, |
| | f"Issues found: {validation['issues']}", |
| | validation |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | class EvaluationSpecialistAgent(BaseAgent): |
| | """ |
| | Specialist in Model Evaluation (metrics, benchmarks, validation). |
| | |
| | Responsibilities: |
| | - Compute accuracy, F1, precision, recall |
| | - Validate coordinate predictions (MAE, quadrant accuracy) |
| | - Compare against baselines |
| | - Generate evaluation reports |
| | """ |
| |
|
| | METRICS = { |
| | "classification": ["accuracy", "f1", "precision", "recall"], |
| | "localization": ["mae", "quadrant_accuracy", "distance_error"], |
| | "detection": ["iou", "ap", "ar"], |
| | } |
| |
|
| | THRESHOLDS = { |
| | "quadrant_accuracy": 0.75, |
| | "mae": 15.0, |
| | "classification_accuracy": 0.80, |
| | } |
| |
|
| | def __init__(self, model_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): |
| | super().__init__("EvaluationSpecialist") |
| | self.model_repo = model_repo |
| | self.metrics = {} |
| |
|
| | def load_evaluation_results(self) -> Dict: |
| | """Load existing evaluation results if available.""" |
| | try: |
| | with open("evaluation_results.json") as f: |
| | return json.load(f) |
| | except FileNotFoundError: |
| | return {} |
| |
|
| | def compute_quick_metrics(self, predictions: List[Dict]) -> Dict: |
| | """Compute quick metrics from predictions.""" |
| | if not predictions: |
| | return {} |
| |
|
| | metrics = {} |
| |
|
| | |
| | coord_preds = [p for p in predictions if p.get("task") in ["point", "pointing"]] |
| | if coord_preds: |
| | valid = [p for p in coord_preds if p.get("x") is not None] |
| | metrics["valid_rate"] = len(valid) / len(coord_preds) |
| |
|
| | |
| | errors = [] |
| | for p in valid: |
| | if p.get("gt_x") and p.get("gt_y"): |
| | err = ((p["x"] - p["gt_x"])**2 + (p["y"] - p["gt_y"])**2)**0.5 |
| | errors.append(err) |
| |
|
| | if errors: |
| | metrics["mae"] = sum(errors) / len(errors) |
| | metrics["quadrant_accuracy"] = sum(1 for e in errors if e < 25) / len(errors) |
| |
|
| | |
| | class_preds = [p for p in predictions if p.get("task") == "classification"] |
| | if class_preds: |
| | correct = sum(1 for p in class_preds if p.get("prediction") == p.get("ground_truth")) |
| | metrics["classification_accuracy"] = correct / len(class_preds) |
| |
|
| | return metrics |
| |
|
| | def evaluate_against_thresholds(self, metrics: Dict) -> Dict: |
| | """Check metrics against quality thresholds.""" |
| | results = {"passed": [], "failed": [], "warnings": []} |
| |
|
| | for metric, threshold in self.THRESHOLDS.items(): |
| | if metric in metrics: |
| | value = metrics[metric] |
| | if metric == "mae": |
| | passed = value <= threshold |
| | else: |
| | passed = value >= threshold |
| |
|
| | entry = {"metric": metric, "value": value, "threshold": threshold} |
| | if passed: |
| | results["passed"].append(entry) |
| | else: |
| | results["failed"].append(entry) |
| |
|
| | return results |
| |
|
| | def generate_report(self, metrics: Dict, threshold_results: Dict) -> str: |
| | """Generate evaluation report.""" |
| | report = [] |
| | report.append("=" * 50) |
| | report.append("EVALUATION REPORT") |
| | report.append("=" * 50) |
| |
|
| | report.append("\nπ METRICS:") |
| | for k, v in metrics.items(): |
| | report.append(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") |
| |
|
| | report.append("\nβ
PASSED:") |
| | for item in threshold_results["passed"]: |
| | report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") |
| |
|
| | if threshold_results["failed"]: |
| | report.append("\nβ FAILED:") |
| | for item in threshold_results["failed"]: |
| | report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") |
| |
|
| | return "\n".join(report) |
| |
|
| | def run(self, predictions: List[Dict] = None) -> AgentResult: |
| | self.status = AgentStatus.RUNNING |
| | self.log("Running evaluation...") |
| |
|
| | |
| | existing = self.load_evaluation_results() |
| |
|
| | if existing: |
| | self.log("Found existing evaluation results") |
| | self.metrics = existing |
| | elif predictions: |
| | self.log(f"Computing metrics from {len(predictions)} predictions") |
| | self.metrics = self.compute_quick_metrics(predictions) |
| | else: |
| | self.status = AgentStatus.WAITING |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.WAITING, |
| | "No predictions available for evaluation" |
| | ) |
| |
|
| | |
| | threshold_results = self.evaluate_against_thresholds(self.metrics) |
| |
|
| | |
| | report = self.generate_report(self.metrics, threshold_results) |
| | self.log(f"\n{report}") |
| |
|
| | if threshold_results["failed"]: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.FAILED, |
| | f"{len(threshold_results['failed'])} metrics below threshold", |
| | {"metrics": self.metrics, "thresholds": threshold_results} |
| | ) |
| | else: |
| | self.status = AgentStatus.SUCCESS |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.SUCCESS, |
| | f"All {len(threshold_results['passed'])} metrics passed", |
| | {"metrics": self.metrics, "thresholds": threshold_results} |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | class DemoSyncAgent(BaseAgent): |
| | """Syncs curated examples to Gradio Space.""" |
| |
|
| | def __init__(self, space_id: str = "mmrech/pitvqa-surgical-vlm"): |
| | super().__init__("DemoSync") |
| | self.space_id = space_id |
| |
|
| | def run(self, curated_examples: List[Dict] = None) -> AgentResult: |
| | self.status = AgentStatus.RUNNING |
| | self.log(f"Syncing to Space: {self.space_id}") |
| |
|
| | if not curated_examples: |
| | self.status = AgentStatus.WAITING |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.WAITING, |
| | "No curated examples to sync" |
| | ) |
| |
|
| | try: |
| | from huggingface_hub import HfApi |
| | api = HfApi() |
| |
|
| | |
| | try: |
| | info = api.space_info(self.space_id) |
| | runtime = info.runtime |
| |
|
| | if runtime and runtime.stage == "RUNNING": |
| | self.log(f"Space is running", "SUCCESS") |
| |
|
| | |
| | examples_json = json.dumps(curated_examples, indent=2) |
| |
|
| | self.status = AgentStatus.SUCCESS |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.SUCCESS, |
| | f"Space running, {len(curated_examples)} examples ready for sync", |
| | {"space_status": "RUNNING", "examples_count": len(curated_examples)} |
| | ) |
| | else: |
| | self.status = AgentStatus.WAITING |
| | return AgentResult( |
| | self.name, |
| | AgentStatus.WAITING, |
| | f"Space not running: {runtime.stage if runtime else 'unknown'}" |
| | ) |
| |
|
| | except Exception as e: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult(self.name, AgentStatus.FAILED, f"Space error: {e}") |
| |
|
| | except Exception as e: |
| | self.status = AgentStatus.FAILED |
| | return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") |
| |
|
| | |
| | |
| | |
| |
|
| | class PitVQAOrchestrator: |
| | """Coordinates all agents for the PitVQA pipeline.""" |
| |
|
| | def __init__(self, job_ids: List[str]): |
| | self.agents = { |
| | "monitor": JobMonitorAgent(job_ids), |
| | "curation": CurationAgent(), |
| | "dataset": DatasetValidatorAgent(), |
| | "model": ModelVerifierAgent(), |
| | "training": TrainingSpecialistAgent(), |
| | "evaluation": EvaluationSpecialistAgent(), |
| | "demo": DemoSyncAgent(), |
| | } |
| | self.results = {} |
| | self.run_count = 0 |
| |
|
| | def run_cycle(self) -> Dict: |
| | """Run one orchestration cycle.""" |
| | self.run_count += 1 |
| | print(f"\n{'='*60}") |
| | print(f"π ORCHESTRATION CYCLE {self.run_count}") |
| | print(f"{'='*60}") |
| |
|
| | |
| | print("\nπ Phase 1: Job Monitoring") |
| | monitor_result = self.agents["monitor"].run() |
| | self.results["monitor"] = monitor_result |
| |
|
| | |
| | print("\nπ Phase 2: Training Validation (HF-LLM-Trainer)") |
| | training_result = self.agents["training"].run() |
| | self.results["training"] = training_result |
| |
|
| | |
| | if monitor_result.status in [AgentStatus.SUCCESS, AgentStatus.WAITING]: |
| |
|
| | |
| | print("\nπ¨ Phase 3: Curation") |
| | curation_result = self.agents["curation"].run() |
| | self.results["curation"] = curation_result |
| |
|
| | |
| | print("\nπ¦ Phase 4: Dataset Validation") |
| | dataset_result = self.agents["dataset"].run() |
| | self.results["dataset"] = dataset_result |
| |
|
| | |
| | print("\nπ€ Phase 5: Model Verification") |
| | model_result = self.agents["model"].run() |
| | self.results["model"] = model_result |
| |
|
| | |
| | print("\nπ Phase 6: Evaluation (Metrics & Quality)") |
| | curated = curation_result.data.get("examples", []) if curation_result.data else [] |
| | eval_result = self.agents["evaluation"].run(predictions=curated) |
| | self.results["evaluation"] = eval_result |
| |
|
| | |
| | print("\nπ Phase 7: Demo Sync") |
| | demo_result = self.agents["demo"].run(curated) |
| | self.results["demo"] = demo_result |
| |
|
| | return self.generate_report() |
| |
|
| | def generate_report(self) -> Dict: |
| | """Generate comprehensive status report.""" |
| | report = { |
| | "timestamp": datetime.now().isoformat(), |
| | "cycle": self.run_count, |
| | "overall_status": self._compute_overall_status(), |
| | "agents": {} |
| | } |
| |
|
| | for name, result in self.results.items(): |
| | report["agents"][name] = { |
| | "status": result.status.value, |
| | "message": result.message |
| | } |
| |
|
| | return report |
| |
|
| | def _compute_overall_status(self) -> str: |
| | """Compute overall pipeline status.""" |
| | statuses = [r.status for r in self.results.values()] |
| |
|
| | if all(s == AgentStatus.SUCCESS for s in statuses): |
| | return "COMPLETE" |
| | elif any(s == AgentStatus.FAILED for s in statuses): |
| | return "NEEDS_ATTENTION" |
| | elif any(s == AgentStatus.WAITING for s in statuses): |
| | return "IN_PROGRESS" |
| | else: |
| | return "UNKNOWN" |
| |
|
| | def print_summary(self, report: Dict): |
| | """Print human-readable summary.""" |
| | print(f"\n{'='*60}") |
| | print("π ORCHESTRATION SUMMARY") |
| | print(f"{'='*60}") |
| | print(f"Time: {report['timestamp']}") |
| | print(f"Cycle: {report['cycle']}") |
| | print(f"Overall: {report['overall_status']}") |
| | print("\nAgent Status:") |
| | for name, info in report["agents"].items(): |
| | icon = {"success": "β
", "failed": "β", "waiting": "β³", "running": "π"}.get(info["status"], "β") |
| | print(f" {icon} {name}: {info['status']} - {info['message'][:50]}") |
| |
|
| | |
| | |
| | |
| |
|
| | def main(): |
| | print("π PitVQA Multi-Agent Orchestrator Starting...") |
| |
|
| | |
| | job_ids = [ |
| | "696cfe9946affbb321046bd9", |
| | "696cfebf57a10a9d296ca042", |
| | ] |
| |
|
| | orchestrator = PitVQAOrchestrator(job_ids) |
| |
|
| | |
| | report = orchestrator.run_cycle() |
| | orchestrator.print_summary(report) |
| |
|
| | |
| | with open("orchestration_report.json", "w") as f: |
| | json.dump(report, f, indent=2) |
| | print(f"\nπΎ Report saved to orchestration_report.json") |
| |
|
| | return report |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|