File size: 9,565 Bytes
e8051be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""

Multi-LLM Handler with failover support

Uses Groq, Gemini, and OpenAI with automatic failover for reliability

"""

import asyncio
import re
import time
from typing import Optional, Dict, Any, List
import os
import requests
import google.generativeai as genai
import openai
from dotenv import load_dotenv
from config.config import get_provider_configs

load_dotenv()

class MultiLLMHandler:
    """Multi-LLM handler with automatic failover across providers."""
    
    def __init__(self):
        """Initialize the multi-LLM handler with all available providers."""
        self.providers = get_provider_configs()
        self.current_provider = None
        self.current_config = None

        # Initialize the first available provider (prefer Gemini/OpenAI for general RAG)
        self._initialize_provider()

        print(f"✅ Initialized Multi-LLM Handler with {self.provider.upper()}: {self.model_name}")

    def _initialize_provider(self):
        """Initialize the first available provider."""
        # Prefer Gemini first for general text tasks
        if self.providers["gemini"]:
            self.current_provider = "gemini"
            self.current_config = self.providers["gemini"][0]
            genai.configure(api_key=self.current_config["api_key"])
        # Then OpenAI
        elif self.providers["openai"]:
            self.current_provider = "openai"
            self.current_config = self.providers["openai"][0]
            openai.api_key = self.current_config["api_key"]
        # Finally Groq
        elif self.providers["groq"]:
            self.current_provider = "groq"
            self.current_config = self.providers["groq"][0]
        else:
            raise ValueError("No LLM providers available with valid API keys")

    @property
    def provider(self):
        """Get current provider name."""
        return self.current_provider

    @property
    def model_name(self):
        """Get current model name."""
        return self.current_config["model"] if self.current_config else "unknown"

    async def _call_groq(self, prompt: str, temperature: float, max_tokens: int) -> str:
        """Call Groq API."""
        headers = {
            "Authorization": f"Bearer {self.current_config['api_key']}",
            "Content-Type": "application/json"
        }
        
        data = {
            "model": self.current_config["model"],
            "messages": [{"role": "user", "content": prompt}],
            "temperature": temperature,
            "max_tokens": max_tokens
        }

        # Hide reasoning tokens (e.g., <think>) for Qwen reasoning models
        try:
            model_name = (self.current_config.get("model") or "").lower()
            if "qwen" in model_name:
                # Per request, use the chat completion parameter to hide reasoning content
                data["reasoning_effort"] = "hidden"
        except Exception:
            # Be resilient if config shape changes
            pass
        
        response = requests.post(
            "https://api.groq.com/openai/v1/chat/completions", 
            headers=headers, 
            json=data, 
            timeout=30
        )
        response.raise_for_status()
        
        result = response.json()
        text = result["choices"][0]["message"]["content"].strip()
        # Safety net: strip any <think>...</think> blocks if present
        try:
            model_name = (self.current_config.get("model") or "").lower()
            if "qwen" in model_name and "<think>" in text.lower():
                text = re.sub(r"<think>[\s\S]*?</think>", "", text, flags=re.IGNORECASE).strip()
        except Exception:
            pass
        return text

    async def _call_gemini(self, prompt: str, temperature: float, max_tokens: int) -> str:
        """Call Gemini API."""
        model = genai.GenerativeModel(self.current_config["model"])
        
        generation_config = genai.types.GenerationConfig(
            temperature=temperature,
            max_output_tokens=max_tokens
        )
        
        response = await asyncio.to_thread(
            model.generate_content,
            prompt,
            generation_config=generation_config
        )
        return response.text.strip()

    async def _call_openai(self, prompt: str, temperature: float, max_tokens: int) -> str:
        """Call OpenAI API."""
        response = await asyncio.to_thread(
            openai.ChatCompletion.create,
            model=self.current_config["model"],
            messages=[{"role": "user", "content": prompt}],
            temperature=temperature,
            max_tokens=max_tokens
        )
        return response.choices[0].message.content.strip()

    async def _try_with_failover(self, prompt: str, temperature: float, max_tokens: int) -> str:
        """Try to generate text with automatic failover."""
        # Get all available providers in order
        provider_order = []
        # Prefer Gemini -> OpenAI -> Groq for general text
        if self.providers["gemini"]:
            provider_order.extend([("gemini", config) for config in self.providers["gemini"]])
        if self.providers["openai"]:
            provider_order.extend([("openai", config) for config in self.providers["openai"]])
        if self.providers["groq"]:
            provider_order.extend([("groq", config) for config in self.providers["groq"]])
        
        last_error = None
        
        for provider_name, config in provider_order:
            try:
                # Set current provider
                old_provider = self.current_provider
                old_config = self.current_config
                
                self.current_provider = provider_name
                self.current_config = config
                
                # Configure API if needed
                if provider_name == "gemini":
                    genai.configure(api_key=config["api_key"])
                elif provider_name == "openai":
                    openai.api_key = config["api_key"]
                
                # Try the API call
                if provider_name == "groq":
                    return await self._call_groq(prompt, temperature, max_tokens)
                elif provider_name == "gemini":
                    return await self._call_gemini(prompt, temperature, max_tokens)
                elif provider_name == "openai":
                    return await self._call_openai(prompt, temperature, max_tokens)
                    
            except Exception as e:
                print(f"⚠️ {provider_name.upper()} ({config['name']}) failed: {str(e)}")
                last_error = e
                
                # Restore previous provider
                self.current_provider = old_provider
                self.current_config = old_config
                continue
        
        # If all providers failed
        raise RuntimeError(f"All LLM providers failed. Last error: {last_error}")

    async def generate_text(self, 

                          prompt: Optional[str] = None,

                          system_prompt: Optional[str] = None,

                          user_prompt: Optional[str] = None,

                          temperature: Optional[float] = 0.4,

                          max_tokens: Optional[int] = 1200) -> str:
        """Generate text using multi-LLM with failover."""
        # Handle both single prompt and system/user prompt formats
        if prompt:
            final_prompt = prompt
        elif system_prompt and user_prompt:
            final_prompt = f"{system_prompt}\n\n{user_prompt}"
        elif user_prompt:
            final_prompt = user_prompt
        else:
            raise ValueError("Must provide either 'prompt' or 'user_prompt'")
        
        return await self._try_with_failover(
            final_prompt, 
            temperature or 0.4, 
            max_tokens or 1200
        )

    async def generate_simple(self, 

                            prompt: str, 

                            temperature: Optional[float] = 0.4,

                            max_tokens: Optional[int] = 1200) -> str:
        """Simple text generation (alias for generate_text for compatibility)."""
        return await self.generate_text(prompt=prompt, temperature=temperature, max_tokens=max_tokens)

    def get_provider_info(self) -> Dict[str, Any]:
        """Get information about the current provider."""
        return {
            "provider": self.current_provider,
            "model": self.model_name,
            "config_name": self.current_config["name"] if self.current_config else "none",
            "available_providers": {
                "groq": len(self.providers["groq"]),
                "gemini": len(self.providers["gemini"]),
                "openai": len(self.providers["openai"])
            }
        }

    async def test_connection(self) -> bool:
        """Test the connection to the current LLM provider."""
        try:
            test_prompt = "Say 'Hello' if you can read this."
            response = await self.generate_simple(test_prompt, temperature=0.1, max_tokens=10)
            return "hello" in response.lower()
        except Exception as e:
            print(f"❌ Connection test failed: {str(e)}")
            return False

# Create a global instance
llm_handler = MultiLLMHandler()