NeerajCodz commited on
Commit
ca1fd98
·
1 Parent(s): 3bfb250

feat: add multi-model LLM router with providers

Browse files
backend/app/models/__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Models module - LLM providers, routing, and ensemble capabilities."""
2
+
3
+ from app.models.router import (
4
+ SmartModelRouter,
5
+ RoutingStrategy,
6
+ RoutingConfig,
7
+ CostTracker,
8
+ ModelScore,
9
+ )
10
+ from app.models.ensemble import (
11
+ ModelEnsemble,
12
+ AggregationStrategy,
13
+ EnsembleResult,
14
+ )
15
+ from app.models.providers import (
16
+ # Base
17
+ BaseProvider,
18
+ ProviderError,
19
+ RateLimitError,
20
+ ModelNotFoundError,
21
+ CompletionResponse,
22
+ ModelInfo,
23
+ TokenUsage,
24
+ # Providers
25
+ OpenAIProvider,
26
+ AnthropicProvider,
27
+ GoogleProvider,
28
+ GroqProvider,
29
+ )
30
+ from app.models.providers.base import TaskType
31
+
32
+ __all__ = [
33
+ # Router
34
+ "SmartModelRouter",
35
+ "RoutingStrategy",
36
+ "RoutingConfig",
37
+ "CostTracker",
38
+ "ModelScore",
39
+ "TaskType",
40
+ # Ensemble
41
+ "ModelEnsemble",
42
+ "AggregationStrategy",
43
+ "EnsembleResult",
44
+ # Base
45
+ "BaseProvider",
46
+ "ProviderError",
47
+ "RateLimitError",
48
+ "ModelNotFoundError",
49
+ "CompletionResponse",
50
+ "ModelInfo",
51
+ "TokenUsage",
52
+ # Providers
53
+ "OpenAIProvider",
54
+ "AnthropicProvider",
55
+ "GoogleProvider",
56
+ "GroqProvider",
57
+ ]
backend/app/models/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (1.03 kB). View file
 
backend/app/models/__pycache__/ensemble.cpython-314.pyc ADDED
Binary file (23.7 kB). View file
 
backend/app/models/__pycache__/router.cpython-314.pyc ADDED
Binary file (27.7 kB). View file
 
backend/app/models/ensemble.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model ensemble for running multiple models and aggregating results."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any
8
+
9
+ from app.models.providers.base import (
10
+ BaseProvider,
11
+ CompletionResponse,
12
+ ProviderError,
13
+ TokenUsage,
14
+ )
15
+ from app.models.router import SmartModelRouter
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class AggregationStrategy(str, Enum):
21
+ """Strategy for aggregating ensemble results."""
22
+
23
+ MAJORITY_VOTE = "majority_vote" # Use most common response
24
+ CONFIDENCE_WEIGHTED = "confidence_weighted" # Weight by model confidence
25
+ FIRST_SUCCESS = "first_success" # Use first successful response
26
+ BEST_QUALITY = "best_quality" # Use response from highest quality model
27
+ CONCATENATE = "concatenate" # Combine all responses
28
+ CONSENSUS = "consensus" # Only return if models agree
29
+
30
+
31
+ @dataclass
32
+ class EnsembleResult:
33
+ """Result from an ensemble run."""
34
+
35
+ content: str
36
+ responses: list[CompletionResponse]
37
+ agreement_score: float # 0-1, how much models agreed
38
+ strategy: AggregationStrategy
39
+ selected_model: str | None = None
40
+ total_cost: float = 0.0
41
+ total_tokens: TokenUsage = field(default_factory=TokenUsage)
42
+ metadata: dict[str, Any] = field(default_factory=dict)
43
+
44
+ def to_dict(self) -> dict[str, Any]:
45
+ """Convert to dictionary."""
46
+ return {
47
+ "content": self.content,
48
+ "responses": [r.to_dict() for r in self.responses],
49
+ "agreement_score": self.agreement_score,
50
+ "strategy": self.strategy.value,
51
+ "selected_model": self.selected_model,
52
+ "total_cost": self.total_cost,
53
+ "total_tokens": {
54
+ "prompt": self.total_tokens.prompt_tokens,
55
+ "completion": self.total_tokens.completion_tokens,
56
+ "total": self.total_tokens.total_tokens,
57
+ },
58
+ "metadata": self.metadata,
59
+ }
60
+
61
+
62
+ class ModelEnsemble:
63
+ """Run multiple models and aggregate their results."""
64
+
65
+ # Model quality tiers for weighted voting
66
+ MODEL_QUALITY_TIERS: dict[str, float] = {
67
+ # Tier 1: Highest quality
68
+ "claude-3-opus-20240229": 1.0,
69
+ "gpt-4o": 0.98,
70
+ "claude-3-5-sonnet-20241022": 0.97,
71
+ "gemini-1.5-pro": 0.95,
72
+ # Tier 2: High quality
73
+ "gpt-4-turbo": 0.90,
74
+ "gpt-4": 0.88,
75
+ "claude-3-sonnet-20240229": 0.85,
76
+ "llama-3.3-70b-versatile": 0.83,
77
+ # Tier 3: Good quality
78
+ "gpt-4o-mini": 0.75,
79
+ "claude-3-5-haiku-20241022": 0.73,
80
+ "gemini-1.5-flash": 0.70,
81
+ "mixtral-8x7b-32768": 0.68,
82
+ # Tier 4: Fast/cheap
83
+ "claude-3-haiku-20240307": 0.60,
84
+ "llama-3.1-8b-instant": 0.55,
85
+ "gpt-3.5-turbo": 0.50,
86
+ }
87
+
88
+ def __init__(
89
+ self,
90
+ router: SmartModelRouter,
91
+ default_models: list[str] | None = None,
92
+ default_strategy: AggregationStrategy = AggregationStrategy.CONFIDENCE_WEIGHTED,
93
+ timeout: float = 60.0,
94
+ ):
95
+ """Initialize the ensemble.
96
+
97
+ Args:
98
+ router: SmartModelRouter instance for accessing providers
99
+ default_models: Default models to use in ensemble
100
+ default_strategy: Default aggregation strategy
101
+ timeout: Timeout for each model completion
102
+ """
103
+ self.router = router
104
+ self.default_models = default_models or []
105
+ self.default_strategy = default_strategy
106
+ self.timeout = timeout
107
+
108
+ async def run(
109
+ self,
110
+ messages: list[dict[str, Any]],
111
+ models: list[str] | None = None,
112
+ strategy: AggregationStrategy | None = None,
113
+ min_responses: int = 1,
114
+ **kwargs: Any,
115
+ ) -> EnsembleResult:
116
+ """Run multiple models and aggregate results.
117
+
118
+ Args:
119
+ messages: List of message dicts
120
+ models: List of model IDs to use (uses defaults if not specified)
121
+ strategy: Aggregation strategy (uses default if not specified)
122
+ min_responses: Minimum number of successful responses required
123
+ **kwargs: Additional completion parameters
124
+
125
+ Returns:
126
+ EnsembleResult with aggregated content and metadata
127
+
128
+ Raises:
129
+ ProviderError: If not enough models respond successfully
130
+ """
131
+ models_to_use = models or self.default_models
132
+ strategy = strategy or self.default_strategy
133
+
134
+ if not models_to_use:
135
+ # Use top 3 available models
136
+ available = self.router.get_available_models()
137
+ models_to_use = [m.id for m in available[:3]]
138
+
139
+ if not models_to_use:
140
+ raise ProviderError("No models available for ensemble", "ensemble")
141
+
142
+ # Run all models concurrently
143
+ tasks = []
144
+ for model_id in models_to_use:
145
+ provider = self.router.get_provider_for_model(model_id)
146
+ if provider:
147
+ task = self._run_model(provider, model_id, messages, **kwargs)
148
+ tasks.append((model_id, task))
149
+
150
+ if not tasks:
151
+ raise ProviderError("No valid models for ensemble", "ensemble")
152
+
153
+ # Gather results
154
+ responses: list[CompletionResponse] = []
155
+ errors: list[tuple[str, Exception]] = []
156
+
157
+ results = await asyncio.gather(
158
+ *[t[1] for t in tasks],
159
+ return_exceptions=True,
160
+ )
161
+
162
+ for (model_id, _), result in zip(tasks, results):
163
+ if isinstance(result, Exception):
164
+ logger.warning(f"Model {model_id} failed: {result}")
165
+ errors.append((model_id, result))
166
+ elif result is not None:
167
+ responses.append(result)
168
+
169
+ if len(responses) < min_responses:
170
+ raise ProviderError(
171
+ f"Only {len(responses)} models responded, need {min_responses}. "
172
+ f"Errors: {[str(e) for _, e in errors]}",
173
+ "ensemble",
174
+ )
175
+
176
+ # Aggregate results
177
+ result = self._aggregate(responses, strategy)
178
+
179
+ return result
180
+
181
+ async def _run_model(
182
+ self,
183
+ provider: BaseProvider,
184
+ model_id: str,
185
+ messages: list[dict[str, Any]],
186
+ **kwargs: Any,
187
+ ) -> CompletionResponse | None:
188
+ """Run a single model with timeout."""
189
+ try:
190
+ return await asyncio.wait_for(
191
+ provider.complete(messages, model_id, **kwargs),
192
+ timeout=self.timeout,
193
+ )
194
+ except asyncio.TimeoutError:
195
+ logger.warning(f"Model {model_id} timed out")
196
+ return None
197
+ except Exception as e:
198
+ logger.warning(f"Model {model_id} error: {e}")
199
+ raise
200
+
201
+ def _aggregate(
202
+ self,
203
+ responses: list[CompletionResponse],
204
+ strategy: AggregationStrategy,
205
+ ) -> EnsembleResult:
206
+ """Aggregate responses based on strategy."""
207
+ if not responses:
208
+ raise ProviderError("No responses to aggregate", "ensemble")
209
+
210
+ # Calculate total cost and tokens
211
+ total_cost = sum(r.cost for r in responses)
212
+ total_tokens = TokenUsage()
213
+ for r in responses:
214
+ total_tokens = total_tokens + r.usage
215
+
216
+ # Calculate agreement score
217
+ agreement_score = self._calculate_agreement(responses)
218
+
219
+ # Select content based on strategy
220
+ if strategy == AggregationStrategy.FIRST_SUCCESS:
221
+ content, selected_model = self._first_success(responses)
222
+ elif strategy == AggregationStrategy.MAJORITY_VOTE:
223
+ content, selected_model = self._majority_vote(responses)
224
+ elif strategy == AggregationStrategy.CONFIDENCE_WEIGHTED:
225
+ content, selected_model = self._confidence_weighted(responses)
226
+ elif strategy == AggregationStrategy.BEST_QUALITY:
227
+ content, selected_model = self._best_quality(responses)
228
+ elif strategy == AggregationStrategy.CONCATENATE:
229
+ content, selected_model = self._concatenate(responses)
230
+ elif strategy == AggregationStrategy.CONSENSUS:
231
+ content, selected_model = self._consensus(responses, agreement_score)
232
+ else:
233
+ content, selected_model = self._first_success(responses)
234
+
235
+ return EnsembleResult(
236
+ content=content,
237
+ responses=responses,
238
+ agreement_score=agreement_score,
239
+ strategy=strategy,
240
+ selected_model=selected_model,
241
+ total_cost=total_cost,
242
+ total_tokens=total_tokens,
243
+ metadata={
244
+ "num_responses": len(responses),
245
+ "models_used": [r.model for r in responses],
246
+ },
247
+ )
248
+
249
+ def _calculate_agreement(self, responses: list[CompletionResponse]) -> float:
250
+ """Calculate agreement score between responses.
251
+
252
+ Uses simple similarity based on common words/tokens.
253
+ """
254
+ if len(responses) < 2:
255
+ return 1.0
256
+
257
+ # Tokenize responses (simple word-based)
258
+ response_tokens = []
259
+ for r in responses:
260
+ words = set(r.content.lower().split())
261
+ response_tokens.append(words)
262
+
263
+ # Calculate pairwise Jaccard similarity
264
+ similarities = []
265
+ for i in range(len(response_tokens)):
266
+ for j in range(i + 1, len(response_tokens)):
267
+ set_i = response_tokens[i]
268
+ set_j = response_tokens[j]
269
+
270
+ if not set_i and not set_j:
271
+ similarities.append(1.0)
272
+ elif not set_i or not set_j:
273
+ similarities.append(0.0)
274
+ else:
275
+ intersection = len(set_i & set_j)
276
+ union = len(set_i | set_j)
277
+ similarities.append(intersection / union)
278
+
279
+ return sum(similarities) / len(similarities) if similarities else 1.0
280
+
281
+ def _first_success(
282
+ self, responses: list[CompletionResponse]
283
+ ) -> tuple[str, str | None]:
284
+ """Return the first successful response."""
285
+ r = responses[0]
286
+ return r.content, r.model
287
+
288
+ def _majority_vote(
289
+ self, responses: list[CompletionResponse]
290
+ ) -> tuple[str, str | None]:
291
+ """Return the most common response (by content similarity)."""
292
+ if len(responses) == 1:
293
+ return responses[0].content, responses[0].model
294
+
295
+ # Find response most similar to others
296
+ best_idx = 0
297
+ best_score = 0.0
298
+
299
+ for i, r in enumerate(responses):
300
+ score = 0.0
301
+ words_i = set(r.content.lower().split())
302
+
303
+ for j, other in enumerate(responses):
304
+ if i != j:
305
+ words_j = set(other.content.lower().split())
306
+ if words_i and words_j:
307
+ intersection = len(words_i & words_j)
308
+ union = len(words_i | words_j)
309
+ score += intersection / union
310
+
311
+ if score > best_score:
312
+ best_score = score
313
+ best_idx = i
314
+
315
+ return responses[best_idx].content, responses[best_idx].model
316
+
317
+ def _confidence_weighted(
318
+ self, responses: list[CompletionResponse]
319
+ ) -> tuple[str, str | None]:
320
+ """Weight responses by model quality/confidence."""
321
+ if len(responses) == 1:
322
+ return responses[0].content, responses[0].model
323
+
324
+ # Score each response by model quality
325
+ scored = []
326
+ for r in responses:
327
+ quality = self.MODEL_QUALITY_TIERS.get(r.model, 0.5)
328
+ scored.append((quality, r))
329
+
330
+ # Sort by quality
331
+ scored.sort(key=lambda x: x[0], reverse=True)
332
+
333
+ # Return highest quality response
334
+ best = scored[0][1]
335
+ return best.content, best.model
336
+
337
+ def _best_quality(
338
+ self, responses: list[CompletionResponse]
339
+ ) -> tuple[str, str | None]:
340
+ """Return response from highest quality model."""
341
+ best_quality = 0.0
342
+ best_response = responses[0]
343
+
344
+ for r in responses:
345
+ quality = self.MODEL_QUALITY_TIERS.get(r.model, 0.5)
346
+ if quality > best_quality:
347
+ best_quality = quality
348
+ best_response = r
349
+
350
+ return best_response.content, best_response.model
351
+
352
+ def _concatenate(
353
+ self, responses: list[CompletionResponse]
354
+ ) -> tuple[str, str | None]:
355
+ """Concatenate all responses."""
356
+ parts = []
357
+ models = []
358
+
359
+ for r in responses:
360
+ parts.append(f"[{r.model}]:\n{r.content}")
361
+ models.append(r.model)
362
+
363
+ content = "\n\n---\n\n".join(parts)
364
+ return content, None # No single model selected
365
+
366
+ def _consensus(
367
+ self,
368
+ responses: list[CompletionResponse],
369
+ agreement_score: float,
370
+ ) -> tuple[str, str | None]:
371
+ """Return result only if models agree (high agreement score)."""
372
+ if agreement_score < 0.5:
373
+ # Low agreement, return best quality with warning
374
+ content, model = self._best_quality(responses)
375
+ return f"[LOW CONSENSUS - {agreement_score:.2f}]\n{content}", model
376
+
377
+ # Good agreement, return majority vote
378
+ return self._majority_vote(responses)
379
+
380
+ async def compare(
381
+ self,
382
+ messages: list[dict[str, Any]],
383
+ models: list[str] | None = None,
384
+ **kwargs: Any,
385
+ ) -> dict[str, Any]:
386
+ """Compare responses from multiple models side-by-side.
387
+
388
+ Args:
389
+ messages: List of message dicts
390
+ models: List of model IDs to compare
391
+ **kwargs: Additional completion parameters
392
+
393
+ Returns:
394
+ Dictionary with comparison data
395
+ """
396
+ result = await self.run(
397
+ messages,
398
+ models,
399
+ strategy=AggregationStrategy.CONCATENATE,
400
+ **kwargs,
401
+ )
402
+
403
+ # Build comparison
404
+ comparison = {
405
+ "responses": [],
406
+ "agreement_score": result.agreement_score,
407
+ "total_cost": result.total_cost,
408
+ "total_tokens": {
409
+ "prompt": result.total_tokens.prompt_tokens,
410
+ "completion": result.total_tokens.completion_tokens,
411
+ "total": result.total_tokens.total_tokens,
412
+ },
413
+ }
414
+
415
+ for r in result.responses:
416
+ comparison["responses"].append({
417
+ "model": r.model,
418
+ "provider": r.provider,
419
+ "content": r.content,
420
+ "cost": r.cost,
421
+ "latency_ms": r.latency_ms,
422
+ "tokens": {
423
+ "prompt": r.usage.prompt_tokens,
424
+ "completion": r.usage.completion_tokens,
425
+ },
426
+ "quality_tier": self.MODEL_QUALITY_TIERS.get(r.model, 0.5),
427
+ })
428
+
429
+ return comparison
430
+
431
+ async def debate(
432
+ self,
433
+ messages: list[dict[str, Any]],
434
+ models: list[str] | None = None,
435
+ rounds: int = 2,
436
+ **kwargs: Any,
437
+ ) -> EnsembleResult:
438
+ """Run a debate between models where they can respond to each other.
439
+
440
+ Args:
441
+ messages: Initial messages
442
+ models: Models to participate in debate
443
+ rounds: Number of debate rounds
444
+ **kwargs: Additional completion parameters
445
+
446
+ Returns:
447
+ Final ensemble result with debate history
448
+ """
449
+ models_to_use = models or self.default_models[:2] # Default to 2 models
450
+
451
+ if len(models_to_use) < 2:
452
+ raise ProviderError("Debate requires at least 2 models", "ensemble")
453
+
454
+ all_responses: list[CompletionResponse] = []
455
+ debate_history: list[dict[str, Any]] = []
456
+ current_messages = messages.copy()
457
+
458
+ for round_num in range(rounds):
459
+ round_responses = []
460
+
461
+ for model_id in models_to_use:
462
+ provider = self.router.get_provider_for_model(model_id)
463
+ if not provider:
464
+ continue
465
+
466
+ try:
467
+ response = await asyncio.wait_for(
468
+ provider.complete(current_messages, model_id, **kwargs),
469
+ timeout=self.timeout,
470
+ )
471
+ round_responses.append(response)
472
+ all_responses.append(response)
473
+
474
+ debate_history.append({
475
+ "round": round_num + 1,
476
+ "model": model_id,
477
+ "content": response.content,
478
+ })
479
+
480
+ except Exception as e:
481
+ logger.warning(f"Model {model_id} failed in round {round_num + 1}: {e}")
482
+
483
+ # Add responses to messages for next round
484
+ if round_responses and round_num < rounds - 1:
485
+ for r in round_responses:
486
+ current_messages.append({
487
+ "role": "assistant",
488
+ "content": f"[{r.model}]: {r.content}",
489
+ })
490
+
491
+ # Ask for follow-up
492
+ current_messages.append({
493
+ "role": "user",
494
+ "content": "Consider the other perspectives and refine your answer.",
495
+ })
496
+
497
+ # Aggregate final round responses
498
+ final_responses = all_responses[-len(models_to_use) :]
499
+ result = self._aggregate(final_responses, AggregationStrategy.CONFIDENCE_WEIGHTED)
500
+
501
+ # Add debate history to metadata
502
+ result.metadata["debate_history"] = debate_history
503
+ result.metadata["total_rounds"] = rounds
504
+
505
+ return result
backend/app/models/providers/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LLM Providers - Multiple provider implementations for model routing."""
2
+
3
+ from app.models.providers.base import (
4
+ BaseProvider,
5
+ ProviderError,
6
+ RateLimitError,
7
+ ModelNotFoundError,
8
+ CompletionResponse,
9
+ ModelInfo,
10
+ TokenUsage,
11
+ )
12
+ from app.models.providers.openai import OpenAIProvider
13
+ from app.models.providers.anthropic import AnthropicProvider
14
+ from app.models.providers.google import GoogleProvider
15
+ from app.models.providers.groq import GroqProvider
16
+
17
+ __all__ = [
18
+ # Base
19
+ "BaseProvider",
20
+ "ProviderError",
21
+ "RateLimitError",
22
+ "ModelNotFoundError",
23
+ "CompletionResponse",
24
+ "ModelInfo",
25
+ "TokenUsage",
26
+ # Providers
27
+ "OpenAIProvider",
28
+ "AnthropicProvider",
29
+ "GoogleProvider",
30
+ "GroqProvider",
31
+ ]
backend/app/models/providers/__pycache__/__init__.cpython-314.pyc ADDED
Binary file (824 Bytes). View file
 
