# standard modules import os import pickle import datetime from PIL import Image from typing import List, Union # useful modules ("pip install" is required) import numpy as np import streamlit as st import pandas as pd import googlemaps import langchain from langchain.globals import set_verbose from langchain.chat_models import ChatOpenAI from langchain.schema import HumanMessage, AIMessage # our defined modules import utils.util_app as util_app from models.solvers.general_solver import GeneralSolver from models.cf_generator import CFTourGenerator from models.classifiers.general_classifier import GeneralClassifier from models.route_explainer import RouteExplainer # general setting SEED = 1234 TOUR_NAME = "static/kyoto_tour" TOUR_PATH = TOUR_NAME + ".csv" TOUR_LATLNG_PATH = TOUR_NAME + "_latlng.csv" TOUR_DISTMAT_PATH = TOUR_NAME + "_distmat.pkl" EXPANDED = False DEBUG = True ROUTE_EXPLAINER_ICON = np.array(Image.open("static/route_explainer_icon.png")) # for debug if DEBUG: langchain.debug = True set_verbose(True) def load_tour_list(): # get lat/lng if os.path.isfile(TOUR_LATLNG_PATH): df_tour = pd.read_csv(TOUR_LATLNG_PATH) else: df_tour = pd.read_csv(TOUR_PATH) if googleapi_key := st.session_state.googleapi_key: gmaps = googlemaps.Client(key=googleapi_key) lat_list =[]; lng_list = [] for destination in df_tour["destination"]: geo_result = gmaps.geocode(destination) lat_list.append(geo_result[0]["geometry"]["location"]["lat"]) lng_list.append(geo_result[0]["geometry"]["location"]["lng"]) # add lat/lng df_tour["lat"] = lat_list df_tour["lng"] = lng_list df_tour.to_csv(TOUR_LATLNG_PATH) # get the central point st.session_state.lat_mean = np.mean(df_tour["lat"]) st.session_state.lng_mean = np.mean(df_tour["lng"]) st.session_state.sw = df_tour[["lat", "lng"]].min().tolist() st.session_state.ne = df_tour[["lat", "lng"]].max().tolist() st.session_state.df_tour = df_tour # get the distance matrix if os.path.isfile(TOUR_DISTMAT_PATH): with open(TOUR_DISTMAT_PATH, "rb") as f: distmat = pickle.load(f) else: if googleapi_key := st.session_state.googleapi_key: gmaps = googlemaps.Client(key=googleapi_key) distmat = [] for origin in df_tour["destination"]: distrow = [] for dest in df_tour["destination"]: if origin != dest: dist_result = gmaps.distance_matrix(origin, dest, mode="driving") distrow.append(dist_result["rows"][0]["elements"][0]["duration"]["value"]) # unit: seconds else: distrow.append(0) distmat.append(distrow) distmat = np.array(distmat) with open(TOUR_DISTMAT_PATH, "wb") as f: pickle.dump(distmat, f) # input features def convert_clock2seconds(clock): return sum([a*b for a, b in zip([3600, 60], map(int, clock.split(':')))]) time_windows = [] for i in range(len(df_tour)): time_windows.append([convert_clock2seconds(df_tour["open"][i]), convert_clock2seconds(df_tour["close"][i])]) time_windows = np.array(time_windows) time_windows -= time_windows[0, 0] node_feats = { "time_window": time_windows.clip(0), "service_time": df_tour["stay_duration (h)"].to_numpy() * 3600 } st.session_state.node_feats = node_feats st.session_state.dist_matrix = distmat st.session_state.node_info = { "open": df_tour["open"], "close": df_tour["close"], "stay": df_tour["stay_duration (h)"] } # tour list if os.path.isfile(TOUR_DISTMAT_PATH) & os.path.isfile(TOUR_LATLNG_PATH): st.session_state.tour_list = [] for i in range(len(df_tour)): st.session_state.tour_list.append({ "name": df_tour["destination"][i], "latlng": (df_tour["lat"][i], df_tour["lng"][i]), "description": f"Hours: {df_tour['open'][i]} - {df_tour['close'][i]}
Duration of stay: {df_tour['stay_duration (h)'][i]}h
Remarks: {df_tour['remarks'][i]}
" }) def solve_vrp() -> None: if ("node_feats" in st.session_state) and ("dist_matrix" in st.session_state): solver = GeneralSolver("tsptw", "ortools", scaling=False) classifier = GeneralClassifier("tsptw", "gt(ortools)") routes = solver.solve(node_feats=st.session_state.node_feats, dist_matrix=st.session_state.dist_matrix) inputs = classifier.get_inputs(routes, 0, st.session_state.node_feats, st.session_state.dist_matrix) labels = classifier(inputs) st.session_state.routes = routes.copy() st.session_state.labels = labels.copy() st.session_state.generated_actual_route = True #---------- # LLM #---------- def load_route_explainer(llm_type: str) -> None: if st.session_state.openai_key: # define llm llm = ChatOpenAI(model=llm_type, temperature=0, streaming=True, model_kwargs={"seed": SEED}) # model_kwargs={"stop": ["\n\n", "Human"]} # define RouteExplainer cf_generator = CFTourGenerator(cf_solver=GeneralSolver("tsptw", "ortools", scaling=False)) classifier = GeneralClassifier("tsptw", "gt(ortools)") st.session_state.route_explainer = RouteExplainer(llm=llm, cf_generator=cf_generator, classifier=classifier) #---------- # UI #---------- # css settings st.set_page_config(layout="wide") util_app.apply_responsible_map_css() util_app.apply_centerize_icon_css() util_app.apply_red_code_css() util_app.apply_remove_sidebar_topspace() #------------------ # side bar setting #------------------ with st.sidebar: #------- # Title #------- icon_col, name_col = st.columns((1,10)) with icon_col: util_app.apply_html('RouteExplainer') with name_col: st.title("RouteExplainer") #---------- # API keys #---------- st.subheader("API keys") openai_key_col1, openai_key_col2 = st.columns((1,10)) with openai_key_col1: util_app.apply_html(' OpenAI API ') with openai_key_col2: openai_key = st.text_input(label="API keys", key="openai_key", placeholder="OpenAI API key", type="password", label_visibility="collapsed") changed_key = openai_key == os.environ.get('OPENAI_API_KEY') os.environ['OPENAI_API_KEY'] = openai_key google_key_col1, google_key_col2 = st.columns((1, 10)) with google_key_col1: util_app.apply_html(' GoogleMap API ') with google_key_col2: st.text_input(label="GoogleMap API key", key="googleapi_key", placeholder="NOT required in this demo", type="password", label_visibility="collapsed") #---------------- # Foundation LLM #---------------- st.subheader("Foundation LLM") llm_type = st.selectbox("LLM", ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo"], key="llm_type", label_visibility="collapsed") #----------- # Tour plan #----------- st.subheader("Tour plan") col1, col2 = st.columns((2, 1)) with col1: # Comming soon: "Taipei Tour (for PAKDD2024)" tour_plan = st.selectbox("Tour plan", ["Kyoto Tour"], key="tour_type", label_visibility="collapsed") with col2: st.button("Generate", on_click=solve_vrp, use_container_width=True) # list destinations load_tour_list() with st.container(): if "routes" in st.session_state: # rearranage destinations in the route order if a route was derivied # re-ordered destinations reordered_tour_list = [st.session_state.tour_list[i] for i in st.session_state.routes[0][:-1]] if "routes" in st.session_state else st.session_state.tour_list arr_time = datetime.datetime.strptime(st.session_state.node_info["open"][0], "%H:%M") for step in range(len(reordered_tour_list)): curr = reordered_tour_list[step] next = reordered_tour_list[step+1] if step != len(reordered_tour_list) - 1 else reordered_tour_list[0] curr_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, curr["name"]) next_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, next["name"]) open_time = datetime.datetime.strptime(st.session_state.node_info["open"][curr_node_id], "%H:%M") # destination info dep_time = max(arr_time, open_time) + datetime.timedelta(hours=st.session_state.node_info["stay"][curr_node_id]) dep_time_str = dep_time.strftime("%H:%M") arr_time_str = arr_time.strftime("%H:%M") arr_dep = f"Arr {arr_time_str} - Dep {dep_time_str}" if step != 0 else f"⭐ Dep {dep_time_str}" with st.expander(f"{arr_dep} | {curr['name']}", expanded=EXPANDED): st.write(curr["description"], unsafe_allow_html=True) # travel time travel_time = st.session_state.dist_matrix[curr_node_id][next_node_id].item() col1, col2, col3 = st.columns(3) with col1: st.markdown(f"
{util_app.add_time_unit(travel_time)}
", unsafe_allow_html=True) with col2: st.markdown("
|
", unsafe_allow_html=True) st.write("") arr_time = dep_time + datetime.timedelta(seconds=travel_time) # return to the origin destination = reordered_tour_list[0] arr_time_str = arr_time.strftime("%H:%M") with st.expander(f"⭐ Arr {arr_time_str} | {destination['name']}", expanded=EXPANDED): st.write(destination["description"], unsafe_allow_html=True) else: # just list destinations for destination in st.session_state.tour_list: with st.expander(destination['name'], expanded=EXPANDED): st.write(destination["description"], unsafe_allow_html=True) #---------------------- # state initialization #---------------------- if "count" not in st.session_state: st.session_state.count = 0 if "chat_history" not in st.session_state: st.session_state.chat_history = [] if "generated_actual_route" not in st.session_state: st.session_state.generated_actual_route = False if "generated_cf_route" not in st.session_state: st.session_state.generated_cf_route = False if "curr_route" not in st.session_state: st.session_state.curr_route = "Actual Route" # once the CF route is selected, this will be "Current Route" if "flag_example" not in st.session_state: st.session_state.flag_example = False if "selected_example" not in st.session_state: st.session_state.selected_example = None if "close_chat" not in st.session_state: st.session_state.close_chat = False if "route_explainer" not in st.session_state or llm_type != st.session_state.curr_llm_type or changed_key: load_route_explainer(llm_type) st.session_state.curr_llm_type = llm_type #-------------------------------- # The following is the main page #-------------------------------- #---------- # Greeding #---------- if "routes" not in st.session_state: util_app.apply_html('
OpenAI API
') greeding = "Hi, I'm RouteExplainer :)
Choose a tour and hit the Generate button to generate your initial route!" if st.session_state.count == 0: util_app.stream_words(greeding, prefix="

