LennardZuendorf commited on
Commit
69b34c4
1 Parent(s): 1a96c54

fix: cleanup build config again, reverting reverted changes

Browse files
Dockerfile CHANGED
@@ -3,8 +3,8 @@
3
  # complete build based on clean python (slower)
4
  #FROM python:3.11.6
5
 
6
- # build based on python with dependencies (quicker) - for dev
7
- FROM thesis:0.2.0-base
8
 
9
  # install dependencies and copy files into image folder
10
  COPY requirements.txt .
@@ -16,8 +16,8 @@ COPY . .
16
  RUN ls --recursive .
17
 
18
  # setting config and run command
19
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
20
 
21
  # build and run commands:
22
- ## docker build -t thesis:0.2.0-full -f Dockerfile .
23
- ## docker run -d --name thesis -p 8080:8080 thesis:0.2.0
 
3
  # complete build based on clean python (slower)
4
  #FROM python:3.11.6
5
 
6
+ # build based on thesis base with dependencies (quicker) - for dev
7
+ FROM thesis-base:0.1.1
8
 
9
  # install dependencies and copy files into image folder
10
  COPY requirements.txt .
 
16
  RUN ls --recursive .
17
 
18
  # setting config and run command
19
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8080"]
20
 
21
  # build and run commands:
22
+ ## docker build -t thesis:0.1.4 -f Dockerfile .
23
+ ## docker run -d --name thesis -p 8080:8080 thesis:0.1.4
Dockerfile-Base CHANGED
@@ -2,11 +2,11 @@
2
  # because all dependencies are already installed, the next webapp build using this base image is much quicker
3
 
4
  # using newest python as a base image
5
- FROM python:3.11.6
6
 
7
  # install dependencies based on requirements
8
  COPY requirements.txt ./
9
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
10
 
11
  # build and run commands
12
- ## docker build -t thesis:0.1.6-base -f Dockerfile-Base .
 
2
  # because all dependencies are already installed, the next webapp build using this base image is much quicker
3
 
4
  # using newest python as a base image
5
+ FROM thesis-base:0.1.1
6
 
7
  # install dependencies based on requirements
8
  COPY requirements.txt ./
9
  RUN pip install --no-cache-dir --upgrade -r requirements.txt
10
 
11
  # build and run commands
12
+ ## docker build -t thesis-base:1.0.1 -f Dockerfile-Base .
README.md CHANGED
@@ -3,12 +3,12 @@ title: Thesis
3
  emoji: 🎓
4
  colorFrom: red
5
  colorTo: yellow
6
- sdk: docker
7
  sdk_version: 4.7.1
8
  app_file: main.py
9
  pinned: true
10
  license: mit
11
- app_port: 7860
12
  ---
13
 
14
  # Bachelor Thesis
 
3
  emoji: 🎓
4
  colorFrom: red
5
  colorTo: yellow
6
+ sdk: gradio
7
  sdk_version: 4.7.1
8
  app_file: main.py
9
  pinned: true
10
  license: mit
11
+ app_port: 8080
12
  ---
13
 
14
  # Bachelor Thesis
