Akj2023 commited on
Commit
d137f7e
·
1 Parent(s): 14ef9cf

Improve UI | Add chains for itinerary with map JSON

Browse files
agents/agent.py CHANGED
@@ -1,9 +1,14 @@
1
  import openai
2
  import logging
3
  import time
4
- from templates.validation import ValidationTemplate # Adjust the import path as necessary
 
 
5
  from langchain.chat_models import ChatOpenAI
6
  from langchain.chains import LLMChain, SequentialChain
 
 
 
7
 
8
  logging.basicConfig(level=logging.INFO)
9
 
@@ -25,10 +30,16 @@ class Agent:
25
 
26
  # Initialize ChatOpenAI with the provided OpenAI API key and model details
27
  self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)
28
- # Initialize the ValidationTemplate
 
29
  self.validation_prompt = ValidationTemplate()
 
 
 
 
30
  # Setup the validation chain using the LLMChain and SequentialChain
31
  self.validation_chain = self._set_up_validation_chain(debug)
 
32
 
33
  def _set_up_validation_chain(self, debug=True):
34
  # Make validation agent chain using LLMChain
@@ -49,6 +60,38 @@ class Agent:
49
  )
50
 
51
  return overall_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def validate_travel(self, query):
54
  self.logger.info("Validating query: %s", query)
@@ -60,7 +103,6 @@ class Agent:
60
  )
61
  )
62
 
63
-
64
  # Call the validation chain with the query and format instructions
65
  validation_result = self.validation_chain.run(
66
  {
@@ -69,9 +111,126 @@ class Agent:
69
  }
70
  )
71
 
72
- # Extract the result from the validation output
73
- validation_output = validation_result["validation_output"].dict()
 
 
 
 
 
 
 
 
 
 
74
  t2 = time.time()
75
  self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)
76
 
77
- return validation_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import openai
2
  import logging
3
  import time
4
+ from templates.validation import ValidationTemplate
5
+ from templates.itinerary import ItineraryTemplate
6
+ from templates.mapping import MappingTemplate
7
  from langchain.chat_models import ChatOpenAI
8
  from langchain.chains import LLMChain, SequentialChain
9
+ from datetime import datetime
10
+ from datetime import date
11
+ from dateutil.relativedelta import relativedelta
12
 
13
  logging.basicConfig(level=logging.INFO)
14
 
 
30
 
31
  # Initialize ChatOpenAI with the provided OpenAI API key and model details
32
  self.chat_model = ChatOpenAI(model=model, temperature=temperature, openai_api_key=self._openai_key)
33
+
34
+ # Initialize the ValidationTemplate and ItineraryTemplate
35
  self.validation_prompt = ValidationTemplate()
36
+ self.itinerary_prompt = ItineraryTemplate()
37
+ self.mapping_prompt = MappingTemplate()
38
+
39
+
40
  # Setup the validation chain using the LLMChain and SequentialChain
41
  self.validation_chain = self._set_up_validation_chain(debug)
42
+ self.itinerary_chain = self._set_up_itinerary_chain(debug)
43
 
44
  def _set_up_validation_chain(self, debug=True):
45
  # Make validation agent chain using LLMChain
 
60
  )
61
 
62
  return overall_chain
63
+
64
+ def _set_up_itinerary_chain(self, debug=True):
65
+
66
+ # set up LLMChain to get the itinerary as a string
67
+ itinerary_agent = LLMChain(
68
+ llm=self.chat_model,
69
+ prompt=self.itinerary_prompt.chat_prompt,
70
+ verbose=debug,
71
+ output_key="itinerary_suggestion",
72
+ )
73
+
74
+ # set up LLMChain to extract the waypoints as a JSON object
75
+ mapping_agent = LLMChain(
76
+ llm=self.chat_model,
77
+ prompt=self.mapping_prompt.chat_prompt,
78
+ output_parser=self.mapping_prompt.parser,
79
+ verbose=debug,
80
+ output_key="mapping_list",
81
+ )
82
+
83
+ # overall chain allows us to call the travel_agent and parser in
84
+ # sequence, with labelled outputs.
85
+ overall_chain = SequentialChain(
86
+ chains=[itinerary_agent, mapping_agent],
87
+ input_variables=["start_location", "end_location", "start_date", "end_date",
88
+ "attractions", "budget", "transportation", "accommodation",
89
+ "schedule", "format_instructions"],
90
+ output_variables=["itinerary_suggestion","mapping_list"],
91
+ verbose=debug,
92
+ )
93
+
94
+ return overall_chain
95
 
96
  def validate_travel(self, query):
