File size: 5,511 Bytes
d1009a4
 
70aa23a
d1009a4
 
 
70aa23a
faa9fb4
d1009a4
 
faa9fb4
 
 
 
 
d1009a4
 
70aa23a
d1009a4
 
 
 
 
 
70aa23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
faa9fb4
70aa23a
 
faa9fb4
 
 
 
 
 
 
 
 
 
70aa23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# import datetime
import os
import time
from functools import lru_cache
from typing import Union, Optional, List
from freeplay import Freeplay, RecordPayload, ResponseInfo, CallInfo
from freeplay.resources.prompts import FormattedPrompt
from langchain_core.messages import BaseMessage, ToolMessage

FREEPLAY_PROJECT_ID = os.getenv("FREEPLAY_PROJECT_ID")
_role_map = {
    'human': 'user',
    'ai': 'assistant',
    'tool': 'tool',
}

# @lru_cache(maxsize=1)
def _get_fp_client():
    return Freeplay(
        freeplay_api_key=os.getenv("FREEPLAY_API_KEY"),
        api_base=os.getenv("FREEPLAY_URL"),
    )


class FreeplayClient:

    def __init__(self, fp_client: Freeplay = None):
        self.fp_client = fp_client or _get_fp_client()
        self.session = None
        self.session_id = None
        # cache variables for recording
        self._prompt_cache = {}
        self._prompt_vars = None
        self._formatted_prompt = None

    def create_session(self):
        # create a Freeplay session
        self.session = self.fp_client.sessions.create()
        self.session_id = self.session.session_id
        return self

    @staticmethod
    def get_fp_client():
        return _get_fp_client()

    # retreive and format your prompt
    def get_formatted_prompt(
        self,
        template: str,
        environment: str = "latest",
        variables: dict = {},
        history: Optional[List[BaseMessage]] = None,
    ):
        """
        Get a formatted prompt from Freeplay.
        """
        formatted_prompt = self.fp_client.prompts.get_formatted(
            project_id=FREEPLAY_PROJECT_ID,
            template_name=template,
            environment=environment,
            variables=variables,
            history=history,
        )
        return formatted_prompt

    def get_prompt(
        self,
        template: str,
        environment: str = "latest",
    ):
        """
        Get an unformatted prompt template from Freeplay.
        """
        key = self._make_cache_key(template, environment)
        if key in self._prompt_cache:
            return self._prompt_cache[key]
        template_prompt = self.fp_client.prompts.get(
            project_id=FREEPLAY_PROJECT_ID,
            template_name=template,
            environment=environment
        )
        self._prompt_cache[key] = template_prompt
        return template_prompt

    def _make_cache_key(self, template: str, environment: str) -> tuple:
        """
        Create a cache key for the prompt cache.
        The key is a tuple of (template, environment).
        Args:
            template (str): The prompt template name.
            environment (str): The environment name.
        Returns:
            tuple: (template, environment)
        """
        return (template, environment)

    def record_session(
        self,
        state,
        end: Optional[float] = time.time(),
        formatted_prompt: Optional[FormattedPrompt] = None,
        prompt_vars: Optional[dict] = None,
    ):
        prompt_vars = prompt_vars or self._prompt_vars
        formatted_prompt = formatted_prompt or self._formatted_prompt

        # convert messages to Freeplay format
        if state['messages'] and isinstance(state['messages'][0], dict):
            # if it's a dict leave it alone and just send it on
            all_messages = state['messages']
        else:
            # otherwise assume it's langchain messages and we need to parse them for freeplay
            # all_messages = [{'role': _role_map[m.type], 'content': m.content} for m in state['messages'] if m.content]
            all_messages = [
                {
                    'role': _role_map[m.type],
                    'content': m.content,
                    **({'tool_call_id': m.tool_call_id} if isinstance(m, ToolMessage) else {})
                }
                for m in state['messages'] if m.content
            ]

        # fix session if we landed here and it's missing
        if not self.session:
            self.session = self.fp_client.sessions.restore_session(session_id=state['freeplay_session_id'])
            self.session_id = self.session.session_id

        # record your LLM call with Freeplay
        payload = RecordPayload(
            all_messages=all_messages,
            inputs=prompt_vars,
            session_info=self.session, 
            prompt_info=formatted_prompt.prompt_info,
            call_info=CallInfo.from_prompt_info(formatted_prompt.prompt_info, start_time=state['start_time'], end_time=end), 
            response_info=ResponseInfo(
                # is_complete=chat_response.choices[0].finish_reason == 'stop'
                is_complete=True
            )
        )
        self.fp_client.recordings.create(payload)

    def get_prompt_by_persona(self,
                              persona: str,
                              variables: dict = {},
                              history: Optional[List[BaseMessage]] = None):
        if 'casual' in persona.lower():
            prompt = self.get_prompt(template='casual_fan_prompt', environment='latest')
        elif 'super' in persona.lower():
            prompt = self.get_prompt(template='super_fan_prompt', environment='latest')
        else:
            raise ValueError(f"Unknown persona: {persona}")

        formatted_prompt = prompt.bind(variables=variables, history=history).format()
        self._prompt_vars = variables
        self._formatted_prompt = formatted_prompt

        return formatted_prompt