File size: 19,617 Bytes
729a1f7
 
0fee802
729a1f7
 
 
 
 
 
 
 
 
 
8db88dd
4386026
8db88dd
 
729a1f7
 
 
8db88dd
f3a5a1f
8db88dd
 
 
729a1f7
 
 
ef1ba2b
8db88dd
 
f3a5a1f
8db88dd
 
 
 
 
f3a5a1f
 
8db88dd
f3a5a1f
 
 
8db88dd
 
 
f3a5a1f
8db88dd
 
 
 
 
 
 
 
 
 
 
f3a5a1f
 
8db88dd
 
 
 
 
 
 
 
f3a5a1f
 
 
 
 
 
 
729a1f7
ef1ba2b
f3a5a1f
729a1f7
8db88dd
 
 
 
 
4386026
729a1f7
f3a5a1f
729a1f7
 
 
 
83c32ef
 
729a1f7
 
83c32ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
729a1f7
 
afedd43
729a1f7
afedd43
 
 
 
 
 
 
 
 
 
 
48a0d5a
 
 
afedd43
48a0d5a
 
afedd43
 
 
 
 
 
140713c
afedd43
 
 
140713c
afedd43
 
 
729a1f7
 
b05f563
729a1f7
b05f563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48a0d5a
 
 
b05f563
48a0d5a
 
b05f563
 
 
 
 
 
 
 
 
729a1f7
4386026
b05f563
 
e6eaeb3
b05f563
 
 
 
 
13861c2
8db88dd
b05f563
 
e6eaeb3
b05f563
 
 
 
 
13861c2
ac4ae39
 
 
 
e6eaeb3
ac4ae39
 
 
 
 
13861c2
ef1ba2b
 
 
 
e6eaeb3
ef1ba2b
4386026
ef1ba2b
 
e6eaeb3
 
 
 
 
 
 
7a1ebee
e6eaeb3
 
 
 
 
 
ef1ba2b
 
 
 
 
 
 
 
 
4386026
ef1ba2b
 
 
 
 
 
 
4386026
 
ef1ba2b
 
 
 
 
 
 
 
4386026
ef1ba2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4386026
ef1ba2b
 
 
 
 
 
 
 
 
4386026
ef1ba2b
 
 
 
 
4386026
8db88dd
 
 
e6eaeb3
8db88dd
 
 
 
e6eaeb3
 
 
 
 
 
 
7a1ebee
e6eaeb3
 
 
 
 
 
8db88dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
# ────────────────────────────── utils/router.py ──────────────────────────────
import os
from ..logger import get_logger
from typing import Dict, Any
from .rotator import robust_post_json, APIKeyRotator

logger = get_logger("ROUTER", __name__)

# Default model names (can be overridden via env)
GEMINI_SMALL = os.getenv("GEMINI_SMALL", "gemini-2.5-flash-lite")
GEMINI_MED   = os.getenv("GEMINI_MED",   "gemini-2.5-flash")
GEMINI_PRO   = os.getenv("GEMINI_PRO",   "gemini-2.5-pro")

# NVIDIA model hierarchy (can be overridden via env)
NVIDIA_SMALL = os.getenv("NVIDIA_SMALL", "meta/llama-3.1-8b-instruct")         # Llama model for easy complexity tasks
NVIDIA_MEDIUM = os.getenv("NVIDIA_MEDIUM", "qwen/qwen3-next-80b-a3b-thinking") # Qwen model for reasoning tasks
NVIDIA_LARGE = os.getenv("NVIDIA_LARGE", "openai/gpt-oss-120b")                # GPT-OSS model for hard/long context tasks

