from operator import itemgetter import pprint from typing import Dict, List from langchain.agents import (AgentExecutor, AgentType, OpenAIFunctionsAgent, tool) from langchain.chains.conversation.memory import ConversationBufferWindowMemory from langchain.chat_models import ChatOpenAI from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.schema import StrOutputParser, SystemMessage, HumanMessage, AIMessage from langchain.callbacks import get_openai_callback, FileCallbackHandler from langchain.schema.agent import AgentActionMessageLog, AgentFinish from langchain.utils.openai_functions import convert_pydantic_to_openai_function from langchain.agents.format_scratchpad import format_to_openai_functions from langchain.schema.runnable import RunnablePassthrough, RunnableLambda from langchain.tools.render import format_tool_to_openai_function from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser from langchain.schema.runnable import RunnableConfig import logging, os, json from collections import defaultdict from pydantic import BaseModel, Field from cyanite import free_text_search from langfuse.callback import CallbackHandler if os.getenv("USE_LANGFUSE") == True: handler = CallbackHandler(os.getenv("LANGFUSE_PUBLIC"), os.getenv("LANGFUSE_PRIVATE"), "https://cloud.langfuse.com" ) else: handler = [] system_message = \ """You are an agent which recommends songs based on music styles provided by the user. - A music style could be a combination of instruments, genres or sounds. - Use get_music_style_description to generate a description of the user's music style. - The styles might contain pop-culture references (artists, movies, TV-Shows, etc) You should include them when generating descriptions. - Comment on the description of the style and wish the user to enjoy the recommended songs (he will have received them). - Do not mention any songs or artists, nor give a list of songs. Write short responses with a respectful and friendly tone. """ describe_music_style_message = \ """You receive a music style and your goal is to describe it further with genres, instruments and sounds. If it contains pop-culture references (like TV-Shows, films, artists, famous people, etc) you should replace them with music styles that resemble them. You should return the new music style as a set of words separated by commas. You always give short answers, with at most 20 words. """ MEMORY_KEY = "history" prompt = ChatPromptTemplate.from_messages([ ("system", system_message), MessagesPlaceholder(variable_name="agent_scratchpad"), MessagesPlaceholder(variable_name=MEMORY_KEY), ("human", "{input}"), ]) conversation_memories = defaultdict( lambda : ConversationBufferWindowMemory(memory_key=MEMORY_KEY, return_messages=True, output_key="output", k = 4) ) #global dicts to store the tracks and the conversation costs music_styles_to_tracks = {} conversation_costs = defaultdict(float) @tool def get_music_style_description(music_style: str) -> str: "A tool which describes a music style and returns a description of it" description = describe_music_style(music_style) tracks = free_text_search(description, 5) logging.warning(f""" music_style = {music_style} music_style_description = {description} tracks = {pprint.pformat(tracks)}""") # we store the tracks in a global variable so that we can access them later music_styles_to_tracks[description] = tracks # we return only the description to the user return description def describe_music_style(music_style: str) -> str: "A tool used to describe music styles" llm_describe = ChatOpenAI(temperature=0.0) prompt_describe = ChatPromptTemplate.from_messages([ ("system", describe_music_style_message), ("human", "{music_style}"), ]) runnable = prompt_describe | llm_describe | StrOutputParser() return runnable.invoke({"music_style" : music_style}, #RunnableConfig(verbose = True, recursion_limit=1) ) # We instantiate the Chat Model and bind the tool to it. llm = ChatOpenAI(temperature=0.7, request_timeout = 30, max_retries = 1) llm_with_tools = llm.bind( functions=[ format_tool_to_openai_function(get_music_style_description) ] ) def get_agent_executor_from_user_id(user_id) -> AgentExecutor: "Returns an agent executor for a given user_id" memory = conversation_memories[user_id] logging.warning(memory) agent = ( { "input": lambda x: x["input"], "agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps']) } | RunnablePassthrough.assign( history = RunnableLambda(memory.load_memory_variables) | itemgetter(MEMORY_KEY) ) | prompt | llm_with_tools | OpenAIFunctionsAgentOutputParser() ) logging.error(memory) return AgentExecutor( agent=agent, tools=[get_music_style_description], memory=memory, callbacks=[handler] if handler else [], return_intermediate_steps=True, max_execution_time= 30, handle_parsing_errors=True, verbose=True ) def get_tracks_from_intermediate_steps(intermediate_steps : List) -> List: "Given a list of intermediate steps, returns the tracks from the last get_music_style_description action" if len(intermediate_steps) == 0: return [] else: print("INTERMEDIATE STEPS") pprint.pprint(intermediate_steps) print("===================") for action_message, prompt in intermediate_steps[::-1]: if action_message.tool == 'get_music_style_description': tracks = music_styles_to_tracks[prompt] return tracks # if none of the actions is get_music_style_description, return empty list return [] def llm_inference(message, history, user_id) -> Dict: """This function is called by the API and returns the conversation response along with the appropriate tracks and costs of the conversation so far""" # it first creates an agent executor with the previous conversation memory of a given user_id agent_executor = get_agent_executor_from_user_id(user_id) with get_openai_callback() as cb: # We get the Agent response answer = agent_executor({"input": message}) # We keep track of the costs conversation_costs[user_id] += cb.total_cost total_conversation_costs = conversation_costs[user_id] # We get the tracks from the intermediate steps if any tracks = get_tracks_from_intermediate_steps(answer['intermediate_steps']) logging.warning(f"step = ${cb.total_cost} total = ${total_conversation_costs}") logging.warning(music_styles_to_tracks) return { "output" : answer['output'], "tracks" : tracks, "cost" : total_conversation_costs }