LennardZuendorf commited on
Commit
43cce2a
โ€ข
1 Parent(s): 7e6f74e

feat: implementing everything for release version 1.0.0

Browse files
.dockerignore CHANGED
@@ -3,6 +3,10 @@ Compose.yaml
3
  Dockerfile-Base
4
  Dockerfile-Light
5
  entrypoint.sh
6
- railway.json
 
 
 
 
7
  /components/
8
  /components/*
 
3
  Dockerfile-Base
4
  Dockerfile-Light
5
  entrypoint.sh
6
+ .gitignore
7
+ .github
8
+ .git
9
+ .pre-commit-config.yaml
10
+ start-venv.sh
11
  /components/
12
  /components/*
README.md CHANGED
@@ -17,28 +17,26 @@ app_port: 8080
17
 
18
  ## ๐Ÿ”— Links:
19
 
20
- **[Github Repository](https://github.com/LennardZuendorf/thesis)**
 
21
 
22
  ## ๐Ÿ—๏ธ Tech Stack:
23
 
24
- **Language and Framework:** Python, JupyterNotebook
25
 
26
- **Noteable Packages:** ๐Ÿค— Transformers, Gradio, SHAP, BERTViz, Shapash
27
 
28
  ## ๐Ÿ‘จโ€๐Ÿ’ป Author and Credits:</h2>
29
 
30
-
31
  **Author:** [@LennardZuendorf](https://github.com/LennardZuendorf)
32
 
33
  **Thesis Supervisor**: [Prof. Dr. Simbeck](https://www.htw-berlin.de/hochschule/personen/person/?eid=9862)
34
  <br> Second Corrector: [Prof. Dr. Hochstein](https://www.htw-berlin.de/hochschule/personen/person/?eid=10628)
35
 
 
36
 
37
- See code for in detail credits, work is based on
38
-
39
- - Mistral AI
40
  - SHAP:
41
  - BERTViz:
42
 
43
-
44
  This Project was part of my studies of Business Computing at University of Applied Science for Technology and Business Berlin (HTW Berlin).
 
17
 
18
  ## ๐Ÿ”— Links:
19
 
20
+ **[Github Repository](https://github.com/LennardZuendorf/thesis-webapp)**
21
+ **[Huggingface Spaces Showcase](https://huggingface.co/spaces/lennardzuendorf/thesis-webapp-docker**
22
 
23
  ## ๐Ÿ—๏ธ Tech Stack:
24
 
25
+ **Language and Framework:** Python
26
 
27
+ **Noteable Packages:** ๐Ÿค— Transformers, FastAPI, Gradio, SHAP, BERTViz
28
 
29
  ## ๐Ÿ‘จโ€๐Ÿ’ป Author and Credits:</h2>
30
 
 
31
  **Author:** [@LennardZuendorf](https://github.com/LennardZuendorf)
32
 
33
  **Thesis Supervisor**: [Prof. Dr. Simbeck](https://www.htw-berlin.de/hochschule/personen/person/?eid=9862)
34
  <br> Second Corrector: [Prof. Dr. Hochstein](https://www.htw-berlin.de/hochschule/personen/person/?eid=10628)
35
 
36
+ See code for in detailed credits, work is based on
37
 
38
+ - GODEL:
 
 
39
  - SHAP:
40
  - BERTViz:
41
 
 
42
  This Project was part of my studies of Business Computing at University of Applied Science for Technology and Business Berlin (HTW Berlin).
backend/controller.py CHANGED
@@ -5,17 +5,18 @@
5
  import gradio as gr
6
 
7
  # internal imports
8
- from model import mistral, godel
9
  from explanation import interpret, visualize
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
 
13
  def interference(
14
- prompt,
15
- history,
16
- system_prompt,
17
- model_selection,
18
- xai_selection,
19
  ):
20
  # if no system prompt is given, use a default one
21
  if system_prompt == "":
@@ -24,20 +25,7 @@ def interference(
24
  Always answer as helpfully as possible, while being safe.
25
  """
26
 
27
- # grabs the model instance depending on the selection
28
- match model_selection.lower():
29
- case "mistral":
30
- model = mistral
31
- case "godel":
32
- model = godel
33
- case _:
34
- # use Gradio warning to display error message
35
- gr.Warning(
36
- f'There was an error in the selected model. It is "{model_selection}"'
37
- )
38
- raise RuntimeError("There was an error in the selected model.")
39
-
40
- # additionally, if the XAI approach is selected, grab the XAI instance
41
  if xai_selection in ("SHAP", "Visualizer"):
42
  match xai_selection.lower():
43
  case "shap":
@@ -46,33 +34,39 @@ def interference(
46
  xai = visualize
47
  case _:
48
  # use Gradio warning to display error message
49
- gr.Warning(
50
- f"""
51
  There was an error in the selected XAI Approach.
52
  It is "{xai_selection}"
53
- """
54
- )
55
  raise RuntimeError("There was an error in the selected XAI approach.")
56
 
57
  # call the explained chat function
58
  prompt_output, history_output, xai_graphic, xai_plot = explained_chat(
59
- model=model,
60
  xai=xai,
61
  message=prompt,
62
  history=history,
63
  system_prompt=system_prompt,
 
64
  )
65
  # if no (or invalid) XAI approach is selected call the vanilla chat function
66
  else:
67
  # call the vanilla chat function
68
  prompt_output, history_output = vanilla_chat(
69
- model=model,
70
  message=prompt,
71
  history=history,
72
  system_prompt=system_prompt,
 
73
  )
74
  # set XAI outputs to disclaimer html/none
75
- xai_graphic, xai_plot = "<div><h1>No Graphic to Display</h1></div>", None
 
 
 
 
 
 
76
 
77
  # return the outputs
78
  return prompt_output, history_output, xai_graphic, xai_plot
@@ -80,27 +74,31 @@ def interference(
80
 
81
  # simple chat function that calls the model
82
  # formats prompts, calls for an answer and returns updated conversation history
83
- def vanilla_chat(model, message: str, history: list, system_prompt: str):
 
 
84
  # formatting the prompt using the model's format_prompt function
85
- prompt = model.format_prompt(message, history, system_prompt)
86
  # generating an answer using the model's respond function
87
  answer = model.respond(prompt)
88
 
89
  # updating the chat history with the new answer
90
- history.append((prompt, answer))
91
 
92
  # returning the updated history
93
  return "", history
94
 
95
 
96
- def explained_chat(model, xai, message: str, history: list, system_prompt: str):
 
 
97
  # formatting the prompt using the model's format_prompt function
98
- prompt = model.format_prompt(message, history, system_prompt)
99
 
100
  # generating an answer using the xai methods explain and respond function
101
  answer, xai_graphic, xai_plot = xai.chat_explained(model, prompt)
102
  # updating the chat history with the new answer
103
- history.append((prompt, answer))
104
 
105
  # returning the updated history, xai graphic and xai plot elements
106
- return "", [["", ""]], xai_graphic, xai_plot
 
5
  import gradio as gr
6
 
7
  # internal imports
8
+ from model import godel
9
  from explanation import interpret, visualize
10
 
11
 
12
  # main interference function that that calls chat functions depending on selections
13
+ # TODO: Limit maximum tokens/model input
14
  def interference(
15
+ prompt: str,
16
+ history: list,
17
+ knowledge: str,
18
+ system_prompt: str,
19
+ xai_selection: str,
20
  ):
21
  # if no system prompt is given, use a default one
22
  if system_prompt == "":
 
25
  Always answer as helpfully as possible, while being safe.
26
  """
27
 
28
+ # if a XAI approach is selected, grab the XAI instance
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  if xai_selection in ("SHAP", "Visualizer"):
30
  match xai_selection.lower():
31
  case "shap":
 
34
  xai = visualize
35
  case _:
36
  # use Gradio warning to display error message
37
+ gr.Warning(f"""
 
38
  There was an error in the selected XAI Approach.
39
  It is "{xai_selection}"
40
+ """)
 
41
  raise RuntimeError("There was an error in the selected XAI approach.")
42
 
43
  # call the explained chat function
44
  prompt_output, history_output, xai_graphic, xai_plot = explained_chat(
45
+ model=godel,
46
  xai=xai,
47
  message=prompt,
48
  history=history,
49
  system_prompt=system_prompt,
50
+ knowledge=knowledge,
51
  )
52
  # if no (or invalid) XAI approach is selected call the vanilla chat function
53
  else:
54
  # call the vanilla chat function
55
  prompt_output, history_output = vanilla_chat(
56
+ model=godel,
57
  message=prompt,
58
  history=history,
59
  system_prompt=system_prompt,
60
+ knowledge=knowledge,
61
  )
62
  # set XAI outputs to disclaimer html/none
63
+ xai_graphic, xai_plot = (
64
+ """
65
+ <div style="text-align: center"><h4>Without Selected XAI Approach,
66
+ no graphic will be displayed</h4></div>
67
+ """,
68
+ None,
69
+ )
70
 
71
  # return the outputs
72
  return prompt_output, history_output, xai_graphic, xai_plot
 
74
 
75
  # simple chat function that calls the model
76
  # formats prompts, calls for an answer and returns updated conversation history
77
+ def vanilla_chat(
78
+ model, message: str, history: list, system_prompt: str, knowledge: str = ""
79
+ ):
80
  # formatting the prompt using the model's format_prompt function
81
+ prompt = model.format_prompt(message, history, system_prompt, knowledge)
82
  # generating an answer using the model's respond function
83
  answer = model.respond(prompt)
84
 
85
  # updating the chat history with the new answer
86
+ history.append((message, answer))
87
 
88
  # returning the updated history
89
  return "", history
90
 
91
 
92
+ def explained_chat(
93
+ model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
94
+ ):
95
  # formatting the prompt using the model's format_prompt function
96
+ prompt = model.format_prompt(message, history, system_prompt, knowledge)
97
 
98
  # generating an answer using the xai methods explain and respond function
99
  answer, xai_graphic, xai_plot = xai.chat_explained(model, prompt)
100
  # updating the chat history with the new answer
101
+ history.append((message, answer))
102
 
103
  # returning the updated history, xai graphic and xai plot elements
104
+ return "", history, xai_graphic, xai_plot
explanation/interpret.py CHANGED
@@ -3,40 +3,60 @@
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- import shap
 
 
 
 
 
 
 
 
7
 
8
 
9
  # main explain function that returns a chat with explanations
10
  def chat_explained(model, prompt):
 
 
11
  # create the shap explainer
12
- shap_explainer = shap.PartitionExplainer(model.MODEL, model.TOKENIZER)
13
  # get the shap values for the prompt
14
- shap_values = shap_explainer(prompt)
15
 
16
  # create the explanation graphic and plot
17
  graphic = create_graphic(shap_values)
18
  plot = create_plot(shap_values)
19
 
20
  # create the response text
21
- response_text = format_output_text(shap_values.output_names)
22
  return response_text, graphic, plot
23
 
24
 
25
- # output text formatting function that turns the list into a string
26
- def format_output_text(output):
27
- # start string with first list item
28
- output_str = output[0]
29
- # add all other list items with a space in between
30
- for txt in output[1:]:
31
- output_str += " " + txt
32
- # return the output string
33
- return output_str
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  # graphic plotting function that creates a html graphic (as string) for the explanation
37
  def create_graphic(shap_values):
38
  # create the html graphic using shap text plot function
39
- graphic_html = shap.plots.text(shap_values, display=False)
40
 
41
  # return the html graphic as string
42
  return str(graphic_html)
@@ -44,42 +64,59 @@ def create_graphic(shap_values):
44
 
45
  # plotting function that creates a heatmap style explanation plot
46
  def create_plot(shap_values):
47
- # setup color palette for heatmap
48
- color_palette = sns.color_palette("coolwarm", as_cmap=True)
49
-
50
- # extract values, text from shap_values
51
- values = shap_values[0]
52
- input_text = shap_values.data[0]
53
- output_text = shap_values.output_names
54
-
55
- # Set the seaborn style for better aesthetics
56
- sns.set(style="darkgrid")
57
- plt.figure(figsize=(20, 10))
58
-
59
- # create the heatmap with horizontal shape
60
- sns.heatmap(
61
- values,
62
- cmap=color_palette,
63
- center=0,
64
- annot=False,
65
- cbar_kws={"fraction": 0.02},
 
66
  )
67
 
68
- # adjusting labels and ticks
69
- plt.xticks(
70
- ticks=np.arange(len(output_text)) + 0.5,
71
- labels=output_text,
72
- rotation=90,
 
 
73
  )
74
- plt.yticks(
75
- ticks=np.arange(len(input_text)) + 0.5,
76
- labels=input_text,
77
- rotation=0,
 
 
 
 
 
 
 
 
 
 
 
 
78
  )
79
 
80
- # set axis labels
81
- plt.xlabel("Output Tokens")
82
- plt.ylabel("Input Tokens")
83
- plt.title("Token-wise SHAP Values")
 
 
84
 
85
  return plt
 
3
  import seaborn as sns
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
+ from shap import models, maskers, plots, PartitionExplainer
7
+ import torch
8
+
9
+ # internal imports
10
+ from utils import formatting as fmt
11
+
12
+ # global variables
13
+ TEACHER_FORCING = None
14
+ TEXT_MASKER = None
15
 
16
 
17
  # main explain function that returns a chat with explanations
18
  def chat_explained(model, prompt):
19
+ model.set_config()
20
+
21
  # create the shap explainer
22
+ shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
23
  # get the shap values for the prompt
24
+ shap_values = shap_explainer([prompt])
25
 
26
  # create the explanation graphic and plot
27
  graphic = create_graphic(shap_values)
28
  plot = create_plot(shap_values)
29
 
30
  # create the response text
31
+ response_text = fmt.format_output_text(shap_values.output_names)
32
  return response_text, graphic, plot
33
 
34
 
35
+ def wrap_shap(model):
36
+ global TEXT_MASKER, TEACHER_FORCING
37
+
38
+ # set the device to cuda if gpu is available
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+
41
+ # updating the model settings again
42
+ model.set_config()
43
+
44
+ # (re)initialize the shap models and masker
45
+ text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
46
+ TEACHER_FORCING = models.TeacherForcing(
47
+ text_generation,
48
+ model.TOKENIZER,
49
+ device=str(device),
50
+ similarity_model=model.MODEL,
51
+ similarity_tokenizer=model.TOKENIZER,
52
+ )
53
+ TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
54
 
55
 
56
  # graphic plotting function that creates a html graphic (as string) for the explanation
57
  def create_graphic(shap_values):
58
  # create the html graphic using shap text plot function
59
+ graphic_html = plots.text(shap_values, display=False)
60
 
61
  # return the html graphic as string
62
  return str(graphic_html)
 
64
 
65
  # plotting function that creates a heatmap style explanation plot
66
  def create_plot(shap_values):
67
+ values = shap_values.values[0]
68
+ output_names = shap_values.output_names
69
+ input_names = shap_values.data[0]
70
+
71
+ # Transpose the values for horizontal input names
72
+ transposed_values = np.transpose(values)
73
+
74
+ # Set seaborn style to dark
75
+ sns.set(style="dark")
76
+
77
+ fig, ax = plt.subplots()
78
+
79
+ # Making background transparent
80
+ ax.set_alpha(0)
81
+ fig.patch.set_alpha(0)
82
+
83
+ # Setting figure size
84
+ fig.set_size_inches(
85
+ max(transposed_values.shape[1] * 2, 10),
86
+ max(transposed_values.shape[0] / 1.5, 5),
87
  )
88
 
89
+ # Plotting the heatmap with Seaborn's color palette
90
+ im = ax.imshow(
91
+ transposed_values,
92
+ vmax=transposed_values.max(),
93
+ vmin=-transposed_values.min(),
94
+ cmap=sns.color_palette("vlag_r", as_cmap=True),
95
+ aspect="auto",
96
  )
97
+
98
+ # Creating colorbar
99
+ cbar = ax.figure.colorbar(im, ax=ax)
100
+ cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
101
+ cbar.ax.yaxis.set_tick_params(color="white")
102
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
103
+
104
+ # Setting ticks and labels with white color for visibility
105
+ ax.set_xticks(np.arange(len(input_names)), labels=input_names)
106
+ ax.set_yticks(np.arange(len(output_names)), labels=output_names)
107
+ plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
108
+ plt.setp(ax.get_yticklabels(), color="white")
109
+
110
+ # Adjusting tick labels
111
+ ax.tick_params(
112
+ top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
113
  )
114
 
115
+ # Adding text annotations - not used for readability
116
+ # for i in range(transposed_values.shape[0]):
117
+ # for j in range(transposed_values.shape[1]):
118
+ # val = transposed_values[i, j]
119
+ # color = "black" if 0.2 < im.norm(val) < 0.8 else "white"
120
+ # ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)
121
 
122
  return plt
explanation/visualize.py CHANGED
@@ -2,20 +2,109 @@
2
 
3
  # external imports
4
  from bertviz import head_view
 
 
 
 
 
 
5
 
6
 
7
  # plotting function that plots the attention values in a heatmap
8
  def chat_explained(model, prompt):
9
- inputs = model.TOKENIZER(prompt, return_tensors="pt")
10
- out = model.MODEL(**inputs, output_attentions=True)
11
 
12
- attention = out["attentions"] # Retrieve attention from model outputs
13
- tokens = model.TOKENIZER.convert_ids_to_tokens(
14
- inputs["input_ids"][0]
15
- ) # Convert input ids to token strings
 
 
 
 
 
 
 
 
 
16
 
17
- graphic = head_view(attention, tokens)
18
- response_text = out[0]
19
- plot = None
 
 
 
20
 
 
 
 
 
21
  return response_text, graphic, plot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  # external imports
4
  from bertviz import head_view
5
+ import matplotlib.pyplot as plt
6
+ import seaborn as sns
7
+ import numpy as np
8
+
9
+ # internal imports
10
+ from utils import formatting as fmt
11
 
12
 
13
  # plotting function that plots the attention values in a heatmap
14
  def chat_explained(model, prompt):
 
 
15
 
16
+ model.set_config()
17
+
18
+ # get encoded input and output vectors
19
+ encoder_input_ids = model.TOKENIZER(
20
+ prompt, return_tensors="pt", add_special_tokens=True
21
+ ).input_ids
22
+ decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
23
+ encoder_text = fmt.format_tokens(
24
+ model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
25
+ )
26
+ decoder_text = fmt.format_tokens(
27
+ model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
28
+ )
29
 
30
+ # get attention values for the input and output vectors
31
+ attention_output = model.MODEL(
32
+ input_ids=encoder_input_ids,
33
+ decoder_input_ids=decoder_input_ids,
34
+ output_attentions=True,
35
+ )
36
 
37
+ # create the response text, graphic and plot
38
+ response_text = fmt.format_output_text(decoder_text)
39
+ graphic = create_graphic(attention_output, (encoder_text, decoder_text))
40
+ plot = create_plot(attention_output, (encoder_text, decoder_text))
41
  return response_text, graphic, plot
42
+
43
+
44
+ # creating a html graphic using BERTViz
45
+ def create_graphic(attention_output, enc_dec_texts: tuple):
46
+
47
+ # calls the head_view function of BERTViz to return html graphic
48
+ hview = head_view(
49
+ encoder_attention=attention_output.encoder_attentions,
50
+ decoder_attention=attention_output.decoder_attentions,
51
+ cross_attention=attention_output.cross_attentions,
52
+ encoder_tokens=enc_dec_texts[0],
53
+ decoder_tokens=enc_dec_texts[1],
54
+ html_action="return",
55
+ )
56
+
57
+ return str(hview.data)
58
+
59
+
60
+ # creating an attention heatmap plot using seaborn
61
+ def create_plot(attention_output, enc_dec_texts: tuple):
62
+ # get the averaged attention weights
63
+ attention = attention_output.cross_attentions[0][0].detach().numpy()
64
+ averaged_attention_weights = np.mean(attention, axis=0)
65
+
66
+ # get the encoder and decoder tokens
67
+ encoder_tokens = enc_dec_texts[0]
68
+ decoder_tokens = enc_dec_texts[1]
69
+
70
+ # set seaborn style to dark and initialize figure and axis
71
+ sns.set(style="dark")
72
+ fig, ax = plt.subplots()
73
+
74
+ # Making background transparent
75
+ ax.set_alpha(0)
76
+ fig.patch.set_alpha(0)
77
+
78
+ # Setting figure size
79
+ fig.set_size_inches(
80
+ max(averaged_attention_weights.shape[1] * 2, 10),
81
+ max(averaged_attention_weights.shape[0] / 1.5, 5),
82
+ )
83
+
84
+ # Plotting the heatmap with seaborn's color palette
85
+ im = ax.imshow(
86
+ averaged_attention_weights,
87
+ vmax=averaged_attention_weights.max(),
88
+ vmin=-averaged_attention_weights.min(),
89
+ cmap=sns.color_palette("rocket", as_cmap=True),
90
+ aspect="auto",
91
+ )
92
+
93
+ # Creating colorbar
94
+ cbar = ax.figure.colorbar(im, ax=ax)
95
+ cbar.ax.set_ylabel("Token Attribution", rotation=-90, va="bottom")
96
+ cbar.ax.yaxis.set_tick_params(color="white")
97
+ plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="white")
98
+
99
+ # Setting ticks and labels with white color for visibility
100
+ ax.set_xticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
101
+ ax.set_yticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
102
+ plt.setp(ax.get_xticklabels(), color="white", rotation=45, ha="right")
103
+ plt.setp(ax.get_yticklabels(), color="white")
104
+
105
+ # Adjusting tick labels
106
+ ax.tick_params(
107
+ top=True, bottom=False, labeltop=False, labelbottom=True, color="white"
108
+ )
109
+
110
+ return plt
main.py CHANGED
@@ -7,8 +7,9 @@ import gradio as gr
7
  # internal imports
8
  from backend.controller import interference
9
 
10
- # Global Variables
11
  app = FastAPI()
 
12
 
13
 
14
  # different functions to provide frontend abilities
@@ -36,37 +37,33 @@ def xai_info(xai_radio):
36
  gr.Info("No XAI method was selected.")
37
 
38
 
39
- # function to display the model info
40
- def model_info(model_radio):
41
- # display the model using the Gradio Info component
42
- gr.Info(f"The model was set to:\n {model_radio}")
43
-
44
-
45
  # ui interface based on Gradio Blocks (see documentation:
46
  # https://www.gradio.app/docs/interface)
47
- with gr.Blocks() as ui:
 
 
 
 
48
  # header row with markdown based text
49
  with gr.Row():
50
  # markdown component to display the header
51
- gr.Markdown(
52
- """
53
- # Thesis Demo - AI Chat Application with XAI
54
  ### Select between tabs below for the different views.
55
- """
56
- )
57
  # ChatBot tab used to chat with the AI chatbot
58
  with gr.Tab("AI ChatBot"):
59
  with gr.Row():
60
  # markdown component to display the header of the current tab
61
- gr.Markdown(
62
- """
63
  ### ChatBot Demo
64
  Chat with the AI ChatBot using the textbox below.
65
  Manipulate the settings in the row above,
66
  including the selection of the model,
67
  the system prompt and the XAI method.
68
- """
69
- )
70
  # row with columns for the different settings
71
  with gr.Row(equal_height=True):
72
  # column that takes up 3/5 of the row
@@ -80,22 +77,12 @@ with gr.Blocks() as ui:
80
  " answer as helpfully as possible, while being safe."
81
  ),
82
  )
83
- with gr.Column(scale=1):
84
- # checkbox group to select the model
85
- model = gr.Radio(
86
- ["Mistral", "GODEL"],
87
- label="Model Selection",
88
- info="Select Model to use for chat.",
89
- value="Mistral",
90
- interactive=True,
91
- show_label=True,
92
- )
93
  with gr.Column(scale=1):
94
  # checkbox group to select the xai method
95
- xai = gr.Radio(
96
  ["None", "SHAP", "Visualizer"],
97
  label="XAI Settings",
98
- info="XAI Functionalities to use.",
99
  value="None",
100
  interactive=True,
101
  show_label=True,
@@ -103,11 +90,10 @@ with gr.Blocks() as ui:
103
 
104
  # calling info functions on inputs for different settings
105
  system_prompt.submit(system_prompt_info, [system_prompt])
106
- model.input(model_info, [model])
107
- xai.input(xai_info, [xai])
108
 
109
  # row with chatbot ui displaying "conversation" with the model
110
- with gr.Row():
111
  # out of the box chatbot component
112
  # see documentation: https://www.gradio.app/docs/chatbot
113
  chatbot = gr.Chatbot(
@@ -115,10 +101,28 @@ with gr.Blocks() as ui:
115
  show_copy_button=True,
116
  avatar_images=("./public/human.jpg", "./public/bot.jpg"),
117
  )
118
- # row with input textbox
 
 
 
 
 
 
 
 
 
 
119
  with gr.Row():
120
  # textbox to enter the user prompt
121
- user_prompt = gr.Textbox(label="Input Message")
 
 
 
 
 
 
 
 
122
  # row with columns for buttons to submit and clear content
123
  with gr.Row():
124
  with gr.Column(scale=1):
@@ -127,79 +131,84 @@ with gr.Blocks() as ui:
127
  clear_btn = gr.ClearButton([user_prompt, chatbot])
128
  with gr.Column(scale=1):
129
  submit_btn = gr.Button("Submit", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # explanations tab used to provide explanations for a specific conversation
132
  with gr.Tab("Explanations"):
133
  # row with markdown component to display the header of the current tab
134
  with gr.Row():
135
- gr.Markdown(
136
- """
137
  ### Get Explanations for Conversations
138
  Using your selected XAI method, you can get explanations for
139
  the conversation you had with the AI ChatBot. The explanations are
140
  based on the last message you sent to the AI ChatBot (see text)
141
- """
142
- )
143
- # row that displays the settings used to create the current model output
144
- ## each textbox statically displays the current values
145
- with gr.Row():
146
- with gr.Column():
147
- gr.Textbox(
148
- value=xai,
149
- label="Used XAI Variant",
150
- show_label=True,
151
- interactive=True,
152
- )
153
- with gr.Column():
154
- gr.Textbox(
155
- value=model, label="Used Model", show_label=True, interactive=True
156
- )
157
- with gr.Column():
158
- gr.Textbox(
159
- value=system_prompt,
160
- label="Used System Prompt",
161
- show_label=True,
162
- interactive=True,
163
- )
164
  # row that displays the generated explanation of the model (if applicable)
165
- with gr.Row():
166
- # wraps the explanation html in an iframe to display it
167
  xai_interactive = gr.HTML(
168
  label="Interactive Explanation",
 
 
 
 
169
  show_label=True,
170
- value="<div><h1>No Graphic to Display</h1></div>",
171
  )
172
  # row and accordion to display an explanation plot (if applicable)
173
  with gr.Row():
174
  with gr.Accordion("Token Explanation Plot", open=False):
 
 
 
 
175
  # plot component that takes a matplotlib figure as input
176
- xai_plot = gr.Plot(
177
- label="Token Level Explanation",
178
- show_label=True,
179
- every=5,
180
- )
181
 
182
  # functions to trigger the controller
183
- ## takes information for the chat and the model, xai selection
184
  ## returns prompt, history and xai data
185
  ## see backend/controller.py for more information
186
  submit_btn.click(
187
  interference,
188
- [user_prompt, chatbot, system_prompt, model, xai],
189
  [user_prompt, chatbot, xai_interactive, xai_plot],
190
  )
191
  # function triggered by the enter key
192
  user_prompt.submit(
193
  interference,
194
- [user_prompt, chatbot, system_prompt, model, xai],
195
  [user_prompt, chatbot, xai_interactive, xai_plot],
196
  )
197
 
198
  # final row to show legal information
199
  ## - credits, data protection and link to the License
200
- with gr.Row():
201
- with gr.Accordion("Credits, Data Protection and License", open=False):
202
- gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
203
 
204
  # mount function for fastAPI Application
205
  app = gr.mount_gradio_app(app, ui, path="/")
 
7
  # internal imports
8
  from backend.controller import interference
9
 
10
+ # Global Variables and css
11
  app = FastAPI()
12
+ css = "body {text-align: start !important;}"
13
 
14
 
15
  # different functions to provide frontend abilities
 
37
  gr.Info("No XAI method was selected.")
38
 
39
 
 
 
 
 
 
 
40
  # ui interface based on Gradio Blocks (see documentation:
41
  # https://www.gradio.app/docs/interface)
42
+ with gr.Blocks(
43
+ css="text-align: start !important",
44
+ title="Thesis Webapp Showcase",
45
+ head="<head>",
46
+ ) as ui:
47
  # header row with markdown based text
48
  with gr.Row():
49
  # markdown component to display the header
50
+ gr.Markdown("""
51
+ # Thesis Demo - AI Chat Application with GODEL
52
+ ## XAI powered by SHAP and BERTVIZ
53
  ### Select between tabs below for the different views.
54
+ """)
 
55
  # ChatBot tab used to chat with the AI chatbot
56
  with gr.Tab("AI ChatBot"):
57
  with gr.Row():
58
  # markdown component to display the header of the current tab
59
+ gr.Markdown("""
 
60
  ### ChatBot Demo
61
  Chat with the AI ChatBot using the textbox below.
62
  Manipulate the settings in the row above,
63
  including the selection of the model,
64
  the system prompt and the XAI method.
65
+
66
+ """)
67
  # row with columns for the different settings
68
  with gr.Row(equal_height=True):
69
  # column that takes up 3/5 of the row
 
77
  " answer as helpfully as possible, while being safe."
78
  ),
79
  )
 
 
 
 
 
 
 
 
 
 
80
  with gr.Column(scale=1):
81
  # checkbox group to select the xai method
82
+ xai_selection = gr.Radio(
83
  ["None", "SHAP", "Visualizer"],
84
  label="XAI Settings",
85
+ info="Select a XAI Implementation to use.",
86
  value="None",
87
  interactive=True,
88
  show_label=True,
 
90
 
91
  # calling info functions on inputs for different settings
92
  system_prompt.submit(system_prompt_info, [system_prompt])
93
+ xai_selection.input(xai_info, [xai_selection])
 
94
 
95
  # row with chatbot ui displaying "conversation" with the model
96
+ with gr.Row(equal_height=True):
97
  # out of the box chatbot component
98
  # see documentation: https://www.gradio.app/docs/chatbot
99
  chatbot = gr.Chatbot(
 
101
  show_copy_button=True,
102
  avatar_images=("./public/human.jpg", "./public/bot.jpg"),
103
  )
104
+ # rows with input textboxes
105
+ with gr.Row():
106
+ # textbox to enter the knowledge
107
+ with gr.Accordion(label="Additional Knowledge", open=False):
108
+ knowledge_input = gr.Textbox(
109
+ value="",
110
+ label="Knowledge",
111
+ max_lines=5,
112
+ info="Add additional context knowledge.",
113
+ show_label=True,
114
+ )
115
  with gr.Row():
116
  # textbox to enter the user prompt
117
+ user_prompt = gr.Textbox(
118
+ label="Input Message",
119
+ max_lines=5,
120
+ info="""
121
+ Ask the ChatBot a question.
122
+ Hint: More complicated question give better explanation insights!
123
+ """,
124
+ show_label=True,
125
+ )
126
  # row with columns for buttons to submit and clear content
127
  with gr.Row():
128
  with gr.Column(scale=1):
 
131
  clear_btn = gr.ClearButton([user_prompt, chatbot])
132
  with gr.Column(scale=1):
133
  submit_btn = gr.Button("Submit", variant="primary")
134
+ with gr.Row():
135
+ gr.Examples(
136
+ label="Example Questions",
137
+ examples=[
138
+ [
139
+ "How does a black hole form in space?",
140
+ (
141
+ "Black holes are created when a massive star's core"
142
+ " collapses after a supernova, forming an object with"
143
+ " gravity so intense that even light cannot escape."
144
+ ),
145
+ ],
146
+ [
147
+ (
148
+ "Explain the importance of the Rosetta Stone in"
149
+ " understanding ancient languages."
150
+ ),
151
+ (
152
+ "The Rosetta Stone, an ancient Egyptian artifact, was key"
153
+ " in decoding hieroglyphs, featuring the same text in three"
154
+ " scripts: hieroglyphs, Demotic, and Greek."
155
+ ),
156
+ ],
157
+ ],
158
+ inputs=[user_prompt, knowledge_input],
159
+ )
160
 
161
  # explanations tab used to provide explanations for a specific conversation
162
  with gr.Tab("Explanations"):
163
  # row with markdown component to display the header of the current tab
164
  with gr.Row():
165
+ gr.Markdown("""
 
166
  ### Get Explanations for Conversations
167
  Using your selected XAI method, you can get explanations for
168
  the conversation you had with the AI ChatBot. The explanations are
169
  based on the last message you sent to the AI ChatBot (see text)
170
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  # row that displays the generated explanation of the model (if applicable)
172
+ with gr.Row(variant="panel"):
173
+ # wraps the explanation html in an iframe to display it interactively
174
  xai_interactive = gr.HTML(
175
  label="Interactive Explanation",
176
+ value=(
177
+ '<div style="text-align: center"><h4>No Graphic to Display'
178
+ " (Yet)</h4></div>"
179
+ ),
180
  show_label=True,
 
181
  )
182
  # row and accordion to display an explanation plot (if applicable)
183
  with gr.Row():
184
  with gr.Accordion("Token Explanation Plot", open=False):
185
+ gr.Markdown("""
186
+ #### Plotted Values
187
+ Values have been excluded for readability. See colorbar for value indication.
188
+ """)
189
  # plot component that takes a matplotlib figure as input
190
+ xai_plot = gr.Plot(label="Token Level Explanation", scale=3)
 
 
 
 
191
 
192
  # functions to trigger the controller
193
+ ## takes information for the chat and the xai selection
194
  ## returns prompt, history and xai data
195
  ## see backend/controller.py for more information
196
  submit_btn.click(
197
  interference,
198
+ [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
199
  [user_prompt, chatbot, xai_interactive, xai_plot],
200
  )
201
  # function triggered by the enter key
202
  user_prompt.submit(
203
  interference,
204
+ [user_prompt, chatbot, knowledge_input, system_prompt, xai_selection],
205
  [user_prompt, chatbot, xai_interactive, xai_plot],
206
  )
207
 
208
  # final row to show legal information
209
  ## - credits, data protection and link to the License
210
+ with gr.Tab(label="Credits, Data Protection and License"):
211
+ gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
 
212
 
213
  # mount function for fastAPI Application
214
  app = gr.mount_gradio_app(app, ui, path="/")
model/godel.py CHANGED
@@ -1,30 +1,55 @@
1
  # GODEL model module for chat interaction and model instance control
 
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
 
 
 
 
4
  # model and tokenizer instance
5
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
6
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
7
- GODEL_CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  # formatting class to formatting input for the model
11
  # CREDIT: Adapted from official interference example on Huggingface
12
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
13
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
 
14
  prompt = ""
15
 
 
 
 
 
 
16
  # adds knowledge text if not empty
17
  if knowledge != "":
18
  knowledge = "[KNOWLEDGE] " + knowledge
19
 
20
- history.append([message])
21
- for user_prompt, bot_response in history:
22
- prompt += f"EOS {user_prompt} EOS {bot_response}"
23
 
24
- prompt = f"{system_prompt} [CONTEXT] {prompt} {knowledge}"
 
 
 
25
 
26
- # returns the full combined prompt for the model
27
- return prompt
28
 
29
 
30
  # response class calling the model and returning the model output message
@@ -32,7 +57,7 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
32
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
33
  def respond(prompt):
34
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
35
- outputs = MODEL.generate(input_ids, **GODEL_CONFIG)
36
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
37
 
38
  return output
 
1
  # GODEL model module for chat interaction and model instance control
2
+
3
+ # external imports
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
 
6
+ # internal imports
7
+ from utils import modelling as mdl
8
+
9
  # model and tokenizer instance
10
  TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
11
  MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
12
+ CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
13
+
14
+
15
+ # TODO: Make config variable
16
+ def set_config(config: dict = None):
17
+ if config is None:
18
+ config = {}
19
+
20
+ MODEL.config.max_new_tokens = 50
21
+ MODEL.config.min_length = 8
22
+ MODEL.config.top_p = 0.9
23
+ MODEL.config.do_sample = True
24
 
25
 
26
  # formatting class to formatting input for the model
27
  # CREDIT: Adapted from official interference example on Huggingface
28
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
29
  def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
30
+ # user input prompt initialization
31
  prompt = ""
32
 
33
+ # limits the prompt elements to the maximum token count
34
+ message, history, system_prompt, knowledge = mdl.prompt_limiter(
35
+ TOKENIZER, message, history, system_prompt, knowledge
36
+ )
37
+
38
  # adds knowledge text if not empty
39
  if knowledge != "":
40
  knowledge = "[KNOWLEDGE] " + knowledge
41
 
42
+ # adds conversation history to the prompt
43
+ for conversation in history:
44
+ prompt += f"EOS {conversation[0]} EOS {conversation[1]}"
45
 
46
+ # adds the message to the prompt
47
+ prompt += f" {message}"
48
+ # combines the entire prompt
49
+ full_prompt = f"{system_prompt} [CONTEXT] {prompt} {knowledge}"
50
 
51
+ # returns the formatted prompt
52
+ return full_prompt
53
 
54
 
55
  # response class calling the model and returning the model output message
 
57
  ## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
58
  def respond(prompt):
59
  input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
60
+ outputs = MODEL.generate(input_ids, **CONFIG)
61
  output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
62
 
63
  return output
model/mistral.py DELETED
@@ -1,71 +0,0 @@
1
- # Mistral 7B model module for chat interaction and model instance control
2
-
3
- # external imports
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
5
- import torch
6
- import gradio as gr
7
-
8
- # global variables for model and tokenizer, config
9
- MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
10
- TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
11
- MISTRAL_CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
12
-
13
- MISTRAL_CONFIG.update(
14
- **{
15
- "temperature": 0.7,
16
- "max_new_tokens": 50,
17
- "top_p": 0.9,
18
- "repetition_penalty": 1.2,
19
- "do_sample": True,
20
- "seed": 42,
21
- }
22
- )
23
-
24
-
25
- # function to format the prompt to include chat history, message
26
- # CREDIT: adapted from Venkata Bhanu Teja Pallakonda in Huggingface discussions
27
- ## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/discussions/
28
-
29
-
30
- def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
31
- prompt = ""
32
- if knowledge != "":
33
- gr.Warning(
34
- """Mistral does not support
35
- additionally knowledge!"""
36
- )
37
-
38
- # if no history, use system prompt and example message
39
- if len(history) == 0:
40
- prompt = f"""<s>[INST] {system_prompt} [/INST] How can I help you today? </s>
41
- [INST] {message} [/INST]"""
42
- else:
43
- # takes the very first exchange and the system prompt as base
44
- for user_prompt, bot_response in history[0]:
45
- prompt = (
46
- f"<s>[INST] {system_prompt} {user_prompt} [/INST] {bot_response}</s>"
47
- )
48
-
49
- # takes all the following conversations and adds them as context
50
- prompt += "".join(
51
- f"[INST] {user_prompt} [/INST] {bot_response}</s>"
52
- for user_prompt, bot_response in history[1:]
53
- )
54
- return prompt
55
-
56
-
57
- # generation class returning the model response based on the input
58
- # CREDIT: adapted from official Mistral Ai 7B Instruct documentation on Huggingface
59
- ## see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
60
- def respond(prompt):
61
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
-
63
- # tokenizing inputs and configuring model
64
- input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")
65
- model_input = input_ids.to(device)
66
- MODEL.to(device)
67
-
68
- # generating text with tokenized input, returning output
69
- output_ids = MODEL.generate(model_input, generation_config=MISTRAL_CONFIG)
70
- output_text = TOKENIZER.batch_decode(output_ids)
71
- return output_text[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
public/credits_dataprotection_license.md CHANGED
@@ -9,19 +9,14 @@
9
  For full credits, please refer to the [thesis print]()
10
 
11
  ### Models
12
- For this project, two different models are used. Both are used through Huggingface's [transformers](https://huggingface.co/docs/transformers/index) library.
13
 
14
- ##### LlaMa 2
15
- LlaMa 2 is an open source model by Meta Research. See [offical paper](https://arxiv.org/pdf/2307.09288.pdf) for more information.
16
 
17
- - the version used in this project is LlaMa 2 7B Chat HF (HF = special version for huggingface), see [huggingface model hub](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
18
- - the model is fine-tuned for chat interactions by Meta Research
19
-
20
- ##### Mistral
21
- Mistral is an open source model by Mistral AI. See [offical paper](https://arxiv.org/pdf/2310.06825.pdf) for more information.
22
-
23
- - the version used in this project is Mistral Instruct, see [huggingface model hub](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1)
24
- - the model is fine-tuned for instruction following by Mistral AI
25
 
26
  ### Libraries
27
  This project uses a number of open source libraries, only the most important ones are listed below.
@@ -29,7 +24,7 @@ This project uses a number of open source libraries, only the most important one
29
  ##### Shap
30
  This application uses a custom version of the shap library, which is available at [GitHub](https://github.com/shap/shap).
31
 
32
- - please refer to the [shap-adapter](https://github.com/LennardZuendorf/thesis-shap-adapter) repository for more information about the changes made to the library, specifically the README and CHANGES files
33
  - the shap library and the used partition SHAP explainer are based on work by Lundberg et al. (2017), see [offical paper](https://arxiv.org/pdf/1705.07874.pdf) for more information
34
 
35
  ##### BertViz
@@ -40,10 +35,11 @@ This application uses a slightly customized version of the bertviz library, whic
40
 
41
 
42
  # Data Protection
43
- This is a non-commercial project, which does not collect any personal data. The only data collected is the data you enter into the application. This data is only used to generate the explanations and is not stored anywhere.
44
- However, the application may be hosted with an external service (i.e. Huggingface Spaces), which may collect data. Please refer to the data protection policies of the respective service for more information.
 
45
 
46
- If you use the "flag" feature, the data you enter will be stored in *publicly available* csv file.
47
 
48
 
49
  # License
 
9
  For full credits, please refer to the [thesis print]()
10
 
11
  ### Models
12
+ This implementation is build on GODEL by Microsoft, Inc.
13
 
14
+ ##### GODEL
15
+ GODEL is an open source model by Microsoft. See [offical paper](https://arxiv.org/abs/2206.11309) for more information.
16
 
17
+ - the version used in this project is GODEL Large, see [huggingface model hub](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq?text=Hey+my+name+is+Thomas%21+How+are+you%3F)
18
+ - the model as is a generative seq2seq transformer fine tuned for goal directed dialog
19
+ - it supports context and knowledge base inputs
 
 
 
 
 
20
 
21
  ### Libraries
22
  This project uses a number of open source libraries, only the most important ones are listed below.
 
24
  ##### Shap
25
  This application uses a custom version of the shap library, which is available at [GitHub](https://github.com/shap/shap).
26
 
27
+ - please refer to the [thesis-custom-shap](https://github.com/LennardZuendorf/thesis-custom-shap) repository for more information about the changes made to the library, specifically the README and CHANGES files
28
  - the shap library and the used partition SHAP explainer are based on work by Lundberg et al. (2017), see [offical paper](https://arxiv.org/pdf/1705.07874.pdf) for more information
29
 
30
  ##### BertViz
 
35
 
36
 
37
  # Data Protection
38
+ This is a non-commercial research project, which does not collect any personal data. The only data collected is the data you enter into the application. This data is only used to generate the explanations and is not stored anywhere.
39
+
40
+ > However, the application may be hosted with an external service (i.e. Huggingface Spaces), which may collect data.
41
 
42
+ Please refer to the data protection policies of the respective service for more information. If you use the "flag" feature, the data you enter will be stored in *publicly available* csv file.
43
 
44
 
45
  # License
pyproject.toml CHANGED
@@ -1,6 +1,8 @@
 
1
  [tool.black]
2
  line-length = 88
3
  include = '\.pyi?$'
 
4
  exclude = '''
5
  /(
6
  \.eggs
 
1
+ # configuration for formatting & linting tools
2
  [tool.black]
3
  line-length = 88
4
  include = '\.pyi?$'
5
+ preview = true
6
  exclude = '''
7
  /(
8
  \.eggs
railway.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "$schema": "https://railway.app/railway.schema.json",
3
- "build": {
4
- "builder": "DOCKERFILE",
5
- "dockerfilePath": "Dockerfile"
6
- },
7
- "deploy": {
8
- "numReplicas": 1,
9
- "sleepApplication": false,
10
- "restartPolicyType": "ON_FAILURE",
11
- "restartPolicyMaxRetries": 10
12
- }
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py ADDED
File without changes
utils/formatting.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # formatting util module providing formatting functions for the model input and output
2
+
3
+ # external imports
4
+ import re
5
+
6
+
7
+ # function to format the model reponse nicely
8
+ def format_output_text(output: list):
9
+ # remove special tokens from list
10
+ formatted_output = format_tokens(output)
11
+
12
+ # start string with first list item if it is not empty
13
+ if formatted_output[0] != "":
14
+ output_str = formatted_output[0]
15
+ else:
16
+ # alternatively start with second list item
17
+ output_str = formatted_output[1]
18
+
19
+ # add all other list items with a space in between
20
+ for txt in formatted_output[1:]:
21
+ # check if the token is a punctuation mark
22
+ if txt in [".", ",", "!", "?"]:
23
+ # add punctuation mark without space
24
+ output_str += txt
25
+ # add token with space if not empty
26
+ elif txt != "":
27
+ output_str += " " + txt
28
+
29
+ # return the combined string with multiple spaces removed
30
+ return re.sub(" +", " ", output_str)
31
+
32
+
33
+ # format the tokens by removing special tokens and special characters
34
+ def format_tokens(tokens: list):
35
+ # define special tokens to remove and initialize empty list
36
+ special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "โ–", "ฤ ", "</w>"]
37
+ updated_tokens = []
38
+
39
+ # loop through tokens
40
+ for t in tokens:
41
+ # remove special token from start of token if found
42
+ if t.startswith("โ–"):
43
+ t = t.lstrip("โ–")
44
+
45
+ # loop through special tokens and remove them if found
46
+ for s in special_tokens:
47
+ t = t.replace(s, "")
48
+
49
+ # add token to list
50
+ updated_tokens.append(t)
51
+
52
+ # return the list of tokens
53
+ return updated_tokens
utils/modelling.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # module for modelling utilities
2
+
3
+ # external imports
4
+ import gradio as gr
5
+
6
+
7
+ def prompt_limiter(
8
+ tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
9
+ ):
10
+ # initializing the prompt history empty
11
+ prompt_history = []
12
+ # getting the token count for the message, system prompt, and knowledge
13
+ pre_count = (
14
+ token_counter(tokenizer, message)
15
+ + token_counter(tokenizer, system_prompt)
16
+ + token_counter(tokenizer, knowledge)
17
+ )
18
+
19
+ # validating the token count
20
+ # check if token count already too high
21
+ if pre_count > 1024:
22
+
23
+ # check if token count too high even without knowledge
24
+ if (
25
+ token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
26
+ > 1024
27
+ ):
28
+
29
+ # show warning and raise error
30
+ gr.Warning("Message and system prompt are too long. Please shorten them.")
31
+ raise RuntimeError(
32
+ "Message and system prompt are too long. Please shorten them."
33
+ )
34
+
35
+ # show warning and remove knowledge
36
+ gr.Warning("Knowledge is too long. It has been removed to keep model running.")
37
+ return message, prompt_history, system_prompt, ""
38
+
39
+ # if token count small enough, add history
40
+ if pre_count < 800:
41
+ # setting the count to the precount
42
+ count = pre_count
43
+ # reversing the history to prioritize recent conversations
44
+ history.reverse()
45
+
46
+ # iterating through the history
47
+ for conversation in history:
48
+
49
+ # checking the token count with the current conversation
50
+ count += token_counter(tokenizer, conversation[0]) + token_counter(
51
+ tokenizer, conversation[1]
52
+ )
53
+
54
+ # add conversation or break loop depending on token count
55
+ if count < 1024:
56
+ prompt_history.append(conversation)
57
+ else:
58
+ break
59
+
60
+ # return the message, prompt history, system prompt, and knowledge
61
+ return message, prompt_history, system_prompt, knowledge
62
+
63
+
64
+ # token counter function using the model tokenizer
65
+ def token_counter(tokenizer, text: str):
66
+ # tokenize the text
67
+ tokens = tokenizer(text, return_tensors="pt").input_ids
68
+ # return the token count
69
+ return len(tokens[0])