route-explainer / models /route_explainer.py
daisuke.kikuta
first commit
719d0db
# templates
import numpy as np
import streamlit as st
from typing import Dict, List
from models.prompts.identify_question import Template4IdentifyQuestion
from models.prompts.generate_explanation import Template4GenerateExplanation
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AIMessage
import utils.util_app as util_app
class StreamingChatCallbackHandler(BaseCallbackHandler):
def __init__(self):
pass
def on_llm_start(self, *args, **kwargs):
self.container = st.empty()
self.text = ""
def on_llm_new_token(self, token: str, *args, **kwargs):
self.text += token
self.container.markdown(
body=self.text,
unsafe_allow_html=False,
)
def on_llm_end(self, response: str, *args, **kwargs):
self.container.markdown(
body=response.generations[0][0].text,
unsafe_allow_html=False,
)
class RouteExplainer():
template_identify_question = Template4IdentifyQuestion()
template_generate_explanation = Template4GenerateExplanation()
def __init__(self,
llm,
cf_generator,
classifier) -> None:
assert cf_generator.problem == classifier.problem, "Problem type of cf_generator and predictor should coincide!"
self.coord_dim = 2
self.problem = cf_generator.problem
self.cf_generator = cf_generator
self.classifier = classifier
self.actual_route = None
self.cf_route = None
# templates
self.question_extractor = self.template_identify_question.sandwiches(llm)
self.explanation_generator = self.template_generate_explanation.sandwiches(llm)
#----------------
# whole pipeline
#----------------
def generate_explanation(self,
tour_list,
whynot_question: str,
actual_routes: list,
actual_labels: list,
node_feats: dict,
dist_matrix: np.array) -> str:
#--------------------------------
# define why & why-not questions
#--------------------------------
route_info_text = self.get_route_info_text(tour_list, actual_routes)
inputs = self.question_extractor.invoke({
"whynot_question": whynot_question,
"route_info": route_info_text
})
util_app.stream_words(inputs["summary"] + " " + inputs["intent"])
st.session_state.chat_history.append(AIMessage(content=inputs["summary"] + inputs["intent"]))
if not inputs["success"]:
return ""
#----------------------
# validate the CF edge
#----------------------
is_cf_edge_feasible, reason = self.validate_cf_edge(node_feats,
dist_matrix,
actual_routes[0],
inputs["cf_step"],
inputs["cf_visit"]-1)
# exception
if not is_cf_edge_feasible:
util_app.stream_words(reason)
return reason
#---------------------
# generate a cf route
#---------------------
cf_routes = self.cf_generator(actual_routes,
vehicle_id=0,
cf_step=inputs["cf_step"],
cf_next_node_id=inputs["cf_visit"]-1,
node_feats=node_feats,
dist_matrix=dist_matrix)
st.session_state.generated_cf_route = True
st.session_state.close_chat = True
st.session_state.cf_step = inputs["cf_step"]
#--------------------------------------
# classify the intentions of each edge
#--------------------------------------
cf_labels = self.classifier(self.classifier.get_inputs(cf_routes,
0,
node_feats,
dist_matrix))
st.session_state.cf_routes = cf_routes
st.session_state.cf_labels = cf_labels
#-------------------------------------
# generate a constrastive explanation
#-------------------------------------
comparison_results = self.get_comparison_results(question_summary=inputs["summary"],
tour_list=tour_list,
actual_routes=actual_routes,
actual_labels=actual_labels,
cf_routes=cf_routes,
cf_labels=cf_labels,
cf_step=inputs["cf_step"])
explanation = self.explanation_generator.invoke({
"comparison_results": comparison_results,
"intent": inputs["intent"]
}, config={"callbacks": [StreamingChatCallbackHandler()]})
return explanation
#-------------------------
# for exctracting inputs
#-------------------------
def get_route_info_text(self, tour_list, routes) -> str:
route_info = ""
# nodes
route_info += "Nodes(node id, name): "
for i, destination in enumerate(tour_list):
if i != len(tour_list) - 1:
route_info += f"({i+1}, {destination['name']}), "
else:
route_info += f"({i+1}, {destination['name']})\n"
# routes
route_info += "Route: "
for i, node_id in enumerate(routes[0]):
if i == 0:
route_info += f"{tour_list[node_id]['name']} "
else:
route_info += f"> (step {i}) > {tour_list[node_id]['name']})"
if i == len(routes[0]) - 1:
route_info += "\n"
else:
route_info += " "
return route_info
#--------------------------
# for validating a CF edge
#--------------------------
def validate_cf_edge(self,
node_feats: Dict[str, np.array],
dist_matrix: np.array,
route: List[int],
cf_step: int,
cf_visit: int) -> bool:
# calc current time
curr_time = node_feats["time_window"][route[0]][0] # start point's open time
for step in range(1, cf_step):
curr_node_id = route[step-1]
next_node_id = route[step]
curr_time += node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
curr_time = max(curr_time, node_feats["time_window"][next_node_id][0]) # waiting
# validate the cf edge
curr_node_id = route[cf_step-1]
next_node_id = cf_visit
next_node_close_time = node_feats["time_window"][next_node_id][1]
arrival_time = curr_time + node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id]
if next_node_close_time < arrival_time:
exceed_time = (arrival_time - next_node_close_time)
return False, f"Oops, your CF edge is infeasible because it does not meet the destination's close time by {util_app.add_time_unit(exceed_time)}."
else:
return True, "The CF edge is feasible!"
#-------------------------------
# for generating an explanation
#-------------------------------
def get_comparison_results(self,
tour_list,
question_summary,
actual_routes: List[List[int]],
actual_labels: List[List[int]],
cf_routes: List[List[int]],
cf_labels: List[List[int]],
cf_step: int) -> str:
comparison_results = "Question:\n" + question_summary + "\n"
comparison_results += "Actual route:\n" + \
self.get_route_info(tour_list, actual_routes[0], actual_labels[0], cf_step-1, "actual") + \
self.get_representative_values(actual_routes[0], actual_labels[0], cf_step-1, "actual")
comparison_results += "CF route:\n" + \
self.get_route_info(tour_list, cf_routes[0], cf_labels[0], cf_step-1, "CF") + \
self.get_representative_values(cf_routes[0], cf_labels[0], cf_step-1, "CF")
comparison_results += "Difference between two routes:\n" + self.get_diff(cf_step-1, actual_routes[0], cf_routes[0])
comparison_results += "Planed desination information:\n" + self.get_node_info()
return comparison_results
def get_route_info(self,
tour_list,
route: List[int],
label: List[int],
ex_step: int,
type: str) -> str:
def get_labelname(label_number):
return "route_len" if label_number == 0 else "time_window"
route_info = "- route: "
for i, node_id in enumerate(route):
if i == ex_step and i != len(route) - 1:
if type == "actual":
edge_label = {get_labelname(label[i])}
else:
edge_label = "user_preference"
route_info += f"{tour_list[node_id]['name']} > ({type} edge: {edge_label}) > "
elif i != len(route) - 1:
route_info += f"{tour_list[node_id]['name']} > ({get_labelname(label[i])}) > "
else:
route_info += f"{tour_list[node_id]['name']}\n"
return route_info
def get_representative_values(self, route, labels, ex_step, type) -> str:
time_window_ratio = self.get_intention_ratio(1, labels, ex_step) * 100
route_len_ratio = self.get_intention_ratio(0, labels, ex_step) * 100
return f"- short-term effect (immediate travel time): {self.get_immediate_state(route, ex_step)//60} minutes\n- long-term effect (total travel time): {self.get_route_length(route)//60} minutes\n- missed nodes: {self.get_infeasible_node_name(route)}\n- edge-intention ratio after the {type} edge: time_window {time_window_ratio: .1f}%, route_len {route_len_ratio: .1f}%"
def get_immediate_state(self, route, ex_step) -> str:
return st.session_state.dist_matrix[route[ex_step]][route[ex_step+1]]
def get_route_length(self, route) -> float:
route_length = 0.0
for i in range(len(route)-1):
route_length += st.session_state.dist_matrix[route[i]][route[i+1]]
return route_length
def get_infeasible_nodes(self, route) -> int:
return len(route) - (len(st.session_state.dist_matrix) - 1)
def get_infeasible_node_name(self, route) -> str:
if len(route) == len(st.session_state.dist_matrix) - 1:
return "none"
else:
num_nodes = np.arange(len(st.session_state.dist_matrix))
for node_id in route:
num_nodes = num_nodes[num_nodes != node_id]
return ",".join([st.session_state.tour_list[node_id]["name"] for node_id in num_nodes])
def get_intention_ratio(self,
intention: int,
labels: List[int],
ex_step: int) -> float:
np_labels = np.array(labels)
return np.sum(np_labels[ex_step:] == intention) / len(labels[ex_step:])
def get_diff(self, ex_step, actual_route, cf_route) -> str:
def get_str(effect: float):
long_effect_str = "The actual route increases it by" if effect > 0 else "The actual route reduces it by"
long_effect_str += util_app.add_time_unit(abs(effect))
return long_effect_str
def get_str2(num_nodes: int, num_missed_nodes):
if num_nodes < 0:
num_nodes_str = f"The actual route visits {abs(num_nodes)} more nodes"
elif num_nodes == 0:
if num_missed_nodes == 0:
num_nodes_str = f"Both routes missed no node,"
else:
num_nodes_str = f"Both routes missed the same number of nodes ({abs(num_missed_nodes)} node(s))"
else:
num_nodes_str = f"The actual route visits {abs(num_nodes)} less nodes"
return num_nodes_str
# short/long-term effects
short_effect = self.get_immediate_state(actual_route, ex_step) - self.get_immediate_state(cf_route, ex_step)
long_effect = self.get_route_length(actual_route) - self.get_route_length(cf_route)
short_effect_str = get_str(short_effect)
long_effect_str = get_str(long_effect)
# missed nodes
missed_nodes = self.get_infeasible_nodes(actual_route) - self.get_infeasible_nodes(cf_route)
missed_nodes_str = get_str2(missed_nodes, self.get_infeasible_nodes(actual_route))
return f"- short-term effect: {short_effect_str}\n - long-term effect: {long_effect_str}\n- missed nodes: {missed_nodes_str}\n"
def get_node_info(self) -> str:
node_info = ""
for i in range(len(st.session_state.df_tour)):
node_info += f"- {st.session_state.df_tour['destination'][i]}: {st.session_state.df_tour['remarks'][i]}\n"
return node_info