peterproofpath commited on
Commit
c71b705
·
verified ·
1 Parent(s): 336e328

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +281 -0
  2. requirements.txt +17 -0
handler.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SigLIP 2 Custom Inference Handler for Hugging Face Inference Endpoints
3
+ Model: google/siglip2-so400m-patch14-384 (Best balance of performance/quality)
4
+
5
+ For ProofPath video assessment - identifies objects, tools, and actions in video frames.
6
+ """
7
+
8
+ from typing import Dict, List, Any, Union
9
+ import torch
10
+ import numpy as np
11
+ import base64
12
+ import io
13
+ from PIL import Image
14
+
15
+
16
+ class EndpointHandler:
17
+ def __init__(self, path: str = ""):
18
+ """
19
+ Initialize SigLIP 2 model for image/frame classification and embedding.
20
+
21
+ Args:
22
+ path: Path to the model directory (provided by HF Inference Endpoints)
23
+ """
24
+ from transformers import AutoProcessor, AutoModel
25
+
26
+ # Use the model path provided by the endpoint, or default to HF hub
27
+ model_id = path if path else "google/siglip2-so400m-patch14-384"
28
+
29
+ # Determine device
30
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+
32
+ # Load processor and model
33
+ self.processor = AutoProcessor.from_pretrained(model_id)
34
+ self.model = AutoModel.from_pretrained(
35
+ model_id,
36
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
37
+ device_map="auto" if torch.cuda.is_available() else None,
38
+ attn_implementation="sdpa" # Use scaled dot product attention
39
+ )
40
+
41
+ if not torch.cuda.is_available():
42
+ self.model = self.model.to(self.device)
43
+
44
+ self.model.eval()
45
+
46
+ def _decode_image(self, image_data: Any) -> Image.Image:
47
+ """
48
+ Decode image from various input formats.
49
+
50
+ Supports:
51
+ - Base64 encoded image
52
+ - URL to image
53
+ - PIL Image
54
+ - Raw bytes
55
+ """
56
+ import requests
57
+
58
+ if isinstance(image_data, Image.Image):
59
+ return image_data
60
+ elif isinstance(image_data, str):
61
+ if image_data.startswith(('http://', 'https://')):
62
+ # URL
63
+ response = requests.get(image_data, stream=True)
64
+ return Image.open(response.raw).convert('RGB')
65
+ elif image_data.startswith('data:'):
66
+ # Data URL
67
+ header, encoded = image_data.split(',', 1)
68
+ image_bytes = base64.b64decode(encoded)
69
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
70
+ else:
71
+ # Assume base64
72
+ image_bytes = base64.b64decode(image_data)
73
+ return Image.open(io.BytesIO(image_bytes)).convert('RGB')
74
+ elif isinstance(image_data, bytes):
75
+ return Image.open(io.BytesIO(image_data)).convert('RGB')
76
+ else:
77
+ raise ValueError(f"Unsupported image input type: {type(image_data)}")
78
+
79
+ def _process_batch(
80
+ self,
81
+ images: List[Image.Image],
82
+ texts: List[str] = None
83
+ ) -> Dict[str, torch.Tensor]:
84
+ """Process a batch of images and optional texts."""
85
+ if texts:
86
+ # SigLIP 2 requires specific padding for text
87
+ inputs = self.processor(
88
+ images=images,
89
+ text=texts,
90
+ padding="max_length",
91
+ max_length=64,
92
+ return_tensors="pt"
93
+ )
94
+ else:
95
+ inputs = self.processor(
96
+ images=images,
97
+ return_tensors="pt"
98
+ )
99
+
100
+ return {k: v.to(self.model.device) for k, v in inputs.items()}
101
+
102
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
103
+ """
104
+ Process image(s) for classification or embedding extraction.
105
+
106
+ Expected input formats:
107
+
108
+ 1. Zero-shot classification:
109
+ {
110
+ "inputs": <image_data>, # single image or list of images
111
+ "parameters": {
112
+ "candidate_labels": ["label1", "label2", ...],
113
+ "hypothesis_template": "This is a photo of {}." # Optional
114
+ }
115
+ }
116
+
117
+ 2. Image embedding only:
118
+ {
119
+ "inputs": <image_data>,
120
+ "parameters": {
121
+ "mode": "embedding"
122
+ }
123
+ }
124
+
125
+ 3. Image-text similarity:
126
+ {
127
+ "inputs": {
128
+ "images": [<image1>, <image2>, ...],
129
+ "texts": ["text1", "text2", ...]
130
+ },
131
+ "parameters": {
132
+ "mode": "similarity"
133
+ }
134
+ }
135
+
136
+ Returns for classification:
137
+ {
138
+ "labels": ["label1", "label2"],
139
+ "scores": [0.85, 0.12],
140
+ "predictions": [{"label": "label1", "score": 0.85}, ...]
141
+ }
142
+
143
+ Returns for embedding:
144
+ {
145
+ "image_embeddings": [[...], ...],
146
+ "embedding_shape": [batch, hidden_dim]
147
+ }
148
+
149
+ Returns for similarity:
150
+ {
151
+ "similarity_matrix": [[...], ...],
152
+ "shape": [num_images, num_texts]
153
+ }
154
+ """
155
+ inputs = data.get("inputs")
156
+ if inputs is None:
157
+ inputs = data.get("image") or data.get("images")
158
+ if inputs is None:
159
+ raise ValueError("No input provided. Use 'inputs', 'image', or 'images' key.")
160
+
161
+ params = data.get("parameters", {})
162
+ mode = params.get("mode", "classification")
163
+
164
+ try:
165
+ # Handle different modes
166
+ if mode == "embedding":
167
+ return self._extract_embeddings(inputs)
168
+ elif mode == "similarity":
169
+ return self._compute_similarity(inputs, params)
170
+ else:
171
+ # Default: zero-shot classification
172
+ return self._classify(inputs, params)
173
+
174
+ except Exception as e:
175
+ return {"error": str(e), "error_type": type(e).__name__}
176
+
177
+ def _classify(self, inputs: Any, params: Dict) -> Dict[str, Any]:
178
+ """Zero-shot image classification."""
179
+ candidate_labels = params.get("candidate_labels", [])
180
+ if not candidate_labels:
181
+ raise ValueError("candidate_labels required for classification mode")
182
+
183
+ hypothesis_template = params.get("hypothesis_template", "This is a photo of {}.")
184
+
185
+ # Decode image(s)
186
+ if isinstance(inputs, list):
187
+ images = [self._decode_image(img) for img in inputs]
188
+ else:
189
+ images = [self._decode_image(inputs)]
190
+
191
+ # Create text prompts from labels
192
+ texts = [hypothesis_template.format(label) for label in candidate_labels]
193
+
194
+ results = []
195
+ for image in images:
196
+ # Process single image with all candidate labels
197
+ processed = self._process_batch([image] * len(texts), texts)
198
+
199
+ with torch.no_grad():
200
+ outputs = self.model(**processed)
201
+
202
+ # SigLIP uses sigmoid, not softmax
203
+ logits_per_image = outputs.logits_per_image
204
+ probs = torch.sigmoid(logits_per_image[0]) # Shape: [num_labels]
205
+
206
+ # Sort by probability
207
+ sorted_indices = probs.argsort(descending=True)
208
+
209
+ predictions = []
210
+ for idx in sorted_indices:
211
+ predictions.append({
212
+ "label": candidate_labels[idx.item()],
213
+ "score": float(probs[idx].item())
214
+ })
215
+
216
+ results.append({
217
+ "labels": [p["label"] for p in predictions],
218
+ "scores": [p["score"] for p in predictions],
219
+ "predictions": predictions
220
+ })
221
+
222
+ # Return single result if single input
223
+ if len(results) == 1:
224
+ return results[0]
225
+ return {"results": results}
226
+
227
+ def _extract_embeddings(self, inputs: Any) -> Dict[str, Any]:
228
+ """Extract image embeddings only."""
229
+ # Decode image(s)
230
+ if isinstance(inputs, list):
231
+ images = [self._decode_image(img) for img in inputs]
232
+ else:
233
+ images = [self._decode_image(inputs)]
234
+
235
+ processed = self.processor(images=images, return_tensors="pt")
236
+ processed = {k: v.to(self.model.device) for k, v in processed.items()}
237
+
238
+ with torch.no_grad():
239
+ # Get vision features directly
240
+ vision_outputs = self.model.get_image_features(**processed)
241
+
242
+ embeddings = vision_outputs.cpu().numpy().tolist()
243
+
244
+ return {
245
+ "image_embeddings": embeddings,
246
+ "embedding_shape": list(vision_outputs.shape)
247
+ }
248
+
249
+ def _compute_similarity(self, inputs: Dict, params: Dict) -> Dict[str, Any]:
250
+ """Compute image-text similarity matrix."""
251
+ images_data = inputs.get("images", [])
252
+ texts = inputs.get("texts", [])
253
+
254
+ if not images_data or not texts:
255
+ raise ValueError("Both 'images' and 'texts' required for similarity mode")
256
+
257
+ # Decode images
258
+ images = [self._decode_image(img) for img in images_data]
259
+
260
+ # Process with padding for SigLIP 2
261
+ processed = self.processor(
262
+ images=images,
263
+ text=texts,
264
+ padding="max_length",
265
+ max_length=64,
266
+ return_tensors="pt"
267
+ )
268
+ processed = {k: v.to(self.model.device) for k, v in processed.items()}
269
+
270
+ with torch.no_grad():
271
+ outputs = self.model(**processed)
272
+
273
+ # Get similarity matrix
274
+ similarity = outputs.logits_per_image # [num_images, num_texts]
275
+ probs = torch.sigmoid(similarity)
276
+
277
+ return {
278
+ "similarity_matrix": probs.cpu().numpy().tolist(),
279
+ "shape": list(probs.shape),
280
+ "logits": similarity.cpu().numpy().tolist()
281
+ }
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # V-JEPA 2 Inference Endpoint Requirements
2
+ # Note: transformers and torch are pre-installed in HF Inference containers
3
+
4
+ # For latest V-JEPA 2 support (may need bleeding edge)
5
+ transformers>=4.45.0
6
+ torch>=2.0.0
7
+
8
+ # Video decoding
9
+ torchcodec>=0.1.0
10
+
11
+ # Standard deps (usually pre-installed)
12
+ numpy>=1.24.0
13
+ einops>=0.7.0
14
+ timm>=0.9.0
15
+
16
+ # For efficient attention
17
+ accelerate>=0.25.0