personamusic / langhcain_agent.py
juanluisrto's picture
Upload 3 files
8e786b4
raw
history blame contribute delete
No virus
7.13 kB
from operator import itemgetter
import pprint
from typing import Dict, List
from langchain.agents import (AgentExecutor, AgentType, OpenAIFunctionsAgent,
tool)
from langchain.chains.conversation.memory import ConversationBufferWindowMemory
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import StrOutputParser, SystemMessage, HumanMessage, AIMessage
from langchain.callbacks import get_openai_callback, FileCallbackHandler
from langchain.schema.agent import AgentActionMessageLog, AgentFinish
from langchain.utils.openai_functions import convert_pydantic_to_openai_function
from langchain.agents.format_scratchpad import format_to_openai_functions
from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
from langchain.tools.render import format_tool_to_openai_function
from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
from langchain.schema.runnable import RunnableConfig
import logging, os, json
from collections import defaultdict
from pydantic import BaseModel, Field
from cyanite import free_text_search
from langfuse.callback import CallbackHandler
if os.getenv("USE_LANGFUSE") == True:
handler = CallbackHandler(os.getenv("LANGFUSE_PUBLIC"), os.getenv("LANGFUSE_PRIVATE"), "https://cloud.langfuse.com" )
else:
handler = []
system_message = \
"""You are an agent which recommends songs based on music styles provided by the user.
- A music style could be a combination of instruments, genres or sounds.
- Use get_music_style_description to generate a description of the user's music style.
- The styles might contain pop-culture references (artists, movies, TV-Shows, etc) You should include them when generating descriptions.
- Comment on the description of the style and wish the user to enjoy the recommended songs (he will have received them).
- Do not mention any songs or artists, nor give a list of songs.
Write short responses with a respectful and friendly tone.
"""
describe_music_style_message = \
"""You receive a music style and your goal is to describe it further with genres, instruments and sounds.
If it contains pop-culture references (like TV-Shows, films, artists, famous people, etc) you should replace them with music styles that resemble them.
You should return the new music style as a set of words separated by commas.
You always give short answers, with at most 20 words.
"""
MEMORY_KEY = "history"
prompt = ChatPromptTemplate.from_messages([
("system", system_message),
MessagesPlaceholder(variable_name="agent_scratchpad"),
MessagesPlaceholder(variable_name=MEMORY_KEY),
("human", "{input}"),
])
conversation_memories = defaultdict(
lambda : ConversationBufferWindowMemory(memory_key=MEMORY_KEY, return_messages=True, output_key="output", k = 4)
)
#global dicts to store the tracks and the conversation costs
music_styles_to_tracks = {}
conversation_costs = defaultdict(float)
@tool
def get_music_style_description(music_style: str) -> str:
"A tool which describes a music style and returns a description of it"
description = describe_music_style(music_style)
tracks = free_text_search(description, 5)
logging.warning(f"""
music_style = {music_style}
music_style_description = {description}
tracks = {pprint.pformat(tracks)}""")
# we store the tracks in a global variable so that we can access them later
music_styles_to_tracks[description] = tracks
# we return only the description to the user
return description
def describe_music_style(music_style: str) -> str:
"A tool used to describe music styles"
llm_describe = ChatOpenAI(temperature=0.0)
prompt_describe = ChatPromptTemplate.from_messages([
("system", describe_music_style_message),
("human", "{music_style}"),
])
runnable = prompt_describe | llm_describe | StrOutputParser()
return runnable.invoke({"music_style" : music_style},
#RunnableConfig(verbose = True, recursion_limit=1)
)
# We instantiate the Chat Model and bind the tool to it.
llm = ChatOpenAI(temperature=0.7, request_timeout = 30, max_retries = 1)
llm_with_tools = llm.bind(
functions=[
format_tool_to_openai_function(get_music_style_description)
]
)
def get_agent_executor_from_user_id(user_id) -> AgentExecutor:
"Returns an agent executor for a given user_id"
memory = conversation_memories[user_id]
logging.warning(memory)
agent = (
{
"input": lambda x: x["input"],
"agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps'])
}
| RunnablePassthrough.assign(
history = RunnableLambda(memory.load_memory_variables) | itemgetter(MEMORY_KEY)
)
| prompt
| llm_with_tools
| OpenAIFunctionsAgentOutputParser()
)
logging.error(memory)
return AgentExecutor(
agent=agent,
tools=[get_music_style_description],
memory=memory,
callbacks=[handler] if handler else [],
return_intermediate_steps=True,
max_execution_time= 30,
handle_parsing_errors=True,
verbose=True
)
def get_tracks_from_intermediate_steps(intermediate_steps : List) -> List:
"Given a list of intermediate steps, returns the tracks from the last get_music_style_description action"
if len(intermediate_steps) == 0:
return []
else:
print("INTERMEDIATE STEPS")
pprint.pprint(intermediate_steps)
print("===================")
for action_message, prompt in intermediate_steps[::-1]:
if action_message.tool == 'get_music_style_description':
tracks = music_styles_to_tracks[prompt]
return tracks
# if none of the actions is get_music_style_description, return empty list
return []
def llm_inference(message, history, user_id) -> Dict:
"""This function is called by the API and returns the conversation response along with the appropriate tracks and costs of the conversation so far"""
# it first creates an agent executor with the previous conversation memory of a given user_id
agent_executor = get_agent_executor_from_user_id(user_id)
with get_openai_callback() as cb:
# We get the Agent response
answer = agent_executor({"input": message})
# We keep track of the costs
conversation_costs[user_id] += cb.total_cost
total_conversation_costs = conversation_costs[user_id]
# We get the tracks from the intermediate steps if any
tracks = get_tracks_from_intermediate_steps(answer['intermediate_steps'])
logging.warning(f"step = ${cb.total_cost} total = ${total_conversation_costs}")
logging.warning(music_styles_to_tracks)
return {
"output" : answer['output'],
"tracks" : tracks,
"cost" : total_conversation_costs
}