File size: 3,956 Bytes
469eae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TypedDict

from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import StandardCallbackDynamicParams


class PromptManagementClient(TypedDict):
    prompt_id: str
    prompt_template: List[AllMessageValues]
    prompt_template_model: Optional[str]
    prompt_template_optional_params: Optional[Dict[str, Any]]
    completed_messages: Optional[List[AllMessageValues]]


class PromptManagementBase(ABC):
    @property
    @abstractmethod
    def integration_name(self) -> str:
        pass

    @abstractmethod
    def should_run_prompt_management(
        self,
        prompt_id: str,
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> bool:
        pass

    @abstractmethod
    def _compile_prompt_helper(
        self,
        prompt_id: str,
        prompt_variables: Optional[dict],
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> PromptManagementClient:
        pass

    def merge_messages(
        self,
        prompt_template: List[AllMessageValues],
        client_messages: List[AllMessageValues],
    ) -> List[AllMessageValues]:
        return prompt_template + client_messages

    def compile_prompt(
        self,
        prompt_id: str,
        prompt_variables: Optional[dict],
        client_messages: List[AllMessageValues],
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> PromptManagementClient:
        compiled_prompt_client = self._compile_prompt_helper(
            prompt_id=prompt_id,
            prompt_variables=prompt_variables,
            dynamic_callback_params=dynamic_callback_params,
        )

        try:
            messages = compiled_prompt_client["prompt_template"] + client_messages
        except Exception as e:
            raise ValueError(
                f"Error compiling prompt: {e}. Prompt id={prompt_id}, prompt_variables={prompt_variables}, client_messages={client_messages}, dynamic_callback_params={dynamic_callback_params}"
            )

        compiled_prompt_client["completed_messages"] = messages
        return compiled_prompt_client

    def _get_model_from_prompt(
        self, prompt_management_client: PromptManagementClient, model: str
    ) -> str:
        if prompt_management_client["prompt_template_model"] is not None:
            return prompt_management_client["prompt_template_model"]
        else:
            return model.replace("{}/".format(self.integration_name), "")

    def get_chat_completion_prompt(
        self,
        model: str,
        messages: List[AllMessageValues],
        non_default_params: dict,
        prompt_id: Optional[str],
        prompt_variables: Optional[dict],
        dynamic_callback_params: StandardCallbackDynamicParams,
    ) -> Tuple[str, List[AllMessageValues], dict]:
        if prompt_id is None:
            raise ValueError("prompt_id is required for Prompt Management Base class")
        if not self.should_run_prompt_management(
            prompt_id=prompt_id, dynamic_callback_params=dynamic_callback_params
        ):
            return model, messages, non_default_params

        prompt_template = self.compile_prompt(
            prompt_id=prompt_id,
            prompt_variables=prompt_variables,
            client_messages=messages,
            dynamic_callback_params=dynamic_callback_params,
        )

        completed_messages = prompt_template["completed_messages"] or messages

        prompt_template_optional_params = (
            prompt_template["prompt_template_optional_params"] or {}
        )

        updated_non_default_params = {
            **non_default_params,
            **prompt_template_optional_params,
        }

        model = self._get_model_from_prompt(
            prompt_management_client=prompt_template, model=model
        )

        return model, completed_messages, updated_non_default_params