import json import logging from openai import OpenAI from typing import Dict, Any, Optional import gradio as gr from prompts import PROMPT_ANALYZER_TEMPLATE import time logger = logging.getLogger(__name__) FALLBACK_MODELS = [ "mixtral-8x7b-32768", "llama-3.1-70b-versatile", "llama-3.1-8b-instant", "llama3-70b-8192", "llama3-8b-8192" ] class ModelManager: def __init__(self): self.current_model_index = 0 self.max_retries = len(FALLBACK_MODELS) @property def current_model(self) -> str: return FALLBACK_MODELS[self.current_model_index] def next_model(self) -> str: self.current_model_index = (self.current_model_index + 1) % len(FALLBACK_MODELS) logger.info(f"Switching to model: {self.current_model}") return self.current_model class PromptEnhancementAPI: def __init__(self, api_key: str, base_url: Optional[str] = None): self.client = OpenAI( api_key=api_key, base_url=base_url or "https://api.groq.com/openai/v1" ) self.model_manager = ModelManager() def _try_parse_json(self, content: str, retries: int = 0) -> Dict[str, Any]: try: result = json.loads(content.strip().lstrip('\n')) if not isinstance(result, dict): raise ValueError("Response is not a valid JSON object") return result except (json.JSONDecodeError, ValueError) as e: if retries < self.model_manager.max_retries - 1: logger.warning(f"JSON parsing failed with model {self.model_manager.current_model}. Switching models...") self.model_manager.next_model() raise e logger.error(f"JSON parsing failed with all models: {str(e)}") raise def generate_enhancement(self, system_prompt: str, user_prompt: str, user_directive: str = "", state: Optional[Dict] = None) -> Dict[str, Any]: retries = 0 last_error = None while retries < self.model_manager.max_retries: try: messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} ] if user_directive: messages.append({"role": "user", "content": f"User directive: {user_directive}"}) if state: messages.append({ "role": "assistant", "content": json.dumps(state) }) response = self.client.chat.completions.create( model=self.model_manager.current_model, messages=messages, temperature=0.7, max_tokens=4000, response_format={"type": "json_object"} ) result = self._try_parse_json(response.choices[0].message.content, retries) return result except (json.JSONDecodeError, ValueError) as e: last_error = e retries += 1 if retries < self.model_manager.max_retries: logger.warning(f"Attempt {retries} failed. Switching models and retrying...") time.sleep(1) # Brief pause before retry continue break except Exception as e: logger.error(f"API error: {str(e)}") if "rate limit" in str(e).lower(): if retries < self.model_manager.max_retries - 1: self.model_manager.next_model() retries += 1 time.sleep(1) continue raise gr.Error(f"API request failed: {str(e)}") logger.error(f"All models failed to generate valid JSON: {str(last_error)}") return create_error_response(user_prompt, user_directive) class PromptEnhancementSystem: def __init__(self, api_key: str, base_url: Optional[str] = None): self.api = PromptEnhancementAPI(api_key, base_url) self.current_state = None self.history = [] def start_session(self, prompt: str, user_directive: str = "") -> Dict[str, Any]: formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format( input_prompt=prompt, user_directive=user_directive ) result = self.api.generate_enhancement( system_prompt=formatted_system_prompt, user_prompt=prompt, user_directive=user_directive ) self.current_state = result self.history = [result] return result def apply_enhancement(self, choice: str, user_directive: str = "") -> Dict[str, Any]: formatted_system_prompt = PROMPT_ANALYZER_TEMPLATE.format( input_prompt=choice, user_directive=user_directive ) result = self.api.generate_enhancement( system_prompt=formatted_system_prompt, user_prompt=choice, user_directive=user_directive, state=self.current_state ) self.current_state = result self.history.append(result) return result