Spaces:
Runtime error
Runtime error
LennardZuendorf
commited on
Commit
•
1f063be
1
Parent(s):
a597c76
fix: another set of attention fixes
Browse files- backend/controller.py +4 -5
- explanation/attention.py +4 -2
- explanation/markup.py +19 -11
- main.py +18 -8
- model/godel.py +1 -1
- utils/modelling.py +1 -4
backend/controller.py
CHANGED
@@ -43,7 +43,6 @@ def explained_chat(
|
|
43 |
# message, history, system_prompt, knowledge
|
44 |
# )
|
45 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
46 |
-
print(f"Formatted prompt: {prompt}")
|
47 |
|
48 |
# generating an answer using the methods chat function
|
49 |
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
|
@@ -66,10 +65,10 @@ def interference(
|
|
66 |
):
|
67 |
# if no proper system prompt is given, use a default one
|
68 |
if system_prompt in ("", " "):
|
69 |
-
system_prompt =
|
70 |
-
You are a helpful, respectful and honest assistant.
|
71 |
-
Always answer as helpfully as possible, while being safe.
|
72 |
-
|
73 |
|
74 |
# if a model is selected, grab the model instance
|
75 |
if model_selection.lower() == "mistral":
|
|
|
43 |
# message, history, system_prompt, knowledge
|
44 |
# )
|
45 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
|
|
46 |
|
47 |
# generating an answer using the methods chat function
|
48 |
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
|
|
|
65 |
):
|
66 |
# if no proper system prompt is given, use a default one
|
67 |
if system_prompt in ("", " "):
|
68 |
+
system_prompt = (
|
69 |
+
"You are a helpful, respectful and honest assistant."
|
70 |
+
"Always answer as helpfully as possible, while being safe."
|
71 |
+
)
|
72 |
|
73 |
# if a model is selected, grab the model instance
|
74 |
if model_selection.lower() == "mistral":
|
explanation/attention.py
CHANGED
@@ -37,6 +37,9 @@ def chat_explained(model, prompt):
|
|
37 |
attention_output = mdl.format_mistral_attention(attention_output)
|
38 |
averaged_attention = fmt.avg_attention(attention_output, model="mistral")
|
39 |
|
|
|
|
|
|
|
40 |
# otherwise use attention visualization for godel
|
41 |
else:
|
42 |
# get attention values for the input and output vectors
|
@@ -49,9 +52,8 @@ def chat_explained(model, prompt):
|
|
49 |
|
50 |
# averaging attention across layers
|
51 |
averaged_attention = fmt.avg_attention(attention_output, model="godel")
|
|
|
52 |
|
53 |
-
# format response text for clean output
|
54 |
-
response_text = fmt.format_output_text(output_text)
|
55 |
# setting placeholder for iFrame graphic
|
56 |
graphic = (
|
57 |
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
|
|
37 |
attention_output = mdl.format_mistral_attention(attention_output)
|
38 |
averaged_attention = fmt.avg_attention(attention_output, model="mistral")
|
39 |
|
40 |
+
response_text = fmt.format_output_text(output_text)
|
41 |
+
response_text = mistral.format_answer(response_text)
|
42 |
+
|
43 |
# otherwise use attention visualization for godel
|
44 |
else:
|
45 |
# get attention values for the input and output vectors
|
|
|
52 |
|
53 |
# averaging attention across layers
|
54 |
averaged_attention = fmt.avg_attention(attention_output, model="godel")
|
55 |
+
response_text = fmt.format_output_text(output_text)
|
56 |
|
|
|
|
|
57 |
# setting placeholder for iFrame graphic
|
58 |
graphic = (
|
59 |
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
explanation/markup.py
CHANGED
@@ -10,6 +10,8 @@ from utils import formatting as fmt
|
|
10 |
|
11 |
# main function that assigns each text snipped a marked bucket
|
12 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
|
|
|
13 |
# naming of the 11 buckets
|
14 |
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
|
15 |
|
@@ -21,6 +23,12 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
21 |
elif variant == "visualizer":
|
22 |
text_values = fmt.flatten_attention(text_values)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
# determine the minimum and maximum values
|
25 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
26 |
|
@@ -47,17 +55,17 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
47 |
for text, value in zip(input_text, text_values):
|
48 |
|
49 |
# validating text and skipping empty text/special tokens
|
50 |
-
if text not in
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
|
62 |
# returning list of marked text snippets as list of tuples
|
63 |
return marked_text
|
|
|
10 |
|
11 |
# main function that assigns each text snipped a marked bucket
|
12 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
13 |
+
print(f"Marking up text {input_text} and {text_values} for {variant}.")
|
14 |
+
|
15 |
# naming of the 11 buckets
|
16 |
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
|
17 |
|
|
|
23 |
elif variant == "visualizer":
|
24 |
text_values = fmt.flatten_attention(text_values)
|
25 |
|
26 |
+
if text_values.size != len(input_text):
|
27 |
+
raise ValueError(
|
28 |
+
"Length of input text and attribution values do not match. "
|
29 |
+
f"Text: {len(input_text)}, Attributions: {len(text_values)}"
|
30 |
+
)
|
31 |
+
|
32 |
# determine the minimum and maximum values
|
33 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
34 |
|
|
|
55 |
for text, value in zip(input_text, text_values):
|
56 |
|
57 |
# validating text and skipping empty text/special tokens
|
58 |
+
# if text not in fmt.SPECIAL_TOKENS:
|
59 |
+
# setting initial bucket at lowest
|
60 |
+
bucket = "-5"
|
61 |
+
|
62 |
+
# looping over all bucket and their threshold
|
63 |
+
for i, threshold in zip(bucket_tags, thresholds):
|
64 |
+
# updating assigned bucket if value is above threshold
|
65 |
+
if value >= threshold:
|
66 |
+
bucket = i
|
67 |
+
# finally adding text and bucket assignment to list of tuples
|
68 |
+
marked_text.append((text, str(bucket)))
|
69 |
|
70 |
# returning list of marked text snippets as list of tuples
|
71 |
return marked_text
|
main.py
CHANGED
@@ -216,36 +216,46 @@ with gr.Blocks(
|
|
216 |
gr.Examples(
|
217 |
label="Example Questions",
|
218 |
examples=[
|
219 |
-
["Does money buy happiness?", "", "", "Mistral", "
|
220 |
-
["Does money buy happiness?", "", "", "Mistral", "
|
221 |
-
["Does money buy happiness?", "", "", "Mistral", "
|
222 |
[
|
223 |
"Does money buy happiness?",
|
224 |
-
"",
|
225 |
(
|
226 |
"Respond from the perspective of billionaire heir"
|
227 |
" living his best life with his father's money."
|
228 |
),
|
229 |
"Mistral",
|
230 |
-
"
|
231 |
],
|
232 |
[
|
233 |
"Does money buy happiness?",
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
"",
|
|
|
|
|
|
|
|
|
235 |
(
|
236 |
"Respond from the perspective of billionaire heir"
|
237 |
" living his best life with his father's money."
|
238 |
),
|
239 |
"Mistral",
|
240 |
-
"
|
241 |
],
|
242 |
],
|
243 |
inputs=[
|
244 |
user_prompt,
|
245 |
-
knowledge_input,
|
246 |
system_prompt,
|
247 |
-
model_selection,
|
248 |
xai_selection,
|
|
|
|
|
249 |
],
|
250 |
)
|
251 |
with gr.Accordion("GODEL Model Examples", open=False):
|
|
|
216 |
gr.Examples(
|
217 |
label="Example Questions",
|
218 |
examples=[
|
219 |
+
["Does money buy happiness?", "None", "", "Mistral", ""],
|
220 |
+
["Does money buy happiness?", "SHAP", "", "Mistral", ""],
|
221 |
+
["Does money buy happiness?", "Attention", "", "Mistral", ""],
|
222 |
[
|
223 |
"Does money buy happiness?",
|
224 |
+
"None",
|
225 |
(
|
226 |
"Respond from the perspective of billionaire heir"
|
227 |
" living his best life with his father's money."
|
228 |
),
|
229 |
"Mistral",
|
230 |
+
"",
|
231 |
],
|
232 |
[
|
233 |
"Does money buy happiness?",
|
234 |
+
"SHAP",
|
235 |
+
(
|
236 |
+
"Respond from the perspective of billionaire heir"
|
237 |
+
" living his best life with his father's money."
|
238 |
+
),
|
239 |
+
"Mistral",
|
240 |
"",
|
241 |
+
],
|
242 |
+
[
|
243 |
+
"Does money buy happiness?",
|
244 |
+
"Attention",
|
245 |
(
|
246 |
"Respond from the perspective of billionaire heir"
|
247 |
" living his best life with his father's money."
|
248 |
),
|
249 |
"Mistral",
|
250 |
+
"",
|
251 |
],
|
252 |
],
|
253 |
inputs=[
|
254 |
user_prompt,
|
|
|
255 |
system_prompt,
|
|
|
256 |
xai_selection,
|
257 |
+
model_selection,
|
258 |
+
knowledge_input,
|
259 |
],
|
260 |
)
|
261 |
with gr.Accordion("GODEL Model Examples", open=False):
|
model/godel.py
CHANGED
@@ -14,7 +14,7 @@ MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2se
|
|
14 |
# model config definition
|
15 |
CONFIG = GenerationConfig.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
16 |
base_config_dict = {
|
17 |
-
"max_new_tokens":
|
18 |
"min_length": 8,
|
19 |
"top_p": 0.9,
|
20 |
"do_sample": True,
|
|
|
14 |
# model config definition
|
15 |
CONFIG = GenerationConfig.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
16 |
base_config_dict = {
|
17 |
+
"max_new_tokens": 64,
|
18 |
"min_length": 8,
|
19 |
"top_p": 0.9,
|
20 |
"do_sample": True,
|
utils/modelling.py
CHANGED
@@ -82,23 +82,20 @@ def get_device():
|
|
82 |
device = torch.device("cuda")
|
83 |
else:
|
84 |
device = torch.device("cpu")
|
85 |
-
|
86 |
return device
|
87 |
|
88 |
|
89 |
# function to set device config
|
90 |
-
# CREDIT:
|
91 |
# see https://captum.ai/tutorials/Llama2_LLM_Attribution
|
92 |
def gpu_loading_config(max_memory: str = "15000MB"):
|
93 |
n_gpus = torch.cuda.device_count()
|
94 |
-
|
95 |
bnb_config = BitsAndBytesConfig(
|
96 |
load_in_4bit=True,
|
97 |
bnb_4bit_use_double_quant=True,
|
98 |
bnb_4bit_quant_type="nf4",
|
99 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
100 |
)
|
101 |
-
|
102 |
return n_gpus, max_memory, bnb_config
|
103 |
|
104 |
|
|
|
82 |
device = torch.device("cuda")
|
83 |
else:
|
84 |
device = torch.device("cpu")
|
|
|
85 |
return device
|
86 |
|
87 |
|
88 |
# function to set device config
|
89 |
+
# CREDIT: Copied from captum llama 2 example
|
90 |
# see https://captum.ai/tutorials/Llama2_LLM_Attribution
|
91 |
def gpu_loading_config(max_memory: str = "15000MB"):
|
92 |
n_gpus = torch.cuda.device_count()
|
|
|
93 |
bnb_config = BitsAndBytesConfig(
|
94 |
load_in_4bit=True,
|
95 |
bnb_4bit_use_double_quant=True,
|
96 |
bnb_4bit_quant_type="nf4",
|
97 |
bnb_4bit_compute_dtype=torch.bfloat16,
|
98 |
)
|
|
|
99 |
return n_gpus, max_memory, bnb_config
|
100 |
|
101 |
|