Spaces:
Sleeping
Sleeping
import openai | |
import time | |
import base64 | |
import copy | |
import streamlit as st | |
import folium | |
from folium.plugins import FloatImage | |
from streamlit_folium import folium_static | |
from langchain.schema import AIMessage | |
#---------------- | |
# util functions | |
#---------------- | |
def apply_html(html: str, **kwargs) -> None: | |
st.markdown(html, unsafe_allow_html=True, **kwargs) | |
def stream_words(words: str, | |
prefix: str = "", | |
suffix: str = "", | |
sleep_time: float = 0.02) -> None: | |
elem = st.empty() | |
_words = "" | |
for word in list(words): | |
_words += word | |
elem.markdown(f"{prefix} {_words} {suffix}", unsafe_allow_html=True) | |
time.sleep(sleep_time) | |
def add_time_unit(time: float) -> str: | |
if time < 60: | |
return f"{time} sec" | |
elif time >= 60 and time < 3599: | |
return f"{time/60:.2f} mins" | |
else: | |
return f"{time/3600:.2f} hours" | |
def find_node_id_by_name(data_list, name) -> int: | |
for index, item in enumerate(data_list): | |
if item.get("name") == name: | |
return index | |
return -1 # Return -1 if the name is not found | |
def mark_destination(tour_list, m) -> None: | |
for destination in tour_list: | |
image_name = f"static/{destination['name'].replace(' ', '-')}.png" | |
if destination["name"] == "Ryoanji Temple": | |
icon_width = 50; icon_height = 35 | |
elif destination["name"] == "Ryoanji Temple": | |
icon_width = 50; icon_height = 40 | |
elif destination["name"] == "Kyoto Geishinkan": | |
icon_width = 40; icon_height = 40 | |
elif destination["name"] == "Nijo-jo Castle": | |
icon_width = 45; icon_height = 35 | |
else: | |
icon_width = 50; icon_height = 50 | |
folium.Marker( | |
location=[destination["latlng"][0], destination["latlng"][1]], | |
tooltip=destination["name"], | |
icon=folium.features.CustomIcon(icon_image = image_name, | |
icon_size = (icon_width, icon_height), | |
icon_anchor = (30, 30), | |
popup_anchor = (3, 3)) | |
).add_to(m) | |
# add an indicator of the start/end point to Kyoto Station | |
# z_indedx_offset: https://github.com/python-visualization/folium/issues/1281 | |
destination = tour_list[0] | |
folium.Marker( | |
location=[destination["latlng"][0]+0.003, destination["latlng"][1]-0.003], | |
tooltip="start/end point", | |
icon=folium.features.CustomIcon(icon_image = "static/star_emoji.png", | |
icon_size = (30, 30), | |
icon_anchor = (30, 30), | |
popup_anchor = (3, 3)), | |
z_index_offset=10000 | |
).add_to(m) | |
# add legend | |
# Ref: https://python-visualization.github.io/folium/latest/user_guide/plugins/float_image.html | |
with open("static/legend.png", "rb") as lf: | |
# open in binary mode, read bytes, encode, decode obtained bytes as utf-8 string | |
b64_content = base64.b64encode(lf.read()).decode("utf-8") | |
FloatImage("data:image/png;base64,{}".format(b64_content), bottom=1, left=1).add_to(m) | |
def initialize_map() -> folium.Map: | |
m = folium.Map(location=[st.session_state.lat_mean, st.session_state.lng_mean], tiles="Cartodb Positron") | |
m.fit_bounds([st.session_state.sw, st.session_state.ne]) | |
mark_destination(st.session_state.tour_list, m) | |
return m | |
def vis_route(routes, labels, m, ex_step, route_type, ant_path=True) -> None: | |
if ("tour_list" in st.session_state) and (routes in st.session_state): | |
tour_list = st.session_state.tour_list | |
for j, route in enumerate(st.session_state[routes]): # vehicle loop | |
for i in range(len(route)): # edge loop | |
if i < len(route) - 1: | |
if i == ex_step: | |
if route_type == "actual": | |
color = "red" | |
popup = "Actual edge" | |
else: | |
color = "blue" | |
popup = "CF edge" | |
else: | |
if labels[j][i] == 0: | |
color = "#2ca02c" | |
popup = "Route length priority" | |
else: | |
color = "#9467bd" | |
popup = "Time window priority" | |
origin_id = route[i] | |
dest_id = route[i+1] | |
origin_latlng = tour_list[origin_id]["latlng"] | |
dest_latlng = tour_list[dest_id]["latlng"] | |
line = folium.PolyLine([origin_latlng, dest_latlng], color=color).add_to(m) | |
if ant_path: | |
folium.plugins.AntPath([origin_latlng, dest_latlng], tooltip=popup, color=color).add_to(m) | |
else: | |
folium.plugins.PolyLineTextPath(line, "> ", offset=11, repeat=True, attributes={"fill": color, | |
"font-size": 30}).add_to(m) | |
def visualize_actual_route(m: folium.Map) -> None: | |
with st.columns((0.2, 1, 0.1))[1]: | |
st.subheader(f"{st.session_state.curr_route}", divider="red") | |
folium_static(m) | |
def visualize_cf_route(m: folium.Map) -> None: | |
with st.columns((0.1, 1, 0.2))[1]: | |
st.subheader("CF route", divider="blue") | |
folium_static(m) | |
def select_actual_route(): | |
route_name = "the actual route" if st.session_state.curr_route == "Actual Route" else "your current route" | |
msg = f"You chose to stay {route_name}. Feel free to ask another why and why-not question for your current route!" | |
st.session_state.chat_history.append(AIMessage(content=msg)) | |
m = initialize_map() | |
m_ = initialize_map() | |
if "labels" in st.session_state: | |
cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1 | |
vis_route("routes", st.session_state.labels, m, cf_step, "actual") | |
vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False) | |
st.session_state.chat_history.append((m, None, m_, None)) | |
st.session_state.close_chat = False | |
def select_cf_route(): | |
msg = "You chose to replace your current route with the CF route. Feel free to ask a why and why-not question for the CF route!" | |
st.session_state.chat_history.append(AIMessage(content=msg)) | |
# replace the actual route & labels with the CF ones | |
st.session_state.routes = copy.deepcopy(st.session_state.cf_routes) | |
st.session_state.labels = copy.deepcopy(st.session_state.cf_labels) | |
st.session_state.curr_route = "Current Route" | |
m = initialize_map() | |
m_ = initialize_map() | |
if "cf_labels" in st.session_state: | |
cf_step = st.session_state.cf_step-1 if st.session_state.generated_cf_route else -1 | |
vis_route("routes", st.session_state.labels, m, cf_step, "actual") | |
vis_route("routes", st.session_state.labels, m_, cf_step, "actual", ant_path=False) | |
st.session_state.chat_history.append((m, None, m_, None)) | |
st.session_state.close_chat = False | |
# ref: https://stackoverflow.com/questions/76522693/how-to-check-the-validity-of-the-openai-key-from-python | |
def validate_openai_api_key(api_key): | |
client = openai.OpenAI(api_key=api_key) | |
try: | |
client.models.list() | |
except openai.AuthenticationError: | |
return False | |
else: | |
return True | |
#----- | |
# CSS | |
#----- | |
RESPONSIBLE_MAP = """\ | |
<style> | |
[title~="st.iframe"]:not([height="0"]) { | |
width: 100%; | |
aspect-ratio: 1; | |
height: auto; | |
} | |
</style> | |
""" | |
def apply_responsible_map_css() -> None: | |
apply_html(RESPONSIBLE_MAP) | |
CENTERIZE_INCON = """ | |
<style> | |
[data-testid="stHorizontalBlock"] { | |
align-items: center; | |
} | |
.stButton { | |
text-align:center | |
} | |
</style> | |
""" | |
def apply_centerize_icon_css() -> None: | |
apply_html(CENTERIZE_INCON) | |
RED_CODE = """\ | |
<style> | |
code { | |
color: #cc3333; | |
} | |
</style> | |
""" | |
def apply_red_code_css() -> None: | |
apply_html(RED_CODE) | |
REMOVE_SIDEBAR_TOPSPACE = """\ | |
<style> | |
.st-emotion-cache-6qob1r.eczjsme3 { | |
margin-top: -75px; | |
} | |
.st-emotion-cache-1b9x38r.eczjsme2 { | |
margin-top: 75px | |
} | |
</style> | |
""" | |
def apply_remove_sidebar_topspace() -> None: | |
apply_html(REMOVE_SIDEBAR_TOPSPACE) | |
#---- | |
# JS | |
#---- | |
# SET_WORDS_TO_CHATBOX = """\ | |
# <script> | |
# function insertText() {{ | |
# var chatInput = parent.document.querySelector('textarea[data-testid="stChatInput"]'); | |
# var nativeInputValueSetter = Object.getOwnPropertyDescriptor(window.HTMLTextAreaElement.prototype, "value").set; | |
# nativeInputValueSetter.call(chatInput, "{words}"); | |
# var event = new Event('input', {{bubbles: true}}); | |
# chatInput.dispatchEvent(event); | |
# }} | |
# insertText(); | |
# </script> | |
# """ | |
# def set_chat_input(words:str) -> None: | |
# st.components.v1.html(SET_WORDS_TO_CHATBOX.format(words=words), height=0) | |
# Ref. https://discuss.streamlit.io/t/issues-with-background-colour-for-buttons/38723/8 | |
# Ref. https://github.com/streamlit/streamlit/issues/6605 | |
CHANGE_HOVER_COLOR = """\ | |
<script> | |
var hide_me_list = window.parent.document.querySelectorAll('iframe'); | |
for (let i = 0; i < hide_me_list.length; i++) {{ | |
if (hide_me_list[i].height == 0) {{ | |
hide_me_list[i].parentNode.style.height = 0; | |
hide_me_list[i].parentNode.style.marginBottom = '-1rem'; | |
}}; | |
}}; | |
if (window.matchMedia('(prefers-color-scheme: dark)').matches) {{ | |
var border = 'rgb(250,250,250,.2)'; | |
}} else {{ | |
var border = 'rgb(49,51,63,.2)'; | |
}} | |
var elements = window.parent.document.querySelectorAll('{widget_type}'); | |
var fontColor = window.getComputedStyle(elements[0]).color; | |
for (var i = 0; i < elements.length; ++i) {{ | |
if (elements[i].innerText == '{widget_label}') {{ | |
elements[i].style.color = fontColor; | |
elements[i].style.background = '{background_color}'; | |
elements[i].onmouseover = function() {{ | |
this.style.color = '{hover_color}'; | |
this.style.borderColor = '{hover_color}'; | |
}}; | |
elements[i].onmouseout = function() {{ | |
this.style.color = fontColor; | |
this.style.borderColor = border; | |
this.style.background = '{background_color}'; | |
}}; | |
elements[i].onclick = function() {{ | |
this.style.color = 'white'; | |
this.style.borderColor = '{hover_color}'; | |
this.style.background = '{hover_color}'; | |
}}; | |
elements[i].onfocus = function() {{ | |
this.style.boxShadow = '{hover_color} 0px 0px 0px 0.2rem'; | |
this.style.borderColor = '{hover_color}'; | |
this.style.color = '{hover_color}'; | |
}}; | |
elements[i].onblur = function() {{ | |
this.style.boxShadow = 'none'; | |
this.style.color = fontColor; | |
this.style.borderColor = border; | |
}}; | |
}} | |
}} | |
</script> | |
""" | |
def change_hover_color(widget_type: str, | |
widget_label: str, | |
hover_color: str, | |
background_color: str = ""): | |
st.components.v1.html(CHANGE_HOVER_COLOR.format(widget_type=widget_type, | |
widget_label=widget_label, | |
hover_color=hover_color, | |
background_color=background_color), height=0) |