File size: 8,220 Bytes
4fe4082
 
b11f272
 
 
4fe4082
 
b11f272
 
4fe4082
 
b11f272
4fe4082
b11f272
 
 
 
 
4fe4082
 
 
 
 
 
b11f272
 
4fe4082
 
 
 
 
 
 
 
 
b11f272
 
4fe4082
 
b11f272
 
 
4fe4082
 
b11f272
 
 
 
 
4fe4082
 
 
 
 
 
 
 
 
 
 
 
 
 
b11f272
4fe4082
 
 
 
 
 
 
 
 
 
 
 
 
b11f272
 
 
 
 
4fe4082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11f272
 
 
4fe4082
b11f272
 
4fe4082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b11f272
 
 
4fe4082
 
 
 
 
 
 
b11f272
4fe4082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from __future__ import annotations

import json
import random
import re
from abc import ABC, abstractmethod
from typing import List, Dict, Union, Optional

from huggingface_hub import InferenceClient
from tenacity import retry, stop_after_attempt, wait_random_exponential
from transformers import AutoTokenizer

from config import *

ROLE_SYSTEM = 'system'
ROLE_USER = 'user'
ROLE_ASSISTANT = 'assistant'

SUPPORTED_MISTRAL_MODELS = ['mistralai/Mixtral-8x7B-Instruct-v0.1', 'mistralai/Mistral-7B-Instruct-v0.2']
SUPPORTED_NOUS_MODELS = ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO']
SUPPORTED_LLAMA_MODELS = ['meta-llama/Llama-2-70b-chat-hf',
                          'meta-llama/Llama-2-13b-chat-hf',
                          'meta-llama/Llama-2-7b-chat-hf']
ALL_SUPPORTED_MODELS = SUPPORTED_MISTRAL_MODELS + SUPPORTED_NOUS_MODELS + SUPPORTED_LLAMA_MODELS    


def select_model(model_name: str, system_prompt: str, **kwargs) -> Model:
    if model_name in SUPPORTED_MISTRAL_MODELS:
        return MistralModel(system_prompt, model_name)
    elif model_name in SUPPORTED_NOUS_MODELS:
        return NousHermesModel(system_prompt, model_name)
    elif model_name in SUPPORTED_LLAMA_MODELS:
        return LlamaModel(system_prompt, model_name)
    else:
        raise ValueError(f'Model {model_name} not supported')


class Model(ABC):
    name: str
    messages: List[Dict[str, str]]
    system_prompt: str

    def __init__(self, model_name: str, system_prompt: str):
        self.name = model_name
        self.system_prompt = system_prompt
        self.messages = [
            {'role': ROLE_SYSTEM, 'content': system_prompt}
        ]

    @abstractmethod
    def __call__(self, *args, **kwargs) -> Union[str, Dict]:
        raise NotImplementedError

    def add_message(self, role: str, content: str):
        assert role in [ROLE_SYSTEM, ROLE_USER, ROLE_ASSISTANT]
        self.messages.append({'role': role, 'content': content})

    def clear_conversations(self):
        self.messages.clear()
        self.add_message(ROLE_SYSTEM, self.system_prompt)

    def __str__(self) -> str:
        return self.name

    def __repr__(self) -> str:
        return self.name