97
  self.logger.info("Validating query: %s", query)
 
103
  )
104
  )
105
 
 
106
  # Call the validation chain with the query and format instructions
107
  validation_result = self.validation_chain.run(
108
  {
 
111
  }
112
  )
113
 
114
+ self.logger.info("Datatype of validation_result: %s", type(validation_result))
115
+
116
+ # Convert the validation result into a dictionary if it's not one already
117
+ if isinstance(validation_result, dict):
118
+ validation_dict = validation_result
119
+ else: # assuming validation_result is an instance of the Validation class
120
+ validation_dict = validation_result.dict()
121
+
122
+ # Log the datatype and content of the validation output
123
+ self.logger.info("Datatype of validation_dict: %s", type(validation_dict))
124
+ self.logger.info("Content of validation_dict: %s", validation_dict)
125
+
126
  t2 = time.time()
127
  self.logger.debug("Time to validate request: %.2f seconds", t2 - t1)
128
 
129
+ return validation_dict
130
+
131
+
132
+ def calculate_duration(self, start_date, end_date):
133
+ if not isinstance(start_date, date) or not isinstance(end_date, date):
134
+ raise ValueError("start_date and end_date must be datetime.date objects")
135
+
136
+ if end_date < start_date:
137
+ raise ValueError("End date must be after or equal to start date")
138
+
139
+ # Calculate the duration using relativedelta
140
+ delta = relativedelta(end_date, start_date)
141
+
142
+ years = delta.years
143
+ months = delta.months
144
+ days = delta.days + 1 # We'll calculate weeks from days
145
+
146
+ duration_parts = []
147
+
148
+ if years > 0:
149
+ duration_parts.append(f"{years} year{'s' if years > 1 else ''}")
150
+
151
+ if months > 0:
152
+ duration_parts.append(f"{months} month{'s' if months > 1 else ''}")
153
+
154
+ weeks = days // 7
155
+ days = days % 7
156
+
157
+ if weeks > 0:
158
+ duration_parts.append(f"{weeks} week{'s' if weeks > 1 else ''}")
159
+
160
+ if days > 0:
161
+ duration_parts.append(f"{days} day{'s' if days > 1 else ''}")
162
+
163
+ return ', '.join(duration_parts)
164
+
165
+
166
+ def generate_itinerary(self, user_details):
167
+ self.logger.info("Generating itinerary for user details: %s", user_details)
168
+
169
+ # Validate the user details dictionary keys match the expected input variables
170
+ expected_keys = ["start_location", "end_location", "start_date", "end_date",
171
+ "attractions", "budget", "transportation", "accommodation",
172
+ "schedule"]
173
+ for key in expected_keys:
174
+ if key not in user_details:
175
+ self.logger.error("Missing '%s' in user details.", key)
176
+ return None # or handle the missing key appropriately
177
+
178
+ try:
179
+ # Calculate trip duration
180
+ trip_duration = self.calculate_duration(user_details['start_date'], user_details['end_date'])
181
+ # Construct the query phrase
182
+ query_phrase = "{} trip from {} to {}".format(trip_duration, user_details['start_location'], user_details['end_location'])
183
+ except KeyError as e:
184
+ self.logger.error("Missing key in user details: %s", e)
185
+ return None # or handle the missing key appropriately
186
+ except ValueError as e:
187
+ self.logger.error(e)
188
+ return None # or handle the error appropriately
189
+
190
+ t1 = time.time()
191
+
192
+ self.logger.info("Calling itinerary chain to validate user query")
193
+ validation_dict = self.validate_travel(query_phrase)
194
+ is_plan_valid = validation_dict["plan_is_valid"]
195
+
196
+
197
+ if is_plan_valid.lower() == "no":
198
+ self.logger.warning("User request was not valid!")
199
+ print("\n######\n Travel plan is not valid \n######\n")
200
+ print(validation_result["updated_request"])
201
+
202
+ # Create a dictionary with variable names as keys
203
+ result_dict = {
204
+ "itinerary_suggestion": None,
205
+ "list_of_places": None,
206
+ "validation_dict": validation_dict
207
+ }
208
+
209
+ return result_dict
210
+
211
+ self.logger.info("User query is valid. Calling itinerary chain on user details")
212
+
213
+ itinerary_details = user_details.copy()
214
+ itinerary_details["format_instructions"] = self.mapping_prompt.parser.get_format_instructions()
215
+
216
+ # Call the itinerary chain with the itinerary details
217
+ itinerary_result = self.itinerary_chain(itinerary_details)
218
+
219
+ itinerary_suggestion = itinerary_result["itinerary_suggestion"]
220
+ list_of_places = itinerary_result["mapping_list"].dict()
221
+
222
+ # Log the datatype and content of the list_of_places output
223
+ self.logger.info("Datatype of validation_dict: %s", type(list_of_places))
224
+ self.logger.info("Content of validation_dict: %s", list_of_places)
225
+
226
+ t2 = time.time()
227
+ self.logger.debug("Time to generate itinerary: %.2f seconds", t2 - t1)
228
+
229
+ # Create a dictionary with variable names as keys
230
+ result_dict = {
231
+ "itinerary_suggestion": itinerary_suggestion,
232
+ "list_of_places": list_of_places,
233
+ "validation_dict": validation_dict
234
+ }
235
+
236
+ return result_dict
app.py CHANGED
@@ -87,16 +87,59 @@ with st.sidebar:
87
  # Main page layout
