File size: 4,022 Bytes
5d99c07
 
 
2b7cdd8
5d99c07
 
 
 
 
326ad4b
5d99c07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b7cdd8
b324c38
7409c2e
b324c38
 
7409c2e
 
 
 
b324c38
 
5d99c07
 
 
229e14c
5d99c07
229e14c
 
b324c38
 
5d99c07
 
 
30049a9
b324c38
5d99c07
 
 
 
 
11174d4
 
5d99c07
 
 
11174d4
 
 
 
5d99c07
 
229e14c
 
 
5d99c07
11174d4
 
 
30049a9
 
 
11174d4
30049a9
 
 
5d99c07
 
 
 
 
 
 
 
30049a9
 
5d99c07
30049a9
 
 
 
 
 
 
 
 
 
 
 
 
5d99c07
 
 
 
 
 
326ad4b
5d99c07
 
326ad4b
5d99c07
7409c2e
5d99c07
7409c2e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# Mistral model module for chat interaction and model instance control

# external imports
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import gradio as gr

# internal imports
from utils import modelling as mdl
from utils import formatting as fmt

# global model and tokenizer instance (created on inital build)
device = mdl.get_device()
if device == torch.device("cuda"):
    n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()

    MODEL = AutoModelForCausalLM.from_pretrained(
        "mistralai/Mistral-7B-Instruct-v0.2",
        quantization_config=bnb_config,
        device_map="auto",  # dispatch efficiently the model on the available ressources
        max_memory={i: max_memory for i in range(n_gpus)},
    )

else:
    MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
    MODEL.to(device)
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

# default model config
CONFIG = GenerationConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
base_config_dict = {
    "temperature": 0.7,
    "max_new_tokens": 64,
    "max_length": 64,
    "top_p": 0.9,
    "repetition_penalty": 1.2,
    "do_sample": True,
    "seed": 42,
}
CONFIG.update(**base_config_dict)


# function to (re) set config
def set_config(config_dict: dict):

    # if config dict is not given, set to default
    if config_dict == {}:
        config_dict = base_config_dict
    CONFIG.update(**config_dict)


# advanced formatting function that takes into a account a conversation history
# CREDIT: adapated from the Mistral AI Instruct chat template
# see https://github.com/chujiezheng/chat_templates/
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
    prompt = ""

    if knowledge != "":
        gr.Info("""
            Mistral doesn't support additional knowledge, it's gonna be ignored.
            """)

    # if no history, use system prompt and example message
    if len(history) == 0:
        prompt = f"""
            <s>[INST] {system_prompt} [/INST] How can I help you today? </s>
            [INST] {message} [/INST]
            """
    else:
        # takes the very first exchange and the system prompt as base
        prompt = f"""
            <s>[INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]}</s>
            """

        # adds conversation history to the prompt
        for conversation in history[1:]:
            # takes all the following conversations and adds them as context
            prompt += "".join(
                f"\n[INST] {conversation[0]} [/INST] {conversation[1]}</s>"
            )

        prompt += """\n[INST] {message} [/INST]"""

    # returns full prompt
    return prompt


# function to extract real answer because mistral always returns the full prompt
def format_answer(answer: str):
    # empty answer string
    formatted_answer = ""

    # splitting answer by instruction tokens
    segments = answer.split("[/INST]")

    # checking if proper history got returned
    if len(segments) > 1:
        # return text after the last ['/INST'] - reponse to last message
        formatted_answer = segments[-1].strip()
    else:
        # return warning and full answer if not enough [/INST] tokens found
        gr.Warning("""
                   There was an issue with answer formatting...\n
                   returning the full answer.
                   """)
        formatted_answer = answer

    print(f"CUT:\n {answer}\nINTO:\n{formatted_answer}")
    return formatted_answer


def respond(prompt: str):

    # tokenizing inputs and configuring model
    input_ids = TOKENIZER(f"{prompt}", return_tensors="pt")["input_ids"].to(device)

    # generating text with tokenized input, returning output
    output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
    output_text = TOKENIZER.batch_decode(output_ids)
    output_text = fmt.format_output_text(output_text)

    return format_answer(output_text)