LennardZuendorf commited on
Commit
229e14c
1 Parent(s): 11174d4

feat/fix: fixing code issues, adding plotting functions

Browse files
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  __pycache__/
3
  /start-venv.sh
4
  /components/iframe/dist/
 
 
2
  __pycache__/
3
  /start-venv.sh
4
  /components/iframe/dist/
5
+ .venv
backend/controller.py CHANGED
@@ -10,7 +10,7 @@ from model import mistral
10
  from explanation import (
11
  interpret_shap as shap_int,
12
  interpret_captum as cpt_int,
13
- visualize as viz,
14
  )
15
 
16
 
@@ -33,10 +33,10 @@ def interference(
33
 
34
  if model_selection.lower() == "mistral":
35
  model = mistral
36
- print("Indetified model as Mistral")
37
  else:
38
  model = godel
39
- print("Indetified model as GODEL")
40
 
41
  # if a XAI approach is selected, grab the XAI module instance
42
  if xai_selection in ("SHAP", "Attention"):
 
10
  from explanation import (
11
  interpret_shap as shap_int,
12
  interpret_captum as cpt_int,
13
+ visualize_att as viz,
14
  )
15
 
16
 
 
33
 
34
  if model_selection.lower() == "mistral":
35
  model = mistral
36
+ print("Indentified model as Mistral")
37
  else:
38
  model = godel
39
+ print("Indentified model as GODEL")
40
 
41
  # if a XAI approach is selected, grab the XAI module instance
42
  if xai_selection in ("SHAP", "Attention"):
explanation/interpret_captum.py CHANGED
@@ -1,40 +0,0 @@
1
- # external imports
2
- from captum.attr import LLMAttribution, TextTokenInput, KernelShap
3
- import torch
4
-
5
- # internal imports
6
- from utils import formatting as fmt
7
- from .markup import markup_text
8
-
9
-
10
- # main explain function that returns a chat with explanations
11
- def chat_explained(model, prompt):
12
- model.set_config({})
13
-
14
- # creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
15
- llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
16
-
17
- # generation attribution
18
- attribution_input = TextTokenInput(prompt, model.TOKENIZER)
19
- attribution_result = llm_attribution.attribute(
20
- attribution_input, gen_args=model.CONFIG.to_dict()
21
- )
22
-
23
- # extracting values and input tokens
24
- values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
25
- input_tokens = fmt.format_tokens(attribution_result.input_tokens)
26
-
27
- # raising error if mismatch occurs
28
- if len(attribution_result.input_tokens) != len(values):
29
- raise RuntimeError("values and input len mismatch")
30
-
31
- # getting response text, graphic placeholder and marked text object
32
- response_text = fmt.format_output_text(attribution_result.output_tokens)
33
- graphic = (
34
- "<div style='text-align: center; font-family:arial;'><h4>Attention"
35
- "Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
36
- )
37
- marked_text = markup_text(input_tokens, values, variant="captum")
38
-
39
- # return response, graphic and marked_text array
40
- return response_text, graphic, marked_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explanation/interpret_shap.py CHANGED
@@ -1,72 +0,0 @@
1
- # interpret module that implements the interpretability method
2
-
3
- # external imports
4
- from shap import models, maskers, plots, PartitionExplainer
5
- import torch
6
-
7
- # internal imports
8
- from utils import formatting as fmt
9
- from .markup import markup_text
10
-
11
- # global variables
12
- TEACHER_FORCING = None
13
- TEXT_MASKER = None
14
-
15
-
16
- # main explain function that returns a chat with explanations
17
- def chat_explained(model, prompt):
18
- model.set_config({})
19
-
20
- # create the shap explainer
21
- shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
22
-
23
- # get the shap values for the prompt
24
- shap_values = shap_explainer([prompt])
25
-
26
- # create the explanation graphic and marked text array
27
- graphic = create_graphic(shap_values)
28
- marked_text = markup_text(
29
- shap_values.data[0], shap_values.values[0], variant="shap"
30
- )
31
-
32
- # create the response text
33
- response_text = fmt.format_output_text(shap_values.output_names)
34
-
35
- # return response, graphic and marked_text array
36
- return response_text, graphic, marked_text
37
-
38
-
39
- # function used to wrap the model with a shap model
40
- def wrap_shap(model):
41
- # calling global variants
42
- global TEXT_MASKER, TEACHER_FORCING
43
-
44
- # set the device to cuda if gpu is available
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
-
47
- # updating the model settings
48
- model.set_config()
49
-
50
- # (re)initialize the shap models and masker
51
- # creating a shap text_generation model
52
- text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
53
- # wrapping the text generation model in a teacher forcing model
54
- TEACHER_FORCING = models.TeacherForcing(
55
- text_generation,
56
- model.TOKENIZER,
57
- device=str(device),
58
- similarity_model=model.MODEL,
59
- similarity_tokenizer=model.TOKENIZER,
60
- )
61
- # setting the text masker as an empty string
62
- TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
63
-
64
-
65
- # graphic plotting function that creates a html graphic (as string) for the explanation
66
- def create_graphic(shap_values):
67
-
68
- # create the html graphic using shap text plot function
69
- graphic_html = plots.text(shap_values, display=False)
70
-
71
- # return the html graphic as string to display in iFrame
72
- return str(graphic_html)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explanation/markup.py CHANGED
@@ -66,16 +66,16 @@ def color_codes():
66
  return {
67
  # -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
68
  # 0: white (assuming default light mode)
69
- # +1 to +5 light pink to string magenta
70
- "-5": "#3251a8", # Strong Light Sky Blue
71
- "-4": "#5A7FB2", # Slightly Lighter Sky Blue
72
- "-3": "#8198BC", # Intermediate Sky Blue
73
- "-2": "#A8B1C6", # Light Sky Blue
74
- "-1": "#E6F0FF", # Very Light Sky Blue
75
- "0": "#FFFFFF", # White
76
- "+1": "#FFE6F0", # Lighter Pink
77
- "+2": "#DF8CA3", # Slightly Stronger Pink
78
- "+3": "#D7708E", # Intermediate Pink
79
- "+4": "#CF5480", # Deep Pink
80
- "+5": "#A83273", # Strong Magenta
81
  }
 
66
  return {
67
  # -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
68
  # 0: white (assuming default light mode)
69
+ # +1 to +5 light pink to strng magenta
70
+ "-5": "#008bfb",
71
+ "-4": "#68a1fd",
72
+ "-3": "#96b7fe",
73
+ "-2": "#bcceff",
74
+ "-1:": "#dee6ff",
75
+ "0": "#ffffff",
76
+ "1": "#ffd9d9",
77
+ "2": "#ffb3b5",
78
+ "3": "#ff8b92",
79
+ "4": "#ff5c71",
80
+ "5": "#ff0051",
81
  }
explanation/plotting.py ADDED
File without changes
explanation/visualize.py DELETED
@@ -1,52 +0,0 @@
1
- # visualization module that creates an attention visualization
2
-
3
-
4
- # internal imports
5
- from utils import formatting as fmt
6
- from .markup import markup_text
7
-
8
-
9
- # chat function that returns an answer
10
- # and marked text based on attention
11
- def chat_explained(model, prompt):
12
-
13
- # get encoded input
14
- encoder_input_ids = model.TOKENIZER(
15
- prompt, return_tensors="pt", add_special_tokens=True
16
- ).input_ids
17
- # generate output together with attentions of the model
18
- decoder_input_ids = model.MODEL.generate(
19
- encoder_input_ids, output_attentions=True, **model.CONFIG
20
- )
21
-
22
- # get input and output text as list of strings
23
- encoder_text = fmt.format_tokens(
24
- model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
25
- )
26
- decoder_text = fmt.format_tokens(
27
- model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
28
- )
29
-
30
- # get attention values for the input and output vectors
31
- # using already generated input and output
32
- attention_output = model.MODEL(
33
- input_ids=encoder_input_ids,
34
- decoder_input_ids=decoder_input_ids,
35
- output_attentions=True,
36
- )
37
-
38
- # averaging attention across layers
39
- averaged_attention = fmt.avg_attention(attention_output)
40
-
41
- # format response text for clean output
42
- response_text = fmt.format_output_text(decoder_text)
43
- # setting placeholder for iFrame graphic
44
- graphic = (
45
- "<div style='text-align: center; font-family:arial;'><h4>Attention"
46
- " Visualization doesn't support an interactive graphic.</h4></div>"
47
- )
48
- # creating marked text using markup_text function and attention
49
- marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
50
-
51
- # returning response, graphic and marked text array
52
- return response_text, graphic, marked_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
explanation/visualize_att.py ADDED
File without changes
model/mistral.py CHANGED
@@ -41,13 +41,11 @@ CONFIG.update(**{
41
 
42
 
43
  # function to (re) set config
44
- def set_config(config: dict):
45
 
46
- # if config dict is given, update it
47
- if config != {}:
48
- CONFIG.update(**dict)
49
- else:
50
- CONFIG.update(**{
51
  "temperature": 0.7,
52
  "max_new_tokens": 50,
53
  "max_length": 50,
@@ -55,7 +53,9 @@ def set_config(config: dict):
55
  "repetition_penalty": 1.2,
56
  "do_sample": True,
57
  "seed": 42,
58
- })
 
 
59
 
60
 
61
  # advanced formatting function that takes into a account a conversation history
@@ -77,9 +77,9 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
77
  """
78
  else:
79
  # takes the very first exchange and the system prompt as base
80
- prompt = (
81
- f"<s>[INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]}</s>"
82
- )
83
 
84
  # adds conversation history to the prompt
85
  for conversation in history[1:]:
 
41
 
42
 
43
  # function to (re) set config
44
+ def set_config(config_dict: dict):
45
 
46
+ # if config dict is not given, set to default
47
+ if config_dict == {}:
48
+ config_dict = {
 
 
49
  "temperature": 0.7,
50
  "max_new_tokens": 50,
51
  "max_length": 50,
 
53
  "repetition_penalty": 1.2,
54
  "do_sample": True,
55
  "seed": 42,
56
+ }
57
+
58
+ CONFIG.update(**dict)
59
 
60
 
61
  # advanced formatting function that takes into a account a conversation history
 
77
  """
78
  else:
79
  # takes the very first exchange and the system prompt as base
80
+ prompt = f"""
81
+ <s>[INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]}</s>
82
+ """
83
 
84
  # adds conversation history to the prompt
85
  for conversation in history[1:]:
requirements.txt CHANGED
@@ -2,7 +2,7 @@ gradio~=4.7.1
2
  transformers~=4.35.2
3
  torch~=2.1.1
4
  shap
5
- captum
6
  bertviz~=1.4.0
7
  accelerate~=0.24.1
8
  bitsandbytes
@@ -13,9 +13,7 @@ uvicorn~=0.24.0
13
  tinydb~=4.8.0
14
  black~=23.12.0
15
  pylint~=3.0.0
16
- seaborn~=0.13.0
17
  numpy
18
  matplotlib
19
  pre-commit
20
- ipython
21
  gradio-iframe~=0.0.10
 
2
  transformers~=4.35.2
3
  torch~=2.1.1
4
  shap
5
+ captum @ git+https://github.com/LennardZuendorf/thesis-captum.git
6
  bertviz~=1.4.0
7
  accelerate~=0.24.1
8
  bitsandbytes
 
13
  tinydb~=4.8.0
14
  black~=23.12.0
15
  pylint~=3.0.0
 
16
  numpy
17
  matplotlib
18
  pre-commit
 
19
  gradio-iframe~=0.0.10