LennardZuendorf commited on
Commit
a597c76
1 Parent(s): f301e04

feat/fixing: correcting bug, updating documentation (final?)

Browse files
README.md CHANGED
@@ -21,7 +21,7 @@ This is the UI showcase for my thesis about the interpretability of LLM based ch
21
 
22
  ### 🔗 Links:
23
 
24
- **[Github Repository](https://github.com/LennardZuendorf/thesis-webapp)**
25
 
26
  **[Huggingface Spaces Showcase](https://huggingface.co/spaces/lennardzuendorf/thesis-webapp-docker)**
27
 
@@ -86,7 +86,7 @@ See code for in detailed credits, work is strongly based on:
86
 
87
  #### SHAP
88
  - [Github](https://github.com/shap/shap)
89
- - [Inital Paper](https://arxiv.org/abs/1705.07874)
90
 
91
  #### Custom Component (/components/iframe/)
92
 
 
21
 
22
  ### 🔗 Links:
23
 
24
+ **[GitHub Repository](https://github.com/LennardZuendorf/thesis-webapp)**
25
 
26
  **[Huggingface Spaces Showcase](https://huggingface.co/spaces/lennardzuendorf/thesis-webapp-docker)**
27
 
 
86
 
87
  #### SHAP
88
  - [Github](https://github.com/shap/shap)
89
+ - [Initial Paper](https://arxiv.org/abs/1705.07874)
90
 
91
  #### Custom Component (/components/iframe/)
92
 
backend/controller.py CHANGED
@@ -43,6 +43,7 @@ def explained_chat(
43
  # message, history, system_prompt, knowledge
44
  # )
45
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
 
46
 
47
  # generating an answer using the methods chat function
48
  answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
@@ -73,10 +74,10 @@ def interference(
73
  # if a model is selected, grab the model instance
74
  if model_selection.lower() == "mistral":
75
  model = mistral
76
- print("Indentified model as Mistral")
77
  else:
78
  model = godel
79
- print("Indentified model as GODEL")
80
 
81
  # if a XAI approach is selected, grab the XAI module instance
82
  # and call the explained chat function
 
43
  # message, history, system_prompt, knowledge
44
  # )
45
  prompt = model.format_prompt(message, history, system_prompt, knowledge)
46
+ print(f"Formatted prompt: {prompt}")
47
 
48
  # generating an answer using the methods chat function
49
  answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
 
74
  # if a model is selected, grab the model instance
75
  if model_selection.lower() == "mistral":
76
  model = mistral
77
+ print("Identified model as Mistral")
78
  else:
79
  model = godel
80
+ print("Identified model as GODEL")
81
 
82
  # if a XAI approach is selected, grab the XAI module instance
83
  # and call the explained chat function
components/iframe/README.md CHANGED
@@ -1,7 +1,7 @@
1
  # gradio iFrame
2
 
3
  This is a custom gradio component used to display the shap package text plot. Which is interactive HTML and needs a custom wrapper.
4
- See custom component examples at offical [docu](https://www.gradio.app/guides/custom-components-in-five-minutes)
5
 
6
  # Credit
7
  CREDIT: based mostly of Gradio template component, HTML
@@ -14,4 +14,4 @@ see: https://www.gradio.app/docs/html
14
  - backend/iframe.py - updating component to accept custom height/width and added new example
15
  - demo/app.py - slightly changed demo file for better dev experience
16
  - frontend/index.svelte - slightly changed to accept custom height/width
17
- - frontend/HTML.svelte - updated to use iFrame and added custom function to programmtically set heigth values
 
1
  # gradio iFrame
2
 
3
  This is a custom gradio component used to display the shap package text plot. Which is interactive HTML and needs a custom wrapper.
4
+ See custom component examples at official [docu](https://www.gradio.app/guides/custom-components-in-five-minutes)
5
 
6
  # Credit
7
  CREDIT: based mostly of Gradio template component, HTML
 
14
  - backend/iframe.py - updating component to accept custom height/width and added new example
15
  - demo/app.py - slightly changed demo file for better dev experience
16
  - frontend/index.svelte - slightly changed to accept custom height/width
17
+ - frontend/HTML.svelte - updated to use iFrame and added custom function to programmatically set height values
explanation/attention.py CHANGED
@@ -11,6 +11,8 @@ from .markup import markup_text
11
  # and marked text based on attention
12
  def chat_explained(model, prompt):
13
 
 
 
14
  # get encoded input
15
  input_ids = model.TOKENIZER(
16
  prompt, return_tensors="pt", add_special_tokens=True
@@ -56,6 +58,7 @@ def chat_explained(model, prompt):
56
  " Visualization doesn't support an interactive graphic.</h4></div>"
57
  )
58
  # creating marked text using markup_text function and attention
 
59
  marked_text = markup_text(input_text, averaged_attention, variant="visualizer")
60
 
61
  # returning response, graphic and marked text array
 
11
  # and marked text based on attention
12
  def chat_explained(model, prompt):
13
 
14
+ print(f"Running explained chat with prompt {prompt}.")
15
+
16
  # get encoded input
17
  input_ids = model.TOKENIZER(
18
  prompt, return_tensors="pt", add_special_tokens=True
 
58
  " Visualization doesn't support an interactive graphic.</h4></div>"
59
  )
60
  # creating marked text using markup_text function and attention
61
+ print(f"Creating marked text with {input_text}.")
62
  marked_text = markup_text(input_text, averaged_attention, variant="visualizer")
63
 
64
  # returning response, graphic and marked text array
explanation/interpret_captum.py CHANGED
@@ -47,7 +47,7 @@ def chat_explained(model, prompt):
47
  # getting response text, graphic placeholder and marked text object
48
  response_text = fmt.format_output_text(attribution_result.output_tokens)
49
  graphic = """<div style='text-align: center; font-family:arial;'><h4>
50
- Intepretation with Captum doesn't support an interactive graphic.</h4></div>
51
  """
52
  # create the explanation marked text array
53
  marked_text = markup_text(input_tokens, values, variant="captum")
 
47
  # getting response text, graphic placeholder and marked text object
48
  response_text = fmt.format_output_text(attribution_result.output_tokens)
49
  graphic = """<div style='text-align: center; font-family:arial;'><h4>
50
+ Interpretation with Captum doesn't support an interactive graphic.</h4></div>
51
  """
52
  # create the explanation marked text array
53
  marked_text = markup_text(input_tokens, values, variant="captum")
explanation/markup.py CHANGED
@@ -21,7 +21,7 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
21
  elif variant == "visualizer":
22
  text_values = fmt.flatten_attention(text_values)
23
 
24
- # Determine the minimum and maximum values
25
  min_val, max_val = np.min(text_values), np.max(text_values)
26
 
27
  # separate the threshold calculation for negative and positive values
@@ -69,7 +69,7 @@ def color_codes():
69
  return {
70
  # -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
71
  # 0: white (assuming default light mode)
72
- # +1 to +5 light pink to strng magenta
73
  "-5": "#008bfb",
74
  "-4": "#68a1fd",
75
  "-3": "#96b7fe",
 
21
  elif variant == "visualizer":
22
  text_values = fmt.flatten_attention(text_values)
23
 
24
+ # determine the minimum and maximum values
25
  min_val, max_val = np.min(text_values), np.max(text_values)
26
 
27
  # separate the threshold calculation for negative and positive values
 
69
  return {
70
  # -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
71
  # 0: white (assuming default light mode)
72
+ # +1 to +5 light pink to strong magenta
73
  "-5": "#008bfb",
74
  "-4": "#68a1fd",
75
  "-3": "#96b7fe",
explanation/plotting.py CHANGED
@@ -7,24 +7,24 @@ import matplotlib.pyplot as plt
7
 
8
  def plot_seq(seq_values: list, method: str = ""):
9
 
10
- # Separate the tokens and their corresponding importance values
11
  tokens, importance = zip(*seq_values)
12
 
13
- # Convert importance values to numpy array for conditional coloring
14
  importance = np.array(importance)
15
 
16
- # Determine the colors based on the sign of the importance values
17
  colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
18
 
19
- # Create a bar plot
20
  plt.figure(figsize=(len(tokens) * 0.9, np.max(importance)))
21
  x_positions = range(len(tokens)) # Positions for the bars
22
 
23
- # Creating vertical bar plot
24
  bar_width = 0.8
25
  plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
26
 
27
- # Annotating each bar with its value
28
  padding = 0.1 # Padding for text annotation
29
  for x, (y, color) in enumerate(zip(importance, colors)):
30
  sign = "+" if y > 0 else ""
 
7
 
8
  def plot_seq(seq_values: list, method: str = ""):
9
 
10
+ # separate the tokens and their corresponding importance values
11
  tokens, importance = zip(*seq_values)
12
 
13
+ # convert importance values to numpy array for conditional coloring
14
  importance = np.array(importance)
15
 
16
+ # determine the colors based on the sign of the importance values
17
  colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
18
 
19
+ # create a bar plot
20
  plt.figure(figsize=(len(tokens) * 0.9, np.max(importance)))
21
  x_positions = range(len(tokens)) # Positions for the bars
22
 
23
+ # creating vertical bar plot
24
  bar_width = 0.8
25
  plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
26
 
27
+ # annotating each bar with its value
28
  padding = 0.1 # Padding for text annotation
29
  for x, (y, color) in enumerate(zip(importance, colors)):
30
  sign = "+" if y > 0 else ""
main.py CHANGED
@@ -26,7 +26,7 @@ css = """
26
  .examples {text-align: start;}
27
  .seperatedRow {border-top: 1rem solid;}",
28
  """
29
- # custom js to force lightmode in custom environments
30
  if os.environ["HOSTING"].lower() != "spaces":
31
  js = """
32
  function () {
@@ -52,6 +52,12 @@ def load_md(path):
52
 
53
  # function to display the system prompt info
54
  def system_prompt_info(sys_prompt_txt):
 
 
 
 
 
 
55
  # display the system prompt using the Gradio Info component
56
  gr.Info(f"The system prompt was set to:\n {sys_prompt_txt}")
57
 
@@ -71,7 +77,7 @@ def model_info(model_radio):
71
 
72
 
73
  # ui interface based on Gradio Blocks
74
- # see https://www.gradio.app/docs/interface)
75
  with gr.Blocks(
76
  css=css,
77
  js=js,
@@ -171,11 +177,11 @@ with gr.Blocks(
171
  show_copy_button=True,
172
  avatar_images=("./public/human.jpg", "./public/bot.jpg"),
173
  )
174
- # extenable components for extra knowledge
175
  with gr.Accordion(label="Additional Knowledge", open=False):
176
  gr.Markdown("""
177
  *Hint:* Add extra knowledge to see GODEL work the best.
178
- Knowledge doesn't work mith Mistral and will be ignored.
179
  """)
180
  # textbox to enter the knowledge
181
  knowledge_input = gr.Textbox(
@@ -217,8 +223,8 @@ with gr.Blocks(
217
  "Does money buy happiness?",
218
  "",
219
  (
220
- "Respond from the perspective of a billionaire enjoying"
221
- " life in Dubai"
222
  ),
223
  "Mistral",
224
  "None",
@@ -227,8 +233,8 @@ with gr.Blocks(
227
  "Does money buy happiness?",
228
  "",
229
  (
230
- "Respond from the perspective of a billionaire enjoying"
231
- " life in Dubai"
232
  ),
233
  "Mistral",
234
  "SHAP",
@@ -251,14 +257,36 @@ with gr.Blocks(
251
  [
252
  "Does money buy happiness?",
253
  (
254
- "Black holes are created when a massive star's core"
255
- " collapses after a supernova, forming an object with"
256
- " gravity so intense that even light cannot escape."
 
 
257
  ),
258
  "",
259
  "GODEL",
260
  "SHAP",
261
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  ],
263
  inputs=[
264
  user_prompt,
@@ -332,7 +360,7 @@ with gr.Blocks(
332
  # load about.md markdown
333
  gr.Markdown(value=load_md("public/about.md"))
334
  with gr.Accordion(label="Credits, Data Protection, License"):
335
- # load credits and dataprotection markdown
336
  gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
337
 
338
  # mount function for fastAPI Application
 
26
  .examples {text-align: start;}
27
  .seperatedRow {border-top: 1rem solid;}",
28
  """
29
+ # custom js to force light mode in custom environments
30
  if os.environ["HOSTING"].lower() != "spaces":
31
  js = """
32
  function () {
 
52
 
53
  # function to display the system prompt info
54
  def system_prompt_info(sys_prompt_txt):
55
+ if sys_prompt_txt == "":
56
+ sys_prompt_txt = """
57
+ You are a helpful, respectful and honest assistant.
58
+ Always answer as helpfully as possible, while being safe.
59
+ """
60
+
61
  # display the system prompt using the Gradio Info component
62
  gr.Info(f"The system prompt was set to:\n {sys_prompt_txt}")
63
 
 
77
 
78
 
79
  # ui interface based on Gradio Blocks
80
+ # see https://www.gradio.app/docs/interface
81
  with gr.Blocks(
82
  css=css,
83
  js=js,
 
177
  show_copy_button=True,
178
  avatar_images=("./public/human.jpg", "./public/bot.jpg"),
179
  )
180
+ # extendable components for extra knowledge
181
  with gr.Accordion(label="Additional Knowledge", open=False):
182
  gr.Markdown("""
183
  *Hint:* Add extra knowledge to see GODEL work the best.
184
+ Knowledge doesn't work with Mistral and will be ignored.
185
  """)
186
  # textbox to enter the knowledge
187
  knowledge_input = gr.Textbox(
 
223
  "Does money buy happiness?",
224
  "",
225
  (
226
+ "Respond from the perspective of billionaire heir"
227
+ " living his best life with his father's money."
228
  ),
229
  "Mistral",
230
  "None",
 
233
  "Does money buy happiness?",
234
  "",
235
  (
236
+ "Respond from the perspective of billionaire heir"
237
+ " living his best life with his father's money."
238
  ),
239
  "Mistral",
240
  "SHAP",
 
257
  [
258
  "Does money buy happiness?",
259
  (
260
+ "Some studies have found a correlation between income"
261
+ " and happiness, but this relationship often has"
262
+ " diminishing returns. From a psychological standpoint,"
263
+ " it's not just having money, but how it is used that"
264
+ " influences happiness."
265
  ),
266
  "",
267
  "GODEL",
268
  "SHAP",
269
  ],
270
+ [
271
+ "Does money buy happiness?",
272
+ (
273
+ "Some studies have found a correlation between income"
274
+ " and happiness, but this relationship often has"
275
+ " diminishing returns. From a psychological standpoint,"
276
+ " it's not just having money, but how it is used that"
277
+ " influences happiness."
278
+ ),
279
+ "",
280
+ "GODEL",
281
+ "Attention",
282
+ ],
283
+ [
284
+ "Does money buy happiness?",
285
+ "",
286
+ "",
287
+ "GODEL",
288
+ "Attention",
289
+ ],
290
  ],
291
  inputs=[
292
  user_prompt,
 
360
  # load about.md markdown
361
  gr.Markdown(value=load_md("public/about.md"))
362
  with gr.Accordion(label="Credits, Data Protection, License"):
363
+ # load credits and data protection markdown
364
  gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
365
 
366
  # mount function for fastAPI Application
model/godel.py CHANGED
@@ -6,7 +6,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
6
  # internal imports
7
  from utils import modelling as mdl
8
 
9
- # global model and tokenizer instance (created on inital build)
10
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
11
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
12
 
 
6
  # internal imports
7
  from utils import modelling as mdl
8
 
9
+ # global model and tokenizer instance (created on initial build)
10
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
11
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
12
 
model/mistral.py CHANGED
@@ -9,7 +9,8 @@ import gradio as gr
9
  from utils import modelling as mdl
10
  from utils import formatting as fmt
11
 
12
- # global model and tokenizer instance (created on inital build)
 
13
  device = mdl.get_device()
14
  if device == torch.device("cuda"):
15
  n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()
@@ -17,13 +18,15 @@ if device == torch.device("cuda"):
17
  MODEL = AutoModelForCausalLM.from_pretrained(
18
  "mistralai/Mistral-7B-Instruct-v0.2",
19
  quantization_config=bnb_config,
20
- device_map="auto", # dispatch efficiently the model on the available ressources
21
  max_memory={i: max_memory for i in range(n_gpus)},
22
  )
23
 
 
24
  else:
25
  MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
26
  MODEL.to(device)
 
27
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
28
 
29
  # default model config
@@ -48,12 +51,13 @@ def set_config(config_dict: dict):
48
  CONFIG.update(**config_dict)
49
 
50
 
51
- # advanced formatting function that takes into a account a conversation history
52
- # CREDIT: adapated from the Mistral AI Instruct chat template
53
  # see https://github.com/chujiezheng/chat_templates/
54
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
55
  prompt = ""
56
 
 
57
  if knowledge != "":
58
  gr.Info("""
59
  Mistral doesn't support additional knowledge, it's gonna be ignored.
@@ -94,7 +98,7 @@ def format_answer(answer: str):
94
 
95
  # checking if proper history got returned
96
  if len(segments) > 1:
97
- # return text after the last ['/INST'] - reponse to last message
98
  formatted_answer = segments[-1].strip()
99
  else:
100
  # return warning and full answer if not enough [/INST] tokens found
@@ -108,7 +112,11 @@ def format_answer(answer: str):
108
  return formatted_answer
109
 
110
 
 
 
 
111
  def respond(prompt: str):
 
112
  set_config({})
113
 
114
  # tokenizing inputs and configuring model
@@ -117,6 +125,9 @@ def respond(prompt: str):
117
  # generating text with tokenized input, returning output
118
  output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
119
  output_text = TOKENIZER.batch_decode(output_ids)
 
 
120
  output_text = fmt.format_output_text(output_text)
121
 
 
122
  return format_answer(output_text)
 
9
  from utils import modelling as mdl
10
  from utils import formatting as fmt
11
 
12
+ # global model and tokenizer instance (created on initial build)
13
+ # determine if GPU is available and load model accordingly
14
  device = mdl.get_device()
15
  if device == torch.device("cuda"):
16
  n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()
 
18
  MODEL = AutoModelForCausalLM.from_pretrained(
19
  "mistralai/Mistral-7B-Instruct-v0.2",
20
  quantization_config=bnb_config,
21
+ device_map="auto",
22
  max_memory={i: max_memory for i in range(n_gpus)},
23
  )
24
 
25
+ # otherwise, load model on CPU
26
  else:
27
  MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
28
  MODEL.to(device)
29
+ # load tokenizer
30
  TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
31
 
32
  # default model config
 
51
  CONFIG.update(**config_dict)
52
 
53
 
54
+ # advanced formatting function that takes into account a conversation history
55
+ # CREDIT: adapted from the Mistral AI Instruct chat template
56
  # see https://github.com/chujiezheng/chat_templates/
57
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
58
  prompt = ""
59
 
60
+ # send information to the ui if knowledge is not empty
61
  if knowledge != "":
62
  gr.Info("""
63
  Mistral doesn't support additional knowledge, it's gonna be ignored.
 
98
 
99
  # checking if proper history got returned
100
  if len(segments) > 1:
101
+ # return text after the last ['/INST'] - response to last message
102
  formatted_answer = segments[-1].strip()
103
  else:
104
  # return warning and full answer if not enough [/INST] tokens found
 
112
  return formatted_answer
113
 
114
 
115
+ # response class calling the model and returning the model output message
116
+ # CREDIT: Copied from official interference example on Huggingface
117
+ # see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
118
  def respond(prompt: str):
119
+ # setting config to default
120
  set_config({})
121
 
122
  # tokenizing inputs and configuring model
 
125
  # generating text with tokenized input, returning output
126
  output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
127
  output_text = TOKENIZER.batch_decode(output_ids)
128
+
129
+ # formatting output text with special function
130
  output_text = fmt.format_output_text(output_text)
131
 
132
+ # returning the model output string
133
  return format_answer(output_text)
utils/formatting.py CHANGED
@@ -100,7 +100,7 @@ def avg_attention(attention_values, model: str):
100
 
101
  # removing the last dimension and transposing to get the correct shape
102
  attention = attention[:, :, :, 0]
103
- attention = attention.transpose
104
 
105
  # return the averaged attention values
106
  return np.mean(attention, axis=1)
 
100
 
101
  # removing the last dimension and transposing to get the correct shape
102
  attention = attention[:, :, :, 0]
103
+ attention = attention.transpose()
104
 
105
  # return the averaged attention values
106
  return np.mean(attention, axis=1)
utils/modelling.py CHANGED
@@ -45,7 +45,7 @@ def prompt_limiter(
45
 
46
  # if token count small enough, adding history bit by bit
47
  if pre_count < 800:
48
- # setting the count to the precount
49
  count = pre_count
50
  # reversing the history to prioritize recent conversations
51
  history.reverse()
@@ -76,6 +76,7 @@ def token_counter(tokenizer, text: str):
76
  return len(tokens[0])
77
 
78
 
 
79
  def get_device():
80
  if torch.cuda.is_available():
81
  device = torch.device("cuda")
@@ -85,7 +86,9 @@ def get_device():
85
  return device
86
 
87
 
88
- # setting device based on available hardware
 
 
89
  def gpu_loading_config(max_memory: str = "15000MB"):
90
  n_gpus = torch.cuda.device_count()
91
 
 
45
 
46
  # if token count small enough, adding history bit by bit
47
  if pre_count < 800:
48
+ # setting the count to the pre-count
49
  count = pre_count
50
  # reversing the history to prioritize recent conversations
51
  history.reverse()
 
76
  return len(tokens[0])
77
 
78
 
79
+ # function to determine the device to use
80
  def get_device():
81
  if torch.cuda.is_available():
82
  device = torch.device("cuda")
 
86
  return device
87
 
88
 
89
+ # function to set device config
90
+ # CREDIT: Adapted from captum llama 2 example
91
+ # see https://captum.ai/tutorials/Llama2_LLM_Attribution
92
  def gpu_loading_config(max_memory: str = "15000MB"):
93
  n_gpus = torch.cuda.device_count()
94