|
import os |
|
from openai import OpenAI |
|
import logging |
|
from typing import List, Dict, Any |
|
import json |
|
import time |
|
from modals.inputs import LLMConfig |
|
|
|
|
|
class LLM: |
|
def __init__(self, max_retries: int = 3, retry_delay: float = 1.0): |
|
self.max_retries = max_retries |
|
self.retry_delay = retry_delay |
|
|
|
def parse_json(self, response: str) -> Dict[str, Any]: |
|
import re |
|
match = re.search(r'```json\s*(\{.*?\}|\[.*?\])\s*```', response, re.DOTALL) |
|
if match: |
|
json_str = match.group(1) |
|
try: |
|
parsed_json = json.loads(json_str) |
|
return parsed_json |
|
except json.JSONDecodeError as e: |
|
raise |
|
raise |
|
|
|
def step(self, messages: List[Dict[str, str]] = None, llm_config: LLMConfig = None) -> str: |
|
messages = messages or [] |
|
llm = OpenAI( |
|
api_key=llm_config.api_key, |
|
base_url=llm_config.base_url |
|
) |
|
for attempt in range(self.max_retries): |
|
try: |
|
response = llm.chat.completions.create( |
|
model=llm_config.model, |
|
messages=messages, |
|
temperature=0.2 |
|
) |
|
return response.choices[0].message.content |
|
except Exception as e: |
|
logging.error(f"Error in LLM step (attempt {attempt + 1}/{self.max_retries}): {e}") |
|
if attempt < self.max_retries - 1: |
|
time.sleep(self.retry_delay * (2 ** attempt)) |
|
else: |
|
raise |
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
llm_config = LLMConfig( |
|
api_key="AIzaSyCsstACK4dJx61ad2_fhWugtvCcEDcTiTE", |
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/", |
|
model="gemini-2.0-flash", |
|
) |
|
|
|
messages = [ |
|
{"role": "system", "content": "You are a helpful AI assistant."}, |
|
{"role": "user", "content": "Tell me a fun fact about space."} |
|
] |
|
|
|
llm = LLM() |
|
response = llm.step(messages, llm_config) |
|
print(response) |
|
|