File size: 19,417 Bytes
3fc1e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
import aiohttp
import asyncio
import json
import re
from typing import Dict, Any, List, Union, Optional, AsyncGenerator
import time

# Global cache for project IDs: {api_key: project_id}
PROJECT_ID_CACHE: Dict[str, str] = {}


class DirectVertexClient:
    """

    A client that connects to Vertex AI using direct URLs instead of the SDK.

    Mimics the interface of genai.Client for seamless integration.

    """
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.project_id: Optional[str] = None
        self.base_url = "https://aiplatform.googleapis.com/v1"
        self.session: Optional[aiohttp.ClientSession] = None
        # Mimic the model_name attribute that might be accessed
        self.model_name = "direct_vertex_client"
        
        # Create nested structure to mimic genai.Client interface
        self.aio = self._AioNamespace(self)
    
    class _AioNamespace:
        def __init__(self, parent):
            self.parent = parent
            self.models = self._ModelsNamespace(parent)
        
        class _ModelsNamespace:
            def __init__(self, parent):
                self.parent = parent
            
            async def generate_content(self, model: str, contents: Any, config: Dict[str, Any]) -> Any:
                """Non-streaming content generation"""
                return await self.parent._generate_content(model, contents, config, stream=False)
            
            async def generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]):
                """Streaming content generation - returns an async generator"""
                # This needs to be an async method that returns the generator
                # to match the SDK's interface where you await the method call
                return self.parent._generate_content_stream(model, contents, config)
    
    async def _ensure_session(self):
        """Ensure aiohttp session is created"""
        if self.session is None:
            self.session = aiohttp.ClientSession()
    
    async def close(self):
        """Clean up resources"""
        if self.session:
            await self.session.close()
            self.session = None
    
    async def discover_project_id(self) -> None:
        """Discover project ID by triggering an intentional error"""
        # Check cache first
        if self.api_key in PROJECT_ID_CACHE:
            self.project_id = PROJECT_ID_CACHE[self.api_key]
            print(f"INFO: Using cached project ID: {self.project_id}")
            return
        
        await self._ensure_session()
        
        # Use a non-existent model to trigger error
        error_url = f"{self.base_url}/publishers/google/models/gemini-2.7-pro-preview-05-06:streamGenerateContent?key={self.api_key}"
        
        try:
            # Send minimal request to trigger error
            payload = {
                "contents": [{"role": "user", "parts": [{"text": "test"}]}]
            }
            
            async with self.session.post(error_url, json=payload) as response:
                response_text = await response.text()
                
                try:
                    # Try to parse as JSON first
                    error_data = json.loads(response_text)
                    
                    # Handle array response format
                    if isinstance(error_data, list) and len(error_data) > 0:
                        error_data = error_data[0]
                    
                    if "error" in error_data:
                        error_message = error_data["error"].get("message", "")
                        # Extract project ID from error message
                        # Pattern: "projects/39982734461/locations/..."
                        match = re.search(r'projects/(\d+)/locations/', error_message)
                        if match:
                            self.project_id = match.group(1)
                            PROJECT_ID_CACHE[self.api_key] = self.project_id
                            print(f"INFO: Discovered project ID: {self.project_id}")
                            return
                except json.JSONDecodeError:
                    # If not JSON, try to find project ID in raw text
                    match = re.search(r'projects/(\d+)/locations/', response_text)
                    if match:
                        self.project_id = match.group(1)
                        PROJECT_ID_CACHE[self.api_key] = self.project_id
                        print(f"INFO: Discovered project ID from raw response: {self.project_id}")
                        return
                
                raise Exception(f"Failed to discover project ID. Status: {response.status}, Response: {response_text[:500]}")
                
        except Exception as e:
            print(f"ERROR: Failed to discover project ID: {e}")
            raise
    
    def _convert_contents(self, contents: Any) -> List[Dict[str, Any]]:
        """Convert SDK Content objects to REST API format"""
        if isinstance(contents, list):
            return [self._convert_content_item(item) for item in contents]
        else:
            return [self._convert_content_item(contents)]
    
    def _convert_content_item(self, content: Any) -> Dict[str, Any]:
        """Convert a single content item to REST API format"""
        if isinstance(content, dict):
            return content
        
        # Handle SDK Content objects
        result = {}
        if hasattr(content, 'role'):
            result['role'] = content.role
        if hasattr(content, 'parts'):
            result['parts'] = []
            for part in content.parts:
                if isinstance(part, dict):
                    result['parts'].append(part)
                elif hasattr(part, 'text'):
                    result['parts'].append({'text': part.text})
                elif hasattr(part, 'inline_data'):
                    result['parts'].append({
                        'inline_data': {
                            'mime_type': part.inline_data.mime_type,
                            'data': part.inline_data.data
                        }
                    })
        return result
    
    def _convert_safety_settings(self, safety_settings: Any) -> List[Dict[str, str]]:
        """Convert SDK SafetySetting objects to REST API format"""
        if not safety_settings:
            return []
        
        result = []
        for setting in safety_settings:
            if isinstance(setting, dict):
                result.append(setting)
            elif hasattr(setting, 'category') and hasattr(setting, 'threshold'):
                # Convert SDK SafetySetting to dict
                result.append({
                    'category': setting.category,
                    'threshold': setting.threshold
                })
        return result
    
    def _convert_tools(self, tools: Any) -> List[Dict[str, Any]]:
        """Convert SDK Tool objects to REST API format"""
        if not tools:
            return []
        
        result = []
        for tool in tools:
            if isinstance(tool, dict):
                result.append(tool)
            else:
                # Convert SDK Tool object to dict
                result.append(self._convert_tool_item(tool))
        return result
    
    def _convert_tool_item(self, tool: Any) -> Dict[str, Any]:
        """Convert a single tool item to REST API format"""
        if isinstance(tool, dict):
            return tool
        
        tool_dict = {}
        
        # Convert all non-private attributes
        if hasattr(tool, '__dict__'):
            for attr_name, attr_value in tool.__dict__.items():
                if not attr_name.startswith('_'):
                    # Convert attribute names from snake_case to camelCase for REST API
                    rest_api_name = self._to_camel_case(attr_name)
                    
                    # Special handling for known types
                    if attr_name == 'google_search' and attr_value is not None:
                        tool_dict[rest_api_name] = {}  # GoogleSearch is empty object in REST
                    elif attr_name == 'function_declarations' and attr_value is not None:
                        tool_dict[rest_api_name] = attr_value
                    elif attr_value is not None:
                        # Recursively convert any other SDK objects
                        tool_dict[rest_api_name] = self._convert_sdk_object(attr_value)
        
        return tool_dict
    
    def _to_camel_case(self, snake_str: str) -> str:
        """Convert snake_case to camelCase"""
        components = snake_str.split('_')
        return components[0] + ''.join(x.title() for x in components[1:])
    
    def _convert_sdk_object(self, obj: Any) -> Any:
        """Generic SDK object converter"""
        if isinstance(obj, (str, int, float, bool, type(None))):
            return obj
        elif isinstance(obj, dict):
            return {k: self._convert_sdk_object(v) for k, v in obj.items()}
        elif isinstance(obj, list):
            return [self._convert_sdk_object(item) for item in obj]
        elif hasattr(obj, '__dict__'):
            # Convert SDK object to dict
            result = {}
            for key, value in obj.__dict__.items():
                if not key.startswith('_'):
                    result[self._to_camel_case(key)] = self._convert_sdk_object(value)
            return result
        else:
            return obj
    
    async def _generate_content(self, model: str, contents: Any, config: Dict[str, Any], stream: bool = False) -> Any:
        """Internal method for content generation"""
        if not self.project_id:
            raise ValueError("Project ID not discovered. Call discover_project_id() first.")
        
        await self._ensure_session()
        
        # Build URL
        endpoint = "streamGenerateContent" if stream else "generateContent"
        url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:{endpoint}?key={self.api_key}"
        
        # Convert contents to REST API format
        payload = {
            "contents": self._convert_contents(contents)
        }
        
        # Extract specific config sections
        if "system_instruction" in config:
            # System instruction should be a content object
            if isinstance(config["system_instruction"], dict):
                payload["systemInstruction"] = config["system_instruction"]
            else:
                payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
        
        if "safety_settings" in config:
            payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
        
        if "tools" in config:
            payload["tools"] = self._convert_tools(config["tools"])
        
        # All other config goes under generationConfig
        generation_config = {}
        for key, value in config.items():
            if key not in ["system_instruction", "safety_settings", "tools"]:
                generation_config[key] = value
        
        if generation_config:
            payload["generationConfig"] = generation_config
        
        try:
            async with self.session.post(url, json=payload) as response:
                if response.status != 200:
                    error_data = await response.json()
                    error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}")
                    raise Exception(f"Vertex AI API error: {error_msg}")
                
                # Get the JSON response
                response_data = await response.json()
                
                # Convert dict to object with attributes for compatibility
                return self._dict_to_obj(response_data)
                
        except Exception as e:
            print(f"ERROR: Direct Vertex API call failed: {e}")
            raise
    
    def _dict_to_obj(self, data):
        """Convert a dict to an object with attributes"""
        if isinstance(data, dict):
            # Create a simple object that allows attribute access
            class AttrDict:
                def __init__(self, d):
                    for key, value in d.items():
                        setattr(self, key, self._convert_value(value))
                
                def _convert_value(self, value):
                    if isinstance(value, dict):
                        return AttrDict(value)
                    elif isinstance(value, list):
                        return [self._convert_value(item) for item in value]
                    else:
                        return value
            
            return AttrDict(data)
        elif isinstance(data, list):
            return [self._dict_to_obj(item) for item in data]
        else:
            return data
    
    async def _generate_content_stream(self, model: str, contents: Any, config: Dict[str, Any]) -> AsyncGenerator:
        """Internal method for streaming content generation"""
        if not self.project_id:
            raise ValueError("Project ID not discovered. Call discover_project_id() first.")
        
        await self._ensure_session()
        
        # Build URL for streaming
        url = f"{self.base_url}/projects/{self.project_id}/locations/global/publishers/google/models/{model}:streamGenerateContent?key={self.api_key}"
        
        # Convert contents to REST API format
        payload = {
            "contents": self._convert_contents(contents)
        }
        
        # Extract specific config sections
        if "system_instruction" in config:
            # System instruction should be a content object
            if isinstance(config["system_instruction"], dict):
                payload["systemInstruction"] = config["system_instruction"]
            else:
                payload["systemInstruction"] = self._convert_content_item(config["system_instruction"])
        
        if "safety_settings" in config:
            payload["safetySettings"] = self._convert_safety_settings(config["safety_settings"])
        
        if "tools" in config:
            payload["tools"] = self._convert_tools(config["tools"])
        
        # All other config goes under generationConfig
        generation_config = {}
        for key, value in config.items():
            if key not in ["system_instruction", "safety_settings", "tools"]:
                generation_config[key] = value
        
        if generation_config:
            payload["generationConfig"] = generation_config
        
        try:
            async with self.session.post(url, json=payload) as response:
                if response.status != 200:
                    error_data = await response.json()
                    # Handle array response format
                    if isinstance(error_data, list) and len(error_data) > 0:
                        error_data = error_data[0]
                    error_msg = error_data.get("error", {}).get("message", f"HTTP {response.status}") if isinstance(error_data, dict) else str(error_data)
                    raise Exception(f"Vertex AI API error: {error_msg}")
                
                # The Vertex AI streaming endpoint returns JSON array elements
                # We need to parse these as they arrive
                buffer = ""
                
                async for chunk in response.content.iter_any():
                    decoded_chunk = chunk.decode('utf-8')
                    buffer += decoded_chunk
                    
                    # Try to extract complete JSON objects from the buffer
                    while True:
                        # Skip whitespace and array brackets
                        buffer = buffer.lstrip()
                        if buffer.startswith('['):
                            buffer = buffer[1:].lstrip()
                            continue
                        if buffer.startswith(']'):
                            # End of array
                            return
                        
                        # Skip comma and whitespace between objects
                        if buffer.startswith(','):
                            buffer = buffer[1:].lstrip()
                            continue
                        
                        # Look for a complete JSON object
                        if buffer.startswith('{'):
                            # Find the matching closing brace
                            brace_count = 0
                            in_string = False
                            escape_next = False
                            
                            for i, char in enumerate(buffer):
                                if escape_next:
                                    escape_next = False
                                    continue
                                
                                if char == '\\' and in_string:
                                    escape_next = True
                                    continue
                                
                                if char == '"' and not in_string:
                                    in_string = True
                                elif char == '"' and in_string:
                                    in_string = False
                                elif char == '{' and not in_string:
                                    brace_count += 1
                                elif char == '}' and not in_string:
                                    brace_count -= 1
                                    
                                    if brace_count == 0:
                                        # Found complete object
                                        obj_str = buffer[:i+1]
                                        buffer = buffer[i+1:]
                                        
                                        try:
                                            chunk_data = json.loads(obj_str)
                                            converted_obj = self._dict_to_obj(chunk_data)
                                            yield converted_obj
                                        except json.JSONDecodeError as e:
                                            print(f"ERROR: DirectVertexClient - Failed to parse JSON: {e}")
                                        
                                        break
                            else:
                                # No complete object found, need more data
                                break
                        else:
                            # No more objects to process in current buffer
                            break
                            
        except Exception as e:
            print(f"ERROR: Direct Vertex streaming API call failed: {e}")
            raise