Spaces:
Sleeping
Sleeping
# templates | |
import numpy as np | |
import streamlit as st | |
from typing import Dict, List | |
from models.prompts.identify_question import Template4IdentifyQuestion | |
from models.prompts.generate_explanation import Template4GenerateExplanation | |
from langchain.callbacks.base import BaseCallbackHandler | |
from langchain.schema import AIMessage | |
import utils.util_app as util_app | |
class StreamingChatCallbackHandler(BaseCallbackHandler): | |
def __init__(self): | |
pass | |
def on_llm_start(self, *args, **kwargs): | |
self.container = st.empty() | |
self.text = "" | |
def on_llm_new_token(self, token: str, *args, **kwargs): | |
self.text += token | |
self.container.markdown( | |
body=self.text, | |
unsafe_allow_html=False, | |
) | |
def on_llm_end(self, response: str, *args, **kwargs): | |
self.container.markdown( | |
body=response.generations[0][0].text, | |
unsafe_allow_html=False, | |
) | |
class RouteExplainer(): | |
template_identify_question = Template4IdentifyQuestion() | |
template_generate_explanation = Template4GenerateExplanation() | |
def __init__(self, | |
llm, | |
cf_generator, | |
classifier) -> None: | |
assert cf_generator.problem == classifier.problem, "Problem type of cf_generator and predictor should coincide!" | |
self.coord_dim = 2 | |
self.problem = cf_generator.problem | |
self.cf_generator = cf_generator | |
self.classifier = classifier | |
self.actual_route = None | |
self.cf_route = None | |
# templates | |
self.question_extractor = self.template_identify_question.sandwiches(llm) | |
self.explanation_generator = self.template_generate_explanation.sandwiches(llm) | |
#---------------- | |
# whole pipeline | |
#---------------- | |
def generate_explanation(self, | |
tour_list, | |
whynot_question: str, | |
actual_routes: list, | |
actual_labels: list, | |
node_feats: dict, | |
dist_matrix: np.array) -> str: | |
#-------------------------------- | |
# define why & why-not questions | |
#-------------------------------- | |
route_info_text = self.get_route_info_text(tour_list, actual_routes) | |
inputs = self.question_extractor.invoke({ | |
"whynot_question": whynot_question, | |
"route_info": route_info_text | |
}) | |
util_app.stream_words(inputs["summary"] + " " + inputs["intent"]) | |
st.session_state.chat_history.append(AIMessage(content=inputs["summary"] + inputs["intent"])) | |
if not inputs["success"]: | |
return "" | |
#---------------------- | |
# validate the CF edge | |
#---------------------- | |
is_cf_edge_feasible, reason = self.validate_cf_edge(node_feats, | |
dist_matrix, | |
actual_routes[0], | |
inputs["cf_step"], | |
inputs["cf_visit"]-1) | |
# exception | |
if not is_cf_edge_feasible: | |
util_app.stream_words(reason) | |
return reason | |
#--------------------- | |
# generate a cf route | |
#--------------------- | |
cf_routes = self.cf_generator(actual_routes, | |
vehicle_id=0, | |
cf_step=inputs["cf_step"], | |
cf_next_node_id=inputs["cf_visit"]-1, | |
node_feats=node_feats, | |
dist_matrix=dist_matrix) | |
st.session_state.generated_cf_route = True | |
st.session_state.close_chat = True | |
st.session_state.cf_step = inputs["cf_step"] | |
#-------------------------------------- | |
# classify the intentions of each edge | |
#-------------------------------------- | |
cf_labels = self.classifier(self.classifier.get_inputs(cf_routes, | |
0, | |
node_feats, | |
dist_matrix)) | |
st.session_state.cf_routes = cf_routes | |
st.session_state.cf_labels = cf_labels | |
#------------------------------------- | |
# generate a constrastive explanation | |
#------------------------------------- | |
comparison_results = self.get_comparison_results(question_summary=inputs["summary"], | |
tour_list=tour_list, | |
actual_routes=actual_routes, | |
actual_labels=actual_labels, | |
cf_routes=cf_routes, | |
cf_labels=cf_labels, | |
cf_step=inputs["cf_step"]) | |
explanation = self.explanation_generator.invoke({ | |
"comparison_results": comparison_results, | |
"intent": inputs["intent"] | |
}, config={"callbacks": [StreamingChatCallbackHandler()]}) | |
return explanation | |
#------------------------- | |
# for exctracting inputs | |
#------------------------- | |
def get_route_info_text(self, tour_list, routes) -> str: | |
route_info = "" | |
# nodes | |
route_info += "Nodes(node id, name): " | |
for i, destination in enumerate(tour_list): | |
if i != len(tour_list) - 1: | |
route_info += f"({i+1}, {destination['name']}), " | |
else: | |
route_info += f"({i+1}, {destination['name']})\n" | |
# routes | |
route_info += "Route: " | |
for i, node_id in enumerate(routes[0]): | |
if i == 0: | |
route_info += f"{tour_list[node_id]['name']} " | |
else: | |
route_info += f"> (step {i}) > {tour_list[node_id]['name']})" | |
if i == len(routes[0]) - 1: | |
route_info += "\n" | |
else: | |
route_info += " " | |
return route_info | |
#-------------------------- | |
# for validating a CF edge | |
#-------------------------- | |
def validate_cf_edge(self, | |
node_feats: Dict[str, np.array], | |
dist_matrix: np.array, | |
route: List[int], | |
cf_step: int, | |
cf_visit: int) -> bool: | |
# calc current time | |
curr_time = node_feats["time_window"][route[0]][0] # start point's open time | |
for step in range(1, cf_step): | |
curr_node_id = route[step-1] | |
next_node_id = route[step] | |
curr_time += node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id] | |
curr_time = max(curr_time, node_feats["time_window"][next_node_id][0]) # waiting | |
# validate the cf edge | |
curr_node_id = route[cf_step-1] | |
next_node_id = cf_visit | |
next_node_close_time = node_feats["time_window"][next_node_id][1] | |
arrival_time = curr_time + node_feats["service_time"][curr_node_id] + dist_matrix[curr_node_id][next_node_id] | |
if next_node_close_time < arrival_time: | |
exceed_time = (arrival_time - next_node_close_time) | |
return False, f"Oops, your CF edge is infeasible because it does not meet the destination's close time by {util_app.add_time_unit(exceed_time)}." | |
else: | |
return True, "The CF edge is feasible!" | |
#------------------------------- | |
# for generating an explanation | |
#------------------------------- | |
def get_comparison_results(self, | |
tour_list, | |
question_summary, | |
actual_routes: List[List[int]], | |
actual_labels: List[List[int]], | |
cf_routes: List[List[int]], | |
cf_labels: List[List[int]], | |
cf_step: int) -> str: | |
comparison_results = "Question:\n" + question_summary + "\n" | |
comparison_results += "Actual route:\n" + \ | |
self.get_route_info(tour_list, actual_routes[0], actual_labels[0], cf_step-1, "actual") + \ | |
self.get_representative_values(actual_routes[0], actual_labels[0], cf_step-1, "actual") | |
comparison_results += "CF route:\n" + \ | |
self.get_route_info(tour_list, cf_routes[0], cf_labels[0], cf_step-1, "CF") + \ | |
self.get_representative_values(cf_routes[0], cf_labels[0], cf_step-1, "CF") | |
comparison_results += "Difference between two routes:\n" + self.get_diff(cf_step-1, actual_routes[0], cf_routes[0]) | |
comparison_results += "Planed desination information:\n" + self.get_node_info() | |
return comparison_results | |
def get_route_info(self, | |
tour_list, | |
route: List[int], | |
label: List[int], | |
ex_step: int, | |
type: str) -> str: | |
def get_labelname(label_number): | |
return "route_len" if label_number == 0 else "time_window" | |
route_info = "- route: " | |
for i, node_id in enumerate(route): | |
if i == ex_step and i != len(route) - 1: | |
if type == "actual": | |
edge_label = {get_labelname(label[i])} | |
else: | |
edge_label = "user_preference" | |
route_info += f"{tour_list[node_id]['name']} > ({type} edge: {edge_label}) > " | |
elif i != len(route) - 1: | |
route_info += f"{tour_list[node_id]['name']} > ({get_labelname(label[i])}) > " | |
else: | |
route_info += f"{tour_list[node_id]['name']}\n" | |
return route_info | |
def get_representative_values(self, route, labels, ex_step, type) -> str: | |
time_window_ratio = self.get_intention_ratio(1, labels, ex_step) * 100 | |
route_len_ratio = self.get_intention_ratio(0, labels, ex_step) * 100 | |
return f"- short-term effect (immediate travel time): {self.get_immediate_state(route, ex_step)//60} minutes\n- long-term effect (total travel time): {self.get_route_length(route)//60} minutes\n- missed nodes: {self.get_infeasible_node_name(route)}\n- edge-intention ratio after the {type} edge: time_window {time_window_ratio: .1f}%, route_len {route_len_ratio: .1f}%" | |
def get_immediate_state(self, route, ex_step) -> str: | |
return st.session_state.dist_matrix[route[ex_step]][route[ex_step+1]] | |
def get_route_length(self, route) -> float: | |
route_length = 0.0 | |
for i in range(len(route)-1): | |
route_length += st.session_state.dist_matrix[route[i]][route[i+1]] | |
return route_length | |
def get_infeasible_nodes(self, route) -> int: | |
return len(route) - (len(st.session_state.dist_matrix) - 1) | |
def get_infeasible_node_name(self, route) -> str: | |
if len(route) == len(st.session_state.dist_matrix) - 1: | |
return "none" | |
else: | |
num_nodes = np.arange(len(st.session_state.dist_matrix)) | |
for node_id in route: | |
num_nodes = num_nodes[num_nodes != node_id] | |
return ",".join([st.session_state.tour_list[node_id]["name"] for node_id in num_nodes]) | |
def get_intention_ratio(self, | |
intention: int, | |
labels: List[int], | |
ex_step: int) -> float: | |
np_labels = np.array(labels) | |
return np.sum(np_labels[ex_step:] == intention) / len(labels[ex_step:]) | |
def get_diff(self, ex_step, actual_route, cf_route) -> str: | |
def get_str(effect: float): | |
long_effect_str = "The actual route increases it by" if effect > 0 else "The actual route reduces it by" | |
long_effect_str += util_app.add_time_unit(abs(effect)) | |
return long_effect_str | |
def get_str2(num_nodes: int, num_missed_nodes): | |
if num_nodes < 0: | |
num_nodes_str = f"The actual route visits {abs(num_nodes)} more nodes" | |
elif num_nodes == 0: | |
if num_missed_nodes == 0: | |
num_nodes_str = f"Both routes missed no node," | |
else: | |
num_nodes_str = f"Both routes missed the same number of nodes ({abs(num_missed_nodes)} node(s))" | |
else: | |
num_nodes_str = f"The actual route visits {abs(num_nodes)} less nodes" | |
return num_nodes_str | |
# short/long-term effects | |
short_effect = self.get_immediate_state(actual_route, ex_step) - self.get_immediate_state(cf_route, ex_step) | |
long_effect = self.get_route_length(actual_route) - self.get_route_length(cf_route) | |
short_effect_str = get_str(short_effect) | |
long_effect_str = get_str(long_effect) | |
# missed nodes | |
missed_nodes = self.get_infeasible_nodes(actual_route) - self.get_infeasible_nodes(cf_route) | |
missed_nodes_str = get_str2(missed_nodes, self.get_infeasible_nodes(actual_route)) | |
return f"- short-term effect: {short_effect_str}\n - long-term effect: {long_effect_str}\n- missed nodes: {missed_nodes_str}\n" | |
def get_node_info(self) -> str: | |
node_info = "" | |
for i in range(len(st.session_state.df_tour)): | |
node_info += f"- {st.session_state.df_tour['destination'][i]}: {st.session_state.df_tour['remarks'][i]}\n" | |
return node_info |