", suffix="

", sleep_time=0.02) else: util_app.apply_html(f"

{greeding}

") #-------------- # chat history #-------------- def find_last_map(lst: List[Union[str, tuple]]) -> int: for i in range(len(lst) - 1, -1, -1): if isinstance(lst[i], tuple): return i return None last_map_idx = find_last_map(st.session_state.chat_history) for i, msg in enumerate(st.session_state.chat_history): if isinstance(msg, tuple): # if the history type is a tuple of maps map1, map2 = (0, 1) if i == last_map_idx else (2, 3) actual_route, cf_route = st.columns(2) if msg[map1] is not None: with actual_route: util_app.visualize_actual_route(msg[map1]) if msg[map2] is not None: with cf_route: util_app.visualize_cf_route(msg[map2]) else: # if the history type is string if isinstance(msg, AIMessage): st.chat_message(msg.type, avatar=ROUTE_EXPLAINER_ICON).write(msg.content) else: st.chat_message(msg.type).write(msg.content) # examples if "cf_routes" not in st.session_state and st.session_state.flag_example: def pickup_example(example: str): st.session_state.selected_example = example examples = [ "Why do we visit Ginkaku-ji Temple from Fushimi-Inari Shrine and why not Kiyomizu-dera Temple?", "What if we visit Kinkaku-ji directly from Kyoto Geishinkan, instead of Nijo-jo Castle?", "Why was the edge from Kinkaku-ji to Kiyomizu-dera selected and why not the edge from Kinkaku-ji to Hanamikoji Dori?" ] col1, col2, col3 = st.columns(3) with col1: st.button(examples[0], use_container_width=True, on_click=pickup_example, args=(examples[0], )) with col2: st.button(examples[1], use_container_width=True, on_click=pickup_example, args=(examples[1], )) with col3: st.button(examples[2], use_container_width=True, on_click=pickup_example, args=(examples[2], )) #---------- # chat box #---------- def answer(prompt: str): st.session_state.chat_history.append(HumanMessage(content=prompt)) st.chat_message("user").write(prompt) if os.environ.get('OPENAI_API_KEY') == "": error_msg = "An OpenAI API key has not been set yet :( Please enter a valid key in the side bar!" with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON): st.write(error_msg) st.session_state.chat_history.append(AIMessage(content=error_msg)) elif util_app.validate_openai_api_key(os.environ.get('OPENAI_API_KEY')): with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON): explanation = st.session_state.route_explainer.generate_explanation(tour_list=st.session_state.tour_list, whynot_question=prompt, actual_routes=st.session_state.routes, actual_labels=st.session_state.labels, node_feats=st.session_state.node_feats, dist_matrix=st.session_state.dist_matrix) if len(explanation) > 0: st.session_state.chat_history.append(AIMessage(content=explanation)) if st.session_state.generated_cf_route: st.rerun() else: error_msg = "The input OpenAI API key appears to be invalid :( Please enter a valid key again in the side bar!" with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON): st.write(error_msg) st.session_state.chat_history.append(AIMessage(content=error_msg)) if st.session_state.selected_example is not None: example = st.session_state.selected_example st.session_state.selected_example = None answer(example) else: if "routes" in st.session_state and not st.session_state.close_chat: if prompt := st.chat_input(placeholder="Ask a why-not question", key="chat_input"): answer(prompt) #--------------------- # route visualization #--------------------- if "tour_list" in st.session_state: # if tour info is loaded # first message if st.session_state.generated_actual_route: # when an actual route is generated with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON): msg = "Here is your initial route. Please ask me a why and why-not question for a specfic edge!" util_app.stream_words(msg, sleep_time=0.01) st.session_state.flag_example = True st.session_state.chat_history.append(AIMessage(content=msg)) # visualize the actual & CF routes actual_route, cf_route = st.columns(2) m = None; m2 = None; m_ = None; m2_ = None if st.session_state.generated_actual_route or st.session_state.generated_cf_route: m = util_app.initialize_map() # overwrite m m_ = util_app.initialize_map() if "labels" in st.session_state: cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1 util_app.vis_route("routes", st.session_state.labels, m, cf_step, "actual") util_app.vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False) with actual_route: util_app.visualize_actual_route(m) if st.session_state.generated_cf_route: m2 = util_app.initialize_map() # overwrite m2 m2_ = util_app.initialize_map() if "cf_labels" in st.session_state: util_app.vis_route("cf_routes", st.session_state.cf_labels, m2, st.session_state.cf_step-1, "cf") util_app.vis_route("cf_routes", st.session_state.cf_labels, m2_, st.session_state.cf_step-1, "cf", ant_path=False) with cf_route: util_app.visualize_cf_route(m2) # update states related to maps if m is not None: st.session_state.chat_history.append((m, m2, m_, m2_)) # route selection button if len(st.session_state.chat_history) > 0: last_msg = st.session_state.chat_history[-1] if isinstance(last_msg, tuple): if (last_msg[0] is not None) and (last_msg[1] is not None): col1, col2 = st.columns(2) with col1: st.button("Stay this route", on_click=util_app.select_actual_route) with col2: st.button("Replace with this route", on_click=util_app.select_cf_route) util_app.change_hover_color("button", "Replace with this route", "#1e90ff") # for displaying examples if st.session_state.generated_actual_route: st.session_state.generated_actual_route = False st.rerun() st.session_state.generated_actual_route = False st.session_state.generated_cf_route = False # update session count st.session_state.count += 1 js = f""" """ st.components.v1.html(js, height=0, width=0)