convosim-ui / models /business_logic_utils /response_processing.py
ivnban27-ctl's picture
fixed databricks integration
42a7266
raw
history blame
7.46 kB
import re
import random
from .config import AI_PHRASES
from .response_generation import generate_sim
def parse_model_response(response: dict, name: str = "") -> str:
"""
Parse the LLM response to extract the assistant's message and apply initial post-processing.
Args:
response (dict): The raw response from the LLM.
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
Returns:
str: The cleaned and parsed assistant's message.
"""
assistant_message = response["choices"][0]["message"]["content"]
cleaned_text = postprocess_text(
assistant_message,
name=name,
human_prefix="user:",
assistant_prefix="assistant:"
)
return cleaned_text
def postprocess_text(
text: str,
name: str = "",
human_prefix: str = "user:",
assistant_prefix: str = "assistant:",
strip_name: bool = True
) -> str:
"""Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs.
Args:
text (str): The text to process.
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
human_prefix (str, optional): The user prefix to remove. Defaults to "user:".
assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:".
strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True.
Returns:
str: Cleaned text.
"""
if text:
# Replace ellipses with a single period
text = re.sub(r'\.\.\.', '.', text)
# Remove unnecessary role prefixes
text = text.replace(human_prefix, "").replace(assistant_prefix, "")
# Remove whispers or other marked reactions
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
reactions = re.compile(r"(\*[\w\s]+\*)") # remove things like "*stutters*"
text = whispers.sub("", text)
text = reactions.sub("", text)
# Remove all quotation marks (both single and double)
text = text.replace('"', '').replace("'", "")
# Normalize spaces
text = re.sub(r"\s+", " ", text).strip()
return text
def apply_guardrails(model_input: dict, response: str, endpoint_url: str, endpoint_bearer_token: str) -> str:
"""Apply the 'I am an AI' guardrail to model responses"""
attempt = 0
max_attempts = 2
while attempt < max_attempts and contains_ai_phrase(response):
# Regenerate the response without modifying the conversation history
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token)
response = parse_model_response(completion)
attempt += 1
if contains_ai_phrase(response):
# Use only the last user message for regeneration
memory = model_input['messages']
last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None)
if last_user_message:
# Create a new conversation with system message and last user message
model_input_copy = {
**model_input,
'messages': [memory[0], last_user_message] # memory[0] is the system message
}
completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token)
response = parse_model_response(completion)
return response
def contains_ai_phrase(text: str) -> bool:
"""Check if the text contains any 'I am an AI' phrases."""
text_lower = text.lower()
return any(phrase.lower() in text_lower for phrase in AI_PHRASES)
def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str:
"""
Truncate the text at the last occurrence of a specified punctuation mark.
Args:
text (str): The text to truncate.
punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…').
Returns:
str: The truncated text.
"""
# Find the last position of any punctuation mark from the provided set
last_punct_position = max(text.rfind(p) for p in punctuation_marks)
# Check if any punctuation mark is found
if last_punct_position == -1:
# No punctuation found, return the original text
return text.strip()
# Return the truncated text up to and including the last punctuation mark
return text[:last_punct_position + 1].strip()
def split_texter_response(text: str) -> str:
"""
Splits the texter's response into multiple messages,
introducing '\ntexter:' prefixes after punctuation.
The number of messages is randomly chosen based on specified probabilities:
- 1 message: 30% chance
- 2 messages: 25% chance
- 3 messages: 45% chance
The first message does not include the '\ntexter:' prefix.
"""
# Use regex to split text into sentences, keeping the punctuation
sentences = re.findall(r'[^.!?]+[.!?]*', text)
# Remove empty strings from sentences
sentences = [s.strip() for s in sentences if s.strip()]
# Decide number of messages based on specified probabilities
num_messages = random.choices([1, 2, 3], weights=[0.3, 0.25, 0.45], k=1)[0]
# If not enough sentences to make the splits, adjust num_messages
if len(sentences) < num_messages:
num_messages = len(sentences)
# If num_messages is 1, return the original text
if num_messages == 1:
return text.strip()
# Calculate split points
# We need to divide the sentences into num_messages parts
avg = len(sentences) / num_messages
split_indices = [int(round(avg * i)) for i in range(1, num_messages)]
# Build the new text
new_text = ''
start = 0
for i, end in enumerate(split_indices + [len(sentences)]):
segment_sentences = sentences[start:end]
segment_text = ' '.join(segment_sentences).strip()
if i == 0:
# First segment, do not add '\ntexter:'
new_text += segment_text
else:
# Subsequent segments, add '\ntexter:'
new_text += f"\ntexter: {segment_text}"
start = end
return new_text.strip()
def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str:
"""
Process the raw model response, including parsing, applying guardrails,
truncation, and splitting the response into multiple messages if necessary.
Args:
completion (dict): Raw response from the LLM.
model_input (dict): The model input containing the conversation history.
endpoint_url (str): The URL of the endpoint.
endpoint_bearer_token (str): The authentication token for endpoint.
Returns:
str: Final processed response ready for the APP.
"""
# Step 1: Parse the raw response to extract the assistant's message
assistant_message = parse_model_response(completion)
# Step 2: Apply guardrails (handle possible AI responses)
guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token)
# Step 3: Apply response truncation
truncated_message = truncate_response(guardrail_message)
# Step 4: Split the response into multiple messages if needed
final_response = split_texter_response(truncated_message)
return final_response