def select_model(question: str, context: str) -> Dict[str, Any]:
    """
    Enhanced three-tier model selection system:
    - Easy tasks (immediate execution, simple) -> Llama (NVIDIA small)
    - Reasoning tasks (analysis, decision-making, JSON parsing) -> Qwen (NVIDIA medium)
    - Hard/long context tasks (complex synthesis, long-form) -> GPT-OSS (NVIDIA large)
    - Very complex tasks (research, comprehensive analysis) -> Gemini Pro
    """
    qlen = len(question.split())
    clen = len(context.split())
    
    # Very hard task keywords - require Gemini Pro (research, comprehensive analysis)
    very_hard_keywords = ("prove", "derivation", "complexity", "algorithm", "optimize", "theorem", "rigorous", "step-by-step", "policy critique", "ambiguity", "counterfactual", "comprehensive", "detailed analysis", "synthesis", "evaluation", "research", "investigation", "comprehensive study")
    
    # Hard/long context keywords - require NVIDIA Large (GPT-OSS)
    hard_keywords = ("analyze", "explain", "compare", "evaluate", "summarize", "extract", "classify", "identify", "describe", "discuss", "synthesis", "consolidate", "process", "generate", "create", "develop", "build", "construct")
    
    # Reasoning task keywords - require Qwen (thinking/reasoning)
    reasoning_keywords = ("reasoning", "context", "enhance", "select", "decide", "choose", "determine", "assess", "judge", "consider", "think", "reason", "logic", "inference", "deduction", "analysis", "interpretation")
    
    # Simple task keywords - immediate execution
    simple_keywords = ("what", "how", "when", "where", "who", "yes", "no", "count", "list", "find", "search", "lookup")
    
    # Determine complexity level
    is_very_hard = (
        any(k in question.lower() for k in very_hard_keywords) or 
        qlen > 120 or 
        clen > 4000 or
        "comprehensive" in question.lower() or
        "detailed" in question.lower() or
        "research" in question.lower()
    )
    
    is_hard = (
        any(k in question.lower() for k in hard_keywords) or 
        qlen > 50 or 
        clen > 1500 or
        "synthesis" in question.lower() or
        "generate" in question.lower() or
        "create" in question.lower()
    )
    
    is_reasoning = (
        any(k in question.lower() for k in reasoning_keywords) or 
        qlen > 20 or 
        clen > 800 or
        "enhance" in question.lower() or
        "context" in question.lower() or
        "select" in question.lower() or
        "decide" in question.lower()
    )
    
    is_simple = (
        any(k in question.lower() for k in simple_keywords) or
        qlen <= 10 or
        clen <= 200
    )

    if is_very_hard:
        # Use Gemini Pro for very complex tasks requiring advanced reasoning
        return {"provider": "gemini", "model": GEMINI_PRO}
    elif is_hard:
        # Use NVIDIA Large (GPT-OSS) for hard/long context tasks
        return {"provider": "nvidia_large", "model": NVIDIA_LARGE}
    elif is_reasoning:
        # Use Qwen for reasoning tasks requiring thinking
        return {"provider": "qwen", "model": NVIDIA_MEDIUM}
    else:
        # Use NVIDIA small (Llama) for simple tasks requiring immediate execution
        return {"provider": "nvidia", "model": NVIDIA_SMALL}


