route-explainer / utils /util_app.py
daisuke.kikuta
first commit
719d0db
raw
history blame
11.8 kB
import openai
import time
import base64
import copy
import streamlit as st
import folium
from folium.plugins import FloatImage
from streamlit_folium import folium_static
from langchain.schema import AIMessage
#----------------
# util functions
#----------------
def apply_html(html: str, **kwargs) -> None:
st.markdown(html, unsafe_allow_html=True, **kwargs)
def stream_words(words: str,
prefix: str = "",
suffix: str = "",
sleep_time: float = 0.02) -> None:
elem = st.empty()
_words = ""
for word in list(words):
_words += word
elem.markdown(f"{prefix} {_words} {suffix}", unsafe_allow_html=True)
time.sleep(sleep_time)
def add_time_unit(time: float) -> str:
if time < 60:
return f"{time} sec"
elif time >= 60 and time < 3599:
return f"{time/60:.2f} mins"
else:
return f"{time/3600:.2f} hours"
def find_node_id_by_name(data_list, name) -> int:
for index, item in enumerate(data_list):
if item.get("name") == name:
return index
return -1 # Return -1 if the name is not found
def mark_destination(tour_list, m) -> None:
for destination in tour_list:
image_name = f"static/{destination['name'].replace(' ', '-')}.png"
if destination["name"] == "Ryoanji Temple":
icon_width = 50; icon_height = 35
elif destination["name"] == "Ryoanji Temple":
icon_width = 50; icon_height = 40
elif destination["name"] == "Kyoto Geishinkan":
icon_width = 40; icon_height = 40
elif destination["name"] == "Nijo-jo Castle":
icon_width = 45; icon_height = 35
else:
icon_width = 50; icon_height = 50
folium.Marker(
location=[destination["latlng"][0], destination["latlng"][1]],
tooltip=destination["name"],
icon=folium.features.CustomIcon(icon_image = image_name,
icon_size = (icon_width, icon_height),
icon_anchor = (30, 30),
popup_anchor = (3, 3))
).add_to(m)
# add an indicator of the start/end point to Kyoto Station
# z_indedx_offset: https://github.com/python-visualization/folium/issues/1281
destination = tour_list[0]
folium.Marker(
location=[destination["latlng"][0]+0.003, destination["latlng"][1]-0.003],
tooltip="start/end point",
icon=folium.features.CustomIcon(icon_image = "static/star_emoji.png",
icon_size = (30, 30),
icon_anchor = (30, 30),
popup_anchor = (3, 3)),
z_index_offset=10000
).add_to(m)
# add legend
# Ref: https://python-visualization.github.io/folium/latest/user_guide/plugins/float_image.html
with open("static/legend.png", "rb") as lf:
# open in binary mode, read bytes, encode, decode obtained bytes as utf-8 string
b64_content = base64.b64encode(lf.read()).decode("utf-8")
FloatImage("data:image/png;base64,{}".format(b64_content), bottom=1, left=1).add_to(m)
def initialize_map() -> folium.Map:
m = folium.Map(location=[st.session_state.lat_mean, st.session_state.lng_mean], tiles="Cartodb Positron")
m.fit_bounds([st.session_state.sw, st.session_state.ne])
mark_destination(st.session_state.tour_list, m)
return m
def vis_route(routes, labels, m, ex_step, route_type, ant_path=True) -> None:
if ("tour_list" in st.session_state) and (routes in st.session_state):
tour_list = st.session_state.tour_list
for j, route in enumerate(st.session_state[routes]): # vehicle loop
for i in range(len(route)): # edge loop
if i < len(route) - 1:
if i == ex_step:
if route_type == "actual":
color = "red"
popup = "Actual edge"
else:
color = "blue"
popup = "CF edge"
else:
if labels[j][i] == 0:
color = "#2ca02c"
popup = "Route length priority"
else:
color = "#9467bd"
popup = "Time window priority"
origin_id = route[i]
dest_id = route[i+1]
origin_latlng = tour_list[origin_id]["latlng"]
dest_latlng = tour_list[dest_id]["latlng"]
line = folium.PolyLine([origin_latlng, dest_latlng], color=color).add_to(m)
if ant_path:
folium.plugins.AntPath([origin_latlng, dest_latlng], tooltip=popup, color=color).add_to(m)
else:
folium.plugins.PolyLineTextPath(line, "> ", offset=11, repeat=True, attributes={"fill": color,
"font-size": 30}).add_to(m)
def visualize_actual_route(m: folium.Map) -> None:
with st.columns((0.2, 1, 0.1))[1]:
st.subheader(f"{st.session_state.curr_route}", divider="red")
folium_static(m)
def visualize_cf_route(m: folium.Map) -> None:
with st.columns((0.1, 1, 0.2))[1]:
st.subheader("CF route", divider="blue")
folium_static(m)
def select_actual_route():
route_name = "the actual route" if st.session_state.curr_route == "Actual Route" else "your current route"
msg = f"You chose to stay {route_name}. Feel free to ask another why and why-not question for your current route!"
st.session_state.chat_history.append(AIMessage(content=msg))
m = initialize_map()
m_ = initialize_map()
if "labels" in st.session_state:
cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1
vis_route("routes", st.session_state.labels, m, cf_step, "actual")
vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False)
st.session_state.chat_history.append((m, None, m_, None))
st.session_state.close_chat = False
def select_cf_route():
msg = "You chose to replace your current route with the CF route. Feel free to ask a why and why-not question for the CF route!"
st.session_state.chat_history.append(AIMessage(content=msg))
# replace the actual route & labels with the CF ones
st.session_state.routes = copy.deepcopy(st.session_state.cf_routes)
st.session_state.labels = copy.deepcopy(st.session_state.cf_labels)
st.session_state.curr_route = "Current Route"
m = initialize_map()
m_ = initialize_map()
if "cf_labels" in st.session_state:
cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1
vis_route("routes", st.session_state.labels, m, cf_step, "actual")
vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False)
st.session_state.chat_history.append((m, None, m_, None))
st.session_state.close_chat = False
# ref: https://stackoverflow.com/questions/76522693/how-to-check-the-validity-of-the-openai-key-from-python
def validate_openai_api_key(api_key):
client = openai.OpenAI(api_key=api_key)
try:
client.models.list()
except openai.AuthenticationError:
return False
else:
return True
#-----
# CSS
#-----
RESPONSIBLE_MAP = """\
<style>
[title~="st.iframe"]:not([height="0"]) {
width: 100%;
aspect-ratio: 1;
height: auto;
}
</style>
"""
def apply_responsible_map_css() -> None:
apply_html(RESPONSIBLE_MAP)
CENTERIZE_INCON = """
<style>
[data-testid="stHorizontalBlock"] {
align-items: center;
}
.stButton {
text-align:center
}
</style>
"""
def apply_centerize_icon_css() -> None:
apply_html(CENTERIZE_INCON)
RED_CODE = """\
<style>
code {
color: #cc3333;
}
</style>
"""
def apply_red_code_css() -> None:
apply_html(RED_CODE)
REMOVE_SIDEBAR_TOPSPACE = """\
<style>
.st-emotion-cache-6qob1r.eczjsme3 {
margin-top: -75px;
}
.st-emotion-cache-1b9x38r.eczjsme2 {
margin-top: 75px
}
</style>
"""
def apply_remove_sidebar_topspace() -> None:
apply_html(REMOVE_SIDEBAR_TOPSPACE)
#----
# JS
#----
# SET_WORDS_TO_CHATBOX = """\
# <script>
# function insertText() {{
# var chatInput = parent.document.querySelector('textarea[data-testid="stChatInput"]');
# var nativeInputValueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, "value").set;
# nativeInputValueSetter.call(chatInput, "{words}");
# var event = new Event('input', {{bubbles: true}});
# chatInput.dispatchEvent(event);
# }}
# insertText();
# </script>
# """
# def set_chat_input(words:str) -> None:
# st.components.v1.html(SET_WORDS_TO_CHATBOX.format(words=words), height=0)
# Ref. https://discuss.streamlit.io/t/issues-with-background-colour-for-buttons/38723/8
# Ref. https://github.com/streamlit/streamlit/issues/6605
CHANGE_HOVER_COLOR = """\
<script>
var hide_me_list = window.parent.document.querySelectorAll('iframe');
for (let i = 0; i < hide_me_list.length; i++) {{
if (hide_me_list[i].height == 0) {{
hide_me_list[i].parentNode.style.height = 0;
hide_me_list[i].parentNode.style.marginBottom = '-1rem';
}};
}};
if (window.matchMedia('(prefers-color-scheme: dark)').matches) {{
var border = 'rgb(250,250,250,.2)';
}} else {{
var border = 'rgb(49,51,63,.2)';
}}
var elements = window.parent.document.querySelectorAll('{widget_type}');
var fontColor = window.getComputedStyle(elements[0]).color;
for (var i = 0; i < elements.length; ++i) {{
if (elements[i].innerText == '{widget_label}') {{
elements[i].style.color = fontColor;
elements[i].style.background = '{background_color}';
elements[i].onmouseover = function() {{
this.style.color = '{hover_color}';
this.style.borderColor = '{hover_color}';
}};
elements[i].onmouseout = function() {{
this.style.color = fontColor;
this.style.borderColor = border;
this.style.background = '{background_color}';
}};
elements[i].onclick = function() {{
this.style.color = 'white';
this.style.borderColor = '{hover_color}';
this.style.background = '{hover_color}';
}};
elements[i].onfocus = function() {{
this.style.boxShadow = '{hover_color} 0px 0px 0px 0.2rem';
this.style.borderColor = '{hover_color}';
this.style.color = '{hover_color}';
}};
elements[i].onblur = function() {{
this.style.boxShadow = 'none';
this.style.color = fontColor;
this.style.borderColor = border;
}};
}}
}}
</script>
"""
def change_hover_color(widget_type: str,
widget_label: str,
hover_color: str,
background_color: str = ""):
st.components.v1.html(CHANGE_HOVER_COLOR.format(widget_type=widget_type,
widget_label=widget_label,
hover_color=hover_color,
background_color=background_color), height=0)