route-explainer / app.py
daisuke.kikuta
first commit
719d0db
raw
history blame
20.6 kB
# standard modules
import os
import pickle
import datetime
from PIL import Image
from typing import List, Union
# useful modules ("pip install" is required)
import numpy as np
import streamlit as st
import pandas as pd
import googlemaps
import langchain
from langchain.globals import set_verbose
from langchain.chat_models import ChatOpenAI
from langchain.schema import HumanMessage, AIMessage
# our defined modules
import utils.util_app as util_app
from models.solvers.general_solver import GeneralSolver
from models.cf_generator import CFTourGenerator
from models.classifiers.general_classifier import GeneralClassifier
from models.route_explainer import RouteExplainer
# general setting
SEED = 1234
TOUR_NAME = "static/kyoto_tour"
TOUR_PATH = TOUR_NAME + ".csv"
TOUR_LATLNG_PATH = TOUR_NAME + "_latlng.csv"
TOUR_DISTMAT_PATH = TOUR_NAME + "_distmat.pkl"
EXPANDED = False
DEBUG = True
ROUTE_EXPLAINER_ICON = np.array(Image.open("static/route_explainer_icon.png"))
# for debug
if DEBUG:
langchain.debug = True
set_verbose(True)
def load_tour_list():
# get lat/lng
if os.path.isfile(TOUR_LATLNG_PATH):
df_tour = pd.read_csv(TOUR_LATLNG_PATH)
else:
df_tour = pd.read_csv(TOUR_PATH)
if googleapi_key := st.session_state.googleapi_key:
gmaps = googlemaps.Client(key=googleapi_key)
lat_list =[]; lng_list = []
for destination in df_tour["destination"]:
geo_result = gmaps.geocode(destination)
lat_list.append(geo_result[0]["geometry"]["location"]["lat"])
lng_list.append(geo_result[0]["geometry"]["location"]["lng"])
# add lat/lng
df_tour["lat"] = lat_list
df_tour["lng"] = lng_list
df_tour.to_csv(TOUR_LATLNG_PATH)
# get the central point
st.session_state.lat_mean = np.mean(df_tour["lat"])
st.session_state.lng_mean = np.mean(df_tour["lng"])
st.session_state.sw = df_tour[["lat", "lng"]].min().tolist()
st.session_state.ne = df_tour[["lat", "lng"]].max().tolist()
st.session_state.df_tour = df_tour
# get the distance matrix
if os.path.isfile(TOUR_DISTMAT_PATH):
with open(TOUR_DISTMAT_PATH, "rb") as f:
distmat = pickle.load(f)
else:
if googleapi_key := st.session_state.googleapi_key:
gmaps = googlemaps.Client(key=googleapi_key)
distmat = []
for origin in df_tour["destination"]:
distrow = []
for dest in df_tour["destination"]:
if origin != dest:
dist_result = gmaps.distance_matrix(origin, dest, mode="driving")
distrow.append(dist_result["rows"][0]["elements"][0]["duration"]["value"]) # unit: seconds
else:
distrow.append(0)
distmat.append(distrow)
distmat = np.array(distmat)
with open(TOUR_DISTMAT_PATH, "wb") as f:
pickle.dump(distmat, f)
# input features
def convert_clock2seconds(clock):
return sum([a*b for a, b in zip([3600, 60], map(int, clock.split(':')))])
time_windows = []
for i in range(len(df_tour)):
time_windows.append([convert_clock2seconds(df_tour["open"][i]),
convert_clock2seconds(df_tour["close"][i])])
time_windows = np.array(time_windows)
time_windows -= time_windows[0, 0]
node_feats = {
"time_window": time_windows.clip(0),
"service_time": df_tour["stay_duration (h)"].to_numpy() * 3600
}
st.session_state.node_feats = node_feats
st.session_state.dist_matrix = distmat
st.session_state.node_info = {
"open": df_tour["open"],
"close": df_tour["close"],
"stay": df_tour["stay_duration (h)"]
}
# tour list
if os.path.isfile(TOUR_DISTMAT_PATH) & os.path.isfile(TOUR_LATLNG_PATH):
st.session_state.tour_list = []
for i in range(len(df_tour)):
st.session_state.tour_list.append({
"name": df_tour["destination"][i],
"latlng": (df_tour["lat"][i], df_tour["lng"][i]),
"description": f"<font color='silver'>Hours: {df_tour['open'][i]} - {df_tour['close'][i]}<br>Duration of stay: {df_tour['stay_duration (h)'][i]}h<br>Remarks: {df_tour['remarks'][i]}</font>"
})
def solve_vrp() -> None:
if ("node_feats" in st.session_state) and ("dist_matrix" in st.session_state):
solver = GeneralSolver("tsptw", "ortools", scaling=False)
classifier = GeneralClassifier("tsptw", "gt(ortools)")
routes = solver.solve(node_feats=st.session_state.node_feats,
dist_matrix=st.session_state.dist_matrix)
inputs = classifier.get_inputs(routes,
0,
st.session_state.node_feats,
st.session_state.dist_matrix)
labels = classifier(inputs)
st.session_state.routes = routes.copy()
st.session_state.labels = labels.copy()
st.session_state.generated_actual_route = True
#----------
# LLM
#----------
def load_route_explainer(llm_type: str) -> None:
if st.session_state.openai_key:
# define llm
llm = ChatOpenAI(model=llm_type,
temperature=0,
streaming=True,
model_kwargs={"seed": SEED})
# model_kwargs={"stop": ["\n\n", "Human"]}
# define RouteExplainer
cf_generator = CFTourGenerator(cf_solver=GeneralSolver("tsptw", "ortools", scaling=False))
classifier = GeneralClassifier("tsptw", "gt(ortools)")
st.session_state.route_explainer = RouteExplainer(llm=llm,
cf_generator=cf_generator,
classifier=classifier)
#----------
# UI
#----------
# css settings
st.set_page_config(layout="wide")
util_app.apply_responsible_map_css()
util_app.apply_centerize_icon_css()
util_app.apply_red_code_css()
util_app.apply_remove_sidebar_topspace()
#------------------
# side bar setting
#------------------
with st.sidebar:
#-------
# Title
#-------
icon_col, name_col = st.columns((1,10))
with icon_col:
util_app.apply_html('<a href="https://ntt-dkiku.github.io/xai-vrp" target="_blank"><img src="./app/static/route_explainer_icon.png" alt="RouteExplainer" width="30" height="30" style="margin-top: 20px;"></a>')
with name_col:
st.title("RouteExplainer")
#----------
# API keys
#----------
st.subheader("API keys")
openai_key_col1, openai_key_col2 = st.columns((1,10))
with openai_key_col1:
util_app.apply_html('<a href="https://openai.com/blog/openai-api" target="_blank"> <img src="./app/static/openai_logo.png" alt="OpenAI API" width="30" height="30"> </a>')
with openai_key_col2:
openai_key = st.text_input(label="API keys",
key="openai_key",
placeholder="OpenAI API key",
type="password",
label_visibility="collapsed")
changed_key = openai_key == os.environ.get('OPENAI_API_KEY')
os.environ['OPENAI_API_KEY'] = openai_key
google_key_col1, google_key_col2 = st.columns((1, 10))
with google_key_col1:
util_app.apply_html('<a href="https://developers.google.com/maps?hl=en" target="_blank"> <img src="./app/static/googlemap_logo.png" alt="GoogleMap API" width="30" height="30"> </a>')
with google_key_col2:
st.text_input(label="GoogleMap API key",
key="googleapi_key",
placeholder="NOT required in this demo",
type="password",
label_visibility="collapsed")
#----------------
# Foundation LLM
#----------------
st.subheader("Foundation LLM")
llm_type = st.selectbox("LLM", ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo"], key="llm_type", label_visibility="collapsed")
#-----------
# Tour plan
#-----------
st.subheader("Tour plan")
col1, col2 = st.columns((2, 1))
with col1:
# Comming soon: "Taipei Tour (for PAKDD2024)"
tour_plan = st.selectbox("Tour plan", ["Kyoto Tour"], key="tour_type", label_visibility="collapsed")
with col2:
st.button("Generate", on_click=solve_vrp, use_container_width=True)
# list destinations
load_tour_list()
with st.container():
if "routes" in st.session_state: # rearranage destinations in the route order if a route was derivied
# re-ordered destinations
reordered_tour_list = [st.session_state.tour_list[i] for i in st.session_state.routes[0][:-1]] if "routes" in st.session_state else st.session_state.tour_list
arr_time = datetime.datetime.strptime(st.session_state.node_info["open"][0], "%H:%M")
for step in range(len(reordered_tour_list)):
curr = reordered_tour_list[step]
next = reordered_tour_list[step+1] if step != len(reordered_tour_list) - 1 else reordered_tour_list[0]
curr_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, curr["name"])
next_node_id = util_app.find_node_id_by_name(st.session_state.tour_list, next["name"])
open_time = datetime.datetime.strptime(st.session_state.node_info["open"][curr_node_id], "%H:%M")
# destination info
dep_time = max(arr_time, open_time) + datetime.timedelta(hours=st.session_state.node_info["stay"][curr_node_id])
dep_time_str = dep_time.strftime("%H:%M")
arr_time_str = arr_time.strftime("%H:%M")
arr_dep = f"Arr {arr_time_str} - Dep {dep_time_str}" if step != 0 else f"⭐ Dep {dep_time_str}"
with st.expander(f"{arr_dep} | {curr['name']}", expanded=EXPANDED):
st.write(curr["description"], unsafe_allow_html=True)
# travel time
travel_time = st.session_state.dist_matrix[curr_node_id][next_node_id].item()
col1, col2, col3 = st.columns(3)
with col1:
st.markdown(f"<center>{util_app.add_time_unit(travel_time)}</center>", unsafe_allow_html=True)
with col2:
st.markdown("<center>|</center>", unsafe_allow_html=True)
st.write("")
arr_time = dep_time + datetime.timedelta(seconds=travel_time)
# return to the origin
destination = reordered_tour_list[0]
arr_time_str = arr_time.strftime("%H:%M")
with st.expander(f"⭐ Arr {arr_time_str} | {destination['name']}", expanded=EXPANDED):
st.write(destination["description"], unsafe_allow_html=True)
else: # just list destinations
for destination in st.session_state.tour_list:
with st.expander(destination['name'], expanded=EXPANDED):
st.write(destination["description"], unsafe_allow_html=True)
#----------------------
# state initialization
#----------------------
if "count" not in st.session_state:
st.session_state.count = 0
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
if "generated_actual_route" not in st.session_state:
st.session_state.generated_actual_route = False
if "generated_cf_route" not in st.session_state:
st.session_state.generated_cf_route = False
if "curr_route" not in st.session_state:
st.session_state.curr_route = "Actual Route" # once the CF route is selected, this will be "Current Route"
if "flag_example" not in st.session_state:
st.session_state.flag_example = False
if "selected_example" not in st.session_state:
st.session_state.selected_example = None
if "close_chat" not in st.session_state:
st.session_state.close_chat = False
if "route_explainer" not in st.session_state or llm_type != st.session_state.curr_llm_type or changed_key:
load_route_explainer(llm_type)
st.session_state.curr_llm_type = llm_type
#--------------------------------
# The following is the main page
#--------------------------------
#----------
# Greeding
#----------
if "routes" not in st.session_state:
util_app.apply_html('<center> <img src="./app/static/route_explainer_icon.png" alt="OpenAI API" width="120" height="120"> </center>')
greeding = "Hi, I'm RouteExplainer :)<br>Choose a tour and hit the <code>Generate</code> button to generate your initial route!"
if st.session_state.count == 0:
util_app.stream_words(greeding, prefix="<center><h4>", suffix="</h4></center>", sleep_time=0.02)
else:
util_app.apply_html(f"<center><h4>{greeding}</h4></center>")
#--------------
# chat history
#--------------
def find_last_map(lst: List[Union[str, tuple]]) -> int:
for i in range(len(lst) - 1, -1, -1):
if isinstance(lst[i], tuple):
return i
return None
last_map_idx = find_last_map(st.session_state.chat_history)
for i, msg in enumerate(st.session_state.chat_history):
if isinstance(msg, tuple): # if the history type is a tuple of maps
map1, map2 = (0, 1) if i == last_map_idx else (2, 3)
actual_route, cf_route = st.columns(2)
if msg[map1] is not None:
with actual_route:
util_app.visualize_actual_route(msg[map1])
if msg[map2] is not None:
with cf_route:
util_app.visualize_cf_route(msg[map2])
else: # if the history type is string
if isinstance(msg, AIMessage):
st.chat_message(msg.type, avatar=ROUTE_EXPLAINER_ICON).write(msg.content)
else:
st.chat_message(msg.type).write(msg.content)
# examples
if "cf_routes" not in st.session_state and st.session_state.flag_example:
def pickup_example(example: str):
st.session_state.selected_example = example
examples = [
"Why do we visit Ginkaku-ji Temple from Fushimi-Inari Shrine and why not Kiyomizu-dera Temple?",
"What if we visit Kinkaku-ji directly from Kyoto Geishinkan, instead of Nijo-jo Castle?",
"Why was the edge from Kinkaku-ji to Kiyomizu-dera selected and why not the edge from Kinkaku-ji to Hanamikoji Dori?"
]
col1, col2, col3 = st.columns(3)
with col1:
st.button(examples[0],
use_container_width=True,
on_click=pickup_example,
args=(examples[0], ))
with col2:
st.button(examples[1],
use_container_width=True,
on_click=pickup_example,
args=(examples[1], ))
with col3:
st.button(examples[2],
use_container_width=True,
on_click=pickup_example,
args=(examples[2], ))
#----------
# chat box
#----------
def answer(prompt: str):
st.session_state.chat_history.append(HumanMessage(content=prompt))
st.chat_message("user").write(prompt)
if os.environ.get('OPENAI_API_KEY') == "":
error_msg = "An OpenAI API key has not been set yet :( Please enter a valid key in the side bar!"
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
st.write(error_msg)
st.session_state.chat_history.append(AIMessage(content=error_msg))
elif util_app.validate_openai_api_key(os.environ.get('OPENAI_API_KEY')):
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
explanation = st.session_state.route_explainer.generate_explanation(tour_list=st.session_state.tour_list,
whynot_question=prompt,
actual_routes=st.session_state.routes,
actual_labels=st.session_state.labels,
node_feats=st.session_state.node_feats,
dist_matrix=st.session_state.dist_matrix)
if len(explanation) > 0:
st.session_state.chat_history.append(AIMessage(content=explanation))
if st.session_state.generated_cf_route:
st.rerun()
else:
error_msg = "The input OpenAI API key appears to be invalid :( Please enter a valid key again in the side bar!"
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
st.write(error_msg)
st.session_state.chat_history.append(AIMessage(content=error_msg))
if st.session_state.selected_example is not None:
example = st.session_state.selected_example
st.session_state.selected_example = None
answer(example)
else:
if "routes" in st.session_state and not st.session_state.close_chat:
if prompt := st.chat_input(placeholder="Ask a why-not question", key="chat_input"):
answer(prompt)
#---------------------
# route visualization
#---------------------
if "tour_list" in st.session_state: # if tour info is loaded
# first message
if st.session_state.generated_actual_route: # when an actual route is generated
with st.chat_message("assistant", avatar=ROUTE_EXPLAINER_ICON):
msg = "Here is your initial route. Please ask me a why and why-not question for a specfic edge!"
util_app.stream_words(msg, sleep_time=0.01)
st.session_state.flag_example = True
st.session_state.chat_history.append(AIMessage(content=msg))
# visualize the actual & CF routes
actual_route, cf_route = st.columns(2)
m = None; m2 = None; m_ = None; m2_ = None
if st.session_state.generated_actual_route or st.session_state.generated_cf_route:
m = util_app.initialize_map() # overwrite m
m_ = util_app.initialize_map()
if "labels" in st.session_state:
cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1
util_app.vis_route("routes", st.session_state.labels, m, cf_step, "actual")
util_app.vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False)
with actual_route:
util_app.visualize_actual_route(m)
if st.session_state.generated_cf_route:
m2 = util_app.initialize_map() # overwrite m2
m2_ = util_app.initialize_map()
if "cf_labels" in st.session_state:
util_app.vis_route("cf_routes", st.session_state.cf_labels, m2, st.session_state.cf_step-1, "cf")
util_app.vis_route("cf_routes", st.session_state.cf_labels, m2_, st.session_state.cf_step-1, "cf", ant_path=False)
with cf_route:
util_app.visualize_cf_route(m2)
# update states related to maps
if m is not None:
st.session_state.chat_history.append((m, m2, m_, m2_))
# route selection button
if len(st.session_state.chat_history) > 0:
last_msg = st.session_state.chat_history[-1]
if isinstance(last_msg, tuple):
if (last_msg[0] is not None) and (last_msg[1] is not None):
col1, col2 = st.columns(2)
with col1:
st.button("Stay this route", on_click=util_app.select_actual_route)
with col2:
st.button("Replace with this route", on_click=util_app.select_cf_route)
util_app.change_hover_color("button", "Replace with this route", "#1e90ff")
# for displaying examples
if st.session_state.generated_actual_route:
st.session_state.generated_actual_route = False
st.rerun()
st.session_state.generated_actual_route = False
st.session_state.generated_cf_route = False
# update session count
st.session_state.count += 1
js = f"""
<script>
function scroll(dummy_var_to_force_repeat_execution){{
var textAreas = parent.document.querySelectorAll('section.main');
for (let index = 0; index < textAreas.length; index++) {{
textAreas[index].scrollTop = textAreas[index].scrollHeight;
}}
}}
scroll({len(st.session_state.chat_history)})
</script>
"""
st.components.v1.html(js, height=0, width=0)