File size: 6,631 Bytes
8276485
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import streamlit as st
from ibm_watsonx_ai.foundation_models import ModelInference
from ibm_watsonx_ai import Credentials, APIClient
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
from knowledge_bases import KNOWLEDGE_BASE_OPTIONS, SYSTEM_PROMPTS
import genparam
import time

def check_password():
    """Password protection check for the app."""
    def password_entered():
        if st.session_state["password"] == st.secrets["app_password"]:
            st.session_state["password_correct"] = True
            del st.session_state["password"]
        else:
            st.session_state["password_correct"] = False

    if "password_correct" not in st.session_state:
        st.markdown("\n\n")
        st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
        st.divider()
        st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
        return False
    elif not st.session_state["password_correct"]:
        st.markdown("\n\n")
        st.text_input("Enter the password", type="password", on_change=password_entered, key="password")
        st.divider()
        st.error("😕 Incorrect password")
        st.info("Designed and developed by Milan Mrdenovic © IBM Norway 2024")
        return False
    else:
        return True

def initialize_session_state():
    """Initialize all session state variables."""
    if 'chat_history_1' not in st.session_state:
        st.session_state.chat_history_1 = []
    if 'chat_history_2' not in st.session_state:
        st.session_state.chat_history_2 = []
    if 'chat_history_3' not in st.session_state:
        st.session_state.chat_history_3 = []
    if 'first_question' not in st.session_state:
        st.session_state.first_question = False 
    if "counter" not in st.session_state:
        st.session_state["counter"] = 0
    if 'token_statistics' not in st.session_state:
        st.session_state.token_statistics = []
    if 'selected_kb' not in st.session_state:
        st.session_state.selected_kb = KNOWLEDGE_BASE_OPTIONS[0]
    if 'current_system_prompts' not in st.session_state:
        st.session_state.current_system_prompts = SYSTEM_PROMPTS[st.session_state.selected_kb]

def setup_client(project_id=None):
    """Setup WatsonX client with credentials."""
    credentials = Credentials(
        url=st.secrets["url"],
        api_key=st.secrets["api_key"]
    )
    project_id = project_id or st.secrets["project_id"]
    client = APIClient(credentials, project_id=project_id)
    return credentials, client

def get_active_model():
    """Get the currently active model based on configuration."""
    return genparam.SELECTED_MODEL_1 if genparam.ACTIVE_MODEL == 0 else genparam.SELECTED_MODEL_2

def get_active_prompt_template():
    """Get the currently active prompt template."""
    return genparam.PROMPT_TEMPLATE_1 if genparam.ACTIVE_MODEL == 0 else genparam.PROMPT_TEMPLATE_2

def prepare_prompt(prompt, chat_history):
    """Prepare the prompt with chat history if available."""
    if genparam.TYPE == "chat" and chat_history:
        chats = "\n".join([f"{message['role']}: \"{message['content']}\"" for message in chat_history])
        return f"Conversation History:\n{chats}\n\nNew User Input: {prompt}"
    return f"User Input: {prompt}"

def apply_prompt_syntax(prompt, system_prompt, prompt_template, bake_in_prompt_syntax):
    """Apply appropriate syntax to the prompt based on model requirements."""
    model_family_syntax = {
        "llama3-instruct (llama-3, 3.1 & 3.2) - system": """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
        "llama3-instruct (llama-3, 3.1 & 3.2) - user": """<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n""",
        "granite-13b-chat & instruct - system": """<|system|>\n{system_prompt}\n<|user|>\n{prompt}\n<|assistant|>\n\n""",
        "granite-13b-chat & instruct - user": """<|user|>\n{prompt}\n<|assistant|>\n\n""",
        "mistral & mixtral v2 tokenizer - system": """<s>[INST] System Prompt: {system_prompt} [/INST][INST] {prompt} [/INST]\n\n""",
        "mistral & mixtral v2 tokenizer - user": """<s>[INST] {prompt} [/INST]\n\n""",
        "no syntax - system": """{system_prompt}\n\n{prompt}""",
        "no syntax - user": """{prompt}"""
    }
    
    if bake_in_prompt_syntax:
        template = model_family_syntax[prompt_template]
        if system_prompt:
            return template.format(system_prompt=system_prompt, prompt=prompt)
    return prompt

def generate_response(watsonx_llm, prompt_data, params):
    """Generate streaming response from the model."""
    generated_response = watsonx_llm.generate_text_stream(prompt=prompt_data, params=params)
    for chunk in generated_response:
        yield chunk

def capture_tokens(prompt_data, response, client, bot_name):
    """Capture token usage statistics."""
    if not genparam.TOKEN_CAPTURE_ENABLED:
        return

    watsonx_llm = ModelInference(
        api_client=client, 
        model_id=genparam.SELECTED_MODEL,
        verify=genparam.VERIFY
    )

    input_tokens = watsonx_llm.tokenize(prompt=prompt_data)["result"]["token_count"]
    output_tokens = watsonx_llm.tokenize(prompt=response)["result"]["token_count"]
    total_tokens = input_tokens + output_tokens

    return {
        "bot_name": bot_name,
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "total_tokens": total_tokens,
        "timestamp": time.strftime("%H:%M:%S")
    }

def fetch_response(user_input, client, system_prompt, chat_history):
    """Fetch response from the model for the given input."""
    prompt = prepare_prompt(user_input, chat_history)
    prompt_data = apply_prompt_syntax(
        prompt,
        system_prompt,
        get_active_prompt_template(),
        genparam.BAKE_IN_PROMPT_SYNTAX
    )

    watsonx_llm = ModelInference(
        api_client=client, 
        model_id=get_active_model(),
        verify=genparam.VERIFY
    )

    params = {
        GenParams.DECODING_METHOD: genparam.DECODING_METHOD,
        GenParams.MAX_NEW_TOKENS: genparam.MAX_NEW_TOKENS,
        GenParams.MIN_NEW_TOKENS: genparam.MIN_NEW_TOKENS,
        GenParams.REPETITION_PENALTY: genparam.REPETITION_PENALTY,
        GenParams.STOP_SEQUENCES: genparam.STOP_SEQUENCES
    }

    stream = generate_response(watsonx_llm, prompt_data, params)
    return stream, prompt_data