File size: 4,451 Bytes
5d99c07
 
 
2b7cdd8
5d99c07
 
 
 
 
326ad4b
5d99c07
a597c76
 
5d99c07
 
 
 
 
 
 
a597c76
5d99c07
 
 
a597c76
5d99c07
 
 
a597c76
5d99c07
 
 
2b7cdd8
b324c38
7409c2e
b324c38
7409c2e
 
 
 
b324c38
 
5d99c07
 
 
229e14c
5d99c07
229e14c
 
b324c38
 
5d99c07
 
a597c76
 
b324c38
5d99c07
 
 
a597c76
5d99c07
 
11174d4
 
5d99c07
 
 
11174d4
 
 
 
5d99c07
 
229e14c
 
 
5d99c07
11174d4
 
 
30049a9
 
 
11174d4
30049a9
 
 
5d99c07
 
 
 
 
 
 
 
30049a9
 
5d99c07
30049a9
 
a597c76
30049a9
 
 
 
 
 
 
 
 
 
5d99c07
 
 
a597c76
 
 
5d99c07
a597c76
c7e16d0
5d99c07
 
326ad4b
5d99c07
 
326ad4b
5d99c07
a597c76
 
7409c2e
5d99c07
a597c76
7409c2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
# Mistral model module for chat interaction and model instance control

# external imports
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import gradio as gr

# internal imports
from utils import modelling as mdl
from utils import formatting as fmt

# global model and tokenizer instance (created on initial build)
# determine if GPU is available and load model accordingly
device = mdl.get_device()
if device == torch.device("cuda"):
    n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()

    MODEL = AutoModelForCausalLM.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2",
        quantization_config=bnb_config,
        device_map="auto",
        max_memory={i: max_memory for i in range(n_gpus)},
    )

# otherwise, load model on CPU
else:
    MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
    MODEL.to(device)
# load tokenizer
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

# default model config
CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
base_config_dict = {
    "temperature": 0.7,
    "max_new_tokens": 64,
    "top_p": 0.9,
    "repetition_penalty": 1.2,
    "do_sample": True,
    "seed": 42,
}
CONFIG.update(**base_config_dict)


# function to (re) set config
def set_config(config_dict: dict):

    # if config dict is not given, set to default
    if config_dict == {}:
        config_dict = base_config_dict
    CONFIG.update(**config_dict)


# advanced formatting function that takes into account a conversation history
# CREDIT: adapted from the Mistral AI Instruct chat template
# see https://github.com/chujiezheng/chat_templates/
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
    prompt = ""

    # send information to the ui if knowledge is not empty
    if knowledge != "":
        gr.Info("""
            Mistral doesn't support additional knowledge, it's gonna be ignored.
            """)

    # if no history, use system prompt and example message
    if len(history) == 0:
        prompt = f"""
            <s>[INST] {system_prompt} [/INST] How can I help you today? </s>
            [INST] {message} [/INST]
            """
    else:
        # takes the very first exchange and the system prompt as base
        prompt = f"""
            <s>[INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]}</s>
            """

        # adds conversation history to the prompt
        for conversation in history[1:]:
            # takes all the following conversations and adds them as context
            prompt += "".join(
                f"\n[INST] {conversation[0]} [/INST] {conversation[1]}</s>"
            )

        prompt += """\n[INST] {message} [/INST]"""

    # returns full prompt
    return prompt


# function to extract real answer because mistral always returns the full prompt
def format_answer(answer: str):
    # empty answer string
    formatted_answer = ""

    # splitting answer by instruction tokens
    segments = answer.split("[/INST]")

    # checking if proper history got returned
    if len(segments) > 1:
        # return text after the last ['/INST'] - response to last message
        formatted_answer = segments[-1].strip()
    else:
        # return warning and full answer if not enough [/INST] tokens found
        gr.Warning("""
                   There was an issue with answer formatting...\n
                   returning the full answer.
                   """)
        formatted_answer = answer

    print(f"CUT:\n {answer}\nINTO:\n{formatted_answer}")
    return formatted_answer


# response class calling the model and returning the model output message
# CREDIT: Copied from official interference example on Huggingface
# see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
def respond(prompt: str):
    # setting config to default
    set_config({})

    # tokenizing inputs and configuring model
    input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device)

    # generating text with tokenized input, returning output
    output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
    output_text = TOKENIZER.batch_decode(output_ids)

    # formatting output text with special function
    output_text = fmt.format_output_text(output_text)

    # returning the model output string
    return format_answer(output_text)