Spaces:
No application file
No application file
File size: 5,511 Bytes
d1009a4 70aa23a d1009a4 70aa23a faa9fb4 d1009a4 faa9fb4 d1009a4 70aa23a d1009a4 70aa23a faa9fb4 70aa23a faa9fb4 70aa23a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# import datetime
import os
import time
from functools import lru_cache
from typing import Union, Optional, List
from freeplay import Freeplay, RecordPayload, ResponseInfo, CallInfo
from freeplay.resources.prompts import FormattedPrompt
from langchain_core.messages import BaseMessage, ToolMessage
FREEPLAY_PROJECT_ID = os.getenv("FREEPLAY_PROJECT_ID")
_role_map = {
'human': 'user',
'ai': 'assistant',
'tool': 'tool',
}
# @lru_cache(maxsize=1)
def _get_fp_client():
return Freeplay(
freeplay_api_key=os.getenv("FREEPLAY_API_KEY"),
api_base=os.getenv("FREEPLAY_URL"),
)
class FreeplayClient:
def __init__(self, fp_client: Freeplay = None):
self.fp_client = fp_client or _get_fp_client()
self.session = None
self.session_id = None
# cache variables for recording
self._prompt_cache = {}
self._prompt_vars = None
self._formatted_prompt = None
def create_session(self):
# create a Freeplay session
self.session = self.fp_client.sessions.create()
self.session_id = self.session.session_id
return self
@staticmethod
def get_fp_client():
return _get_fp_client()
# retreive and format your prompt
def get_formatted_prompt(
self,
template: str,
environment: str = "latest",
variables: dict = {},
history: Optional[List[BaseMessage]] = None,
):
"""
Get a formatted prompt from Freeplay.
"""
formatted_prompt = self.fp_client.prompts.get_formatted(
project_id=FREEPLAY_PROJECT_ID,
template_name=template,
environment=environment,
variables=variables,
history=history,
)
return formatted_prompt
def get_prompt(
self,
template: str,
environment: str = "latest",
):
"""
Get an unformatted prompt template from Freeplay.
"""
key = self._make_cache_key(template, environment)
if key in self._prompt_cache:
return self._prompt_cache[key]
template_prompt = self.fp_client.prompts.get(
project_id=FREEPLAY_PROJECT_ID,
template_name=template,
environment=environment
)
self._prompt_cache[key] = template_prompt
return template_prompt
def _make_cache_key(self, template: str, environment: str) -> tuple:
"""
Create a cache key for the prompt cache.
The key is a tuple of (template, environment).
Args:
template (str): The prompt template name.
environment (str): The environment name.
Returns:
tuple: (template, environment)
"""
return (template, environment)
def record_session(
self,
state,
end: Optional[float] = time.time(),
formatted_prompt: Optional[FormattedPrompt] = None,
prompt_vars: Optional[dict] = None,
):
prompt_vars = prompt_vars or self._prompt_vars
formatted_prompt = formatted_prompt or self._formatted_prompt
# convert messages to Freeplay format
if state['messages'] and isinstance(state['messages'][0], dict):
# if it's a dict leave it alone and just send it on
all_messages = state['messages']
else:
# otherwise assume it's langchain messages and we need to parse them for freeplay
# all_messages = [{'role': _role_map[m.type], 'content': m.content} for m in state['messages'] if m.content]
all_messages = [
{
'role': _role_map[m.type],
'content': m.content,
**({'tool_call_id': m.tool_call_id} if isinstance(m, ToolMessage) else {})
}
for m in state['messages'] if m.content
]
# fix session if we landed here and it's missing
if not self.session:
self.session = self.fp_client.sessions.restore_session(session_id=state['freeplay_session_id'])
self.session_id = self.session.session_id
# record your LLM call with Freeplay
payload = RecordPayload(
all_messages=all_messages,
inputs=prompt_vars,
session_info=self.session,
prompt_info=formatted_prompt.prompt_info,
call_info=CallInfo.from_prompt_info(formatted_prompt.prompt_info, start_time=state['start_time'], end_time=end),
response_info=ResponseInfo(
# is_complete=chat_response.choices[0].finish_reason == 'stop'
is_complete=True
)
)
self.fp_client.recordings.create(payload)
def get_prompt_by_persona(self,
persona: str,
variables: dict = {},
history: Optional[List[BaseMessage]] = None):
if 'casual' in persona.lower():
prompt = self.get_prompt(template='casual_fan_prompt', environment='latest')
elif 'super' in persona.lower():
prompt = self.get_prompt(template='super_fan_prompt', environment='latest')
else:
raise ValueError(f"Unknown persona: {persona}")
formatted_prompt = prompt.bind(variables=variables, history=history).format()
self._prompt_vars = variables
self._formatted_prompt = formatted_prompt
return formatted_prompt
|