# standard modules
import os
import pickle
import datetime
from PIL import Image
from typing import List, Union
# useful modules ("pip install" is required)
import numpy as np
import streamlit as st
import pandas as pd
import googlemaps
import langchain
from langchain.globals import set_verbose
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, AIMessage
# our defined modules
import utils.util_app as util_app
from models.solvers.general_solver import GeneralSolver
from models.cf_generator import CFTourGenerator
from models.classifiers.general_classifier import GeneralClassifier
from models.route_explainer import RouteExplainer
# general setting
SEED = 1234
TOUR_NAME = "static/kyoto_tour"
TOUR_PATH = TOUR_NAME + ".csv"
TOUR_LATLNG_PATH = TOUR_NAME + "_latlng.csv"
TOUR_DISTMAT_PATH = TOUR_NAME + "_distmat.pkl"
EXPANDED = False
DEBUG = True
ROUTE_EXPLAINER_ICON = np.array(Image.open("static/route_explainer_icon.png"))
# for debug
if DEBUG:
langchain.debug = True
set_verbose(True)
def load_tour_list():
# get lat/lng
if os.path.isfile(TOUR_LATLNG_PATH):
df_tour = pd.read_csv(TOUR_LATLNG_PATH)
else:
df_tour = pd.read_csv(TOUR_PATH)
if googleapi_key := st.session_state.googleapi_key:
gmaps = googlemaps.Client(key=googleapi_key)
lat_list =[]; lng_list = []
for destination in df_tour["destination"]:
geo_result = gmaps.geocode(destination)
lat_list.append(geo_result[0]["geometry"]["location"]["lat"])
lng_list.append(geo_result[0]["geometry"]["location"]["lng"])
# add lat/lng
df_tour["lat"] = lat_list
df_tour["lng"] = lng_list
df_tour.to_csv(TOUR_LATLNG_PATH)
# get the central point
st.session_state.lat_mean = np.mean(df_tour["lat"])
st.session_state.lng_mean = np.mean(df_tour["lng"])
st.session_state.sw = df_tour[["lat", "lng"]].min().tolist()
st.session_state.ne = df_tour[["lat", "lng"]].max().tolist()
st.session_state.df_tour = df_tour
# get the distance matrix
if os.path.isfile(TOUR_DISTMAT_PATH):
with open(TOUR_DISTMAT_PATH, "rb") as f:
distmat = pickle.load(f)
else:
if googleapi_key := st.session_state.googleapi_key:
gmaps = googlemaps.Client(key=googleapi_key)
distmat = []
for origin in df_tour["destination"]:
distrow = []
for dest in df_tour["destination"]:
if origin != dest:
dist_result = gmaps.distance_matrix(origin, dest, mode="driving")
distrow.append(dist_result["rows"][0]["elements"][0]["duration"]["value"]) # unit: seconds
else:
distrow.append(0)
distmat.append(distrow)
distmat = np.array(distmat)
with open(TOUR_DISTMAT_PATH, "wb") as f:
pickle.dump(distmat, f)
# input features
def convert_clock2seconds(clock):
return sum([a*b for a, b in zip([3600, 60], map(int, clock.split(':')))])
time_windows = []
for i in range(len(df_tour)):
time_windows.append([convert_clock2seconds(df_tour["open"][i]),
convert_clock2seconds(df_tour["close"][i])])
time_windows = np.array(time_windows)
time_windows -= time_windows[0, 0]
node_feats = {
"time_window": time_windows.clip(0),
"service_time": df_tour["stay_duration (h)"].to_numpy() * 3600
}
st.session_state.node_feats = node_feats
st.session_state.dist_matrix = distmat
st.session_state.node_info = {
"open": df_tour["open"],
"close": df_tour["close"],
"stay": df_tour["stay_duration (h)"]
}
# tour list
if os.path.isfile(TOUR_DISTMAT_PATH) & os.path.isfile(TOUR_LATLNG_PATH):
st.session_state.tour_list = []
for i in range(len(df_tour)):
st.session_state.tour_list.append({
"name": df_tour["destination"][i],
"latlng": (df_tour["lat"][i], df_tour["lng"][i]),
"description": f"Hours: {df_tour['open'][i]} - {df_tour['close'][i]}
Duration of stay: {df_tour['stay_duration (h)'][i]}h
Remarks: {df_tour['remarks'][i]}"
})
def solve_vrp() -> None:
if ("node_feats" in st.session_state) and ("dist_matrix" in st.session_state):
solver = GeneralSolver("tsptw", "ortools", scaling=False)
classifier = GeneralClassifier("tsptw", "gt(ortools)")
routes = solver.solve(node_feats=st.session_state.node_feats,
dist_matrix=st.session_state.dist_matrix)
inputs = classifier.get_inputs(routes,
0,
st.session_state.node_feats,
st.session_state.dist_matrix)
labels = classifier(inputs)
st.session_state.routes = routes.copy()
st.session_state.labels = labels.copy()
st.session_state.generated_actual_route = True
#----------
# LLM
#----------
def load_route_explainer(llm_type: str) -> None:
if st.session_state.openai_key:
# define llm
llm = ChatOpenAI(model=llm_type,
temperature=0,
streaming=True,
model_kwargs={"seed": SEED})
# model_kwargs={"stop": ["\n\n", "Human"]}
# define RouteExplainer
cf_generator = CFTourGenerator(cf_solver=GeneralSolver("tsptw", "ortools", scaling=False))
classifier = GeneralClassifier("tsptw", "gt(ortools)")
st.session_state.route_explainer = RouteExplainer(llm=llm,
cf_generator=cf_generator,
classifier=classifier)
#----------
# UI
#----------
# css settings
st.set_page_config(layout="wide")
util_app.apply_responsible_map_css()
util_app.apply_centerize_icon_css()
util_app.apply_red_code_css()
util_app.apply_remove_sidebar_topspace()
#------------------
# side bar setting
#------------------
with st.sidebar:
#-------
# Title
#-------
icon_col, name_col = st.columns((1,10))
with icon_col:
util_app.apply_html('')
with name_col:
st.title("RouteExplainer")
#----------
# API keys
#----------
st.subheader("API keys")
openai_key_col1, openai_key_col2 = st.columns((1,10))
with openai_key_col1:
util_app.apply_html(' ')
with openai_key_col2:
openai_key = st.text_input(label="API keys",
key="openai_key",
placeholder="OpenAI API key",
type="password",
label_visibility="collapsed")
changed_key = openai_key == os.environ.get('OPENAI_API_KEY')
os.environ['OPENAI_API_KEY'] = openai_key
google_key_col1, google_key_col2 = st.columns((1, 10))
with google_key_col1:
util_app.apply_html(' ')
with google_key_col2:
st.text_input(label="GoogleMap API key",
key="googleapi_key",
placeholder="NOT required in this demo",
type="password",
label_visibility="collapsed")
#----------------
# Foundation LLM
#----------------
st.subheader("Foundation LLM")
llm_type = st.selectbox("LLM", ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo"], key="llm_type", label_visibility="collapsed")
#-----------
# Tour plan
#-----------
st.subheader("Tour plan")
col1, col2 = st.columns((2, 1))
with col1:
# Comming soon: "Taipei Tour (for PAKDD2024)"
tour_plan = st.selectbox("Tour plan", ["Kyoto Tour"], key="tour_type", label_visibility="collapsed")
with col2:
st.button("Generate", on_click=solve_vrp, use_container_width=True)
# list destinations
load_tour_list()
with st.container():
if "routes" in st.session_state: # rearranage destinations in the route order if a route was derivied
# re-ordered destinations
reordered_tour_list = [st.session_state.tour_list[i] for i in st.session_state.routes[0][:-1]] if "routes" in st.session_state else st.session_state.tour_list
arr_time = datetime.datetime.strptime(st.session_state.node_info["open"][0], "%H:%M")
for step in range(len(reordered_tour_list)):
curr = reordered_tour_list[step]
next = reordered_tour_list[step+1] if step != len(reordered_tour_list) - 1 else reordered_tour_list[0]
curr_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, curr["name"])
next_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, next["name"])
open_time = datetime.datetime.strptime(st.session_state.node_info["open"][curr_node_id], "%H:%M")
# destination info
dep_time = max(arr_time, open_time) + datetime.timedelta(hours=st.session_state.node_info["stay"][curr_node_id])
dep_time_str = dep_time.strftime("%H:%M")
arr_time_str = arr_time.strftime("%H:%M")
arr_dep = f"Arr {arr_time_str} - Dep {dep_time_str}" if step != 0 else f"⭐ Dep {dep_time_str}"
with st.expander(f"{arr_dep} | {curr['name']}", expanded=EXPANDED):
st.write(curr["description"], unsafe_allow_html=True)
# travel time
travel_time = st.session_state.dist_matrix[curr_node_id][next_node_id].item()
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"
Generate
button to generate your initial route!"
if st.session_state.count == 0:
util_app.stream_words(greeding, prefix="