class HFAPIModel(Model):

    def __call__(self, user_prompt: str, *args,
                 use_json: bool = False,
                 temperature: float = 0,
                 timeout: float = None,
                 cache: bool = False,
                 json_retry_count: int = 5,
                 **kwargs) -> Union[str, Dict]:
        """
        Returns the model's response.
        If use_json = True, will try its best to return a json dict, but not guaranteed.
        If we cannot parse the JSON, we will return the response string directly.
        """
        self.add_message(ROLE_USER, user_prompt)
        response = self.get_response(temperature, use_json, timeout, cache)
        if use_json:
            for i in range(json_retry_count):
                # cache only if both instruct to do and first try
                response = self.get_response(temperature, use_json, timeout, cache and i == 0)
                json_obj = self.find_first_valid_json(response)
                if json_obj is not None:
                    response = json_obj
                    break
        self.add_message(ROLE_ASSISTANT, response)
        return response

    @retry(stop=stop_after_attempt(5), wait=wait_random_exponential(max=10), reraise=True)  # retry if exception
    def get_response(self, temperature: float, use_json: bool, timeout: float, cache: bool) -> str:
        client = InferenceClient(model=self.name, timeout=timeout)
        # client = InferenceClient(model=self.name, token=random.choice(HF_API_TOKENS), timeout=timeout)
        if not cache:
            client.headers["x-use-cache"] = "0"
        # print(self.formatter(self.messages))  # debug
        r = client.text_generation(self.format_messages(),
                                   do_sample=temperature > 0,
                                   temperature=temperature if temperature > 0 else None,
                                   max_new_tokens=4096)
        return r

    @abstractmethod
    def format_messages(self) -> str:
        raise NotImplementedError

    def get_short_name(self) -> str:
        """
        Returns the last part of the model name.
        For example, "mistralai/Mixtral-8x7B-Instruct-v0.1" -> "Mixtral-8x7B-Instruct-v0.1"
        """
        return self.name.split('/')[-1]

    @staticmethod
    def find_first_valid_json(s) -> Optional[Dict]:
        s = re.sub(r'\\(?!["\\/bfnrt]|u[0-9a-fA-F]{4})', lambda m: m.group(0)[1:], s)  # remove all invalid escapes chars
        for i in range(len(s)):
            if s[i] != '{':
                continue
            for j in range(i + 1, len(s) + 1):
                if s[j - 1] != '}':
                    continue
                try:
                    potential_json = s[i:j]
                    json_obj = json.loads(potential_json, strict=False)
                    return json_obj  # Return the first valid JSON object found
                except json.JSONDecodeError:
                    pass  # Continue searching if JSON decoding fails
        return None  # Return None if no valid JSON object is found


class MistralModel(HFAPIModel):

    def __init__(self, system_prompt: str, model_name: str = 'mistralai/Mixtral-8x7B-Instruct-v0.1') -> None:
        assert model_name in ['mistralai/Mixtral-8x7B-Instruct-v0.1',
                              'mistralai/Mistral-7B-Instruct-v0.2'], 'Model not supported'
        super().__init__(model_name, system_prompt)

    def format_messages(self) -> str:
        messages = self.messages
        # mistral doesn't support system prompt, so we need to convert it to user prompt
        if messages[0]['role'] == ROLE_SYSTEM:
            assert len(self.messages) >= 2
            messages = [{'role'   : ROLE_USER,
                         'content': messages[0]['content'] + '\n' + messages[1]['content']}] + messages[2:]
        tokenizer = AutoTokenizer.from_pretrained(self.name)
        r = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, max_length=4096)
        # print(r)
        return r


class NousHermesModel(HFAPIModel):

    def __init__(self, system_prompt: str, model_name: str = 'NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO') -> None:
        assert model_name in ['NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO'], 'Model not supported'
        super().__init__(model_name, system_prompt)

    def format_messages(self) -> str:
        messages = self.messages
        assert len(messages) >= 2  # must be at least a system and a user
        assert messages[0]['role'] == ROLE_SYSTEM and messages[1]['role'] == ROLE_USER
        tokenizer = AutoTokenizer.from_pretrained(self.name)
        r = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, max_length=4096)
        # print(r)
        return r


class LlamaModel(HFAPIModel):

    def __init__(self, system_prompt: str, model_name: str = 'meta-llama/Llama-2-70b-chat-hf') -> None:
        assert model_name in ['meta-llama/Llama-2-70b-chat-hf',
                              'meta-llama/Llama-2-13b-chat-hf',
                              'meta-llama/Llama-2-7b-chat-hf'], 'Model not supported'
        super().__init__(model_name, system_prompt)

    def format_messages(self) -> str:
        """
        <s>[INST] <<SYS>>
        {system_prompt}
        <</SYS>>

        {user_message} [/INST]
        """
        messages = self.messages
        assert len(messages) >= 2  # must be at least a system and a user
        r = f'<s>[INST] <<SYS>>\n{messages[0]["content"]}\n<</SYS>>\n\n{messages[1]["content"]} [/INST]'
        for msg in messages[2:]:
            role, content = msg['role'], msg['content']
            if role == ROLE_SYSTEM:
                assert ValueError
            elif role == ROLE_USER:
                if r.endswith('</s>'):
                    r += '<s>'
                r += f'[INST] {content} [/INST]'
            elif role == ROLE_ASSISTANT:
                r += f'{content}</s>'
            else:
                raise ValueError
        return r