backend/controller.py CHANGED
@@ -10,7 +10,6 @@ from explanation import interpret, visualize
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
13
- # TODO: Limit maximum tokens/model input
14
  def interference(
15
  prompt: str,
16
  history: list,
 
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
 
13
  def interference(
14
  prompt: str,
15
  history: list,
entrypoint.sh DELETED
@@ -1,8 +0,0 @@
1
- #!/bin/bash
2
- # entrypoint script for the docker container to run at start
3
-
4
- # installing all the dependencies
5
- pip install --no-cache-dir --upgrade -r requirements.txt
6
-
7
- # running the fastapi app
8
- uvicorn main:app --host 0.0.0.0 --port 8080
 
 
 
 
 
 
 
 
 
explanation/interpret.py CHANGED
@@ -62,35 +62,29 @@ def create_graphic(shap_values):
62
  return str(graphic_html)
63
 
64
 
65
- # plotting function that creates a heatmap style explanation plot
 
 
66
  def create_plot(shap_values):
67
  values = shap_values.values[0]
68
  output_names = shap_values.output_names
69
  input_names = shap_values.data[0]
70
 
71
- # Transpose the values for horizontal input names
72
- transposed_values = np.transpose(values)
73
-
74
  # Set seaborn style to dark
75
- sns.set(style="dark")
76
-
77
  fig, ax = plt.subplots()
78
 
79
- # Making background transparent
80
- ax.set_alpha(0)
81
- fig.patch.set_alpha(0)
82
-
83
  # Setting figure size
84
  fig.set_size_inches(
85
- max(transposed_values.shape[1] * 2, 10),
86
- max(transposed_values.shape[0] / 1.5, 5),
87
  )
88
 
89
  # Plotting the heatmap with Seaborn's color palette
90
  im = ax.imshow(
91
- transposed_values,
92
- vmax=transposed_values.max(),
93
- vmin=-transposed_values.min(),
94
  cmap=sns.color_palette("vlag_r", as_cmap=True),
95
  aspect="auto",
96
  )
@@ -98,25 +92,25 @@ def create_plot(shap_values):
98
  # Creating colorbar
99
  cbar = ax.figure.colorbar(im, ax=ax)
100
  cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
101
- cbar.ax.yaxis.set_tick_params(color="white")
102
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
103
 
104
  # Setting ticks and labels with white color for visibility
105
- ax.set_xticks(np.arange(len(input_names)), labels=input_names)
106
- ax.set_yticks(np.arange(len(output_names)), labels=output_names)
107
- plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
108
- plt.setp(ax.get_yticklabels(), color="white")
109
 
110
  # Adjusting tick labels
111
  ax.tick_params(
112
  top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
113
  )
114
 
115
- # Adding text annotations - not used for readability
116
- # for i in range(transposed_values.shape[0]):
117
- # for j in range(transposed_values.shape[1]):
118
- # val = transposed_values[i, j]
119
- # color = "black" if 0.2 < im.norm(val) < 0.8 else "white"
120
- # ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
121
 
122
  return plt
 
62
  return str(graphic_html)
63
 
64
 
65
+ # creating an attention heatmap plot using matplotlib/seaborn
66
+ # CREDIT: adopted from official Matplotlib documentation
67
+ ## see https://matplotlib.org/stable/
68
  def create_plot(shap_values):
69
  values = shap_values.values[0]
70
  output_names = shap_values.output_names
71
  input_names = shap_values.data[0]
72
 
 
 
 
73
  # Set seaborn style to dark
74
+ sns.set(style="white")
 
75
  fig, ax = plt.subplots()
76
 
 
 
 
 
77
  # Setting figure size
78
  fig.set_size_inches(
79
+ max(values.shape[1] * 2, 10),
80
+ max(values.shape[0] * 1, 5),
81
  )
82
 
83
  # Plotting the heatmap with Seaborn's color palette
84
  im = ax.imshow(
85
+ values,
86
+ vmax=values.max(),
87
+ vmin=values.min(),
88
  cmap=sns.color_palette("vlag_r", as_cmap=True),
89
  aspect="auto",
90
  )
 
92
  # Creating colorbar
93
  cbar = ax.figure.colorbar(im, ax=ax)
94
  cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
95
+ cbar.ax.yaxis.set_tick_params(color="black")
96
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
97
 
98
  # Setting ticks and labels with white color for visibility
99
+ ax.set_yticks(np.arange(len(input_names)), labels=input_names)
100
+ ax.set_xticks(np.arange(len(output_names)), labels=output_names)
101
+ plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
102
+ plt.setp(ax.get_yticklabels(), color="black")
103
 
104
  # Adjusting tick labels
105
  ax.tick_params(
106
  top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
107
  )
108
 
109
+ # Adding text annotations with appropriate contrast
110
+ for i in range(values.shape[0]):
111
+ for j in range(values.shape[1]):
112
+ val = values[i, j]
113
+ color = "white" if im.norm(values.max()) / 2 > im.norm(val) else "black"
114
+ ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
115
 
116
  return plt
explanation/visualize.py CHANGED
@@ -57,28 +57,27 @@ def create_graphic(attention_output, enc_dec_texts: tuple):
57
  return str(hview.data)
58
 
59
 
60
- # creating an attention heatmap plot using seaborn
 
 
61
  def create_plot(attention_output, enc_dec_texts: tuple):
62
  # get the averaged attention weights
63
  attention = attention_output.cross_attentions[0][0].detach().numpy()
64
  averaged_attention_weights = np.mean(attention, axis=0)
 
65
 
66
- # get the encoder and decoder tokens
67
  encoder_tokens = enc_dec_texts[0]
68
  decoder_tokens = enc_dec_texts[1]
69
 
70
  # set seaborn style to dark and initialize figure and axis
71
- sns.set(style="dark")
72
  fig, ax = plt.subplots()
73
 
74
- # Making background transparent
75
- ax.set_alpha(0)
76
- fig.patch.set_alpha(0)
77
-
78
  # Setting figure size
79
  fig.set_size_inches(
80
  max(averaged_attention_weights.shape[1] * 2, 10),
81
- max(averaged_attention_weights.shape[0] / 1.5, 5),
82
  )
83
 
84
  # Plotting the heatmap with seaborn's color palette
@@ -92,19 +91,27 @@ def create_plot(attention_output, enc_dec_texts: tuple):
92
 
93
  # Creating colorbar
94
  cbar = ax.figure.colorbar(im, ax=ax)
95
- cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
96
- cbar.ax.yaxis.set_tick_params(color="white")
97
- plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
98
-
99
- # Setting ticks and labels with white color for visibility
100
- ax.set_xticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
101
- ax.set_yticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
102
- plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
103
- plt.setp(ax.get_yticklabels(), color="white")
104
-
105
- # Adjusting tick labels
106
- ax.tick_params(
107
- top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
108
- )
109
-
 
 
 
 
 
 
 
 
110
  return plt
 
57
  return str(hview.data)
58
 
59
 
60
+ # creating an attention heatmap plot using matplotlib/seaborn
61
+ # CREDIT: adopted from official Matplotlib documentation
62
+ ## see https://matplotlib.org/stable/
63
  def create_plot(attention_output, enc_dec_texts: tuple):
64
  # get the averaged attention weights
65
  attention = attention_output.cross_attentions[0][0].detach().numpy()
66
  averaged_attention_weights = np.mean(attention, axis=0)
67
+ averaged_attention_weights = np.transpose(averaged_attention_weights)
68
 
69
+ # get the encoder and decoder tokens in text form
70
  encoder_tokens = enc_dec_texts[0]
71
  decoder_tokens = enc_dec_texts[1]
72
 
73
  # set seaborn style to dark and initialize figure and axis
74
+ sns.set(style="white")
75
  fig, ax = plt.subplots()
76
 
 
 
 
 
77
  # Setting figure size
78
  fig.set_size_inches(
79
  max(averaged_attention_weights.shape[1] * 2, 10),
80
+ max(averaged_attention_weights.shape[0] * 1, 5),
81
  )
82
 
83
  # Plotting the heatmap with seaborn's color palette
 
91
 
92
  # Creating colorbar
93
  cbar = ax.figure.colorbar(im, ax=ax)
94
+ cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
95
+ cbar.ax.yaxis.set_tick_params(color="black")
96
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")
97
+
98
+ # Setting ticks and labels with black color for visibility
99
+ ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
100
+ ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
101
+ ax.set_title("Attention Weights by Token")
102
+ plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
103
+ plt.setp(ax.get_yticklabels(), color="black")
104
+
105
+ # Adding text annotations with appropriate contrast
106
+ for i in range(averaged_attention_weights.shape[0]):
107
+ for j in range(averaged_attention_weights.shape[1]):
108
+ val = averaged_attention_weights[i, j]
109
+ color = (
110
+ "white"
111
+ if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
112
+ else "black"
113
+ )
114
+ ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
115
+
116
+ # return the plot
117
  return plt
main.py CHANGED
@@ -1,12 +1,19 @@
1
  # main application file initializing the gradio based ui and calling other
 
 
 
 
2
  # external imports
3
  from fastapi import FastAPI
4
  import markdown
5
  import gradio as gr
 
 
6
 
7
  # internal imports
8
  from backend.controller import interference
9
 
 
10
  # Global Variables and css
11
  app = FastAPI()
12
  css = "body {text-align: start !important;}"
@@ -187,7 +194,7 @@ with gr.Blocks(
187
  Values have been excluded for readability. See colorbar for value indication.
188
  """)
189
  # plot component that takes a matplotlib figure as input
190
- xai_plot = gr.Plot(label="Token Level Explanation", scale=3)
191
 
192
  # functions to trigger the controller
193
  ## takes information for the chat and the xai selection
@@ -207,16 +214,21 @@ with gr.Blocks(
207
 
208
  # final row to show legal information
209
  ## - credits, data protection and link to the License
210
- with gr.Tab(label="Credits, Data Protection and License"):
211
- gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
 
 
212
 
213
  # mount function for fastAPI Application
214
  app = gr.mount_gradio_app(app, ui, path="/")
215
 
216
  # launch function using uvicorn to launch the fastAPI application
217
  if __name__ == "__main__":
218
- from uvicorn import run
219
 
220
- # run the application on port 8080 in reload mode
 
 
 
 
221
  ## for local development, uses Docker for Prod deployment
222
  run("main:app", port=8080, reload=True)
 
1
  # main application file initializing the gradio based ui and calling other
2
+
3
+ # standard imports
4
+ import os
5
+
6
  # external imports
7
  from fastapi import FastAPI
8
  import markdown
9
  import gradio as gr
10
+ from uvicorn import run
11
+
12
 
13
  # internal imports
14
  from backend.controller import interference
15
 
16
+
17
  # Global Variables and css
18
  app = FastAPI()
19
  css = "body {text-align: start !important;}"
 
194
  Values have been excluded for readability. See colorbar for value indication.
195
  """)
196
  # plot component that takes a matplotlib figure as input
197
+ xai_plot = gr.Plot(label="Token Level Explanation")
198
 
199
  # functions to trigger the controller
200
  ## takes information for the chat and the xai selection
 
214
 
215
  # final row to show legal information
216
  ## - credits, data protection and link to the License
217
+ with gr.Tab(label="About"):
218
+ gr.Markdown(value=load_md("public/about.md"))
219
+ with gr.Accordion(label="Credits, Data Protection, License"):
220
+ gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
221
 
222
  # mount function for fastAPI Application
223
  app = gr.mount_gradio_app(app, ui, path="/")
224
 
225
  # launch function using uvicorn to launch the fastAPI application
226
  if __name__ == "__main__":
 
227
 
228
+ # use standard gradio launch option for hgf spaces
229
+ if os.environ["HOSTING"].lower() == "spaces":
230
+ ui.launch(auth=("htw", "berlin@123"))
231
+
232
+ # otherwise run the application on port 8080 in reload mode
233
  ## for local development, uses Docker for Prod deployment
234
  run("main:app", port=8080, reload=True)