LennardZuendorf commited on
Commit
58a02af
1 Parent(s): 0367ac5

fix: updating explanation component

Browse files
Files changed (3) hide show
  1. backend/controller.py +3 -3
  2. explanation/markup.py +26 -22
  3. main.py +8 -6
backend/controller.py CHANGED
@@ -6,7 +6,7 @@ import gradio as gr
6
 
7
  # internal imports
8
  from model import godel
9
- from explanation import interpret, visualize
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
@@ -28,9 +28,9 @@ def interference(
28
  if xai_selection in ("SHAP", "Visualizer"):
29
  match xai_selection.lower():
30
  case "shap":
31
- xai = interpret
32
  case "visualizer":
33
- xai = visualize
34
  case _:
35
  # use Gradio warning to display error message
36
  gr.Warning(f"""
 
6
 
7
  # internal imports
8
  from model import godel
9
+ from explanation import interpret_shap as sint, visualize as viz
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
 
28
  if xai_selection in ("SHAP", "Visualizer"):
29
  match xai_selection.lower():
30
  case "shap":
31
+ xai = sint
32
  case "visualizer":
33
+ xai = viz
34
  case _:
35
  # use Gradio warning to display error message
36
  gr.Warning(f"""
explanation/markup.py CHANGED
@@ -9,7 +9,7 @@ from utils import formatting as fmt
9
 
10
 
11
  def markup_text(input_text: list, text_values: ndarray, variant: str):
12
- buckets = 10
13
 
14
  # Flatten the explanations values
15
  if variant == "shap":
@@ -21,39 +21,43 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
21
 
22
  # Separate the threshold calculation for negative and positive values
23
  if variant == "visualizer":
24
- thresholds = np.linspace(min_val, max_val, num=buckets, endpoint=False)[1:]
 
 
25
  else:
26
- neg_thresholds = np.linspace(min_val, 0, num=buckets // 2 + 1, endpoint=False)[
27
- 1:
28
- ]
29
- pos_thresholds = np.linspace(0, max_val, num=buckets // 2 + 1)[1:]
30
- thresholds = np.concatenate([neg_thresholds, pos_thresholds])
31
 
32
  marked_text = []
33
 
34
  # Function to determine the bucket for a given value
35
  for text, value in zip(input_text, text_values):
36
- bucket = 0
37
- for i, threshold in enumerate(thresholds, start=1):
38
- if value > threshold:
39
  bucket = i
40
  marked_text.append((text, str(bucket)))
41
 
 
 
42
  return marked_text
43
 
44
 
45
  def color_codes():
46
  return {
47
- # 1-5: Strong Light Red to Lighter Red
48
- "1": "#FF6666", # Strong Light Red
49
- "2": "#FF8080", # Slightly Lighter Red
50
- "3": "#FF9999", # Intermediate Light Red
51
- "4": "#FFB3B3", # Light Red
52
- "5": "#FFCCCC", # Very Light Red
53
- # 6-10: Light Green to Strong Light Green
54
- "6": "#B3FFB3", # Light Green
55
- "7": "#99FF99", # Slightly Stronger Green
56
- "8": "#80FF80", # Intermediate Green
57
- "9": "#66FF66", # Strong Green
58
- "10": "#4DFF4D", # Very Strong Green
59
  }
 
9
 
10
 
11
  def markup_text(input_text: list, text_values: ndarray, variant: str):
12
+ bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
13
 
14
  # Flatten the explanations values
15
  if variant == "shap":
 
21
 
22
  # Separate the threshold calculation for negative and positive values
23
  if variant == "visualizer":
24
+ neg_thresholds = np.linspace(
25
+ 0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
26
+ )[1:]
27
  else:
28
+ neg_thresholds = np.linspace(
29
+ min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
30
+ )[1:]
31
+ pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
32
+ thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
33
 
34
  marked_text = []
35
 
36
  # Function to determine the bucket for a given value
37
  for text, value in zip(input_text, text_values):
38
+ bucket = "-5"
39
+ for i, threshold in zip(bucket_tags, thresholds):
40
+ if value >= threshold:
41
  bucket = i
42
  marked_text.append((text, str(bucket)))
43
 
44
+ print(thresholds)
45
+ print(marked_text)
46
  return marked_text
47
 
48
 
49
  def color_codes():
50
  return {
51
+ # 1-5: Strong Light Sky Blue to Lighter Sky Blue
52
+ "-5": "#3251a8", # Strong Light Sky Blue
53
+ "-4": "#5A7FB2", # Slightly Lighter Sky Blue
54
+ "-3": "#8198BC", # Intermediate Sky Blue
55
+ "-2": "#A8B1C6", # Light Sky Blue
56
+ "-1": "#E6F0FF", # Very Light Sky Blue
57
+ "0": "#FFFFFF", # White
58
+ "+1": "#FFE6F0", # Lighter Pink
59
+ "+2": "#DF8CA3", # Slightly Stronger Pink
60
+ "+3": "#D7708E", # Intermediate Pink
61
+ "+4": "#CF5480", # Deep Pink
62
+ "+5": "#A83273", # Strong Magenta
63
  }
main.py CHANGED
@@ -75,7 +75,7 @@ with gr.Blocks(
75
 
76
  """)
77
  # row with columns for the different settings
78
- with gr.Row(equal_height=True, variant="compact"):
79
  # column that takes up 3/5 of the row
80
  with gr.Column(scale=3):
81
  # textbox to enter the system prompt
@@ -108,13 +108,15 @@ with gr.Blocks(
108
  # accordion to display the normalized input explanation
109
  with gr.Accordion(label="Input Explanation", open=False):
110
  gr.Markdown("""
111
- #### Input Explanation
112
- The input explanation shows the explanation for the last message
113
- you sent to the AI ChatBot. The explanation is based on the
114
- XAI method you selected.
115
  """)
116
  xai_text = gr.HighlightedText(
117
- color_map=coloring, label="Input Explanation", show_legend=True
 
 
 
118
  )
119
  # out of the box chatbot component
120
  # see documentation: https://www.gradio.app/docs/chatbot
 
75
 
76
  """)
77
  # row with columns for the different settings
78
+ with gr.Row(equal_height=True):
79
  # column that takes up 3/5 of the row
80
  with gr.Column(scale=3):
81
  # textbox to enter the system prompt
 
108
  # accordion to display the normalized input explanation
109
  with gr.Accordion(label="Input Explanation", open=False):
110
  gr.Markdown("""
111
+ The explanations are based on 10 buckets that range between the
112
+ lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
113
+ **The legend show the color for each bucket.**
 
114
  """)
115
  xai_text = gr.HighlightedText(
116
+ color_map=coloring,
117
+ label="Input Explanation",
118
+ show_legend=True,
119
+ show_label=False,
120
  )
121
  # out of the box chatbot component
122
  # see documentation: https://www.gradio.app/docs/chatbot