LennardZuendorf commited on
Commit
2492536
1 Parent(s): 3f2ed3d

chore: updating documentation

Browse files
.gitignore CHANGED
@@ -2,4 +2,3 @@
2
  __pycache__/
3
  /start-venv.sh
4
  /components/iframe/dist/
5
- /components/
 
2
  __pycache__/
3
  /start-venv.sh
4
  /components/iframe/dist/
 
README.md CHANGED
@@ -80,6 +80,7 @@ This project is licensed under the MIT License, see [LICENSE](LICENSE.md) for mo
80
  - University: HTW Berlin
81
 
82
  See code for in detailed credits, work is strongly based on:
 
83
  #### GODEL
84
  - [HGF Model Page](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq?text=Hey+my+name+is+Mariama%21+How+are+you%3F)
85
  - [Paper on HGF](https://huggingface.co/papers/2206.11309)
@@ -88,3 +89,7 @@ See code for in detailed credits, work is strongly based on:
88
  #### SHAP
89
  - [Github](https://github.com/shap/shap)
90
  - [Inital Paper](https://arxiv.org/abs/1705.07874)
 
 
 
 
 
80
  - University: HTW Berlin
81
 
82
  See code for in detailed credits, work is strongly based on:
83
+
84
  #### GODEL
85
  - [HGF Model Page](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq?text=Hey+my+name+is+Mariama%21+How+are+you%3F)
86
  - [Paper on HGF](https://huggingface.co/papers/2206.11309)
 
89
  #### SHAP
90
  - [Github](https://github.com/shap/shap)
91
  - [Inital Paper](https://arxiv.org/abs/1705.07874)
92
+
93
+ #### Custom Component (/components/iframe/)
94
+
95
+ Is based on Gradio component, see indivdual README for full changelog.
__init__.py CHANGED
@@ -1,2 +1 @@
1
- # empty init file for the package
2
- # for fastapi to recognize the module
 
1
+ # empty init file for the module
 
backend/__init__.py CHANGED
@@ -1,2 +1 @@
1
- # empty init file for the package
2
- # for fastapi to recognize the module
 
1
+ # empty init file for the modules
 
backend/controller.py CHANGED
@@ -1,15 +1,16 @@
1
  # controller for the application that calls the model and explanation functions
2
- # and returns the updated conversation history
3
 
4
  # external imports
5
  import gradio as gr
6
 
7
  # internal imports
8
  from model import godel
9
- from explanation import interpret_shap as sint, visualize as viz
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
 
13
  def interference(
14
  prompt: str,
15
  history: list,
@@ -17,18 +18,19 @@ def interference(
17
  system_prompt: str,
18
  xai_selection: str,
19
  ):
20
- # if no system prompt is given, use a default one
21
- if system_prompt == "":
22
  system_prompt = """
23
  You are a helpful, respectful and honest assistant.
24
  Always answer as helpfully as possible, while being safe.
25
  """
26
 
27
- # if a XAI approach is selected, grab the XAI instance
28
  if xai_selection in ("SHAP", "Attention"):
 
29
  match xai_selection.lower():
30
  case "shap":
31
- xai = sint
32
  case "attention":
33
  xai = viz
34
  case _:
@@ -37,9 +39,10 @@ def interference(
37
  There was an error in the selected XAI Approach.
38
  It is "{xai_selection}"
39
  """)
 
40
  raise RuntimeError("There was an error in the selected XAI approach.")
41
 
42
- # call the explained chat function
43
  prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
44
  model=godel,
45
  xai=xai,
@@ -48,7 +51,7 @@ def interference(
48
  system_prompt=system_prompt,
49
  knowledge=knowledge,
50
  )
51
- # if no (or invalid) XAI approach is selected call the vanilla chat function
52
  else:
53
  # call the vanilla chat function
54
  prompt_output, history_output = vanilla_chat(
@@ -78,12 +81,12 @@ def vanilla_chat(
78
  ):
79
  # formatting the prompt using the model's format_prompt function
80
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
 
81
  # generating an answer using the model's respond function
82
  answer = model.respond(prompt)
83
 
84
  # updating the chat history with the new answer
85
  history.append((message, answer))
86
-
87
  # returning the updated history
88
  return "", history
89
 
@@ -94,7 +97,7 @@ def explained_chat(
94
  # formatting the prompt using the model's format_prompt function
95
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
96
 
97
- # generating an answer using the xai methods explain and respond function
98
  answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
99
 
100
  # updating the chat history with the new answer
 
1
  # controller for the application that calls the model and explanation functions
2
+ # returns the updated conversation history and extra elements
3
 
4
  # external imports
5
  import gradio as gr
6
 
7
  # internal imports
8
  from model import godel
9
+ from explanation import interpret_shap as shap_int, visualize as viz
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
13
+ # is getting called on every chat submit
14
  def interference(
15
  prompt: str,
16
  history: list,
 
18
  system_prompt: str,
19
  xai_selection: str,
20
  ):
21
+ # if no proper system prompt is given, use a default one
22
+ if system_prompt in ('', ' '):
23
  system_prompt = """
24
  You are a helpful, respectful and honest assistant.
25
  Always answer as helpfully as possible, while being safe.
26
  """
27
 
28
+ # if a XAI approach is selected, grab the XAI module instance
29
  if xai_selection in ("SHAP", "Attention"):
30
+ # matching selection
31
  match xai_selection.lower():
32
  case "shap":
33
+ xai = shap_int
34
  case "attention":
35
  xai = viz
36
  case _:
 
39
  There was an error in the selected XAI Approach.
40
  It is "{xai_selection}"
41
  """)
42
+ # raise runtime exception
43
  raise RuntimeError("There was an error in the selected XAI approach.")
44
 
45
+ # call the explained chat function with the model instance
46
  prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
47
  model=godel,
48
  xai=xai,
 
51
  system_prompt=system_prompt,
52
  knowledge=knowledge,
53
  )
54
+ # if no XAI approach is selected call the vanilla chat function
55
  else:
56
  # call the vanilla chat function
57
  prompt_output, history_output = vanilla_chat(
 
81
  ):
82
  # formatting the prompt using the model's format_prompt function
83
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
84
+
85
  # generating an answer using the model's respond function
86
  answer = model.respond(prompt)
87
 
88
  # updating the chat history with the new answer
89
  history.append((message, answer))
 
90
  # returning the updated history
91
  return "", history
92
 
 
97
  # formatting the prompt using the model's format_prompt function
98
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
99
 
100
+ # generating an answer using the methods chat function
101
  answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
102
 
103
  # updating the chat history with the new answer
components/iframe/README.md CHANGED
@@ -1,51 +1,17 @@
1
- # gradio_iframe
2
- A custom gradio component to embed an iframe in a gradio interface. This component is based on the [HTML]() component.
3
- It's currently still a work in progress.
4
 
5
- ## Usage
 
6
 
7
- The usage is similar to the HTML component. You can pass valid html and it will be rendered in the interface as an iframe, meaning you can embed any website or webapp that supports iframes.
8
- Also, JavaScript should run normal. You can even pass an iframe inside an iframe (see below!), i.e. a youtube or spotify embed.
 
9
 
10
- The size will adjust to the size of the iframe (onload), **this is gonna be a bit delayed**. The width is default at 100%.
11
- You can also set the height and width manually.
12
 
13
- ### Example
14
-
15
- ```python
16
- import gradio as gr
17
- from gradio_iframe import iFrame
18
-
19
- gr.Interface(
20
- iFrame(
21
- label="iFrame Example",
22
- value=("""
23
- <iframe width="560"
24
- height="315"
25
- src="https://www.youtube.com/embed/dQw4w9WgXcQ?si=QfHLpHZsI98oZT1G"
26
- title="YouTube video player"
27
- frameborder="0"
28
- allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
29
- allowfullscreen>
30
- </iframe>"""),
31
- show_label=True)
32
- )
33
- ```
34
-
35
- ## Roadmap
36
-
37
- - [ ] Add manual hand over of other iFrame options.
38
- - [ ] Explore switch between src and srcdoc through variable.
39
-
40
- ## Known Issues
41
-
42
- **There are many reason why it's not a good idea to embed websites in an iframe.**
43
- See [this](https://blog.bitsrc.io/4-security-concerns-with-iframes-every-web-developer-should-know-24c73e6a33e4), or just google "iframe security concerns" for more information. Also, iFrames will use additional computing power and memory, which can slow down the interface.
44
-
45
- Also, this component is still a work in progress and not fully tested. Use at your own risk.
46
-
47
- ### Other Issues
48
-
49
- - Height sometimes does not grow according to the inner component.
50
- - The component is not completely responsive yet and struggles with variable heigth.
51
- - ...
 
1
+ # gradio iFrame
 
 
2
 
3
+ This is a custom gradio component used to display the shap package text plot. Which is interactive HTML and needs a custom wrapper.
4
+ See custom component examples at offical [docu](https://www.gradio.app/guides/custom-components-in-five-minutes)
5
 
6
+ # Credit
7
+ CREDIT: based mostly of Gradio template component, HTML
8
+ see: https://www.gradio.app/docs/html
9
 
10
+ ## Changes
11
+ **Addition/changes are marked. Everything else can be considered the work of other (the Gradio Team)**
12
 
13
+ #### Changes Files/Contributions
14
+ - backend/iframe.py - updating component to accept custom height/width and added new example
15
+ - demo/app.py - slightly changed demo file for better dev experience
16
+ - frontend/index.svelte - slightly changed to accept custom height/width
17
+ - frontend/HTML.svelte - updated to use iFrame and added custom function to programmtically set heigth values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
components/iframe/backend/gradio_iframe/iframe.py CHANGED
@@ -62,10 +62,12 @@ class iFrame(Component):
62
  value=value,
63
  )
64
 
 
65
  self.height = height
66
  self.width = width
67
 
68
  def example_inputs(self) -> Any:
 
69
  return """<iframe width="560" height="315" src="https://www.youtube.com/embed/dQw4w9WgXcQ?si=QfHLpHZsI98oZT1G" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>"""
70
 
71
  def preprocess(self, payload: str | None) -> str | None:
 
62
  value=value,
63
  )
64
 
65
+ # updating component to take custom height and width values
66
  self.height = height
67
  self.width = width
68
 
69
  def example_inputs(self) -> Any:
70
+ # setting a custom example
71
  return """<iframe width="560" height="315" src="https://www.youtube.com/embed/dQw4w9WgXcQ?si=QfHLpHZsI98oZT1G" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>"""
72
 
73
  def preprocess(self, payload: str | None) -> str | None:
components/iframe/frontend/Index.svelte CHANGED
@@ -1,3 +1,5 @@
 
 
1
  <script lang="ts">
2
  import type { Gradio } from "@gradio/utils";
3
  import HTML from "./shared/HTML.svelte";
@@ -10,6 +12,7 @@
10
  export let elem_classes: string[] = [];
11
  export let visible = true;
12
  export let value = "";
 
13
  export let height: string;
14
  export let width: string = "100%";
15
  export let loading_status: LoadingStatus;
 
1
+ # index component that wraps the custom iFrame ("HTML")
2
+
3
  <script lang="ts">
4
  import type { Gradio } from "@gradio/utils";
5
  import HTML from "./shared/HTML.svelte";
 
12
  export let elem_classes: string[] = [];
13
  export let visible = true;
14
  export let value = "";
15
+ # updated to take custom heigth
16
  export let height: string;
17
  export let width: string = "100%";
18
  export let loading_status: LoadingStatus;
components/iframe/frontend/shared/HTML.svelte CHANGED
@@ -1,3 +1,5 @@
 
 
1
  <script lang="ts">
2
  import { createEventDispatcher } from "svelte";
3
  export let elem_classes: string[] = [];
@@ -5,6 +7,7 @@
5
  export let visible = true;
6
  export let min_height = false;
7
 
 
8
  export let height = "100%";
9
  export let width = "100%";
10
 
@@ -12,10 +15,14 @@
12
 
13
  let iframeElement;
14
 
 
15
  const onLoad = () => {
16
  try {
 
17
  const iframeDocument = iframeElement.contentDocument || iframeElement.contentWindow.document;
 
18
  if (height === "100%") {
 
19
  const height = iframeDocument.documentElement.scrollHeight;
20
  iframeElement.style.height = `${height}px`;
21
  }
@@ -33,6 +40,7 @@
33
  class:hide={!visible}
34
  class:height={height}
35
  >
 
36
  <iframe
37
  bind:this={iframeElement}
38
  title="iframe component"
 
1
+ # HTML component that implements custom iFrame
2
+
3
  <script lang="ts">
4
  import { createEventDispatcher } from "svelte";
5
  export let elem_classes: string[] = [];
 
7
  export let visible = true;
8
  export let min_height = false;
9
 
10
+ # default setting height and width
11
  export let height = "100%";
12
  export let width = "100%";
13
 
 
15
 
16
  let iframeElement;
17
 
18
+ # custom function to update iFrame height on load of HTML
19
  const onLoad = () => {
20
  try {
21
+ # calling iFrame document
22
  const iframeDocument = iframeElement.contentDocument || iframeElement.contentWindow.document;
23
+ # if heigth not custom, setting height individually
24
  if (height === "100%") {
25
+ # grabbing height from iFrame document
26
  const height = iframeDocument.documentElement.scrollHeight;
27
  iframeElement.style.height = `${height}px`;
28
  }
 
40
  class:hide={!visible}
41
  class:height={height}
42
  >
43
+ # updated to use Iframe instead of HTML, using string values with srcdoc
44
  <iframe
45
  bind:this={iframeElement}
46
  title="iframe component"
explanation/__init__.py CHANGED
@@ -1,2 +1 @@
1
- # empty init file for the package
2
- # for fastapi to recognize the module
 
1
+ # empty init file for the modules
 
explanation/interpret_shap.py CHANGED
@@ -1,4 +1,5 @@
1
  # interpret module that implements the interpretability method
 
2
  # external imports
3
  from shap import models, maskers, plots, PartitionExplainer
4
  import torch
@@ -14,14 +15,15 @@ TEXT_MASKER = None
14
 
15
  # main explain function that returns a chat with explanations
16
  def chat_explained(model, prompt):
17
- model.set_config()
18
 
19
  # create the shap explainer
20
  shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
 
21
  # get the shap values for the prompt
22
  shap_values = shap_explainer([prompt])
23
 
24
- # create the explanation graphic and plot
25
  graphic = create_graphic(shap_values)
26
  marked_text = markup_text(
27
  shap_values.data[0], shap_values.values[0], variant="shap"
@@ -29,20 +31,26 @@ def chat_explained(model, prompt):
29
 
30
  # create the response text
31
  response_text = fmt.format_output_text(shap_values.output_names)
 
 
32
  return response_text, graphic, marked_text
33
 
34
 
 
35
  def wrap_shap(model):
 
36
  global TEXT_MASKER, TEACHER_FORCING
37
 
38
  # set the device to cuda if gpu is available
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
 
41
- # updating the model settings again
42
  model.set_config()
43
 
44
  # (re)initialize the shap models and masker
 
45
  text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
 
46
  TEACHER_FORCING = models.TeacherForcing(
47
  text_generation,
48
  model.TOKENIZER,
@@ -50,13 +58,15 @@ def wrap_shap(model):
50
  similarity_model=model.MODEL,
51
  similarity_tokenizer=model.TOKENIZER,
52
  )
 
53
  TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
54
 
55
 
56
  # graphic plotting function that creates a html graphic (as string) for the explanation
57
  def create_graphic(shap_values):
 
58
  # create the html graphic using shap text plot function
59
  graphic_html = plots.text(shap_values, display=False)
60
 
61
- # return the html graphic as string
62
  return str(graphic_html)
 
1
  # interpret module that implements the interpretability method
2
+
3
  # external imports
4
  from shap import models, maskers, plots, PartitionExplainer
5
  import torch
 
15
 
16
  # main explain function that returns a chat with explanations
17
  def chat_explained(model, prompt):
18
+ model.set_config({})
19
 
20
  # create the shap explainer
21
  shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
22
+
23
  # get the shap values for the prompt
24
  shap_values = shap_explainer([prompt])
25
 
26
+ # create the explanation graphic and marked text array
27
  graphic = create_graphic(shap_values)
28
  marked_text = markup_text(
29
  shap_values.data[0], shap_values.values[0], variant="shap"
 
31
 
32
  # create the response text
33
  response_text = fmt.format_output_text(shap_values.output_names)
34
+
35
+ # return response, graphic and marked_text array
36
  return response_text, graphic, marked_text
37
 
38
 
39
+ # function used to wrap the model with a shap model
40
  def wrap_shap(model):
41
+ # calling global variants
42
  global TEXT_MASKER, TEACHER_FORCING
43
 
44
  # set the device to cuda if gpu is available
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
 
47
+ # updating the model settings
48
  model.set_config()
49
 
50
  # (re)initialize the shap models and masker
51
+ # creating a shap text_generation model
52
  text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
53
+ # wrapping the text generation model in a teacher forcing model
54
  TEACHER_FORCING = models.TeacherForcing(
55
  text_generation,
56
  model.TOKENIZER,
 
58
  similarity_model=model.MODEL,
59
  similarity_tokenizer=model.TOKENIZER,
60
  )
61
+ # setting the text masker as an empty string
62
  TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
63
 
64
 
65
  # graphic plotting function that creates a html graphic (as string) for the explanation
66
  def create_graphic(shap_values):
67
+
68
  # create the html graphic using shap text plot function
69
  graphic_html = plots.text(shap_values, display=False)
70
 
71
+ # return the html graphic as string to display in iFrame
72
  return str(graphic_html)
explanation/markup.py CHANGED
@@ -1,4 +1,4 @@
1
- # markup module that provides marked up text and a plot for the explanations
2
 
3
  # external imports
4
  import numpy as np
@@ -8,10 +8,12 @@ from numpy import ndarray
8
  from utils import formatting as fmt
9
 
10
 
 
11
  def markup_text(input_text: list, text_values: ndarray, variant: str):
 
12
  bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
13
 
14
- # Flatten the values depending on the source
15
  # attention is averaged, SHAP summed up
16
  if variant == "shap":
17
  text_values = np.transpose(text_values)
@@ -22,34 +24,49 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
22
  # Determine the minimum and maximum values
23
  min_val, max_val = np.min(text_values), np.max(text_values)
24
 
25
- # Separate the threshold calculation for negative and positive values
 
26
  if variant == "visualizer":
27
  neg_thresholds = np.linspace(
28
  0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
29
  )[1:]
 
30
  else:
31
  neg_thresholds = np.linspace(
32
  min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
33
  )[1:]
 
34
  pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
 
35
  thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
36
 
 
37
  marked_text = []
38
 
39
- # Function to determine the bucket for a given value
40
  for text, value in zip(input_text, text_values):
 
41
  bucket = "-5"
 
 
42
  for i, threshold in zip(bucket_tags, thresholds):
 
43
  if value >= threshold:
44
  bucket = i
 
45
  marked_text.append((text, str(bucket)))
46
 
 
47
  return marked_text
48
 
49
 
 
 
50
  def color_codes():
51
  return {
52
- # 1-5: Strong Light Sky Blue to Lighter Sky Blue
 
 
53
  "-5": "#3251a8", # Strong Light Sky Blue
54
  "-4": "#5A7FB2", # Slightly Lighter Sky Blue
55
  "-3": "#8198BC", # Intermediate Sky Blue
 
1
+ # markup module that provides marked up text as an array
2
 
3
  # external imports
4
  import numpy as np
 
8
  from utils import formatting as fmt
9
 
10
 
11
+ # main function that assigns each text snipped a marked bucket
12
  def markup_text(input_text: list, text_values: ndarray, variant: str):
13
+ # naming of the 11 buckets
14
  bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
15
 
16
+ # flatten the values depending on the source
17
  # attention is averaged, SHAP summed up
18
  if variant == "shap":
19
  text_values = np.transpose(text_values)
 
24
  # Determine the minimum and maximum values
25
  min_val, max_val = np.min(text_values), np.max(text_values)
26
 
27
+ # separate the threshold calculation for negative and positive values
28
+ # visualization negative thresholds are all 0 since attetion always positive
29
  if variant == "visualizer":
30
  neg_thresholds = np.linspace(
31
  0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
32
  )[1:]
33
+ # standart config for 5 negative buckets
34
  else:
35
  neg_thresholds = np.linspace(
36
  min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
37
  )[1:]
38
+ # creating positive thresholds between 0 and max values
39
  pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
40
+ # combining thresholds
41
  thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
42
 
43
+ # init empty marked text list
44
  marked_text = []
45
 
46
+ # looping over each text snippet and attribution value
47
  for text, value in zip(input_text, text_values):
48
+ # setting inital bucket at lowest
49
  bucket = "-5"
50
+
51
+ # looping over all bucket and their threshold
52
  for i, threshold in zip(bucket_tags, thresholds):
53
+ # updating assigned bucket if value is above threshold
54
  if value >= threshold:
55
  bucket = i
56
+ # finally adding text and bucket assignment to list of tuples
57
  marked_text.append((text, str(bucket)))
58
 
59
+ # returning list of marked text snippets as list of tuples
60
  return marked_text
61
 
62
 
63
+ # function that defines color codes
64
+ # coloring along SHAP style coloring for consistency
65
  def color_codes():
66
  return {
67
+ # -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
68
+ # 0: white (assuming default light mode)
69
+ # +1 to +5 light pink to string magenta
70
  "-5": "#3251a8", # Strong Light Sky Blue
71
  "-4": "#5A7FB2", # Slightly Lighter Sky Blue
72
  "-3": "#8198BC", # Intermediate Sky Blue
explanation/visualize.py CHANGED
@@ -1,21 +1,26 @@
1
- # visualization module that creates an attention visualization using BERTViz
2
 
3
 
4
  # internal imports
5
  from utils import formatting as fmt
 
6
  from .markup import markup_text
7
 
8
 
9
- # plotting function that plots the attention values in a heatmap
 
10
  def chat_explained(model, prompt):
11
 
12
- model.set_config()
13
-
14
- # get encoded input and output vectors
15
  encoder_input_ids = model.TOKENIZER(
16
  prompt, return_tensors="pt", add_special_tokens=True
17
  ).input_ids
18
- decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
 
 
 
 
 
19
  encoder_text = fmt.format_tokens(
20
  model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
21
  )
@@ -24,20 +29,25 @@ def chat_explained(model, prompt):
24
  )
25
 
26
  # get attention values for the input and output vectors
 
27
  attention_output = model.MODEL(
28
  input_ids=encoder_input_ids,
29
  decoder_input_ids=decoder_input_ids,
30
  output_attentions=True,
31
  )
32
 
 
33
  averaged_attention = fmt.avg_attention(attention_output)
34
 
35
- # create the response text and marked text for ui
36
  response_text = fmt.format_output_text(decoder_text)
 
37
  graphic = (
38
  "<div style='text-align: center; font-family:arial;'><h4>Attention"
39
  " Visualization doesn't support an interactive graphic.</h4></div>"
40
  )
 
41
  marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
42
 
 
43
  return response_text, graphic, marked_text
 
1
+ # visualization module that creates an attention visualization
2
 
3
 
4
  # internal imports
5
  from utils import formatting as fmt
6
+ from model.godel import CONFIG
7
  from .markup import markup_text
8
 
9
 
10
+ # chat function that returns an answer
11
+ # and marked text based on attention
12
  def chat_explained(model, prompt):
13
 
14
+ # get encoded input
 
 
15
  encoder_input_ids = model.TOKENIZER(
16
  prompt, return_tensors="pt", add_special_tokens=True
17
  ).input_ids
18
+ # generate output together with attentions of the model
19
+ decoder_input_ids = model.MODEL.generate(
20
+ encoder_input_ids, output_attentions=True, **CONFIG
21
+ )
22
+
23
+ # get input and output text as list of strings
24
  encoder_text = fmt.format_tokens(
25
  model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
26
  )
 
29
  )
30
 
31
  # get attention values for the input and output vectors
32
+ # using already generated input and output
33
  attention_output = model.MODEL(
34
  input_ids=encoder_input_ids,
35
  decoder_input_ids=decoder_input_ids,
36
  output_attentions=True,
37
  )
38
 
39
+ # averaging attention across layers
40
  averaged_attention = fmt.avg_attention(attention_output)
41
 
42
+ # format response text for clean output
43
  response_text = fmt.format_output_text(decoder_text)
44
+ # setting placeholder for iFrame graphic
45
  graphic = (
46
  "<div style='text-align: center; font-family:arial;'><h4>Attention"
47
  " Visualization doesn't support an interactive graphic.</h4></div>"
48
  )
49
+ # creating marked text using markup_text function and attention
50
  marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
51
 
52
+ # returning response, graphic and marked text array
53
  return response_text, graphic, marked_text
main.py CHANGED
@@ -14,13 +14,21 @@ from gradio_iframe import iFrame
14
  from backend.controller import interference
15
  from explanation.markup import color_codes
16
 
17
- # Global Variables and css
 
 
18
  app = FastAPI()
 
 
 
 
19
  css = """
20
  .examples {text-align: start;}
21
  .seperatedRow {border-top: 1rem solid;}",
22
  """
23
- js = """
 
 
24
  function () {
25
  gradioURL = window.location.href
26
  if (!gradioURL.endsWith('?__theme=light')) {
@@ -28,7 +36,8 @@ js = """
28
  }
29
  }
30
  """
31
- coloring = color_codes()
 
32
 
33
 
34
  # different functions to provide frontend abilities
@@ -56,8 +65,8 @@ def xai_info(xai_radio):
56
  gr.Info("No XAI method was selected.")
57
 
58
 
59
- # ui interface based on Gradio Blocks (see documentation:
60
- # https://www.gradio.app/docs/interface)
61
  with gr.Blocks(
62
  css=css,
63
  js=js,
@@ -88,6 +97,7 @@ with gr.Blocks(
88
  """)
89
  # row with columns for the different settings
90
  with gr.Row(equal_height=True):
 
91
  with gr.Accordion(label="Application Settings", open=False):
92
  # column that takes up 3/4 of the row
93
  with gr.Column(scale=3):
@@ -95,6 +105,7 @@ with gr.Blocks(
95
  system_prompt = gr.Textbox(
96
  label="System Prompt",
97
  info="Set the models system prompt, dictating how it answers.",
 
98
  placeholder=(
99
  "You are a helpful, respectful and honest assistant. Always"
100
  " answer as helpfully as possible, while being safe."
@@ -105,26 +116,29 @@ with gr.Blocks(
105
  # checkbox group to select the xai method
106
  xai_selection = gr.Radio(
107
  ["None", "SHAP", "Attention"],
108
- label="XAI Settings",
109
- info="Select a XAI Implementation to use.",
110
  value="None",
111
  interactive=True,
112
  show_label=True,
113
  )
114
 
115
- # calling info functions on inputs for different settings
116
  system_prompt.submit(system_prompt_info, [system_prompt])
117
  xai_selection.input(xai_info, [xai_selection])
118
 
119
  # row with chatbot ui displaying "conversation" with the model
120
  with gr.Row(equal_height=True):
 
121
  with gr.Group(elem_classes="border: 1px solid black;"):
122
  # accordion to display the normalized input explanation
123
  with gr.Accordion(label="Input Explanation", open=False):
124
  gr.Markdown("""
125
  The explanations are based on 10 buckets that range between the
126
  lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
127
- **The legend show the color for each bucket.**
 
 
128
  """)
129
  xai_text = gr.HighlightedText(
130
  color_map=coloring,
@@ -132,15 +146,19 @@ with gr.Blocks(
132
  show_legend=True,
133
  show_label=False,
134
  )
135
- # out of the box chatbot component
136
  # see documentation: https://www.gradio.app/docs/chatbot
137
  chatbot = gr.Chatbot(
138
  layout="panel",
139
  show_copy_button=True,
140
  avatar_images=("./public/human.jpg", "./public/bot.jpg"),
141
  )
142
- # textbox to enter the knowledge
143
  with gr.Accordion(label="Additional Knowledge", open=False):
 
 
 
 
144
  knowledge_input = gr.Textbox(
145
  value="",
146
  label="Knowledge",
@@ -149,24 +167,31 @@ with gr.Blocks(
149
  show_label=True,
150
  )
151
  # textbox to enter the user prompt
 
 
 
 
152
  user_prompt = gr.Textbox(
153
  label="Input Message",
154
  max_lines=5,
155
  info="""
156
  Ask the ChatBot a question.
157
- Hint: More complicated question give better explanation insights!
158
  """,
159
  show_label=True,
160
  )
161
  # row with columns for buttons to submit and clear content
162
  with gr.Row(elem_classes=""):
163
- with gr.Column(scale=1):
164
  # out of the box clear button which clearn the given components (see
165
- # documentation: https://www.gradio.app/docs/clearbutton)
166
  clear_btn = gr.ClearButton([user_prompt, chatbot])
167
- with gr.Column(scale=1):
 
168
  submit_btn = gr.Button("Submit", variant="primary")
 
169
  with gr.Row(elem_classes="examples"):
 
 
170
  gr.Examples(
171
  label="Example Questions",
172
  examples=[
@@ -235,18 +260,21 @@ with gr.Blocks(
235
  # final row to show legal information
236
  ## - credits, data protection and link to the License
237
  with gr.Tab(label="About"):
 
238
  gr.Markdown(value=load_md("public/about.md"))
239
  with gr.Accordion(label="Credits, Data Protection, License"):
 
240
  gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
241
 
242
  # mount function for fastAPI Application
243
  app = gr.mount_gradio_app(app, ui, path="/")
244
 
245
- # launch function using uvicorn to launch the fastAPI application
246
  if __name__ == "__main__":
247
 
248
  # use standard gradio launch option for hgf spaces
249
  if os.environ["HOSTING"].lower() == "spaces":
 
250
  ui.launch(auth=("htw", "berlin@123"))
251
 
252
  # otherwise run the application on port 8080 in reload mode
 
14
  from backend.controller import interference
15
  from explanation.markup import color_codes
16
 
17
+
18
+ # global Variables and js/css
19
+ # creating FastAPI app and getting color codes
20
  app = FastAPI()
21
+ coloring = color_codes()
22
+
23
+
24
+ # defining custom css and js for certain environments
25
  css = """
26
  .examples {text-align: start;}
27
  .seperatedRow {border-top: 1rem solid;}",
28
  """
29
+ # custom js to force lightmode in custom environments
30
+ if os.environ["HOSTING"].lower() != "spaces":
31
+ js = """
32
  function () {
33
  gradioURL = window.location.href
34
  if (!gradioURL.endsWith('?__theme=light')) {
 
36
  }
37
  }
38
  """
39
+ else:
40
+ js = ""
41
 
42
 
43
  # different functions to provide frontend abilities
 
65
  gr.Info("No XAI method was selected.")
66
 
67
 
68
+ # ui interface based on Gradio Blocks
69
+ # see https://www.gradio.app/docs/interface)
70
  with gr.Blocks(
71
  css=css,
72
  js=js,
 
97
  """)
98
  # row with columns for the different settings
99
  with gr.Row(equal_height=True):
100
+ # accordion that extends if clicked
101
  with gr.Accordion(label="Application Settings", open=False):
102
  # column that takes up 3/4 of the row
103
  with gr.Column(scale=3):
 
105
  system_prompt = gr.Textbox(
106
  label="System Prompt",
107
  info="Set the models system prompt, dictating how it answers.",
108
+ # default system prompt is set to this in the backend
109
  placeholder=(
110
  "You are a helpful, respectful and honest assistant. Always"
111
  " answer as helpfully as possible, while being safe."
 
116
  # checkbox group to select the xai method
117
  xai_selection = gr.Radio(
118
  ["None", "SHAP", "Attention"],
119
+ label="Interpretability Settings",
120
+ info="Select a Interpretability Implementation to use.",
121
  value="None",
122
  interactive=True,
123
  show_label=True,
124
  )
125
 
126
+ # calling info functions on inputs/submits for different settings
127
  system_prompt.submit(system_prompt_info, [system_prompt])
128
  xai_selection.input(xai_info, [xai_selection])
129
 
130
  # row with chatbot ui displaying "conversation" with the model
131
  with gr.Row(equal_height=True):
132
+ # group to display components closely together
133
  with gr.Group(elem_classes="border: 1px solid black;"):
134
  # accordion to display the normalized input explanation
135
  with gr.Accordion(label="Input Explanation", open=False):
136
  gr.Markdown("""
137
  The explanations are based on 10 buckets that range between the
138
  lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
139
+ **The legend shows the color for each bucket.**
140
+
141
+ *HINT*: This works best in light mode.
142
  """)
143
  xai_text = gr.HighlightedText(
144
  color_map=coloring,
 
146
  show_legend=True,
147
  show_label=False,
148
  )
149
+ # out of the box chatbot component with avatar images
150
  # see documentation: https://www.gradio.app/docs/chatbot
151
  chatbot = gr.Chatbot(
152
  layout="panel",
153
  show_copy_button=True,
154
  avatar_images=("./public/human.jpg", "./public/bot.jpg"),
155
  )
156
+ # extenable components for extra knowledge
157
  with gr.Accordion(label="Additional Knowledge", open=False):
158
+ gr.Markdown(
159
+ "*Hint:* Add extra knowledge to see GODEL work the best."
160
+ )
161
+ # textbox to enter the knowledge
162
  knowledge_input = gr.Textbox(
163
  value="",
164
  label="Knowledge",
 
167
  show_label=True,
168
  )
169
  # textbox to enter the user prompt
170
+ gr.Markdown(
171
+ "*Hint:* More complicated question give better explanation"
172
+ " insights!"
173
+ )
174
  user_prompt = gr.Textbox(
175
  label="Input Message",
176
  max_lines=5,
177
  info="""
178
  Ask the ChatBot a question.
 
179
  """,
180
  show_label=True,
181
  )
182
  # row with columns for buttons to submit and clear content
183
  with gr.Row(elem_classes=""):
184
+ with gr.Column():
185
  # out of the box clear button which clearn the given components (see
186
+ # see: https://www.gradio.app/docs/clearbutton)
187
  clear_btn = gr.ClearButton([user_prompt, chatbot])
188
+ with gr.Column():
189
+ # submit button that calls the backend functions on click
190
  submit_btn = gr.Button("Submit", variant="primary")
191
+ # row with content examples that get autofilled on click
192
  with gr.Row(elem_classes="examples"):
193
+ # examples util component
194
+ # see: https://www.gradio.app/docs/examples
195
  gr.Examples(
196
  label="Example Questions",
197
  examples=[
 
260
  # final row to show legal information
261
  ## - credits, data protection and link to the License
262
  with gr.Tab(label="About"):
263
+ # load about.md markdown
264
  gr.Markdown(value=load_md("public/about.md"))
265
  with gr.Accordion(label="Credits, Data Protection, License"):
266
+ # load credits and dataprotection markdown
267
  gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
268
 
269
  # mount function for fastAPI Application
270
  app = gr.mount_gradio_app(app, ui, path="/")
271
 
272
+ # launch function to launch the application
273
  if __name__ == "__main__":
274
 
275
  # use standard gradio launch option for hgf spaces
276
  if os.environ["HOSTING"].lower() == "spaces":
277
+ # set password to deny public access
278
  ui.launch(auth=("htw", "berlin@123"))
279
 
280
  # otherwise run the application on port 8080 in reload mode
model/__init__.py CHANGED
@@ -1,2 +1 @@
1
- # empty init file for the package
2
- # for fastapi to recognize the module
 
1
+ # empty init file for the module
 
model/godel.py CHANGED
@@ -6,21 +6,28 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  # internal imports
7
  from utils import modelling as mdl
8
 
9
- # model and tokenizer instance
10
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
11
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
 
 
12
  CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
13
 
14
 
15
- # TODO: Make config variable
16
- def set_config(config: dict = None):
17
- if config is None:
18
- config = {}
19
 
20
- MODEL.config.max_new_tokens = 50
21
- MODEL.config.min_length = 8
22
- MODEL.config.top_p = 0.9
23
- MODEL.config.do_sample = True
 
 
 
 
 
 
24
 
25
 
26
  # formatting class to formatting input for the model
@@ -56,8 +63,12 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
56
  # CREDIT: Copied from official interference example on Huggingface
57
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
58
  def respond(prompt):
 
59
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
 
 
60
  outputs = MODEL.generate(input_ids, **CONFIG)
61
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
62
 
 
63
  return output
 
6
  # internal imports
7
  from utils import modelling as mdl
8
 
9
+ # global model and tokenizer instance (created on inital build)
10
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
11
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
12
+
13
+ # default model config
14
  CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
15
 
16
 
17
+ # function to (re) set config
18
+ def set_config(config: dict):
19
+ global CONFIG
 
20
 
21
+ # if config dict is given, update it
22
+ if config != {}:
23
+ CONFIG = config
24
+ else:
25
+ # hard setting model config to default
26
+ # needed for shap
27
+ MODEL.config.max_new_tokens = 50
28
+ MODEL.config.min_length = 8
29
+ MODEL.config.top_p = 0.9
30
+ MODEL.config.do_sample = True
31
 
32
 
33
  # formatting class to formatting input for the model
 
63
  # CREDIT: Copied from official interference example on Huggingface
64
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
65
  def respond(prompt):
66
+ # tokenizing input string
67
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
68
+
69
+ # generating using config and decoding output
70
  outputs = MODEL.generate(input_ids, **CONFIG)
71
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
72
 
73
+ # returns the model output string
74
  return output
utils/__init__.py CHANGED
@@ -0,0 +1 @@
 
 
1
+ # empty init file for the module
utils/formatting.py CHANGED
@@ -7,8 +7,10 @@ from numpy import ndarray
7
 
8
 
9
  # function to format the model reponse nicely
 
10
  def format_output_text(output: list):
11
- # remove special tokens from list
 
12
  formatted_output = format_tokens(output)
13
 
14
  # start string with first list item if it is not empty
@@ -34,8 +36,10 @@ def format_output_text(output: list):
34
 
35
  # format the tokens by removing special tokens and special characters
36
  def format_tokens(tokens: list):
37
- # define special tokens to remove and initialize empty list
38
  special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
 
 
39
  updated_tokens = []
40
 
41
  # loop through tokens
@@ -44,7 +48,7 @@ def format_tokens(tokens: list):
44
  if t.startswith("▁"):
45
  t = t.lstrip("▁")
46
 
47
- # loop through special tokens and remove them if found
48
  for s in special_tokens:
49
  t = t.replace(s, "")
50
 
@@ -55,15 +59,17 @@ def format_tokens(tokens: list):
55
  return updated_tokens
56
 
57
 
58
- # function to flatten values into a 2d list by averaging the explanation values
59
  def flatten_attribution(values: ndarray, axis: int = 0):
60
  return np.sum(values, axis=axis)
61
 
62
 
 
63
  def flatten_attention(values: ndarray, axis: int = 0):
64
  return np.mean(values, axis=axis)
65
 
66
 
 
67
  def avg_attention(attention_values):
68
  attention = attention_values.decoder_attentions[0][0].detach().numpy()
69
  return np.mean(attention, axis=0)
 
7
 
8
 
9
  # function to format the model reponse nicely
10
+ # takes a list of strings and returnings a combined string
11
  def format_output_text(output: list):
12
+
13
+ # remove special tokens from list using other function
14
  formatted_output = format_tokens(output)
15
 
16
  # start string with first list item if it is not empty
 
36
 
37
  # format the tokens by removing special tokens and special characters
38
  def format_tokens(tokens: list):
39
+ # define special tokens to remove
40
  special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
41
+
42
+ # initialize empty list
43
  updated_tokens = []
44
 
45
  # loop through tokens
 
48
  if t.startswith("▁"):
49
  t = t.lstrip("▁")
50
 
51
+ # loop through special tokens list and remove from current token if matched
52
  for s in special_tokens:
53
  t = t.replace(s, "")
54
 
 
59
  return updated_tokens
60
 
61
 
62
+ # function to flatten shap values into a 2d list by summing them up
63
  def flatten_attribution(values: ndarray, axis: int = 0):
64
  return np.sum(values, axis=axis)
65
 
66
 
67
+ # function to flatten values into a 2d list by averaging the attention values
68
  def flatten_attention(values: ndarray, axis: int = 0):
69
  return np.mean(values, axis=axis)
70
 
71
 
72
+ # function to get averaged decoder attention from attention values
73
  def avg_attention(attention_values):
74
  attention = attention_values.decoder_attentions[0][0].detach().numpy()
75
  return np.mean(attention, axis=0)
utils/modelling.py CHANGED
@@ -1,26 +1,28 @@
1
- # module for modelling utilities
2
 
3
  # external imports
4
  import gradio as gr
5
 
6
 
 
 
7
  def prompt_limiter(
8
  tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
9
  ):
10
- # initializing the prompt history empty
11
  prompt_history = []
12
- # getting the token count for the message, system prompt, and knowledge
13
  pre_count = (
14
  token_counter(tokenizer, message)
15
  + token_counter(tokenizer, system_prompt)
16
  + token_counter(tokenizer, knowledge)
17
  )
18
 
19
- # validating the token count
20
- # check if token count already too high
21
  if pre_count > 1024:
22
 
23
- # check if token count too high even without knowledge
24
  if (
25
  token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
26
  > 1024
@@ -32,11 +34,14 @@ def prompt_limiter(
32
  "Message and system prompt are too long. Please shorten them."
33
  )
34
 
35
- # show warning and remove knowledge
36
- gr.Warning("Knowledge is too long. It has been removed to keep model running.")
 
 
 
37
  return message, prompt_history, system_prompt, ""
38
 
39
- # if token count small enough, add history
40
  if pre_count < 800:
41
  # setting the count to the precount
42
  count = pre_count
@@ -46,7 +51,7 @@ def prompt_limiter(
46
  # iterating through the history
47
  for conversation in history:
48
 
49
- # checking the token count with the current conversation
50
  count += token_counter(tokenizer, conversation[0]) + token_counter(
51
  tokenizer, conversation[1]
52
  )
@@ -57,7 +62,7 @@ def prompt_limiter(
57
  else:
58
  break
59
 
60
- # return the message, prompt history, system prompt, and knowledge
61
  return message, prompt_history, system_prompt, knowledge
62
 
63
 
 
1
+ # modelling util module providing formatting functions for model functionalities
2
 
3
  # external imports
4
  import gradio as gr
5
 
6
 
7
+ # function that limits the prompt to contain model runtime
8
+ # tries to keep as much as possible, always keeping at least message and system prompt
9
  def prompt_limiter(
10
  tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
11
  ):
12
+ # initializing the new prompt history empty
13
  prompt_history = []
14
+ # getting the current token count for the message, system prompt, and knowledge
15
  pre_count = (
16
  token_counter(tokenizer, message)
17
  + token_counter(tokenizer, system_prompt)
18
  + token_counter(tokenizer, knowledge)
19
  )
20
 
21
+ # validating the token count against threshold of 1024
22
+ # check if token count already too high without history
23
  if pre_count > 1024:
24
 
25
+ # check if token count too high even without knowledge and history
26
  if (
27
  token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
28
  > 1024
 
34
  "Message and system prompt are too long. Please shorten them."
35
  )
36
 
37
+ # show warning and return with empty history and empty knowledge
38
+ gr.Warning("""
39
+ Input too long.
40
+ Knowledge and conversation history have been removed to keep model running.
41
+ """)
42
  return message, prompt_history, system_prompt, ""
43
 
44
+ # if token count small enough, adding history bit by bit
45
  if pre_count < 800:
46
  # setting the count to the precount
47
  count = pre_count
 
51
  # iterating through the history
52
  for conversation in history:
53
 
54
+ # checking the token count with the current conversation
55
  count += token_counter(tokenizer, conversation[0]) + token_counter(
56
  tokenizer, conversation[1]
57
  )
 
62
  else:
63
  break
64
 
65
+ # return the message, adapted, system prompt, and knowledge
66
  return message, prompt_history, system_prompt, knowledge
67
 
68