Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import openai | |
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam | |
import os | |
from typing import List, Optional, ClassVar | |
import enum | |
from llm_handler.llm_interface import LLMInterface, DefaultEnumMeta | |
class ChatModelVersion(enum.Enum, metaclass=DefaultEnumMeta): | |
GPT_3_5 = 'gpt-3.5-turbo-1106' | |
GPT_4 = 'gpt-4' | |
GPT_4_TURBO = 'gpt-4-1106-preview' | |
GPT_4_O = 'gpt-4o' | |
class EmbeddingModelVersion(enum.Enum, metaclass=DefaultEnumMeta): | |
SMALL_3 = 'text-embedding-3-small' | |
ADA_002 = 'text-embedding-ada-002' | |
LARGE = 'text-embedding-3-large' | |
class OpenAIHandler(LLMInterface): | |
_ENV_KEY_NAME: ClassVar[str] = 'OPENAI_API_KEY' | |
_client: openai.Client | |
def __init__(self, openai_api_key: Optional[str] = None): | |
_openai_api_key = openai_api_key or os.environ.get(self._ENV_KEY_NAME) | |
if not _openai_api_key: | |
raise ValueError(f'{self._ENV_KEY_NAME} not set') | |
openai.api_key = _openai_api_key | |
self._client = openai.Client() | |
def get_chat_completion( # type: ignore | |
self, | |
messages: List[ChatCompletionMessageParam], | |
model: ChatModelVersion = ChatModelVersion.GPT_4_O, | |
temperature: float = 0.2, | |
**kwargs) -> str: | |
response = self._client.chat.completions.create(model=model.value, | |
messages=messages, | |
temperature=temperature, | |
**kwargs) | |
responses: List[str] = [] | |
for choice in response.choices: | |
if choice.finish_reason != 'stop' or not choice.message.content: | |
raise ValueError(f'Choice did not complete correctly: {choice}') | |
responses.append(choice.message.content) | |
if len(responses) != 1: | |
raise ValueError(f'Expected one response, got {len(responses)}: {responses}') | |
return responses[0] | |
def get_text_embedding( # type: ignore | |
self, input: str, model: EmbeddingModelVersion) -> List[float]: | |
response = self._client.embeddings.create(model=model.value, | |
encoding_format='float', | |
input=input) | |
if not response.data: | |
raise ValueError(f'No embedding in response: {response}') | |
elif len(response.data) != 1: | |
raise ValueError(f'More than one embedding in response: {response}') | |
return response.data[0].embedding | |