""" This module is responsible for modifying the chat prompt and history. """ import json import re import extensions.superboogav2.parameters as parameters from modules import chat from modules.text_generation import get_encoded_length from modules.logging_colors import logger from extensions.superboogav2.utils import create_context_text, create_metadata_source from .data_processor import process_and_add_to_collector from .chromadb import ChromaCollector CHAT_METADATA = create_metadata_source('automatic-chat-insert') INSTRUCT_MODE = 'instruct' CHAT_INSTRUCT_MODE = 'chat-instruct' def _is_instruct_mode(state: dict): mode = state.get('mode') return mode == INSTRUCT_MODE or mode == CHAT_INSTRUCT_MODE def _remove_tag_if_necessary(user_input: str): if not parameters.get_is_manual(): return user_input return re.sub(r'^\s*!c\s*|\s*!c\s*$', '', user_input) def _should_query(input: str): if not parameters.get_is_manual(): return True if re.search(r'^\s*!c|!c\s*$', input, re.MULTILINE): return True return False def _format_single_exchange(name, text): if re.search(r':\s*$', name): return '{} {}\n'.format(name, text) else: return '{}: {}\n'.format(name, text) def _get_names(state: dict): if _is_instruct_mode(state): user_name = state['name1_instruct'] bot_name = state['name2_instruct'] else: user_name = state['name1'] bot_name = state['name2'] if not user_name: user_name = 'User' if not bot_name: bot_name = 'Assistant' return user_name, bot_name def _concatinate_history(history: dict, state: dict): full_history_text = '' user_name, bot_name = _get_names(state) # Grab the internal history. internal_history = history['internal'] assert isinstance(internal_history, list) # Iterate through the history. for exchange in internal_history: assert isinstance(exchange, list) if len(exchange) >= 1: full_history_text += _format_single_exchange(user_name, exchange[0]) if len(exchange) >= 2: full_history_text += _format_single_exchange(bot_name, exchange[1]) return full_history_text[:-1] # Remove the last new line. def _hijack_last(context_text: str, history: dict, max_len: int, state: dict): num_context_tokens = get_encoded_length(context_text) names = _get_names(state)[::-1] history_tokens = 0 replace_position = None for i, messages in enumerate(reversed(history['internal'])): for j, message in enumerate(reversed(messages)): num_message_tokens = get_encoded_length(_format_single_exchange(names[j], message)) # TODO: This is an extremely naive solution. A more robust implementation must be made. if history_tokens + num_context_tokens <= max_len: # This message can be replaced replace_position = (i, j) history_tokens += num_message_tokens if replace_position is None: logger.warn("The provided context_text is too long to replace any message in the history.") else: # replace the message at replace_position with context_text i, j = replace_position history['internal'][-i-1][-j-1] = context_text def custom_generate_chat_prompt_internal(user_input: str, state: dict, collector: ChromaCollector, **kwargs): if parameters.get_add_chat_to_data(): # Get the whole history as one string history_as_text = _concatinate_history(kwargs['history'], state) if history_as_text: # Delete all documents that were auto-inserted collector.delete(ids_to_delete=None, where=CHAT_METADATA) # Insert the processed history process_and_add_to_collector(history_as_text, collector, False, CHAT_METADATA) if _should_query(user_input): user_input = _remove_tag_if_necessary(user_input) results = collector.get_sorted_by_dist(user_input, n_results=parameters.get_chunk_count(), max_token_count=int(parameters.get_max_token_count())) # Check if the strategy is to modify the last message. If so, prepend or append to the user query. if parameters.get_injection_strategy() == parameters.APPEND_TO_LAST: user_input = user_input + create_context_text(results) elif parameters.get_injection_strategy() == parameters.PREPEND_TO_LAST: user_input = create_context_text(results) + user_input elif parameters.get_injection_strategy() == parameters.HIJACK_LAST_IN_CONTEXT: _hijack_last(create_context_text(results), kwargs['history'], state['truncation_length'], state) return chat.generate_chat_prompt(user_input, state, **kwargs)