LegalAI-DS / ai_providers.py
Hassankhwileh's picture
Update ai_providers.py
b6737fb verified
raw
history blame
2.37 kB
# ai_providers.py
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
import os
from dotenv import load_dotenv
import time
class AIProvider(ABC):
@abstractmethod
def get_completion(self, prompt: str, **kwargs) -> str:
pass
@abstractmethod
def get_config(self) -> Dict[str, Any]:
pass
class GroqProvider(AIProvider):
def __init__(self, api_key: str, model: str = "deepseek-ai/deepseek-r1-distill-llama-70b"):
self.api_key = api_key
self.model = model
def get_completion(self, prompt: str, **kwargs) -> str:
from groq import Groq
client = Groq(api_key=self.api_key)
# Configure default parameters
params = {
"temperature": kwargs.get("temperature", 0.7),
"max_tokens": kwargs.get("max_tokens", 4000),
"top_p": kwargs.get("top_p", 1.0),
"stop": kwargs.get("stop", None)
}
# Add retry logic for robustness
max_retries = 3
retry_delay = 1
for attempt in range(max_retries):
try:
completion = client.chat.completions.create(
model=self.model,
messages=[{"role": "user", "content": prompt}],
**params
)
return completion.choices[0].message.content
except Exception as e:
if attempt == max_retries - 1:
raise e
time.sleep(retry_delay * (attempt + 1))
def get_config(self) -> Dict[str, Any]:
return {
"config_list": [{
"model": self.model,
"api_key": self.api_key,
"temperature": 0.7,
"max_tokens": 4000
}]
}
class AIProviderFactory:
@staticmethod
def create_provider(provider_type: str, api_key: Optional[str] = None, model: Optional[str] = None) -> AIProvider:
if not api_key:
load_dotenv()
if provider_type != "groq":
raise ValueError("Only Groq provider is supported.")
return GroqProvider(
api_key=api_key or os.getenv("GROQ_API_KEY"),
model=model or "deepseek-ai/deepseek-r1-distill-llama-70b"
)