paint / api.py
baconnier's picture
Upload 10 files
05f2374 verified
import json
import logging
from openai import OpenAI
from typing import Dict, Any, Optional
import gradio as gr
from prompts import PROMPT_ANALYZER_TEMPLATE
import time
logger = logging.getLogger(__name__)
FALLBACK_MODELS = [
"mixtral-8x7b-32768",
"llama-3.1-70b-versatile",
"llama-3.1-8b-instant",
"llama3-70b-8192",
"llama3-8b-8192"
]
class ModelManager:
def __init__(self):
self.current_model_index = 0
self.max_retries = len(FALLBACK_MODELS)
@property
def current_model(self) -> str:
return FALLBACK_MODELS[self.current_model_index]
def next_model(self) -> str:
self.current_model_index = (self.current_model_index + 1) % len(FALLBACK_MODELS)
logger.info(f"Switching to model: {self.current_model}")
return self.current_model
class PromptEnhancementAPI:
def __init__(self, api_key: str, base_url: Optional[str] = None):
self.client = OpenAI(
api_key=api_key,
base_url=base_url or "https://api.groq.com/openai/v1"
)
self.model_manager = ModelManager()
def _try_parse_json(self, content: str, retries: int = 0) -> Dict[str, Any]:
try:
result = json.loads(content.strip().lstrip('\n'))
if not isinstance(result, dict):
raise ValueError("Response is not a valid JSON object")
return result
except (json.JSONDecodeError, ValueError) as e:
if retries < self.model_manager.max_retries - 1:
logger.warning(f"JSON parsing failed with model {self.model_manager.current_model}. Switching models...")
self.model_manager.next_model()
raise e
logger.error(f"JSON parsing failed with all models: {str(e)}")
raise
def generate_enhancement(self, system_prompt: str, user_prompt: str, user_directive: str = "", state: Optional[Dict] = None) -> Dict[str, Any]:
retries = 0
last_error = None
while retries < self.model_manager.max_retries:
try:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
if user_directive:
messages.append({"role": "user", "content": f"User directive: {user_directive}"})
if state:
messages.append({
"role": "assistant",
"content": json.dumps(state)
})
response = self.client.chat.completions.create(
model=self.model_manager.current_model,
messages=messages,
temperature=0.7,
max_tokens=4000,
response_format={"type": "json_object"}
)
result = self._try_parse_json(response.choices[0].message.content, retries)
return result
except (json.JSONDecodeError, ValueError) as e:
last_error = e
retries += 1
if retries < self.model_manager.max_retries:
logger.warning(f"Attempt {retries} failed. Switching models and retrying...")
time.sleep(1) # Brief pause before retry
continue
break
except Exception as e:
logger.error(f"API error: {str(e)}")
if "rate limit" in str(e).lower():
if retries < self.model_manager.max_retries - 1:
self.model_manager.next_model()
retries += 1
time.sleep(1)
continue
raise gr.Error(f"API request failed: {str(e)}")
logger.error(f"All models failed to generate valid JSON: {str(last_error)}")
return create_error_response(user_prompt, user_directive)
class PromptEnhancementSystem:
def __init__(self, api_key: str, base_url: Optional[str] = None):
self.api = PromptEnhancementAPI(api_key, base_url)
self.current_state = None
self.history = []
def start_session(self, prompt: str, user_directive: str = "") -> Dict[str, Any]:
formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format(
input_prompt=prompt,
user_directive=user_directive
)
result = self.api.generate_enhancement(
system_prompt=formatted_system_prompt,
user_prompt=prompt,
user_directive=user_directive
)
self.current_state = result
self.history = [result]
return result
def apply_enhancement(self, choice: str, user_directive: str = "") -> Dict[str, Any]:
formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format(
input_prompt=choice,
user_directive=user_directive
)
result = self.api.generate_enhancement(
system_prompt=formatted_system_prompt,
user_prompt=choice,
user_directive=user_directive,
state=self.current_state
)
self.current_state = result
self.history.append(result)
return result