File size: 4,563 Bytes
fe1089d
2492536
fe1089d
 
 
 
 
 
5d99c07
dacf466
226ad46
dacf466
 
 
fe1089d
 
 
2492536
fe1089d
 
 
 
 
 
5d99c07
fe1089d
2492536
5d99c07
fe1089d
 
 
 
 
21aad16
5d99c07
229e14c
5d99c07
 
229e14c
5d99c07
2492536
ba1dc89
2492536
fe1089d
 
21aad16
dacf466
 
 
ba1dc89
226ad46
fe1089d
 
 
 
 
 
2492536
fe1089d
 
2492536
b324c38
 
 
 
 
 
 
 
 
fe1089d
2492536
fe1089d
 
 
5d99c07
fe1089d
 
 
 
 
 
30049a9
fe1089d
 
 
 
d2116db
b324c38
fe1089d
 
 
30049a9
fe1089d
 
 
 
 
 
 
7ad098c
21aad16
fe1089d
 
2492536
fe1089d
 
 
 
 
 
 
 
 
 
 
 
7ad098c
21aad16
fe1089d
dacf466
aaf0c9d
dacf466
fe1089d
 
2492536
30049a9
d2116db
fe1089d
 
 
 
30049a9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# controller for the application that calls the model and explanation functions
# returns the updated conversation history and extra elements

# external imports
import gradio as gr

# internal imports
from model import godel
from model import mistral
from explanation import (
    attention as attention_viz,
    interpret_shap as shap_int,
    interpret_captum as cpt_int,
)


# main interference function that that calls chat functions depending on selections
# is getting called on every chat submit
def interference(
    prompt: str,
    history: list,
    knowledge: str,
    system_prompt: str,
    xai_selection: str,
    model_selection: str,
):
    # if no proper system prompt is given, use a default one
    if system_prompt in ("", " "):
        system_prompt = """
            You are a helpful, respectful and honest assistant.
            Always answer as helpfully as possible, while being safe.
        """

    if model_selection.lower() == "mistral":
        model = mistral
        print("Indentified model as Mistral")
    else:
        model = godel
        print("Indentified model as GODEL")

    # if a XAI approach is selected, grab the XAI module instance
    if xai_selection in ("SHAP", "Attention"):
        # matching selection
        match xai_selection.lower():
            case "shap":
                if model_selection.lower() == "mistral":
                    xai = cpt_int
                else:
                    xai = shap_int
            case "attention":
                xai = attention_viz
            case _:
                # use Gradio warning to display error message
                gr.Warning(f"""
                    There was an error in the selected XAI Approach.
                    It is "{xai_selection}"
                    """)
                # raise runtime exception
                raise RuntimeError("There was an error in the selected XAI approach.")

        # call the explained chat function with the model instance
        prompt_output, history_output, xai_interactive, xai_markup, xai_plot = (
            explained_chat(
                model=model,
                xai=xai,
                message=prompt,
                history=history,
                system_prompt=system_prompt,
                knowledge=knowledge,
            )
        )
    # if no XAI approach is selected call the vanilla chat function
    else:
        # call the vanilla chat function
        prompt_output, history_output = vanilla_chat(
            model=model,
            message=prompt,
            history=history,
            system_prompt=system_prompt,
            knowledge=knowledge,
        )
        # set XAI outputs to disclaimer html/none
        xai_interactive, xai_markup, xai_plot = (
            """
            <div style="text-align: center"><h4>Without Selected XAI Approach,
            no graphic will be displayed</h4></div>
            """,
            [("", "")],
            None,
        )

    # return the outputs
    return prompt_output, history_output, xai_interactive, xai_markup, xai_plot


# simple chat function that calls the model
# formats prompts, calls for an answer and returns updated conversation history
def vanilla_chat(
    model, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    print(f"Running normal chat with {model}.")

    # formatting the prompt using the model's format_prompt function
    prompt = model.format_prompt(message, history, system_prompt, knowledge)

    # generating an answer using the model's respond function
    answer = model.respond(prompt)

    # updating the chat history with the new answer
    history.append((message, answer))
    # returning the updated history
    return "", history


def explained_chat(
    model, xai, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    print(f"Running explained chat with {xai} with {model}.")

    # formatting the prompt using the model's format_prompt function
    # message, history, system_prompt, knowledge = mdl.prompt_limiter(
    #    message, history, system_prompt, knowledge
    # )
    prompt = model.format_prompt(message, history, system_prompt, knowledge)

    # generating an answer using the methods chat function
    answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)

    # updating the chat history with the new answer
    history.append((message, answer))

    # returning the updated history, xai graphic and xai plot elements
    return "", history, xai_graphic, xai_markup, xai_plot