Spaces:
Sleeping
Sleeping
File size: 7,462 Bytes
42a7266 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import re
import random
from .config import AI_PHRASES
from .response_generation import generate_sim
def parse_model_response(response: dict, name: str = "") -> str:
"""
Parse the LLM response to extract the assistant's message and apply initial post-processing.
Args:
response (dict): The raw response from the LLM.
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
Returns:
str: The cleaned and parsed assistant's message.
"""
assistant_message = response["choices"][0]["message"]["content"]
cleaned_text = postprocess_text(
assistant_message,
name=name,
human_prefix="user:",
assistant_prefix="assistant:"
)
return cleaned_text
def postprocess_text(
text: str,
name: str = "",
human_prefix: str = "user:",
assistant_prefix: str = "assistant:",
strip_name: bool = True
) -> str:
"""Eliminates whispers, reactions, ellipses, and quotation marks from generated text by LLMs.
Args:
text (str): The text to process.
name (str, optional): Name to strip from the beginning of the text. Defaults to "".
human_prefix (str, optional): The user prefix to remove. Defaults to "user:".
assistant_prefix (str, optional): The assistant prefix to remove. Defaults to "assistant:".
strip_name (bool, optional): Whether to remove the name at the beginning of the text. Defaults to True.
Returns:
str: Cleaned text.
"""
if text:
# Replace ellipses with a single period
text = re.sub(r'\.\.\.', '.', text)
# Remove unnecessary role prefixes
text = text.replace(human_prefix, "").replace(assistant_prefix, "")
# Remove whispers or other marked reactions
whispers = re.compile(r"(\([\w\s]+\))") # remove things like "(whispers)"
reactions = re.compile(r"(\*[\w\s]+\*)") # remove things like "*stutters*"
text = whispers.sub("", text)
text = reactions.sub("", text)
# Remove all quotation marks (both single and double)
text = text.replace('"', '').replace("'", "")
# Normalize spaces
text = re.sub(r"\s+", " ", text).strip()
return text
def apply_guardrails(model_input: dict, response: str, endpoint_url: str, endpoint_bearer_token: str) -> str:
"""Apply the 'I am an AI' guardrail to model responses"""
attempt = 0
max_attempts = 2
while attempt < max_attempts and contains_ai_phrase(response):
# Regenerate the response without modifying the conversation history
completion = generate_sim(model_input, endpoint_url, endpoint_bearer_token)
response = parse_model_response(completion)
attempt += 1
if contains_ai_phrase(response):
# Use only the last user message for regeneration
memory = model_input['messages']
last_user_message = next((msg for msg in reversed(memory) if msg['role'] == 'user'), None)
if last_user_message:
# Create a new conversation with system message and last user message
model_input_copy = {
**model_input,
'messages': [memory[0], last_user_message] # memory[0] is the system message
}
completion = generate_sim(model_input_copy, endpoint_url, endpoint_bearer_token)
response = parse_model_response(completion)
return response
def contains_ai_phrase(text: str) -> bool:
"""Check if the text contains any 'I am an AI' phrases."""
text_lower = text.lower()
return any(phrase.lower() in text_lower for phrase in AI_PHRASES)
def truncate_response(text: str, punctuation_marks: tuple = ('.', '!', '?', '…')) -> str:
"""
Truncate the text at the last occurrence of a specified punctuation mark.
Args:
text (str): The text to truncate.
punctuation_marks (tuple, optional): A tuple of punctuation marks to use for truncation. Defaults to ('.', '!', '?', '…').
Returns:
str: The truncated text.
"""
# Find the last position of any punctuation mark from the provided set
last_punct_position = max(text.rfind(p) for p in punctuation_marks)
# Check if any punctuation mark is found
if last_punct_position == -1:
# No punctuation found, return the original text
return text.strip()
# Return the truncated text up to and including the last punctuation mark
return text[:last_punct_position + 1].strip()
def split_texter_response(text: str) -> str:
"""
Splits the texter's response into multiple messages,
introducing '\ntexter:' prefixes after punctuation.
The number of messages is randomly chosen based on specified probabilities:
- 1 message: 30% chance
- 2 messages: 25% chance
- 3 messages: 45% chance
The first message does not include the '\ntexter:' prefix.
"""
# Use regex to split text into sentences, keeping the punctuation
sentences = re.findall(r'[^.!?]+[.!?]*', text)
# Remove empty strings from sentences
sentences = [s.strip() for s in sentences if s.strip()]
# Decide number of messages based on specified probabilities
num_messages = random.choices([1, 2, 3], weights=[0.3, 0.25, 0.45], k=1)[0]
# If not enough sentences to make the splits, adjust num_messages
if len(sentences) < num_messages:
num_messages = len(sentences)
# If num_messages is 1, return the original text
if num_messages == 1:
return text.strip()
# Calculate split points
# We need to divide the sentences into num_messages parts
avg = len(sentences) / num_messages
split_indices = [int(round(avg * i)) for i in range(1, num_messages)]
# Build the new text
new_text = ''
start = 0
for i, end in enumerate(split_indices + [len(sentences)]):
segment_sentences = sentences[start:end]
segment_text = ' '.join(segment_sentences).strip()
if i == 0:
# First segment, do not add '\ntexter:'
new_text += segment_text
else:
# Subsequent segments, add '\ntexter:'
new_text += f"\ntexter: {segment_text}"
start = end
return new_text.strip()
def process_model_response(completion: dict, model_input: dict, endpoint_url: str, endpoint_bearer_token: str) -> str:
"""
Process the raw model response, including parsing, applying guardrails,
truncation, and splitting the response into multiple messages if necessary.
Args:
completion (dict): Raw response from the LLM.
model_input (dict): The model input containing the conversation history.
endpoint_url (str): The URL of the endpoint.
endpoint_bearer_token (str): The authentication token for endpoint.
Returns:
str: Final processed response ready for the APP.
"""
# Step 1: Parse the raw response to extract the assistant's message
assistant_message = parse_model_response(completion)
# Step 2: Apply guardrails (handle possible AI responses)
guardrail_message = apply_guardrails(model_input, assistant_message, endpoint_url, endpoint_bearer_token)
# Step 3: Apply response truncation
truncated_message = truncate_response(guardrail_message)
# Step 4: Split the response into multiple messages if needed
final_response = split_texter_response(truncated_message)
return final_response
|