88
  st.header('Your Itinerary')
89
  if submit:
90
- if start_location and end_location and attractions and start_date and end_date:
91
- # The function to generate the itinerary would go here.
92
- # The following lines are placeholders to show the captured inputs.
93
- st.write('From:', start_location)
94
- st.write('To:', end_location)
95
- st.write('Travel Dates:', st.session_state['start_date'], 'to', st.session_state['end_date'])
96
- st.write('Attractions:', attractions)
97
- st.write('Budget:', budget)
98
- st.write('Transportation:', transportation)
99
- st.write('Accommodation:', accommodation)
100
- st.write('Daily Schedule:', schedule)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  else:
102
- st.error('Please fill out all required fields.')
 
87
  # Main page layout
88
  st.header('Your Itinerary')
89
  if submit:
90
+ missing_fields = []
91
+
92
+ if not start_location:
93
+ missing_fields.append("Start Location")
94
+
95
+ if not end_location:
96
+ missing_fields.append("End Location")
97
+
98
+ if not attractions:
99
+ missing_fields.append("Attractions")
100
+
101
+ if not start_date:
102
+ missing_fields.append("Start Date")
103
+
104
+ if not end_date:
105
+ missing_fields.append("End Date")
106
+
107
+ if not accommodation:
108
+ missing_fields.append("Accommodation")
109
+
110
+ if not schedule:
111
+ missing_fields.append("Schedule")
112
+
113
+ if not transportation:
114
+ missing_fields.append("Transportation")
115
+
116
+ if not missing_fields:
117
+ user_details = {
118
+ 'start_location': start_location,
119
+ 'end_location': end_location,
120
+ 'start_date': start_date,
121
+ 'end_date': end_date,
122
+ 'attractions': attractions,
123
+ 'budget': budget,
124
+ 'transportation': transportation,
125
+ 'accommodation': accommodation,
126
+ 'schedule': schedule
127
+ }
128
+
129
+ # Display user details in the Streamlit console
130
+ st.write("Start Location:", user_details['start_location'])
131
+ st.write("End Location:", user_details['end_location'])
132
+ st.write("Start Date:", user_details['start_date'])
133
+ st.write("End Date:", user_details['end_date'])
134
+ st.write("Attractions:", user_details['attractions'])
135
+ st.write("Budget:", user_details['budget'])
136
+ st.write("Transportation:", user_details['transportation'])
137
+ st.write("Accommodation:", user_details['accommodation'])
138
+ st.write("Schedule:", user_details['schedule'])
139
+
140
+ note = """Here is your personalized travel itinerary covering all the major locations you want to visit.
141
+ This map gives you a general overview of your travel route from start to finish.
142
+ For daily plans, please review the sequence of waypoints to ensure the best experience,
143
+ as the route may vary based on daily activities and traffic conditions."""
144
  else:
