File size: 2,642 Bytes
9041389
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations
import openai
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
import os
from typing import List, Optional, ClassVar
import enum

from llm_handler.llm_interface import LLMInterface, DefaultEnumMeta


class ChatModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
    GPT_3_5 = 'gpt-3.5-turbo-1106'
    GPT_4 = 'gpt-4'
    GPT_4_TURBO = 'gpt-4-1106-preview'
    GPT_4_O = 'gpt-4o'


class EmbeddingModelVersion(enum.Enum, metaclass=DefaultEnumMeta):
    SMALL_3 = 'text-embedding-3-small'
    ADA_002 = 'text-embedding-ada-002'
    LARGE = 'text-embedding-3-large'


class OpenAIHandler(LLMInterface):

    _ENV_KEY_NAME: ClassVar[str] = 'OPENAI_API_KEY'
    _client: openai.Client

    def __init__(self, openai_api_key: Optional[str] = None):
        _openai_api_key = openai_api_key or os.environ.get(self._ENV_KEY_NAME)
        if not _openai_api_key:
            raise ValueError(f'{self._ENV_KEY_NAME} not set')
        openai.api_key = _openai_api_key
        self._client = openai.Client()

    def get_chat_completion(  # type: ignore
            self,
            messages: List[ChatCompletionMessageParam],
            model: ChatModelVersion = ChatModelVersion.GPT_4_O,
            temperature: float = 0.2,
            **kwargs) -> str:
        response = self._client.chat.completions.create(model=model.value,
                                                        messages=messages,
                                                        temperature=temperature,
                                                        **kwargs)
        responses: List[str] = []
        for choice in response.choices:
            if choice.finish_reason != 'stop' or not choice.message.content:
                raise ValueError(f'Choice did not complete correctly: {choice}')
            responses.append(choice.message.content)
        if len(responses) != 1:
            raise ValueError(f'Expected one response, got {len(responses)}: {responses}')
        return responses[0]

    def get_text_embedding(  # type: ignore
            self, input: str, model: EmbeddingModelVersion) -> List[float]:
        response = self._client.embeddings.create(model=model.value,
                                                  encoding_format='float',
                                                  input=input)
        if not response.data:
            raise ValueError(f'No embedding in response: {response}')
        elif len(response.data) != 1:
            raise ValueError(f'More than one embedding in response: {response}')
        return response.data[0].embedding