Spaces:
Runtime error
Runtime error
| # Character_Interaction_tab.py | |
| # Description: This file contains the functions that are used for Character Interactions in the Gradio UI. | |
| # | |
| # Imports | |
| import base64 | |
| import io | |
| import uuid | |
| from datetime import datetime as datetime | |
| import logging | |
| import json | |
| import os | |
| from typing import List, Dict, Tuple, Union | |
| # | |
| # External Imports | |
| import gradio as gr | |
| from PIL import Image | |
| # | |
| # Local Imports | |
| from App_Function_Libraries.Chat.Chat_Functions import chat, load_characters, save_chat_history_to_db_wrapper | |
| from App_Function_Libraries.Gradio_UI.Chat_ui import chat_wrapper | |
| from App_Function_Libraries.Gradio_UI.Writing_tab import generate_writing_feedback | |
| from App_Function_Libraries.Utils.Utils import default_api_endpoint, format_api_name, global_api_endpoints | |
| # | |
| ######################################################################################################################## | |
| # | |
| # Single-Character chat Functions: | |
| # FIXME - add these functions to the Personas library | |
| def chat_with_character(user_message, history, char_data, api_name_input, api_key): | |
| if char_data is None: | |
| return history, "Please import a character card first." | |
| bot_message = generate_writing_feedback(user_message, char_data['name'], "Overall", api_name_input, | |
| api_key) | |
| history.append((user_message, bot_message)) | |
| return history, "" | |
| def import_character_card(file): | |
| if file is None: | |
| logging.warning("No file provided for character card import") | |
| return None | |
| try: | |
| if file.name.lower().endswith(('.png', '.webp')): | |
| logging.info(f"Attempting to import character card from image: {file.name}") | |
| json_data = extract_json_from_image(file) | |
| if json_data: | |
| logging.info("JSON data extracted from image, attempting to parse") | |
| card_data = import_character_card_json(json_data) | |
| if card_data: | |
| # Save the image data | |
| with Image.open(file) as img: | |
| img_byte_arr = io.BytesIO() | |
| img.save(img_byte_arr, format='PNG') | |
| card_data['image'] = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| return card_data | |
| else: | |
| logging.warning("No JSON data found in the image") | |
| else: | |
| logging.info(f"Attempting to import character card from JSON file: {file.name}") | |
| content = file.read().decode('utf-8') | |
| return import_character_card_json(content) | |
| except Exception as e: | |
| logging.error(f"Error importing character card: {e}") | |
| return None | |
| def import_character_card_json(json_content): | |
| try: | |
| # Remove any leading/trailing whitespace | |
| json_content = json_content.strip() | |
| # Log the first 100 characters of the content | |
| logging.debug(f"JSON content (first 100 chars): {json_content[:100]}...") | |
| card_data = json.loads(json_content) | |
| logging.debug(f"Parsed JSON data keys: {list(card_data.keys())}") | |
| if 'spec' in card_data and card_data['spec'] == 'chara_card_v2': | |
| logging.info("Detected V2 character card") | |
| return card_data['data'] | |
| else: | |
| logging.info("Assuming V1 character card") | |
| return card_data | |
| except json.JSONDecodeError as e: | |
| logging.error(f"JSON decode error: {e}") | |
| logging.error(f"Problematic JSON content: {json_content[:500]}...") | |
| except Exception as e: | |
| logging.error(f"Unexpected error parsing JSON: {e}") | |
| return None | |
| def extract_json_from_image(image_file): | |
| logging.debug(f"Attempting to extract JSON from image: {image_file.name}") | |
| try: | |
| with Image.open(image_file) as img: | |
| logging.debug("Image opened successfully") | |
| metadata = img.info | |
| if 'chara' in metadata: | |
| logging.debug("Found 'chara' in image metadata") | |
| chara_content = metadata['chara'] | |
| logging.debug(f"Content of 'chara' metadata (first 100 chars): {chara_content[:100]}...") | |
| try: | |
| decoded_content = base64.b64decode(chara_content).decode('utf-8') | |
| logging.debug(f"Decoded content (first 100 chars): {decoded_content[:100]}...") | |
| return decoded_content | |
| except Exception as e: | |
| logging.error(f"Error decoding base64 content: {e}") | |
| logging.debug("'chara' not found in metadata, checking for base64 encoded data") | |
| raw_data = img.tobytes() | |
| possible_json = raw_data.split(b'{', 1)[-1].rsplit(b'}', 1)[0] | |
| if possible_json: | |
| try: | |
| decoded = base64.b64decode(possible_json).decode('utf-8') | |
| if decoded.startswith('{') and decoded.endswith('}'): | |
| logging.debug("Found and decoded base64 JSON data") | |
| return '{' + decoded + '}' | |
| except Exception as e: | |
| logging.error(f"Error decoding base64 data: {e}") | |
| logging.warning("No JSON data found in the image") | |
| except Exception as e: | |
| logging.error(f"Error extracting JSON from image: {e}") | |
| return None | |
| def load_chat_history(file): | |
| try: | |
| content = file.read().decode('utf-8') | |
| chat_data = json.loads(content) | |
| return chat_data['history'], chat_data['character'] | |
| except Exception as e: | |
| logging.error(f"Error loading chat history: {e}") | |
| return None, None | |
| # | |
| # End of X | |
| ###################################################################################################################### | |
| # | |
| # Multi-Character Chat Interface | |
| # FIXME - refactor and move these functions to the Character_Chat library so that it uses the same functions | |
| def character_interaction_setup(): | |
| characters = load_characters() | |
| return characters, [], None, None | |
| def extract_character_response(response: Union[str, Tuple]) -> str: | |
| if isinstance(response, tuple): | |
| # If it's a tuple, try to extract the first string element | |
| for item in response: | |
| if isinstance(item, str): | |
| return item.strip() | |
| # If no string found, return a default message | |
| return "I'm not sure how to respond." | |
| elif isinstance(response, str): | |
| # If it's already a string, just return it | |
| return response.strip() | |
| else: | |
| # For any other type, return a default message | |
| return "I'm having trouble forming a response." | |
| # def process_character_response(response: str) -> str: | |
| # # Remove any leading explanatory text before the first '---' | |
| # parts = response.split('---') | |
| # if len(parts) > 1: | |
| # return '---' + '---'.join(parts[1:]) | |
| # return response.strip() | |
| def process_character_response(response: Union[str, Tuple]) -> str: | |
| if isinstance(response, tuple): | |
| response = ' '.join(str(item) for item in response if isinstance(item, str)) | |
| if isinstance(response, str): | |
| # Remove any leading explanatory text before the first '---' | |
| parts = response.split('---') | |
| if len(parts) > 1: | |
| return '---' + '---'.join(parts[1:]) | |
| return response.strip() | |
| else: | |
| return "I'm having trouble forming a response." | |
| def character_turn(characters: Dict, conversation: List[Tuple[str, str]], | |
| current_character: str, other_characters: List[str], | |
| api_endpoint: str, api_key: str, temperature: float, | |
| scenario: str = "") -> Tuple[List[Tuple[str, str]], str]: | |
| if not current_character or current_character not in characters: | |
| return conversation, current_character | |
| if not conversation and scenario: | |
| conversation.append(("Scenario", scenario)) | |
| current_char = characters[current_character] | |
| other_chars = [characters[char] for char in other_characters if char in characters and char != current_character] | |
| prompt = f"{current_char['name']}'s personality: {current_char['personality']}\n" | |
| for char in other_chars: | |
| prompt += f"{char['name']}'s personality: {char['personality']}\n" | |
| prompt += "Conversation so far:\n" + "\n".join([f"{sender}: {message}" for sender, message in conversation]) | |
| prompt += f"\n\nHow would {current_char['name']} respond?" | |
| try: | |
| response = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "", None, False, temperature, "") | |
| processed_response = process_character_response(response) | |
| conversation.append((current_char['name'], processed_response)) | |
| except Exception as e: | |
| error_message = f"Error generating response: {str(e)}" | |
| conversation.append((current_char['name'], error_message)) | |
| return conversation, current_character | |
| def character_interaction(character1: str, character2: str, api_endpoint: str, api_key: str, | |
| num_turns: int, scenario: str, temperature: float, | |
| user_interjection: str = "") -> List[str]: | |
| characters = load_characters() | |
| char1 = characters[character1] | |
| char2 = characters[character2] | |
| conversation = [] | |
| current_speaker = char1 | |
| other_speaker = char2 | |
| # Add scenario to the conversation start | |
| if scenario: | |
| conversation.append(f"Scenario: {scenario}") | |
| for turn in range(num_turns): | |
| # Construct the prompt for the current speaker | |
| prompt = f"{current_speaker['name']}'s personality: {current_speaker['personality']}\n" | |
| prompt += f"{other_speaker['name']}'s personality: {other_speaker['personality']}\n" | |
| prompt += f"Conversation so far:\n" + "\n".join( | |
| [msg if isinstance(msg, str) else f"{msg[0]}: {msg[1]}" for msg in conversation]) | |
| # Add user interjection if provided | |
| if user_interjection and turn == num_turns // 2: | |
| prompt += f"\n\nUser interjection: {user_interjection}\n" | |
| conversation.append(f"User: {user_interjection}") | |
| prompt += f"\n\nHow would {current_speaker['name']} respond?" | |
| # FIXME - figure out why the double print is happening | |
| # Get response from the LLM | |
| response = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "", None, False, temperature, "") | |
| # Add the response to the conversation | |
| conversation.append((current_speaker['name'], response)) | |
| # Switch speakers | |
| current_speaker, other_speaker = other_speaker, current_speaker | |
| # Convert the conversation to a list of strings for output | |
| return [f"{msg[0]}: {msg[1]}" if isinstance(msg, tuple) else msg for msg in conversation] | |
| def create_multiple_character_chat_tab(): | |
| try: | |
| default_value = None | |
| if default_api_endpoint: | |
| if default_api_endpoint in global_api_endpoints: | |
| default_value = format_api_name(default_api_endpoint) | |
| else: | |
| logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") | |
| except Exception as e: | |
| logging.error(f"Error setting default API endpoint: {str(e)}") | |
| default_value = None | |
| with gr.TabItem("Multi-Character Chat", visible=True): | |
| characters, conversation, current_character, other_character = character_interaction_setup() | |
| with gr.Blocks() as character_interaction: | |
| gr.Markdown("# Multi-Character Chat") | |
| with gr.Row(): | |
| num_characters = gr.Dropdown(label="Number of Characters", choices=["2", "3", "4"], value="2") | |
| character_selectors = [gr.Dropdown(label=f"Character {i + 1}", choices=list(characters.keys())) for i in | |
| range(4)] | |
| # Refactored API selection dropdown | |
| api_endpoint = gr.Dropdown( | |
| choices=["None"] + [format_api_name(api) for api in global_api_endpoints], | |
| value=default_value, | |
| label="API for Interaction (Optional)" | |
| ) | |
| api_key = gr.Textbox(label="API Key (if required)", type="password") | |
| temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.7) | |
| scenario = gr.Textbox(label="Scenario (optional)", lines=3) | |
| chat_display = gr.Chatbot(label="Character Interaction") | |
| current_index = gr.State(0) | |
| next_turn_btn = gr.Button("Next Turn") | |
| narrator_input = gr.Textbox(label="Narrator Input", placeholder="Add a narration or description...") | |
| add_narration_btn = gr.Button("Add Narration") | |
| error_box = gr.Textbox(label="Error Messages", visible=False) | |
| reset_btn = gr.Button("Reset Conversation") | |
| chat_media_name = gr.Textbox(label="Custom Chat Name(optional)", visible=True) | |
| save_chat_history_to_db = gr.Button("Save Chat History to DataBase") | |
| def update_character_selectors(num): | |
| return [gr.update(visible=True) if i < int(num) else gr.update(visible=False) for i in range(4)] | |
| num_characters.change( | |
| update_character_selectors, | |
| inputs=[num_characters], | |
| outputs=character_selectors | |
| ) | |
| def reset_conversation(): | |
| return [], 0, gr.update(value=""), gr.update(value="") | |
| def take_turn(conversation, current_index, char1, char2, char3, char4, api_endpoint, api_key, temperature, | |
| scenario): | |
| char_selectors = [char for char in [char1, char2, char3, char4] if char] # Remove None values | |
| num_chars = len(char_selectors) | |
| if num_chars == 0: | |
| return conversation, current_index # No characters selected, return without changes | |
| if not conversation: | |
| conversation = [] | |
| if scenario: | |
| conversation.append(("Scenario", scenario)) | |
| current_character = char_selectors[current_index % num_chars] | |
| next_index = (current_index + 1) % num_chars | |
| prompt = f"Character speaking: {current_character}\nOther characters: {', '.join(char for char in char_selectors if char != current_character)}\n" | |
| prompt += "Generate the next part of the conversation, including character dialogues and actions. Characters should speak in first person." | |
| response, new_conversation, _ = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "", | |
| None, False, temperature, "") | |
| # Format the response | |
| formatted_lines = [] | |
| for line in response.split('\n'): | |
| if ':' in line: | |
| speaker, text = line.split(':', 1) | |
| formatted_lines.append(f"**{speaker.strip()}**: {text.strip()}") | |
| else: | |
| formatted_lines.append(line) | |
| formatted_response = '\n'.join(formatted_lines) | |
| # Update the last message in the conversation with the formatted response | |
| if new_conversation: | |
| new_conversation[-1] = (new_conversation[-1][0], formatted_response) | |
| else: | |
| new_conversation.append((current_character, formatted_response)) | |
| return new_conversation, next_index | |
| def add_narration(narration, conversation): | |
| if narration: | |
| conversation.append(("Narrator", narration)) | |
| return conversation, "" | |
| def take_turn_with_error_handling(conversation, current_index, char1, char2, char3, char4, api_endpoint, | |
| api_key, temperature, scenario): | |
| try: | |
| new_conversation, next_index = take_turn(conversation, current_index, char1, char2, char3, char4, | |
| api_endpoint, api_key, temperature, scenario) | |
| return new_conversation, next_index, gr.update(visible=False, value="") | |
| except Exception as e: | |
| error_message = f"An error occurred: {str(e)}" | |
| return conversation, current_index, gr.update(visible=True, value=error_message) | |
| # Define States for conversation_id and media_content, which are required for saving chat history | |
| media_content = gr.State({}) | |
| conversation_id = gr.State(str(uuid.uuid4())) | |
| next_turn_btn.click( | |
| take_turn_with_error_handling, | |
| inputs=[chat_display, current_index] + character_selectors + [api_endpoint, api_key, temperature, | |
| scenario], | |
| outputs=[chat_display, current_index, error_box] | |
| ) | |
| add_narration_btn.click( | |
| add_narration, | |
| inputs=[narrator_input, chat_display], | |
| outputs=[chat_display, narrator_input] | |
| ) | |
| reset_btn.click( | |
| reset_conversation, | |
| outputs=[chat_display, current_index, scenario, narrator_input] | |
| ) | |
| # FIXME - Implement saving chat history to database; look at Chat_UI.py for reference | |
| save_chat_history_to_db.click( | |
| save_chat_history_to_db_wrapper, | |
| inputs=[chat_display, conversation_id, media_content, chat_media_name], | |
| outputs=[conversation_id, gr.Textbox(label="Save Status")] | |
| ) | |
| return character_interaction | |
| # | |
| # End of Multi-Character chat tab | |
| ######################################################################################################################## | |
| # | |
| # Narrator-Controlled Conversation Tab | |
| # From `Fuzzlewumper` on Reddit. | |
| def create_narrator_controlled_conversation_tab(): | |
| try: | |
| default_value = None | |
| if default_api_endpoint: | |
| if default_api_endpoint in global_api_endpoints: | |
| default_value = format_api_name(default_api_endpoint) | |
| else: | |
| logging.warning(f"Default API endpoint '{default_api_endpoint}' not found in global_api_endpoints") | |
| except Exception as e: | |
| logging.error(f"Error setting default API endpoint: {str(e)}") | |
| default_value = None | |
| with gr.TabItem("Narrator-Controlled Conversation", visible=True): | |
| gr.Markdown("# Narrator-Controlled Conversation") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Refactored API selection dropdown | |
| api_endpoint = gr.Dropdown( | |
| choices=["None"] + [format_api_name(api) for api in global_api_endpoints], | |
| value=default_value, | |
| label="API for Chat Interaction (Optional)" | |
| ) | |
| api_key = gr.Textbox(label="API Key (if required)", type="password") | |
| temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.7) | |
| with gr.Column(scale=2): | |
| narrator_input = gr.Textbox( | |
| label="Narrator Input", | |
| placeholder="Set the scene or provide context...", | |
| lines=3 | |
| ) | |
| character_inputs = [] | |
| for i in range(4): # Allow up to 4 characters | |
| with gr.Row(): | |
| name = gr.Textbox(label=f"Character {i + 1} Name") | |
| description = gr.Textbox(label=f"Character {i + 1} Description", lines=3) | |
| character_inputs.append((name, description)) | |
| conversation_display = gr.Chatbot(label="Conversation", height=400) | |
| user_input = gr.Textbox(label="Your Input (optional)", placeholder="Add your own dialogue or action...") | |
| with gr.Row(): | |
| generate_btn = gr.Button("Generate Next Interaction") | |
| reset_btn = gr.Button("Reset Conversation") | |
| chat_media_name = gr.Textbox(label="Custom Chat Name(optional)", visible=True) | |
| save_chat_history_to_db = gr.Button("Save Chat History to DataBase") | |
| error_box = gr.Textbox(label="Error Messages", visible=False) | |
| # Define States for conversation_id and media_content, which are required for saving chat history | |
| conversation_id = gr.State(str(uuid.uuid4())) | |
| media_content = gr.State({}) | |
| def generate_interaction(conversation, narrator_text, user_text, api_endpoint, api_key, temperature, | |
| *character_data): | |
| try: | |
| characters = [{"name": name.strip(), "description": desc.strip()} | |
| for name, desc in zip(character_data[::2], character_data[1::2]) | |
| if name.strip() and desc.strip()] | |
| if not characters: | |
| raise ValueError("At least one character must be defined.") | |
| prompt = f"Narrator: {narrator_text}\n\n" | |
| for char in characters: | |
| prompt += f"Character '{char['name']}': {char['description']}\n" | |
| prompt += "\nGenerate the next part of the conversation, including character dialogues and actions. " | |
| prompt += "Characters should speak in first person. " | |
| if user_text: | |
| prompt += f"\nIncorporate this user input: {user_text}" | |
| prompt += "\nResponse:" | |
| response, conversation, _ = chat_wrapper(prompt, conversation, {}, [], api_endpoint, api_key, "", None, | |
| False, temperature, "") | |
| # Format the response | |
| formatted_lines = [] | |
| for line in response.split('\n'): | |
| if ':' in line: | |
| speaker, text = line.split(':', 1) | |
| formatted_lines.append(f"**{speaker.strip()}**: {text.strip()}") | |
| else: | |
| formatted_lines.append(line) | |
| formatted_response = '\n'.join(formatted_lines) | |
| # Update the last message in the conversation with the formatted response | |
| if conversation: | |
| conversation[-1] = (conversation[-1][0], formatted_response) | |
| else: | |
| conversation.append((None, formatted_response)) | |
| return conversation, gr.update(value=""), gr.update(value=""), gr.update(visible=False, value="") | |
| except Exception as e: | |
| error_message = f"An error occurred: {str(e)}" | |
| return conversation, gr.update(), gr.update(), gr.update(visible=True, value=error_message) | |
| def reset_conversation(): | |
| return [], gr.update(value=""), gr.update(value=""), gr.update(visible=False, value="") | |
| generate_btn.click( | |
| generate_interaction, | |
| inputs=[conversation_display, narrator_input, user_input, api_endpoint, api_key, temperature] + | |
| [input for char_input in character_inputs for input in char_input], | |
| outputs=[conversation_display, narrator_input, user_input, error_box] | |
| ) | |
| reset_btn.click( | |
| reset_conversation, | |
| outputs=[conversation_display, narrator_input, user_input, error_box] | |
| ) | |
| # FIXME - Implement saving chat history to database; look at Chat_UI.py for reference | |
| save_chat_history_to_db.click( | |
| save_chat_history_to_db_wrapper, | |
| inputs=[conversation_display, conversation_id, media_content, chat_media_name], | |
| outputs=[conversation_id, gr.Textbox(label="Save Status")] | |
| ) | |
| return api_endpoint, api_key, temperature, narrator_input, conversation_display, user_input, generate_btn, reset_btn, error_box | |
| # | |
| # End of Narrator-Controlled Conversation tab | |
| ######################################################################################################################## |