File size: 4,615 Bytes
fe1089d
2492536
fe1089d
 
 
 
 
 
5d99c07
dacf466
226ad46
dacf466
 
 
fe1089d
 
f301e04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe1089d
 
 
 
 
 
5d99c07
fe1089d
2492536
5d99c07
1f063be
 
 
 
fe1089d
f301e04
21aad16
5d99c07
a597c76
5d99c07
 
a597c76
5d99c07
2492536
f301e04
ba1dc89
2492536
fe1089d
 
21aad16
dacf466
 
 
ba1dc89
226ad46
fe1089d
 
 
 
 
 
2492536
fe1089d
 
2492536
b324c38
 
 
 
 
 
 
 
 
fe1089d
2492536
fe1089d
f301e04
fe1089d
5d99c07
fe1089d
 
 
 
 
 
30049a9
fe1089d
 
 
 
d2116db
b324c38
fe1089d
 
 
30049a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# controller for the application that calls the model and explanation functions
# returns the updated conversation history and extra elements

# external imports
import gradio as gr

# internal imports
from model import godel
from model import mistral
from explanation import (
    attention as attention_viz,
    interpret_shap as shap_int,
    interpret_captum as cpt_int,
)


# simple chat function that calls the model
# formats prompts, calls for an answer and returns updated conversation history
def vanilla_chat(
    model, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    print(f"Running normal chat with {model}.")

    # formatting the prompt using the model's format_prompt function
    prompt = model.format_prompt(message, history, system_prompt, knowledge)

    # generating an answer using the model's respond function
    answer = model.respond(prompt)

    # updating the chat history with the new answer
    history.append((message, answer))
    # returning the updated history
    return "", history


def explained_chat(
    model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    print(f"Running explained chat with {xai} with {model}.")

    # formatting the prompt using the model's format_prompt function
    # message, history, system_prompt, knowledge = mdl.prompt_limiter(
    #    message, history, system_prompt, knowledge
    # )
    prompt = model.format_prompt(message, history, system_prompt, knowledge)

    # generating an answer using the methods chat function
    answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)

    # updating the chat history with the new answer
    history.append((message, answer))

    # returning the updated history, xai graphic and xai plot elements
    return "", history, xai_graphic, xai_markup, xai_plot


# main interference function that calls chat functions depending on selections
def interference(
    prompt: str,
    history: list,
    knowledge: str,
    system_prompt: str,
    xai_selection: str,
    model_selection: str,
):
    # if no proper system prompt is given, use a default one
    if system_prompt in ("", " "):
        system_prompt = (
            "You are a helpful, respectful and honest assistant."
            "Always answer as helpfully as possible, while being safe."
        )

    # if a model is selected, grab the model instance
    if model_selection.lower() == "mistral":
        model = mistral
        print("Identified model as Mistral")
    else:
        model = godel
        print("Identified model as GODEL")

    # if a XAI approach is selected, grab the XAI module instance
    # and call the explained chat function
    if xai_selection in ("SHAP", "Attention"):
        # matching selection
        match xai_selection.lower():
            case "shap":
                if model_selection.lower() == "mistral":
                    xai = cpt_int
                else:
                    xai = shap_int
            case "attention":
                xai = attention_viz
            case _:
                # use Gradio warning to display error message
                gr.Warning(f"""
                    There was an error in the selected XAI Approach.
                    It is "{xai_selection}"
                    """)
                # raise runtime exception
                raise RuntimeError("There was an error in the selected XAI approach.")

        # call the explained chat function with the model instance
        prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
            explained_chat(
                model=model,
                xai=xai,
                message=prompt,
                history=history,
                system_prompt=system_prompt,
                knowledge=knowledge,
            )
        )
    # if no XAI approach is selected call the vanilla chat function
    else:
        # calling the vanilla chat function
        prompt_output, history_output = vanilla_chat(
            model=model,
            message=prompt,
            history=history,
            system_prompt=system_prompt,
            knowledge=knowledge,
        )
        # set XAI outputs to disclaimer html/none
        xai_interactive, xai_markup, xai_plot = (
            """
            <div style="text-align: center"><h4>Without Selected XAI Approach,
            no graphic will be displayed</h4></div>
            """,
            [("", "")],
            None,
        )

    # return the outputs
    return prompt_output, history_output, xai_interactive, xai_markup, xai_plot