convosim-ui / models /business_logic_utils /response_processing.py
aftorresc's picture
Llama 3.1 Update and Spanish Scenarios (#17)
f755aca verified
import re
import random
from .config import AI_PHRASES, REFUSAL_PHRASES, REPLACEMENT_PHRASES
from .response_generation import generate_sim
def parse_model_response(response: dict, name: str = "") -> str:
"""
Parse the LLM response to extract the assistant's message and apply initial post-processing.
Args:
response (dict): The raw response from the LLM.
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
Returns:
str: The cleaned and parsed assistant's message.
"""
assistant_message = response["choices"][0]["message"]["content"]
cleaned_text = postprocess_text(
assistant_message,
name=name,
human_prefix="user:",
assistant_prefix="assistant:"
)
return cleaned_text
def postprocess_text(
text: str,
name: str = "",
human_prefix: str = "user:",
assistant_prefix: str = "assistant:",
strip_name: bool = True
) -> str:
"""Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs.
Args:
text (str): The text to process.
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
human_prefix (str, optional): The user prefix to remove. Defaults to "user:".
assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:".
strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True.
Returns:
str: Cleaned text.
"""
if text:
# Replace ellipses with a single period
text = re.sub(r'\.\.\.', '.', text)
# Remove unnecessary role prefixes
text = text.replace(human_prefix, "").replace(assistant_prefix, "")
# Remove whispers or other marked reactions
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
reactions = re.compile(r"(\*[\w\s]+\*)") # remove things like "*stutters*"
text = whispers.sub("", text)
text = reactions.sub("", text)
# Remove double quotation marks
text = text.replace('"', '')
# Remove stutters of any length (e.g., "M-m-my" or "M-m-m-m-my" or "M-My" to "My")
text = re.sub(r'\b(\w)(-\1)+-\1(\w*)', r'\1\3', text, flags=re.IGNORECASE)
# Normalize spaces
text = re.sub(r"\s+", " ", text).strip()
return text
def apply_guardrails(
model_input: dict,
response: str,
endpoint_url: str,
endpoint_bearer_token: str
) -> str:
"""
Apply guardrails to the model's response, including checking for AI phrases and refusal phrases.
"""
# Retrieve language from model_input (defaulting to 'en')
language = model_input.get("language", "en")
# 1. Guardrail against AI phrases
attempt_ai = 0
max_attempts_ai = 2
while attempt_ai < max_attempts_ai and contains_ai_phrase(response, language=language):
# Regenerate the response without modifying the conversation history
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token)
response = parse_model_response(completion)
attempt_ai += 1
# If AI phrases still appear, try narrowing to only the system + last user message
if contains_ai_phrase(response, language=language):
memory = model_input['messages']
last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None)
if last_user_message:
# Create a new conversation with the system message & last user message
model_input_copy = {
**model_input,
'messages': [memory[0], last_user_message]
}
completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token)
response = parse_model_response(completion)
# 2. Guardrail against refusal phrases
attempt_refusal = 0
max_attempts_refusal = 5
while attempt_refusal < max_attempts_refusal and contains_refusal_phrase(response, language=language):
# Regenerate the response
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token)
response = parse_model_response(completion)
attempt_refusal += 1
# If refusal phrases are still present, pick a random replacement in the correct language
if contains_refusal_phrase(response, language=language):
# Fallback to English list if language not found
replacement_list = REPLACEMENT_PHRASES.get(language, REPLACEMENT_PHRASES["en"])
response = random.choice(replacement_list)
return response
def contains_ai_phrase(text: str, language: str = "en") -> bool:
"""Check if the text contains any 'I am an AI' phrases in the specified language."""
text_lower = text.lower()
# Fallback to English list if language not found
ai_list = AI_PHRASES.get(language, AI_PHRASES["en"])
return any(phrase.lower() in text_lower for phrase in ai_list)
def contains_refusal_phrase(text: str, language: str = "en") -> bool:
"""Check if the text contains any phrases from the refusal phrases list in the specified language."""
text_lower = text.lower()
# Fallback to English list if language not found
refusal_list = REFUSAL_PHRASES.get(language, REFUSAL_PHRASES["en"])
return any(phrase.lower() in text_lower for phrase in refusal_list)
def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str:
"""
Truncate the text at the last occurrence of a specified punctuation mark.
Args:
text (str): The text to truncate.
punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…').
Returns:
str: The truncated text.
"""
# Find the last position of any punctuation mark from the provided set
last_punct_position = max(text.rfind(p) for p in punctuation_marks)
# Check if any punctuation mark is found
if last_punct_position == -1:
# No punctuation found, return the original text
return text.strip()
# Return the truncated text up to and including the last punctuation mark
return text[:last_punct_position + 1].strip()
def split_texter_response(text: str) -> str:
"""
Splits the texter's response into multiple messages,
introducing '\ntexter:' prefixes after punctuation.
The number of messages is randomly chosen based on specified probabilities:
- 1 message: 30% chance
- 2 messages: 25% chance
- 3 messages: 45% chance
The first message does not include the '\ntexter:' prefix.
"""
# Use regex to split text into sentences, keeping the punctuation
sentences = re.findall(r'[^.!?]+[.!?]*', text)
# Remove empty strings from sentences
sentences = [s.strip() for s in sentences if s.strip()]
# Decide number of messages based on specified probabilities
num_messages = random.choices([1, 2, 3], weights=[0.3, 0.25, 0.45], k=1)[0]
# If not enough sentences to make the splits, adjust num_messages
if len(sentences) < num_messages:
num_messages = len(sentences)
# If num_messages is 1, return the original text
if num_messages == 1:
return text.strip()
# Calculate split points
# We need to divide the sentences into num_messages parts
avg = len(sentences) / num_messages
split_indices = [int(round(avg * i)) for i in range(1, num_messages)]
# Build the new text
new_text = ''
start = 0
for i, end in enumerate(split_indices + [len(sentences)]):
segment_sentences = sentences[start:end]
segment_text = ' '.join(segment_sentences).strip()
if i == 0:
# First segment, do not add '\ntexter:'
new_text += segment_text
else:
# Subsequent segments, add '\ntexter:'
new_text += f"\ntexter: {segment_text}"
start = end
return new_text.strip()
def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str:
"""
Process the raw model response, including parsing, applying guardrails,
truncation, and splitting the response into multiple messages if necessary.
Args:
completion (dict): Raw response from the LLM.
model_input (dict): The model input containing the conversation history.
endpoint_url (str): The URL of the endpoint.
endpoint_bearer_token (str): The authentication token for endpoint.
Returns:
str: Final processed response ready for the APP.
"""
# Step 1: Parse the raw response to extract the assistant's message
assistant_message = parse_model_response(completion)
# Step 2: Apply guardrails (handle possible AI responses)
guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token)
# Step 3: Apply response truncation
truncated_message = truncate_response(guardrail_message)
# Step 4: Split the response into multiple messages if needed
final_response = split_texter_response(truncated_message)
return final_response