Spaces:
Sleeping
Sleeping
| import json | |
| import openai | |
| from gradio import ChatMessage | |
| class SantaAgent: | |
| def __init__(self, system_prompt: str): | |
| self.system_prompt = system_prompt | |
| self.client = openai.OpenAI() | |
| self.tools = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "buy_item", | |
| "description": "Buy an item from the store.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "item": { | |
| "type": "string", | |
| "description": "The item to buy from the store." | |
| } | |
| }, | |
| "required": ["item"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "give_present", | |
| "description": "Give a present to a person.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "person": { | |
| "type": "string", | |
| "description": "The person to give the present to." | |
| }, | |
| "item": { | |
| "type": "string", | |
| "description": "The item to give to the person." | |
| } | |
| }, | |
| "required": ["person", "item"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "make_naughty_nice_list", | |
| "description": "Make a list of children that have been naughty and nice. This function cannot make other lists.", | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "check_naughty_nice_list", | |
| "description": "Check which children have been naughty and nice. This is the only information in the list.", | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "cut_paper", | |
| "description": "Cut wrapping paper to wrap a present.", | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "find_end_of_tape", | |
| "description": "Find the end of the tape to wrap a present.", | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "wrap_present", | |
| "description": "Wrap a present.", | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "label_present", | |
| "description": "Label a present with the recipient's name.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "recipient": { | |
| "type": "string", | |
| "description": "The name of the recipient." | |
| } | |
| }, | |
| "required": ["recipient"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "retrieve_letters", | |
| "description": "Retrieve letters from children." | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "check_temperature", | |
| "description": "Use this tool to check the temperature of an object.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "object": { | |
| "type": "string", | |
| "description": "The object to check the temperature of." | |
| } | |
| }, | |
| "required": ["object"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "dunk_cookie", | |
| "description": "Dunk a cookie in milk." | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "drink", | |
| "description": "Drink an item.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "item": { | |
| "type": "string", | |
| "description": "The item to drink." | |
| } | |
| }, | |
| "required": ["item"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "put_route_into_maps", | |
| "description": "Put a route into Google Maps.", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "addr1": { | |
| "type": "string", | |
| "description": "First Address to Visit." | |
| }, | |
| "addr2": { | |
| "type": "string", | |
| "description": "Second Address to Visit." | |
| }, | |
| "addr3": { | |
| "type": "string", | |
| "description": "Third Address to Visit." | |
| } | |
| }, | |
| "required": ["addr1", "addr2", "addr3"] | |
| } | |
| } | |
| }, | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "stop", | |
| "description": "Use this tool if you are finished and want to stop." | |
| } | |
| } | |
| ] | |
| def buy_item(self, item: str): | |
| """Buy an item from the store.""" | |
| return f"Bought {item} from the store." | |
| def give_present(self, person: str, item: str): | |
| """Give a present to a person.""" | |
| return f"Gave {item} to {person}." | |
| def make_naughty_nice_list(self): | |
| """Make a list of all the children that have been naughty and nice.""" | |
| return "Made a list." | |
| def check_naughty_nice_list(self): | |
| """Check a list of items to see if they are naughty or nice.""" | |
| return json.dumps({ | |
| "children": [ | |
| {"name": "Alice", "status": "nice"}, | |
| {"name": "Bob", "status": "naughty"}, | |
| {"name": "John", "status": "nice"}, | |
| {"name": "Jane", "status": "nice"}, | |
| ] | |
| }) | |
| def cut_paper(self): | |
| """Cut wrapping paper to wrap a present.""" | |
| return "Cut the wrapping paper." | |
| def find_end_of_tape(self): | |
| """Find the end of the tape to wrap a present.""" | |
| return "Found the end of the tape." | |
| def wrap_present(self): | |
| """Wrap a present.""" | |
| return "Wrapped the present." | |
| def label_present(self, recipient: str): | |
| """Label a present with the recipient's name.""" | |
| return f"Labeled the present for {recipient}." | |
| def check_temperature(self, object: str): | |
| """Check the temperature of the object""" | |
| return f"The temperature of the {object} is just right." | |
| def dunk_cookie(self): | |
| """Dunk a cookie in milk.""" | |
| return "Dunked a cookie in milk." | |
| def drink(self, item: str): | |
| """Drink an item.""" | |
| return f"Drank {item}." | |
| def retrieve_letters(self): | |
| """Retrieve letters from children.""" | |
| return json.dumps({ | |
| "letters": [ | |
| {"text": "Dear Santa, I would like a Bike for Christmas.", "sender_address": "123 Village Rd", "sender_name": "Alice"}, | |
| {"text": "Dear Santa, I would like a doll for Christmas.", "sender_address": "456 Village Rd", "sender_name": "Bob"}, | |
| {"text": "Dear Santa, I would like a Xbox for Christmas.", "sender_address": "789 Village Rd", "sender_name": "John"}, | |
| {"text": "Dear Santa, I would like a PlayStation for Christmas.", "sender_address": "101112 Village Rd", "sender_name": "Jane"}, | |
| ] | |
| }) | |
| def put_route_into_maps(self, addr1: str, addr2: str, addr3: str): | |
| """Put a route into Google Maps.""" | |
| return json.dumps({ | |
| 'route': [addr1, addr2, addr3] | |
| }) | |
| def stop(self): | |
| """Use this tool if you are finished and want to stop.""" | |
| return "STOP" | |
| def mock_run_santa_agent(self): | |
| messages = [ | |
| {"role": "user", "content": "Hi there"}, | |
| {"role": "assistant", "content": "Bye bye"}, | |
| ] | |
| gradio_messages = [ | |
| ChatMessage(role="user", content="Hi there"), | |
| ChatMessage(role="assistant", content="Bye bye"), | |
| ] | |
| return messages, gradio_messages | |
| def run_santa_agent(self, user_prompt: str): | |
| """Run the Santa agent.""" | |
| messages = [ | |
| {"role": "system", "content": self.system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ] | |
| gradio_messages = [ | |
| ChatMessage(role="system", content=self.system_prompt), | |
| ChatMessage(role="user", content=user_prompt), | |
| ] | |
| consequtive_non_tool_call_count = 0 | |
| while True: | |
| response = self.client.chat.completions.create( | |
| messages=messages, | |
| model="gpt-4o-mini", | |
| tools=self.tools, | |
| tool_choice="auto", | |
| ) | |
| tool_calls = response.choices[0].message.tool_calls | |
| # Reduce the number of non-toolcall messages | |
| if not tool_calls: | |
| consequtive_non_tool_call_count += 1 | |
| if consequtive_non_tool_call_count >= 2: | |
| break | |
| else: | |
| consequtive_non_tool_call_count = 0 | |
| messages.append(response.choices[0].message.to_dict()) | |
| content = response.choices[0].message.content | |
| if content is not None: | |
| gradio_messages.append(ChatMessage(role="assistant", content=content)) | |
| should_stop = False | |
| if tool_calls: | |
| for tool_call in tool_calls: | |
| arguments = json.loads(tool_call.function.arguments) | |
| if tool_call.function.name == "buy_item": | |
| item = arguments["item"] | |
| gradio_messages.append(ChatMessage(role="assistant", content=f"buy_item({item})", metadata={"title": "π§ Tool Call: buy_item"})) | |
| output = self.buy_item(item) | |
| elif tool_call.function.name == "give_present": | |
| person, item = arguments["person"], arguments["item"] | |
| gradio_messages.append(ChatMessage(role="assistant", content=f"give_present({person}, {item})", metadata={"title": "π§ Tool Call: give_present"})) | |
| output = self.give_present(person, item) | |
| elif tool_call.function.name == "make_naughty_nice_list": | |
| output = self.make_naughty_nice_list() | |
| gradio_messages.append(ChatMessage(role="assistant", content="make_naughty_nice_list", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "check_naughty_nice_list": | |
| output = self.check_naughty_nice_list() | |
| gradio_messages.append(ChatMessage(role="assistant", content="check_naughty_nice_list", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "cut_paper": | |
| output = self.cut_paper() | |
| gradio_messages.append(ChatMessage(role="assistant", content="cut_paper", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "find_end_of_tape": | |
| output = self.find_end_of_tape() | |
| gradio_messages.append(ChatMessage(role="assistant", content="find_end_of_tape", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "wrap_present": | |
| output = self.wrap_present() | |
| gradio_messages.append(ChatMessage(role="assistant", content="wrap_present", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "label_present": | |
| recipient = arguments["recipient"] | |
| output = self.label_present(recipient) | |
| gradio_messages.append(ChatMessage(role="assistant", content=f"label_present({recipient})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "retrieve_letters": | |
| output = self.retrieve_letters() | |
| gradio_messages.append(ChatMessage(role="assistant", content="retrieve_letters", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "check_temperature": | |
| object = arguments["object"] | |
| output = self.check_temperature(object) | |
| gradio_messages.append(ChatMessage(role="assistant", content=f"check_temperature({object})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "dunk_cookie": | |
| output = self.dunk_cookie() | |
| gradio_messages.append(ChatMessage(role="assistant", content="dunk_cookie", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "drink": | |
| item = arguments["item"] | |
| output = self.drink(item) | |
| gradio_messages.append(ChatMessage(role="assistant", content=f"drink({item})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "put_route_into_maps": | |
| addr1, addr2, addr3 = arguments["addr1"], arguments["addr2"], arguments["addr3"] | |
| output = self.put_route_into_maps(addr1, addr2, addr3) | |
| gradio_messages.append(ChatMessage(role="assistant", content=f"put_route_into_maps({addr1}, {addr2}, {addr3})", metadata={"title": f"π§ Tool Call: {tool_call.function.name}"})) | |
| elif tool_call.function.name == "stop": | |
| output = self.stop() | |
| should_stop = True | |
| messages.append({"role": "tool", "content": output, "tool_call_id": tool_call.id}) | |
| if not should_stop: | |
| gradio_messages.append(ChatMessage(role="assistant", content=output, metadata={"title": f"π§ Tool Output: {tool_call.function.name}"})) | |
| if should_stop or len(messages) > 10: | |
| break | |
| return messages, gradio_messages | |