backend/app/models/providers/__pycache__/anthropic.cpython-314.pyc ADDED
Binary file (17.7 kB). View file
 
backend/app/models/providers/__pycache__/base.cpython-314.pyc ADDED
Binary file (22.1 kB). View file
 
backend/app/models/providers/__pycache__/google.cpython-314.pyc ADDED
Binary file (18 kB). View file
 
backend/app/models/providers/__pycache__/groq.cpython-314.pyc ADDED
Binary file (14.4 kB). View file
 
backend/app/models/providers/__pycache__/openai.cpython-314.pyc ADDED
Binary file (14.9 kB). View file
 
backend/app/models/providers/anthropic.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Anthropic provider implementation."""
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, AsyncIterator
6
+
7
+ import httpx
8
+
9
+ from app.models.providers.base import (
10
+ AuthenticationError,
11
+ BaseProvider,
12
+ CompletionResponse,
13
+ ModelInfo,
14
+ ModelNotFoundError,
15
+ ProviderError,
16
+ RateLimitError,
17
+ TokenUsage,
18
+ )
19
+
20
+
21
+ class AnthropicProvider(BaseProvider):
22
+ """Anthropic API provider supporting Claude models."""
23
+
24
+ PROVIDER_NAME = "anthropic"
25
+ DEFAULT_BASE_URL = "https://api.anthropic.com/v1"
26
+ API_VERSION = "2023-06-01"
27
+
28
+ # Model definitions with pricing (per 1K tokens)
29
+ MODELS = {
30
+ "claude-3-opus-20240229": ModelInfo(
31
+ id="claude-3-opus-20240229",
32
+ name="Claude 3 Opus",
33
+ provider="anthropic",
34
+ context_window=200000,
35
+ max_output_tokens=4096,
36
+ supports_functions=True,
37
+ supports_vision=True,
38
+ supports_streaming=True,
39
+ cost_per_1k_input=0.015,
40
+ cost_per_1k_output=0.075,
41
+ ),
42
+ "claude-3-sonnet-20240229": ModelInfo(
43
+ id="claude-3-sonnet-20240229",
44
+ name="Claude 3 Sonnet",
45
+ provider="anthropic",
46
+ context_window=200000,
47
+ max_output_tokens=4096,
48
+ supports_functions=True,
49
+ supports_vision=True,
50
+ supports_streaming=True,
51
+ cost_per_1k_input=0.003,
52
+ cost_per_1k_output=0.015,
53
+ ),
54
+ "claude-3-5-sonnet-20241022": ModelInfo(
55
+ id="claude-3-5-sonnet-20241022",
56
+ name="Claude 3.5 Sonnet",
57
+ provider="anthropic",
58
+ context_window=200000,
59
+ max_output_tokens=8192,
60
+ supports_functions=True,
61
+ supports_vision=True,
62
+ supports_streaming=True,
63
+ cost_per_1k_input=0.003,
64
+ cost_per_1k_output=0.015,
65
+ ),
66
+ "claude-3-haiku-20240307": ModelInfo(
67
+ id="claude-3-haiku-20240307",
68
+ name="Claude 3 Haiku",
69
+ provider="anthropic",
70
+ context_window=200000,
71
+ max_output_tokens=4096,
72
+ supports_functions=True,
73
+ supports_vision=True,
74
+ supports_streaming=True,
75
+ cost_per_1k_input=0.00025,
76
+ cost_per_1k_output=0.00125,
77
+ ),
78
+ "claude-3-5-haiku-20241022": ModelInfo(
79
+ id="claude-3-5-haiku-20241022",
80
+ name="Claude 3.5 Haiku",
81
+ provider="anthropic",
82
+ context_window=200000,
83
+ max_output_tokens=8192,
84
+ supports_functions=True,
85
+ supports_vision=True,
86
+ supports_streaming=True,
87
+ cost_per_1k_input=0.001,
88
+ cost_per_1k_output=0.005,
89
+ ),
90
+ }
91
+
92
+ # Aliases for convenience
93
+ MODEL_ALIASES = {
94
+ "claude-3-opus": "claude-3-opus-20240229",
95
+ "claude-3-sonnet": "claude-3-sonnet-20240229",
96
+ "claude-3.5-sonnet": "claude-3-5-sonnet-20241022",
97
+ "claude-3-haiku": "claude-3-haiku-20240307",
98
+ "claude-3.5-haiku": "claude-3-5-haiku-20241022",
99
+ }
100
+
101
+ def __init__(
102
+ self,
103
+ api_key: str,
104
+ base_url: str | None = None,
105
+ timeout: float = 60.0,
106
+ max_retries: int = 3,
107
+ rate_limit_rpm: int = 60,
108
+ ):
109
+ super().__init__(
110
+ api_key=api_key,
111
+ base_url=base_url or self.DEFAULT_BASE_URL,
112
+ timeout=timeout,
113
+ max_retries=max_retries,
114
+ rate_limit_rpm=rate_limit_rpm,
115
+ )
116
+ self._client: httpx.AsyncClient | None = None
117
+
118
+ async def initialize(self) -> None:
119
+ """Initialize the HTTP client."""
120
+ self._client = httpx.AsyncClient(
121
+ base_url=self.base_url,
122
+ headers={
123
+ "x-api-key": self.api_key,
124
+ "anthropic-version": self.API_VERSION,
125
+ "Content-Type": "application/json",
126
+ },
127
+ timeout=self.timeout,
128
+ )
129
+
130
+ async def shutdown(self) -> None:
131
+ """Close the HTTP client."""
132
+ if self._client:
133
+ await self._client.aclose()
134
+ self._client = None
135
+
136
+ async def _ensure_client(self) -> httpx.AsyncClient:
137
+ """Ensure client is initialized."""
138
+ if not self._client:
139
+ await self.initialize()
140
+ return self._client # type: ignore
141
+
142
+ def _resolve_model(self, model: str) -> str:
143
+ """Resolve model alias to full model ID."""
144
+ return self.MODEL_ALIASES.get(model, model)
145
+
146
+ def get_models(self) -> list[ModelInfo]:
147
+ """Get available Anthropic models."""
148
+ return list(self.MODELS.values())
149
+
150
+ def _convert_messages(
151
+ self, messages: list[dict[str, Any]]
152
+ ) -> tuple[str | None, list[dict[str, Any]]]:
153
+ """Convert OpenAI-style messages to Anthropic format.
154
+
155
+ Returns:
156
+ Tuple of (system_message, converted_messages)
157
+ """
158
+ system_message: str | None = None
159
+ converted: list[dict[str, Any]] = []
160
+
161
+ for msg in messages:
162
+ role = msg["role"]
163
+ content = msg["content"]
164
+
165
+ if role == "system":
166
+ system_message = content
167
+ elif role == "assistant":
168
+ converted.append({"role": "assistant", "content": content})
169
+ elif role == "user":
170
+ converted.append({"role": "user", "content": content})
171
+ elif role == "function":
172
+ # Convert function result to user message
173
+ converted.append({
174
+ "role": "user",
175
+ "content": f"Function result for {msg.get('name', 'function')}: {content}",
176
+ })
177
+ elif role == "tool":
178
+ # Convert tool result
179
+ converted.append({
180
+ "role": "user",
181
+ "content": [{
182
+ "type": "tool_result",
183
+ "tool_use_id": msg.get("tool_call_id", ""),
184
+ "content": content,
185
+ }],
186
+ })
187
+
188
+ return system_message, converted
189
+
190
+ def _convert_tools(
191
+ self, tools: list[dict[str, Any]] | None
192
+ ) -> list[dict[str, Any]] | None:
193
+ """Convert OpenAI-style tools to Anthropic format."""
194
+ if not tools:
195
+ return None
196
+
197
+ converted = []
198
+ for tool in tools:
199
+ if tool.get("type") == "function":
200
+ func = tool["function"]
201
+ converted.append({
202
+ "name": func["name"],
203
+ "description": func.get("description", ""),
204
+ "input_schema": func.get("parameters", {"type": "object", "properties": {}}),
205
+ })
206
+ return converted if converted else None
207
+
208
+ async def complete(
209
+ self,
210
+ messages: list[dict[str, Any]],
211
+ model: str,
212
+ temperature: float = 0.7,
213
+ max_tokens: int | None = None,
214
+ functions: list[dict[str, Any]] | None = None,
215
+ function_call: str | dict[str, str] | None = None,
216
+ tools: list[dict[str, Any]] | None = None,
217
+ tool_choice: str | dict[str, Any] | None = None,
218
+ stop: list[str] | None = None,
219
+ **kwargs: Any,
220
+ ) -> CompletionResponse:
221
+ """Generate a completion using Anthropic API."""
222
+ await self._acquire_rate_limit()
223
+
224
+ model = self._resolve_model(model)
225
+ model_info = self.get_model_info(model)
226
+ if not model_info:
227
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
228
+
229
+ client = await self._ensure_client()
230
+
231
+ # Convert messages
232
+ system_message, converted_messages = self._convert_messages(messages)
233
+
234
+ # Build request payload
235
+ payload: dict[str, Any] = {
236
+ "model": model,
237
+ "messages": converted_messages,
238
+ "max_tokens": max_tokens or model_info.max_output_tokens,
239
+ }
240
+
241
+ if system_message:
242
+ payload["system"] = system_message
243
+
244
+ if temperature is not None:
245
+ payload["temperature"] = temperature
246
+
247
+ if stop:
248
+ payload["stop_sequences"] = stop
249
+
250
+ # Convert tools (prefer tools over functions)
251
+ anthropic_tools = self._convert_tools(tools)
252
+ if not anthropic_tools and functions:
253
+ # Convert legacy functions format
254
+ anthropic_tools = [
255
+ {
256
+ "name": f["name"],
257
+ "description": f.get("description", ""),
258
+ "input_schema": f.get("parameters", {"type": "object", "properties": {}}),
259
+ }
260
+ for f in functions
261
+ ]
262
+
263
+ if anthropic_tools:
264
+ payload["tools"] = anthropic_tools
265
+
266
+ # Handle tool choice
267
+ if tool_choice == "auto" or tool_choice is None:
268
+ payload["tool_choice"] = {"type": "auto"}
269
+ elif tool_choice == "required":
270
+ payload["tool_choice"] = {"type": "any"}
271
+ elif isinstance(tool_choice, dict) and "function" in tool_choice:
272
+ payload["tool_choice"] = {"type": "tool", "name": tool_choice["function"]["name"]}
273
+
274
+ start_time = time.time()
275
+
276
+ try:
277
+ response = await self._retry_with_backoff(
278
+ self._make_request, client, payload
279
+ )
280
+ except httpx.HTTPStatusError as e:
281
+ self._handle_http_error(e)
282
+
283
+ latency_ms = (time.time() - start_time) * 1000
284
+
285
+ # Parse response
286
+ content_blocks = response.get("content", [])
287
+ usage_data = response.get("usage", {})
288
+
289
+ # Extract text content and tool uses
290
+ text_content = ""
291
+ tool_calls = []
292
+
293
+ for block in content_blocks:
294
+ if block["type"] == "text":
295
+ text_content += block["text"]
296
+ elif block["type"] == "tool_use":
297
+ tool_calls.append({
298
+ "id": block["id"],
299
+ "type": "function",
300
+ "function": {
301
+ "name": block["name"],
302
+ "arguments": json.dumps(block["input"]),
303
+ },
304
+ })
305
+
306
+ usage = TokenUsage(
307
+ prompt_tokens=usage_data.get("input_tokens", 0),
308
+ completion_tokens=usage_data.get("output_tokens", 0),
309
+ total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0),
310
+ )
311
+
312
+ cost = self.calculate_cost(model, usage)
313
+ self._track_usage(usage, cost)
314
+
315
+ return CompletionResponse(
316
+ content=text_content,
317
+ model=response.get("model", model),
318
+ provider=self.PROVIDER_NAME,
319
+ usage=usage,
320
+ finish_reason=response.get("stop_reason"),
321
+ function_call=None,
322
+ tool_calls=tool_calls if tool_calls else None,
323
+ raw_response=response,
324
+ latency_ms=latency_ms,
325
+ cost=cost,
326
+ )
327
+
328
+ async def _make_request(
329
+ self, client: httpx.AsyncClient, payload: dict[str, Any]
330
+ ) -> dict[str, Any]:
331
+ """Make the API request."""
332
+ response = await client.post("/messages", json=payload)
333
+ response.raise_for_status()
334
+ return response.json()
335
+
336
+ def _handle_http_error(self, error: httpx.HTTPStatusError) -> None:
337
+ """Handle HTTP errors from Anthropic."""
338
+ status = error.response.status_code
339
+ try:
340
+ body = error.response.json()
341
+ message = body.get("error", {}).get("message", str(error))
342
+ except Exception:
343
+ message = str(error)
344
+
345
+ if status == 401:
346
+ raise AuthenticationError(self.PROVIDER_NAME, message)
347
+ elif status == 429:
348
+ retry_after = error.response.headers.get("retry-after")
349
+ raise RateLimitError(
350
+ self.PROVIDER_NAME,
351
+ retry_after=float(retry_after) if retry_after else None,
352
+ message=message,
353
+ )
354
+ elif status == 404:
355
+ raise ModelNotFoundError(self.PROVIDER_NAME, "unknown")
356
+ else:
357
+ raise ProviderError(message, self.PROVIDER_NAME, status)
358
+
359
+ async def stream(
360
+ self,
361
+ messages: list[dict[str, Any]],
362
+ model: str,
363
+ temperature: float = 0.7,
364
+ max_tokens: int | None = None,
365
+ **kwargs: Any,
366
+ ) -> AsyncIterator[str]:
367
+ """Stream a completion from Anthropic."""
368
+ await self._acquire_rate_limit()
369
+
370
+ model = self._resolve_model(model)
371
+ model_info = self.get_model_info(model)
372
+ if not model_info:
373
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
374
+
375
+ client = await self._ensure_client()
376
+
377
+ system_message, converted_messages = self._convert_messages(messages)
378
+
379
+ payload: dict[str, Any] = {
380
+ "model": model,
381
+ "messages": converted_messages,
382
+ "max_tokens": max_tokens or model_info.max_output_tokens,
383
+ "stream": True,
384
+ }
385
+
386
+ if system_message:
387
+ payload["system"] = system_message
388
+
389
+ if temperature is not None:
390
+ payload["temperature"] = temperature
391
+
392
+ try:
393
+ async with client.stream("POST", "/messages", json=payload) as response:
394
+ response.raise_for_status()
395
+
396
+ async for line in response.aiter_lines():
397
+ if line.startswith("data: "):
398
+ data = line[6:]
399
+
400
+ try:
401
+ event = json.loads(data)
402
+ event_type = event.get("type")
403
+
404
+ if event_type == "content_block_delta":
405
+ delta = event.get("delta", {})
406
+ if delta.get("type") == "text_delta":
407
+ yield delta.get("text", "")
408
+
409
+ except json.JSONDecodeError:
410
+ continue
411
+
412
+ except httpx.HTTPStatusError as e:
413
+ self._handle_http_error(e)
backend/app/models/providers/base.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base provider abstract class and common types."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from dataclasses import dataclass, field
5
+ from datetime import datetime
6
+ from enum import Enum
7
+ from typing import Any, AsyncIterator, Callable
8
+ import asyncio
9
+ import time
10
+
11
+
12
+ class ProviderError(Exception):
13
+ """Base exception for provider errors."""
14
+
15
+ def __init__(self, message: str, provider: str, status_code: int | None = None):
16
+ self.message = message
17
+ self.provider = provider
18
+ self.status_code = status_code
19
+ super().__init__(f"[{provider}] {message}")
20
+
21
+
22
+ class RateLimitError(ProviderError):
23
+ """Rate limit exceeded error."""
24
+
25
+ def __init__(
26
+ self,
27
+ provider: str,
28
+ retry_after: float | None = None,
29
+ message: str = "Rate limit exceeded",
30
+ ):
31
+ self.retry_after = retry_after
32
+ super().__init__(message, provider, status_code=429)
33
+
34
+
35
+ class ModelNotFoundError(ProviderError):
36
+ """Model not found or not available error."""
37
+
38
+ def __init__(self, provider: str, model: str):
39
+ super().__init__(f"Model '{model}' not found", provider, status_code=404)
40
+
41
+
42
+ class AuthenticationError(ProviderError):
43
+ """Authentication failed error."""
44
+
45
+ def __init__(self, provider: str, message: str = "Authentication failed"):
46
+ super().__init__(message, provider, status_code=401)
47
+
48
+
49
+ @dataclass
50
+ class TokenUsage:
51
+ """Token usage tracking."""
52
+
53
+ prompt_tokens: int = 0
54
+ completion_tokens: int = 0
55
+ total_tokens: int = 0
56
+
57
+ def __add__(self, other: "TokenUsage") -> "TokenUsage":
58
+ return TokenUsage(
59
+ prompt_tokens=self.prompt_tokens + other.prompt_tokens,
60
+ completion_tokens=self.completion_tokens + other.completion_tokens,
61
+ total_tokens=self.total_tokens + other.total_tokens,
62
+ )
63
+
64
+
65
+ @dataclass
66
+ class CompletionResponse:
67
+ """Standardized completion response across providers."""
68
+
69
+ content: str
70
+ model: str
71
+ provider: str
72
+ usage: TokenUsage
73
+ finish_reason: str | None = None
74
+ function_call: dict[str, Any] | None = None
75
+ tool_calls: list[dict[str, Any]] | None = None
76
+ raw_response: dict[str, Any] | None = None
77
+ latency_ms: float = 0.0
78
+ cost: float = 0.0
79
+ timestamp: datetime = field(default_factory=datetime.utcnow)
80
+
81
+ def to_dict(self) -> dict[str, Any]:
82
+ """Convert response to dictionary."""
83
+ return {
84
+ "content": self.content,
85
+ "model": self.model,
86
+ "provider": self.provider,
87
+ "usage": {
88
+ "prompt_tokens": self.usage.prompt_tokens,
89
+ "completion_tokens": self.usage.completion_tokens,
90
+ "total_tokens": self.usage.total_tokens,
91
+ },
92
+ "finish_reason": self.finish_reason,
93
+ "function_call": self.function_call,
94
+ "tool_calls": self.tool_calls,
95
+ "latency_ms": self.latency_ms,
96
+ "cost": self.cost,
97
+ "timestamp": self.timestamp.isoformat(),
98
+ }
99
+
100
+
101
+ @dataclass
102
+ class ModelInfo:
103
+ """Model information and capabilities."""
104
+
105
+ id: str
106
+ name: str
107
+ provider: str
108
+ context_window: int
109
+ max_output_tokens: int
110
+ supports_functions: bool = False
111
+ supports_vision: bool = False
112
+ supports_streaming: bool = True
113
+ cost_per_1k_input: float = 0.0
114
+ cost_per_1k_output: float = 0.0
115
+
116
+ @property
117
+ def cost_per_million_input(self) -> float:
118
+ """Cost per million input tokens."""
119
+ return self.cost_per_1k_input * 1000
120
+
121
+ @property
122
+ def cost_per_million_output(self) -> float:
123
+ """Cost per million output tokens."""
124
+ return self.cost_per_1k_output * 1000
125
+
126
+
127
+ class TaskType(str, Enum):
128
+ """Types of tasks for model routing."""
129
+
130
+ GENERAL = "general"
131
+ CODE = "code"
132
+ REASONING = "reasoning"
133
+ EXTRACTION = "extraction"
134
+ SUMMARIZATION = "summarization"
135
+ CLASSIFICATION = "classification"
136
+ CREATIVE = "creative"
137
+ FAST = "fast"
138
+
139
+
140
+ @dataclass
141
+ class RateLimitState:
142
+ """Rate limiter state."""
143
+
144
+ tokens: float
145
+ last_update: float
146
+ max_tokens: float
147
+ refill_rate: float # tokens per second
148
+
149
+
150
+ class BaseProvider(ABC):
151
+ """Abstract base class for LLM providers."""
152
+
153
+ PROVIDER_NAME: str = "base"
154
+
155
+ def __init__(
156
+ self,
157
+ api_key: str,
158
+ base_url: str | None = None,
159
+ timeout: float = 60.0,
160
+ max_retries: int = 3,
161
+ rate_limit_rpm: int = 60,
162
+ ):
163
+ self.api_key = api_key
164
+ self.base_url = base_url
165
+ self.timeout = timeout
166
+ self.max_retries = max_retries
167
+
168
+ # Rate limiting (token bucket)
169
+ self._rate_limit = RateLimitState(
170
+ tokens=rate_limit_rpm,
171
+ last_update=time.time(),
172
+ max_tokens=rate_limit_rpm,
173
+ refill_rate=rate_limit_rpm / 60.0,
174
+ )
175
+ self._rate_limit_lock = asyncio.Lock()
176
+
177
+ # Usage tracking
178
+ self._total_usage = TokenUsage()
179
+ self._total_cost: float = 0.0
180
+ self._request_count: int = 0
181
+
182
+ @abstractmethod
183
+ async def complete(
184
+ self,
185
+ messages: list[dict[str, Any]],
186
+ model: str,
187
+ temperature: float = 0.7,
188
+ max_tokens: int | None = None,
189
+ functions: list[dict[str, Any]] | None = None,
190
+ function_call: str | dict[str, str] | None = None,
191
+ tools: list[dict[str, Any]] | None = None,
192
+ tool_choice: str | dict[str, Any] | None = None,
193
+ stop: list[str] | None = None,
194
+ **kwargs: Any,
195
+ ) -> CompletionResponse:
196
+ """Generate a completion from the model.
197
+
198
+ Args:
199
+ messages: List of message dicts with 'role' and 'content'
200
+ model: Model identifier
201
+ temperature: Sampling temperature (0-2)
202
+ max_tokens: Maximum tokens to generate
203
+ functions: Function definitions for function calling
204
+ function_call: Function call mode or specific function
205
+ tools: Tool definitions (newer format)
206
+ tool_choice: Tool choice mode or specific tool
207
+ stop: Stop sequences
208
+ **kwargs: Additional provider-specific parameters
209
+
210
+ Returns:
211
+ CompletionResponse with generated content and metadata
212
+ """
213
+ ...
214
+
215
+ @abstractmethod
216
+ async def stream(
217
+ self,
218
+ messages: list[dict[str, Any]],
219
+ model: str,
220
+ temperature: float = 0.7,
221
+ max_tokens: int | None = None,
222
+ **kwargs: Any,
223
+ ) -> AsyncIterator[str]:
224
+ """Stream a completion from the model.
225
+
226
+ Args:
227
+ messages: List of message dicts
228
+ model: Model identifier
229
+ temperature: Sampling temperature
230
+ max_tokens: Maximum tokens to generate
231
+ **kwargs: Additional parameters
232
+
233
+ Yields:
234
+ Content chunks as they arrive
235
+ """
236
+ ...
237
+
238
+ @abstractmethod
239
+ def get_models(self) -> list[ModelInfo]:
240
+ """Get list of available models from this provider.
241
+
242
+ Returns:
243
+ List of ModelInfo objects
244
+ """
245
+ ...
246
+
247
+ def get_model_info(self, model_id: str) -> ModelInfo | None:
248
+ """Get info for a specific model.
249
+
250
+ Args:
251
+ model_id: Model identifier
252
+
253
+ Returns:
254
+ ModelInfo or None if not found
255
+ """
256
+ for model in self.get_models():
257
+ if model.id == model_id:
258
+ return model
259
+ return None
260
+
261
+ def calculate_cost(self, model: str, usage: TokenUsage) -> float:
262
+ """Calculate cost for a completion.
263
+
264
+ Args:
265
+ model: Model identifier
266
+ usage: Token usage
267
+
268
+ Returns:
269
+ Cost in USD
270
+ """
271
+ model_info = self.get_model_info(model)
272
+ if not model_info:
273
+ return 0.0
274
+
275
+ input_cost = (usage.prompt_tokens / 1000) * model_info.cost_per_1k_input
276
+ output_cost = (usage.completion_tokens / 1000) * model_info.cost_per_1k_output
277
+ return input_cost + output_cost
278
+
279
+ async def _acquire_rate_limit(self) -> None:
280
+ """Acquire a token from the rate limiter."""
281
+ async with self._rate_limit_lock:
282
+ now = time.time()
283
+ elapsed = now - self._rate_limit.last_update
284
+
285
+ # Refill tokens
286
+ self._rate_limit.tokens = min(
287
+ self._rate_limit.max_tokens,
288
+ self._rate_limit.tokens + elapsed * self._rate_limit.refill_rate,
289
+ )
290
+ self._rate_limit.last_update = now
291
+
292
+ if self._rate_limit.tokens < 1:
293
+ # Calculate wait time
294
+ wait_time = (1 - self._rate_limit.tokens) / self._rate_limit.refill_rate
295
+ await asyncio.sleep(wait_time)
296
+ self._rate_limit.tokens = 0
297
+ else:
298
+ self._rate_limit.tokens -= 1
299
+
300
+ def _track_usage(self, usage: TokenUsage, cost: float) -> None:
301
+ """Track usage and cost."""
302
+ self._total_usage = self._total_usage + usage
303
+ self._total_cost += cost
304
+ self._request_count += 1
305
+
306
+ @property
307
+ def total_usage(self) -> TokenUsage:
308
+ """Get total token usage."""
309
+ return self._total_usage
310
+
311
+ @property
312
+ def total_cost(self) -> float:
313
+ """Get total cost in USD."""
314
+ return self._total_cost
315
+
316
+ @property
317
+ def request_count(self) -> int:
318
+ """Get total request count."""
319
+ return self._request_count
320
+
321
+ def reset_tracking(self) -> None:
322
+ """Reset usage tracking."""
323
+ self._total_usage = TokenUsage()
324
+ self._total_cost = 0.0
325
+ self._request_count = 0
326
+
327
+ async def _retry_with_backoff(
328
+ self,
329
+ func: Callable,
330
+ *args: Any,
331
+ **kwargs: Any,
332
+ ) -> Any:
333
+ """Retry a function with exponential backoff.
334
+
335
+ Args:
336
+ func: Async function to retry
337
+ *args: Positional arguments
338
+ **kwargs: Keyword arguments
339
+
340
+ Returns:
341
+ Function result
342
+
343
+ Raises:
344
+ Last exception if all retries fail
345
+ """
346
+ last_exception: Exception | None = None
347
+
348
+ for attempt in range(self.max_retries):
349
+ try:
350
+ return await func(*args, **kwargs)
351
+ except RateLimitError as e:
352
+ last_exception = e
353
+ wait_time = e.retry_after or (2**attempt)
354
+ await asyncio.sleep(wait_time)
355
+ except ProviderError as e:
356
+ # Don't retry auth or not found errors
357
+ if e.status_code in (401, 403, 404):
358
+ raise
359
+ last_exception = e
360
+ await asyncio.sleep(2**attempt)
361
+
362
+ if last_exception:
363
+ raise last_exception
364
+
365
+ async def initialize(self) -> None:
366
+ """Initialize the provider (optional setup)."""
367
+ pass
368
+
369
+ async def shutdown(self) -> None:
370
+ """Cleanup resources."""
371
+ pass
372
+
373
+ def __repr__(self) -> str:
374
+ return f"{self.__class__.__name__}(requests={self._request_count}, cost=${self._total_cost:.4f})"
backend/app/models/providers/google.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Google AI provider implementation (Gemini models)."""
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, AsyncIterator
6
+
7
+ import httpx
8
+
9
+ from app.models.providers.base import (
10
+ AuthenticationError,
11
+ BaseProvider,
12
+ CompletionResponse,
13
+ ModelInfo,
14
+ ModelNotFoundError,
15
+ ProviderError,
16
+ RateLimitError,
17
+ TokenUsage,
18
+ )
19
+
20
+
21
+ class GoogleProvider(BaseProvider):
22
+ """Google AI API provider supporting Gemini models."""
23
+
24
+ PROVIDER_NAME = "google"
25
+ DEFAULT_BASE_URL = "https://generativelanguage.googleapis.com/v1beta"
26
+
27
+ # Model definitions with pricing (per 1K tokens)
28
+ MODELS = {
29
+ "gemini-1.5-pro": ModelInfo(
30
+ id="gemini-1.5-pro",
31
+ name="Gemini 1.5 Pro",
32
+ provider="google",
33
+ context_window=2097152,
34
+ max_output_tokens=8192,
35
+ supports_functions=True,
36
+ supports_vision=True,
37
+ supports_streaming=True,
38
+ cost_per_1k_input=0.00125,
39
+ cost_per_1k_output=0.005,
40
+ ),
41
+ "gemini-1.5-flash": ModelInfo(
42
+ id="gemini-1.5-flash",
43
+ name="Gemini 1.5 Flash",
44
+ provider="google",
45
+ context_window=1048576,
46
+ max_output_tokens=8192,
47
+ supports_functions=True,
48
+ supports_vision=True,
49
+ supports_streaming=True,
50
+ cost_per_1k_input=0.000075,
51
+ cost_per_1k_output=0.0003,
52
+ ),
53
+ "gemini-2.0-flash-exp": ModelInfo(
54
+ id="gemini-2.0-flash-exp",
55
+ name="Gemini 2.0 Flash (Experimental)",
56
+ provider="google",
57
+ context_window=1048576,
58
+ max_output_tokens=8192,
59
+ supports_functions=True,
60
+ supports_vision=True,
61
+ supports_streaming=True,
62
+ cost_per_1k_input=0.0,
63
+ cost_per_1k_output=0.0,
64
+ ),
65
+ "gemini-pro": ModelInfo(
66
+ id="gemini-pro",
67
+ name="Gemini Pro",
68
+ provider="google",
69
+ context_window=32760,
70
+ max_output_tokens=8192,
71
+ supports_functions=True,
72
+ supports_vision=False,
73
+ supports_streaming=True,
74
+ cost_per_1k_input=0.0005,
75
+ cost_per_1k_output=0.0015,
76
+ ),
77
+ }
78
+
79
+ # Aliases
80
+ MODEL_ALIASES = {
81
+ "gemini-flash": "gemini-1.5-flash",
82
+ "gemini-1.5": "gemini-1.5-pro",
83
+ }
84
+
85
+ def __init__(
86
+ self,
87
+ api_key: str,
88
+ base_url: str | None = None,
89
+ timeout: float = 60.0,
90
+ max_retries: int = 3,
91
+ rate_limit_rpm: int = 60,
92
+ ):
93
+ super().__init__(
94
+ api_key=api_key,
95
+ base_url=base_url or self.DEFAULT_BASE_URL,
96
+ timeout=timeout,
97
+ max_retries=max_retries,
98
+ rate_limit_rpm=rate_limit_rpm,
99
+ )
100
+ self._client: httpx.AsyncClient | None = None
101
+
102
+ async def initialize(self) -> None:
103
+ """Initialize the HTTP client."""
104
+ self._client = httpx.AsyncClient(
105
+ base_url=self.base_url,
106
+ headers={"Content-Type": "application/json"},
107
+ timeout=self.timeout,
108
+ )
109
+
110
+ async def shutdown(self) -> None:
111
+ """Close the HTTP client."""
112
+ if self._client:
113
+ await self._client.aclose()
114
+ self._client = None
115
+
116
+ async def _ensure_client(self) -> httpx.AsyncClient:
117
+ """Ensure client is initialized."""
118
+ if not self._client:
119
+ await self.initialize()
120
+ return self._client # type: ignore
121
+
122
+ def _resolve_model(self, model: str) -> str:
123
+ """Resolve model alias to full model ID."""
124
+ return self.MODEL_ALIASES.get(model, model)
125
+
126
+ def get_models(self) -> list[ModelInfo]:
127
+ """Get available Google AI models."""
128
+ return list(self.MODELS.values())
129
+
130
+ def _convert_messages(
131
+ self, messages: list[dict[str, Any]]
132
+ ) -> tuple[str | None, list[dict[str, Any]]]:
133
+ """Convert OpenAI-style messages to Gemini format.
134
+
135
+ Returns:
136
+ Tuple of (system_instruction, contents)
137
+ """
138
+ system_instruction: str | None = None
139
+ contents: list[dict[str, Any]] = []
140
+
141
+ for msg in messages:
142
+ role = msg["role"]
143
+ content = msg["content"]
144
+
145
+ if role == "system":
146
+ system_instruction = content
147
+ elif role == "assistant":
148
+ contents.append({
149
+ "role": "model",
150
+ "parts": [{"text": content}] if isinstance(content, str) else content,
151
+ })
152
+ elif role == "user":
153
+ contents.append({
154
+ "role": "user",
155
+ "parts": [{"text": content}] if isinstance(content, str) else content,
156
+ })
157
+ elif role == "function":
158
+ # Function response
159
+ contents.append({
160
+ "role": "function",
161
+ "parts": [{
162
+ "functionResponse": {
163
+ "name": msg.get("name", "function"),
164
+ "response": {"result": content},
165
+ }
166
+ }],
167
+ })
168
+ elif role == "tool":
169
+ # Tool response
170
+ contents.append({
171
+ "role": "function",
172
+ "parts": [{
173
+ "functionResponse": {
174
+ "name": msg.get("tool_call_id", "tool"),
175
+ "response": {"result": content},
176
+ }
177
+ }],
178
+ })
179
+
180
+ return system_instruction, contents
181
+
182
+ def _convert_tools(
183
+ self, tools: list[dict[str, Any]] | None
184
+ ) -> list[dict[str, Any]] | None:
185
+ """Convert OpenAI-style tools to Gemini format."""
186
+ if not tools:
187
+ return None
188
+
189
+ function_declarations = []
190
+ for tool in tools:
191
+ if tool.get("type") == "function":
192
+ func = tool["function"]
193
+ function_declarations.append({
194
+ "name": func["name"],
195
+ "description": func.get("description", ""),
196
+ "parameters": func.get("parameters", {"type": "object", "properties": {}}),
197
+ })
198
+
199
+ return [{"functionDeclarations": function_declarations}] if function_declarations else None
200
+
201
+ async def complete(
202
+ self,
203
+ messages: list[dict[str, Any]],
204
+ model: str,
205
+ temperature: float = 0.7,
206
+ max_tokens: int | None = None,
207
+ functions: list[dict[str, Any]] | None = None,
208
+ function_call: str | dict[str, str] | None = None,
209
+ tools: list[dict[str, Any]] | None = None,
210
+ tool_choice: str | dict[str, Any] | None = None,
211
+ stop: list[str] | None = None,
212
+ **kwargs: Any,
213
+ ) -> CompletionResponse:
214
+ """Generate a completion using Google AI API."""
215
+ await self._acquire_rate_limit()
216
+
217
+ model = self._resolve_model(model)
218
+ model_info = self.get_model_info(model)
219
+ if not model_info:
220
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
221
+
222
+ client = await self._ensure_client()
223
+
224
+ # Convert messages
225
+ system_instruction, contents = self._convert_messages(messages)
226
+
227
+ # Build request payload
228
+ payload: dict[str, Any] = {
229
+ "contents": contents,
230
+ "generationConfig": {
231
+ "temperature": temperature,
232
+ },
233
+ }
234
+
235
+ if max_tokens:
236
+ payload["generationConfig"]["maxOutputTokens"] = max_tokens
237
+
238
+ if stop:
239
+ payload["generationConfig"]["stopSequences"] = stop
240
+
241
+ if system_instruction:
242
+ payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
243
+
244
+ # Convert tools
245
+ gemini_tools = self._convert_tools(tools)
246
+ if not gemini_tools and functions:
247
+ gemini_tools = [{
248
+ "functionDeclarations": [
249
+ {
250
+ "name": f["name"],
251
+ "description": f.get("description", ""),
252
+ "parameters": f.get("parameters", {"type": "object", "properties": {}}),
253
+ }
254
+ for f in functions
255
+ ]
256
+ }]
257
+
258
+ if gemini_tools:
259
+ payload["tools"] = gemini_tools
260
+
261
+ start_time = time.time()
262
+
263
+ url = f"/models/{model}:generateContent?key={self.api_key}"
264
+
265
+ try:
266
+ response = await self._retry_with_backoff(
267
+ self._make_request, client, url, payload
268
+ )
269
+ except httpx.HTTPStatusError as e:
270
+ self._handle_http_error(e)
271
+
272
+ latency_ms = (time.time() - start_time) * 1000
273
+
274
+ # Parse response
275
+ candidates = response.get("candidates", [])
276
+ if not candidates:
277
+ raise ProviderError("No candidates in response", self.PROVIDER_NAME)
278
+
279
+ candidate = candidates[0]
280
+ content_parts = candidate.get("content", {}).get("parts", [])
281
+
282
+ # Extract text content and function calls
283
+ text_content = ""
284
+ tool_calls = []
285
+
286
+ for part in content_parts:
287
+ if "text" in part:
288
+ text_content += part["text"]
289
+ elif "functionCall" in part:
290
+ fc = part["functionCall"]
291
+ tool_calls.append({
292
+ "id": f"call_{fc['name']}",
293
+ "type": "function",
294
+ "function": {
295
+ "name": fc["name"],
296
+ "arguments": json.dumps(fc.get("args", {})),
297
+ },
298
+ })
299
+
300
+ # Parse usage
301
+ usage_data = response.get("usageMetadata", {})
302
+ usage = TokenUsage(
303
+ prompt_tokens=usage_data.get("promptTokenCount", 0),
304
+ completion_tokens=usage_data.get("candidatesTokenCount", 0),
305
+ total_tokens=usage_data.get("totalTokenCount", 0),
306
+ )
307
+
308
+ cost = self.calculate_cost(model, usage)
309
+ self._track_usage(usage, cost)
310
+
311
+ # Map finish reason
312
+ finish_reason_map = {
313
+ "STOP": "stop",
314
+ "MAX_TOKENS": "length",
315
+ "SAFETY": "content_filter",
316
+ "RECITATION": "content_filter",
317
+ }
318
+ finish_reason = finish_reason_map.get(
319
+ candidate.get("finishReason", ""), candidate.get("finishReason")
320
+ )
321
+
322
+ return CompletionResponse(
323
+ content=text_content,
324
+ model=model,
325
+ provider=self.PROVIDER_NAME,
326
+ usage=usage,
327
+ finish_reason=finish_reason,
328
+ function_call=None,
329
+ tool_calls=tool_calls if tool_calls else None,
330
+ raw_response=response,
331
+ latency_ms=latency_ms,
332
+ cost=cost,
333
+ )
334
+
335
+ async def _make_request(
336
+ self, client: httpx.AsyncClient, url: str, payload: dict[str, Any]
337
+ ) -> dict[str, Any]:
338
+ """Make the API request."""
339
+ response = await client.post(url, json=payload)
340
+ response.raise_for_status()
341
+ return response.json()
342
+
343
+ def _handle_http_error(self, error: httpx.HTTPStatusError) -> None:
344
+ """Handle HTTP errors from Google AI."""
345
+ status = error.response.status_code
346
+ try:
347
+ body = error.response.json()
348
+ message = body.get("error", {}).get("message", str(error))
349
+ except Exception:
350
+ message = str(error)
351
+
352
+ if status == 401 or status == 403:
353
+ raise AuthenticationError(self.PROVIDER_NAME, message)
354
+ elif status == 429:
355
+ retry_after = error.response.headers.get("retry-after")
356
+ raise RateLimitError(
357
+ self.PROVIDER_NAME,
358
+ retry_after=float(retry_after) if retry_after else None,
359
+ message=message,
360
+ )
361
+ elif status == 404:
362
+ raise ModelNotFoundError(self.PROVIDER_NAME, "unknown")
363
+ else:
364
+ raise ProviderError(message, self.PROVIDER_NAME, status)
365
+
366
+ async def stream(
367
+ self,
368
+ messages: list[dict[str, Any]],
369
+ model: str,
370
+ temperature: float = 0.7,
371
+ max_tokens: int | None = None,
372
+ **kwargs: Any,
373
+ ) -> AsyncIterator[str]:
374
+ """Stream a completion from Google AI."""
375
+ await self._acquire_rate_limit()
376
+
377
+ model = self._resolve_model(model)
378
+ model_info = self.get_model_info(model)
379
+ if not model_info:
380
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
381
+
382
+ client = await self._ensure_client()
383
+
384
+ system_instruction, contents = self._convert_messages(messages)
385
+
386
+ payload: dict[str, Any] = {
387
+ "contents": contents,
388
+ "generationConfig": {
389
+ "temperature": temperature,
390
+ },
391
+ }
392
+
393
+ if max_tokens:
394
+ payload["generationConfig"]["maxOutputTokens"] = max_tokens
395
+
396
+ if system_instruction:
397
+ payload["systemInstruction"] = {"parts": [{"text": system_instruction}]}
398
+
399
+ url = f"/models/{model}:streamGenerateContent?key={self.api_key}&alt=sse"
400
+
401
+ try:
402
+ async with client.stream("POST", url, json=payload) as response:
403
+ response.raise_for_status()
404
+
405
+ async for line in response.aiter_lines():
406
+ if line.startswith("data: "):
407
+ data = line[6:]
408
+
409
+ try:
410
+ chunk = json.loads(data)
411
+ candidates = chunk.get("candidates", [])
412
+ if candidates:
413
+ parts = candidates[0].get("content", {}).get("parts", [])
414
+ for part in parts:
415
+ if "text" in part:
416
+ yield part["text"]
417
+ except json.JSONDecodeError:
418
+ continue
419
+
420
+ except httpx.HTTPStatusError as e:
421
+ self._handle_http_error(e)
backend/app/models/providers/groq.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Groq provider implementation (fast inference)."""
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, AsyncIterator
6
+
7
+ import httpx
8
+
9
+ from app.models.providers.base import (
10
+ AuthenticationError,
11
+ BaseProvider,
12
+ CompletionResponse,
13
+ ModelInfo,
14
+ ModelNotFoundError,
15
+ ProviderError,
16
+ RateLimitError,
17
+ TokenUsage,
18
+ )
19
+
20
+
21
+ class GroqProvider(BaseProvider):
22
+ """Groq API provider for fast LLM inference."""
23
+
24
+ PROVIDER_NAME = "groq"
25
+ DEFAULT_BASE_URL = "https://api.groq.com/openai/v1"
26
+
27
+ # Model definitions with pricing (per 1K tokens)
28
+ MODELS = {
29
+ "llama-3.3-70b-versatile": ModelInfo(
30
+ id="llama-3.3-70b-versatile",
31
+ name="Llama 3.3 70B Versatile",
32
+ provider="groq",
33
+ context_window=128000,
34
+ max_output_tokens=32768,
35
+ supports_functions=True,
36
+ supports_vision=False,
37
+ supports_streaming=True,
38
+ cost_per_1k_input=0.00059,
39
+ cost_per_1k_output=0.00079,
40
+ ),
41
+ "llama-3.1-70b-versatile": ModelInfo(
42
+ id="llama-3.1-70b-versatile",
43
+ name="Llama 3.1 70B Versatile",
44
+ provider="groq",
45
+ context_window=128000,
46
+ max_output_tokens=32768,
47
+ supports_functions=True,
48
+ supports_vision=False,
49
+ supports_streaming=True,
50
+ cost_per_1k_input=0.00059,
51
+ cost_per_1k_output=0.00079,
52
+ ),
53
+ "llama-3.1-8b-instant": ModelInfo(
54
+ id="llama-3.1-8b-instant",
55
+ name="Llama 3.1 8B Instant",
56
+ provider="groq",
57
+ context_window=128000,
58
+ max_output_tokens=8000,
59
+ supports_functions=True,
60
+ supports_vision=False,
61
+ supports_streaming=True,
62
+ cost_per_1k_input=0.00005,
63
+ cost_per_1k_output=0.00008,
64
+ ),
65
+ "llama3-70b-8192": ModelInfo(
66
+ id="llama3-70b-8192",
67
+ name="Llama 3 70B",
68
+ provider="groq",
69
+ context_window=8192,
70
+ max_output_tokens=8192,
71
+ supports_functions=True,
72
+ supports_vision=False,
73
+ supports_streaming=True,
74
+ cost_per_1k_input=0.00059,
75
+ cost_per_1k_output=0.00079,
76
+ ),
77
+ "llama3-8b-8192": ModelInfo(
78
+ id="llama3-8b-8192",
79
+ name="Llama 3 8B",
80
+ provider="groq",
81
+ context_window=8192,
82
+ max_output_tokens=8192,
83
+ supports_functions=True,
84
+ supports_vision=False,
85
+ supports_streaming=True,
86
+ cost_per_1k_input=0.00005,
87
+ cost_per_1k_output=0.00008,
88
+ ),
89
+ "mixtral-8x7b-32768": ModelInfo(
90
+ id="mixtral-8x7b-32768",
91
+ name="Mixtral 8x7B",
92
+ provider="groq",
93
+ context_window=32768,
94
+ max_output_tokens=32768,
95
+ supports_functions=True,
96
+ supports_vision=False,
97
+ supports_streaming=True,
98
+ cost_per_1k_input=0.00024,
99
+ cost_per_1k_output=0.00024,
100
+ ),
101
+ "gemma2-9b-it": ModelInfo(
102
+ id="gemma2-9b-it",
103
+ name="Gemma 2 9B IT",
104
+ provider="groq",
105
+ context_window=8192,
106
+ max_output_tokens=8192,
107
+ supports_functions=True,
108
+ supports_vision=False,
109
+ supports_streaming=True,
110
+ cost_per_1k_input=0.00020,
111
+ cost_per_1k_output=0.00020,
112
+ ),
113
+ }
114
+
115
+ # Aliases for convenience
116
+ MODEL_ALIASES = {
117
+ "llama3": "llama3-70b-8192",
118
+ "llama3-70b": "llama3-70b-8192",
119
+ "llama3-8b": "llama3-8b-8192",
120
+ "llama-3.1": "llama-3.1-70b-versatile",
121
+ "llama-3.3": "llama-3.3-70b-versatile",
122
+ "mixtral": "mixtral-8x7b-32768",
123
+ "gemma2": "gemma2-9b-it",
124
+ }
125
+
126
+ def __init__(
127
+ self,
128
+ api_key: str,
129
+ base_url: str | None = None,
130
+ timeout: float = 60.0,
131
+ max_retries: int = 3,
132
+ rate_limit_rpm: int = 30, # Groq has stricter limits
133
+ ):
134
+ super().__init__(
135
+ api_key=api_key,
136
+ base_url=base_url or self.DEFAULT_BASE_URL,
137
+ timeout=timeout,
138
+ max_retries=max_retries,
139
+ rate_limit_rpm=rate_limit_rpm,
140
+ )
141
+ self._client: httpx.AsyncClient | None = None
142
+
143
+ async def initialize(self) -> None:
144
+ """Initialize the HTTP client."""
145
+ self._client = httpx.AsyncClient(
146
+ base_url=self.base_url,
147
+ headers={
148
+ "Authorization": f"Bearer {self.api_key}",
149
+ "Content-Type": "application/json",
150
+ },
151
+ timeout=self.timeout,
152
+ )
153
+
154
+ async def shutdown(self) -> None:
155
+ """Close the HTTP client."""
156
+ if self._client:
157
+ await self._client.aclose()
158
+ self._client = None
159
+
160
+ async def _ensure_client(self) -> httpx.AsyncClient:
161
+ """Ensure client is initialized."""
162
+ if not self._client:
163
+ await self.initialize()
164
+ return self._client # type: ignore
165
+
166
+ def _resolve_model(self, model: str) -> str:
167
+ """Resolve model alias to full model ID."""
168
+ return self.MODEL_ALIASES.get(model, model)
169
+
170
+ def get_models(self) -> list[ModelInfo]:
171
+ """Get available Groq models."""
172
+ return list(self.MODELS.values())
173
+
174
+ async def complete(
175
+ self,
176
+ messages: list[dict[str, Any]],
177
+ model: str,
178
+ temperature: float = 0.7,
179
+ max_tokens: int | None = None,
180
+ functions: list[dict[str, Any]] | None = None,
181
+ function_call: str | dict[str, str] | None = None,
182
+ tools: list[dict[str, Any]] | None = None,
183
+ tool_choice: str | dict[str, Any] | None = None,
184
+ stop: list[str] | None = None,
185
+ **kwargs: Any,
186
+ ) -> CompletionResponse:
187
+ """Generate a completion using Groq API (OpenAI-compatible)."""
188
+ await self._acquire_rate_limit()
189
+
190
+ model = self._resolve_model(model)
191
+ model_info = self.get_model_info(model)
192
+ if not model_info:
193
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
194
+
195
+ client = await self._ensure_client()
196
+
197
+ # Build request payload (OpenAI-compatible format)
198
+ payload: dict[str, Any] = {
199
+ "model": model,
200
+ "messages": messages,
201
+ "temperature": temperature,
202
+ }
203
+
204
+ if max_tokens:
205
+ payload["max_tokens"] = max_tokens
206
+ if stop:
207
+ payload["stop"] = stop
208
+
209
+ # Function calling
210
+ if functions and model_info.supports_functions:
211
+ payload["functions"] = functions
212
+ if function_call:
213
+ payload["function_call"] = function_call
214
+
215
+ # Tools (newer format)
216
+ if tools and model_info.supports_functions:
217
+ payload["tools"] = tools
218
+ if tool_choice:
219
+ payload["tool_choice"] = tool_choice
220
+
221
+ # Additional params
222
+ for key in ["top_p", "presence_penalty", "frequency_penalty"]:
223
+ if key in kwargs:
224
+ payload[key] = kwargs[key]
225
+
226
+ start_time = time.time()
227
+
228
+ try:
229
+ response = await self._retry_with_backoff(
230
+ self._make_request, client, payload
231
+ )
232
+ except httpx.HTTPStatusError as e:
233
+ self._handle_http_error(e)
234
+
235
+ latency_ms = (time.time() - start_time) * 1000
236
+
237
+ # Parse response (OpenAI-compatible)
238
+ choice = response["choices"][0]
239
+ message = choice["message"]
240
+ usage_data = response.get("usage", {})
241
+
242
+ usage = TokenUsage(
243
+ prompt_tokens=usage_data.get("prompt_tokens", 0),
244
+ completion_tokens=usage_data.get("completion_tokens", 0),
245
+ total_tokens=usage_data.get("total_tokens", 0),
246
+ )
247
+
248
+ cost = self.calculate_cost(model, usage)
249
+ self._track_usage(usage, cost)
250
+
251
+ # Extract function/tool calls
252
+ func_call = message.get("function_call")
253
+ tool_calls_raw = message.get("tool_calls")
254
+
255
+ tool_calls = None
256
+ if tool_calls_raw:
257
+ tool_calls = [
258
+ {
259
+ "id": tc["id"],
260
+ "type": tc["type"],
261
+ "function": {
262
+ "name": tc["function"]["name"],
263
+ "arguments": tc["function"]["arguments"],
264
+ },
265
+ }
266
+ for tc in tool_calls_raw
267
+ ]
268
+
269
+ return CompletionResponse(
270
+ content=message.get("content") or "",
271
+ model=response.get("model", model),
272
+ provider=self.PROVIDER_NAME,
273
+ usage=usage,
274
+ finish_reason=choice.get("finish_reason"),
275
+ function_call=func_call,
276
+ tool_calls=tool_calls,
277
+ raw_response=response,
278
+ latency_ms=latency_ms,
279
+ cost=cost,
280
+ )
281
+
282
+ async def _make_request(
283
+ self, client: httpx.AsyncClient, payload: dict[str, Any]
284
+ ) -> dict[str, Any]:
285
+ """Make the API request."""
286
+ response = await client.post("/chat/completions", json=payload)
287
+ response.raise_for_status()
288
+ return response.json()
289
+
290
+ def _handle_http_error(self, error: httpx.HTTPStatusError) -> None:
291
+ """Handle HTTP errors from Groq."""
292
+ status = error.response.status_code
293
+ try:
294
+ body = error.response.json()
295
+ message = body.get("error", {}).get("message", str(error))
296
+ except Exception:
297
+ message = str(error)
298
+
299
+ if status == 401:
300
+ raise AuthenticationError(self.PROVIDER_NAME, message)
301
+ elif status == 429:
302
+ retry_after = error.response.headers.get("retry-after")
303
+ raise RateLimitError(
304
+ self.PROVIDER_NAME,
305
+ retry_after=float(retry_after) if retry_after else None,
306
+ message=message,
307
+ )
308
+ elif status == 404:
309
+ raise ModelNotFoundError(self.PROVIDER_NAME, "unknown")
310
+ else:
311
+ raise ProviderError(message, self.PROVIDER_NAME, status)
312
+
313
+ async def stream(
314
+ self,
315
+ messages: list[dict[str, Any]],
316
+ model: str,
317
+ temperature: float = 0.7,
318
+ max_tokens: int | None = None,
319
+ **kwargs: Any,
320
+ ) -> AsyncIterator[str]:
321
+ """Stream a completion from Groq."""
322
+ await self._acquire_rate_limit()
323
+
324
+ model = self._resolve_model(model)
325
+ model_info = self.get_model_info(model)
326
+ if not model_info:
327
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
328
+
329
+ client = await self._ensure_client()
330
+
331
+ payload: dict[str, Any] = {
332
+ "model": model,
333
+ "messages": messages,
334
+ "temperature": temperature,
335
+ "stream": True,
336
+ }
337
+
338
+ if max_tokens:
339
+ payload["max_tokens"] = max_tokens
340
+
341
+ try:
342
+ async with client.stream("POST", "/chat/completions", json=payload) as response:
343
+ response.raise_for_status()
344
+
345
+ async for line in response.aiter_lines():
346
+ if line.startswith("data: "):
347
+ data = line[6:]
348
+ if data == "[DONE]":
349
+ break
350
+
351
+ try:
352
+ chunk = json.loads(data)
353
+ delta = chunk["choices"][0].get("delta", {})
354
+ content = delta.get("content")
355
+ if content:
356
+ yield content
357
+ except json.JSONDecodeError:
358
+ continue
359
+
360
+ except httpx.HTTPStatusError as e:
361
+ self._handle_http_error(e)
backend/app/models/providers/openai.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI provider implementation."""
2
+
3
+ import json
4
+ import time
5
+ from typing import Any, AsyncIterator
6
+
7
+ import httpx
8
+
9
+ from app.models.providers.base import (
10
+ AuthenticationError,
11
+ BaseProvider,
12
+ CompletionResponse,
13
+ ModelInfo,
14
+ ModelNotFoundError,
15
+ ProviderError,
16
+ RateLimitError,
17
+ TokenUsage,
18
+ )
19
+
20
+
21
+ class OpenAIProvider(BaseProvider):
22
+ """OpenAI API provider supporting GPT models."""
23
+
24
+ PROVIDER_NAME = "openai"
25
+ DEFAULT_BASE_URL = "https://api.openai.com/v1"
26
+
27
+ # Model definitions with pricing (per 1K tokens)
28
+ MODELS = {
29
+ "gpt-4o": ModelInfo(
30
+ id="gpt-4o",
31
+ name="GPT-4o",
32
+ provider="openai",
33
+ context_window=128000,
34
+ max_output_tokens=16384,
35
+ supports_functions=True,
36
+ supports_vision=True,
37
+ supports_streaming=True,
38
+ cost_per_1k_input=0.005,
39
+ cost_per_1k_output=0.015,
40
+ ),
41
+ "gpt-4o-mini": ModelInfo(
42
+ id="gpt-4o-mini",
43
+ name="GPT-4o Mini",
44
+ provider="openai",
45
+ context_window=128000,
46
+ max_output_tokens=16384,
47
+ supports_functions=True,
48
+ supports_vision=True,
49
+ supports_streaming=True,
50
+ cost_per_1k_input=0.00015,
51
+ cost_per_1k_output=0.0006,
52
+ ),
53
+ "gpt-4-turbo": ModelInfo(
54
+ id="gpt-4-turbo",
55
+ name="GPT-4 Turbo",
56
+ provider="openai",
57
+ context_window=128000,
58
+ max_output_tokens=4096,
59
+ supports_functions=True,
60
+ supports_vision=True,
61
+ supports_streaming=True,
62
+ cost_per_1k_input=0.01,
63
+ cost_per_1k_output=0.03,
64
+ ),
65
+ "gpt-4": ModelInfo(
66
+ id="gpt-4",
67
+ name="GPT-4",
68
+ provider="openai",
69
+ context_window=8192,
70
+ max_output_tokens=4096,
71
+ supports_functions=True,
72
+ supports_vision=False,
73
+ supports_streaming=True,
74
+ cost_per_1k_input=0.03,
75
+ cost_per_1k_output=0.06,
76
+ ),
77
+ "gpt-3.5-turbo": ModelInfo(
78
+ id="gpt-3.5-turbo",
79
+ name="GPT-3.5 Turbo",
80
+ provider="openai",
81
+ context_window=16385,
82
+ max_output_tokens=4096,
83
+ supports_functions=True,
84
+ supports_vision=False,
85
+ supports_streaming=True,
86
+ cost_per_1k_input=0.0005,
87
+ cost_per_1k_output=0.0015,
88
+ ),
89
+ }
90
+
91
+ def __init__(
92
+ self,
93
+ api_key: str,
94
+ base_url: str | None = None,
95
+ organization: str | None = None,
96
+ timeout: float = 60.0,
97
+ max_retries: int = 3,
98
+ rate_limit_rpm: int = 60,
99
+ ):
100
+ super().__init__(
101
+ api_key=api_key,
102
+ base_url=base_url or self.DEFAULT_BASE_URL,
103
+ timeout=timeout,
104
+ max_retries=max_retries,
105
+ rate_limit_rpm=rate_limit_rpm,
106
+ )
107
+ self.organization = organization
108
+ self._client: httpx.AsyncClient | None = None
109
+
110
+ async def initialize(self) -> None:
111
+ """Initialize the HTTP client."""
112
+ headers = {
113
+ "Authorization": f"Bearer {self.api_key}",
114
+ "Content-Type": "application/json",
115
+ }
116
+ if self.organization:
117
+ headers["OpenAI-Organization"] = self.organization
118
+
119
+ self._client = httpx.AsyncClient(
120
+ base_url=self.base_url,
121
+ headers=headers,
122
+ timeout=self.timeout,
123
+ )
124
+
125
+ async def shutdown(self) -> None:
126
+ """Close the HTTP client."""
127
+ if self._client:
128
+ await self._client.aclose()
129
+ self._client = None
130
+
131
+ async def _ensure_client(self) -> httpx.AsyncClient:
132
+ """Ensure client is initialized."""
133
+ if not self._client:
134
+ await self.initialize()
135
+ return self._client # type: ignore
136
+
137
+ def get_models(self) -> list[ModelInfo]:
138
+ """Get available OpenAI models."""
139
+ return list(self.MODELS.values())
140
+
141
+ async def complete(
142
+ self,
143
+ messages: list[dict[str, Any]],
144
+ model: str,
145
+ temperature: float = 0.7,
146
+ max_tokens: int | None = None,
147
+ functions: list[dict[str, Any]] | None = None,
148
+ function_call: str | dict[str, str] | None = None,
149
+ tools: list[dict[str, Any]] | None = None,
150
+ tool_choice: str | dict[str, Any] | None = None,
151
+ stop: list[str] | None = None,
152
+ **kwargs: Any,
153
+ ) -> CompletionResponse:
154
+ """Generate a completion using OpenAI API."""
155
+ await self._acquire_rate_limit()
156
+
157
+ model_info = self.get_model_info(model)
158
+ if not model_info:
159
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
160
+
161
+ client = await self._ensure_client()
162
+
163
+ # Build request payload
164
+ payload: dict[str, Any] = {
165
+ "model": model,
166
+ "messages": messages,
167
+ "temperature": temperature,
168
+ }
169
+
170
+ if max_tokens:
171
+ payload["max_tokens"] = max_tokens
172
+ if stop:
173
+ payload["stop"] = stop
174
+
175
+ # Function calling (legacy format)
176
+ if functions and model_info.supports_functions:
177
+ payload["functions"] = functions
178
+ if function_call:
179
+ payload["function_call"] = function_call
180
+
181
+ # Tools (newer format)
182
+ if tools and model_info.supports_functions:
183
+ payload["tools"] = tools
184
+ if tool_choice:
185
+ payload["tool_choice"] = tool_choice
186
+
187
+ # Additional kwargs
188
+ for key in ["top_p", "presence_penalty", "frequency_penalty", "logit_bias", "user"]:
189
+ if key in kwargs:
190
+ payload[key] = kwargs[key]
191
+
192
+ start_time = time.time()
193
+
194
+ try:
195
+ response = await self._retry_with_backoff(
196
+ self._make_request, client, payload
197
+ )
198
+ except httpx.HTTPStatusError as e:
199
+ self._handle_http_error(e)
200
+
201
+ latency_ms = (time.time() - start_time) * 1000
202
+
203
+ # Parse response
204
+ choice = response["choices"][0]
205
+ message = choice["message"]
206
+ usage_data = response.get("usage", {})
207
+
208
+ usage = TokenUsage(
209
+ prompt_tokens=usage_data.get("prompt_tokens", 0),
210
+ completion_tokens=usage_data.get("completion_tokens", 0),
211
+ total_tokens=usage_data.get("total_tokens", 0),
212
+ )
213
+
214
+ cost = self.calculate_cost(model, usage)
215
+ self._track_usage(usage, cost)
216
+
217
+ # Extract function call / tool calls
218
+ func_call = message.get("function_call")
219
+ tool_calls_raw = message.get("tool_calls")
220
+
221
+ tool_calls = None
222
+ if tool_calls_raw:
223
+ tool_calls = [
224
+ {
225
+ "id": tc["id"],
226
+ "type": tc["type"],
227
+ "function": {
228
+ "name": tc["function"]["name"],
229
+ "arguments": tc["function"]["arguments"],
230
+ },
231
+ }
232
+ for tc in tool_calls_raw
233
+ ]
234
+
235
+ return CompletionResponse(
236
+ content=message.get("content") or "",
237
+ model=response.get("model", model),
238
+ provider=self.PROVIDER_NAME,
239
+ usage=usage,
240
+ finish_reason=choice.get("finish_reason"),
241
+ function_call=func_call,
242
+ tool_calls=tool_calls,
243
+ raw_response=response,
244
+ latency_ms=latency_ms,
245
+ cost=cost,
246
+ )
247
+
248
+ async def _make_request(
249
+ self, client: httpx.AsyncClient, payload: dict[str, Any]
250
+ ) -> dict[str, Any]:
251
+ """Make the API request."""
252
+ response = await client.post("/chat/completions", json=payload)
253
+ response.raise_for_status()
254
+ return response.json()
255
+
256
+ def _handle_http_error(self, error: httpx.HTTPStatusError) -> None:
257
+ """Handle HTTP errors from OpenAI."""
258
+ status = error.response.status_code
259
+ try:
260
+ body = error.response.json()
261
+ message = body.get("error", {}).get("message", str(error))
262
+ except Exception:
263
+ message = str(error)
264
+
265
+ if status == 401:
266
+ raise AuthenticationError(self.PROVIDER_NAME, message)
267
+ elif status == 429:
268
+ retry_after = error.response.headers.get("retry-after")
269
+ raise RateLimitError(
270
+ self.PROVIDER_NAME,
271
+ retry_after=float(retry_after) if retry_after else None,
272
+ message=message,
273
+ )
274
+ elif status == 404:
275
+ raise ModelNotFoundError(self.PROVIDER_NAME, "unknown")
276
+ else:
277
+ raise ProviderError(message, self.PROVIDER_NAME, status)
278
+
279
+ async def stream(
280
+ self,
281
+ messages: list[dict[str, Any]],
282
+ model: str,
283
+ temperature: float = 0.7,
284
+ max_tokens: int | None = None,
285
+ **kwargs: Any,
286
+ ) -> AsyncIterator[str]:
287
+ """Stream a completion from OpenAI."""
288
+ await self._acquire_rate_limit()
289
+
290
+ model_info = self.get_model_info(model)
291
+ if not model_info:
292
+ raise ModelNotFoundError(self.PROVIDER_NAME, model)
293
+
294
+ client = await self._ensure_client()
295
+
296
+ payload: dict[str, Any] = {
297
+ "model": model,
298
+ "messages": messages,
299
+ "temperature": temperature,
300
+ "stream": True,
301
+ }
302
+
303
+ if max_tokens:
304
+ payload["max_tokens"] = max_tokens
305
+
306
+ try:
307
+ async with client.stream("POST", "/chat/completions", json=payload) as response:
308
+ response.raise_for_status()
309
+
310
+ async for line in response.aiter_lines():
311
+ if line.startswith("data: "):
312
+ data = line[6:]
313
+ if data == "[DONE]":
314
+ break
315
+
316
+ try:
317
+ chunk = json.loads(data)
318
+ delta = chunk["choices"][0].get("delta", {})
319
+ content = delta.get("content")
320
+ if content:
321
+ yield content
322
+ except json.JSONDecodeError:
323
+ continue
324
+
325
+ except httpx.HTTPStatusError as e:
326
+ self._handle_http_error(e)
327
+
328
+ async def create_embedding(
329
+ self,
330
+ text: str | list[str],
331
+ model: str = "text-embedding-3-small",
332
+ ) -> list[list[float]]:
333
+ """Create embeddings for text.
334
+
335
+ Args:
336
+ text: Text or list of texts to embed
337
+ model: Embedding model to use
338
+
339
+ Returns:
340
+ List of embedding vectors
341
+ """
342
+ client = await self._ensure_client()
343
+
344
+ payload = {
345
+ "model": model,
346
+ "input": text if isinstance(text, list) else [text],
347
+ }
348
+
349
+ response = await client.post("/embeddings", json=payload)
350
+ response.raise_for_status()
351
+
352
+ data = response.json()
353
+ return [item["embedding"] for item in data["data"]]
backend/app/models/router.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Smart model router for intelligent model selection and fallback."""
2
+
3
+ import asyncio
4
+ import logging
5
+ from dataclasses import dataclass, field
6
+ from datetime import datetime
7
+ from enum import Enum
8
+ from typing import Any
9
+
10
+ from pydantic import SecretStr
11
+
12
+ from app.models.providers.base import (
13
+ BaseProvider,
14
+ CompletionResponse,
15
+ ModelInfo,
16
+ ProviderError,
17
+ RateLimitError,
18
+ TaskType,
19
+ TokenUsage,
20
+ )
21
+ from app.models.providers.openai import OpenAIProvider
22
+ from app.models.providers.anthropic import AnthropicProvider
23
+ from app.models.providers.google import GoogleProvider
24
+ from app.models.providers.groq import GroqProvider
25
+
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ class RoutingStrategy(str, Enum):
30
+ """Model routing strategies."""
31
+
32
+ BEST_QUALITY = "best_quality" # Use highest quality model
33
+ BEST_SPEED = "best_speed" # Use fastest model
34
+ BEST_VALUE = "best_value" # Balance quality/cost
35
+ LOWEST_COST = "lowest_cost" # Use cheapest model
36
+ ROUND_ROBIN = "round_robin" # Rotate between models
37
+
38
+
39
+ @dataclass
40
+ class ModelScore:
41
+ """Scoring for model routing decisions."""
42
+
43
+ model_id: str
44
+ provider: str
45
+ quality_score: float = 0.0 # 0-1, higher is better
46
+ speed_score: float = 0.0 # 0-1, higher is faster
47
+ cost_score: float = 0.0 # 0-1, higher is cheaper
48
+ overall_score: float = 0.0
49
+
50
+
51
+ @dataclass
52
+ class RoutingConfig:
53
+ """Configuration for model routing."""
54
+
55
+ default_strategy: RoutingStrategy = RoutingStrategy.BEST_VALUE
56
+ max_fallback_attempts: int = 3
57
+ fallback_delay_seconds: float = 1.0
58
+ enable_caching: bool = True
59
+ cache_ttl_seconds: int = 300
60
+
61
+ # Task-specific model preferences
62
+ task_preferences: dict[TaskType, list[str]] = field(default_factory=lambda: {
63
+ TaskType.GENERAL: ["gpt-4o", "claude-3-5-sonnet-20241022", "gemini-1.5-pro"],
64
+ TaskType.CODE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-1.5-pro"],
65
+ TaskType.REASONING: ["claude-3-opus-20240229", "gpt-4o", "gemini-1.5-pro"],
66
+ TaskType.EXTRACTION: ["gpt-4o-mini", "claude-3-haiku-20240307", "gemini-1.5-flash"],
67
+ TaskType.SUMMARIZATION: ["gpt-4o-mini", "claude-3-5-haiku-20241022", "gemini-1.5-flash"],
68
+ TaskType.CLASSIFICATION: ["gpt-4o-mini", "claude-3-haiku-20240307", "llama-3.1-8b-instant"],
69
+ TaskType.CREATIVE: ["claude-3-5-sonnet-20241022", "gpt-4o", "gemini-1.5-pro"],
70
+ TaskType.FAST: ["llama-3.1-8b-instant", "gemini-1.5-flash", "gpt-4o-mini"],
71
+ })
72
+
73
+
74
+ @dataclass
75
+ class CostTracker:
76
+ """Track costs across providers and models."""
77
+
78
+ total_cost: float = 0.0
79
+ cost_by_provider: dict[str, float] = field(default_factory=dict)
80
+ cost_by_model: dict[str, float] = field(default_factory=dict)
81
+ request_count: int = 0
82
+ total_tokens: TokenUsage = field(default_factory=TokenUsage)
83
+ start_time: datetime = field(default_factory=datetime.utcnow)
84
+
85
+ def track(self, response: CompletionResponse) -> None:
86
+ """Track a completion response."""
87
+ self.total_cost += response.cost
88
+ self.request_count += 1
89
+ self.total_tokens = self.total_tokens + response.usage
90
+
91
+ # By provider
92
+ self.cost_by_provider[response.provider] = (
93
+ self.cost_by_provider.get(response.provider, 0.0) + response.cost
94
+ )
95
+
96
+ # By model
97
+ self.cost_by_model[response.model] = (
98
+ self.cost_by_model.get(response.model, 0.0) + response.cost
99
+ )
100
+
101
+ def get_summary(self) -> dict[str, Any]:
102
+ """Get cost summary."""
103
+ return {
104
+ "total_cost_usd": self.total_cost,
105
+ "request_count": self.request_count,
106
+ "total_tokens": {
107
+ "prompt": self.total_tokens.prompt_tokens,
108
+ "completion": self.total_tokens.completion_tokens,
109
+ "total": self.total_tokens.total_tokens,
110
+ },
111
+ "cost_by_provider": self.cost_by_provider,
112
+ "cost_by_model": self.cost_by_model,
113
+ "avg_cost_per_request": (
114
+ self.total_cost / self.request_count if self.request_count > 0 else 0
115
+ ),
116
+ "tracking_since": self.start_time.isoformat(),
117
+ }
118
+
119
+ def reset(self) -> None:
120
+ """Reset cost tracking."""
121
+ self.total_cost = 0.0
122
+ self.cost_by_provider = {}
123
+ self.cost_by_model = {}
124
+ self.request_count = 0
125
+ self.total_tokens = TokenUsage()
126
+ self.start_time = datetime.utcnow()
127
+
128
+
129
+ class SmartModelRouter:
130
+ """Intelligent model router with fallback and cost tracking."""
131
+
132
+ # Model quality rankings (subjective, based on benchmarks)
133
+ MODEL_QUALITY_SCORES: dict[str, float] = {
134
+ # OpenAI
135
+ "gpt-4o": 0.95,
136
+ "gpt-4-turbo": 0.92,
137
+ "gpt-4": 0.90,
138
+ "gpt-4o-mini": 0.80,
139
+ "gpt-3.5-turbo": 0.70,
140
+ # Anthropic
141
+ "claude-3-opus-20240229": 0.97,
142
+ "claude-3-5-sonnet-20241022": 0.94,
143
+ "claude-3-sonnet-20240229": 0.88,
144
+ "claude-3-5-haiku-20241022": 0.82,
145
+ "claude-3-haiku-20240307": 0.75,
146
+ # Google
147
+ "gemini-1.5-pro": 0.91,
148
+ "gemini-2.0-flash-exp": 0.88,
149
+ "gemini-1.5-flash": 0.78,
150
+ "gemini-pro": 0.75,
151
+ # Groq
152
+ "llama-3.3-70b-versatile": 0.85,
153
+ "llama-3.1-70b-versatile": 0.84,
154
+ "llama3-70b-8192": 0.82,
155
+ "mixtral-8x7b-32768": 0.78,
156
+ "llama-3.1-8b-instant": 0.65,
157
+ "llama3-8b-8192": 0.60,
158
+ "gemma2-9b-it": 0.62,
159
+ }
160
+
161
+ # Model speed rankings (relative, based on typical latency)
162
+ MODEL_SPEED_SCORES: dict[str, float] = {
163
+ # Groq is fastest
164
+ "llama-3.1-8b-instant": 0.98,
165
+ "llama3-8b-8192": 0.97,
166
+ "gemma2-9b-it": 0.96,
167
+ "mixtral-8x7b-32768": 0.94,
168
+ "llama3-70b-8192": 0.92,
169
+ "llama-3.1-70b-versatile": 0.91,
170
+ "llama-3.3-70b-versatile": 0.90,
171
+ # Google Flash is fast
172
+ "gemini-1.5-flash": 0.88,
173
+ "gemini-2.0-flash-exp": 0.87,
174
+ # Mini models
175
+ "gpt-4o-mini": 0.85,
176
+ "claude-3-haiku-20240307": 0.84,
177
+ "claude-3-5-haiku-20241022": 0.83,
178
+ "gpt-3.5-turbo": 0.82,
179
+ # Pro models
180
+ "gemini-pro": 0.75,
181
+ "gemini-1.5-pro": 0.70,
182
+ "gpt-4o": 0.68,
183
+ "claude-3-5-sonnet-20241022": 0.65,
184
+ "claude-3-sonnet-20240229": 0.62,
185
+ "gpt-4-turbo": 0.55,
186
+ "gpt-4": 0.50,
187
+ "claude-3-opus-20240229": 0.40,
188
+ }
189
+
190
+ def __init__(
191
+ self,
192
+ openai_api_key: str | SecretStr | None = None,
193
+ anthropic_api_key: str | SecretStr | None = None,
194
+ google_api_key: str | SecretStr | None = None,
195
+ groq_api_key: str | SecretStr | None = None,
196
+ config: RoutingConfig | None = None,
197
+ ):
198
+ self.config = config or RoutingConfig()
199
+ self.providers: dict[str, BaseProvider] = {}
200
+ self.cost_tracker = CostTracker()
201
+ self._initialized = False
202
+ self._round_robin_index = 0
203
+
204
+ # Store API keys (handle SecretStr)
205
+ self._api_keys = {
206
+ "openai": self._get_key_value(openai_api_key),
207
+ "anthropic": self._get_key_value(anthropic_api_key),
208
+ "google": self._get_key_value(google_api_key),
209
+ "groq": self._get_key_value(groq_api_key),
210
+ }
211
+
212
+ @staticmethod
213
+ def _get_key_value(key: str | SecretStr | None) -> str | None:
214
+ """Extract string value from SecretStr if needed."""
215
+ if key is None:
216
+ return None
217
+ if isinstance(key, SecretStr):
218
+ return key.get_secret_value()
219
+ return key
220
+
221
+ async def initialize(self) -> None:
222
+ """Initialize all configured providers."""
223
+ if self._initialized:
224
+ return
225
+
226
+ # Initialize providers based on available API keys
227
+ if self._api_keys["openai"]:
228
+ provider = OpenAIProvider(api_key=self._api_keys["openai"])
229
+ await provider.initialize()
230
+ self.providers["openai"] = provider
231
+ logger.info("Initialized OpenAI provider")
232
+
233
+ if self._api_keys["anthropic"]:
234
+ provider = AnthropicProvider(api_key=self._api_keys["anthropic"])
235
+ await provider.initialize()
236
+ self.providers["anthropic"] = provider
237
+ logger.info("Initialized Anthropic provider")
238
+
239
+ if self._api_keys["google"]:
240
+ provider = GoogleProvider(api_key=self._api_keys["google"])
241
+ await provider.initialize()
242
+ self.providers["google"] = provider
243
+ logger.info("Initialized Google provider")
244
+
245
+ if self._api_keys["groq"]:
246
+ provider = GroqProvider(api_key=self._api_keys["groq"])
247
+ await provider.initialize()
248
+ self.providers["groq"] = provider
249
+ logger.info("Initialized Groq provider")
250
+
251
+ if not self.providers:
252
+ logger.warning("No LLM providers configured")
253
+
254
+ self._initialized = True
255
+
256
+ async def shutdown(self) -> None:
257
+ """Shutdown all providers."""
258
+ for provider in self.providers.values():
259
+ await provider.shutdown()
260
+ self.providers.clear()
261
+ self._initialized = False
262
+
263
+ def get_available_models(self) -> list[ModelInfo]:
264
+ """Get all available models across providers."""
265
+ models = []
266
+ for provider in self.providers.values():
267
+ models.extend(provider.get_models())
268
+ return models
269
+
270
+ def get_provider_for_model(self, model: str) -> BaseProvider | None:
271
+ """Get the provider for a specific model."""
272
+ for provider in self.providers.values():
273
+ if provider.get_model_info(model):
274
+ return provider
275
+
276
+ # Check aliases for Anthropic and Google
277
+ if hasattr(provider, "MODEL_ALIASES"):
278
+ if model in provider.MODEL_ALIASES: # type: ignore
279
+ return provider
280
+
281
+ return None
282
+
283
+ def _score_model(
284
+ self,
285
+ model_info: ModelInfo,
286
+ strategy: RoutingStrategy,
287
+ ) -> ModelScore:
288
+ """Score a model based on routing strategy."""
289
+ model_id = model_info.id
290
+
291
+ quality = self.MODEL_QUALITY_SCORES.get(model_id, 0.5)
292
+ speed = self.MODEL_SPEED_SCORES.get(model_id, 0.5)
293
+
294
+ # Calculate cost score (inverse of cost, normalized)
295
+ max_cost = 0.1 # $0.10 per 1K tokens as reference
296
+ avg_cost = (model_info.cost_per_1k_input + model_info.cost_per_1k_output) / 2
297
+ cost_score = 1.0 - min(avg_cost / max_cost, 1.0)
298
+
299
+ # Calculate overall score based on strategy
300
+ if strategy == RoutingStrategy.BEST_QUALITY:
301
+ overall = quality * 0.8 + speed * 0.1 + cost_score * 0.1
302
+ elif strategy == RoutingStrategy.BEST_SPEED:
303
+ overall = quality * 0.1 + speed * 0.8 + cost_score * 0.1
304
+ elif strategy == RoutingStrategy.LOWEST_COST:
305
+ overall = quality * 0.1 + speed * 0.1 + cost_score * 0.8
306
+ else: # BEST_VALUE
307
+ overall = quality * 0.4 + speed * 0.3 + cost_score * 0.3
308
+
309
+ return ModelScore(
310
+ model_id=model_id,
311
+ provider=model_info.provider,
312
+ quality_score=quality,
313
+ speed_score=speed,
314
+ cost_score=cost_score,
315
+ overall_score=overall,
316
+ )
317
+
318
+ def route(
319
+ self,
320
+ task_type: TaskType = TaskType.GENERAL,
321
+ strategy: RoutingStrategy | None = None,
322
+ required_features: list[str] | None = None,
323
+ ) -> tuple[str, BaseProvider] | None:
324
+ """Route to the best model for the task.
325
+
326
+ Args:
327
+ task_type: Type of task to perform
328
+ strategy: Routing strategy (uses default if not specified)
329
+ required_features: Required model features (e.g., 'functions', 'vision')
330
+
331
+ Returns:
332
+ Tuple of (model_id, provider) or None if no suitable model found
333
+ """
334
+ if not self.providers:
335
+ return None
336
+
337
+ strategy = strategy or self.config.default_strategy
338
+
339
+ # Handle round robin specially
340
+ if strategy == RoutingStrategy.ROUND_ROBIN:
341
+ models = self.get_available_models()
342
+ if not models:
343
+ return None
344
+
345
+ # Filter by features if needed
346
+ if required_features:
347
+ models = self._filter_by_features(models, required_features)
348
+
349
+ if not models:
350
+ return None
351
+
352
+ model = models[self._round_robin_index % len(models)]
353
+ self._round_robin_index += 1
354
+ provider = self.get_provider_for_model(model.id)
355
+ return (model.id, provider) if provider else None
356
+
357
+ # Get task preferences
358
+ preferred_models = self.config.task_preferences.get(task_type, [])
359
+
360
+ # Check preferred models first
361
+ for model_id in preferred_models:
362
+ provider = self.get_provider_for_model(model_id)
363
+ if provider:
364
+ model_info = provider.get_model_info(model_id)
365
+ if model_info and self._meets_requirements(model_info, required_features):
366
+ return (model_id, provider)
367
+
368
+ # Score all available models
369
+ scored_models: list[tuple[ModelScore, BaseProvider]] = []
370
+ for provider in self.providers.values():
371
+ for model_info in provider.get_models():
372
+ if self._meets_requirements(model_info, required_features):
373
+ score = self._score_model(model_info, strategy)
374
+ scored_models.append((score, provider))
375
+
376
+ if not scored_models:
377
+ return None
378
+
379
+ # Sort by overall score
380
+ scored_models.sort(key=lambda x: x[0].overall_score, reverse=True)
381
+ best_score, best_provider = scored_models[0]
382
+
383
+ return (best_score.model_id, best_provider)
384
+
385
+ def _meets_requirements(
386
+ self,
387
+ model_info: ModelInfo,
388
+ required_features: list[str] | None,
389
+ ) -> bool:
390
+ """Check if model meets required features."""
391
+ if not required_features:
392
+ return True
393
+
394
+ for feature in required_features:
395
+ if feature == "functions" and not model_info.supports_functions:
396
+ return False
397
+ if feature == "vision" and not model_info.supports_vision:
398
+ return False
399
+ if feature == "streaming" and not model_info.supports_streaming:
400
+ return False
401
+
402
+ return True
403
+
404
+ def _filter_by_features(
405
+ self,
406
+ models: list[ModelInfo],
407
+ required_features: list[str],
408
+ ) -> list[ModelInfo]:
409
+ """Filter models by required features."""
410
+ return [m for m in models if self._meets_requirements(m, required_features)]
411
+
412
+ async def complete(
413
+ self,
414
+ messages: list[dict[str, Any]],
415
+ model: str | None = None,
416
+ task_type: TaskType = TaskType.GENERAL,
417
+ strategy: RoutingStrategy | None = None,
418
+ required_features: list[str] | None = None,
419
+ fallback: bool = True,
420
+ **kwargs: Any,
421
+ ) -> CompletionResponse:
422
+ """Generate a completion with automatic routing and fallback.
423
+
424
+ Args:
425
+ messages: List of message dicts
426
+ model: Specific model to use (overrides routing)
427
+ task_type: Type of task for routing
428
+ strategy: Routing strategy
429
+ required_features: Required model features
430
+ fallback: Enable fallback on failure
431
+ **kwargs: Additional completion parameters
432
+
433
+ Returns:
434
+ CompletionResponse from the model
435
+
436
+ Raises:
437
+ ProviderError: If all models fail
438
+ """
439
+ if not self._initialized:
440
+ await self.initialize()
441
+
442
+ # Determine model(s) to try
443
+ models_to_try: list[tuple[str, BaseProvider]] = []
444
+
445
+ if model:
446
+ # Specific model requested
447
+ provider = self.get_provider_for_model(model)
448
+ if provider:
449
+ models_to_try.append((model, provider))
450
+ else:
451
+ raise ProviderError(f"Model {model} not found", "router")
452
+ else:
453
+ # Use routing
454
+ route_result = self.route(task_type, strategy, required_features)
455
+ if route_result:
456
+ models_to_try.append(route_result)
457
+
458
+ # Add fallback models
459
+ if fallback and len(models_to_try) < self.config.max_fallback_attempts:
460
+ # Get additional models for fallback
461
+ preferred = self.config.task_preferences.get(task_type, [])
462
+ for fallback_model in preferred:
463
+ if len(models_to_try) >= self.config.max_fallback_attempts:
464
+ break
465
+
466
+ provider = self.get_provider_for_model(fallback_model)
467
+ if provider and (fallback_model, provider) not in models_to_try:
468
+ models_to_try.append((fallback_model, provider))
469
+
470
+ if not models_to_try:
471
+ raise ProviderError("No suitable models available", "router")
472
+
473
+ # Try models in order
474
+ last_error: Exception | None = None
475
+
476
+ for i, (model_id, provider) in enumerate(models_to_try):
477
+ try:
478
+ logger.info(f"Attempting completion with {provider.PROVIDER_NAME}/{model_id}")
479
+ response = await provider.complete(messages, model_id, **kwargs)
480
+
481
+ # Track cost
482
+ self.cost_tracker.track(response)
483
+
484
+ return response
485
+
486
+ except RateLimitError as e:
487
+ logger.warning(f"Rate limited by {provider.PROVIDER_NAME}: {e}")
488
+ last_error = e
489
+ if i < len(models_to_try) - 1:
490
+ await asyncio.sleep(self.config.fallback_delay_seconds)
491
+
492
+ except ProviderError as e:
493
+ logger.warning(f"Provider error from {provider.PROVIDER_NAME}: {e}")
494
+ last_error = e
495
+ if i < len(models_to_try) - 1:
496
+ await asyncio.sleep(self.config.fallback_delay_seconds)
497
+
498
+ except Exception as e:
499
+ logger.error(f"Unexpected error from {provider.PROVIDER_NAME}: {e}")
500
+ last_error = e
501
+
502
+ # All models failed
503
+ raise ProviderError(
504
+ f"All models failed. Last error: {last_error}",
505
+ "router",
506
+ )
507
+
508
+ def get_cost_summary(self) -> dict[str, Any]:
509
+ """Get cost tracking summary."""
510
+ return self.cost_tracker.get_summary()
511
+
512
+ def reset_cost_tracking(self) -> None:
513
+ """Reset cost tracking."""
514
+ self.cost_tracker.reset()
515
+
516
+ @property
517
+ def available_providers(self) -> list[str]:
518
+ """List of initialized provider names."""
519
+ return list(self.providers.keys())
520
+
521
+ def __repr__(self) -> str:
522
+ return (
523
+ f"SmartModelRouter(providers={list(self.providers.keys())}, "
524
+ f"requests={self.cost_tracker.request_count}, "
525
+ f"cost=${self.cost_tracker.total_cost:.4f})"
526
+ )