Spaces:
Runtime error
Runtime error
LennardZuendorf
commited on
Commit
•
c28c597
1
Parent(s):
f5ebee7
fix: fixing linting issues, updating gh action
Browse files- .github/workflows/hgf-sync-main.yml +30 -1
- backend/controller.py +7 -9
- explanation/interpret_shap.py +1 -4
- explanation/visualize.py +0 -5
- main.py +1 -4
- utils/formatting.py +2 -0
.github/workflows/hgf-sync-main.yml
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
# workflow that syncs the main branch to the Hugging Face Hub (Huggingface Spaces)
|
2 |
-
#
|
|
|
|
|
3 |
|
4 |
name: HGF Hub Sync (Main)
|
5 |
# runs on pushes to the main branch and manually triggered workflows
|
@@ -9,8 +11,35 @@ on:
|
|
9 |
|
10 |
workflow_dispatch:
|
11 |
|
|
|
|
|
|
|
12 |
# jobs to run
|
13 |
jobs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# sync job
|
15 |
sync-to-hub:
|
16 |
runs-on: ubuntu-latest
|
|
|
1 |
# workflow that syncs the main branch to the Hugging Face Hub (Huggingface Spaces)
|
2 |
+
# only syncs if the build and lint is also fine
|
3 |
+
# CREDIT: Adapted from Hugging Face, Inc.
|
4 |
+
## see https://huggingface.co/docs/hub/spaces-github-actions
|
5 |
|
6 |
name: HGF Hub Sync (Main)
|
7 |
# runs on pushes to the main branch and manually triggered workflows
|
|
|
11 |
|
12 |
workflow_dispatch:
|
13 |
|
14 |
+
permissions:
|
15 |
+
contents: read
|
16 |
+
|
17 |
# jobs to run
|
18 |
jobs:
|
19 |
+
|
20 |
+
# build job that installs dependencies and lints
|
21 |
+
build:
|
22 |
+
runs-on: ubuntu-latest
|
23 |
+
steps:
|
24 |
+
# checkout the repository
|
25 |
+
- uses: actions/checkout@v3
|
26 |
+
# set up python 3.10
|
27 |
+
- name: Set up Python 3.10
|
28 |
+
uses: actions/setup-python@v3
|
29 |
+
with:
|
30 |
+
python-version: "3.10"
|
31 |
+
# install dependencies from requirements.txt
|
32 |
+
- name: Install dependencies
|
33 |
+
run: |
|
34 |
+
python -m pip install --upgrade pip
|
35 |
+
pip install pylint black
|
36 |
+
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
|
37 |
+
# lint and format all files with black (as defined in pyproject.toml)
|
38 |
+
- name: Lint & Fix with Black
|
39 |
+
run: black .
|
40 |
+
# lint with pylint
|
41 |
+
- name: Lint with Pylint
|
42 |
+
run: pylint .
|
43 |
# sync job
|
44 |
sync-to-hub:
|
45 |
runs-on: ubuntu-latest
|
backend/controller.py
CHANGED
@@ -40,15 +40,13 @@ def interference(
|
|
40 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
41 |
|
42 |
# call the explained chat function
|
43 |
-
prompt_output, history_output, xai_graphic, xai_markup = (
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
knowledge=knowledge,
|
51 |
-
)
|
52 |
)
|
53 |
# if no (or invalid) XAI approach is selected call the vanilla chat function
|
54 |
else:
|
|
|
40 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
41 |
|
42 |
# call the explained chat function
|
43 |
+
prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
|
44 |
+
model=godel,
|
45 |
+
xai=xai,
|
46 |
+
message=prompt,
|
47 |
+
history=history,
|
48 |
+
system_prompt=system_prompt,
|
49 |
+
knowledge=knowledge,
|
|
|
|
|
50 |
)
|
51 |
# if no (or invalid) XAI approach is selected call the vanilla chat function
|
52 |
else:
|
explanation/interpret_shap.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
# interpret module that implements the interpretability method
|
2 |
# external imports
|
3 |
-
import seaborn as sns
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import numpy as np
|
6 |
from shap import models, maskers, plots, PartitionExplainer
|
7 |
import torch
|
8 |
|
@@ -62,4 +59,4 @@ def create_graphic(shap_values):
|
|
62 |
graphic_html = plots.text(shap_values, display=False)
|
63 |
|
64 |
# return the html graphic as string
|
65 |
-
return str(graphic_html)
|
|
|
1 |
# interpret module that implements the interpretability method
|
2 |
# external imports
|
|
|
|
|
|
|
3 |
from shap import models, maskers, plots, PartitionExplainer
|
4 |
import torch
|
5 |
|
|
|
59 |
graphic_html = plots.text(shap_values, display=False)
|
60 |
|
61 |
# return the html graphic as string
|
62 |
+
return str(graphic_html)
|
explanation/visualize.py
CHANGED
@@ -1,9 +1,5 @@
|
|
1 |
# visualization module that creates an attention visualization using BERTViz
|
2 |
|
3 |
-
# external imports
|
4 |
-
import matplotlib.pyplot as plt
|
5 |
-
import seaborn as sns
|
6 |
-
import numpy as np
|
7 |
|
8 |
# internal imports
|
9 |
from utils import formatting as fmt
|
@@ -41,4 +37,3 @@ def chat_explained(model, prompt):
|
|
41 |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
42 |
|
43 |
return response_text, "", marked_text
|
44 |
-
|
|
|
1 |
# visualization module that creates an attention visualization using BERTViz
|
2 |
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# internal imports
|
5 |
from utils import formatting as fmt
|
|
|
37 |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
38 |
|
39 |
return response_text, "", marked_text
|
|
main.py
CHANGED
@@ -180,10 +180,7 @@ with gr.Blocks(
|
|
180 |
" scripts: hieroglyphs, Demotic, and Greek."
|
181 |
),
|
182 |
],
|
183 |
-
[
|
184 |
-
"Does money buy happiness?",
|
185 |
-
""
|
186 |
-
],
|
187 |
],
|
188 |
inputs=[user_prompt, knowledge_input],
|
189 |
)
|
|
|
180 |
" scripts: hieroglyphs, Demotic, and Greek."
|
181 |
),
|
182 |
],
|
183 |
+
["Does money buy happiness?", ""],
|
|
|
|
|
|
|
184 |
],
|
185 |
inputs=[user_prompt, knowledge_input],
|
186 |
)
|
utils/formatting.py
CHANGED
@@ -69,9 +69,11 @@ def format_tokens(tokens: list):
|
|
69 |
def flatten_attribution(values: ndarray, axis: int = 0):
|
70 |
return np.sum(values, axis=axis)
|
71 |
|
|
|
72 |
def flatten_attention(values: ndarray, axis: int = 0):
|
73 |
return np.mean(values, axis=axis)
|
74 |
|
|
|
75 |
def avg_attention(attention_values):
|
76 |
attention = attention_values.cross_attentions[0][0].detach().numpy()
|
77 |
return np.mean(attention, axis=0)
|
|
|
69 |
def flatten_attribution(values: ndarray, axis: int = 0):
|
70 |
return np.sum(values, axis=axis)
|
71 |
|
72 |
+
|
73 |
def flatten_attention(values: ndarray, axis: int = 0):
|
74 |
return np.mean(values, axis=axis)
|
75 |
|
76 |
+
|
77 |
def avg_attention(attention_values):
|
78 |
attention = attention_values.cross_attentions[0][0].detach().numpy()
|
79 |
return np.mean(attention, axis=0)
|