|
|
|
import os |
|
import json |
|
import logging |
|
import numpy as np |
|
import requests |
|
from fastapi import FastAPI, HTTPException, Depends, status |
|
from pydantic import BaseModel, Field, constr |
|
from sentence_transformers import SentenceTransformer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from datetime import datetime |
|
import firebase_admin |
|
from firebase_admin import credentials, firestore |
|
from bs4 import BeautifulSoup |
|
import re |
|
from typing import List, Dict, Optional, Tuple |
|
from cachetools import TTLCache |
|
import gc |
|
from llama_cpp import Llama |
|
import asyncio |
|
import nest_asyncio |
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
nest_asyncio.apply() |
|
|
|
|
|
|
|
DOWNLOAD_DIR = "./downloaded_models/" |
|
os.makedirs(DOWNLOAD_DIR, exist_ok=True) |
|
|
|
|
|
HUGGINGFACE_MODELS = [ |
|
{ |
|
"name": "Foundation-Sec-8B-Q8_0", |
|
"url": "https://huggingface.co/fdtn-ai/Foundation-Sec-8B-Q8_0-GGUF/resolve/main/foundation-sec-8b-q8_0.gguf" |
|
}, |
|
{ |
|
"name": "Lily-Cybersecurity-7B-v0.2-Q8_0", |
|
"url": "https://huggingface.co/Nekuromento/Lily-Cybersecurity-7B-v0.2-Q8_0-GGUF/resolve/main/lily-cybersecurity-7b-v0.2-q8_0.gguf" |
|
}, |
|
{ |
|
"name": "SecurityLLM-GGUF (sarvam-m-q8_0)", |
|
"url": "https://huggingface.co/QuantFactory/SecurityLLM-GGUF/resolve/main/sarvam-m-q8_0.gguf" |
|
} |
|
] |
|
|
|
DATA_DIR = "./data" |
|
DEEP_SEARCH_CACHE_TTL = 3600 |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
level=logging.DEBUG, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger(__name__) |
|
logger.info("Logging initialized with DEBUG level.") |
|
|
|
|
|
app = FastAPI( |
|
title="Hugging Face Strategic Agent Service", |
|
description="Provides knowledge base access and strategic reasoning for the pentest agent on Hugging Face Spaces.", |
|
version="1.0.0" |
|
) |
|
|
|
|
|
firebase_creds_path = os.getenv("FIREBASE_CREDS_PATH", "cred.json") |
|
db = None |
|
if not firebase_admin._apps: |
|
try: |
|
if os.path.exists(firebase_creds_path): |
|
cred = credentials.Certificate(firebase_creds_path) |
|
firebase_admin.initialize_app(cred) |
|
db = firestore.client() |
|
logger.info("Firebase initialized successfully.") |
|
else: |
|
logger.warning(f"Firebase credentials file not found at {firebase_creds_path}. Firebase will not be initialized.") |
|
except Exception as e: |
|
logger.error(f"Failed to initialize Firebase: {e}. Ensure FIREBASE_CREDS_PATH is set correctly and the file exists.", exc_info=True) |
|
|
|
|
|
strategic_llm: Optional[Llama] = None |
|
current_strategic_model_url: Optional[str] = None |
|
|
|
|
|
SUPPORTED_TOOLS = [ |
|
"nmap", "gobuster", "nikto", "sqlmap", "adb", "frida", |
|
"drozer", "apktool", "msfconsole", "mobsfscan", "burpsuite", |
|
"metasploit", "curl", "wget", "hydra", "john", "aircrack-ng" |
|
] |
|
|
|
|
|
deep_search_cache = TTLCache(maxsize=100, ttl=DEEP_SEARCH_CACHE_TTL) |
|
|
|
|
|
SYSTEM_INSTRUCTION = ( |
|
"You are an expert pentest agent. Strictly follow these rules:\n" |
|
"1. Output ONLY valid shell commands\n" |
|
"2. NEVER include timestamps, dates, or any text outside commands\n" |
|
"3. Never repeat previous commands\n" |
|
"4. Always verify command safety before execution\n\n" |
|
"Example valid response:\n" |
|
"nmap -sV 192.168.1.6\n\n" |
|
"Key Principles:\n" |
|
"- Never give up until the goal is achieved\n" |
|
"- Learn from failures and adapt strategies\n" |
|
"- Leverage all available knowledge and tools\n" |
|
"- Break complex tasks into smaller achievable steps\n" |
|
"- Always ensure actions are ethical and within scope\n\n" |
|
"Available Tools:\n" |
|
"- nmap: Network scanning and service detection\n" |
|
"- gobuster: Web directory brute-forcing\n" |
|
"- nikto: Web server vulnerability scanner\n" |
|
"- sqlmap: SQL injection testing\n" |
|
"- adb: Android Debug Bridge\n" |
|
"- metasploit: Exploitation framework\n\n" |
|
"Error Handling Examples:\n" |
|
"Example 1 (Command Failure):\n" |
|
" If nmap fails because host is down, try: nmap -Pn -sV 192.168.1.6\n" |
|
"Example 2 (Web Server Error):\n" |
|
" If web server returns 403, try: gobuster dir -u http://192.168.1.6 -w /usr/share/wordlists/dirbuster/directory-list-2.3-medium.txt\n" |
|
"Example 3 (ADB Connection Failed):\n" |
|
" If ADB connection fails, try: adb kill-server && adb start-server" |
|
) |
|
|
|
|
|
class FirebaseKnowledgeBase: |
|
def __init__(self): |
|
self.collection = db.collection('knowledge_base') if db else None |
|
|
|
def query(self, goal: str, phase: str = None, limit: int = 10) -> list: |
|
if not db or not firebase_admin._apps: |
|
logger.error("Firestore client not initialized. Cannot query knowledge base.") |
|
return [] |
|
|
|
|
|
if not hasattr(self, 'collection') or self.collection is None: |
|
self.collection = db.collection('knowledge_base') |
|
|
|
keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device'] |
|
if phase: |
|
keywords.append(phase.lower()) |
|
|
|
try: |
|
query_ref = self.collection |
|
results = [] |
|
docs = query_ref.stream() |
|
|
|
for doc in docs: |
|
data = doc.to_dict() |
|
text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}" |
|
if any(keyword in text for keyword in keywords): |
|
results.append(data) |
|
if len(results) >= 10: |
|
break |
|
|
|
priority_order = {"high": 1, "medium": 2, "low": 3} |
|
results.sort(key=lambda x: ( |
|
priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3), |
|
x.get('metadata', {}).get('timestamp', 0) |
|
)) |
|
|
|
return results[:10] |
|
except Exception as e: |
|
logger.error(f"Failed to query knowledge base: {e}", exc_info=True) |
|
return [] |
|
|
|
|
|
class KnowledgeIndex: |
|
def __init__(self, model_name="all-MiniLM-L6-v2"): |
|
self.model = SentenceTransformer( |
|
model_name, |
|
cache_folder=os.path.join(DATA_DIR, "hf_cache") |
|
) |
|
self.knowledge_base = [] |
|
os.makedirs(DATA_DIR, exist_ok=True) |
|
self.load_knowledge_from_file(os.path.join(DATA_DIR, 'knowledge_base.json')) |
|
|
|
def load_knowledge_from_file(self, file_path): |
|
logger.debug(f"Attempting to load knowledge from file: {file_path}") |
|
if os.path.exists(file_path): |
|
try: |
|
with open(file_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
if not isinstance(data, list): |
|
logger.error("Knowledge base file is not a list. Please check the file format.") |
|
return |
|
for item in data: |
|
if isinstance(item, dict): |
|
text = item.get('text', '') |
|
source = item.get('source', 'local') |
|
elif isinstance(item, str): |
|
text = item |
|
source = 'local' |
|
else: |
|
logger.warning(f"Skipping unsupported item type: {type(item)}") |
|
continue |
|
if text: |
|
embedding = self.model.encode(text).tolist() |
|
self.knowledge_base.append({'text': text, 'embedding': embedding, 'source': source}) |
|
logger.info(f"Loaded {len(self.knowledge_base)} items into RAG knowledge base.") |
|
except Exception as e: |
|
logger.error(f"Error loading knowledge from {file_path}: {e}", exc_info=True) |
|
else: |
|
logger.warning(f"Knowledge base file not found: {file_path}. RAG will operate on an empty knowledge base.") |
|
try: |
|
with open(file_path, 'w', encoding='utf-8') as f: |
|
json.dump([], f) |
|
logger.info(f"Created empty knowledge base file at: {file_path}") |
|
except Exception as e: |
|
logger.error(f"Error creating empty knowledge base file at {file_path}: {e}", exc_info=True) |
|
|
|
def retrieve(self, query: str, top_k: int = 5) -> List[Dict]: |
|
if not self.knowledge_base: |
|
logger.debug("Knowledge base is empty, no RAG retrieval possible.") |
|
return [] |
|
|
|
try: |
|
query_embedding = self.model.encode(query).reshape(1, -1) |
|
embeddings = np.array([item['embedding'] for item in self.knowledge_base]) |
|
|
|
similarities = cosine_similarity(query_embedding, embeddings)[0] |
|
top_indices = similarities.argsort()[-top_k:][::-1] |
|
|
|
results = [] |
|
for i in top_indices: |
|
results.append({ |
|
"text": self.knowledge_base[i]['text'], |
|
"similarity": similarities[i], |
|
"source": self.knowledge_base[i].get('source', 'RAG') |
|
}) |
|
logger.debug(f"RAG retrieved {len(results)} results for query: '{query}'") |
|
return results |
|
except Exception as e: |
|
logger.error(f"Error during RAG retrieval for query '{query}': {e}", exc_info=True) |
|
return [] |
|
|
|
|
|
class DeepSearchEngine: |
|
def __init__(self): |
|
self.headers = { |
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" |
|
} |
|
|
|
def search_device_info(self, device_info: str, os_version: str) -> dict: |
|
logger.debug(f"Performing deep search for device: {device_info}, OS: {os_version}") |
|
results = { |
|
"device": device_info, |
|
"os_version": os_version, |
|
"vulnerabilities": [], |
|
"exploits": [], |
|
"recommendations": [] |
|
} |
|
try: |
|
cve_results = self.search_cve(device_info, os_version) |
|
results["vulnerabilities"] = cve_results |
|
exploit_results = self.search_exploits(device_info, os_version) |
|
results["exploits"] = exploit_results |
|
recommendations = self.get_security_recommendations(os_version) |
|
results["recommendations"] = recommendations |
|
logger.debug("Deep search completed.") |
|
except Exception as e: |
|
logger.error(f"Deep search failed: {e}", exc_info=True) |
|
return results |
|
|
|
def search_cve(self, device: str, os_version: str) -> list: |
|
cves = [] |
|
try: |
|
query = f"{device} {os_version} CVE" |
|
search_url = f"https://cve.mitre.org/cgi-bin/cvekey.cgi?keyword={query}" |
|
logger.debug(f"Searching CVE Mitre: {search_url}") |
|
response = requests.get(search_url, headers=self.headers) |
|
response.raise_for_status() |
|
if response.status_code == 200: |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
table = soup.find('div', id='TableWithRules') |
|
if table: |
|
rows = table.find_all('tr')[1:] |
|
for row in rows: |
|
cols = row.find_all('td') |
|
if len(cols) >= 2: |
|
cve_id = cols[0].get_text(strip=True) |
|
description = cols[1].get_text(strip=True) |
|
cves.append({ |
|
"cve_id": cve_id, |
|
"description": description, |
|
"source": "CVE Mitre" |
|
}) |
|
logger.debug(f"Found {len(cves)} CVEs.") |
|
return cves[:10] |
|
except Exception as e: |
|
logger.error(f"CVE search failed: {e}", exc_info=True) |
|
return [] |
|
|
|
def search_exploits(self, device: str, os_version: str) -> list: |
|
exploits = [] |
|
try: |
|
query = f"{device} {os_version}" |
|
search_url = f"https://www.exploit-db.com/search?q={query}" |
|
logger.debug(f"Searching ExploitDB: {search_url}") |
|
response = requests.get(search_url, headers=self.headers) |
|
response.raise_for_status() |
|
if response.status_code == 200: |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
cards = soup.select('.card .card-title') |
|
for card in cards: |
|
title = card.get_text(strip=True) |
|
link = card.find('a')['href'] |
|
if not link.startswith('http'): |
|
link = f"https://www.exploit-db.com{link}" |
|
exploits.append({ |
|
"title": title, |
|
"link": link, |
|
"source": "ExploitDB" |
|
}) |
|
logger.debug(f"Found {len(exploits)} exploits.") |
|
return exploits[:10] |
|
except Exception as e: |
|
logger.error(f"Exploit search failed: {e}", exc_info=True) |
|
return [] |
|
|
|
def get_security_recommendations(self, os_version: str) -> list: |
|
recommendations = [] |
|
try: |
|
logger.debug(f"Getting security recommendations for OS: {os_version}") |
|
if "android" in os_version.lower(): |
|
url = "https://source.android.com/docs/security/bulletin" |
|
response = requests.get(url, headers=self.headers) |
|
response.raise_for_status() |
|
if response.status_code == 200: |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
versions = soup.select('.devsite-article-body h2') |
|
for version in versions: |
|
if os_version in version.get_text(): |
|
next_ul = version.find_next('ul') |
|
if next_ul: |
|
items = next_ul.select('li') |
|
for item in items: |
|
recommendations.append(item.get_text(strip=True)) |
|
elif "ios" in os_version.lower(): |
|
url = "https://support.apple.com/en-us/HT201222" |
|
response = requests.get(url, headers=self.headers) |
|
response.raise_for_status() |
|
if response.status_code == 200: |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
sections = soup.select('#sections') |
|
for section in sections: |
|
if os_version in section.get_text(): |
|
items = section.select('li') |
|
for item in items: |
|
recommendations.append(item.get_text(strip=True)) |
|
logger.debug(f"Found {len(recommendations)} recommendations.") |
|
return recommendations[:5] |
|
except Exception as e: |
|
logger.error(f"Security recommendations search failed: {e}", exc_info=True) |
|
return [] |
|
|
|
def search_public_resources(self, device_info: str) -> list: |
|
resources = [] |
|
try: |
|
logger.debug(f"Searching public resources for device: {device_info}") |
|
github_url = f"https://github.com/search?q={device_info.replace(' ', '+')}+pentest" |
|
response = requests.get(github_url, headers=self.headers) |
|
response.raise_for_status() |
|
if response.status_code == 200: |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
repos = soup.select('.repo-list-item') |
|
for repo in repos: |
|
title = repo.select_one('.v-align-middle').get_text(strip=True) |
|
description = repo.select_one('.mb-1').get_text(strip=True) if repo.select_one('.mb-1') else "" |
|
url = f"https://github.com{repo.select_one('.v-align-middle')['href']}" |
|
resources.append({ |
|
"title": title, |
|
"description": description, |
|
"url": url, |
|
"source": "GitHub" |
|
}) |
|
forum_url = f"https://hackforums.net/search.php?action=finduserthreads&keywords={device_info.replace(' ', '+')}" |
|
response = requests.get(forum_url, headers=self.headers) |
|
response.raise_for_status() |
|
if response.status_code == 200: |
|
soup = BeautifulSoup(response.text, 'html.parser') |
|
threads = soup.select('.thread') |
|
for thread in threads: |
|
title = thread.select_one('.threadtitle').get_text(strip=True) |
|
url = f"https://hackforums.net{thread.select_one('.threadtitle a')['href']}" |
|
resources.append({ |
|
"title": title, |
|
"description": "Forum discussion", |
|
"url": url, |
|
"source": "HackForums" |
|
}) |
|
logger.debug(f"Found {len(resources)} public resources.") |
|
return resources[:10] |
|
except Exception as e: |
|
logger.error(f"Public resources search failed: {e}", exc_info=True) |
|
return [] |
|
|
|
|
|
firebase_kb = FirebaseKnowledgeBase() |
|
rag_index = KnowledgeIndex() |
|
deep_search_engine = DeepSearchEngine() |
|
|
|
|
|
class StrategicAgentBrain: |
|
def __init__(self): |
|
self.llm: Optional[Llama] = None |
|
self.current_goal: Optional[str] = None |
|
self.current_phase: str = "initial_reconnaissance" |
|
self.current_plan: List[Dict] = [] |
|
self.current_phase_index: int = 0 |
|
self.identified_vulnerabilities: List[Dict] = [] |
|
self.gathered_info: List[str] = [] |
|
self.command_retry_counts: Dict[str, int] = {} |
|
self.conversation_history: List[Dict] = [] |
|
self.used_commands = set() |
|
self.execution_history = [] |
|
self.goal_achieved = False |
|
self.no_progress_count = 0 |
|
self.react_cycle_count = 0 |
|
self.loaded_model_name: Optional[str] = None |
|
logger.info("StrategicAgentBrain initialized.") |
|
|
|
async def load_strategic_llm(self, model_url: str): |
|
global strategic_llm, current_strategic_model_url |
|
logger.info(f"Attempting to load strategic LLM from URL: {model_url}") |
|
|
|
|
|
model_filename = model_url.split('/')[-1] |
|
local_model_path = os.path.join(DOWNLOAD_DIR, model_filename) |
|
|
|
if strategic_llm and current_strategic_model_url == model_url: |
|
logger.info(f"Strategic LLM model from {model_url} is already loaded.") |
|
self.llm = strategic_llm |
|
return True, f"Model '{self.loaded_model_name}' is already loaded." |
|
|
|
|
|
if strategic_llm: |
|
await self.unload_strategic_llm() |
|
|
|
|
|
if not os.path.exists(local_model_path): |
|
logger.info(f"Model not found locally. Attempting to download from {model_url} to {local_model_path}...") |
|
try: |
|
response = requests.get(model_url, stream=True) |
|
response.raise_for_status() |
|
with open(local_model_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
logger.info(f"Model downloaded successfully to {local_model_path}.") |
|
except Exception as e: |
|
logger.error(f"Failed to download model from {model_url}: {e}", exc_info=True) |
|
return False, f"Failed to download model: {str(e)}" |
|
|
|
try: |
|
logger.info(f"Loading Strategic LLM model from {local_model_path}...") |
|
strategic_llm = Llama( |
|
model_path=local_model_path, |
|
n_ctx=3096, |
|
n_gpu_layers=0, |
|
n_threads=os.cpu_count(), |
|
n_batch=512, |
|
verbose=False |
|
) |
|
current_strategic_model_url = model_url |
|
self.llm = strategic_llm |
|
self.loaded_model_name = model_filename |
|
logger.info(f"Strategic LLM model {model_filename} loaded successfully (CPU-only).") |
|
return True, f"Model '{model_filename}' loaded successfully (CPU-only)." |
|
except Exception as e: |
|
logger.error(f"Failed to load Strategic LLM model from {local_model_path}: {e}", exc_info=True) |
|
strategic_llm = None |
|
current_strategic_model_url = None |
|
self.llm = None |
|
self.loaded_model_name = None |
|
return False, f"Failed to load model: {str(e)}" |
|
|
|
async def unload_strategic_llm(self): |
|
global strategic_llm, current_strategic_model_url |
|
if strategic_llm: |
|
logger.info("Unloading Strategic LLM model...") |
|
del strategic_llm |
|
strategic_llm = None |
|
current_strategic_model_url = None |
|
self.llm = None |
|
self.loaded_model_name = None |
|
gc.collect() |
|
logger.info("Strategic LLM model unloaded.") |
|
|
|
def _get_rag_context(self, query: str) -> str: |
|
results = rag_index.retrieve(query) |
|
if not results: |
|
return "" |
|
rag_context = "Relevant Knowledge for Current Context:\n" |
|
for i, result in enumerate(results): |
|
text = result.get('text', '') or result.get('completion', '') |
|
source = result.get('source', 'RAG') |
|
rag_context += f"{i+1}. [{source}] {text}\n" |
|
return rag_context |
|
|
|
def _get_firebase_knowledge(self, goal: str, phase: str = None) -> str: |
|
if not db or not firebase_admin._apps: |
|
logger.error("Firestore client not initialized. Cannot query knowledge base.") |
|
return "" |
|
|
|
|
|
if not hasattr(self, 'collection') or self.collection is None: |
|
self.collection = db.collection('knowledge_base') |
|
|
|
keywords = [goal.lower(), 'android', 'pentest', 'mobile', 'device'] |
|
if phase: |
|
keywords.append(phase.lower()) |
|
|
|
try: |
|
query_ref = self.collection |
|
results = [] |
|
docs = query_ref.stream() |
|
|
|
for doc in docs: |
|
data = doc.to_dict() |
|
text = f"{data.get('prompt', '').lower()} {data.get('completion', '').lower()} {data.get('metadata', '').lower()}" |
|
if any(keyword in text for keyword in keywords): |
|
results.append(data) |
|
if len(results) >= 10: |
|
break |
|
|
|
priority_order = {"high": 1, "medium": 2, "low": 3} |
|
results.sort(key=lambda x: ( |
|
priority_order.get(x.get('metadata', {}).get('priority', 'low').lower(), 3), |
|
x.get('metadata', {}).get('timestamp', 0) |
|
)) |
|
|
|
return results[:10] |
|
except Exception as e: |
|
logger.error(f"Failed to query knowledge base: {e}", exc_info=True) |
|
return "" |
|
|
|
def extract_device_info(self) -> str: |
|
for info in self.gathered_info: |
|
if "model" in info.lower() or "device" in info.lower(): |
|
match = re.search(r'(?:model|device)\s*[:=]\s*([^\n]+)', info, re.IGNORECASE) |
|
if match: |
|
return match.group(1).strip() |
|
ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") |
|
return ip_match.group(0) if ip_match else "Unknown Device" |
|
|
|
def extract_os_version(self) -> str: |
|
for info in self.gathered_info: |
|
if "android" in info.lower() or "ios" in info.lower() or "os" in info.lower(): |
|
android_match = re.search(r'android\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE) |
|
if android_match: |
|
return f"Android {android_match.group(1)}" |
|
ios_match = re.search(r'ios\s+(\d+(?:\.\d+)+)', info, re.IGNORECASE) |
|
if ios_match: |
|
return f"iOS {ios_match.group(1)}" |
|
linux_match = re.search(r'linux\s+kernel\s+(\d+\.\d+\.\d+)', info, re.IGNORECASE) |
|
if linux_match: |
|
return f"Linux {linux_match.group(1)}" |
|
return "Unknown OS Version" |
|
|
|
def format_deep_search_results(self, results: dict) -> str: |
|
context = "Deep Search Results:\n" |
|
context += f"Device: {results.get('device', 'Unknown')}\n" |
|
context += f"OS Version: {results.get('os_version', 'Unknown')}\n\n" |
|
if results.get('vulnerabilities'): |
|
context += "Discovered Vulnerabilities:\n" |
|
for i, vuln in enumerate(results['vulnerabilities'][:5], 1): |
|
context += f"{i}. {vuln.get('cve_id', 'CVE-XXXX-XXXX')}: {vuln.get('description', 'No description')}\n" |
|
context += "\n" |
|
if results.get('exploits'): |
|
context += "Available Exploits:\n" |
|
for i, exploit in enumerate(results['exploits'][:5], 1): |
|
context += f"{i}. {exploit.get('title', 'Untitled exploit')} [Source: {exploit.get('source', 'Unknown')}]\n" |
|
context += "\n" |
|
if results.get('recommendations'): |
|
context += "Security Recommendations:\n" |
|
for i, rec in enumerate(results['recommendations'][:3], 1): |
|
context += f"{i}. {rec}\n" |
|
context += "\n" |
|
if results.get('public_resources'): |
|
context += "Public Resources:\n" |
|
for i, res in enumerate(results['public_resources'][:3], 1): |
|
context += f"{i}. {res.get('title', 'Untitled resource')} [Source: {res.get('source', 'Unknown')}]\n" |
|
return context |
|
|
|
def generate_deep_search_prompt(self, context: str) -> str: |
|
return f""" |
|
You are an expert pentester. Below are deep search results for the target device. |
|
Use this information to generate the next penetration testing command.{context} |
|
|
|
Current Goal: {self.current_goal} |
|
Current Phase: {self.current_phase} |
|
|
|
Recent Command History:{', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'} |
|
|
|
Based on this information, what is the SINGLE MOST EFFECTIVE shell command to execute next? |
|
Focus on exploiting the most critical vulnerabilities or gathering more information. |
|
|
|
Response Format: |
|
Command: <your_command_here> |
|
""" |
|
|
|
def _generate_llm_prompt(self) -> str: |
|
rag_context = self._get_rag_context(f"{self.current_goal} {self.current_phase}") |
|
firebase_knowledge = self._get_firebase_knowledge(self.current_goal, self.current_phase) |
|
|
|
history_context = "\n".join( |
|
f"{entry['role']}: {entry['content']}" for entry in self.conversation_history[-2:] |
|
) |
|
|
|
execution_history = "\n".join( |
|
f"Command: {res['command']}\nResult: {res['output'][:100]}...\nSuccess: {res['success']}" |
|
for res in self.execution_history[-2:] |
|
) if self.execution_history else "No previous results." |
|
|
|
strategic_advice = self._get_rag_context(self.current_phase) |
|
|
|
def shorten_text(text, max_length=300): |
|
if len(text) > max_length: |
|
return text[:max_length] + "... [truncated]" |
|
return text |
|
|
|
rag_context = shorten_text(rag_context, max_length=200) |
|
firebase_knowledge = shorten_text(firebase_knowledge, max_length=200) |
|
strategic_advice = shorten_text(strategic_advice, max_length=100) |
|
history_context = shorten_text(history_context, max_length=150) |
|
execution_history = shorten_text(execution_history, max_length=500) |
|
|
|
prompt = f""" |
|
System Instructions: {SYSTEM_INSTRUCTION} |
|
|
|
Current Goal: '{self.current_goal}' |
|
Current Phase: {self.current_phase} - {self.current_plan[self.current_phase_index]['objective'] if self.current_plan and self.current_phase_index < len(self.current_plan) else 'No objective'} |
|
|
|
Based on the following knowledge and previous results, generate the SINGLE, VALID SHELL COMMAND to advance the penetration testing process. |
|
|
|
**Knowledge from External Services (RAG & Firebase):** |
|
{rag_context} |
|
{firebase_knowledge} |
|
|
|
**Previous Execution Results:** |
|
{execution_history} |
|
|
|
**Recent Conversation History:** |
|
{history_context} |
|
|
|
**Strategic Advice for Current Phase:** |
|
{strategic_advice} |
|
|
|
***CRITICAL RULES FOR OUTPUT:*** |
|
1. **OUTPUT ONLY THE COMMAND.** |
|
2. **DO NOT INCLUDE ANY REASONING, THOUGHTS, EXPLANATIONS, OR ANY OTHER TEXT.** |
|
3. The command MUST be directly executable in a Linux terminal. |
|
4. Avoid repeating these recent commands: {', '.join(list(self.used_commands)[-3:]) if self.used_commands else 'None'} |
|
5. If the previous command failed, try a different approach or a related tool. |
|
6. For the 'android_enumeration' phase, prioritize ADB commands. |
|
|
|
Example valid commands for initial reconnaissance of an Android phone: |
|
nmap -sV -Pn 192.168.1.14 |
|
adb devices |
|
adb connect 192.168.1.14:5555 |
|
|
|
Command: |
|
""" |
|
return prompt |
|
|
|
def _get_llm_response(self, custom_prompt: str = None) -> str: |
|
if not self.llm: |
|
logger.error("Strategic LLM instance is None. Cannot get response. Please load a model first.") |
|
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") |
|
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" |
|
return f"Command: echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {target_ip}'" |
|
|
|
prompt = custom_prompt if custom_prompt else self._generate_llm_prompt() |
|
|
|
logger.info(f"Sending prompt to Strategic LLM:\n{prompt[:500]}...") |
|
|
|
try: |
|
response = self.llm( |
|
prompt, |
|
max_tokens=512, |
|
temperature=0.3, |
|
stop=["\n"] |
|
) |
|
llm_response = response['choices'][0]['text'].strip() |
|
logger.info(f"Strategic LLM raw response: {llm_response}") |
|
|
|
if not llm_response: |
|
logger.warning("Strategic LLM returned an empty response. Using fallback command.") |
|
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") |
|
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" |
|
return f"Command: nmap -sV -Pn {target_ip}" |
|
|
|
return llm_response |
|
except Exception as e: |
|
logger.error(f"Error during Strategic LLM inference: {e}", exc_info=True) |
|
logger.warning("Strategic LLM inference failed. Using fallback command.") |
|
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', self.current_goal or "") |
|
target_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" |
|
return f"Command: nmap -sV -Pn {target_ip}" |
|
|
|
def parse_llm_response(self, response: str) -> str: |
|
logger.info(f"Attempting to parse LLM response: '{response}'") |
|
command = None |
|
try: |
|
code_block = re.search(r'```(?:bash|sh)?\s*([\s\S]*?)```', response) |
|
if code_block: |
|
command = code_block.group(1).strip() |
|
logger.info(f"Command extracted from code block: '{command}'") |
|
|
|
if not command: |
|
command_match = re.search(r'^\s*Command\s*:\s*(.+)$', response, re.MULTILINE | re.IGNORECASE) |
|
if command_match: |
|
command = command_match.group(1).strip() |
|
logger.info(f"Command extracted from 'Command:' line: '{command}'") |
|
|
|
if not command: |
|
stripped_response = response.strip() |
|
if any(stripped_response.startswith(tool) for tool in SUPPORTED_TOOLS): |
|
command = stripped_response |
|
logger.info(f"Command extracted as direct supported tool command: '{command}'") |
|
|
|
if command: |
|
original_command = command |
|
command = re.sub(r'^\s*(Command|Answer|Note|Result)\s*[:.-]?\s*', '', command, flags=re.IGNORECASE).strip() |
|
logger.info(f"Cleaned command: from '{original_command}' to '{command}'") |
|
|
|
if not re.match(r'^[a-zA-Z0-9_./:;= \-\'"\s]+$', command): |
|
logger.error(f"Invalid command characters detected after cleanup: '{command}'") |
|
return None |
|
|
|
if re.search(r'(reason|thought|explanation|rationale|note|result):', command, re.IGNORECASE): |
|
logger.warning(f"Command '{command}' appears to be reasoning/explanation. Rejecting.") |
|
return None |
|
|
|
if command not in self.used_commands: |
|
self.used_commands.add(command) |
|
logger.info(f"Returning valid and new command: '{command}'") |
|
return command |
|
else: |
|
logger.warning(f"Command '{command}' already used. Skipping.") |
|
return None |
|
else: |
|
logger.warning("No valid command could be extracted from LLM response based on strict rules.") |
|
return None |
|
except Exception as e: |
|
logger.error(f"Error parsing LLM response: {e}", exc_info=True) |
|
return None |
|
|
|
def set_goal(self, goal: str): |
|
self.current_goal = goal |
|
self.goal_achieved = False |
|
self.react_cycle_count = 0 |
|
self.no_progress_count = 0 |
|
|
|
self.current_plan = self._generate_strategic_plan(goal) |
|
self.current_phase_index = 0 |
|
self.identified_vulnerabilities = [] |
|
self.gathered_info = [] |
|
self.command_retry_counts = {} |
|
self.conversation_history = [{"role": "user", "content": f"New goal set: {goal}"}] |
|
self.used_commands.clear() |
|
self.execution_history = [] |
|
self.goal_achieved = False |
|
logger.info(f"Strategic Agent Goal set: {goal}. Starting initial reconnaissance.") |
|
|
|
|
|
def _generate_strategic_plan(self, goal: str) -> List[Dict]: |
|
logger.debug(f"Generating strategic plan for goal: {goal}") |
|
plan = [] |
|
goal_lower = goal.lower() |
|
|
|
plan.append({"phase": "initial_reconnaissance", "objective": f"Perform initial reconnaissance for {goal}"}) |
|
|
|
if "web" in goal_lower or "http" in goal_lower: |
|
plan.append({"phase": "web_enumeration", "objective": "Enumerate web server for directories and files"}) |
|
plan.append({"phase": "web_vulnerability_analysis", "objective": "Analyze web vulnerabilities (SQLi, XSS, etc.)"}) |
|
plan.append({"phase": "web_exploitation", "objective": "Attempt to exploit web vulnerabilities"}) |
|
plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation activities"}) |
|
|
|
elif "android" in goal_lower or "mobile" in goal_lower or "adb" in goal_lower: |
|
plan.append({"phase": "android_enumeration", "objective": "Enumerate Android device via ADB"}) |
|
plan.append({"phase": "android_app_analysis", "objective": "Analyze Android application for vulnerabilities"}) |
|
plan.append({"phase": "android_exploitation", "objective": "Attempt to exploit Android vulnerabilities"}) |
|
plan.append({"phase": "data_extraction", "objective": "Extract sensitive data from device"}) |
|
|
|
else: |
|
plan.append({"phase": "network_scanning", "objective": "Perform detailed network scanning"}) |
|
plan.append({"phase": "service_enumeration", "objective": "Enumerate services and identify versions"}) |
|
plan.append({"phase": "vulnerability_analysis", "objective": "Analyze services for vulnerabilities"}) |
|
plan.append({"phase": "exploitation", "objective": "Attempt to exploit vulnerabilities"}) |
|
plan.append({"phase": "post_exploitation", "objective": "Perform post exploitation (privilege escalation, data exfiltration)"}) |
|
|
|
plan.append({"phase": "reporting", "objective": "Generate pentest report"}) |
|
|
|
logger.info(f"Generated strategic plan for goal '{goal}': {plan}") |
|
return plan |
|
|
|
def evaluate_phase_completion(self) -> float: |
|
phase_commands = [cmd for cmd in self.execution_history |
|
if cmd.get('phase', '') == self.current_phase] |
|
if not phase_commands: |
|
return 0.0 |
|
successful = sum(1 for cmd in phase_commands if cmd['success']) |
|
return successful / len(phase_commands) |
|
|
|
def advance_phase(self): |
|
if self.current_phase_index < len(self.current_plan) - 1: |
|
self.current_phase_index += 1 |
|
self.current_phase = self.current_plan[self.current_phase_index]["phase"] |
|
logger.info(f"Strategic Agent advancing to new phase: {self.current_phase.replace('_', ' ').title()}") |
|
self.no_progress_count = 0 |
|
self.react_cycle_count = 0 |
|
else: |
|
self.current_phase = "completed" |
|
self.goal_achieved = True |
|
logger.info("Strategic Agent: All planned phases completed. Goal achieved!") |
|
|
|
def observe_result(self, command: str, output: str, success: bool): |
|
logger.debug(f"Strategic Agent observing result for command '{command}': Success={success}") |
|
self.execution_history.append({"command": command, "output": output, "success": success, "timestamp": datetime.now().isoformat()}) |
|
self.gathered_info.append(output) |
|
|
|
self.analyze_command_output_strategic(command, output) |
|
|
|
if not success: |
|
self.no_progress_count += 1 |
|
else: |
|
self.no_progress_count = 0 |
|
|
|
if success and self.current_phase_index < len(self.current_plan) - 1: |
|
phase_completion = self.evaluate_phase_completion() |
|
if phase_completion >= 0.8: |
|
self.advance_phase() |
|
|
|
def analyze_command_output_strategic(self, command: str, output: str): |
|
"""Strategic Agent performs deeper analysis of command output for vulnerabilities.""" |
|
try: |
|
logger.debug(f"Analyzing strategic command output for: {command}") |
|
if command.startswith("nmap"): |
|
if "open" in output and "vulnerable" in output.lower(): |
|
self.ingest_vulnerability( |
|
"Potential vulnerability found in NMAP scan", |
|
"Medium", |
|
"NMAP-SCAN" |
|
) |
|
port_matches = re.findall(r'(\d+)/tcp\s+open\s+(\S+)', output) |
|
for port, service in port_matches: |
|
self.gathered_info.append(f"Discovered open port {port} with service {service}") |
|
|
|
elif command.startswith("nikto"): |
|
if "OSVDB-" in output: |
|
vuln_matches = re.findall(r'OSVDB-\d+:\s*(.+)', output) |
|
for vuln in vuln_matches[:3]: |
|
self.ingest_vulnerability( |
|
f"Nikto vulnerability: {vuln}", |
|
"High", |
|
"NIKTO-SCAN" |
|
) |
|
|
|
elif command.startswith("sqlmap"): |
|
if "injection" in output.lower(): |
|
self.ingest_vulnerability( |
|
"SQL injection vulnerability detected", |
|
"Critical", |
|
"SQLMAP-SCAN" |
|
) |
|
|
|
elif command.startswith("adb"): |
|
if "debuggable" in output.lower(): |
|
self.ingest_vulnerability( |
|
"Debuggable Android application found", |
|
"High", |
|
"ADB-DEBUG" |
|
) |
|
if "permission" in output.lower() and "denied" in output.lower(): |
|
self.ingest_vulnerability( |
|
"Permission issue detected on Android device", |
|
"Medium", |
|
"ADB-PERMISSION" |
|
) |
|
except Exception as e: |
|
logger.error(f"Strategic Agent: Error analyzing command output: {e}", exc_info=True) |
|
|
|
def ingest_vulnerability(self, description: str, severity: str, cve_id: Optional[str] = None, exploit_id: Optional[str] = None): |
|
vulnerability = { |
|
"description": description, |
|
"severity": severity, |
|
"timestamp": datetime.now().isoformat() |
|
} |
|
if cve_id: |
|
vulnerability["cve_id"] = cve_id |
|
if exploit_id: |
|
vulnerability["exploit_id"] = exploit_id |
|
|
|
self.identified_vulnerabilities.append(vulnerability) |
|
logger.info(f"Strategic Agent identified vulnerability: {description} (Severity: {severity})") |
|
|
|
|
|
strategic_brain = StrategicAgentBrain() |
|
|
|
|
|
class RAGRequest(BaseModel): |
|
query: constr(min_length=3, max_length=500) |
|
top_k: int = Field(5, gt=0, le=20) |
|
|
|
class FirebaseQueryRequest(BaseModel): |
|
goal: str |
|
phase: str = None |
|
limit: int = 10 |
|
|
|
class DeepSearchRequest(BaseModel): |
|
device_info: str |
|
os_version: str |
|
|
|
class SetGoalRequest(BaseModel): |
|
goal: str |
|
|
|
class GetNextCommandRequest(BaseModel): |
|
current_state: str |
|
last_command_output: str |
|
last_command_success: bool |
|
execution_history_summary: List[Dict] = [] |
|
gathered_info_summary: List[str] = [] |
|
identified_vulnerabilities_summary: List[Dict] = [] |
|
|
|
class ObserveResultRequest(BaseModel): |
|
command: str |
|
output: str |
|
success: bool |
|
|
|
class LoadStrategicModelRequest(BaseModel): |
|
model_url: str |
|
|
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
"""Endpoint to check the health of the service.""" |
|
logger.debug("Health check requested.") |
|
return {"status": "ok", "message": "Knowledge service is running."} |
|
|
|
@app.post("/rag/retrieve") |
|
async def rag_retrieve_endpoint(request: RAGRequest): |
|
logger.debug(f"RAG retrieve endpoint called with query: {request.query}") |
|
try: |
|
results = rag_index.retrieve(request.query, request.top_k) |
|
return {"success": True, "data": {"results": results}, "error": None} |
|
except Exception as e: |
|
logger.error(f"RAG retrieval failed: {e}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/firebase/query") |
|
async def firebase_query_endpoint(request: FirebaseQueryRequest): |
|
logger.debug(f"Firebase query endpoint called with goal: {request.goal}, phase: {request.phase}") |
|
try: |
|
results = firebase_kb.query(request.goal, request.phase, request.limit) |
|
return {"success": True, "data": {"results": results}, "error": None} |
|
except Exception as e: |
|
logger.error(f"Firebase query failed: {e}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/deep_search") |
|
async def deep_search_endpoint(request: DeepSearchRequest): |
|
logger.debug(f"Deep search endpoint called for device: {request.device_info}, OS: {request.os_version}") |
|
try: |
|
results = deep_search_engine.search_device_info(request.device_info, request.os_version) |
|
results["public_resources"] = deep_search_engine.search_public_resources(request.device_info) |
|
return {"success": True, "data": results, "error": None} |
|
except Exception as e: |
|
logger.error(f"Deep search failed: {e}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.post("/strategic_agent/load_model") |
|
async def load_strategic_model(request: LoadStrategicModelRequest): |
|
logger.info(f"Request to load strategic model: {request.model_url}") |
|
success, message = await strategic_brain.load_strategic_llm(request.model_url) |
|
if success: |
|
logger.info(f"Strategic model loaded successfully: {message}") |
|
return {"status": "success", "message": message, "model": strategic_brain.loaded_model_name} |
|
else: |
|
logger.error(f"Failed to load strategic model: {message}") |
|
raise HTTPException(status_code=500, detail=message) |
|
|
|
@app.post("/strategic_agent/unload_model") |
|
async def unload_strategic_model(): |
|
logger.info("Request to unload strategic model.") |
|
await strategic_brain.unload_strategic_llm() |
|
return {"status": "success", "message": "Strategic LLM unloaded."} |
|
|
|
@app.post("/strategic_agent/set_goal") |
|
async def strategic_set_goal(request: SetGoalRequest): |
|
logger.info(f"Strategic Agent received new goal: {request.goal}") |
|
|
|
strategic_brain.set_goal(request.goal) |
|
return {"status": "success", "message": f"Goal set to: {request.goal}"} |
|
|
|
@app.post("/strategic_agent/get_next_command") |
|
async def strategic_get_next_command(request: GetNextCommandRequest): |
|
logger.debug("Strategic Agent received request for next command.") |
|
|
|
strategic_brain.execution_history = request.execution_history_summary |
|
strategic_brain.gathered_info = request.gathered_info_summary |
|
strategic_brain.identified_vulnerabilities = request.identified_vulnerabilities_summary |
|
|
|
|
|
command = strategic_brain.parse_llm_response( |
|
strategic_brain._get_llm_response( |
|
strategic_brain._generate_llm_prompt() |
|
) |
|
) |
|
|
|
if command: |
|
strategic_brain.used_commands.add(command) |
|
logger.info(f"Strategic Agent generated command: {command}") |
|
return {"command": command, "status": "success"} |
|
else: |
|
|
|
target_ip_match = re.search(r'\b\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}\b', strategic_brain.current_goal or "") |
|
fallback_ip = target_ip_match.group(0) if target_ip_match else "192.168.1.1" |
|
logger.warning(f"Strategic Agent failed to generate command. Returning fallback: {fallback_ip}") |
|
|
|
if strategic_brain.llm is None: |
|
return {"command": f"echo 'No LLM loaded. Please load a model from settings. Fallback: nmap -sV -Pn {fallback_ip}'", "status": "fallback", "message": "No LLM loaded on Strategic Agent. Please load one from the frontend settings."} |
|
else: |
|
return {"command": f"nmap -sV -Pn {fallback_ip}", "status": "fallback", "message": "Strategic Agent could not determine a valid next command."} |
|
|
|
|
|
@app.post("/strategic_agent/observe_result") |
|
async def strategic_observe_result(request: ObserveResultRequest): |
|
logger.debug(f"Strategic Agent received observation for command: {request.command}, success: {request.success}") |
|
strategic_brain.observe_result(request.command, request.output, request.success) |
|
return {"status": "success", "message": "Observation received and processed."} |
|
|
|
@app.get("/strategic_agent/get_status") |
|
async def strategic_get_status(): |
|
logger.debug("Strategic Agent status requested.") |
|
return { |
|
"currentGoal": strategic_brain.current_goal, |
|
"currentPhase": strategic_brain.current_phase.replace('_', ' ').title(), |
|
"reactCycleCount": strategic_brain.react_cycle_count, |
|
"noProgressCount": strategic_brain.no_progress_count, |
|
"identifiedVulnerabilities": [v['description'] for v in strategic_brain.identified_vulnerabilities], |
|
"gatheredInfo": [info[:100] + "..." for info in strategic_brain.gathered_info[-5:]] if strategic_brain.gathered_info else [], |
|
"executionHistorySummary": [{ |
|
"command": e['command'], |
|
"success": e['success'], |
|
"timestamp": e['timestamp'] |
|
} for e in strategic_brain.execution_history[-10:]], |
|
"strategicPlan": strategic_brain.current_plan, |
|
"currentPhaseIndex": strategic_brain.current_phase_index, |
|
"goalAchieved": strategic_brain.goal_achieved, |
|
"strategicAgentStatus": "Running" if strategic_brain.current_goal and not strategic_brain.goal_achieved else "Idle", |
|
"loadedModel": strategic_brain.loaded_model_name |
|
} |
|
|
|
@app.get("/api/models") |
|
async def get_available_models_strategic(): |
|
"""List predefined Hugging Face models for strategic agent.""" |
|
logger.debug("Request for available strategic models received.") |
|
|
|
return JSONResponse(content=json.dumps(HUGGINGFACE_MODELS), media_type="application/json") |
|
|
|
|
|
@app.on_event("startup") |
|
async def startup_event_download_models(): |
|
logger.info("Application startup event triggered. Attempting to download all predefined models.") |
|
|
|
|
|
for model_info in HUGGINGFACE_MODELS: |
|
model_url = model_info["url"] |
|
model_name = model_info["name"] |
|
model_filename = model_url.split('/')[-1] |
|
local_model_path = os.path.join(DOWNLOAD_DIR, model_filename) |
|
|
|
if not os.path.exists(local_model_path): |
|
logger.info(f"Downloading model '{model_name}' from {model_url} to {local_model_path}...") |
|
try: |
|
response = requests.get(model_url, stream=True) |
|
response.raise_for_status() |
|
with open(local_model_path, 'wb') as f: |
|
for chunk in response.iter_content(chunk_size=8192): |
|
f.write(chunk) |
|
logger.info(f"Model '{model_name}' downloaded successfully.") |
|
except Exception as e: |
|
logger.error(f"Failed to download model '{model_name}': {e}", exc_info=True) |
|
else: |
|
logger.info(f"Model '{model_name}' already exists at {local_model_path}. Skipping download.") |
|
logger.info("Finished attempting to download all predefined models.") |
|
|
|
|
|
@app.on_event("shutdown") |
|
async def shutdown_event_cleanup(): |
|
logger.info("Application shutdown event triggered. Performing cleanup.") |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
logger.info("Starting FastAPI application on Hugging Face Spaces (port 7860)...") |
|
uvicorn.run( |
|
app, |
|
host="0.0.0.0", |
|
port=7860, |
|
log_level="info" |
|
) |
|
|