PeAI / app.py
MrA7A's picture
Create app.py
e26acf0 verified
# app.py (Strategic Agent Service for Hugging Face Spaces - CPU Only, Preload All Models, No ngrok)
import os
import json
import logging
import numpy as np
import requests
from fastapi import FastAPI, HTTPException, Depends, status
from pydantic import BaseModel, Field, constr
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from datetime import datetime
import firebase_admin
from firebase_admin import credentials, firestore
from bs4 import BeautifulSoup
import re
from typing import List, Dict, Optional, Tuple
from cachetools import TTLCache
import gc
from llama_cpp import Llama
import asyncio
import nest_asyncio
from fastapi.responses import JSONResponse # Added for explicit JSONResponse
# Apply nest_asyncio to allow running asyncio.run() in environments with existing event loops
nest_asyncio.apply()
# --- Configuration ---
# Directory to store downloaded GGUF models within Hugging Face Space's writable space
DOWNLOAD_DIR = "./downloaded_models/" # Changed to a local directory within the Space
os.makedirs(DOWNLOAD_DIR, exist_ok=True)
# Predefined Hugging Face GGUF model URLs for dynamic loading
HUGGINGFACE_MODELS = [
{
"name": "Foundation-Sec-8B-Q8_0",
"url": "https://huggingface.co/fdtn-ai/Foundation-Sec-8B-Q8_0-GGUF/resolve/main/foundation-sec-8b-q8_0.gguf"
},
{
"name": "Lily-Cybersecurity-7B-v0.2-Q8_0",
"url": "https://huggingface.co/Nekuromento/Lily-Cybersecurity-7B-v0.2-Q8_0-GGUF/resolve/main/lily-cybersecurity-7b-v0.2-q8_0.gguf"
},
{
"name": "SecurityLLM-GGUF (sarvam-m-q8_0)",
"url": "https://huggingface.co/QuantFactory/SecurityLLM-GGUF/resolve/main/sarvam-m-q8_0.gguf"
}
]
DATA_DIR = "./data" # Local data for Hugging Face Space
DEEP_SEARCH_CACHE_TTL = 3600
# --- ngrok Configuration (Removed) ---
# NGROK_AUTH_TOKEN and NGROK_STRATEGIC_AGENT_TUNNEL_URL are removed
# --- Logging Setup ---
logging.basicConfig(
level=logging.DEBUG, # Changed from INFO to DEBUG
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
logger.info("Logging initialized with DEBUG level.")
# Initialize FastAPI app
app = FastAPI(
title="Hugging Face Strategic Agent Service",
description="Provides knowledge base access and strategic reasoning for the pentest agent on Hugging Face Spaces.",
version="1.0.0"
)
# Initialize Firebase
firebase_creds_path = os.getenv("FIREBASE_CREDS_PATH", "cred.json")
db = None
if not firebase_admin._apps:
try:
if os.path.exists(firebase_creds_path):
cred = credentials.Certificate(firebase_creds_path)
firebase_admin.initialize_app(cred)
db = firestore.client()
logger.info("Firebase initialized successfully.")
else:
logger.warning(f"Firebase credentials file not found at {firebase_creds_path}. Firebase will not be initialized.")
except Exception as e:
logger.error(f"Failed to initialize Firebase: {e}. Ensure FIREBASE_CREDS_PATH is set correctly and the file exists.", exc_info=True)
# Global LLM instance for Strategic Agent
strategic_llm: Optional[Llama] = None
current_strategic_model_url: Optional[str] = None # Now tracks URL, not local path
# Supported tools (Strategic Agent needs to know these for command generation)
SUPPORTED_TOOLS = [
"nmap", "gobuster", "nikto", "sqlmap", "adb", "frida",
"drozer", "apktool", "msfconsole", "mobsfscan", "burpsuite",
"metasploit", "curl", "wget", "hydra", "john", "aircrack-ng"
]
# --- Deep Search Cache ---
deep_search_cache = TTLCache(maxsize=100, ttl=DEEP_SEARCH_CACHE_TTL)
# --- Enhanced System Instruction (English) ---
SYSTEM_INSTRUCTION = (
"You are an expert pentest agent. Strictly follow these rules:\n"
"1. Output ONLY valid shell commands\n"
"2. NEVER include timestamps, dates, or any text outside commands\n"
"3. Never repeat previous commands\n"
"4. Always verify command safety before execution\n\n"
"Example valid response:\n"
"nmap -sV 192.168.1.6\n\n"
"Key Principles:\n"
"- Never give up until the goal is achieved\n"
"- Learn from failures and adapt strategies\n"
"- Leverage all available knowledge and tools\n"
"- Break complex tasks into smaller achievable steps\n"
"- Always ensure actions are ethical and within scope\n\n"
"Available Tools:\n"
"- nmap: Network scanning and service detection\n"
"- gobuster: Web directory brute-forcing\n"
"- nikto: Web server vulnerability scanner\n"
"- sqlmap: SQL injection testing\n"
"- adb: Android Debug Bridge\n"
"- metasploit: Exploitation framework\n\n"
"Error Handling Examples:\n"
"Example 1 (Command Failure):\n"
" If nmap fails because host is down, try: nmap -Pn -sV 192.168.1.6\n"
"Example 2 (Web Server Error):\n"
" If web server returns 403, try: gobuster dir -u http://192.168.1.6 -w /usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt\n"
"Example 3 (ADB Connection Failed):\n"
" If ADB connection fails, try: adb kill-server && adb start-server"
)
# --- Firebase Knowledge Base Integration ---
class FirebaseKnowledgeBase:
def __init__(self):
self.collection = db.collection('knowledge_base') if db else None
def query(self, goal: str, phase: str = None, limit: int = 10) -> list:
if not db or not firebase_admin._apps: # Check if Firebase is initialized
logger.error("Firestore client not initialized. Cannot query knowledge base.")
return []
# Re-instantiate collection if it's None (e.g., if Firebase init failed initially)
if not hasattr(self, 'collection') or self.collection is None:
self.collection = db.collection('knowledge_base')
keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device']
if phase:
keywords.append(phase.lower())
try:
query_ref = self.collection
results = []
docs = query_ref.stream() # Use query_ref instead of self.collection directly
for doc in docs:
data = doc.to_dict()
text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}"
if any(keyword in text for keyword in keywords):
results.append(data)
if len(results) >= 10: # Use a fixed limit for stream
break
priority_order = {"high": 1, "medium": 2, "low": 3}
results.sort(key=lambda x: (
priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3),
x.get('metadata', {}).get('timestamp', 0)
))
return results[:10] # Ensure limit is applied
except Exception as e:
logger.error(f"Failed to query knowledge base: {e}", exc_info=True)
return []
# --- RAG Knowledge Index ---
class KnowledgeIndex:
def __init__(self, model_name="all-MiniLM-L6-v2"):
self.model = SentenceTransformer(
model_name,
cache_folder=os.path.join(DATA_DIR, "hf_cache") # Use local data dir for cache
)
self.knowledge_base = []
os.makedirs(DATA_DIR, exist_ok=True)
self.load_knowledge_from_file(os.path.join(DATA_DIR, 'knowledge_base.json'))
def load_knowledge_from_file(self, file_path):
logger.debug(f"Attempting to load knowledge from file: {file_path}")
if os.path.exists(file_path):
try:
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if not isinstance(data, list):
logger.error("Knowledge base file is not a list. Please check the file format.")
return
for item in data:
if isinstance(item, dict):
text = item.get('text', '')
source = item.get('source', 'local')
elif isinstance(item, str):
text = item
source = 'local'
else:
logger.warning(f"Skipping unsupported item type: {type(item)}")
continue
if text:
embedding = self.model.encode(text).tolist()
self.knowledge_base.append({'text': text, 'embedding': embedding, 'source': source})
logger.info(f"Loaded {len(self.knowledge_base)} items into RAG knowledge base.")
except Exception as e:
logger.error(f"Error loading knowledge from {file_path}: {e}", exc_info=True)
else:
logger.warning(f"Knowledge base file not found: {file_path}. RAG will operate on an empty knowledge base.")
try:
with open(file_path, 'w', encoding='utf-8') as f:
json.dump([], f)
logger.info(f"Created empty knowledge base file at: {file_path}")
except Exception as e:
logger.error(f"Error creating empty knowledge base file at {file_path}: {e}", exc_info=True)
def retrieve(self, query: str, top_k: int = 5) -> List[Dict]:
if not self.knowledge_base:
logger.debug("Knowledge base is empty, no RAG retrieval possible.")
return []
try:
query_embedding = self.model.encode(query).reshape(1, -1)
embeddings = np.array([item['embedding'] for item in self.knowledge_base])
similarities = cosine_similarity(query_embedding, embeddings)[0]
top_indices = similarities.argsort()[-top_k:][::-1]
results = []
for i in top_indices:
results.append({
"text": self.knowledge_base[i]['text'],
"similarity": similarities[i],
"source": self.knowledge_base[i].get('source', 'RAG')
})
logger.debug(f"RAG retrieved {len(results)} results for query: '{query}'")
return results
except Exception as e:
logger.error(f"Error during RAG retrieval for query '{query}': {e}", exc_info=True)
return []
# --- Deep Search Engine ---
class DeepSearchEngine:
def __init__(self):
self.headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
def search_device_info(self, device_info: str, os_version: str) -> dict:
logger.debug(f"Performing deep search for device: {device_info}, OS: {os_version}")
results = {
"device": device_info,
"os_version": os_version,
"vulnerabilities": [],
"exploits": [],
"recommendations": []
}
try:
cve_results = self.search_cve(device_info, os_version)
results["vulnerabilities"] = cve_results
exploit_results = self.search_exploits(device_info, os_version)
results["exploits"] = exploit_results
recommendations = self.get_security_recommendations(os_version)
results["recommendations"] = recommendations
logger.debug("Deep search completed.")
except Exception as e:
logger.error(f"Deep search failed: {e}", exc_info=True)
return results
def search_cve(self, device: str, os_version: str) -> list:
cves = []
try:
query = f"{device} {os_version} CVE"
search_url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}"
logger.debug(f"Searching CVE Mitre: {search_url}")
response = requests.get(search_url, headers=self.headers)
response.raise_for_status() # Raise an exception for HTTP errors
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
table = soup.find('div', id='TableWithRules')
if table:
rows = table.find_all('tr')[1:]
for row in rows:
cols = row.find_all('td')
if len(cols) >= 2:
cve_id = cols[0].get_text(strip=True)
description = cols[1].get_text(strip=True)
cves.append({
"cve_id": cve_id,
"description": description,
"source": "CVE Mitre"
})
logger.debug(f"Found {len(cves)} CVEs.")
return cves[:10]
except Exception as e:
logger.error(f"CVE search failed: {e}", exc_info=True)
return []
def search_exploits(self, device: str, os_version: str) -> list:
exploits = []
try:
query = f"{device} {os_version}"
search_url = f"https://www.exploit-db.com/search?q={query}"
logger.debug(f"Searching ExploitDB: {search_url}")
response = requests.get(search_url, headers=self.headers)
response.raise_for_status() # Raise an exception for HTTP errors
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
cards = soup.select('.card .card-title')
for card in cards:
title = card.get_text(strip=True)
link = card.find('a')['href']
if not link.startswith('http'):
link = f"https://www.exploit-db.com{link}"
exploits.append({
"title": title,
"link": link,
"source": "ExploitDB"
})
logger.debug(f"Found {len(exploits)} exploits.")
return exploits[:10]
except Exception as e:
logger.error(f"Exploit search failed: {e}", exc_info=True)
return []
def get_security_recommendations(self, os_version: str) -> list:
recommendations = []
try:
logger.debug(f"Getting security recommendations for OS: {os_version}")
if "android" in os_version.lower():
url = "https://source.android.com/docs/security/bulletin"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
versions = soup.select('.devsite-article-body h2')
for version in versions:
if os_version in version.get_text():
next_ul = version.find_next('ul')
if next_ul:
items = next_ul.select('li')
for item in items:
recommendations.append(item.get_text(strip=True))
elif "ios" in os_version.lower():
url = "https://support.apple.com/en-us/HT201222"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
sections = soup.select('#sections')
for section in sections:
if os_version in section.get_text():
items = section.select('li')
for item in items:
recommendations.append(item.get_text(strip=True))
logger.debug(f"Found {len(recommendations)} recommendations.")
return recommendations[:5]
except Exception as e:
logger.error(f"Security recommendations search failed: {e}", exc_info=True)
return []
def search_public_resources(self, device_info: str) -> list:
resources = []
try:
logger.debug(f"Searching public resources for device: {device_info}")
github_url = f"https://github.com/search?q={device_info.replace(' ', '+')}+pentest"
response = requests.get(github_url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
repos = soup.select('.repo-list-item')
for repo in repos:
title = repo.select_one('.v-align-middle').get_text(strip=True)
description = repo.select_one('.mb-1').get_text(strip=True) if repo.select_one('.mb-1') else ""
url = f"https://github.com{repo.select_one('.v-align-middle')['href']}"
resources.append({
"title": title,
"description": description,
"url": url,
"source": "GitHub"
})
forum_url = f"https://hackforums.net/search.php?action=finduserthreads&keywords={device_info.replace(' ', '+')}"
response = requests.get(forum_url, headers=self.headers)
response.raise_for_status()
if response.status_code == 200:
soup = BeautifulSoup(response.text, 'html.parser')
threads = soup.select('.thread')
for thread in threads:
title = thread.select_one('.threadtitle').get_text(strip=True)
url = f"https://hackforums.net{thread.select_one('.threadtitle a')['href']}"
resources.append({
"title": title,
"description": "Forum discussion",
"url": url,
"source": "HackForums"
})
logger.debug(f"Found {len(resources)} public resources.")
return resources[:10]
except Exception as e:
logger.error(f"Public resources search failed: {e}", exc_info=True)
return []
# --- Initialize Services (Local to Strategic Agent) ---
firebase_kb = FirebaseKnowledgeBase()
rag_index = KnowledgeIndex()
deep_search_engine = DeepSearchEngine()
# --- Strategic Agent Brain (formerly SmartExecutionEngine logic) ---
class StrategicAgentBrain:
def __init__(self):
self.llm: Optional[Llama] = None
self.current_goal: Optional[str] = None
self.current_phase: str = "initial_reconnaissance"
self.current_plan: List[Dict] = []
self.current_phase_index: int = 0
self.identified_vulnerabilities: List[Dict] = []
self.gathered_info: List[str] = []
self.command_retry_counts: Dict[str, int] = {}
self.conversation_history: List[Dict] = []
self.used_commands = set()
self.execution_history = []
self.goal_achieved = False
self.no_progress_count = 0
self.react_cycle_count = 0
self.loaded_model_name: Optional[str] = None # To store the name of the loaded model
logger.info("StrategicAgentBrain initialized.")
async def load_strategic_llm(self, model_url: str):
global strategic_llm, current_strategic_model_url
logger.info(f"Attempting to load strategic LLM from URL: {model_url}")
# Determine local path for the model
model_filename = model_url.split('/')[-1]
local_model_path = os.path.join(DOWNLOAD_DIR, model_filename)
if strategic_llm and current_strategic_model_url == model_url:
logger.info(f"Strategic LLM model from {model_url} is already loaded.")
self.llm = strategic_llm
return True, f"Model '{self.loaded_model_name}' is already loaded."
# If a model is currently loaded, unload it first
if strategic_llm:
await self.unload_strategic_llm()
# Ensure model is downloaded before attempting to load
if not os.path.exists(local_model_path):
logger.info(f"Model not found locally. Attempting to download from {model_url} to {local_model_path}...")
try:
response = requests.get(model_url, stream=True)
response.raise_for_status()
with open(local_model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Model downloaded successfully to {local_model_path}.")
except Exception as e:
logger.error(f"Failed to download model from {model_url}: {e}", exc_info=True)
return False, f"Failed to download model: {str(e)}"
try:
logger.info(f"Loading Strategic LLM model from {local_model_path}...")
strategic_llm = Llama(
model_path=local_model_path,
n_ctx=3096,
n_gpu_layers=0, # Explicitly set to 0 for CPU-only
n_threads=os.cpu_count(), # Use all available CPU threads
n_batch=512,
verbose=False
)
current_strategic_model_url = model_url
self.llm = strategic_llm
self.loaded_model_name = model_filename # Store the filename
logger.info(f"Strategic LLM model {model_filename} loaded successfully (CPU-only).")
return True, f"Model '{model_filename}' loaded successfully (CPU-only)."
except Exception as e:
logger.error(f"Failed to load Strategic LLM model from {local_model_path}: {e}", exc_info=True)
strategic_llm = None
current_strategic_model_url = None
self.llm = None
self.loaded_model_name = None
return False, f"Failed to load model: {str(e)}"
async def unload_strategic_llm(self):
global strategic_llm, current_strategic_model_url
if strategic_llm:
logger.info("Unloading Strategic LLM model...")
del strategic_llm
strategic_llm = None
current_strategic_model_url = None
self.llm = None
self.loaded_model_name = None
gc.collect()
logger.info("Strategic LLM model unloaded.")
def _get_rag_context(self, query: str) -> str:
results = rag_index.retrieve(query)
if not results:
return ""
rag_context = "Relevant Knowledge for Current Context:\n"
for i, result in enumerate(results):
text = result.get('text', '') or result.get('completion', '')
source = result.get('source', 'RAG')
rag_context += f"{i+1}. [{source}] {text}\n"
return rag_context
def _get_firebase_knowledge(self, goal: str, phase: str = None) -> str:
if not db or not firebase_admin._apps: # Check if Firebase is initialized
logger.error("Firestore client not initialized. Cannot query knowledge base.")
return ""
# Re-instantiate collection if it's None (e.g., if Firebase init failed initially)
if not hasattr(self, 'collection') or self.collection is None:
self.collection = db.collection('knowledge_base')
keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device']
if phase:
keywords.append(phase.lower())
try:
query_ref = self.collection
results = []
docs = query_ref.stream() # Use query_ref instead of self.collection directly
for doc in docs:
data = doc.to_dict()
text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}"
if any(keyword in text for keyword in keywords):
results.append(data)
if len(results) >= 10: # Use a fixed limit for stream
break
priority_order = {"high": 1, "medium": 2, "low": 3}
results.sort(key=lambda x: (
priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3),
x.get('metadata', {}).get('timestamp', 0)
))
return results[:10] # Ensure limit is applied
except Exception as e:
logger.error(f"Failed to query knowledge base: {e}", exc_info=True)
return ""
def extract_device_info(self) -> str:
for info in self.gathered_info:
if "model" in info.lower() or "device" in info.lower():
match = re.search(r'(?:model|device)\s*[:=]\s*([^\n]+)', info, re.IGNORECASE)
if match:
return match.group(1).strip()
ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
return ip_match.group(0) if ip_match else "Unknown Device"
def extract_os_version(self) -> str:
for info in self.gathered_info:
if "android" in info.lower() or "ios" in info.lower() or "os" in info.lower():
android_match = re.search(r'android\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE)
if android_match:
return f"Android {android_match.group(1)}"
ios_match = re.search(r'ios\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE)
if ios_match:
return f"iOS {ios_match.group(1)}"
linux_match = re.search(r'linux\s+kernel\s+(\d+\.\d+\.\d+)', info, re.IGNORECASE)
if linux_match:
return f"Linux {linux_match.group(1)}"
return "Unknown OS Version"
def format_deep_search_results(self, results: dict) -> str:
context = "Deep Search Results:\n"
context += f"Device: {results.get('device', 'Unknown')}\n"
context += f"OS Version: {results.get('os_version', 'Unknown')}\n\n"
if results.get('vulnerabilities'):
context += "Discovered Vulnerabilities:\n"
for i, vuln in enumerate(results['vulnerabilities'][:5], 1):
context += f"{i}. {vuln.get('cve_id', 'CVE-XXXX-XXXX')}: {vuln.get('description', 'No description')}\n"
context += "\n"
if results.get('exploits'):
context += "Available Exploits:\n"
for i, exploit in enumerate(results['exploits'][:5], 1):
context += f"{i}. {exploit.get('title', 'Untitled exploit')} [Source: {exploit.get('source', 'Unknown')}]\n"
context += "\n"
if results.get('recommendations'):
context += "Security Recommendations:\n"
for i, rec in enumerate(results['recommendations'][:3], 1):
context += f"{i}. {rec}\n"
context += "\n"
if results.get('public_resources'):
context += "Public Resources:\n"
for i, res in enumerate(results['public_resources'][:3], 1):
context += f"{i}. {res.get('title', 'Untitled resource')} [Source: {res.get('source', 'Unknown')}]\n"
return context
def generate_deep_search_prompt(self, context: str) -> str:
return f"""
You are an expert pentester. Below are deep search results for the target device.
Use this information to generate the next penetration testing command.{context}
Current Goal: {self.current_goal}
Current Phase: {self.current_phase}
Recent Command History:{', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'}
Based on this information, what is the SINGLE MOST EFFECTIVE shell command to execute next?
Focus on exploiting the most critical vulnerabilities or gathering more information.
Response Format:
Command: <your_command_here>
"""
def _generate_llm_prompt(self) -> str:
rag_context = self._get_rag_context(f"{self.current_goal} {self.current_phase}")
firebase_knowledge = self._get_firebase_knowledge(self.current_goal, self.current_phase)
history_context = "\n".join(
f"{entry['role']}: {entry['content']}" for entry in self.conversation_history[-2:]
)
execution_history = "\n".join(
f"Command: {res['command']}\nResult: {res['output'][:100]}...\nSuccess: {res['success']}"
for res in self.execution_history[-2:]
) if self.execution_history else "No previous results."
strategic_advice = self._get_rag_context(self.current_phase) # Using RAG for strategic advice too
def shorten_text(text, max_length=300):
if len(text) > max_length:
return text[:max_length] + "... [truncated]"
return text
rag_context = shorten_text(rag_context, max_length=200)
firebase_knowledge = shorten_text(firebase_knowledge, max_length=200)
strategic_advice = shorten_text(strategic_advice, max_length=100)
history_context = shorten_text(history_context, max_length=150)
execution_history = shorten_text(execution_history, max_length=500)
prompt = f"""
System Instructions: {SYSTEM_INSTRUCTION}
Current Goal: '{self.current_goal}'
Current Phase: {self.current_phase} - {self.current_plan[self.current_phase_index]['objective'] if self.current_plan and self.current_phase_index < len(self.current_plan) else 'No objective'}
Based on the following knowledge and previous results, generate the SINGLE, VALID SHELL COMMAND to advance the penetration testing process.
**Knowledge from External Services (RAG & Firebase):**
{rag_context}
{firebase_knowledge}
**Previous Execution Results:**
{execution_history}
**Recent Conversation History:**
{history_context}
**Strategic Advice for Current Phase:**
{strategic_advice}
***CRITICAL RULES FOR OUTPUT:***
1. **OUTPUT ONLY THE COMMAND.**
2. **DO NOT INCLUDE ANY REASONING, THOUGHTS, EXPLANATIONS, OR ANY OTHER TEXT.**
3. The command MUST be directly executable in a Linux terminal.
4. Avoid repeating these recent commands: {', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'}
5. If the previous command failed, try a different approach or a related tool.
6. For the 'android_enumeration' phase, prioritize ADB commands.
Example valid commands for initial reconnaissance of an Android phone:
nmap -sV -Pn 192.168.1.14
adb devices
adb connect 192.168.1.14:5555
Command:
"""
return prompt
def _get_llm_response(self, custom_prompt: str = None) -> str:
if not self.llm:
logger.error("Strategic LLM instance is None. Cannot get response. Please load a model first.")
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
return f"Command: echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {target_ip}'"
prompt = custom_prompt if custom_prompt else self._generate_llm_prompt()
logger.info(f"Sending prompt to Strategic LLM:\n{prompt[:500]}...")
try:
response = self.llm(
prompt,
max_tokens=512,
temperature=0.3,
stop=["\n"]
)
llm_response = response['choices'][0]['text'].strip()
logger.info(f"Strategic LLM raw response: {llm_response}")
if not llm_response:
logger.warning("Strategic LLM returned an empty response. Using fallback command.")
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
return f"Command: nmap -sV -Pn {target_ip}"
return llm_response
except Exception as e:
logger.error(f"Error during Strategic LLM inference: {e}", exc_info=True)
logger.warning("Strategic LLM inference failed. Using fallback command.")
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "")
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
return f"Command: nmap -sV -Pn {target_ip}"
def parse_llm_response(self, response: str) -> str:
logger.info(f"Attempting to parse LLM response: '{response}'")
command = None
try:
code_block = re.search(r'```(?:bash|sh)?\s*([\s\S]*?)```', response)
if code_block:
command = code_block.group(1).strip()
logger.info(f"Command extracted from code block: '{command}'")
if not command:
command_match = re.search(r'^\s*Command\s*:\s*(.+)$', response, re.MULTILINE | re.IGNORECASE)
if command_match:
command = command_match.group(1).strip()
logger.info(f"Command extracted from 'Command:' line: '{command}'")
if not command:
stripped_response = response.strip()
if any(stripped_response.startswith(tool) for tool in SUPPORTED_TOOLS):
command = stripped_response
logger.info(f"Command extracted as direct supported tool command: '{command}'")
if command:
original_command = command
command = re.sub(r'^\s*(Command|Answer|Note|Result)\s*[:.-]?\s*', '', command, flags=re.IGNORECASE).strip()
logger.info(f"Cleaned command: from '{original_command}' to '{command}'")
if not re.match(r'^[a-zA-Z0-9_./:;= \-\'"\s]+$', command):
logger.error(f"Invalid command characters detected after cleanup: '{command}'")
return None
if re.search(r'(reason|thought|explanation|rationale|note|result):', command, re.IGNORECASE):
logger.warning(f"Command '{command}' appears to be reasoning/explanation. Rejecting.")
return None
if command not in self.used_commands:
self.used_commands.add(command)
logger.info(f"Returning valid and new command: '{command}'")
return command
else:
logger.warning(f"Command '{command}' already used. Skipping.")
return None
else:
logger.warning("No valid command could be extracted from LLM response based on strict rules.")
return None
except Exception as e:
logger.error(f"Error parsing LLM response: {e}", exc_info=True)
return None
def set_goal(self, goal: str):
self.current_goal = goal
self.goal_achieved = False
self.react_cycle_count = 0
self.no_progress_count = 0
self.current_plan = self._generate_strategic_plan(goal)
self.current_phase_index = 0
self.identified_vulnerabilities = []
self.gathered_info = []
self.command_retry_counts = {}
self.conversation_history = [{"role": "user", "content": f"New goal set: {goal}"}]
self.used_commands.clear()
self.execution_history = []
self.goal_achieved = False
logger.info(f"Strategic Agent Goal set: {goal}. Starting initial reconnaissance.")
def _generate_strategic_plan(self, goal: str) -> List[Dict]:
logger.debug(f"Generating strategic plan for goal: {goal}")
plan = []
goal_lower = goal.lower()
plan.append({"phase": "initial_reconnaissance", "objective": f"Perform initial reconnaissance for {goal}"})
if "web" in goal_lower or "http" in goal_lower:
plan.append({"phase": "web_enumeration", "objective": "Enumerate web server for directories and files"})
plan.append({"phase": "web_vulnerability_analysis", "objective": "Analyze web vulnerabilities (SQLi, XSS, etc.)"})
plan.append({"phase": "web_exploitation", "objective": "Attempt to exploit web vulnerabilities"})
plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation activities"})
elif "android" in goal_lower or "mobile" in goal_lower or "adb" in goal_lower:
plan.append({"phase": "android_enumeration", "objective": "Enumerate Android device via ADB"})
plan.append({"phase": "android_app_analysis", "objective": "Analyze Android application for vulnerabilities"})
plan.append({"phase": "android_exploitation", "objective": "Attempt to exploit Android vulnerabilities"})
plan.append({"phase": "data_extraction", "objective": "Extract sensitive data from device"})
else:
plan.append({"phase": "network_scanning", "objective": "Perform detailed network scanning"})
plan.append({"phase": "service_enumeration", "objective": "Enumerate services and identify versions"})
plan.append({"phase": "vulnerability_analysis", "objective": "Analyze services for vulnerabilities"})
plan.append({"phase": "exploitation", "objective": "Attempt to exploit vulnerabilities"})
plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation (privilege escalation, data exfiltration)"})
plan.append({"phase": "reporting", "objective": "Generate pentest report"})
logger.info(f"Generated strategic plan for goal '{goal}': {plan}")
return plan
def evaluate_phase_completion(self) -> float:
phase_commands = [cmd for cmd in self.execution_history
if cmd.get('phase', '') == self.current_phase]
if not phase_commands:
return 0.0
successful = sum(1 for cmd in phase_commands if cmd['success'])
return successful / len(phase_commands)
def advance_phase(self):
if self.current_phase_index < len(self.current_plan) - 1:
self.current_phase_index += 1
self.current_phase = self.current_plan[self.current_phase_index]["phase"]
logger.info(f"Strategic Agent advancing to new phase: {self.current_phase.replace('_', ' ').title()}")
self.no_progress_count = 0
self.react_cycle_count = 0
else:
self.current_phase = "completed"
self.goal_achieved = True
logger.info("Strategic Agent: All planned phases completed. Goal achieved!")
def observe_result(self, command: str, output: str, success: bool):
logger.debug(f"Strategic Agent observing result for command '{command}': Success={success}")
self.execution_history.append({"command": command, "output": output, "success": success, "timestamp": datetime.now().isoformat()})
self.gathered_info.append(output)
self.analyze_command_output_strategic(command, output)
if not success:
self.no_progress_count += 1
else:
self.no_progress_count = 0
if success and self.current_phase_index < len(self.current_plan) - 1:
phase_completion = self.evaluate_phase_completion()
if phase_completion >= 0.8:
self.advance_phase()
def analyze_command_output_strategic(self, command: str, output: str):
"""Strategic Agent performs deeper analysis of command output for vulnerabilities."""
try:
logger.debug(f"Analyzing strategic command output for: {command}")
if command.startswith("nmap"):
if "open" in output and "vulnerable" in output.lower():
self.ingest_vulnerability(
"Potential vulnerability found in NMAP scan",
"Medium",
"NMAP-SCAN"
)
port_matches = re.findall(r'(\d+)/tcp\s+open\s+(\S+)', output)
for port, service in port_matches:
self.gathered_info.append(f"Discovered open port {port} with service {service}")
elif command.startswith("nikto"):
if "OSVDB-" in output:
vuln_matches = re.findall(r'OSVDB-\d+:\s*(.+)', output)
for vuln in vuln_matches[:3]:
self.ingest_vulnerability(
f"Nikto vulnerability: {vuln}",
"High",
"NIKTO-SCAN"
)
elif command.startswith("sqlmap"):
if "injection" in output.lower():
self.ingest_vulnerability(
"SQL injection vulnerability detected",
"Critical",
"SQLMAP-SCAN"
)
elif command.startswith("adb"):
if "debuggable" in output.lower():
self.ingest_vulnerability(
"Debuggable Android application found",
"High",
"ADB-DEBUG"
)
if "permission" in output.lower() and "denied" in output.lower():
self.ingest_vulnerability(
"Permission issue detected on Android device",
"Medium",
"ADB-PERMISSION"
)
except Exception as e:
logger.error(f"Strategic Agent: Error analyzing command output: {e}", exc_info=True)
def ingest_vulnerability(self, description: str, severity: str, cve_id: Optional[str] = None, exploit_id: Optional[str] = None):
vulnerability = {
"description": description,
"severity": severity,
"timestamp": datetime.now().isoformat()
}
if cve_id:
vulnerability["cve_id"] = cve_id
if exploit_id:
vulnerability["exploit_id"] = exploit_id
self.identified_vulnerabilities.append(vulnerability)
logger.info(f"Strategic Agent identified vulnerability: {description} (Severity: {severity})")
# Instantiate the Strategic Agent Brain
strategic_brain = StrategicAgentBrain()
# --- Request Models for API Endpoints ---
class RAGRequest(BaseModel):
query: constr(min_length=3, max_length=500)
top_k: int = Field(5, gt=0, le=20)
class FirebaseQueryRequest(BaseModel):
goal: str
phase: str = None
limit: int = 10
class DeepSearchRequest(BaseModel):
device_info: str
os_version: str
class SetGoalRequest(BaseModel):
goal: str
class GetNextCommandRequest(BaseModel):
current_state: str
last_command_output: str
last_command_success: bool
execution_history_summary: List[Dict] = []
gathered_info_summary: List[str] = []
identified_vulnerabilities_summary: List[Dict] = []
class ObserveResultRequest(BaseModel):
command: str
output: str
success: bool
class LoadStrategicModelRequest(BaseModel):
model_url: str # Now expects a URL instead of a local path
# --- API Endpoints ---
@app.get("/health")
async def health_check():
"""Endpoint to check the health of the service."""
logger.debug("Health check requested.")
return {"status": "ok", "message": "Knowledge service is running."}
@app.post("/rag/retrieve")
async def rag_retrieve_endpoint(request: RAGRequest):
logger.debug(f"RAG retrieve endpoint called with query: {request.query}")
try:
results = rag_index.retrieve(request.query, request.top_k)
return {"success": True, "data": {"results": results}, "error": None}
except Exception as e:
logger.error(f"RAG retrieval failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/firebase/query")
async def firebase_query_endpoint(request: FirebaseQueryRequest):
logger.debug(f"Firebase query endpoint called with goal: {request.goal}, phase: {request.phase}")
try:
results = firebase_kb.query(request.goal, request.phase, request.limit)
return {"success": True, "data": {"results": results}, "error": None}
except Exception as e:
logger.error(f"Firebase query failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/deep_search")
async def deep_search_endpoint(request: DeepSearchRequest):
logger.debug(f"Deep search endpoint called for device: {request.device_info}, OS: {request.os_version}")
try:
results = deep_search_engine.search_device_info(request.device_info, request.os_version)
results["public_resources"] = deep_search_engine.search_public_resources(request.device_info)
return {"success": True, "data": results, "error": None}
except Exception as e:
logger.error(f"Deep search failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/strategic_agent/load_model")
async def load_strategic_model(request: LoadStrategicModelRequest):
logger.info(f"Request to load strategic model: {request.model_url}")
success, message = await strategic_brain.load_strategic_llm(request.model_url)
if success:
logger.info(f"Strategic model loaded successfully: {message}")
return {"status": "success", "message": message, "model": strategic_brain.loaded_model_name}
else:
logger.error(f"Failed to load strategic model: {message}")
raise HTTPException(status_code=500, detail=message)
@app.post("/strategic_agent/unload_model")
async def unload_strategic_model():
logger.info("Request to unload strategic model.")
await strategic_brain.unload_strategic_llm()
return {"status": "success", "message": "Strategic LLM unloaded."}
@app.post("/strategic_agent/set_goal")
async def strategic_set_goal(request: SetGoalRequest):
logger.info(f"Strategic Agent received new goal: {request.goal}")
# Call the synchronous set_goal method
strategic_brain.set_goal(request.goal)
return {"status": "success", "message": f"Goal set to: {request.goal}"}
@app.post("/strategic_agent/get_next_command")
async def strategic_get_next_command(request: GetNextCommandRequest):
logger.debug("Strategic Agent received request for next command.")
# Update strategic brain's state with latest from execution agent
strategic_brain.execution_history = request.execution_history_summary
strategic_brain.gathered_info = request.gathered_info_summary
strategic_brain.identified_vulnerabilities = request.identified_vulnerabilities_summary
# Simulate agent's thinking process
command = strategic_brain.parse_llm_response(
strategic_brain._get_llm_response(
strategic_brain._generate_llm_prompt() # Generate prompt based on updated state
)
)
if command:
strategic_brain.used_commands.add(command) # Ensure strategic agent tracks used commands
logger.info(f"Strategic Agent generated command: {command}")
return {"command": command, "status": "success"}
else:
# Fallback if strategic agent fails to generate a valid command
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', strategic_brain.current_goal or "")
fallback_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1"
logger.warning(f"Strategic Agent failed to generate command. Returning fallback: {fallback_ip}")
# If no LLM is loaded, provide a more informative fallback
if strategic_brain.llm is None:
return {"command": f"echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {fallback_ip}'", "status": "fallback", "message": "No LLM loaded on Strategic Agent. Please load one from the frontend settings."}
else:
return {"command": f"nmap -sV -Pn {fallback_ip}", "status": "fallback", "message": "Strategic Agent could not determine a valid next command."}
@app.post("/strategic_agent/observe_result")
async def strategic_observe_result(request: ObserveResultRequest):
logger.debug(f"Strategic Agent received observation for command: {request.command}, success: {request.success}")
strategic_brain.observe_result(request.command, request.output, request.success)
return {"status": "success", "message": "Observation received and processed."}
@app.get("/strategic_agent/get_status")
async def strategic_get_status():
logger.debug("Strategic Agent status requested.")
return {
"currentGoal": strategic_brain.current_goal,
"currentPhase": strategic_brain.current_phase.replace('_', ' ').title(),
"reactCycleCount": strategic_brain.react_cycle_count,
"noProgressCount": strategic_brain.no_progress_count,
"identifiedVulnerabilities": [v['description'] for v in strategic_brain.identified_vulnerabilities],
"gatheredInfo": [info[:100] + "..." for info in strategic_brain.gathered_info[-5:]] if strategic_brain.gathered_info else [],
"executionHistorySummary": [{
"command": e['command'],
"success": e['success'],
"timestamp": e['timestamp']
} for e in strategic_brain.execution_history[-10:]],
"strategicPlan": strategic_brain.current_plan,
"currentPhaseIndex": strategic_brain.current_phase_index,
"goalAchieved": strategic_brain.goal_achieved,
"strategicAgentStatus": "Running" if strategic_brain.current_goal and not strategic_brain.goal_achieved else "Idle",
"loadedModel": strategic_brain.loaded_model_name # Return the name of the loaded model
}
@app.get("/api/models")
async def get_available_models_strategic():
"""List predefined Hugging Face models for strategic agent."""
logger.debug("Request for available strategic models received.")
# Explicitly return JSONResponse to ensure correct content type
return JSONResponse(content=json.dumps(HUGGINGFACE_MODELS), media_type="application/json")
# --- Startup Event to Download All Models and Start ngrok Tunnel (Modified for HF Spaces) ---
@app.on_event("startup")
async def startup_event_download_models(): # Renamed function
logger.info("Application startup event triggered. Attempting to download all predefined models.")
# Download all models
for model_info in HUGGINGFACE_MODELS:
model_url = model_info["url"]
model_name = model_info["name"]
model_filename = model_url.split('/')[-1]
local_model_path = os.path.join(DOWNLOAD_DIR, model_filename)
if not os.path.exists(local_model_path):
logger.info(f"Downloading model '{model_name}' from {model_url} to {local_model_path}...")
try:
response = requests.get(model_url, stream=True)
response.raise_for_status()
with open(local_model_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
logger.info(f"Model '{model_name}' downloaded successfully.")
except Exception as e:
logger.error(f"Failed to download model '{model_name}': {e}", exc_info=True)
else:
logger.info(f"Model '{model_name}' already exists at {local_model_path}. Skipping download.")
logger.info("Finished attempting to download all predefined models.")
# --- Shutdown Event (ngrok related parts removed) ---
@app.on_event("shutdown")
async def shutdown_event_cleanup(): # Renamed function
logger.info("Application shutdown event triggered. Performing cleanup.")
# No ngrok.kill() needed here as ngrok is not used
if __name__ == "__main__":
import uvicorn
logger.info("Starting FastAPI application on Hugging Face Spaces (port 7860)...")
uvicorn.run(
app,
host="0.0.0.0",
port=7860, # Standard port for Hugging Face Spaces
log_level="info" # Changed to info for less verbose default output
)