Spaces:
Runtime error
Runtime error
Improve UI | Add chains for itinerary with map JSON
Browse files- agents/agent.py +165 -6
- app.py +55 -12
- templates/itinerary.py +66 -0
- templates/mapping.py +77 -0
- tests/agent/test_agent.py +32 -3
agents/agent.py
CHANGED
@@ -1,9 +1,14 @@
|
|
1 |
import openai
|
2 |
import logging
|
3 |
import time
|
4 |
-
from templates.validation import ValidationTemplate
|
|
|
|
|
5 |
from langchain.chat_models import ChatOpenAI
|
6 |
from langchain.chains import LLMChain, SequentialChain
|
|
|
|
|
|
|
7 |
|
8 |
logging.basicConfig(level=logging.INFO)
|
9 |
|
@@ -25,10 +30,16 @@ class Agent:
|
|
25 |
|
26 |
# Initialize ChatOpenAI with the provided OpenAI API key and model details
|
27 |
self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)
|
28 |
-
|
|
|
29 |
self.validation_prompt = ValidationTemplate()
|
|
|
|
|
|
|
|
|
30 |
# Setup the validation chain using the LLMChain and SequentialChain
|
31 |
self.validation_chain = self._set_up_validation_chain(debug)
|
|
|
32 |
|
33 |
def _set_up_validation_chain(self, debug=True):
|
34 |
# Make validation agent chain using LLMChain
|
@@ -49,6 +60,38 @@ class Agent:
|
|
49 |
)
|
50 |
|
51 |
return overall_chain
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
def validate_travel(self, query):
|
54 |
self.logger.info("Validating query: %s", query)
|
@@ -60,7 +103,6 @@ class Agent:
|
|
60 |
)
|
61 |
)
|
62 |
|
63 |
-
|
64 |
# Call the validation chain with the query and format instructions
|
65 |
validation_result = self.validation_chain.run(
|
66 |
{
|
@@ -69,9 +111,126 @@ class Agent:
|
|
69 |
}
|
70 |
)
|
71 |
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
t2 = time.time()
|
75 |
self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)
|
76 |
|
77 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import openai
|
2 |
import logging
|
3 |
import time
|
4 |
+
from templates.validation import ValidationTemplate
|
5 |
+
from templates.itinerary import ItineraryTemplate
|
6 |
+
from templates.mapping import MappingTemplate
|
7 |
from langchain.chat_models import ChatOpenAI
|
8 |
from langchain.chains import LLMChain, SequentialChain
|
9 |
+
from datetime import datetime
|
10 |
+
from datetime import date
|
11 |
+
from dateutil.relativedelta import relativedelta
|
12 |
|
13 |
logging.basicConfig(level=logging.INFO)
|
14 |
|
|
|
30 |
|
31 |
# Initialize ChatOpenAI with the provided OpenAI API key and model details
|
32 |
self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)
|
33 |
+
|
34 |
+
# Initialize the ValidationTemplate and ItineraryTemplate
|
35 |
self.validation_prompt = ValidationTemplate()
|
36 |
+
self.itinerary_prompt = ItineraryTemplate()
|
37 |
+
self.mapping_prompt = MappingTemplate()
|
38 |
+
|
39 |
+
|
40 |
# Setup the validation chain using the LLMChain and SequentialChain
|
41 |
self.validation_chain = self._set_up_validation_chain(debug)
|
42 |
+
self.itinerary_chain = self._set_up_itinerary_chain(debug)
|
43 |
|
44 |
def _set_up_validation_chain(self, debug=True):
|
45 |
# Make validation agent chain using LLMChain
|
|
|
60 |
)
|
61 |
|
62 |
return overall_chain
|
63 |
+
|
64 |
+
def _set_up_itinerary_chain(self, debug=True):
|
65 |
+
|
66 |
+
# set up LLMChain to get the itinerary as a string
|
67 |
+
itinerary_agent = LLMChain(
|
68 |
+
llm=self.chat_model,
|
69 |
+
prompt=self.itinerary_prompt.chat_prompt,
|
70 |
+
verbose=debug,
|
71 |
+
output_key="itinerary_suggestion",
|
72 |
+
)
|
73 |
+
|
74 |
+
# set up LLMChain to extract the waypoints as a JSON object
|
75 |
+
mapping_agent = LLMChain(
|
76 |
+
llm=self.chat_model,
|
77 |
+
prompt=self.mapping_prompt.chat_prompt,
|
78 |
+
output_parser=self.mapping_prompt.parser,
|
79 |
+
verbose=debug,
|
80 |
+
output_key="mapping_list",
|
81 |
+
)
|
82 |
+
|
83 |
+
# overall chain allows us to call the travel_agent and parser in
|
84 |
+
# sequence, with labelled outputs.
|
85 |
+
overall_chain = SequentialChain(
|
86 |
+
chains=[itinerary_agent, mapping_agent],
|
87 |
+
input_variables=["start_location", "end_location", "start_date", "end_date",
|
88 |
+
"attractions", "budget", "transportation", "accommodation",
|
89 |
+
"schedule", "format_instructions"],
|
90 |
+
output_variables=["itinerary_suggestion","mapping_list"],
|
91 |
+
verbose=debug,
|
92 |
+
)
|
93 |
+
|
94 |
+
return overall_chain
|
95 |
|
96 |
def validate_travel(self, query):
|
97 |
self.logger.info("Validating query: %s", query)
|
|
|
103 |
)
|
104 |
)
|
105 |
|
|
|
106 |
# Call the validation chain with the query and format instructions
|
107 |
validation_result = self.validation_chain.run(
|
108 |
{
|
|
|
111 |
}
|
112 |
)
|
113 |
|
114 |
+
self.logger.info("Datatype of validation_result: %s", type(validation_result))
|
115 |
+
|
116 |
+
# Convert the validation result into a dictionary if it's not one already
|
117 |
+
if isinstance(validation_result, dict):
|
118 |
+
validation_dict = validation_result
|
119 |
+
else: # assuming validation_result is an instance of the Validation class
|
120 |
+
validation_dict = validation_result.dict()
|
121 |
+
|
122 |
+
# Log the datatype and content of the validation output
|
123 |
+
self.logger.info("Datatype of validation_dict: %s", type(validation_dict))
|
124 |
+
self.logger.info("Content of validation_dict: %s", validation_dict)
|
125 |
+
|
126 |
t2 = time.time()
|
127 |
self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)
|
128 |
|
129 |
+
return validation_dict
|
130 |
+
|
131 |
+
|
132 |
+
def calculate_duration(self, start_date, end_date):
|
133 |
+
if not isinstance(start_date, date) or not isinstance(end_date, date):
|
134 |
+
raise ValueError("start_date and end_date must be datetime.date objects")
|
135 |
+
|
136 |
+
if end_date < start_date:
|
137 |
+
raise ValueError("End date must be after or equal to start date")
|
138 |
+
|
139 |
+
# Calculate the duration using relativedelta
|
140 |
+
delta = relativedelta(end_date, start_date)
|
141 |
+
|
142 |
+
years = delta.years
|
143 |
+
months = delta.months
|
144 |
+
days = delta.days + 1 # We'll calculate weeks from days
|
145 |
+
|
146 |
+
duration_parts = []
|
147 |
+
|
148 |
+
if years > 0:
|
149 |
+
duration_parts.append(f"{years} year{'s' if years > 1 else ''}")
|
150 |
+
|
151 |
+
if months > 0:
|
152 |
+
duration_parts.append(f"{months} month{'s' if months > 1 else ''}")
|
153 |
+
|
154 |
+
weeks = days // 7
|
155 |
+
days = days % 7
|
156 |
+
|
157 |
+
if weeks > 0:
|
158 |
+
duration_parts.append(f"{weeks} week{'s' if weeks > 1 else ''}")
|
159 |
+
|
160 |
+
if days > 0:
|
161 |
+
duration_parts.append(f"{days} day{'s' if days > 1 else ''}")
|
162 |
+
|
163 |
+
return ', '.join(duration_parts)
|
164 |
+
|
165 |
+
|
166 |
+
def generate_itinerary(self, user_details):
|
167 |
+
self.logger.info("Generating itinerary for user details: %s", user_details)
|
168 |
+
|
169 |
+
# Validate the user details dictionary keys match the expected input variables
|
170 |
+
expected_keys = ["start_location", "end_location", "start_date", "end_date",
|
171 |
+
"attractions", "budget", "transportation", "accommodation",
|
172 |
+
"schedule"]
|
173 |
+
for key in expected_keys:
|
174 |
+
if key not in user_details:
|
175 |
+
self.logger.error("Missing '%s' in user details.", key)
|
176 |
+
return None # or handle the missing key appropriately
|
177 |
+
|
178 |
+
try:
|
179 |
+
# Calculate trip duration
|
180 |
+
trip_duration = self.calculate_duration(user_details['start_date'], user_details['end_date'])
|
181 |
+
# Construct the query phrase
|
182 |
+
query_phrase = "{} trip from {} to {}".format(trip_duration, user_details['start_location'], user_details['end_location'])
|
183 |
+
except KeyError as e:
|
184 |
+
self.logger.error("Missing key in user details: %s", e)
|
185 |
+
return None # or handle the missing key appropriately
|
186 |
+
except ValueError as e:
|
187 |
+
self.logger.error(e)
|
188 |
+
return None # or handle the error appropriately
|
189 |
+
|
190 |
+
t1 = time.time()
|
191 |
+
|
192 |
+
self.logger.info("Calling itinerary chain to validate user query")
|
193 |
+
validation_dict = self.validate_travel(query_phrase)
|
194 |
+
is_plan_valid = validation_dict["plan_is_valid"]
|
195 |
+
|
196 |
+
|
197 |
+
if is_plan_valid.lower() == "no":
|
198 |
+
self.logger.warning("User request was not valid!")
|
199 |
+
print("\n######\n Travel plan is not valid \n######\n")
|
200 |
+
print(validation_result["updated_request"])
|
201 |
+
|
202 |
+
# Create a dictionary with variable names as keys
|
203 |
+
result_dict = {
|
204 |
+
"itinerary_suggestion": None,
|
205 |
+
"list_of_places": None,
|
206 |
+
"validation_dict": validation_dict
|
207 |
+
}
|
208 |
+
|
209 |
+
return result_dict
|
210 |
+
|
211 |
+
self.logger.info("User query is valid. Calling itinerary chain on user details")
|
212 |
+
|
213 |
+
itinerary_details = user_details.copy()
|
214 |
+
itinerary_details["format_instructions"] = self.mapping_prompt.parser.get_format_instructions()
|
215 |
+
|
216 |
+
# Call the itinerary chain with the itinerary details
|
217 |
+
itinerary_result = self.itinerary_chain(itinerary_details)
|
218 |
+
|
219 |
+
itinerary_suggestion = itinerary_result["itinerary_suggestion"]
|
220 |
+
list_of_places = itinerary_result["mapping_list"].dict()
|
221 |
+
|
222 |
+
# Log the datatype and content of the list_of_places output
|
223 |
+
self.logger.info("Datatype of validation_dict: %s", type(list_of_places))
|
224 |
+
self.logger.info("Content of validation_dict: %s", list_of_places)
|
225 |
+
|
226 |
+
t2 = time.time()
|
227 |
+
self.logger.debug("Time to generate itinerary: %.2f seconds", t2 - t1)
|
228 |
+
|
229 |
+
# Create a dictionary with variable names as keys
|
230 |
+
result_dict = {
|
231 |
+
"itinerary_suggestion": itinerary_suggestion,
|
232 |
+
"list_of_places": list_of_places,
|
233 |
+
"validation_dict": validation_dict
|
234 |
+
}
|
235 |
+
|
236 |
+
return result_dict
|
app.py
CHANGED
@@ -87,16 +87,59 @@ with st.sidebar:
|
|
87 |
# Main page layout
|
88 |
st.header('Your Itinerary')
|
89 |
if submit:
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
else:
|
102 |
-
st.error('Please fill out
|
|
|
87 |
# Main page layout
|
88 |
st.header('Your Itinerary')
|
89 |
if submit:
|
90 |
+
missing_fields = []
|
91 |
+
|
92 |
+
if not start_location:
|
93 |
+
missing_fields.append("Start Location")
|
94 |
+
|
95 |
+
if not end_location:
|
96 |
+
missing_fields.append("End Location")
|
97 |
+
|
98 |
+
if not attractions:
|
99 |
+
missing_fields.append("Attractions")
|
100 |
+
|
101 |
+
if not start_date:
|
102 |
+
missing_fields.append("Start Date")
|
103 |
+
|
104 |
+
if not end_date:
|
105 |
+
missing_fields.append("End Date")
|
106 |
+
|
107 |
+
if not accommodation:
|
108 |
+
missing_fields.append("Accommodation")
|
109 |
+
|
110 |
+
if not schedule:
|
111 |
+
missing_fields.append("Schedule")
|
112 |
+
|
113 |
+
if not transportation:
|
114 |
+
missing_fields.append("Transportation")
|
115 |
+
|
116 |
+
if not missing_fields:
|
117 |
+
user_details = {
|
118 |
+
'start_location': start_location,
|
119 |
+
'end_location': end_location,
|
120 |
+
'start_date': start_date,
|
121 |
+
'end_date': end_date,
|
122 |
+
'attractions': attractions,
|
123 |
+
'budget': budget,
|
124 |
+
'transportation': transportation,
|
125 |
+
'accommodation': accommodation,
|
126 |
+
'schedule': schedule
|
127 |
+
}
|
128 |
+
|
129 |
+
# Display user details in the Streamlit console
|
130 |
+
st.write("Start Location:", user_details['start_location'])
|
131 |
+
st.write("End Location:", user_details['end_location'])
|
132 |
+
st.write("Start Date:", user_details['start_date'])
|
133 |
+
st.write("End Date:", user_details['end_date'])
|
134 |
+
st.write("Attractions:", user_details['attractions'])
|
135 |
+
st.write("Budget:", user_details['budget'])
|
136 |
+
st.write("Transportation:", user_details['transportation'])
|
137 |
+
st.write("Accommodation:", user_details['accommodation'])
|
138 |
+
st.write("Schedule:", user_details['schedule'])
|
139 |
+
|
140 |
+
note = """Here is your personalized travel itinerary covering all the major locations you want to visit.
|
141 |
+
This map gives you a general overview of your travel route from start to finish.
|
142 |
+
For daily plans, please review the sequence of waypoints to ensure the best experience,
|
143 |
+
as the route may vary based on daily activities and traffic conditions."""
|
144 |
else:
|
145 |
+
st.error(f'Please fill out the following required fields: {", ".join(missing_fields)}')
|
templates/itinerary.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import (
|
2 |
+
ChatPromptTemplate,
|
3 |
+
SystemMessagePromptTemplate,
|
4 |
+
HumanMessagePromptTemplate,
|
5 |
+
)
|
6 |
+
|
7 |
+
class ItineraryTemplate:
|
8 |
+
def __init__(self):
|
9 |
+
self.system_template = """
|
10 |
+
You are a sophisticated AI travel guide. Your job is to create engaging and practical travel plans for users.
|
11 |
+
|
12 |
+
The user's travel details and preferences will be presented in a structured format, starting with four hashtags. Your task is to convert these details into a detailed travel itinerary that includes waypoints, activities, and suggestions tailored to the user's interests and constraints.
|
13 |
+
|
14 |
+
Focus on providing specific addresses or locations for each suggested activity.
|
15 |
+
|
16 |
+
Take into account the user's available time, budget, preferred transportation mode, accommodation type, and desired daily schedule intensity. Aim to construct an itinerary that is enjoyable, feasible, and aligns with the user's expectations.
|
17 |
+
|
18 |
+
Present the itinerary in a bulleted list format with clear indications of the start and end points, as well as the recommended mode of transit between locations.
|
19 |
+
|
20 |
+
If the user has not specified certain details, use your judgment to select appropriate options, ensuring to provide specific addresses. Your response should be the itinerary list exclusively.
|
21 |
+
"""
|
22 |
+
|
23 |
+
self.human_template = """
|
24 |
+
#### User's Travel Details ####
|
25 |
+
- Starting Location: {start_location}
|
26 |
+
- Destination: {end_location}
|
27 |
+
- Travel Dates: {start_date} to {end_date}
|
28 |
+
- Attractions of Interest: {attractions}
|
29 |
+
- Travel Budget Range: {budget}
|
30 |
+
- Preferred Transportation: {transportation}
|
31 |
+
- Accommodation Type: {accommodation}
|
32 |
+
- Desired Daily Schedule Intensity: {schedule}
|
33 |
+
"""
|
34 |
+
|
35 |
+
# Assuming SystemMessagePromptTemplate and HumanMessagePromptTemplate are classes
|
36 |
+
# defined elsewhere in your codebase that format these messages.
|
37 |
+
self.system_message_prompt = SystemMessagePromptTemplate.from_template(
|
38 |
+
self.system_template
|
39 |
+
)
|
40 |
+
self.human_message_prompt = HumanMessagePromptTemplate.from_template(
|
41 |
+
self.human_template,
|
42 |
+
input_variables=["start_location", "end_location", "start_date", "end_date",
|
43 |
+
"attractions", "budget", "transportation", "accommodation",
|
44 |
+
"schedule"]
|
45 |
+
)
|
46 |
+
|
47 |
+
self.chat_prompt = ChatPromptTemplate.from_messages(
|
48 |
+
[self.system_message_prompt, self.human_message_prompt]
|
49 |
+
)
|
50 |
+
|
51 |
+
def generate_prompt(self, user_details):
|
52 |
+
"""
|
53 |
+
Fill in the human template with the actual details from the user's input.
|
54 |
+
"""
|
55 |
+
filled_human_template = self.human_message_prompt.format(
|
56 |
+
start_location=user_details['start_location'],
|
57 |
+
end_location=user_details['end_location'],
|
58 |
+
start_date=user_details['start_date'].strftime('%Y-%m-%d'),
|
59 |
+
end_date=user_details['end_date'].strftime('%Y-%m-%d'),
|
60 |
+
attractions=", ".join(user_details['attractions']),
|
61 |
+
budget=f"{user_details['budget'][0]} to {user_details['budget'][1]}",
|
62 |
+
transportation=user_details['transportation'],
|
63 |
+
accommodation=user_details['accommodation'],
|
64 |
+
schedule=user_details['schedule']
|
65 |
+
)
|
66 |
+
return f"{self.system_message_prompt}\n{filled_human_template}"
|
templates/mapping.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import (
|
2 |
+
ChatPromptTemplate,
|
3 |
+
SystemMessagePromptTemplate,
|
4 |
+
HumanMessagePromptTemplate,
|
5 |
+
)
|
6 |
+
from langchain.output_parsers import PydanticOutputParser
|
7 |
+
from pydantic import BaseModel, Field
|
8 |
+
from typing import List
|
9 |
+
|
10 |
+
class Path(BaseModel):
|
11 |
+
start_location: str = Field(description="start location of trip")
|
12 |
+
end_location: str = Field(description="end location of trip")
|
13 |
+
waypoints: List[str] = Field(description="list of waypoints")
|
14 |
+
transit: str = Field(description="mode of transportation")
|
15 |
+
|
16 |
+
class MappingTemplate(object):
|
17 |
+
|
18 |
+
MAX_WAYPOINTS = 20 # This can be adjusted or made configurable.
|
19 |
+
|
20 |
+
def __init__(self):
|
21 |
+
self.system_template = """
|
22 |
+
You an agent who converts detailed travel plans into a simple list of locations.
|
23 |
+
|
24 |
+
The itinerary will be denoted by four hashtags. Convert it into
|
25 |
+
list of places that they should visit. Try to include the specific address of each location.
|
26 |
+
|
27 |
+
Your output should always contain the start and end point of the trip, and may also include a list
|
28 |
+
of waypoints. It should also include a mode of transit. The number of waypoints cannot exceed 20.
|
29 |
+
If you can't infer the mode of transit, make a best guess given the trip location.
|
30 |
+
|
31 |
+
For example:
|
32 |
+
|
33 |
+
####
|
34 |
+
Itinerary for a 2-day driving trip within London:
|
35 |
+
- Day 1:
|
36 |
+
- Start at Buckingham Palace (The Mall, London SW1A 1AA)
|
37 |
+
- Visit the Tower of London (Tower Hill, London EC3N 4AB)
|
38 |
+
- Explore the British Museum (Great Russell St, Bloomsbury, London WC1B 3DG)
|
39 |
+
- Enjoy shopping at Oxford Street (Oxford St, London W1C 1JN)
|
40 |
+
- End the day at Covent Garden (Covent Garden, London WC2E 8RF)
|
41 |
+
- Day 2:
|
42 |
+
- Start at Westminster Abbey (20 Deans Yd, Westminster, London SW1P 3PA)
|
43 |
+
- Visit the Churchill War Rooms (Clive Steps, King Charles St, London SW1A 2AQ)
|
44 |
+
- Explore the Natural History Museum (Cromwell Rd, Kensington, London SW7 5BD)
|
45 |
+
- End the trip at the Tower Bridge (Tower Bridge Rd, London SE1 2UP)
|
46 |
+
#####
|
47 |
+
|
48 |
+
Output:
|
49 |
+
Start: Buckingham Palace, The Mall, London SW1A 1AA
|
50 |
+
End: Tower Bridge, Tower Bridge Rd, London SE1 2UP
|
51 |
+
Waypoints: ["Tower of London, Tower Hill, London EC3N 4AB", "British Museum, Great Russell St, Bloomsbury, London WC1B 3DG", "Oxford St, London W1C 1JN", "Covent Garden, London WC2E 8RF","Westminster, London SW1A 0AA", "St. James's Park, London", "Natural History Museum, Cromwell Rd, Kensington, London SW7 5BD"]
|
52 |
+
Transit: driving
|
53 |
+
|
54 |
+
Transit can be only one of the following options: "driving", "train", "bus" or "flight".
|
55 |
+
|
56 |
+
{format_instructions}
|
57 |
+
"""
|
58 |
+
|
59 |
+
self.human_template = """
|
60 |
+
####{itinerary_suggestion}####
|
61 |
+
"""
|
62 |
+
|
63 |
+
self.parser = PydanticOutputParser(pydantic_object=Path)
|
64 |
+
|
65 |
+
self.system_message_prompt = SystemMessagePromptTemplate.from_template(
|
66 |
+
self.system_template,
|
67 |
+
partial_variables={
|
68 |
+
"format_instructions": self.parser.get_format_instructions()
|
69 |
+
},
|
70 |
+
)
|
71 |
+
self.human_message_prompt = HumanMessagePromptTemplate.from_template(
|
72 |
+
self.human_template, input_variables=["itinerary_suggestion"]
|
73 |
+
)
|
74 |
+
|
75 |
+
self.chat_prompt = ChatPromptTemplate.from_messages(
|
76 |
+
[self.system_message_prompt, self.human_message_prompt]
|
77 |
+
)
|
tests/agent/test_agent.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import unittest
|
2 |
-
from config import load_secrets
|
3 |
-
from agents.agent import Agent
|
|
|
4 |
|
5 |
DEBUG = True
|
6 |
|
@@ -23,7 +24,7 @@ class TestAgentMethods(unittest.TestCase):
|
|
23 |
debug=self.debug,
|
24 |
)
|
25 |
|
26 |
-
|
27 |
def test_validation_chain(self):
|
28 |
validation_chain = self.agent._set_up_validation_chain(debug=self.debug)
|
29 |
|
@@ -60,6 +61,34 @@ class TestAgentMethods(unittest.TestCase):
|
|
60 |
q3_out = q3_res["validation_output"].dict()
|
61 |
self.assertEqual(q3_out["plan_is_valid"], "yes")
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
if __name__ == "__main__":
|
65 |
unittest.main()
|
|
|
1 |
import unittest
|
2 |
+
from config import load_secrets
|
3 |
+
from agents.agent import Agent
|
4 |
+
import datetime
|
5 |
|
6 |
DEBUG = True
|
7 |
|
|
|
24 |
debug=self.debug,
|
25 |
)
|
26 |
|
27 |
+
@unittest.skipIf(DEBUG, "Skipping this test while debugging other tests")
|
28 |
def test_validation_chain(self):
|
29 |
validation_chain = self.agent._set_up_validation_chain(debug=self.debug)
|
30 |
|
|
|
61 |
q3_out = q3_res["validation_output"].dict()
|
62 |
self.assertEqual(q3_out["plan_is_valid"], "yes")
|
63 |
|
64 |
+
|
65 |
+
def test_generate_itinerary(self):
|
66 |
+
|
67 |
+
user_details = {
|
68 |
+
"start_location": "Berkeley, CA",
|
69 |
+
"end_location": "Seattle, WA",
|
70 |
+
"start_date": datetime.date(2023, 12, 10),
|
71 |
+
"end_date": datetime.date(2023, 12, 15),
|
72 |
+
"attractions": ["museums", "parks"],
|
73 |
+
"budget": "1500-3000 USD",
|
74 |
+
"transportation": "rental car, public Transport",
|
75 |
+
"accommodation": "hotels",
|
76 |
+
"schedule": "relaxed"
|
77 |
+
}
|
78 |
+
|
79 |
+
# Call the generate_itinerary method with the user_details
|
80 |
+
itinerary_result = self.agent.generate_itinerary(user_details)
|
81 |
+
|
82 |
+
itinerary_suggestion = itinerary_result["itinerary_suggestion"]
|
83 |
+
list_of_places = itinerary_result["list_of_places"]
|
84 |
+
validation_dict = itinerary_result["validation_dict"]
|
85 |
+
|
86 |
+
print("\nItinerary Suggestion Returned:\n", itinerary_suggestion)
|
87 |
+
print("\nList of Places Returned:\n", list_of_places)
|
88 |
+
|
89 |
+
# Assert that the itinerary contains expected keys or values.
|
90 |
+
# This depends on what `generate_itinerary` returns. For example:
|
91 |
+
self.assertIsNotNone(itinerary_suggestion)
|
92 |
|
93 |
if __name__ == "__main__":
|
94 |
unittest.main()
|