metisllm-dashboard / llm_handler /openai_handler.py
Gateston Johns
first real commit
9041389
from __future__ import annotations
import openai
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
import os
from typing import List, Optional, ClassVar
import enum
from llm_handler.llm_interface import LLMInterface, DefaultEnumMeta
class ChatModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
GPT_3_5 = 'gpt-3.5-turbo-1106'
GPT_4 = 'gpt-4'
GPT_4_TURBO = 'gpt-4-1106-preview'
GPT_4_O = 'gpt-4o'
class EmbeddingModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
SMALL_3 = 'text-embedding-3-small'
ADA_002 = 'text-embedding-ada-002'
LARGE = 'text-embedding-3-large'
class OpenAIHandler(LLMInterface):
_ENV_KEY_NAME: ClassVar[str] = 'OPENAI_API_KEY'
_client: openai.Client
def __init__(self, openai_api_key: Optional[str] = None):
_openai_api_key = openai_api_key or os.environ.get(self._ENV_KEY_NAME)
if not _openai_api_key:
raise ValueError(f'{self._ENV_KEY_NAME} not set')
openai.api_key = _openai_api_key
self._client = openai.Client()
def get_chat_completion( # type: ignore
self,
messages: List[ChatCompletionMessageParam],
model: ChatModelVersion = ChatModelVersion.GPT_4_O,
temperature: float = 0.2,
**kwargs) -> str:
response = self._client.chat.completions.create(model=model.value,
messages=messages,
temperature=temperature,
**kwargs)
responses: List[str] = []
for choice in response.choices:
if choice.finish_reason != 'stop' or not choice.message.content:
raise ValueError(f'Choice did not complete correctly: {choice}')
responses.append(choice.message.content)
if len(responses) != 1:
raise ValueError(f'Expected one response, got {len(responses)}: {responses}')
return responses[0]
def get_text_embedding( # type: ignore
self, input: str, model: EmbeddingModelVersion) -> List[float]:
response = self._client.embeddings.create(model=model.value,
encoding_format='float',
input=input)
if not response.data:
raise ValueError(f'No embedding in response: {response}')
elif len(response.data) != 1:
raise ValueError(f'More than one embedding in response: {response}')
return response.data[0].embedding