Spaces:
Runtime error
Runtime error
File size: 4,409 Bytes
fe1089d 2492536 fe1089d 5d99c07 dacf466 fe1089d 2492536 fe1089d 5d99c07 fe1089d 2492536 5d99c07 fe1089d 21aad16 5d99c07 7ad098c 5d99c07 7ad098c 5d99c07 2492536 ba1dc89 2492536 fe1089d 21aad16 dacf466 ba1dc89 58a02af fe1089d 2492536 fe1089d 2492536 c28c597 5d99c07 c28c597 fe1089d 2492536 fe1089d 5d99c07 fe1089d f5ebee7 fe1089d d2116db fe1089d f5ebee7 fe1089d 7ad098c 21aad16 fe1089d 2492536 fe1089d 7ad098c 21aad16 fe1089d dacf466 aaf0c9d dacf466 fe1089d 2492536 f5ebee7 d2116db fe1089d f5ebee7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# controller for the application that calls the model and explanation functions
# returns the updated conversation history and extra elements
# external imports
import gradio as gr
# internal imports
from model import godel
from model import mistral
from explanation import (
interpret_shap as shap_int,
interpret_captum as cpt_int,
visualize as viz,
)
# main interference function that that calls chat functions depending on selections
# is getting called on every chat submit
def interference(
prompt: str,
history: list,
knowledge: str,
system_prompt: str,
xai_selection: str,
model_selection: str,
):
# if no proper system prompt is given, use a default one
if system_prompt in ("", " "):
system_prompt = """
You are a helpful, respectful and honest assistant.
Always answer as helpfully as possible, while being safe.
"""
if model_selection.lower() == "mistral":
model = mistral
print("Indetified model as Mistral")
else:
model = godel
print("Indetified model as GODEL")
# if a XAI approach is selected, grab the XAI module instance
if xai_selection in ("SHAP", "Attention"):
# matching selection
match xai_selection.lower():
case "shap":
if model_selection.lower() == "mistral":
xai = cpt_int
else:
xai = shap_int
case "attention":
xai = viz
case _:
# use Gradio warning to display error message
gr.Warning(f"""
There was an error in the selected XAI Approach.
It is "{xai_selection}"
""")
# raise runtime exception
raise RuntimeError("There was an error in the selected XAI approach.")
# call the explained chat function with the model instance
prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
model=model,
xai=xai,
message=prompt,
history=history,
system_prompt=system_prompt,
knowledge=knowledge,
)
# if no XAI approach is selected call the vanilla chat function
else:
# call the vanilla chat function
prompt_output, history_output = vanilla_chat(
model=model,
message=prompt,
history=history,
system_prompt=system_prompt,
knowledge=knowledge,
)
# set XAI outputs to disclaimer html/none
xai_graphic, xai_markup = (
"""
<div style="text-align: center"><h4>Without Selected XAI Approach,
no graphic will be displayed</h4></div>
""",
[("", "")],
)
# return the outputs
return prompt_output, history_output, xai_graphic, xai_markup
# simple chat function that calls the model
# formats prompts, calls for an answer and returns updated conversation history
def vanilla_chat(
model, message: str, history: list, system_prompt: str, knowledge: str = ""
):
print(f"Running normal chat with {model}.")
# formatting the prompt using the model's format_prompt function
prompt = model.format_prompt(message, history, system_prompt, knowledge)
# generating an answer using the model's respond function
answer = model.respond(prompt)
# updating the chat history with the new answer
history.append((message, answer))
# returning the updated history
return "", history
def explained_chat(
model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
):
print(f"Running explained chat with {xai} with {model}.")
# formatting the prompt using the model's format_prompt function
# message, history, system_prompt, knowledge = mdl.prompt_limiter(
# message, history, system_prompt, knowledge
# )
prompt = model.format_prompt(message, history, system_prompt, knowledge)
# generating an answer using the methods chat function
answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
# updating the chat history with the new answer
history.append((message, answer))
# returning the updated history, xai graphic and xai plot elements
return "", history, xai_graphic, xai_markup
|