File size: 6,191 Bytes
ab7e98d
 
 
 
 
fc0299f
ab7e98d
 
 
 
 
 
 
fc0299f
1c091c0
fc0299f
 
 
 
ab7e98d
 
 
 
 
 
 
 
 
 
 
1c091c0
ab7e98d
1c091c0
fc0299f
ab7e98d
 
fc0299f
2d2ab61
ab7e98d
 
1c091c0
ab7e98d
 
 
 
 
1c091c0
ab7e98d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c091c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc0299f
 
 
 
2d2ab61
fc0299f
 
1c091c0
fc0299f
2d2ab61
1c091c0
2d2ab61
 
 
1c091c0
 
 
 
 
 
 
 
 
fc0299f
 
1c091c0
fc0299f
2d2ab61
 
fc0299f
2d2ab61
 
 
 
 
 
fc0299f
 
 
 
1c091c0
fc0299f
 
 
 
 
 
 
2d2ab61
ab7e98d
1c091c0
ab7e98d
 
 
1c091c0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
LLM Provider Interface for Flare
"""
import os
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any
import httpx
from openai import AsyncOpenAI
from utils import log

class LLMInterface(ABC):
    """Abstract base class for LLM providers"""
    
    def __init__(self, settings: Dict[str, Any] = None):
        """Initialize with provider settings"""
        self.settings = settings or {}
        self.internal_prompt = self.settings.get("internal_prompt", "")
        self.parameter_collection_config = self.settings.get("parameter_collection_config", {})
    
    @abstractmethod
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response from LLM"""
        pass
    
    @abstractmethod
    async def startup(self, project_config: Dict) -> bool:
        """Initialize LLM with project config"""
        pass

class SparkLLM(LLMInterface):
    """Spark LLM integration"""
    
    def __init__(self, spark_endpoint: str, spark_token: str, provider_variant: str = "cloud", settings: Dict[str, Any] = None):
        super().__init__(settings)
        self.spark_endpoint = spark_endpoint.rstrip("/")
        self.spark_token = spark_token
        self.provider_variant = provider_variant
        log(f"πŸ”Œ SparkLLM initialized with endpoint: {self.spark_endpoint}")
    
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response from Spark LLM"""
        headers = {
            "Authorization": f"Bearer {self.spark_token}",
            "Content-Type": "application/json"
        }
        
        # Build payload
        payload = {
            "system_prompt": system_prompt,
            "user_input": user_input,
            "context": context
        }
        
        try:
            async with httpx.AsyncClient(timeout=60) as client:
                response = await client.post(
                    f"{self.spark_endpoint}/generate",
                    json=payload,
                    headers=headers
                )
                response.raise_for_status()
                data = response.json()
                
                # Try different response fields
                raw = data.get("model_answer", "").strip()
                if not raw:
                    raw = (data.get("assistant") or data.get("text", "")).strip()
                
                return raw
        except Exception as e:
            log(f"❌ Spark error: {e}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """Send startup request to Spark"""
        headers = {
            "Authorization": f"Bearer {self.spark_token}",
            "Content-Type": "application/json"
        }
        
        # Extract required fields from project config
        body = {
            "work_mode": self.provider_variant,
            "cloud_token": self.spark_token,
            "project_name": project_config.get("name"),
            "project_version": project_config.get("version_id"),
            "repo_id": project_config.get("repo_id"),
            "generation_config": project_config.get("generation_config", {}),
            "use_fine_tune": project_config.get("use_fine_tune", False),
            "fine_tune_zip": project_config.get("fine_tune_zip", "")
        }
        
        try:
            async with httpx.AsyncClient(timeout=10) as client:
                response = await client.post(
                    f"{self.spark_endpoint}/startup",
                    json=body,
                    headers=headers
                )
                
                if response.status_code >= 400:
                    log(f"❌ Spark startup failed: {response.status_code} - {response.text}")
                    return False
                
                log(f"βœ… Spark acknowledged startup ({response.status_code})")
                return True
        except Exception as e:
            log(f"⚠️ Spark startup error: {e}")
            return False

class GPT4oLLM(LLMInterface):
    """OpenAI GPT integration"""
    
    def __init__(self, api_key: str, model: str = "gpt-4o-mini", settings: Dict[str, Any] = None):
        super().__init__(settings)
        self.api_key = api_key
        self.model = self._map_model_name(model)
        self.client = AsyncOpenAI(api_key=api_key)
        
        # Extract model-specific settings
        self.temperature = settings.get("temperature", 0.7) if settings else 0.7
        self.max_tokens = settings.get("max_tokens", 4096) if settings else 4096
        
        log(f"βœ… Initialized GPT LLM with model: {self.model}")
    
    def _map_model_name(self, model: str) -> str:
        """Map provider name to actual model name"""
        mappings = {
            "gpt4o": "gpt-4",
            "gpt4o-mini": "gpt-4o-mini"
        }
        return mappings.get(model, model)
    
    async def generate(self, system_prompt: str, user_input: str, context: List[Dict]) -> str:
        """Generate response from OpenAI"""
        try:
            # Build messages
            messages = [{"role": "system", "content": system_prompt}]
            
            # Add context
            for msg in context:
                messages.append({
                    "role": msg.get("role", "user"),
                    "content": msg.get("content", "")
                })
            
            # Add current user input
            messages.append({"role": "user", "content": user_input})
            
            # Call OpenAI
            response = await self.client.chat.completions.create(
                model=self.model,
                messages=messages,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            )
            
            return response.choices[0].message.content.strip()
        except Exception as e:
            log(f"❌ OpenAI error: {e}")
            raise
    
    async def startup(self, project_config: Dict) -> bool:
        """GPT doesn't need startup, always return True"""
        log("βœ… GPT provider ready (no startup needed)")
        return True