File size: 3,675 Bytes
2492536
fe1089d
 
5d99c07
fe1089d
5d99c07
fe1089d
 
2492536
 
fe1089d
 
 
2492536
fe1089d
2492536
fe1089d
 
 
 
 
 
2492536
 
fe1089d
 
2492536
fe1089d
 
 
 
 
 
 
 
 
 
 
2492536
 
 
 
 
fe1089d
 
2492536
fe1089d
 
 
 
 
 
 
 
 
2492536
fe1089d
 
 
 
 
 
 
 
 
 
2492536
fe1089d
 
 
 
 
 
 
 
 
5d99c07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67a34bd
 
 
 
 
 
 
 
 
 
f301e04
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# modelling util module providing formatting functions for model functionalities

# external imports
import torch
import gradio as gr
from transformers import BitsAndBytesConfig


# function that limits the prompt to contain model runtime
# tries to keep as much as possible, always keeping at least message and system prompt
def prompt_limiter(
    tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
):
    # initializing the new prompt history empty
    prompt_history = []
    # getting the current token count for the message, system prompt, and knowledge
    pre_count = (
        token_counter(tokenizer, message)
        + token_counter(tokenizer, system_prompt)
        + token_counter(tokenizer, knowledge)
    )

    # validating the token count against threshold of 1024
    # check if token count already too high without history
    if pre_count > 1024:

        # check if token count too high even without knowledge and history
        if (
            token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
            > 1024
        ):

            # show warning and raise error
            gr.Warning("Message and system prompt are too long. Please shorten them.")
            raise RuntimeError(
                "Message and system prompt are too long. Please shorten them."
            )

        # show warning and return with empty history and empty knowledge
        gr.Warning("""
                   Input too long.
                   Knowledge and conversation history have been removed to keep model running.
                   """)
        return message, prompt_history, system_prompt, ""

    # if token count small enough, adding history bit by bit
    if pre_count < 800:
        # setting the count to the precount
        count = pre_count
        # reversing the history to prioritize recent conversations
        history.reverse()

        # iterating through the history
        for conversation in history:

            # checking the token count i´with the current conversation
            count += token_counter(tokenizer, conversation[0]) + token_counter(
                tokenizer, conversation[1]
            )

            # add conversation or break loop depending on token count
            if count < 1024:
                prompt_history.append(conversation)
            else:
                break

    # return the message, adapted, system prompt, and knowledge
    return message, prompt_history, system_prompt, knowledge


# token counter function using the model tokenizer
def token_counter(tokenizer, text: str):
    # tokenize the text
    tokens = tokenizer(text, return_tensors="pt").input_ids
    # return the token count
    return len(tokens[0])


def get_device():
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    return device


# setting device based on available hardware
def gpu_loading_config(max_memory: str = "15000MB"):
    n_gpus = torch.cuda.device_count()

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )

    return n_gpus, max_memory, bnb_config


# formatting mistral attention values
# CREDIT: copied and adapted from BERTViz
# see https://github.com/jessevig/bertviz
def format_mistral_attention(attention_values):
    squeezed = []
    for layer_attention in attention_values:
        layer_attention = layer_attention.squeeze(0)
        squeezed.append(layer_attention)
    return torch.stack(squeezed).to(torch.device("cpu"))