Spaces:
Running
Running
import re | |
import random | |
from .config import AI_PHRASES, REFUSAL_PHRASES, REPLACEMENT_PHRASES | |
from .response_generation import generate_sim | |
def parse_model_response(response: dict, name: str = "") -> str: | |
""" | |
Parse the LLM response to extract the assistant's message and apply initial post-processing. | |
Args: | |
response (dict): The raw response from the LLM. | |
name (str, optional): Name to strip from the beginning of the text. Defaults to "". | |
Returns: | |
str: The cleaned and parsed assistant's message. | |
""" | |
assistant_message = response["choices"][0]["message"]["content"] | |
cleaned_text = postprocess_text( | |
assistant_message, | |
name=name, | |
human_prefix="user:", | |
assistant_prefix="assistant:" | |
) | |
return cleaned_text | |
def capitalize_sentences(text: str) -> str: | |
"""Capitalize the first word of each sentence in the text.""" | |
def cap(match): | |
return match.group(1) + match.group(2).upper() | |
# Capitalize after . ! ? followed by space | |
text = re.sub(r'([.!?]\s+)([a-z])', cap, text) | |
# Also capitalize the first letter of the text if it's lowercase | |
if text and text[0].islower(): | |
text = text[0].upper() + text[1:] | |
return text | |
def postprocess_text( | |
text: str, | |
name: str = "", | |
human_prefix: str = "user:", | |
assistant_prefix: str = "assistant:", | |
strip_name: bool = True | |
) -> str: | |
"""Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs. | |
Args: | |
text (str): The text to process. | |
name (str, optional): Name to strip from the beginning of the text. Defaults to "". | |
human_prefix (str, optional): The user prefix to remove. Defaults to "user:". | |
assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:". | |
strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True. | |
Returns: | |
str: Cleaned text. | |
""" | |
if text: | |
# Replace ellipses with a single period | |
text = re.sub(r'\.\.\.', '.', text) | |
# Remove unnecessary role prefixes | |
text = text.replace(human_prefix, "").replace(assistant_prefix, "") | |
# Remove whispers or other marked reactions | |
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)" | |
reactions = re.compile(r"(\*[\w\s]+\*)") # remove things like "*stutters*" | |
text = whispers.sub("", text) | |
text = reactions.sub("", text) | |
# Remove double quotation marks | |
text = text.replace('"', '') | |
# Remove stutters of any length (e.g., "M-m-my" or "M-m-m-m-my" or "M-My" to "My") | |
text = re.sub(r'\b(\w)(-\1)+-\1(\w*)', r'\1\3', text, flags=re.IGNORECASE) | |
# Normalize spaces | |
text = re.sub(r"\s+", " ", text).strip() | |
# Capitalize start of sentences | |
text = capitalize_sentences(text) | |
return text | |
def apply_guardrails( | |
model_input: dict, | |
response: str, | |
endpoint_url: str, | |
endpoint_bearer_token: str | |
) -> str: | |
""" | |
Apply guardrails to the model's response, including checking for AI phrases and refusal phrases. | |
""" | |
# Retrieve language from model_input (defaulting to 'en') | |
language = model_input.get("language", "en") | |
# 1. Guardrail against AI phrases | |
attempt_ai = 0 | |
max_attempts_ai = 2 | |
while attempt_ai < max_attempts_ai and contains_ai_phrase(response, language=language): | |
# Regenerate the response without modifying the conversation history | |
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token) | |
response = parse_model_response(completion) | |
attempt_ai += 1 | |
# If AI phrases still appear, try narrowing to only the system + last user message | |
if contains_ai_phrase(response, language=language): | |
memory = model_input['messages'] | |
last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None) | |
if last_user_message: | |
# Create a new conversation with the system message & last user message | |
model_input_copy = { | |
**model_input, | |
'messages': [memory[0], last_user_message] | |
} | |
completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token) | |
response = parse_model_response(completion) | |
# 2. Guardrail against refusal phrases | |
attempt_refusal = 0 | |
max_attempts_refusal = 5 | |
while attempt_refusal < max_attempts_refusal and contains_refusal_phrase(response, language=language): | |
# Regenerate the response | |
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token) | |
response = parse_model_response(completion) | |
attempt_refusal += 1 | |
# If refusal phrases are still present, pick a random replacement in the correct language | |
if contains_refusal_phrase(response, language=language): | |
# Fallback to English list if language not found | |
replacement_list = REPLACEMENT_PHRASES.get(language, REPLACEMENT_PHRASES["en"]) | |
response = random.choice(replacement_list) | |
return response | |
def contains_ai_phrase(text: str, language: str = "en") -> bool: | |
"""Check if the text contains any 'I am an AI' phrases in the specified language.""" | |
text_lower = text.lower() | |
# Fallback to English list if language not found | |
ai_list = AI_PHRASES.get(language, AI_PHRASES["en"]) | |
return any(phrase.lower() in text_lower for phrase in ai_list) | |
def contains_refusal_phrase(text: str, language: str = "en") -> bool: | |
"""Check if the text contains any phrases from the refusal phrases list in the specified language.""" | |
text_lower = text.lower() | |
# Fallback to English list if language not found | |
refusal_list = REFUSAL_PHRASES.get(language, REFUSAL_PHRASES["en"]) | |
return any(phrase.lower() in text_lower for phrase in refusal_list) | |
def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str: | |
""" | |
Truncate the text at the last occurrence of a specified punctuation mark. | |
Args: | |
text (str): The text to truncate. | |
punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…'). | |
Returns: | |
str: The truncated text. | |
""" | |
# Find the last position of any punctuation mark from the provided set | |
last_punct_position = max(text.rfind(p) for p in punctuation_marks) | |
# Check if any punctuation mark is found | |
if last_punct_position == -1: | |
# No punctuation found, return the original text | |
return text.strip() | |
# Return the truncated text up to and including the last punctuation mark | |
return text[:last_punct_position + 1].strip() | |
def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str: | |
""" | |
Process the raw model response, including parsing, applying guardrails, | |
truncation, and splitting the response into multiple messages if necessary. | |
Args: | |
completion (dict): Raw response from the LLM. | |
model_input (dict): The model input containing the conversation history. | |
endpoint_url (str): The URL of the endpoint. | |
endpoint_bearer_token (str): The authentication token for endpoint. | |
Returns: | |
str: Final processed response ready for the APP. | |
""" | |
# Step 1: Parse the raw response to extract the assistant's message | |
assistant_message = parse_model_response(completion) | |
# Step 2: Apply guardrails (handle possible AI responses) | |
guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token) | |
# Step 3: Apply response truncation | |
final_response = truncate_response(guardrail_message) | |
return final_response |