async def generate_answer_with_model(selection: Dict[str, Any], system_prompt: str, user_prompt: str,
                                     gemini_rotator: APIKeyRotator, nvidia_rotator: APIKeyRotator, 
                                     user_id: str = None, context: str = "") -> str:
    provider = selection["provider"]
    model = selection["model"]
    
    # Track model usage for analytics
    try:
        from utils.analytics import get_analytics_tracker
        tracker = get_analytics_tracker()
        if tracker and user_id:
            await tracker.track_model_usage(
                user_id=user_id,
                model_name=model,
                provider=provider,
                context=context or "api_call",
                metadata={"system_prompt_length": len(system_prompt), "user_prompt_length": len(user_prompt)}
            )
    except Exception as e:
        logger.debug(f"[ROUTER] Analytics tracking failed: {e}")

    if provider == "gemini":
        # Try Gemini first
        try:
            key = gemini_rotator.get_key() or ""
            url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}"
            payload = {
                "contents": [
                    {"role": "user", "parts": [{"text": f"{system_prompt}\n\n{user_prompt}"}]}
                ],
                "generationConfig": {"temperature": 0.2}
            }
            headers = {"Content-Type": "application/json"}
            data = await robust_post_json(url, headers, payload, gemini_rotator)
            
            content = data["candidates"][0]["content"]["parts"][0]["text"]
            if not content or content.strip() == "":
                logger.warning(f"Empty content from Gemini model: {data}")
                raise Exception("Empty content from Gemini")
            return content
        except Exception as e:
            logger.warning(f"Gemini model {model} failed: {e}. Attempting fallback...")
            
            # Fallback logic: GEMINI_PRO/MED → NVIDIA_LARGE, GEMINI_SMALL → NVIDIA_SMALL
            if model in [GEMINI_PRO, GEMINI_MED]:
                logger.info(f"Falling back from {model} to NVIDIA_LARGE")
                fallback_selection = {"provider": "nvidia_large", "model": NVIDIA_LARGE}
                return await generate_answer_with_model(fallback_selection, system_prompt, user_prompt, gemini_rotator, nvidia_rotator, user_id, context)
            elif model == GEMINI_SMALL:
                logger.info(f"Falling back from {model} to NVIDIA_SMALL")
                fallback_selection = {"provider": "nvidia", "model": NVIDIA_SMALL}
                return await generate_answer_with_model(fallback_selection, system_prompt, user_prompt, gemini_rotator, nvidia_rotator, user_id, context)
            else:
                logger.error(f"No fallback defined for Gemini model: {model}")
                return "I couldn't parse the model response."

    elif provider == "nvidia":
        # Try NVIDIA small model first
        try:
            key = nvidia_rotator.get_key() or ""
            url = "https://integrate.api.nvidia.com/v1/chat/completions"
            payload = {
                "model": model,
                "temperature": 0.2,
                "messages": [
                    {"role": "system", "content": system_prompt},
                    {"role": "user", "content": user_prompt},
                ]
            }
            headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
            
            logger.info(f"[ROUTER] NVIDIA API call - Model: {model}, Key present: {bool(key)}")
            logger.info(f"[ROUTER] System prompt length: {len(system_prompt)}, User prompt length: {len(user_prompt)}")
            
            data = await robust_post_json(url, headers, payload, nvidia_rotator)
            
            logger.info(f"[ROUTER] NVIDIA API response type: {type(data)}, keys: {list(data.keys()) if isinstance(data, dict) else 'Not a dict'}")
            content = data["choices"][0]["message"]["content"]
            if not content or content.strip() == "":
                logger.warning(f"Empty content from NVIDIA model: {data}")
                raise Exception("Empty content from NVIDIA")
            return content
        except Exception as e:
            logger.warning(f"NVIDIA model {model} failed: {e}. Attempting fallback...")
            
            # Fallback: NVIDIA_SMALL → Try a different NVIDIA model or basic response
            if model == NVIDIA_SMALL:
                logger.info(f"Falling back from {model} to basic response")
                return "I'm experiencing technical difficulties with the AI model. Please try again later."
            else:
                logger.error(f"No fallback defined for NVIDIA model: {model}")
                return "I couldn't parse the model response."

    elif provider == "qwen":
        # Use Qwen for reasoning tasks with fallback
        try:
            return await qwen_chat_completion(system_prompt, user_prompt, nvidia_rotator, user_id, context)
        except Exception as e:
            logger.warning(f"Qwen model failed: {e}. Attempting fallback...")
            # Fallback: Qwen → NVIDIA_SMALL
            logger.info("Falling back from Qwen to NVIDIA_SMALL")
            fallback_selection = {"provider": "nvidia", "model": NVIDIA_SMALL}
            return await generate_answer_with_model(fallback_selection, system_prompt, user_prompt, gemini_rotator, nvidia_rotator, user_id, context)
    elif provider == "nvidia_large":
        # Use NVIDIA Large (GPT-OSS) for hard/long context tasks with fallback
        try:
            return await nvidia_large_chat_completion(system_prompt, user_prompt, nvidia_rotator, user_id, context)
        except Exception as e:
            logger.warning(f"NVIDIA_LARGE model failed: {e}. Attempting fallback...")
            # Fallback: NVIDIA_LARGE → NVIDIA_SMALL
            logger.info("Falling back from NVIDIA_LARGE to NVIDIA_SMALL")
            fallback_selection = {"provider": "nvidia", "model": NVIDIA_SMALL}
            return await generate_answer_with_model(fallback_selection, system_prompt, user_prompt, gemini_rotator, nvidia_rotator, user_id, context)
    elif provider == "nvidia_coder":
        # Use NVIDIA Coder for code generation tasks with fallback
        try:
            from helpers.coder import nvidia_coder_completion
            return await nvidia_coder_completion(system_prompt, user_prompt, nvidia_rotator, user_id, context)
        except Exception as e:
            logger.warning(f"NVIDIA_CODER model failed: {e}. Attempting fallback...")
            # Fallback: NVIDIA_CODER → NVIDIA_SMALL
            logger.info("Falling back from NVIDIA_CODER to NVIDIA_SMALL")
            fallback_selection = {"provider": "nvidia", "model": NVIDIA_SMALL}
            return await generate_answer_with_model(fallback_selection, system_prompt, user_prompt, gemini_rotator, nvidia_rotator, user_id, context)

    return "Unsupported provider."


