File size: 3,864 Bytes
796780d
 
41e99e7
796780d
 
 
 
bf7dfcc
41e99e7
bf7dfcc
305c627
 
 
 
 
 
bf7dfcc
796780d
bf7dfcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305c627
8ba5658
796780d
 
 
 
 
 
 
 
 
 
 
 
 
bf7dfcc
 
 
 
 
 
 
 
 
796780d
bf7dfcc
 
 
 
 
796780d
bf7dfcc
 
 
 
 
 
796780d
 
 
 
bf7dfcc
 
 
 
 
796780d
bf7dfcc
 
796780d
bf7dfcc
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from typing import Dict, List, Any, Union
from sam2.sam2_image_predictor import SAM2ImagePredictor
import torch
import numpy as np
from PIL import Image
import io
import base64
from huggingface_hub import InferenceEndpoint

class EndpointHandler(InferenceEndpoint):
    def __init__(self, model_dir=None):
        """Initialize the handler with mock predictor for local testing
        
        Args:
            model_dir (str, optional): Path to model directory. Defaults to None.
        """
        # Comment out real model for local testing
        self.predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-small")
        
        # Mock predictor for local testing
        # class MockPredictor:
        #     def set_image(self, image):
        #         print(f"Mock: set_image called with shape {image.shape}")
                
        #     def predict(self, point_coords=None, point_labels=None):
        #         print("Mock: predict called")
        #         if point_coords is not None:
        #             print(f"Mock: with point coords {point_coords}")
        #             print(f"Mock: with point labels {point_labels}")
        #             # Return mock mask focused around the point
        #             mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(1)]
        #             mock_scores = np.array([0.95])  # Higher confidence for point prompt
        #         else:
        #             # Return multiple mock masks for automatic mode
        #             mock_masks = [np.zeros((100, 100), dtype=bool) for _ in range(3)]
        #             mock_scores = np.array([0.9, 0.8, 0.7])
        #         return mock_masks, mock_scores, None
                
        # self.predictor = MockPredictor()

    def _load_image(self, image_data: Union[str, bytes]) -> Image.Image:
        """Load image from binary or base64 data"""
        try:
            # Handle base64 encoded data
            if isinstance(image_data, str):
                image_data = base64.b64decode(image_data)
            
            # Convert bytes to PIL Image
            image = Image.open(io.BytesIO(image_data))
            return image
        except Exception as e:
            raise ValueError(f"Failed to load image: {str(e)}")

    def __call__(self, image_bytes):
        # Get point prompts if provided in request
        if isinstance(image_bytes, dict):
            point_coords = image_bytes.get('point_coords')
            point_labels = image_bytes.get('point_labels')
            image_bytes = image_bytes['image']
        else:
            point_coords = None
            point_labels = None

        # Convert bytes to image
        image = Image.open(io.BytesIO(image_bytes))
        if image.mode != 'RGB':
            image = image.convert('RGB')
        image_array = np.array(image)

        # Run inference (will use mock predictor locally)
        with torch.inference_mode():
            if torch.cuda.is_available():
                with torch.autocast("cuda", dtype=torch.bfloat16):
                    self.predictor.set_image(image_array)
                    masks, scores, _ = self.predictor.predict(
                        point_coords=point_coords,
                        point_labels=point_labels
                    )
            else:
                self.predictor.set_image(image_array)
                masks, scores, _ = self.predictor.predict(
                    point_coords=point_coords,
                    point_labels=point_labels
                )

        # Format output
        if masks is not None:
            return {
                "masks": [mask.tolist() for mask in masks],
                "scores": scores.tolist() if scores is not None else None,
                "status": "success"
            }
        return {"error": "No masks generated", "status": "error"}