145
+ st.error(f'Please fill out the following required fields: {", ".join(missing_fields)}')
templates/itinerary.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import (
2
+ ChatPromptTemplate,
3
+ SystemMessagePromptTemplate,
4
+ HumanMessagePromptTemplate,
5
+ )
6
+
7
+ class ItineraryTemplate:
8
+ def __init__(self):
9
+ self.system_template = """
10
+ You are a sophisticated AI travel guide. Your job is to create engaging and practical travel plans for users.
11
+
12
+ The user's travel details and preferences will be presented in a structured format, starting with four hashtags. Your task is to convert these details into a detailed travel itinerary that includes waypoints, activities, and suggestions tailored to the user's interests and constraints.
13
+
14
+ Focus on providing specific addresses or locations for each suggested activity.
15
+
16
+ Take into account the user's available time, budget, preferred transportation mode, accommodation type, and desired daily schedule intensity. Aim to construct an itinerary that is enjoyable, feasible, and aligns with the user's expectations.
17
+
18
+ Present the itinerary in a bulleted list format with clear indications of the start and end points, as well as the recommended mode of transit between locations.
19
+
20
+ If the user has not specified certain details, use your judgment to select appropriate options, ensuring to provide specific addresses. Your response should be the itinerary list exclusively.
21
+ """
22
+
23
+ self.human_template = """
24
+ #### User's Travel Details ####
25
+ - Starting Location: {start_location}
26
+ - Destination: {end_location}
27
+ - Travel Dates: {start_date} to {end_date}
28
+ - Attractions of Interest: {attractions}
29
+ - Travel Budget Range: {budget}
30
+ - Preferred Transportation: {transportation}
31
+ - Accommodation Type: {accommodation}
32
+ - Desired Daily Schedule Intensity: {schedule}
33
+ """
34
+
35
+ # Assuming SystemMessagePromptTemplate and HumanMessagePromptTemplate are classes
36
+ # defined elsewhere in your codebase that format these messages.
37
+ self.system_message_prompt = SystemMessagePromptTemplate.from_template(
38
+ self.system_template
39
+ )
40
+ self.human_message_prompt = HumanMessagePromptTemplate.from_template(
41
+ self.human_template,
42
+ input_variables=["start_location", "end_location", "start_date", "end_date",
43
+ "attractions", "budget", "transportation", "accommodation",
44
+ "schedule"]
45
+ )
46
+
47
+ self.chat_prompt = ChatPromptTemplate.from_messages(
48
+ [self.system_message_prompt, self.human_message_prompt]
49
+ )
50
+
51
+ def generate_prompt(self, user_details):
52
+ """
53
+ Fill in the human template with the actual details from the user's input.
54
+ """
55
+ filled_human_template = self.human_message_prompt.format(
56
+ start_location=user_details['start_location'],
57
+ end_location=user_details['end_location'],
58
+ start_date=user_details['start_date'].strftime('%Y-%m-%d'),
59
+ end_date=user_details['end_date'].strftime('%Y-%m-%d'),
60
+ attractions=", ".join(user_details['attractions']),
61
+ budget=f"{user_details['budget'][0]} to {user_details['budget'][1]}",
62
+ transportation=user_details['transportation'],
63
+ accommodation=user_details['accommodation'],
64
+ schedule=user_details['schedule']
65
+ )
66
+ return f"{self.system_message_prompt}\n{filled_human_template}"
templates/mapping.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import (
2
+ ChatPromptTemplate,
3
+ SystemMessagePromptTemplate,
4
+ HumanMessagePromptTemplate,
5
+ )
6
+ from langchain.output_parsers import PydanticOutputParser
7
+ from pydantic import BaseModel, Field
8
+ from typing import List
9
+
10
+ class Path(BaseModel):
11
+ start_location: str = Field(description="start location of trip")
12
+ end_location: str = Field(description="end location of trip")
13
+ waypoints: List[str] = Field(description="list of waypoints")
14
+ transit: str = Field(description="mode of transportation")
15
+
16
+ class MappingTemplate(object):
17
+
18
+ MAX_WAYPOINTS = 20 # This can be adjusted or made configurable.
19
+
20
+ def __init__(self):
21
+ self.system_template = """
22
+ You an agent who converts detailed travel plans into a simple list of locations.
23
+
24
+ The itinerary will be denoted by four hashtags. Convert it into
25
+ list of places that they should visit. Try to include the specific address of each location.
26
+
27
+ Your output should always contain the start and end point of the trip, and may also include a list
28
+ of waypoints. It should also include a mode of transit. The number of waypoints cannot exceed 20.
29
+ If you can't infer the mode of transit, make a best guess given the trip location.
30
+
31
+ For example:
32
+
33
+ ####
34
+ Itinerary for a 2-day driving trip within London:
35
+ - Day 1:
36
+ - Start at Buckingham Palace (The Mall, London SW1A 1AA)
37
+ - Visit the Tower of London (Tower Hill, London EC3N 4AB)
38
+ - Explore the British Museum (Great Russell St, Bloomsbury, London WC1B 3DG)
39
+ - Enjoy shopping at Oxford Street (Oxford St, London W1C 1JN)
40
+ - End the day at Covent Garden (Covent Garden, London WC2E 8RF)
41
+ - Day 2:
42
+ - Start at Westminster Abbey (20 Deans Yd, Westminster, London SW1P 3PA)
43
+ - Visit the Churchill War Rooms (Clive Steps, King Charles St, London SW1A 2AQ)
44
+ - Explore the Natural History Museum (Cromwell Rd, Kensington, London SW7 5BD)
45
+ - End the trip at the Tower Bridge (Tower Bridge Rd, London SE1 2UP)
46
+ #####
47
+
48
+ Output:
49
+ Start: Buckingham Palace, The Mall, London SW1A 1AA
50
+ End: Tower Bridge, Tower Bridge Rd, London SE1 2UP
51
+ Waypoints: ["Tower of London, Tower Hill, London EC3N 4AB", "British Museum, Great Russell St, Bloomsbury, London WC1B 3DG", "Oxford St, London W1C 1JN", "Covent Garden, London WC2E 8RF","Westminster, London SW1A 0AA", "St. James's Park, London", "Natural History Museum, Cromwell Rd, Kensington, London SW7 5BD"]
52
+ Transit: driving
53
+
54
+ Transit can be only one of the following options: "driving", "train", "bus" or "flight".
55
+
56
+ {format_instructions}
57
+ """
58
+
59
+ self.human_template = """
60
+ ####{itinerary_suggestion}####
61
+ """
62
+
63
+ self.parser = PydanticOutputParser(pydantic_object=Path)
64
+
65
+ self.system_message_prompt = SystemMessagePromptTemplate.from_template(
66
+ self.system_template,
67
+ partial_variables={
68
+ "format_instructions": self.parser.get_format_instructions()
69
+ },
70
+ )
71
+ self.human_message_prompt = HumanMessagePromptTemplate.from_template(
72
+ self.human_template, input_variables=["itinerary_suggestion"]
73
+ )
74
+
75
+ self.chat_prompt = ChatPromptTemplate.from_messages(
76
+ [self.system_message_prompt, self.human_message_prompt]
77
+ )
tests/agent/test_agent.py CHANGED
@@ -1,6 +1,7 @@
1
  import unittest