async def qwen_chat_completion(system_prompt: str, user_prompt: str, nvidia_rotator: APIKeyRotator, user_id: str = None, context: str = "") -> str:
    """
    Qwen chat completion with thinking mode enabled.
    Uses the NVIDIA API rotator for key management.
    """
    # Track model usage for analytics
    try:
        from utils.analytics import get_analytics_tracker
        tracker = get_analytics_tracker()
        if tracker and user_id:
            await tracker.track_model_usage(
                user_id=user_id,
                model_name=os.getenv("NVIDIA_MEDIUM", "qwen/qwen3-next-80b-a3b-thinking"),
                provider="nvidia",
                context=context or "qwen_completion",
                metadata={"system_prompt_length": len(system_prompt), "user_prompt_length": len(user_prompt)}
            )
    except Exception as e:
        logger.debug(f"[ROUTER] Analytics tracking failed: {e}")
    key = nvidia_rotator.get_key() or ""
    url = "https://integrate.api.nvidia.com/v1/chat/completions"
    
    payload = {
        "model": NVIDIA_MEDIUM,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        "temperature": 0.6,
        "top_p": 0.7,
        "max_tokens": 8192,
        "stream": True
    }
    
    headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
    
    logger.info(f"[QWEN] API call - Model: {NVIDIA_MEDIUM}, Key present: {bool(key)}")
    logger.info(f"[QWEN] System prompt length: {len(system_prompt)}, User prompt length: {len(user_prompt)}")
    
    try:
        # For streaming, we need to handle the response differently
        import httpx
        async with httpx.AsyncClient(timeout=60) as client:
            response = await client.post(url, headers=headers, json=payload)
            
            if response.status_code in (401, 403, 429) or (500 <= response.status_code < 600):
                logger.warning(f"HTTP {response.status_code} from Qwen provider. Rotating key and retrying")
                nvidia_rotator.rotate()
                # Retry once with new key
                key = nvidia_rotator.get_key() or ""
                headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
                response = await client.post(url, headers=headers, json=payload)
            
            response.raise_for_status()
            
            # Handle streaming response
            content = ""
            async for line in response.aiter_lines():
                if line.startswith("data: "):
                    data = line[6:]  # Remove "data: " prefix
                    if data.strip() == "[DONE]":
                        break
                    
                    try:
                        import json
                        chunk_data = json.loads(data)
                        if "choices" in chunk_data and len(chunk_data["choices"]) > 0:
                            delta = chunk_data["choices"][0].get("delta", {})
                            
                            # Handle reasoning content (thinking)
                            reasoning = delta.get("reasoning_content")
                            if reasoning:
                                logger.debug(f"[QWEN] Reasoning: {reasoning}")
                            
                            # Handle regular content
                            chunk_content = delta.get("content")
                            if chunk_content:
                                content += chunk_content
                    except json.JSONDecodeError:
                        continue
            
            if not content or content.strip() == "":
                logger.warning(f"Empty content from Qwen model")
                return "I received an empty response from the model."
            
            return content.strip()
            
    except Exception as e:
        logger.warning(f"Qwen API error: {e}")
        return "I couldn't process the request with Qwen model."


