Spaces:
Running
Running
# ai_providers.py | |
from abc import ABC, abstractmethod | |
from typing import Dict, Any, Optional | |
import os | |
from dotenv import load_dotenv | |
import time | |
class AIProvider(ABC): | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
pass | |
def get_config(self) -> Dict[str, Any]: | |
pass | |
class GroqProvider(AIProvider): | |
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"): | |
self.api_key = api_key | |
self.model = model | |
def get_completion(self, prompt: str, **kwargs) -> str: | |
from groq import Groq | |
client = Groq(api_key=self.api_key) | |
# Configure default parameters | |
params = { | |
"temperature": kwargs.get("temperature", 0.7), | |
"max_tokens": kwargs.get("max_tokens", 4000), | |
"top_p": kwargs.get("top_p", 1.0), | |
"stop": kwargs.get("stop", None) | |
} | |
# Add retry logic for robustness | |
max_retries = 3 | |
retry_delay = 1 | |
for attempt in range(max_retries): | |
try: | |
completion = client.chat.completions.create( | |
model=self.model, | |
messages=[{"role": "user", "content": prompt}], | |
**params | |
) | |
return completion.choices[0].message.content | |
except Exception as e: | |
if attempt == max_retries - 1: | |
raise e | |
time.sleep(retry_delay * (attempt + 1)) | |
def get_config(self) -> Dict[str, Any]: | |
return { | |
"config_list": [{ | |
"model": self.model, | |
"api_key": self.api_key, | |
"temperature": 0.7, | |
"max_tokens": 4000 | |
}] | |
} | |
class AIProviderFactory: | |
def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider: | |
if not api_key: | |
load_dotenv() | |
if provider_type != "groq": | |
raise ValueError("Only Groq provider is supported.") | |
return GroqProvider( | |
api_key=api_key or os.getenv("GROQ_API_KEY"), | |
model=model or "deepseek-ai/deepseek-r1-distill-llama-70b" | |
) |