import torch
import numpy as np
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from ultralytics import YOLO
from typing import Dict, List, Tuple, Union, Optional
from dataclasses import dataclass

@dataclass
class SegmentationResult:
    """Data class to store segmentation results"""
    label: str
    confidence: float
    mask: np.ndarray
    bounding_box: List[int]

class ObjectSegmenter:
    """A class for zero-shot object detection and segmentation"""
    def __init__(self, device: Optional[str] = None):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        torch.cuda.empty_cache()
        self._init_models()

    def _init_models(self):
        """Initialize DINO and YOLO models"""
        # Grounding DINO setup
        self.dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-tiny")
        self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
            "IDEA-Research/grounding-dino-tiny"
        ).to(self.device).eval()
        
        # YOLO setup
        self.yolo_model = YOLO('yolov8n-seg.pt')

    def segment_objects(
        self,
        image: Union[Image.Image, np.ndarray, str],
        objects: Union[str, List[str]],
        box_threshold: float = 0.4,
        text_threshold: float = 0.3
    ) -> List[SegmentationResult]:
        """Segment specified objects in the image"""
        # Prepare image
        if isinstance(image, str):
            image = Image.open(image)
        elif isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        if image.mode != 'RGB':
            image = image.convert('RGB')

        # Prepare text prompt
        if isinstance(objects, list):
            text_prompt = ". ".join(objects)
        else:
            text_prompt = objects
        if not text_prompt.endswith('.'):
            text_prompt += '.'

        # Get DINO detections
        dino_results = self._get_dino_detections(
            image, text_prompt, box_threshold, text_threshold
        )
        
        # Get YOLO segmentation
        yolo_results = self.yolo_model(image, verbose=False)[0]
        
        # Match detections with segmentations
        return self._process_results(dino_results, yolo_results)

    @torch.no_grad()
    def _get_dino_detections(
        self, 
        image: Image.Image, 
        text_prompt: str,
        box_threshold: float,
        text_threshold: float
    ) -> dict:
        """Get object detections from Grounding DINO"""
        inputs = self.dino_processor(
            images=image, 
            text=text_prompt, 
            return_tensors="pt"
        ).to(self.device)
        
        outputs = self.dino_model(**inputs)
        results = self.dino_processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=box_threshold,
            text_threshold=text_threshold,
            target_sizes=[image.size[::-1]]
        )[0]
        
        return results

    def _process_results(
        self, 
        dino_results: dict,
        yolo_results
    ) -> List[SegmentationResult]:
        """Match detections with segmentations and create result objects"""
        segmentation_results = []

        for box, score, label in zip(
            dino_results["boxes"],
            dino_results["scores"],
            dino_results["labels"]
        ):
            box = [int(x) for x in box.tolist()]
            
            # Find best matching YOLO mask
            best_mask = self._find_best_mask(box, yolo_results)
            
            if best_mask is not None:
                result = SegmentationResult(
                    label=label,
                    confidence=float(score),
                    mask=best_mask,
                    bounding_box=box
                )
                segmentation_results.append(result)

        return segmentation_results

    def _find_best_mask(self, box: List[int], yolo_results) -> Optional[np.ndarray]:
        """Find best matching YOLO mask for a given bounding box"""
        if len(yolo_results.masks) == 0:
            return None

        best_iou = 0
        best_mask = None

        for mask in yolo_results.masks.data:
            mask_np = mask.cpu().numpy()
            y_indices, x_indices = np.where(mask_np > 0)
            if len(y_indices) == 0:
                continue
            
            mask_box = [
                x_indices.min(),
                y_indices.min(),
                x_indices.max(),
                y_indices.max()
            ]
            
            iou = self._calculate_iou(box, mask_box)
            if iou > best_iou:
                best_iou = iou
                best_mask = mask_np

        return best_mask

    @staticmethod
    def _calculate_iou(box1: List[int], box2: List[int]) -> float:
        """Calculate Intersection over Union between two boxes"""
        intersection = max(0, min(box1[2], box2[2]) - max(box1[0], box2[0])) * \
                      max(0, min(box1[3], box2[3]) - max(box1[1], box2[1]))
        box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
        box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
        return intersection / (box1_area + box2_area - intersection)

# Initialize the segmenter
segmenter = ObjectSegmenter()