Spaces:
Runtime error
Runtime error
Commit
โข
43cce2a
1
Parent(s):
7e6f74e
feat: implementing everything for release version 1.0.0
Browse files- .dockerignore +5 -1
- README.md +6 -8
- backend/controller.py +32 -34
- explanation/interpret.py +83 -46
- explanation/visualize.py +98 -9
- main.py +82 -73
- model/godel.py +33 -8
- model/mistral.py +0 -71
- public/credits_dataprotection_license.md +11 -15
- pyproject.toml +2 -0
- railway.json +0 -13
- utils/__init__.py +0 -0
- utils/formatting.py +53 -0
- utils/modelling.py +69 -0
.dockerignore
CHANGED
@@ -3,6 +3,10 @@ Compose.yaml
|
|
3 |
Dockerfile-Base
|
4 |
Dockerfile-Light
|
5 |
entrypoint.sh
|
6 |
-
|
|
|
|
|
|
|
|
|
7 |
/components/
|
8 |
/components/*
|
|
|
3 |
Dockerfile-Base
|
4 |
Dockerfile-Light
|
5 |
entrypoint.sh
|
6 |
+
.gitignore
|
7 |
+
.github
|
8 |
+
.git
|
9 |
+
.pre-commit-config.yaml
|
10 |
+
start-venv.sh
|
11 |
/components/
|
12 |
/components/*
|
README.md
CHANGED
@@ -17,28 +17,26 @@ app_port: 8080
|
|
17 |
|
18 |
## ๐ Links:
|
19 |
|
20 |
-
**[Github Repository](https://github.com/LennardZuendorf/thesis)**
|
|
|
21 |
|
22 |
## ๐๏ธ Tech Stack:
|
23 |
|
24 |
-
**Language and Framework:** Python
|
25 |
|
26 |
-
**Noteable Packages:** ๐ค Transformers, Gradio, SHAP, BERTViz
|
27 |
|
28 |
## ๐จโ๐ป Author and Credits:</h2>
|
29 |
|
30 |
-
|
31 |
**Author:** [@LennardZuendorf](https://github.com/LennardZuendorf)
|
32 |
|
33 |
**Thesis Supervisor**: [Prof. Dr. Simbeck](https://www.htw-berlin.de/hochschule/personen/person/?eid=9862)
|
34 |
<br> Second Corrector: [Prof. Dr. Hochstein](https://www.htw-berlin.de/hochschule/personen/person/?eid=10628)
|
35 |
|
|
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
- Mistral AI
|
40 |
- SHAP:
|
41 |
- BERTViz:
|
42 |
|
43 |
-
|
44 |
This Project was part of my studies of Business Computing at University of Applied Science for Technology and Business Berlin (HTW Berlin).
|
|
|
17 |
|
18 |
## ๐ Links:
|
19 |
|
20 |
+
**[Github Repository](https://github.com/LennardZuendorf/thesis-webapp)**
|
21 |
+
**[Huggingface Spaces Showcase](https://huggingface.co/spaces/lennardzuendorf/thesis-webapp-docker**
|
22 |
|
23 |
## ๐๏ธ Tech Stack:
|
24 |
|
25 |
+
**Language and Framework:** Python
|
26 |
|
27 |
+
**Noteable Packages:** ๐ค Transformers, FastAPI, Gradio, SHAP, BERTViz
|
28 |
|
29 |
## ๐จโ๐ป Author and Credits:</h2>
|
30 |
|
|
|
31 |
**Author:** [@LennardZuendorf](https://github.com/LennardZuendorf)
|
32 |
|
33 |
**Thesis Supervisor**: [Prof. Dr. Simbeck](https://www.htw-berlin.de/hochschule/personen/person/?eid=9862)
|
34 |
<br> Second Corrector: [Prof. Dr. Hochstein](https://www.htw-berlin.de/hochschule/personen/person/?eid=10628)
|
35 |
|
36 |
+
See code for in detailed credits, work is based on
|
37 |
|
38 |
+
- GODEL:
|
|
|
|
|
39 |
- SHAP:
|
40 |
- BERTViz:
|
41 |
|
|
|
42 |
This Project was part of my studies of Business Computing at University of Applied Science for Technology and Business Berlin (HTW Berlin).
|
backend/controller.py
CHANGED
@@ -5,17 +5,18 @@
|
|
5 |
import gradio as gr
|
6 |
|
7 |
# internal imports
|
8 |
-
from model import
|
9 |
from explanation import interpret, visualize
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
|
|
13 |
def interference(
|
14 |
-
prompt,
|
15 |
-
history,
|
16 |
-
|
17 |
-
|
18 |
-
xai_selection,
|
19 |
):
|
20 |
# if no system prompt is given, use a default one
|
21 |
if system_prompt == "":
|
@@ -24,20 +25,7 @@ def interference(
|
|
24 |
Always answer as helpfully as possible, while being safe.
|
25 |
"""
|
26 |
|
27 |
-
#
|
28 |
-
match model_selection.lower():
|
29 |
-
case "mistral":
|
30 |
-
model = mistral
|
31 |
-
case "godel":
|
32 |
-
model = godel
|
33 |
-
case _:
|
34 |
-
# use Gradio warning to display error message
|
35 |
-
gr.Warning(
|
36 |
-
f'There was an error in the selected model. It is "{model_selection}"'
|
37 |
-
)
|
38 |
-
raise RuntimeError("There was an error in the selected model.")
|
39 |
-
|
40 |
-
# additionally, if the XAI approach is selected, grab the XAI instance
|
41 |
if xai_selection in ("SHAP", "Visualizer"):
|
42 |
match xai_selection.lower():
|
43 |
case "shap":
|
@@ -46,33 +34,39 @@ def interference(
|
|
46 |
xai = visualize
|
47 |
case _:
|
48 |
# use Gradio warning to display error message
|
49 |
-
gr.Warning(
|
50 |
-
f"""
|
51 |
There was an error in the selected XAI Approach.
|
52 |
It is "{xai_selection}"
|
53 |
-
"""
|
54 |
-
)
|
55 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
56 |
|
57 |
# call the explained chat function
|
58 |
prompt_output, history_output, xai_graphic, xai_plot = explained_chat(
|
59 |
-
model=
|
60 |
xai=xai,
|
61 |
message=prompt,
|
62 |
history=history,
|
63 |
system_prompt=system_prompt,
|
|
|
64 |
)
|
65 |
# if no (or invalid) XAI approach is selected call the vanilla chat function
|
66 |
else:
|
67 |
# call the vanilla chat function
|
68 |
prompt_output, history_output = vanilla_chat(
|
69 |
-
model=
|
70 |
message=prompt,
|
71 |
history=history,
|
72 |
system_prompt=system_prompt,
|
|
|
73 |
)
|
74 |
# set XAI outputs to disclaimer html/none
|
75 |
-
xai_graphic, xai_plot =
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
# return the outputs
|
78 |
return prompt_output, history_output, xai_graphic, xai_plot
|
@@ -80,27 +74,31 @@ def interference(
|
|
80 |
|
81 |
# simple chat function that calls the model
|
82 |
# formats prompts, calls for an answer and returns updated conversation history
|
83 |
-
def vanilla_chat(
|
|
|
|
|
84 |
# formatting the prompt using the model's format_prompt function
|
85 |
-
prompt = model.format_prompt(message, history, system_prompt)
|
86 |
# generating an answer using the model's respond function
|
87 |
answer = model.respond(prompt)
|
88 |
|
89 |
# updating the chat history with the new answer
|
90 |
-
history.append((
|
91 |
|
92 |
# returning the updated history
|
93 |
return "", history
|
94 |
|
95 |
|
96 |
-
def explained_chat(
|
|
|
|
|
97 |
# formatting the prompt using the model's format_prompt function
|
98 |
-
prompt = model.format_prompt(message, history, system_prompt)
|
99 |
|
100 |
# generating an answer using the xai methods explain and respond function
|
101 |
answer, xai_graphic, xai_plot = xai.chat_explained(model, prompt)
|
102 |
# updating the chat history with the new answer
|
103 |
-
history.append((
|
104 |
|
105 |
# returning the updated history, xai graphic and xai plot elements
|
106 |
-
return "",
|
|
|
5 |
import gradio as gr
|
6 |
|
7 |
# internal imports
|
8 |
+
from model import godel
|
9 |
from explanation import interpret, visualize
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
13 |
+
# TODO: Limit maximum tokens/model input
|
14 |
def interference(
|
15 |
+
prompt: str,
|
16 |
+
history: list,
|
17 |
+
knowledge: str,
|
18 |
+
system_prompt: str,
|
19 |
+
xai_selection: str,
|
20 |
):
|
21 |
# if no system prompt is given, use a default one
|
22 |
if system_prompt == "":
|
|
|
25 |
Always answer as helpfully as possible, while being safe.
|
26 |
"""
|
27 |
|
28 |
+
# if a XAI approach is selected, grab the XAI instance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
if xai_selection in ("SHAP", "Visualizer"):
|
30 |
match xai_selection.lower():
|
31 |
case "shap":
|
|
|
34 |
xai = visualize
|
35 |
case _:
|
36 |
# use Gradio warning to display error message
|
37 |
+
gr.Warning(f"""
|
|
|
38 |
There was an error in the selected XAI Approach.
|
39 |
It is "{xai_selection}"
|
40 |
+
""")
|
|
|
41 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
42 |
|
43 |
# call the explained chat function
|
44 |
prompt_output, history_output, xai_graphic, xai_plot = explained_chat(
|
45 |
+
model=godel,
|
46 |
xai=xai,
|
47 |
message=prompt,
|
48 |
history=history,
|
49 |
system_prompt=system_prompt,
|
50 |
+
knowledge=knowledge,
|
51 |
)
|
52 |
# if no (or invalid) XAI approach is selected call the vanilla chat function
|
53 |
else:
|
54 |
# call the vanilla chat function
|
55 |
prompt_output, history_output = vanilla_chat(
|
56 |
+
model=godel,
|
57 |
message=prompt,
|
58 |
history=history,
|
59 |
system_prompt=system_prompt,
|
60 |
+
knowledge=knowledge,
|
61 |
)
|
62 |
# set XAI outputs to disclaimer html/none
|
63 |
+
xai_graphic, xai_plot = (
|
64 |
+
"""
|
65 |
+
<div style="text-align: center"><h4>Without Selected XAI Approach,
|
66 |
+
no graphic will be displayed</h4></div>
|
67 |
+
""",
|
68 |
+
None,
|
69 |
+
)
|
70 |
|
71 |
# return the outputs
|
72 |
return prompt_output, history_output, xai_graphic, xai_plot
|
|
|
74 |
|
75 |
# simple chat function that calls the model
|
76 |
# formats prompts, calls for an answer and returns updated conversation history
|
77 |
+
def vanilla_chat(
|
78 |
+
model, message: str, history: list, system_prompt: str, knowledge: str = ""
|
79 |
+
):
|
80 |
# formatting the prompt using the model's format_prompt function
|
81 |
+
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
82 |
# generating an answer using the model's respond function
|
83 |
answer = model.respond(prompt)
|
84 |
|
85 |
# updating the chat history with the new answer
|
86 |
+
history.append((message, answer))
|
87 |
|
88 |
# returning the updated history
|
89 |
return "", history
|
90 |
|
91 |
|
92 |
+
def explained_chat(
|
93 |
+
model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
|
94 |
+
):
|
95 |
# formatting the prompt using the model's format_prompt function
|
96 |
+
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
97 |
|
98 |
# generating an answer using the xai methods explain and respond function
|
99 |
answer, xai_graphic, xai_plot = xai.chat_explained(model, prompt)
|
100 |
# updating the chat history with the new answer
|
101 |
+
history.append((message, answer))
|
102 |
|
103 |
# returning the updated history, xai graphic and xai plot elements
|
104 |
+
return "", history, xai_graphic, xai_plot
|
explanation/interpret.py
CHANGED
@@ -3,40 +3,60 @@
|
|
3 |
import seaborn as sns
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
|
9 |
# main explain function that returns a chat with explanations
|
10 |
def chat_explained(model, prompt):
|
|
|
|
|
11 |
# create the shap explainer
|
12 |
-
shap_explainer =
|
13 |
# get the shap values for the prompt
|
14 |
-
shap_values = shap_explainer(prompt)
|
15 |
|
16 |
# create the explanation graphic and plot
|
17 |
graphic = create_graphic(shap_values)
|
18 |
plot = create_plot(shap_values)
|
19 |
|
20 |
# create the response text
|
21 |
-
response_text = format_output_text(shap_values.output_names)
|
22 |
return response_text, graphic, plot
|
23 |
|
24 |
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
|
36 |
# graphic plotting function that creates a html graphic (as string) for the explanation
|
37 |
def create_graphic(shap_values):
|
38 |
# create the html graphic using shap text plot function
|
39 |
-
graphic_html =
|
40 |
|
41 |
# return the html graphic as string
|
42 |
return str(graphic_html)
|
@@ -44,42 +64,59 @@ def create_graphic(shap_values):
|
|
44 |
|
45 |
# plotting function that creates a heatmap style explanation plot
|
46 |
def create_plot(shap_values):
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
values
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
plt.
|
58 |
-
|
59 |
-
#
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
|
|
66 |
)
|
67 |
|
68 |
-
#
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
73 |
)
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
)
|
79 |
|
80 |
-
#
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
|
85 |
return plt
|
|
|
3 |
import seaborn as sns
|
4 |
import matplotlib.pyplot as plt
|
5 |
import numpy as np
|
6 |
+
from shap import models, maskers, plots, PartitionExplainer
|
7 |
+
import torch
|
8 |
+
|
9 |
+
# internal imports
|
10 |
+
from utils import formatting as fmt
|
11 |
+
|
12 |
+
# global variables
|
13 |
+
TEACHER_FORCING = None
|
14 |
+
TEXT_MASKER = None
|
15 |
|
16 |
|
17 |
# main explain function that returns a chat with explanations
|
18 |
def chat_explained(model, prompt):
|
19 |
+
model.set_config()
|
20 |
+
|
21 |
# create the shap explainer
|
22 |
+
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
23 |
# get the shap values for the prompt
|
24 |
+
shap_values = shap_explainer([prompt])
|
25 |
|
26 |
# create the explanation graphic and plot
|
27 |
graphic = create_graphic(shap_values)
|
28 |
plot = create_plot(shap_values)
|
29 |
|
30 |
# create the response text
|
31 |
+
response_text = fmt.format_output_text(shap_values.output_names)
|
32 |
return response_text, graphic, plot
|
33 |
|
34 |
|
35 |
+
def wrap_shap(model):
|
36 |
+
global TEXT_MASKER, TEACHER_FORCING
|
37 |
+
|
38 |
+
# set the device to cuda if gpu is available
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
|
41 |
+
# updating the model settings again
|
42 |
+
model.set_config()
|
43 |
+
|
44 |
+
# (re)initialize the shap models and masker
|
45 |
+
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
|
46 |
+
TEACHER_FORCING = models.TeacherForcing(
|
47 |
+
text_generation,
|
48 |
+
model.TOKENIZER,
|
49 |
+
device=str(device),
|
50 |
+
similarity_model=model.MODEL,
|
51 |
+
similarity_tokenizer=model.TOKENIZER,
|
52 |
+
)
|
53 |
+
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
|
54 |
|
55 |
|
56 |
# graphic plotting function that creates a html graphic (as string) for the explanation
|
57 |
def create_graphic(shap_values):
|
58 |
# create the html graphic using shap text plot function
|
59 |
+
graphic_html = plots.text(shap_values, display=False)
|
60 |
|
61 |
# return the html graphic as string
|
62 |
return str(graphic_html)
|
|
|
64 |
|
65 |
# plotting function that creates a heatmap style explanation plot
|
66 |
def create_plot(shap_values):
|
67 |
+
values = shap_values.values[0]
|
68 |
+
output_names = shap_values.output_names
|
69 |
+
input_names = shap_values.data[0]
|
70 |
+
|
71 |
+
# Transpose the values for horizontal input names
|
72 |
+
transposed_values = np.transpose(values)
|
73 |
+
|
74 |
+
# Set seaborn style to dark
|
75 |
+
sns.set(style="dark")
|
76 |
+
|
77 |
+
fig, ax = plt.subplots()
|
78 |
+
|
79 |
+
# Making background transparent
|
80 |
+
ax.set_alpha(0)
|
81 |
+
fig.patch.set_alpha(0)
|
82 |
+
|
83 |
+
# Setting figure size
|
84 |
+
fig.set_size_inches(
|
85 |
+
max(transposed_values.shape[1] * 2, 10),
|
86 |
+
max(transposed_values.shape[0] / 1.5, 5),
|
87 |
)
|
88 |
|
89 |
+
# Plotting the heatmap with Seaborn's color palette
|
90 |
+
im = ax.imshow(
|
91 |
+
transposed_values,
|
92 |
+
vmax=transposed_values.max(),
|
93 |
+
vmin=-transposed_values.min(),
|
94 |
+
cmap=sns.color_palette("vlag_r", as_cmap=True),
|
95 |
+
aspect="auto",
|
96 |
)
|
97 |
+
|
98 |
+
# Creating colorbar
|
99 |
+
cbar = ax.figure.colorbar(im, ax=ax)
|
100 |
+
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
|
101 |
+
cbar.ax.yaxis.set_tick_params(color="white")
|
102 |
+
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
|
103 |
+
|
104 |
+
# Setting ticks and labels with white color for visibility
|
105 |
+
ax.set_xticks(np.arange(len(input_names)), labels=input_names)
|
106 |
+
ax.set_yticks(np.arange(len(output_names)), labels=output_names)
|
107 |
+
plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
|
108 |
+
plt.setp(ax.get_yticklabels(), color="white")
|
109 |
+
|
110 |
+
# Adjusting tick labels
|
111 |
+
ax.tick_params(
|
112 |
+
top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
|
113 |
)
|
114 |
|
115 |
+
# Adding text annotations - not used for readability
|
116 |
+
# for i in range(transposed_values.shape[0]):
|
117 |
+
# for j in range(transposed_values.shape[1]):
|
118 |
+
# val = transposed_values[i, j]
|
119 |
+
# color = "black" if 0.2 < im.norm(val) < 0.8 else "white"
|
120 |
+
# ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
|
121 |
|
122 |
return plt
|
explanation/visualize.py
CHANGED
@@ -2,20 +2,109 @@
|
|
2 |
|
3 |
# external imports
|
4 |
from bertviz import head_view
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
|
7 |
# plotting function that plots the attention values in a heatmap
|
8 |
def chat_explained(model, prompt):
|
9 |
-
inputs = model.TOKENIZER(prompt, return_tensors="pt")
|
10 |
-
out = model.MODEL(**inputs, output_attentions=True)
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
20 |
|
|
|
|
|
|
|
|
|
21 |
return response_text, graphic, plot
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
# external imports
|
4 |
from bertviz import head_view
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import seaborn as sns
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
# internal imports
|
10 |
+
from utils import formatting as fmt
|
11 |
|
12 |
|
13 |
# plotting function that plots the attention values in a heatmap
|
14 |
def chat_explained(model, prompt):
|
|
|
|
|
15 |
|
16 |
+
model.set_config()
|
17 |
+
|
18 |
+
# get encoded input and output vectors
|
19 |
+
encoder_input_ids = model.TOKENIZER(
|
20 |
+
prompt, return_tensors="pt", add_special_tokens=True
|
21 |
+
).input_ids
|
22 |
+
decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
|
23 |
+
encoder_text = fmt.format_tokens(
|
24 |
+
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
|
25 |
+
)
|
26 |
+
decoder_text = fmt.format_tokens(
|
27 |
+
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
|
28 |
+
)
|
29 |
|
30 |
+
# get attention values for the input and output vectors
|
31 |
+
attention_output = model.MODEL(
|
32 |
+
input_ids=encoder_input_ids,
|
33 |
+
decoder_input_ids=decoder_input_ids,
|
34 |
+
output_attentions=True,
|
35 |
+
)
|
36 |
|
37 |
+
# create the response text, graphic and plot
|
38 |
+
response_text = fmt.format_output_text(decoder_text)
|
39 |
+
graphic = create_graphic(attention_output, (encoder_text, decoder_text))
|
40 |
+
plot = create_plot(attention_output, (encoder_text, decoder_text))
|
41 |
return response_text, graphic, plot
|
42 |
+
|
43 |
+
|
44 |
+
# creating a html graphic using BERTViz
|
45 |
+
def create_graphic(attention_output, enc_dec_texts: tuple):
|
46 |
+
|
47 |
+
# calls the head_view function of BERTViz to return html graphic
|
48 |
+
hview = head_view(
|
49 |
+
encoder_attention=attention_output.encoder_attentions,
|
50 |
+
decoder_attention=attention_output.decoder_attentions,
|
51 |
+
cross_attention=attention_output.cross_attentions,
|
52 |
+
encoder_tokens=enc_dec_texts[0],
|
53 |
+
decoder_tokens=enc_dec_texts[1],
|
54 |
+
html_action="return",
|
55 |
+
)
|
56 |
+
|
57 |
+
return str(hview.data)
|
58 |
+
|
59 |
+
|
60 |
+
# creating an attention heatmap plot using seaborn
|
61 |
+
def create_plot(attention_output, enc_dec_texts: tuple):
|
62 |
+
# get the averaged attention weights
|
63 |
+
attention = attention_output.cross_attentions[0][0].detach().numpy()
|
64 |
+
averaged_attention_weights = np.mean(attention, axis=0)
|
65 |
+
|
66 |
+
# get the encoder and decoder tokens
|
67 |
+
encoder_tokens = enc_dec_texts[0]
|
68 |
+
decoder_tokens = enc_dec_texts[1]
|
69 |
+
|
70 |
+
# set seaborn style to dark and initialize figure and axis
|
71 |
+
sns.set(style="dark")
|
72 |
+
fig, ax = plt.subplots()
|
73 |
+
|
74 |
+
# Making background transparent
|
75 |
+
ax.set_alpha(0)
|
76 |
+
fig.patch.set_alpha(0)
|
77 |
+
|
78 |
+
# Setting figure size
|
79 |
+
fig.set_size_inches(
|
80 |
+
max(averaged_attention_weights.shape[1] * 2, 10),
|
81 |
+
max(averaged_attention_weights.shape[0] / 1.5, 5),
|
82 |
+
)
|
83 |
+
|
84 |
+
# Plotting the heatmap with seaborn's color palette
|
85 |
+
im = ax.imshow(
|
86 |
+
averaged_attention_weights,
|
87 |
+
vmax=averaged_attention_weights.max(),
|
88 |
+
vmin=-averaged_attention_weights.min(),
|
89 |
+
cmap=sns.color_palette("rocket", as_cmap=True),
|
90 |
+
aspect="auto",
|
91 |
+
)
|
92 |
+
|
93 |
+
# Creating colorbar
|
94 |
+
cbar = ax.figure.colorbar(im, ax=ax)
|
95 |
+
cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
|
96 |
+
cbar.ax.yaxis.set_tick_params(color="white")
|
97 |
+
plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
|
98 |
+
|
99 |
+
# Setting ticks and labels with white color for visibility
|
100 |
+
ax.set_xticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
|
101 |
+
ax.set_yticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
|
102 |
+
plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
|
103 |
+
plt.setp(ax.get_yticklabels(), color="white")
|
104 |
+
|
105 |
+
# Adjusting tick labels
|
106 |
+
ax.tick_params(
|
107 |
+
top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
|
108 |
+
)
|
109 |
+
|
110 |
+
return plt
|
main.py
CHANGED
@@ -7,8 +7,9 @@ import gradio as gr
|
|
7 |
# internal imports
|
8 |
from backend.controller import interference
|
9 |
|
10 |
-
# Global Variables
|
11 |
app = FastAPI()
|
|
|
12 |
|
13 |
|
14 |
# different functions to provide frontend abilities
|
@@ -36,37 +37,33 @@ def xai_info(xai_radio):
|
|
36 |
gr.Info("No XAI method was selected.")
|
37 |
|
38 |
|
39 |
-
# function to display the model info
|
40 |
-
def model_info(model_radio):
|
41 |
-
# display the model using the Gradio Info component
|
42 |
-
gr.Info(f"The model was set to:\n {model_radio}")
|
43 |
-
|
44 |
-
|
45 |
# ui interface based on Gradio Blocks (see documentation:
|
46 |
# https://www.gradio.app/docs/interface)
|
47 |
-
with gr.Blocks(
|
|
|
|
|
|
|
|
|
48 |
# header row with markdown based text
|
49 |
with gr.Row():
|
50 |
# markdown component to display the header
|
51 |
-
gr.Markdown(
|
52 |
-
|
53 |
-
|
54 |
### Select between tabs below for the different views.
|
55 |
-
"""
|
56 |
-
)
|
57 |
# ChatBot tab used to chat with the AI chatbot
|
58 |
with gr.Tab("AI ChatBot"):
|
59 |
with gr.Row():
|
60 |
# markdown component to display the header of the current tab
|
61 |
-
gr.Markdown(
|
62 |
-
"""
|
63 |
### ChatBot Demo
|
64 |
Chat with the AI ChatBot using the textbox below.
|
65 |
Manipulate the settings in the row above,
|
66 |
including the selection of the model,
|
67 |
the system prompt and the XAI method.
|
68 |
-
|
69 |
-
|
70 |
# row with columns for the different settings
|
71 |
with gr.Row(equal_height=True):
|
72 |
# column that takes up 3/5 of the row
|
@@ -80,22 +77,12 @@ with gr.Blocks() as ui:
|
|
80 |
" answer as helpfully as possible, while being safe."
|
81 |
),
|
82 |
)
|
83 |
-
with gr.Column(scale=1):
|
84 |
-
# checkbox group to select the model
|
85 |
-
model = gr.Radio(
|
86 |
-
["Mistral", "GODEL"],
|
87 |
-
label="Model Selection",
|
88 |
-
info="Select Model to use for chat.",
|
89 |
-
value="Mistral",
|
90 |
-
interactive=True,
|
91 |
-
show_label=True,
|
92 |
-
)
|
93 |
with gr.Column(scale=1):
|
94 |
# checkbox group to select the xai method
|
95 |
-
|
96 |
["None", "SHAP", "Visualizer"],
|
97 |
label="XAI Settings",
|
98 |
-
info="XAI
|
99 |
value="None",
|
100 |
interactive=True,
|
101 |
show_label=True,
|
@@ -103,11 +90,10 @@ with gr.Blocks() as ui:
|
|
103 |
|
104 |
# calling info functions on inputs for different settings
|
105 |
system_prompt.submit(system_prompt_info, [system_prompt])
|
106 |
-
|
107 |
-
xai.input(xai_info, [xai])
|
108 |
|
109 |
# row with chatbot ui displaying "conversation" with the model
|
110 |
-
with gr.Row():
|
111 |
# out of the box chatbot component
|
112 |
# see documentation: https://www.gradio.app/docs/chatbot
|
113 |
chatbot = gr.Chatbot(
|
@@ -115,10 +101,28 @@ with gr.Blocks() as ui:
|
|
115 |
show_copy_button=True,
|
116 |
avatar_images=("./public/human.jpg", "./public/bot.jpg"),
|
117 |
)
|
118 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
with gr.Row():
|
120 |
# textbox to enter the user prompt
|
121 |
-
user_prompt = gr.Textbox(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
# row with columns for buttons to submit and clear content
|
123 |
with gr.Row():
|
124 |
with gr.Column(scale=1):
|
@@ -127,79 +131,84 @@ with gr.Blocks() as ui:
|
|
127 |
clear_btn = gr.ClearButton([user_prompt, chatbot])
|
128 |
with gr.Column(scale=1):
|
129 |
submit_btn = gr.Button("Submit", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
# explanations tab used to provide explanations for a specific conversation
|
132 |
with gr.Tab("Explanations"):
|
133 |
# row with markdown component to display the header of the current tab
|
134 |
with gr.Row():
|
135 |
-
gr.Markdown(
|
136 |
-
"""
|
137 |
### Get Explanations for Conversations
|
138 |
Using your selected XAI method, you can get explanations for
|
139 |
the conversation you had with the AI ChatBot. The explanations are
|
140 |
based on the last message you sent to the AI ChatBot (see text)
|
141 |
-
"""
|
142 |
-
)
|
143 |
-
# row that displays the settings used to create the current model output
|
144 |
-
## each textbox statically displays the current values
|
145 |
-
with gr.Row():
|
146 |
-
with gr.Column():
|
147 |
-
gr.Textbox(
|
148 |
-
value=xai,
|
149 |
-
label="Used XAI Variant",
|
150 |
-
show_label=True,
|
151 |
-
interactive=True,
|
152 |
-
)
|
153 |
-
with gr.Column():
|
154 |
-
gr.Textbox(
|
155 |
-
value=model, label="Used Model", show_label=True, interactive=True
|
156 |
-
)
|
157 |
-
with gr.Column():
|
158 |
-
gr.Textbox(
|
159 |
-
value=system_prompt,
|
160 |
-
label="Used System Prompt",
|
161 |
-
show_label=True,
|
162 |
-
interactive=True,
|
163 |
-
)
|
164 |
# row that displays the generated explanation of the model (if applicable)
|
165 |
-
with gr.Row():
|
166 |
-
# wraps the explanation html in an iframe to display it
|
167 |
xai_interactive = gr.HTML(
|
168 |
label="Interactive Explanation",
|
|
|
|
|
|
|
|
|
169 |
show_label=True,
|
170 |
-
value="<div><h1>No Graphic to Display</h1></div>",
|
171 |
)
|
172 |
# row and accordion to display an explanation plot (if applicable)
|
173 |
with gr.Row():
|
174 |
with gr.Accordion("Token Explanation Plot", open=False):
|
|
|
|
|
|
|
|
|
175 |
# plot component that takes a matplotlib figure as input
|
176 |
-
xai_plot = gr.Plot(
|
177 |
-
label="Token Level Explanation",
|
178 |
-
show_label=True,
|
179 |
-
every=5,
|
180 |
-
)
|
181 |
|
182 |
# functions to trigger the controller
|
183 |
-
## takes information for the chat and the
|
184 |
## returns prompt, history and xai data
|
185 |
## see backend/controller.py for more information
|
186 |
submit_btn.click(
|
187 |
interference,
|
188 |
-
[user_prompt, chatbot,
|
189 |
[user_prompt, chatbot, xai_interactive, xai_plot],
|
190 |
)
|
191 |
# function triggered by the enter key
|
192 |
user_prompt.submit(
|
193 |
interference,
|
194 |
-
[user_prompt, chatbot,
|
195 |
[user_prompt, chatbot, xai_interactive, xai_plot],
|
196 |
)
|
197 |
|
198 |
# final row to show legal information
|
199 |
## - credits, data protection and link to the License
|
200 |
-
with gr.
|
201 |
-
|
202 |
-
gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
|
203 |
|
204 |
# mount function for fastAPI Application
|
205 |
app = gr.mount_gradio_app(app, ui, path="/")
|
|
|
7 |
# internal imports
|
8 |
from backend.controller import interference
|
9 |
|
10 |
+
# Global Variables and css
|
11 |
app = FastAPI()
|
12 |
+
css = "body {text-align: start !important;}"
|
13 |
|
14 |
|
15 |
# different functions to provide frontend abilities
|
|
|
37 |
gr.Info("No XAI method was selected.")
|
38 |
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
# ui interface based on Gradio Blocks (see documentation:
|
41 |
# https://www.gradio.app/docs/interface)
|
42 |
+
with gr.Blocks(
|
43 |
+
css="text-align: start !important",
|
44 |
+
title="Thesis Webapp Showcase",
|
45 |
+
head="<head>",
|
46 |
+
) as ui:
|
47 |
# header row with markdown based text
|
48 |
with gr.Row():
|
49 |
# markdown component to display the header
|
50 |
+
gr.Markdown("""
|
51 |
+
# Thesis Demo - AI Chat Application with GODEL
|
52 |
+
## XAI powered by SHAP and BERTVIZ
|
53 |
### Select between tabs below for the different views.
|
54 |
+
""")
|
|
|
55 |
# ChatBot tab used to chat with the AI chatbot
|
56 |
with gr.Tab("AI ChatBot"):
|
57 |
with gr.Row():
|
58 |
# markdown component to display the header of the current tab
|
59 |
+
gr.Markdown("""
|
|
|
60 |
### ChatBot Demo
|
61 |
Chat with the AI ChatBot using the textbox below.
|
62 |
Manipulate the settings in the row above,
|
63 |
including the selection of the model,
|
64 |
the system prompt and the XAI method.
|
65 |
+
|
66 |
+
""")
|
67 |
# row with columns for the different settings
|
68 |
with gr.Row(equal_height=True):
|
69 |
# column that takes up 3/5 of the row
|
|
|
77 |
" answer as helpfully as possible, while being safe."
|
78 |
),
|
79 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
with gr.Column(scale=1):
|
81 |
# checkbox group to select the xai method
|
82 |
+
xai_selection = gr.Radio(
|
83 |
["None", "SHAP", "Visualizer"],
|
84 |
label="XAI Settings",
|
85 |
+
info="Select a XAI Implementation to use.",
|
86 |
value="None",
|
87 |
interactive=True,
|
88 |
show_label=True,
|
|
|
90 |
|
91 |
# calling info functions on inputs for different settings
|
92 |
system_prompt.submit(system_prompt_info, [system_prompt])
|
93 |
+
xai_selection.input(xai_info, [xai_selection])
|
|
|
94 |
|
95 |
# row with chatbot ui displaying "conversation" with the model
|
96 |
+
with gr.Row(equal_height=True):
|
97 |
# out of the box chatbot component
|
98 |
# see documentation: https://www.gradio.app/docs/chatbot
|
99 |
chatbot = gr.Chatbot(
|
|
|
101 |
show_copy_button=True,
|
102 |
avatar_images=("./public/human.jpg", "./public/bot.jpg"),
|
103 |
)
|
104 |
+
# rows with input textboxes
|
105 |
+
with gr.Row():
|
106 |
+
# textbox to enter the knowledge
|
107 |
+
with gr.Accordion(label="Additional Knowledge", open=False):
|
108 |
+
knowledge_input = gr.Textbox(
|
109 |
+
value="",
|
110 |
+
label="Knowledge",
|
111 |
+
max_lines=5,
|
112 |
+
info="Add additional context knowledge.",
|
113 |
+
show_label=True,
|
114 |
+
)
|
115 |
with gr.Row():
|
116 |
# textbox to enter the user prompt
|
117 |
+
user_prompt = gr.Textbox(
|
118 |
+
label="Input Message",
|
119 |
+
max_lines=5,
|
120 |
+
info="""
|
121 |
+
Ask the ChatBot a question.
|
122 |
+
Hint: More complicated question give better explanation insights!
|
123 |
+
""",
|
124 |
+
show_label=True,
|
125 |
+
)
|
126 |
# row with columns for buttons to submit and clear content
|
127 |
with gr.Row():
|
128 |
with gr.Column(scale=1):
|
|
|
131 |
clear_btn = gr.ClearButton([user_prompt, chatbot])
|
132 |
with gr.Column(scale=1):
|
133 |
submit_btn = gr.Button("Submit", variant="primary")
|
134 |
+
with gr.Row():
|
135 |
+
gr.Examples(
|
136 |
+
label="Example Questions",
|
137 |
+
examples=[
|
138 |
+
[
|
139 |
+
"How does a black hole form in space?",
|
140 |
+
(
|
141 |
+
"Black holes are created when a massive star's core"
|
142 |
+
" collapses after a supernova, forming an object with"
|
143 |
+
" gravity so intense that even light cannot escape."
|
144 |
+
),
|
145 |
+
],
|
146 |
+
[
|
147 |
+
(
|
148 |
+
"Explain the importance of the Rosetta Stone in"
|
149 |
+
" understanding ancient languages."
|
150 |
+
),
|
151 |
+
(
|
152 |
+
"The Rosetta Stone, an ancient Egyptian artifact, was key"
|
153 |
+
" in decoding hieroglyphs, featuring the same text in three"
|
154 |
+
" scripts: hieroglyphs, Demotic, and Greek."
|
155 |
+
),
|
156 |
+
],
|
157 |
+
],
|
158 |
+
inputs=[user_prompt, knowledge_input],
|
159 |
+
)
|
160 |
|
161 |
# explanations tab used to provide explanations for a specific conversation
|
162 |
with gr.Tab("Explanations"):
|
163 |
# row with markdown component to display the header of the current tab
|
164 |
with gr.Row():
|
165 |
+
gr.Markdown("""
|
|
|
166 |
### Get Explanations for Conversations
|
167 |
Using your selected XAI method, you can get explanations for
|
168 |
the conversation you had with the AI ChatBot. The explanations are
|
169 |
based on the last message you sent to the AI ChatBot (see text)
|
170 |
+
""")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
# row that displays the generated explanation of the model (if applicable)
|
172 |
+
with gr.Row(variant="panel"):
|
173 |
+
# wraps the explanation html in an iframe to display it interactively
|
174 |
xai_interactive = gr.HTML(
|
175 |
label="Interactive Explanation",
|
176 |
+
value=(
|
177 |
+
'<div style="text-align: center"><h4>No Graphic to Display'
|
178 |
+
" (Yet)</h4></div>"
|
179 |
+
),
|
180 |
show_label=True,
|
|
|
181 |
)
|
182 |
# row and accordion to display an explanation plot (if applicable)
|
183 |
with gr.Row():
|
184 |
with gr.Accordion("Token Explanation Plot", open=False):
|
185 |
+
gr.Markdown("""
|
186 |
+
#### Plotted Values
|
187 |
+
Values have been excluded for readability. See colorbar for value indication.
|
188 |
+
""")
|
189 |
# plot component that takes a matplotlib figure as input
|
190 |
+
xai_plot = gr.Plot(label="Token Level Explanation", scale=3)
|
|
|
|
|
|
|
|
|
191 |
|
192 |
# functions to trigger the controller
|
193 |
+
## takes information for the chat and the xai selection
|
194 |
## returns prompt, history and xai data
|
195 |
## see backend/controller.py for more information
|
196 |
submit_btn.click(
|
197 |
interference,
|
198 |
+
[user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
|
199 |
[user_prompt, chatbot, xai_interactive, xai_plot],
|
200 |
)
|
201 |
# function triggered by the enter key
|
202 |
user_prompt.submit(
|
203 |
interference,
|
204 |
+
[user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
|
205 |
[user_prompt, chatbot, xai_interactive, xai_plot],
|
206 |
)
|
207 |
|
208 |
# final row to show legal information
|
209 |
## - credits, data protection and link to the License
|
210 |
+
with gr.Tab(label="Credits, Data Protection and License"):
|
211 |
+
gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
|
|
|
212 |
|
213 |
# mount function for fastAPI Application
|
214 |
app = gr.mount_gradio_app(app, ui, path="/")
|
model/godel.py
CHANGED
@@ -1,30 +1,55 @@
|
|
1 |
# GODEL model module for chat interaction and model instance control
|
|
|
|
|
2 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
3 |
|
|
|
|
|
|
|
4 |
# model and tokenizer instance
|
5 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
6 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
|
10 |
# formatting class to formatting input for the model
|
11 |
# CREDIT: Adapted from official interference example on Huggingface
|
12 |
## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
|
13 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
|
|
14 |
prompt = ""
|
15 |
|
|
|
|
|
|
|
|
|
|
|
16 |
# adds knowledge text if not empty
|
17 |
if knowledge != "":
|
18 |
knowledge = "[KNOWLEDGE] " + knowledge
|
19 |
|
20 |
-
history
|
21 |
-
for
|
22 |
-
prompt += f"EOS {
|
23 |
|
24 |
-
|
|
|
|
|
|
|
25 |
|
26 |
-
# returns the
|
27 |
-
return
|
28 |
|
29 |
|
30 |
# response class calling the model and returning the model output message
|
@@ -32,7 +57,7 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
|
|
32 |
## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
|
33 |
def respond(prompt):
|
34 |
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
|
35 |
-
outputs = MODEL.generate(input_ids, **
|
36 |
output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
37 |
|
38 |
return output
|
|
|
1 |
# GODEL model module for chat interaction and model instance control
|
2 |
+
|
3 |
+
# external imports
|
4 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
5 |
|
6 |
+
# internal imports
|
7 |
+
from utils import modelling as mdl
|
8 |
+
|
9 |
# model and tokenizer instance
|
10 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
11 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
12 |
+
CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
|
13 |
+
|
14 |
+
|
15 |
+
# TODO: Make config variable
|
16 |
+
def set_config(config: dict = None):
|
17 |
+
if config is None:
|
18 |
+
config = {}
|
19 |
+
|
20 |
+
MODEL.config.max_new_tokens = 50
|
21 |
+
MODEL.config.min_length = 8
|
22 |
+
MODEL.config.top_p = 0.9
|
23 |
+
MODEL.config.do_sample = True
|
24 |
|
25 |
|
26 |
# formatting class to formatting input for the model
|
27 |
# CREDIT: Adapted from official interference example on Huggingface
|
28 |
## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
|
29 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
30 |
+
# user input prompt initialization
|
31 |
prompt = ""
|
32 |
|
33 |
+
# limits the prompt elements to the maximum token count
|
34 |
+
message, history, system_prompt, knowledge = mdl.prompt_limiter(
|
35 |
+
TOKENIZER, message, history, system_prompt, knowledge
|
36 |
+
)
|
37 |
+
|
38 |
# adds knowledge text if not empty
|
39 |
if knowledge != "":
|
40 |
knowledge = "[KNOWLEDGE] " + knowledge
|
41 |
|
42 |
+
# adds conversation history to the prompt
|
43 |
+
for conversation in history:
|
44 |
+
prompt += f"EOS {conversation[0]} EOS {conversation[1]}"
|
45 |
|
46 |
+
# adds the message to the prompt
|
47 |
+
prompt += f" {message}"
|
48 |
+
# combines the entire prompt
|
49 |
+
full_prompt = f"{system_prompt} [CONTEXT] {prompt} {knowledge}"
|
50 |
|
51 |
+
# returns the formatted prompt
|
52 |
+
return full_prompt
|
53 |
|
54 |
|
55 |
# response class calling the model and returning the model output message
|
|
|
57 |
## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
|
58 |
def respond(prompt):
|
59 |
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
|
60 |
+
outputs = MODEL.generate(input_ids, **CONFIG)
|
61 |
output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
62 |
|
63 |
return output
|
model/mistral.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
# Mistral 7B model module for chat interaction and model instance control
|
2 |
-
|
3 |
-
# external imports
|
4 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
|
5 |
-
import torch
|
6 |
-
import gradio as gr
|
7 |
-
|
8 |
-
# global variables for model and tokenizer, config
|
9 |
-
MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
10 |
-
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
11 |
-
MISTRAL_CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
|
12 |
-
|
13 |
-
MISTRAL_CONFIG.update(
|
14 |
-
**{
|
15 |
-
"temperature": 0.7,
|
16 |
-
"max_new_tokens": 50,
|
17 |
-
"top_p": 0.9,
|
18 |
-
"repetition_penalty": 1.2,
|
19 |
-
"do_sample": True,
|
20 |
-
"seed": 42,
|
21 |
-
}
|
22 |
-
)
|
23 |
-
|
24 |
-
|
25 |
-
# function to format the prompt to include chat history, message
|
26 |
-
# CREDIT: adapted from Venkata Bhanu Teja Pallakonda in Huggingface discussions
|
27 |
-
## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/
|
28 |
-
|
29 |
-
|
30 |
-
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
31 |
-
prompt = ""
|
32 |
-
if knowledge != "":
|
33 |
-
gr.Warning(
|
34 |
-
"""Mistral does not support
|
35 |
-
additionally knowledge!"""
|
36 |
-
)
|
37 |
-
|
38 |
-
# if no history, use system prompt and example message
|
39 |
-
if len(history) == 0:
|
40 |
-
prompt = f"""<s>[INST] {system_prompt} [/INST] How can I help you today? </s>
|
41 |
-
[INST] {message} [/INST]"""
|
42 |
-
else:
|
43 |
-
# takes the very first exchange and the system prompt as base
|
44 |
-
for user_prompt, bot_response in history[0]:
|
45 |
-
prompt = (
|
46 |
-
f"<s>[INST] {system_prompt} {user_prompt} [/INST] {bot_response}</s>"
|
47 |
-
)
|
48 |
-
|
49 |
-
# takes all the following conversations and adds them as context
|
50 |
-
prompt += "".join(
|
51 |
-
f"[INST] {user_prompt} [/INST] {bot_response}</s>"
|
52 |
-
for user_prompt, bot_response in history[1:]
|
53 |
-
)
|
54 |
-
return prompt
|
55 |
-
|
56 |
-
|
57 |
-
# generation class returning the model response based on the input
|
58 |
-
# CREDIT: adapted from official Mistral Ai 7B Instruct documentation on Huggingface
|
59 |
-
## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
|
60 |
-
def respond(prompt):
|
61 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
-
|
63 |
-
# tokenizing inputs and configuring model
|
64 |
-
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")
|
65 |
-
model_input = input_ids.to(device)
|
66 |
-
MODEL.to(device)
|
67 |
-
|
68 |
-
# generating text with tokenized input, returning output
|
69 |
-
output_ids = MODEL.generate(model_input, generation_config=MISTRAL_CONFIG)
|
70 |
-
output_text = TOKENIZER.batch_decode(output_ids)
|
71 |
-
return output_text[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public/credits_dataprotection_license.md
CHANGED
@@ -9,19 +9,14 @@
|
|
9 |
For full credits, please refer to the [thesis print]()
|
10 |
|
11 |
### Models
|
12 |
-
|
13 |
|
14 |
-
#####
|
15 |
-
|
16 |
|
17 |
-
- the version used in this project is
|
18 |
-
- the model is fine
|
19 |
-
|
20 |
-
##### Mistral
|
21 |
-
Mistral is an open source model by Mistral AI. See [offical paper](https://arxiv.org/pdf/2310.06825.pdf) for more information.
|
22 |
-
|
23 |
-
- the version used in this project is Mistral Instruct, see [huggingface model hub](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
|
24 |
-
- the model is fine-tuned for instruction following by Mistral AI
|
25 |
|
26 |
### Libraries
|
27 |
This project uses a number of open source libraries, only the most important ones are listed below.
|
@@ -29,7 +24,7 @@ This project uses a number of open source libraries, only the most important one
|
|
29 |
##### Shap
|
30 |
This application uses a custom version of the shap library, which is available at [GitHub](https://github.com/shap/shap).
|
31 |
|
32 |
-
- please refer to the [shap
|
33 |
- the shap library and the used partition SHAP explainer are based on work by Lundberg et al. (2017), see [offical paper](https://arxiv.org/pdf/1705.07874.pdf) for more information
|
34 |
|
35 |
##### BertViz
|
@@ -40,10 +35,11 @@ This application uses a slightly customized version of the bertviz library, whic
|
|
40 |
|
41 |
|
42 |
# Data Protection
|
43 |
-
This is a non-commercial project, which does not collect any personal data. The only data collected is the data you enter into the application. This data is only used to generate the explanations and is not stored anywhere.
|
44 |
-
|
|
|
45 |
|
46 |
-
If you use the "flag" feature, the data you enter will be stored in *publicly available* csv file.
|
47 |
|
48 |
|
49 |
# License
|
|
|
9 |
For full credits, please refer to the [thesis print]()
|
10 |
|
11 |
### Models
|
12 |
+
This implementation is build on GODEL by Microsoft, Inc.
|
13 |
|
14 |
+
##### GODEL
|
15 |
+
GODEL is an open source model by Microsoft. See [offical paper](https://arxiv.org/abs/2206.11309) for more information.
|
16 |
|
17 |
+
- the version used in this project is GODEL Large, see [huggingface model hub](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq?text=Hey+my+name+is+Thomas%21+How+are+you%3F)
|
18 |
+
- the model as is a generative seq2seq transformer fine tuned for goal directed dialog
|
19 |
+
- it supports context and knowledge base inputs
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
### Libraries
|
22 |
This project uses a number of open source libraries, only the most important ones are listed below.
|
|
|
24 |
##### Shap
|
25 |
This application uses a custom version of the shap library, which is available at [GitHub](https://github.com/shap/shap).
|
26 |
|
27 |
+
- please refer to the [thesis-custom-shap](https://github.com/LennardZuendorf/thesis-custom-shap) repository for more information about the changes made to the library, specifically the README and CHANGES files
|
28 |
- the shap library and the used partition SHAP explainer are based on work by Lundberg et al. (2017), see [offical paper](https://arxiv.org/pdf/1705.07874.pdf) for more information
|
29 |
|
30 |
##### BertViz
|
|
|
35 |
|
36 |
|
37 |
# Data Protection
|
38 |
+
This is a non-commercial research project, which does not collect any personal data. The only data collected is the data you enter into the application. This data is only used to generate the explanations and is not stored anywhere.
|
39 |
+
|
40 |
+
> However, the application may be hosted with an external service (i.e. Huggingface Spaces), which may collect data.
|
41 |
|
42 |
+
Please refer to the data protection policies of the respective service for more information. If you use the "flag" feature, the data you enter will be stored in *publicly available* csv file.
|
43 |
|
44 |
|
45 |
# License
|
pyproject.toml
CHANGED
@@ -1,6 +1,8 @@
|
|
|
|
1 |
[tool.black]
|
2 |
line-length = 88
|
3 |
include = '\.pyi?$'
|
|
|
4 |
exclude = '''
|
5 |
/(
|
6 |
\.eggs
|
|
|
1 |
+
# configuration for formatting & linting tools
|
2 |
[tool.black]
|
3 |
line-length = 88
|
4 |
include = '\.pyi?$'
|
5 |
+
preview = true
|
6 |
exclude = '''
|
7 |
/(
|
8 |
\.eggs
|
railway.json
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"$schema": "https://railway.app/railway.schema.json",
|
3 |
-
"build": {
|
4 |
-
"builder": "DOCKERFILE",
|
5 |
-
"dockerfilePath": "Dockerfile"
|
6 |
-
},
|
7 |
-
"deploy": {
|
8 |
-
"numReplicas": 1,
|
9 |
-
"sleepApplication": false,
|
10 |
-
"restartPolicyType": "ON_FAILURE",
|
11 |
-
"restartPolicyMaxRetries": 10
|
12 |
-
}
|
13 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__init__.py
ADDED
File without changes
|
utils/formatting.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# formatting util module providing formatting functions for the model input and output
|
2 |
+
|
3 |
+
# external imports
|
4 |
+
import re
|
5 |
+
|
6 |
+
|
7 |
+
# function to format the model reponse nicely
|
8 |
+
def format_output_text(output: list):
|
9 |
+
# remove special tokens from list
|
10 |
+
formatted_output = format_tokens(output)
|
11 |
+
|
12 |
+
# start string with first list item if it is not empty
|
13 |
+
if formatted_output[0] != "":
|
14 |
+
output_str = formatted_output[0]
|
15 |
+
else:
|
16 |
+
# alternatively start with second list item
|
17 |
+
output_str = formatted_output[1]
|
18 |
+
|
19 |
+
# add all other list items with a space in between
|
20 |
+
for txt in formatted_output[1:]:
|
21 |
+
# check if the token is a punctuation mark
|
22 |
+
if txt in [".", ",", "!", "?"]:
|
23 |
+
# add punctuation mark without space
|
24 |
+
output_str += txt
|
25 |
+
# add token with space if not empty
|
26 |
+
elif txt != "":
|
27 |
+
output_str += " " + txt
|
28 |
+
|
29 |
+
# return the combined string with multiple spaces removed
|
30 |
+
return re.sub(" +", " ", output_str)
|
31 |
+
|
32 |
+
|
33 |
+
# format the tokens by removing special tokens and special characters
|
34 |
+
def format_tokens(tokens: list):
|
35 |
+
# define special tokens to remove and initialize empty list
|
36 |
+
special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "โ", "ฤ ", "</w>"]
|
37 |
+
updated_tokens = []
|
38 |
+
|
39 |
+
# loop through tokens
|
40 |
+
for t in tokens:
|
41 |
+
# remove special token from start of token if found
|
42 |
+
if t.startswith("โ"):
|
43 |
+
t = t.lstrip("โ")
|
44 |
+
|
45 |
+
# loop through special tokens and remove them if found
|
46 |
+
for s in special_tokens:
|
47 |
+
t = t.replace(s, "")
|
48 |
+
|
49 |
+
# add token to list
|
50 |
+
updated_tokens.append(t)
|
51 |
+
|
52 |
+
# return the list of tokens
|
53 |
+
return updated_tokens
|
utils/modelling.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# module for modelling utilities
|
2 |
+
|
3 |
+
# external imports
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def prompt_limiter(
|
8 |
+
tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
|
9 |
+
):
|
10 |
+
# initializing the prompt history empty
|
11 |
+
prompt_history = []
|
12 |
+
# getting the token count for the message, system prompt, and knowledge
|
13 |
+
pre_count = (
|
14 |
+
token_counter(tokenizer, message)
|
15 |
+
+ token_counter(tokenizer, system_prompt)
|
16 |
+
+ token_counter(tokenizer, knowledge)
|
17 |
+
)
|
18 |
+
|
19 |
+
# validating the token count
|
20 |
+
# check if token count already too high
|
21 |
+
if pre_count > 1024:
|
22 |
+
|
23 |
+
# check if token count too high even without knowledge
|
24 |
+
if (
|
25 |
+
token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
|
26 |
+
> 1024
|
27 |
+
):
|
28 |
+
|
29 |
+
# show warning and raise error
|
30 |
+
gr.Warning("Message and system prompt are too long. Please shorten them.")
|
31 |
+
raise RuntimeError(
|
32 |
+
"Message and system prompt are too long. Please shorten them."
|
33 |
+
)
|
34 |
+
|
35 |
+
# show warning and remove knowledge
|
36 |
+
gr.Warning("Knowledge is too long. It has been removed to keep model running.")
|
37 |
+
return message, prompt_history, system_prompt, ""
|
38 |
+
|
39 |
+
# if token count small enough, add history
|
40 |
+
if pre_count < 800:
|
41 |
+
# setting the count to the precount
|
42 |
+
count = pre_count
|
43 |
+
# reversing the history to prioritize recent conversations
|
44 |
+
history.reverse()
|
45 |
+
|
46 |
+
# iterating through the history
|
47 |
+
for conversation in history:
|
48 |
+
|
49 |
+
# checking the token count with the current conversation
|
50 |
+
count += token_counter(tokenizer, conversation[0]) + token_counter(
|
51 |
+
tokenizer, conversation[1]
|
52 |
+
)
|
53 |
+
|
54 |
+
# add conversation or break loop depending on token count
|
55 |
+
if count < 1024:
|
56 |
+
prompt_history.append(conversation)
|
57 |
+
else:
|
58 |
+
break
|
59 |
+
|
60 |
+
# return the message, prompt history, system prompt, and knowledge
|
61 |
+
return message, prompt_history, system_prompt, knowledge
|
62 |
+
|
63 |
+
|
64 |
+
# token counter function using the model tokenizer
|
65 |
+
def token_counter(tokenizer, text: str):
|
66 |
+
# tokenize the text
|
67 |
+
tokens = tokenizer(text, return_tensors="pt").input_ids
|
68 |
+
# return the token count
|
69 |
+
return len(tokens[0])
|