File size: 12,210 Bytes
1d777c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import logging
from re import split, sub
from threading import Lock
from time import sleep
from typing import Tuple, Dict

try:
    from extensions.telegram_bot.source.generators.abstract_generator import AbstractGenerator
except ImportError:
    from source.generators.abstract_generator import AbstractGenerator

try:
    import extensions.telegram_bot.source.const as const
    import extensions.telegram_bot.source.utils as utils
    import extensions.telegram_bot.source.generator as generator
    from extensions.telegram_bot.source.user import User as User
    from extensions.telegram_bot.source.conf import cfg
    from extensions.telegram_bot.source.conf import cfg
except ImportError:
    import source.const as const
    import source.utils as utils
    import source.generator as generator
    from source.user import User as User
    from source.conf import cfg

# Define generator lock to prevent GPU overloading
generator_lock = Lock()

# Generator obj
debug_flag = True


# ====================================================================================
# TEXT LOGIC
async def aget_answer(text_in: str, user: User, bot_mode: str, generation_params: Dict, name_in="") -> Tuple[str, str]:
    return await get_answer(text_in, user, bot_mode, generation_params, name_in)


@utils.async_wrap
def get_answer(text_in: str, user: User, bot_mode: str, generation_params: Dict, name_in=""):
    # additional delay option
    if cfg.answer_delay > 0:
        sleep(cfg.answer_delay)
    # if generation will fail, return "fail" answer
    answer = const.GENERATOR_FAIL
    # default result action - message
    return_msg_action = const.MSG_SEND
    # if user is default equal to user1
    name_in = user.name1 if name_in == "" else name_in
    # for regeneration result checking
    previous_result = ""
    # acquire generator lock if we can
    generator_lock.acquire(timeout=cfg.generation_timeout)
    # user_input preprocessing
    try:
        # Preprocessing: actions which return result immediately:
        if text_in.startswith(tuple(cfg.permanent_change_name2_prefixes)):
            # If user_in starts with perm_prefix - just replace name2
            user.name2 = text_in[2:]
            return_msg_action = const.MSG_SYSTEM
            generator_lock.release()
            return "New bot name: " + user.name2, return_msg_action
        if text_in.startswith(tuple(cfg.permanent_change_name1_prefixes)):
            # If user_in starts with perm_prefix - just replace name2
            user.name1 = text_in[2:]
            return_msg_action = const.MSG_SYSTEM
            generator_lock.release()
            return "New user name: " + user.name1, return_msg_action
        if text_in.startswith(tuple(cfg.permanent_add_context_prefixes)):
            # If user_in starts with perm_prefix - just replace name2
            user.context += "\n" + text_in[2:]
            return_msg_action = const.MSG_SYSTEM
            generator_lock.release()
            return "Added to context: " + text_in[2:], return_msg_action
        if text_in.startswith(tuple(cfg.replace_prefixes)):
            # If user_in starts with replace_prefix - fully replace last message
            user.change_last_message(history_out=text_in[1:])
            return_msg_action = const.MSG_DEL_LAST
            generator_lock.release()
            return user.history_last_out, return_msg_action
        if text_in == const.GENERATOR_MODE_DEL_WORD:
            # If user_in starts with replace_prefix - fully replace last message
            # get and change last message
            last_message = user.history_last_out
            last_word = split(r"\n|\.+ +|: +|! +|\? +|\' +|\" +|; +|\) +|\* +", last_message)[-1]
            if len(last_word) == 0 and len(last_message) > 1:
                last_word = " "
            new_last_message = last_message[: -(len(last_word))]
            new_last_message = new_last_message.strip()
            if len(new_last_message) == 0:
                return_msg_action = const.MSG_NOTHING_TO_DO
            else:
                user.change_last_message(history_out=new_last_message)
            generator_lock.release()
            return user.history_last_out, return_msg_action

        # Preprocessing: actions which not depends on user input:
        if bot_mode in [const.MODE_QUERY]:
            user.history = []

        # if regenerate - msg_id the same, text and name the same. But history clearing:
        if text_in == const.GENERATOR_MODE_REGENERATE:
            if str(user.msg_id[-1]) not in user.previous_history:
                user.previous_history.update({str(user.msg_id[-1]): []})
            user.previous_history[str(user.msg_id[-1])].append(user.history_last_out)
            text_in = user.text_in[-1]
            name_in = user.name_in[-1]
            last_msg_id = user.msg_id[-1]
            user.truncate_last_message()
            user.msg_id.append(last_msg_id)

        # Preprocessing: add user_in/names/whitespaces to history in right order depends on mode:
        if bot_mode in [const.MODE_NOTEBOOK]:
            # If notebook mode - append to history only user_in, no additional preparing;
            user.text_in.append(text_in)
            user.history_append("", text_in)
        elif text_in == const.GENERATOR_MODE_IMPERSONATE:
            # if impersonate - append to history only "name1:", no adding "" history
            # line to prevent bug in history sequence, add "name1:" prefix for generation
            user.text_in.append(text_in)
            user.name_in.append(name_in)
            user.history_append("", name_in + ":")
        elif text_in == const.GENERATOR_MODE_NEXT:
            # if user_in is "" - no user text, it is like continue generation adding "" history line
            #  to prevent bug in history sequence, add "name2:" prefix for generation
            user.text_in.append(text_in)
            user.name_in.append(name_in)
            user.history_append("", user.name2 + ":")
        elif text_in == const.GENERATOR_MODE_CONTINUE:
            # if user_in is "" - no user text, it is like continue generation
            # adding "" history line to prevent bug in history sequence, add "name2:" prefix for generation
            pass
        elif text_in.startswith(tuple(cfg.sd_api_prefixes)):
            # If user_in starts with prefix - impersonate-like (if you try to get "impersonate view")
            # adding "" line to prevent bug in history sequence, user_in is prefix for bot answer
            user.msg_id.append(0)
            user.text_in.append(text_in)
            user.name_in.append(name_in)
            if len(text_in) == 1:
                user.history_append("", cfg.sd_api_prompt_self)
            else:
                user.history_append("", cfg.sd_api_prompt_of.replace("OBJECT", text_in[1:].strip()))
            return_msg_action = const.MSG_SD_API
        elif text_in.startswith(tuple(cfg.impersonate_prefixes)):
            # If user_in starts with prefix - impersonate-like (if you try to get "impersonate view")
            # adding "" line to prevent bug in history sequence, user_in is prefix for bot answer

            user.text_in.append(text_in)
            user.name_in.append(text_in[1:])
            user.history_append("", text_in[1:] + ":")

        else:
            # If not notebook/impersonate/continue mode then ordinary chat preparing
            # add "name1&2:" to user and bot message (generation from name2 point of view);
            user.text_in.append(text_in)
            user.name_in.append(name_in)
            user.history_append(name_in + ": " + text_in, user.name2 + ":")
    except Exception as exception:
        generator_lock.release()
        logging.error("get_answer (prepare text part) " + str(exception) + str(exception.args))

    # Text processing with LLM
    try:
        # Set eos_token and stopping_strings.
        stopping_strings = generation_params["stopping_strings"].copy()
        eos_token = generation_params["eos_token"]
        if bot_mode in [const.MODE_CHAT, const.MODE_CHAT_R, const.MODE_ADMIN]:
            stopping_strings += [
                name_in + ":",
                user.name1 + ":",
                user.name2 + ":",
            ]
        if cfg.bot_prompt_end != "":
            stopping_strings.append(cfg.bot_prompt_end)

        # adjust context/greeting/example
        if user.context.strip().endswith("\n"):
            context = f"{user.context.strip()}"
        else:
            context = f"{user.context.strip()}\n"
        context = cfg.context_prompt_begin + context + cfg.context_prompt_end
        if len(user.example) > 0:
            example = user.example
        else:
            example = ""
        if len(user.greeting) > 0:
            greeting = "\n" + user.name2 + ": " + user.greeting
        else:
            greeting = ""
        # Make prompt: context + example + conversation history
        available_len = generation_params["truncation_length"]
        context_len = generator.get_tokens_count(context)
        available_len -= context_len
        if available_len < 0:
            available_len = 0
            logging.info("telegram_bot - CONTEXT IS TOO LONG!!!")

        conversation = [example, greeting]
        for i in user.history:
            if len(i["in"]) > 0:
                conversation.append("".join([cfg.user_prompt_begin, i["in"], cfg.user_prompt_end]))
            if len(i["out"]) > 0:
                conversation.append("".join([cfg.bot_prompt_begin, i["out"], cfg.bot_prompt_end]))
        if len(cfg.bot_prompt_end):
            conversation[-1] = conversation[-1][: -len(cfg.bot_prompt_end)]

        prompt = ""
        for s in reversed(conversation):
            s = "\n" + s if len(s) > 0 else s
            s_len = generator.get_tokens_count(s)
            if available_len >= s_len:
                prompt = s + prompt
                available_len -= s_len
            else:
                break
        prompt = context + prompt
        prompt = sub(
            r": +",
            ": ",
            prompt,
        )
        # Generate!
        if debug_flag:
            print(prompt)
        answer = generator.generate_answer(
            prompt=prompt,
            generation_params=generation_params,
            eos_token=eos_token,
            stopping_strings=stopping_strings,
            default_answer=answer,
            turn_template=user.turn_template,
        )
        if debug_flag:
            print(answer)
        # Truncate prompt prefix/postfix
        if len(cfg.bot_prompt_end) > 0 and answer.endswith(cfg.bot_prompt_end):
            answer = answer[: -len(cfg.bot_prompt_end)]
        if len(cfg.bot_prompt_end) > 2 and answer.endswith(cfg.bot_prompt_end[:-1]):
            answer = answer[: -len(cfg.bot_prompt_end[:-1])]
        if len(cfg.bot_prompt_begin) > 0 and answer.startswith(cfg.bot_prompt_begin):
            answer = answer[: -len(cfg.bot_prompt_begin)]
        # If generation result zero length - return  "Empty answer."
        if len(answer) < 1:
            answer = const.GENERATOR_EMPTY_ANSWER
        # Final return
        if answer not in [const.GENERATOR_EMPTY_ANSWER, const.GENERATOR_FAIL]:
            # if everything ok - add generated answer in history and return
            # last
            for end in stopping_strings:
                if answer.endswith(end):
                    answer = answer[: -len(end)]
            user.change_last_message(history_out=user.history_last_out + " " + answer)
        generator_lock.release()
        if len(user.msg_id) > 0:
            if str(user.msg_id[-1]) in user.previous_history:
                if user.previous_history[str(user.msg_id[-1])][-1] == user.history_last_out:
                    return_msg_action = const.MSG_NOTHING_TO_DO
        return user.history_last_out, return_msg_action
    except Exception as exception:
        logging.error("get_answer (generator part) " + str(exception) + str(exception.args))
        # anyway, release generator lock. Then return
        generator_lock.release()
        return_msg_action = const.MSG_SYSTEM
        return user.history_last_out, return_msg_action