juanluisrto commited on
Commit
8e786b4
1 Parent(s): 034eb7d

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +24 -104
  2. cyanite.py +74 -0
  3. langhcain_agent.py +191 -0
app.py CHANGED
@@ -1,121 +1,41 @@
1
- import os, json, random, logging
2
- from typing import List
3
- from dotenv import load_dotenv
4
-
5
- from langchain.agents import AgentType, initialize_agent
6
- from langchain.chat_models import ChatOpenAI
7
- from langchain.tools import Tool
8
-
9
- from langchain.schema import SystemMessage
10
- from langchain.agents import OpenAIFunctionsAgent
11
- from langchain.prompts import MessagesPlaceholder
12
- from langchain.agents import AgentExecutor, OpenAIFunctionsAgent
13
- from langchain.chains.conversation.memory import ConversationBufferMemory
14
- from langchain.chat_models import ChatOpenAI
15
- from langchain.agents import tool, AgentExecutor, OpenAIFunctionsAgent, AgentType, Agent
16
- from langchain.schema import SystemMessage
17
- from langchain.prompts import MessagesPlaceholder
18
- from langchain.chains.conversation.memory import ConversationBufferMemory
19
- from langchain.chat_models import ChatOpenAI
20
-
21
- from langchain.prompts import ChatPromptTemplate
22
- from langchain.schema import StrOutputParser
23
-
24
  import gradio as gr
 
 
 
25
 
26
- load_dotenv()
27
-
28
-
29
-
30
- llm = ChatOpenAI(temperature=0)
31
-
32
-
33
-
34
- from typing import List, Dict
35
-
36
- @tool
37
- def describe_popculture_references(references: List) -> Dict:
38
- "A tool used to describe pop-culture references as music styles"
39
- prompt = ChatPromptTemplate.from_messages([
40
- ("system", """You receive a list of pop-culture references (like TV-Shows, films, artists, famous people, etc).
41
- For each reference, write a few words separated by commas which captures the essence of it. Use music styles, sounds and instruments.
42
- Return a dict with the references as keys and music styles as values.
43
- """),
44
- ("human", "{references_list}"),
45
- ])
46
- runnable = prompt | llm | StrOutputParser()
47
- return runnable.invoke({"references_list" : references})
48
-
49
-
50
- @tool
51
- def extract_popculture_references(input_style: str) -> List:
52
- "A tool used to extract pop-culture references from a piece of text"
53
- prompt = ChatPromptTemplate.from_messages([
54
- ("system", """You detect elements of the pop-culture (like TV-Shows, films, artists, famous people, etc) in the human's input message.
55
- Return a list with these elements only. If there are none, return an empty list.
56
- """),
57
- ("human", "{input_style}"),
58
- ])
59
- runnable = prompt | llm | StrOutputParser()
60
- output = runnable.invoke({"input_style" : input_style})
61
- return output
62
-
63
- @tool
64
- def call_music_recommendation_api(input : str) -> List[str]:
65
- """
66
- Calls the music recommendation API
67
- """
68
- print("Calling music recommendation API: ", input)
69
- return {"songs" : [input]}
70
-
71
- tools = [describe_popculture_references, extract_popculture_references, call_music_recommendation_api]
72
-
73
-
74
 
75
- system_message = SystemMessage(content =
76
- """You are an agent which recommends songs based on the style a user gives.
77
- You follow the following conversation protocol:
78
- - You start the conversation by asking the user what style of music they like
79
- - The user responds with a style of music
80
- - If there are pop culture references like a movie, a TV show, an artist, a famous person, extract them AND then describe them as music styles.
81
- - Ask the user if he is ok with the new generated style
82
- - If the user agrees, call the music recommendation API with this style.
83
- """)
84
 
85
 
86
- MEMORY_KEY = "chat_history"
87
- prompt = OpenAIFunctionsAgent.create_prompt(
88
- system_message=system_message,
89
- extra_prompt_messages=[MessagesPlaceholder(variable_name=MEMORY_KEY)]
90
- )
91
 
92
- memory = ConversationBufferMemory(memory_key=MEMORY_KEY, return_messages=True)
93
 
 
 
 
 
94
 
95
- agent = OpenAIFunctionsAgent(
96
- llm=llm,
97
- tools=tools,
98
- prompt=prompt,
99
- agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION
100
- )
101
 
