File size: 4,843 Bytes
0eeee8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
This module is responsible for modifying the chat prompt and history.
"""
import json
import re

import extensions.superboogav2.parameters as parameters

from modules import chat
from modules.text_generation import get_encoded_length
from modules.logging_colors import logger
from extensions.superboogav2.utils import create_context_text, create_metadata_source

from .data_processor import process_and_add_to_collector
from .chromadb import ChromaCollector


CHAT_METADATA = create_metadata_source('automatic-chat-insert')

INSTRUCT_MODE = 'instruct'
CHAT_INSTRUCT_MODE = 'chat-instruct'


def _is_instruct_mode(state: dict):
    mode = state.get('mode')
    return mode == INSTRUCT_MODE or mode == CHAT_INSTRUCT_MODE


def _remove_tag_if_necessary(user_input: str):
    if not parameters.get_is_manual():
        return user_input
    
    return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input)


def _should_query(input: str):
    if not parameters.get_is_manual():
        return True
    
    if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE):
        return True
    
    return False


def _format_single_exchange(name, text):
    if re.search(r':\s*$', name):
        return '{} {}\n'.format(name, text)
    else:
        return '{}: {}\n'.format(name, text)


def _get_names(state: dict):
    if _is_instruct_mode(state):
        user_name = state['name1_instruct']
        bot_name = state['name2_instruct']
    else:
        user_name = state['name1']
        bot_name = state['name2']

    if not user_name:
        user_name = 'User'
    if not bot_name:
        bot_name = 'Assistant'

    return user_name, bot_name


def _concatinate_history(history: dict, state: dict):
    full_history_text = ''
    user_name, bot_name = _get_names(state)

    # Grab the internal history.
    internal_history = history['internal']
    assert isinstance(internal_history, list)

    # Iterate through the history.
    for exchange in internal_history:
        assert isinstance(exchange, list)

        if len(exchange) >= 1:
            full_history_text += _format_single_exchange(user_name, exchange[0])
        if len(exchange) >= 2:
            full_history_text += _format_single_exchange(bot_name, exchange[1])

    return full_history_text[:-1] # Remove the last new line.


def _hijack_last(context_text: str, history: dict, max_len: int, state: dict):
    num_context_tokens = get_encoded_length(context_text)

    names = _get_names(state)[::-1]

    history_tokens = 0
    replace_position = None
    for i, messages in enumerate(reversed(history['internal'])):
        for j, message in enumerate(reversed(messages)):
            num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message))
            
            # TODO: This is an extremely naive solution. A more robust implementation must be made.
            if history_tokens + num_context_tokens <= max_len:
                # This message can be replaced
                replace_position = (i, j)
            
            history_tokens += num_message_tokens
    
    if replace_position is None:
        logger.warn("The provided context_text is too long to replace any message in the history.")
    else:
        # replace the message at replace_position with context_text
        i, j = replace_position
        history['internal'][-i-1][-j-1] = context_text


def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs):
    if parameters.get_add_chat_to_data():
        # Get the whole history as one string
        history_as_text = _concatinate_history(kwargs['history'], state)

        if history_as_text:
            # Delete all documents that were auto-inserted
            collector.delete(ids_to_delete=None, where=CHAT_METADATA)
            # Insert the processed history
            process_and_add_to_collector(history_as_text, collector, False, CHAT_METADATA)

    if _should_query(user_input):
        user_input = _remove_tag_if_necessary(user_input)
        results = collector.get_sorted_by_dist(user_input, n_results=parameters.get_chunk_count(), max_token_count=int(parameters.get_max_token_count()))

        # Check if the strategy is to modify the last message. If so, prepend or append to the user query.
        if parameters.get_injection_strategy() == parameters.APPEND_TO_LAST:
            user_input = user_input + create_context_text(results)
        elif parameters.get_injection_strategy() == parameters.PREPEND_TO_LAST:
            user_input = create_context_text(results) + user_input
        elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT:
            _hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state)
    
    return chat.generate_chat_prompt(user_input, state, **kwargs)