import os import re from dotenv import load_dotenv, find_dotenv import json import gradio as gr import torch # first import torch then transformers from torch.nn.functional import softmax from transformers import AutoModelForSequenceClassification from huggingface_hub import InferenceClient from transformers import pipeline from huggingface_hub import login from transformers import AutoTokenizer, AutoModelForCausalLM import logging import sys from datetime import datetime import psutil from typing import Dict, Any, Optional, Tuple # # Add model caching and optimization # from functools import lru_cache # import torch.nn as nn # Custom tprint function with timestamp def tprint(*args, **kwargs): timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") print(f"[{timestamp}] [{sys._getframe().f_back.f_lineno}]", *args, **kwargs) # Configure logging with timestamp and line numbers logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] [%(filename)s:%(lineno)d] %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) def get_available_memory(): """Get available GPU and system memory""" gpu_memory = None if torch.cuda.is_available(): gpu_memory = torch.cuda.get_device_properties(0).total_memory system_memory = psutil.virtual_memory().available return gpu_memory, system_memory def load_env(): _ = load_dotenv(find_dotenv()) def get_huggingface_api_key(): load_env() huggingface_api_key = os.getenv("HUGGINGFACE_API_KEY") if not huggingface_api_key: logging.error("HUGGINGFACE_API_KEY not found in environment variables") raise ValueError("HUGGINGFACE_API_KEY not found in environment variables") return huggingface_api_key def get_huggingface_inference_key(): load_env() huggingface_inference_key = os.getenv("HUGGINGFACE_INFERENCE_KEY") if not huggingface_inference_key: logging.error("HUGGINGFACE_API_KEY not found in environment variables") raise ValueError("HUGGINGFACE_API_KEY not found in environment variables") return huggingface_inference_key # Model configuration MODEL_CONFIG = { "main_model": { # "name": "meta-llama/Llama-3.2-3B-Instruct", # "name": "meta-llama/Llama-3.2-1B-Instruct", # to fit in cpu on hugging face space "name": "meta-llama/Llama-3.2-1B", # to fit in cpu on hugging face space # "name": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", # to fit in cpu on hugging face space # "name": "microsoft/phi-2", # "dtype": torch.bfloat16, "dtype": torch.float32, # Use float32 for CPU "max_length": 512, "device": "cuda" if torch.cuda.is_available() else "cpu", }, "safety_model": { "name": "meta-llama/Llama-Guard-3-1B", # "dtype": torch.bfloat16, "dtype": torch.float32, # Use float32 for CPU "max_length": 256, "device": "cuda" if torch.cuda.is_available() else "cpu", "max_tokens": 500, }, } PROMPT_GUARD_CONFIG = { "model_id": "meta-llama/Prompt-Guard-86M", "temperature": 1.0, "jailbreak_threshold": 0.5, "injection_threshold": 0.9, "device": "cpu", "safe_commands": [ "look around", "investigate", "explore", "search", "examine", "take", "use", "go", "walk", "continue", "help", "inventory", "quest", "status", "map", "talk", "fight", "run", "hide", ], "max_length": 512, } def initialize_prompt_guard(): """Initialize Prompt Guard model""" try: api_key = get_huggingface_api_key() login(token=api_key) tokenizer = AutoTokenizer.from_pretrained(PROMPT_GUARD_CONFIG["model_id"]) model = AutoModelForSequenceClassification.from_pretrained( PROMPT_GUARD_CONFIG["model_id"] ) return model, tokenizer except Exception as e: logger.error(f"Failed to initialize Prompt Guard: {e}") raise def get_class_probabilities(text: str, guard_model, guard_tokenizer) -> torch.Tensor: """Evaluate model probabilities with temperature scaling""" try: inputs = guard_tokenizer( text, return_tensors="pt", padding=True, truncation=True, max_length=PROMPT_GUARD_CONFIG["max_length"], ).to(PROMPT_GUARD_CONFIG["device"]) with torch.no_grad(): logits = guard_model(**inputs).logits scaled_logits = logits / PROMPT_GUARD_CONFIG["temperature"] return softmax(scaled_logits, dim=-1) except Exception as e: logger.error(f"Error getting class probabilities: {e}") return None def get_jailbreak_score(text: str, guard_model, guard_tokenizer) -> float: """Get jailbreak probability score""" try: probabilities = get_class_probabilities(text, guard_model, guard_tokenizer) if probabilities is None: return 1.0 # Fail safe return probabilities[0, 2].item() except Exception as e: logger.error(f"Error getting jailbreak score: {e}") return 1.0 def get_injection_score(text: str, guard_model, guard_tokenizer) -> float: """Get injection probability score""" try: probabilities = get_class_probabilities(text, guard_model, guard_tokenizer) if probabilities is None: return 1.0 # Fail safe return (probabilities[0, 1] + probabilities[0, 2]).item() except Exception as e: logger.error(f"Error getting injection score: {e}") return 1.0 # Initialize safety model pipeline try: # Initialize Prompt Guard guard_model, guard_tokenizer = initialize_prompt_guard() except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") def is_prompt_safe(message: str) -> bool: """Enhanced safety check with Prompt Guard""" try: # Allow safe game commands if any(cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"]): logger.info("Message matched safe command pattern") return True # Get safety scores jailbreak_score = get_jailbreak_score(message, guard_model, guard_tokenizer) injection_score = get_injection_score(message, guard_model, guard_tokenizer) logger.info( f"Safety scores - Jailbreak: {jailbreak_score}, Injection: {injection_score}" ) # Check against thresholds is_safe = ( jailbreak_score < PROMPT_GUARD_CONFIG["jailbreak_threshold"] # and injection_score < PROMPT_GUARD_CONFIG["injection_threshold"] # Disable for now because injection is too strict and current prompt guard model seems malfunctioning for now. ) logger.info(f"Final safety result: {is_safe}") return is_safe except Exception as e: logger.error(f"Safety check failed: {e}") return False # def initialize_model_pipeline(model_name, force_cpu=False): # """Initialize pipeline with memory management""" # try: # if force_cpu: # device = -1 # else: # device = MODEL_CONFIG["main_model"]["device"] # api_key = get_huggingface_api_key() # # Use 8-bit quantization for memory efficiency # model = AutoModelForCausalLM.from_pretrained( # model_name, # load_in_8bit=False, # torch_dtype=MODEL_CONFIG["main_model"]["dtype"], # use_cache=True, # device_map="auto", # low_cpu_mem_usage=True, # trust_remote_code=True, # token=api_key, # Add token here # ) # model.config.use_cache = True # tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key) # # Initialize pipeline # logger.info(f"Initializing pipeline with device: {device}") # generator = pipeline( # "text-generation", # model=model, # tokenizer=tokenizer, # # device=device, # # temperature=0.7, # model_kwargs={"low_cpu_mem_usage": True}, # ) # logger.info("Model Pipeline initialized successfully") # return generator, tokenizer # except ImportError as e: # logger.error(f"Missing required package: {str(e)}") # raise # except Exception as e: # logger.error(f"Failed to initialize pipeline: {str(e)}") # raise # # Initialize model pipeline # try: # # Use a smaller model for testing # # model_name = "meta-llama/Meta-Llama-3-8B-Instruct" # # model_name = "google/gemma-2-2b" # Start with a smaller model # # model_name = "microsoft/phi-2" # # model_name = "meta-llama/Llama-3.2-1B-Instruct" # # model_name = "meta-llama/Llama-3.2-3B-Instruct" # model_name = MODEL_CONFIG["main_model"]["name"] # # Initialize the pipeline with memory management # generator, tokenizer = initialize_model_pipeline(model_name) # except Exception as e: # logger.error(f"Failed to initialize model: {str(e)}") # # Fallback to CPU if GPU initialization fails # try: # logger.info("Attempting CPU fallback...") # generator, tokenizer = initialize_model_pipeline(model_name, force_cpu=True) # except Exception as e: # logger.error(f"CPU fallback failed: {str(e)}") # raise def initialize_inference_client(): """Initialize HuggingFace Inference Client""" try: inference_key = get_huggingface_inference_key() client = InferenceClient(api_key=inference_key) logger.info("Inference Client initialized successfully") return client except Exception as e: logger.error(f"Failed to initialize Inference Client: {e}") raise # Initialize inference client and make API call try: inference_client = initialize_inference_client() except Exception as e: logger.error(f"Failed to initialize the inference client model: {str(e)}") def load_world(filename): with open(filename, "r") as f: return json.load(f) # Define system_prompt and model system_prompt = """You are an AI Game master. Your job is to write what happens next in a player's adventure game. CRITICAL Rules: - Write EXACTLY 3 sentences maximum - Use daily English language - Start with "You " - Don't use 'Elara' or 'she/he', only use 'you' - Use only second person ("you") - Never include dialogue after the response - Never continue with additional actions or responses - Never add follow-up questions or choices - Never include 'User:' or 'Assistant:' in response - Never include any note or these kinds of sentences: 'Note from the game master' - Never use ellipsis (...) - Never include 'What would you like to do?' or similar prompts - Always finish with one real response - Never use 'Your turn' or or anything like conversation starting prompts - Always end the response with a period(.)""" def get_game_state(inventory: Dict = None) -> Dict[str, Any]: """Initialize game state with safe defaults and quest system""" try: # Load world data world = load_world("shared_data/Ethoria.json") character = world["kingdoms"]["Valdor"]["towns"]["Ravenhurst"]["npcs"][ "Elara Brightshield" ] tprint(f"character in get_game_state: {character}") game_state = { "name": world["name"], "world": world["description"], "kingdom": world["kingdoms"]["Valdor"]["description"], "town_name": world["kingdoms"]["Valdor"]["towns"]["Ravenhurst"]["name"], "town": world["kingdoms"]["Valdor"]["towns"]["Ravenhurst"]["description"], "character_name": character["name"], "character_description": character["description"], "start": world["start"], "inventory": inventory or { "cloth pants": 1, "cloth shirt": 1, "goggles": 1, "leather bound journal": 1, "gold": 5, }, "player": None, "dungeon": None, "current_quest": None, "completed_quests": [], "exp": 0, "level": 1, "reputation": {"Valdor": 0, "Ravenhurst": 0}, } # tprint(f"game_state in get_game_state: {game_state}") # Extract required data with fallbacks return game_state except (FileNotFoundError, KeyError, json.JSONDecodeError) as e: logger.error(f"Error loading world data: {e}") # Provide default values if world loading fails return { "world": "Ethoria is a realm of seven kingdoms, each founded on distinct moral principles.", "kingdom": "Valdor, the Kingdom of Courage", "town": "Ravenhurst, a town of skilled hunters and trappers", "character_name": "Elara Brightshield", "character_description": "A sturdy warrior with shining silver armor", "start": "Your journey begins in the mystical realm of Ethoria...", "inventory": inventory or { "cloth pants": 1, "cloth shirt": 1, "goggles": 1, "leather bound journal": 1, "gold": 5, }, "player": None, "dungeon": None, "current_quest": None, "completed_quests": [], "exp": 0, "level": 1, "reputation": {"Valdor": 0, "Ravenhurst": 0}, } def generate_dynamic_quest(game_state: Dict) -> Dict: """Generate varied quests based on progress and level""" completed = len(game_state.get("completed_quests", [])) level = game_state.get("level", 1) # Quest templates by type quest_types = { "combat": [ { "title": "The Beast's Lair", "description": "A fearsome {creature} has been terrorizing the outskirts of Ravenhurst.", "objective": "Hunt down and defeat the {creature}.", "creatures": [ "shadow wolf", "frost bear", "ancient wyrm", "spectral tiger", ], }, ], "exploration": [ { "title": "Lost Secrets", "description": "Rumors speak of an ancient {location} containing powerful artifacts.", "objective": "Explore the {location} and uncover its secrets.", "locations": [ "crypt", "temple ruins", "abandoned mine", "forgotten library", ], }, ], "mystery": [ { "title": "Dark Omens", "description": "The {sign} has appeared, marking the rise of an ancient power.", "objective": "Investigate the meaning of the {sign}.", "signs": [ "blood moon", "mysterious runes", "spectral lights", "corrupted wildlife", ], }, ], } # Select quest type and template quest_type = list(quest_types.keys())[completed % len(quest_types)] template = quest_types[quest_type][0] # Could add more templates per type # Fill in dynamic elements if quest_type == "combat": creature = template["creatures"][level % len(template["creatures"])] title = template["title"] description = template["description"].format(creature=creature) objective = template["objective"].format(creature=creature) elif quest_type == "exploration": location = template["locations"][level % len(template["locations"])] title = template["title"] description = template["description"].format(location=location) objective = template["objective"].format(location=location) else: # mystery sign = template["signs"][level % len(template["signs"])] title = template["title"] description = template["description"].format(sign=sign) objective = template["objective"].format(sign=sign) return { "id": f"quest_{quest_type}_{completed}", "title": title, "description": f"{description} {objective}", "exp_reward": 150 + (level * 50), "status": "active", "triggers": ["investigate", "explore", quest_type, "search"], "completion_text": f"You've made progress in understanding the growing darkness.", "next_quest_hint": "More mysteries await in the shadows of Ravenhurst.", } def generate_next_quest(game_state: Dict) -> Dict: """Generate next quest based on progress""" completed = len(game_state.get("completed_quests", [])) level = game_state.get("level", 1) quest_chain = [ { "id": "mist_investigation", "title": "Investigate the Mist", "description": "Strange mists have been gathering around Ravenhurst. Investigate their source.", "exp_reward": 100, "status": "active", "triggers": ["mist", "investigate", "explore"], "completion_text": "As you investigate the mist, you discover ancient runes etched into nearby stones.", "next_quest_hint": "The runes seem to point to an old hunting trail.", }, { "id": "hunters_trail", "title": "The Hunter's Trail", "description": "Local hunters have discovered strange tracks in the forest. Follow them to their source.", "exp_reward": 150, "status": "active", "triggers": ["tracks", "follow", "trail"], "completion_text": "The tracks lead to an ancient well, where you hear strange whispers.", "next_quest_hint": "The whispers seem to be coming from deep within the well.", }, { "id": "dark_whispers", "title": "Whispers in the Dark", "description": "Mysterious whispers echo from the old well. Investigate their source.", "exp_reward": 200, "status": "active", "triggers": ["well", "whispers", "listen"], "completion_text": "You discover an ancient seal at the bottom of the well.", "next_quest_hint": "The seal bears markings of an ancient evil.", }, ] # Generate dynamic quests after initial chain if completed >= len(quest_chain): return generate_dynamic_quest(game_state) # current_quest_index = min(completed, len(quest_chain) - 1) # return quest_chain[current_quest_index] return quest_chain[completed] def check_quest_completion(message: str, game_state: Dict) -> Tuple[bool, str]: """Check quest completion and handle progression""" if not game_state.get("current_quest"): return False, "" quest = game_state["current_quest"] triggers = quest.get("triggers", []) if any(trigger in message.lower() for trigger in triggers): # Award experience exp_reward = quest.get("exp_reward", 100) game_state["exp"] += exp_reward # Update player level if needed while game_state["exp"] >= 100 * game_state["level"]: game_state["level"] += 1 game_state["player"].level = ( game_state["level"] if game_state.get("player") else game_state["level"] ) level_up_text = ( f"\nLevel Up! You are now level {game_state['level']}!" if game_state["exp"] >= 100 * (game_state["level"] - 1) else "" ) # Store completed quest game_state["completed_quests"].append(quest) # Generate next quest next_quest = generate_next_quest(game_state) game_state["current_quest"] = next_quest # Update status display if game_state.get("player"): game_state["player"].exp = game_state["exp"] game_state["player"].level = game_state["level"] # Build completion message completion_msg = f""" Quest Complete: {quest['title']}! (+{exp_reward} exp){level_up_text} {quest.get('completion_text', '')} New Quest: {next_quest['title']} {next_quest['description']} {next_quest.get('next_quest_hint', '')}""" return True, completion_msg return False, "" def parse_items_from_story(text: str) -> Dict[str, int]: """Extract item changes from story text with improved pattern matching""" items = {} # Skip parsing if text starts with common narrative phrases skip_patterns = [ "you see", "you find yourself", "you are", "you stand", "you hear", "you feel", ] if any(text.lower().startswith(pattern) for pattern in skip_patterns): return items # Common item keywords and patterns gold_pattern = r"(\d+)\s*gold(?:\s+coins?)?" items_pattern = r"(?:receive|find|given|obtain|pick up|grab)\s+(?:a|an|the)?\s*(\d+)?\s*([\w\s]+?)" try: # Find gold amounts gold_matches = re.findall(gold_pattern, text.lower()) if gold_matches: items["gold"] = sum(int(x) for x in gold_matches) # Find other items item_matches = re.findall(items_pattern, text.lower()) for count, item in item_matches: # Validate item name item = item.strip() if len(item) > 2 and not any( # Minimum length check skip in item for skip in ["yourself", "you", "door", "wall", "floor"] ): # Skip common words count = int(count) if count else 1 if item in items: items[item] += count else: items[item] = count return items except Exception as e: logger.error(f"Error parsing items from story: {e}") return {} def update_game_inventory(game_state: Dict, story_text: str) -> Tuple[str, list]: """Update inventory and return message and updated inventory data""" try: items = parse_items_from_story(story_text) update_msg = "" # Update inventory for item, count in items.items(): if item in game_state["inventory"]: game_state["inventory"][item] += count else: game_state["inventory"][item] = count update_msg += f"\nReceived: {count} {item}" # Create updated inventory data for display inventory_data = [ [item, count] for item, count in game_state["inventory"].items() ] return update_msg, inventory_data except Exception as e: logger.error(f"Error updating inventory: {e}") return "", [] def extract_response_after_action(full_text: str, action: str) -> str: """Extract response text that comes after the user action line""" try: if not full_text: # Add null check logger.error("Received empty response from model") return "You look around carefully." # Split into lines lines = full_text.split("\n") # Find index of line containing user action action_line_index = -1 for i, line in enumerate(lines): if action.lower() in line.lower(): # More flexible matching action_line_index = i break if action_line_index >= 0: # Get all lines after the action line response_lines = lines[action_line_index + 1 :] response = " ".join(line.strip() for line in response_lines if line.strip()) # Clean up any remaining markers response = response.split("user:")[0].strip() response = response.split("system:")[0].strip() response = response.split("assistant:")[0].strip() return response if response else "You look around carefully." return "You look around carefully." # Default response except Exception as e: logger.error(f"Error extracting response: {e}") return "You look around carefully." def run_action(message: str, history: list, game_state: Dict) -> str: """Process game actions and generate responses with quest handling""" try: initial_quest = generate_next_quest(game_state) game_state["current_quest"] = initial_quest # Handle start game command if message.lower() == "start game": start_response = f"""Welcome to {game_state['name']}. {game_state['world']} {game_state['start']} You are currently in {game_state['town_name']}, {game_state['town']}. {game_state['town_name']} is a city in {game_state['kingdom']}. Current Quest: {initial_quest['title']} {initial_quest['description']} What would you like to do?""" return start_response # Verify game state if not isinstance(game_state, dict): logger.error(f"Invalid game state type: {type(game_state)}") return "Error: Invalid game state" # Safety check with Prompt Guard if not is_prompt_safe(message): logger.warning("Unsafe content detected in user prompt") return "I cannot process that request for safety reasons." # logger.info(f"Processing action with game state: {game_state}") logger.info(f"Processing action with game state") world_info = f"""World: {game_state['world']} Kingdom: {game_state['kingdom']} Town: {game_state['town']} Character: {game_state['character_name']} Current Quest: {game_state["current_quest"]['title']} Quest Objective: {game_state["current_quest"]['description']} Inventory: {json.dumps(game_state['inventory'])}""" # # Enhanced system prompt for better response formatting # enhanced_prompt = f"""{system_prompt} # Additional Rules: # - Always start responses with 'You ', 'You see' or 'You hear' or 'You feel' # - Use ONLY second person perspective ('you', not 'Elara' or 'she/he') # - Describe immediate surroundings and sensations # - Keep responses focused on the player's direct experience""" # messages = [ # {"role": "system", "content": system_prompt}, # {"role": "user", "content": world_info}, # ] # Properly formatted messages for API messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": world_info}, { "role": "assistant", "content": "I understand the game world and will help guide your adventure.", }, {"role": "user", "content": message}, ] # # Format chat history # if history: # for h in history: # if isinstance(h, tuple): # messages.append({"role": "assistant", "content": h[0]}) # messages.append({"role": "user", "content": h[1]}) # Add history in correct alternating format if history: # for h in history[-3:]: # Last 3 exchanges for h in history: if isinstance(h, tuple): messages.append({"role": "user", "content": h[0]}) messages.append({"role": "assistant", "content": h[1]}) # messages.append({"role": "user", "content": message}) # Convert messages to string format for pipeline prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) logger.info("Generating response...") ## Generate response # model_output = generator( # prompt, # max_new_tokens=len(tokenizer.encode(message)) # + 120, # Set max_new_tokens based on input length # num_return_sequences=1, # # temperature=0.7, # More creative but still focused # repetition_penalty=1.2, # pad_token_id=tokenizer.eos_token_id, # ) # # Check for None response # if not model_output or not isinstance(model_output, list): # logger.error(f"Invalid model output: {model_output}") # tprint(f"Invalid model output: {model_output}") # return "You look around carefully." # if not model_output[0] or not isinstance(model_output[0], dict): # logger.error(f"Invalid response format: {type(model_output[0])}") # return "You look around carefully." # # Extract and clean response # full_response = model_output[0]["generated_text"] # if not full_response: # logger.error("Empty response from model") # return "You look around carefully." # tprint(f"Full response in run_action: {full_response}") # response = extract_response_after_action(full_response, message) # tprint(f"Extracted response in run_action: {response}") # # Convert to second person # response = response.replace("Elara", "You") # # # Format response # # if not response.startswith("You"): # # response = "You see " + response # # Validate no cut-off sentences # if response.rstrip().endswith(("you also", "meanwhile", "suddenly", "...")): # response = response.rsplit(" ", 1)[0] # Remove last word # # Ensure proper formatting # response = response.rstrip("?").rstrip(".") + "." # response = response.replace("...", ".") # Initialize client and make API call # client = initialize_inference_client() client = inference_client # Generate response using Inference API completion = client.chat.completions.create( model="mistralai/Mistral-7B-Instruct-v0.3", # Use inference API model messages=messages, max_tokens=520, ) response = completion.choices[0].message.content tprint(f"Generated response Inference API: {response}") if not response: return "You look around carefully." # Safety check the responce using inference API if not is_safe(response): logger.warning("Unsafe content detected - blocking response") return "This response was blocked for safety reasons." # # Perform safety check before returning # safe = is_safe(response) # tprint(f"\nSafety Check Result: {'SAFE' if safe else 'UNSAFE'}") # logger.info(f"Safety check result: {'SAFE' if safe else 'UNSAFE'}") # if not safe: # logging.warning("Unsafe content detected - blocking response") # tprint("Unsafe content detected - Response blocked") # return "This response was blocked for safety reasons." # if safe: # # Check for quest completion # quest_completed, quest_message = check_quest_completion(message, game_state) # if quest_completed: # response += quest_message # # Check for item updates # inventory_update = update_game_inventory(game_state, response) # if inventory_update: # response += inventory_update # Check for quest completion quest_completed, quest_message = check_quest_completion(message, game_state) if quest_completed: response += quest_message # Check for item-inventory updates inventory_update, inventory_data = update_game_inventory(game_state, response) if inventory_update: response += inventory_update tprint(f"Final response in run_action: {response}") # Validate response return response if response else "You look around carefully." except KeyError as e: logger.error(f"Missing required game state key: {e}") return "Error: Game state is missing required information" except Exception as e: logger.error(f"Error generating response: {e}") return ( "I apologize, but I had trouble processing that command. Please try again." ) def update_game_status(game_state: Dict) -> Tuple[str, str]: """Generate updated status and quest display text""" # Status text status_text = ( f"Health: {game_state.get('player').health if game_state.get('player') else 100}/100\n" f"Level: {game_state.get('level', 1)}\n" f"Exp: {game_state.get('exp', 0)}/{100 * game_state.get('level', 1)}" ) # Quest text quest_text = "No active quest" if game_state.get("current_quest"): quest = game_state["current_quest"] quest_text = f"{quest['title']}\n{quest['description']}" if quest.get("next_quest_hint"): quest_text += f"\n{quest['next_quest_hint']}" return status_text, quest_text def chat_response(message: str, chat_history: list, current_state: dict) -> tuple: """Process chat input and return response with updates""" try: if not message.strip(): return chat_history, current_state, "", "", [] # Add empty inventory data # Get AI response output = run_action(message, chat_history, current_state) # Update chat history without status info chat_history = chat_history or [] chat_history.append((message, output)) # Update status displays status_text, quest_text = update_game_status(current_state) # Get inventory updates update_msg, inventory_data = update_game_inventory(current_state, output) if update_msg: output += update_msg # Return tuple includes empty string to clear input return chat_history, current_state, status_text, quest_text, inventory_data except Exception as e: logger.error(f"Error in chat response: {e}") return chat_history, current_state, "", "", [] def start_game(main_loop, game_state, share=False): """Initialize and launch game interface""" with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# AI Dungeon Adventure") # Game state storage state = gr.State(game_state) history = gr.State([]) with gr.Row(): # Game display with gr.Column(scale=3): chatbot = gr.Chatbot( height=550, placeholder="Type 'start game' to begin", ) # Input area with submit button with gr.Row(): txt = gr.Textbox( show_label=False, placeholder="What do you want to do?", container=False, ) submit_btn = gr.Button("Submit", variant="primary") clear = gr.ClearButton([txt, chatbot]) # Enhanced Status panel with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Character Status") status = gr.Textbox( label="Status", value="Health: 100/100\nLevel: 1\nExp: 0/100", interactive=False, ) quest_display = gr.Textbox( label="Current Quest", value="No active quest", interactive=False, ) inventory_data = [ [item, count] for item, count in game_state.get("inventory", {}).items() ] inventory = gr.Dataframe( value=inventory_data, headers=["Item", "Quantity"], label="Inventory", interactive=False, ) # Command suggestions gr.Examples( examples=[ "look around", "continue the story", "take sword", "go to the forest", ], inputs=txt, ) # def chat_response( # message: str, chat_history: list, current_state: dict # ) -> tuple: # """Process chat input and return response with updates""" # try: # if not message.strip(): # return chat_history, current_state, "" # Only clear input # # Get AI response # output = run_action(message, chat_history, current_state) # # Update chat history # chat_history = chat_history or [] # chat_history.append((message, output)) # # Update status if player exists # # Update displays # status_text = ( # f"Health: {current_state['player'].health}/{current_state['player'].max_health}\n" # f"Level: {current_state['player'].level}\n" # f"Exp: {current_state['player'].exp}/{current_state['player'].exp_to_level}" # ) # quest_text = "No active quest" # if current_state["current_quest"]: # quest = current_state["current_quest"] # quest_text = f"{quest['title']}\n{quest['description']}" # # Update inventory display # inv_data = [ # [item, count] # for item, count in current_state.get("inventory", {}).items() # ] # return (chat_history, current_state, status_text, quest_text) # except Exception as e: # logger.error(f"Error in chat response: {e}") # return ( # chat_history + [(message, "Error processing command")], # current_state, # status.value, # quest_display.value, # ) # # Connect components # submit_btn.click( # chat_response, # inputs=[txt, chatbot, state], # # outputs=[chatbot, state, txt], # outputs=[chatbot, state, status, quest_display], # show_progress=True, # ) # # txt.submit( # # chat_response, # # [txt, chatbot, state], # # [chatbot, state, txt], # # show_progress=True, # # ) # txt.submit( # chat_response, # inputs=[txt, chatbot, state], # outputs=[chatbot, state, status, quest_display], # Keep status separate # show_progress=True, # ) def submit_action(message, history, state): # Process response new_history, new_state, status_text, quest_text, inventory_data = ( chat_response(message, history, state) ) # Update inventory display inventory.value = inventory_data # Clear input return "", new_history, new_state, status_text, quest_text, inventory submit_btn.click( submit_action, inputs=[txt, chatbot, state], outputs=[txt, chatbot, state, status, quest_display, inventory], ) txt.submit( submit_action, inputs=[txt, chatbot, state], outputs=[txt, chatbot, state, status, quest_display, inventory], ) demo.launch(share=share) everyone_content_policy = { "policy": """O1: Violence and Hate. Should not - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics. - Use slurs based on personal characteristics - Write gruesome / graphic violence. Can - Write tame fantasy violence O2: Sexual Content. Should not - Write any somewhat sexual content that mentions sexual functions or sexual body parts. - Write any content that would be inappropriate for kids. Can - Write romantic content (including kissing, flirting etc...) O3: Self-Harm. Should not - Encourage people to harm themselves. - Romanticize or condone self-harm or suicide. - Write story depicting suicide. O4: Profanity. Should not - Write with any profane language that wouldn't be appropriate for kids. Can - Write with kid appropriate language O5: Regulated or Controlled Substances. Should not - Write stories that depicts, glorifies or encourages drug use.""" } def initialize_safety_client(): """Initialize HuggingFace Inference Client""" try: inference_key = get_huggingface_inference_key() # api_key = get_huggingface_api_key() return InferenceClient(api_key=inference_key) except Exception as e: logger.error(f"Failed to initialize safety client: {e}") raise # Initialize safety model pipeline try: safety_client = initialize_safety_client() except Exception as e: logger.error(f"Failed to initialize model: {str(e)}") # Fallback to CPU if GPU initialization fails def is_safe(message: str) -> bool: """Check content safety using Inference API""" try: # client = initialize_safety_client() messages = [ {"role": "user", "content": f"Check if this content is safe:\n{message}"}, { "role": "assistant", "content": f"I will check if the content is safe based on this content policy:\n{everyone_content_policy['policy']}", }, {"role": "user", "content": "Is it safe or unsafe?"}, ] try: completion = safety_client.chat.completions.create( model=MODEL_CONFIG["safety_model"]["name"], messages=messages, max_tokens=MODEL_CONFIG["safety_model"]["max_tokens"], temperature=0.1, ) response = completion.choices[0].message.content.lower() logger.info(f"Safety check response: {response}") is_safe = "safe" in response and "unsafe" not in response logger.info(f"Safety check result: {'SAFE' if is_safe else 'UNSAFE'}") return is_safe except Exception as api_error: logger.error(f"API error: {api_error}") # Fallback to allow common game commands return any( cmd in message.lower() for cmd in PROMPT_GUARD_CONFIG["safe_commands"] ) except Exception as e: logger.error(f"Safety check failed: {e}") return False # def init_safety_model(model_name, force_cpu=False): # """Initialize safety checking model with optimized memory usage""" # try: # if force_cpu: # device = -1 # else: # device = MODEL_CONFIG["safety_model"]["device"] # # model_id = "meta-llama/Llama-Guard-3-8B" # # model_id = "meta-llama/Llama-Guard-3-1B" # api_key = get_huggingface_api_key() # safety_model = AutoModelForCausalLM.from_pretrained( # model_name, # token=api_key, # torch_dtype=MODEL_CONFIG["safety_model"]["dtype"], # use_cache=True, # device_map="auto", # ) # safety_model.config.use_cache = True # safety_tokenizer = AutoTokenizer.from_pretrained(model_name, token=api_key) # # Set pad token explicitly # safety_tokenizer.pad_token = safety_tokenizer.eos_token # logger.info(f"Safety model initialized successfully on {device}") # return safety_model, safety_tokenizer # except Exception as e: # logger.error(f"Failed to initialize safety model: {e}") # raise # # Initialize safety model pipeline # try: # safety_model_name = MODEL_CONFIG["safety_model"]["name"] # api_key = get_huggingface_api_key() # # Initialize the pipeline with memory management # safety_model, safety_tokenizer = init_safety_model(safety_model_name) # except Exception as e: # logger.error(f"Failed to initialize model: {str(e)}") # # Fallback to CPU if GPU initialization fails # try: # logger.info("Attempting CPU fallback...") # safety_model, safety_tokenizer = init_safety_model( # safety_model_name, force_cpu=True # ) # except Exception as e: # logger.error(f"CPU fallback failed: {str(e)}") # raise # @lru_cache(maxsize=1000) # def get_safety_response(prompt: str) -> str: # """Cached safety check responses""" # with torch.no_grad(): # inputs = safety_tokenizer( # prompt, # return_tensors="pt", # padding=True, # truncation=True, # max_length=MODEL_CONFIG["safety_model"]["max_length"], # ).to(safety_model.device) # output = safety_model.generate( # **inputs, # max_new_tokens=5, # temperature=0.1, # Lower temperature for more consistent output # pad_token_id=safety_tokenizer.pad_token_id, # eos_token_id=safety_tokenizer.eos_token_id, # ) # return safety_tokenizer.decode(output[0], skip_special_tokens=True) # def is_safe(message: str) -> bool: # """Check if message content is safe""" # try: # logger.info("Performing safety check...") # # Build the prompt with embedded values # # prompt = f"""[INST] Task: Check if this content is safe according to our safety policy with the below categories. # # # # {everyone_content_policy['policy']} # # # # # # {message} # # # # Provide your safety assessment for user in the above CONTENT TO CHECK section: # # - Reply ONLY with single word 'safe' or 'unsafe' [/INST] # # """ # prompt = f"""[INST]Check if safe: # {message} # Reply 'safe' or 'unsafe'[/INST] # """ # # Generate safety check on cuda # # with torch.no_grad(): # # inputs = safety_tokenizer( # # prompt, # # return_tensors="pt", # # padding=True, # # truncation=True, # # ) # # # Move inputs to correct device # # inputs = {k: v.to(device) for k, v in inputs.items()} # # output = safety_model.generate( # # **inputs, # # max_new_tokens=10, # # temperature=0.1, # Lower temperature for more consistent output # # pad_token_id=safety_tokenizer.pad_token_id, # Use configured pad token # # eos_token_id=safety_tokenizer.eos_token_id, # # do_sample=False, # # ) # # result = safety_tokenizer.decode(output[0], skip_special_tokens=True) # result = get_safety_response(prompt) # tprint(f"Raw safety check result: {result}") # # # Extract response after prompt # # if "[/INST]" in result: # # result = result.split("[/INST]")[-1] # # # Clean response # # result = result.lower().strip() # # tprint(f"Cleaned safety check result: {result}") # # words = [word for word in result.split() if word in ["safe", "unsafe"]] # # # Take first valid response word # # is_safe = words[0] == "safe" if words else False # # tprint("Final Safety check result:", is_safe) # is_safe = "safe" in result.lower().split() # logger.info( # f"Safety check completed - Result: {'SAFE' if is_safe else 'UNSAFE'}" # ) # return is_safe # except Exception as e: # logger.error(f"Safety check failed: {e}") # return False # def detect_inventory_changes(game_state, output): # inventory = game_state["inventory"] # messages = [ # {"role": "system", "content": system_prompt}, # {"role": "user", "content": f"Current Inventory: {str(inventory)}"}, # {"role": "user", "content": f"Recent Story: {output}"}, # {"role": "user", "content": "Inventory Updates"}, # ] # input_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) # model_output = generator(input_text, num_return_sequences=1, temperature=0.0) # response = model_output[0]["generated_text"] # result = json.loads(response) # return result["itemUpdates"] # def update_inventory(inventory, item_updates): # update_msg = "" # for update in item_updates: # name = update["name"] # change_amount = update["change_amount"] # if change_amount > 0: # if name not in inventory: # inventory[name] = change_amount # else: # inventory[name] += change_amount # update_msg += f"\nInventory: {name} +{change_amount}" # elif name in inventory and change_amount < 0: # inventory[name] += change_amount # update_msg += f"\nInventory: {name} {change_amount}" # if name in inventory and inventory[name] < 0: # del inventory[name] # return update_msg logging.info("Finished helper function")