102
- agent_executor = AgentExecutor(agent=agent, tools=tools, memory=memory, verbose=True)
103
 
104
 
105
- def inference(message, history):
106
- # return agent_executor.run(message)
107
- for chunk in agent_executor.stream(message):
108
- yield chunk["output"]
109
-
110
 
111
- gr.ChatInterface(
112
- inference,
113
- chatbot=gr.Chatbot(height=400),
 
114
  textbox=gr.Textbox(placeholder="Ask me for music recommendations!", container=False, scale=7),
115
  description="This AI makes song recommendations based on your music style.",
 
116
  title="Persona Music song recommender",
117
- examples=["Recommend me something in Quentin Tarantino reggae style", "Give me songs with calm and relaxing vibes", "I want to listen to something like the movie Inception", "I want music that sounds like Lebron James eating soup"],
118
  retry_btn="Retry",
119
  clear_btn="Clear",
120
- undo_btn = None,
121
- ).queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import logging
3
+ import uuid
4
+ from dotenv import load_dotenv
5
 
6
+ load_dotenv(override=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ from langhcain_agent import llm_inference
 
 
 
 
 
 
 
 
9
 
10
 
 
 
 
 
 
11
 
12
+ def predict_interface(message, history=None, user_id = None):
13
 
14
+ response = llm_inference(message, history, user_id)
15
+ logging.error(response)
16
+ logging.error(user_id)
17
+ return response['output']
18
 
 
 
 
 
 
 
19
 
 
20
 
21
 
22
+ session_id = gr.Textbox(value = str(uuid.uuid4()), type = "text", label = "session_id")
23
+ example_sentences=["Recommend me something in Quentin Tarantino reggae style", "Give me songs with calm and relaxing vibes", "I want to listen to something like the movie Inception", "I want music that sounds like Lebron James eating soup"]
24
+ examples = [[example, f"user_{i}"] for i, example in enumerate(example_sentences)]
 
 
25
 
26
+ chat = gr.ChatInterface(
27
+ predict_interface,
28
+ additional_inputs= [session_id],
29
+ chatbot=gr.Chatbot(height=600),
30
  textbox=gr.Textbox(placeholder="Ask me for music recommendations!", container=False, scale=7),
31
  description="This AI makes song recommendations based on your music style.",
32
+ examples=examples,
33
  title="Persona Music song recommender",
 
34
  retry_btn="Retry",
35
  clear_btn="Clear",
36
+ undo_btn = None
37
+ )
38
+
39
+
40
+
41
+ chat.queue().launch()
cyanite.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import json
4
+ import requests
5
+
6
+ CYANITE_API_URL = "https://api.cyanite.ai/graphql"
7
+ CYANITE_ACCESS_TOKEN = os.getenv("CYANITE_ACCESS_TOKEN")
8
+
9
+ def free_text_search(search_text, num_tracks=5):
10
+ headers = {
11
+ "Authorization": f"Bearer {CYANITE_ACCESS_TOKEN}",
12
+ "Content-Type": "application/json"
13
+ }
14
+
15
+ query = '''
16
+ query FreeTextSearch($searchText: String!, $numTracks: Int!) {
17
+ freeTextSearch(
18
+ first: $numTracks
19
+ target: { library: {} }
20
+ searchText: $searchText
21
+ ) {
22
+ ... on FreeTextSearchError {
23
+ message
24
+ code
25
+ }
26
+ ... on FreeTextSearchConnection {
27
+ edges {
28
+ cursor
29
+ node {
30
+ id
31
+ title
32
+ }
33
+ }
34
+ }
35
+ }
36
+ }
37
+ '''
38
+
39
+ variables = {
40
+ "searchText": search_text,
41
+ "numTracks": num_tracks
42
+ }
43
+ import time
44
+
45
+ start_time = time.time()
46
+
47
+ response = requests.post(
48
+ CYANITE_API_URL,
49
+ headers=headers,
50
+ json={'query': query, 'variables': variables}
51
+ )
52
+
53
+ end_time = time.time()
54
+ time_taken = end_time - start_time
55
+ logging.warning(f"Cyanite API: Time taken: {time_taken} seconds")
56
+
57
+ if response.status_code == 200:
58
+ songs = extract_songs_from_response(response.json())
59
+ if songs:
60
+ return songs
61
+ else:
62
+ raise Exception("No songs found")
63
+ else:
64
+ raise Exception(f"Query failed with status code {response.status_code}")
65
+
66
+ def extract_songs_from_response(response_json):
67
+ try:
68
+ edges = response_json['data']['freeTextSearch']['edges']
69
+ if not edges:
70
+ return None # No songs found
71
+ songs = [{"id": edge["node"]["id"], "title": edge["node"]["title"]} for edge in edges]
72
+ return songs
73
+ except KeyError:
74
+ raise Exception("Invalid response format")
langhcain_agent.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from operator import itemgetter
2
+ import pprint
3
+ from typing import Dict, List
4
+
5
+ from langchain.agents import (AgentExecutor, AgentType, OpenAIFunctionsAgent,
6
+ tool)
7
+ from langchain.chains.conversation.memory import ConversationBufferWindowMemory
8
+ from langchain.chat_models import ChatOpenAI
9
+ from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
10
+ from langchain.schema import StrOutputParser, SystemMessage, HumanMessage, AIMessage
11
+ from langchain.callbacks import get_openai_callback, FileCallbackHandler
12
+ from langchain.schema.agent import AgentActionMessageLog, AgentFinish
13
+ from langchain.utils.openai_functions import convert_pydantic_to_openai_function
14
+ from langchain.agents.format_scratchpad import format_to_openai_functions
15
+ from langchain.schema.runnable import RunnablePassthrough, RunnableLambda
16
+ from langchain.tools.render import format_tool_to_openai_function
17
+ from langchain.agents.output_parsers import OpenAIFunctionsAgentOutputParser
18
+ from langchain.schema.runnable import RunnableConfig
19
+
20
+
21
+
22
+ import logging, os, json
23
+ from collections import defaultdict
24
+
25
+ from pydantic import BaseModel, Field
26
+
27
+ from cyanite import free_text_search
28
+
29
+ from langfuse.callback import CallbackHandler
30
+
31
+ if os.getenv("USE_LANGFUSE") == True:
32
+ handler = CallbackHandler(os.getenv("LANGFUSE_PUBLIC"), os.getenv("LANGFUSE_PRIVATE"), "https://cloud.langfuse.com" )
33
+ else:
34
+ handler = []
35
+
36
+
37
+
38
+ system_message = \
39
+ """You are an agent which recommends songs based on music styles provided by the user.
40
+ - A music style could be a combination of instruments, genres or sounds.
41
+ - Use get_music_style_description to generate a description of the user's music style.
42
+ - The styles might contain pop-culture references (artists, movies, TV-Shows, etc) You should include them when generating descriptions.
43
+ - Comment on the description of the style and wish the user to enjoy the recommended songs (he will have received them).
44
+ - Do not mention any songs or artists, nor give a list of songs.
45
+ Write short responses with a respectful and friendly tone.
46
+ """
47
+
48
+
49
+
50
+ describe_music_style_message = \
51
+ """You receive a music style and your goal is to describe it further with genres, instruments and sounds.
52
+ If it contains pop-culture references (like TV-Shows, films, artists, famous people, etc) you should replace them with music styles that resemble them.
53
+ You should return the new music style as a set of words separated by commas.
54
+ You always give short answers, with at most 20 words.
55
+ """
56
+
57
+
58
+ MEMORY_KEY = "history"
59
+
60
+ prompt = ChatPromptTemplate.from_messages([
61
+ ("system", system_message),
62
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
63
+ MessagesPlaceholder(variable_name=MEMORY_KEY),
64
+ ("human", "{input}"),
65
+ ])
66
+
67
+ conversation_memories = defaultdict(
68
+ lambda : ConversationBufferWindowMemory(memory_key=MEMORY_KEY, return_messages=True, output_key="output", k = 4)
69
+ )
70
+
71
+ #global dicts to store the tracks and the conversation costs
72
+ music_styles_to_tracks = {}
73
+ conversation_costs = defaultdict(float)
74
+
75
+
76
+ @tool
77
+ def get_music_style_description(music_style: str) -> str:
78
+ "A tool which describes a music style and returns a description of it"
79
+ description = describe_music_style(music_style)
80
+ tracks = free_text_search(description, 5)
81
+
82
+ logging.warning(f"""
83
+ music_style = {music_style}
84
+ music_style_description = {description}
85
+ tracks = {pprint.pformat(tracks)}""")
86
+
87
+ # we store the tracks in a global variable so that we can access them later
88
+ music_styles_to_tracks[description] = tracks
89
+ # we return only the description to the user
90
+ return description
91
+
92
+ def describe_music_style(music_style: str) -> str:
93
+ "A tool used to describe music styles"
94
+ llm_describe = ChatOpenAI(temperature=0.0)
95
+ prompt_describe = ChatPromptTemplate.from_messages([
96
+ ("system", describe_music_style_message),
97
+ ("human", "{music_style}"),
98
+ ])
99
+ runnable = prompt_describe | llm_describe | StrOutputParser()
100
+ return runnable.invoke({"music_style" : music_style},
101
+ #RunnableConfig(verbose = True, recursion_limit=1)
102
+ )
103
+
104
+
105
+ # We instantiate the Chat Model and bind the tool to it.
106
+ llm = ChatOpenAI(temperature=0.7, request_timeout = 30, max_retries = 1)
107
+ llm_with_tools = llm.bind(
108
+ functions=[
109
+ format_tool_to_openai_function(get_music_style_description)
110
+ ]
111
+ )
112
+
113
+ def get_agent_executor_from_user_id(user_id) -> AgentExecutor:
114
+ "Returns an agent executor for a given user_id"
115
+ memory = conversation_memories[user_id]
116
+
117
+ logging.warning(memory)
118
+
119
+ agent = (
120
+ {
121
+ "input": lambda x: x["input"],
122
+ "agent_scratchpad": lambda x: format_to_openai_functions(x['intermediate_steps'])
123
+ }
124
+ | RunnablePassthrough.assign(
125
+ history = RunnableLambda(memory.load_memory_variables) | itemgetter(MEMORY_KEY)
126
+ )
127
+ | prompt
128
+ | llm_with_tools
129
+ | OpenAIFunctionsAgentOutputParser()
130
+ )
131
+
132
+ logging.error(memory)
133
+ return AgentExecutor(
134
+ agent=agent,
135
+ tools=[get_music_style_description],
136
+ memory=memory,
137
+ callbacks=[handler] if handler else [],
138
+ return_intermediate_steps=True,
139
+ max_execution_time= 30,
140
+ handle_parsing_errors=True,
141
+ verbose=True
142
+ )
143
+
144
+
145
+
146
+ def get_tracks_from_intermediate_steps(intermediate_steps : List) -> List:
147
+ "Given a list of intermediate steps, returns the tracks from the last get_music_style_description action"
148
+ if len(intermediate_steps) == 0:
149
+ return []
150
+ else:
151
+ print("INTERMEDIATE STEPS")
152
+ pprint.pprint(intermediate_steps)
153
+ print("===================")
154
+ for action_message, prompt in intermediate_steps[::-1]:
155
+ if action_message.tool == 'get_music_style_description':
156
+ tracks = music_styles_to_tracks[prompt]
157
+ return tracks
158
+
159
+ # if none of the actions is get_music_style_description, return empty list
160
+ return []
161
+
162
+
163
+ def llm_inference(message, history, user_id) -> Dict:
164
+ """This function is called by the API and returns the conversation response along with the appropriate tracks and costs of the conversation so far"""
165
+
166
+ # it first creates an agent executor with the previous conversation memory of a given user_id
167
+ agent_executor = get_agent_executor_from_user_id(user_id)
168
+
169
+ with get_openai_callback() as cb:
170
+
171
+ # We get the Agent response
172
+ answer = agent_executor({"input": message})
173
+
174
+ # We keep track of the costs
175
+ conversation_costs[user_id] += cb.total_cost
176
+ total_conversation_costs = conversation_costs[user_id]
177
+
178
+ # We get the tracks from the intermediate steps if any
179
+ tracks = get_tracks_from_intermediate_steps(answer['intermediate_steps'])
180
+
181
+ logging.warning(f"step = ${cb.total_cost} total = ${total_conversation_costs}")
182
+ logging.warning(music_styles_to_tracks)
183
+
184
+ return {
185
+ "output" : answer['output'],
186
+ "tracks" : tracks,
187
+ "cost" : total_conversation_costs
188
+ }
189
+
190
+
191
+