async def nvidia_large_chat_completion(system_prompt: str, user_prompt: str, nvidia_rotator: APIKeyRotator, user_id: str = None, context: str = "") -> str:
    """
    NVIDIA Large (GPT-OSS) chat completion for hard/long context tasks.
    Uses the NVIDIA API rotator for key management.
    """
    # Track model usage for analytics
    try:
        from utils.analytics import get_analytics_tracker
        tracker = get_analytics_tracker()
        if tracker and user_id:
            await tracker.track_model_usage(
                user_id=user_id,
                model_name=os.getenv("NVIDIA_LARGE", "openai/gpt-oss-120b"),
                provider="nvidia_large",
                context=context or "nvidia_large_completion",
                metadata={"system_prompt_length": len(system_prompt), "user_prompt_length": len(user_prompt)}
            )
    except Exception as e:
        logger.debug(f"[ROUTER] Analytics tracking failed: {e}")
    key = nvidia_rotator.get_key() or ""
    url = "https://integrate.api.nvidia.com/v1/chat/completions"
    
    payload = {
        "model": NVIDIA_LARGE,
        "messages": [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        "temperature": 1.0,
        "top_p": 1.0,
        "max_tokens": 4096,
        "stream": True
    }
    
    headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
    
    logger.info(f"[NVIDIA_LARGE] API call - Model: {NVIDIA_LARGE}, Key present: {bool(key)}")
    logger.info(f"[NVIDIA_LARGE] System prompt length: {len(system_prompt)}, User prompt length: {len(user_prompt)}")
    
    try:
        # For streaming, we need to handle the response differently
        import httpx
        async with httpx.AsyncClient(timeout=60) as client:
            response = await client.post(url, headers=headers, json=payload)
            
            if response.status_code in (401, 403, 429) or (500 <= response.status_code < 600):
                logger.warning(f"HTTP {response.status_code} from NVIDIA Large provider. Rotating key and retrying")
                nvidia_rotator.rotate()
                # Retry once with new key
                key = nvidia_rotator.get_key() or ""
                headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"}
                response = await client.post(url, headers=headers, json=payload)
            
            response.raise_for_status()
            
            # Handle streaming response
            content = ""
            async for line in response.aiter_lines():
                if line.startswith("data: "):
                    data = line[6:]  # Remove "data: " prefix
                    if data.strip() == "[DONE]":
                        break
                    
                    try:
                        import json
                        chunk_data = json.loads(data)
                        if "choices" in chunk_data and len(chunk_data["choices"]) > 0:
                            delta = chunk_data["choices"][0].get("delta", {})
                            
                            # Handle reasoning content (thinking)
                            reasoning = delta.get("reasoning_content")
                            if reasoning:
                                logger.debug(f"[NVIDIA_LARGE] Reasoning: {reasoning}")
                            
                            # Handle regular content
                            chunk_content = delta.get("content")
                            if chunk_content:
                                content += chunk_content
                    except json.JSONDecodeError:
                        continue
            
            if not content or content.strip() == "":
                logger.warning(f"Empty content from NVIDIA Large model")
                return "I received an empty response from the model."
            
            return content.strip()
            
    except Exception as e:
        logger.warning(f"NVIDIA Large API error: {e}")
        return "I couldn't process the request with NVIDIA Large model."