2
- from config import load_secrets # Update this path to match the actual location
3
- from agents.agent import Agent # Update this path to match the actual location
 
4
 
5
  DEBUG = True
6
 
@@ -23,7 +24,7 @@ class TestAgentMethods(unittest.TestCase):
23
  debug=self.debug,
24
  )
25
 
26
- # @unittest.skipIf(DEBUG, "Skipping this test while debugging other tests")
27
  def test_validation_chain(self):
28
  validation_chain = self.agent._set_up_validation_chain(debug=self.debug)
29
 
@@ -60,6 +61,34 @@ class TestAgentMethods(unittest.TestCase):
60
  q3_out = q3_res["validation_output"].dict()
61
  self.assertEqual(q3_out["plan_is_valid"], "yes")
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  if __name__ == "__main__":
65
  unittest.main()
 
1
  import unittest
2
+ from config import load_secrets
3
+ from agents.agent import Agent
4
+ import datetime
5
 
6
  DEBUG = True
7
 
 
24
  debug=self.debug,
25
  )
26
 
27
+ @unittest.skipIf(DEBUG, "Skipping this test while debugging other tests")
28
  def test_validation_chain(self):
29
  validation_chain = self.agent._set_up_validation_chain(debug=self.debug)
30
 
 
61
  q3_out = q3_res["validation_output"].dict()
62
  self.assertEqual(q3_out["plan_is_valid"], "yes")
63
 
64
+
65
+ def test_generate_itinerary(self):
66
+
67
+ user_details = {
68
+ "start_location": "Berkeley, CA",
69
+ "end_location": "Seattle, WA",
70
+ "start_date": datetime.date(2023, 12, 10),
71
+ "end_date": datetime.date(2023, 12, 15),
72
+ "attractions": ["museums", "parks"],
73
+ "budget": "1500-3000 USD",
74
+ "transportation": "rental car, public Transport",
75
+ "accommodation": "hotels",
76
+ "schedule": "relaxed"
77
+ }
78
+
79
+ # Call the generate_itinerary method with the user_details
80
+ itinerary_result = self.agent.generate_itinerary(user_details)
81
+
82
+ itinerary_suggestion = itinerary_result["itinerary_suggestion"]
83
+ list_of_places = itinerary_result["list_of_places"]
84
+ validation_dict = itinerary_result["validation_dict"]
85
+
86
+ print("\nItinerary Suggestion Returned:\n", itinerary_suggestion)
87
+ print("\nList of Places Returned:\n", list_of_places)
88
+
89
+ # Assert that the itinerary contains expected keys or values.
90
+ # This depends on what `generate_itinerary` returns. For example:
91
+ self.assertIsNotNone(itinerary_suggestion)
92
 
93
  if __name__ == "__main__":
94
  unittest.main()