Spaces:
Sleeping
Sleeping
# +------------------------------------+ | |
# | |
# Prompt Injection Detection | |
# | |
# +------------------------------------+ | |
# Thank you users! We ❤️ you! - Krrish & Ishaan | |
## Reject a call if it contains a prompt injection attack. | |
from difflib import SequenceMatcher | |
from typing import List, Literal, Optional | |
from fastapi import HTTPException | |
import litellm | |
from litellm._logging import verbose_proxy_logger | |
from litellm.caching.caching import DualCache | |
from litellm.constants import DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD | |
from litellm.integrations.custom_logger import CustomLogger | |
from litellm.litellm_core_utils.prompt_templates.factory import ( | |
prompt_injection_detection_default_pt, | |
) | |
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth | |
from litellm.router import Router | |
from litellm.utils import get_formatted_prompt | |
class _OPTIONAL_PromptInjectionDetection(CustomLogger): | |
# Class variables or attributes | |
def __init__( | |
self, | |
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None, | |
): | |
self.prompt_injection_params = prompt_injection_params | |
self.llm_router: Optional[Router] = None | |
self.verbs = [ | |
"Ignore", | |
"Disregard", | |
"Skip", | |
"Forget", | |
"Neglect", | |
"Overlook", | |
"Omit", | |
"Bypass", | |
"Pay no attention to", | |
"Do not follow", | |
"Do not obey", | |
] | |
self.adjectives = [ | |
"", | |
"prior", | |
"previous", | |
"preceding", | |
"above", | |
"foregoing", | |
"earlier", | |
"initial", | |
] | |
self.prepositions = [ | |
"", | |
"and start over", | |
"and start anew", | |
"and begin afresh", | |
"and start from scratch", | |
] | |
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): | |
if level == "INFO": | |
verbose_proxy_logger.info(print_statement) | |
elif level == "DEBUG": | |
verbose_proxy_logger.debug(print_statement) | |
if litellm.set_verbose is True: | |
print(print_statement) # noqa | |
def update_environment(self, router: Optional[Router] = None): | |
self.llm_router = router | |
if ( | |
self.prompt_injection_params is not None | |
and self.prompt_injection_params.llm_api_check is True | |
): | |
if self.llm_router is None: | |
raise Exception( | |
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection." | |
) | |
self.print_verbose( | |
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}" | |
) | |
if ( | |
self.prompt_injection_params.llm_api_name is None | |
or self.prompt_injection_params.llm_api_name | |
not in self.llm_router.model_names | |
): | |
raise Exception( | |
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'." | |
) | |
def generate_injection_keywords(self) -> List[str]: | |
combinations = [] | |
for verb in self.verbs: | |
for adj in self.adjectives: | |
for prep in self.prepositions: | |
phrase = " ".join(filter(None, [verb, adj, prep])).strip() | |
if ( | |
len(phrase.split()) > 2 | |
): # additional check to ensure more than 2 words | |
combinations.append(phrase.lower()) | |
return combinations | |
def check_user_input_similarity( | |
self, | |
user_input: str, | |
similarity_threshold: float = DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD, | |
) -> bool: | |
user_input_lower = user_input.lower() | |
keywords = self.generate_injection_keywords() | |
for keyword in keywords: | |
# Calculate the length of the keyword to extract substrings of the same length from user input | |
keyword_length = len(keyword) | |
for i in range(len(user_input_lower) - keyword_length + 1): | |
# Extract a substring of the same length as the keyword | |
substring = user_input_lower[i : i + keyword_length] | |
# Calculate similarity | |
match_ratio = SequenceMatcher(None, substring, keyword).ratio() | |
if match_ratio > similarity_threshold: | |
self.print_verbose( | |
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}", | |
level="INFO", | |
) | |
return True # Found a highly similar substring | |
return False # No substring crossed the threshold | |
async def async_pre_call_hook( | |
self, | |
user_api_key_dict: UserAPIKeyAuth, | |
cache: DualCache, | |
data: dict, | |
call_type: str, # "completion", "embeddings", "image_generation", "moderation" | |
): | |
try: | |
""" | |
- check if user id part of call | |
- check if user id part of blocked list | |
""" | |
self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook") | |
try: | |
assert call_type in [ | |
"completion", | |
"text_completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
] | |
except Exception: | |
self.print_verbose( | |
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" | |
) | |
return data | |
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore | |
is_prompt_attack = False | |
if self.prompt_injection_params is not None: | |
# 1. check if heuristics check turned on | |
if self.prompt_injection_params.heuristics_check is True: | |
is_prompt_attack = self.check_user_input_similarity( | |
user_input=formatted_prompt | |
) | |
if is_prompt_attack is True: | |
raise HTTPException( | |
status_code=400, | |
detail={ | |
"error": "Rejected message. This is a prompt injection attack." | |
}, | |
) | |
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet | |
if self.prompt_injection_params.vector_db_check is True: | |
pass | |
else: | |
is_prompt_attack = self.check_user_input_similarity( | |
user_input=formatted_prompt | |
) | |
if is_prompt_attack is True: | |
raise HTTPException( | |
status_code=400, | |
detail={ | |
"error": "Rejected message. This is a prompt injection attack." | |
}, | |
) | |
return data | |
except HTTPException as e: | |
if ( | |
e.status_code == 400 | |
and isinstance(e.detail, dict) | |
and "error" in e.detail # type: ignore | |
and self.prompt_injection_params is not None | |
and self.prompt_injection_params.reject_as_response | |
): | |
return e.detail.get("error") | |
raise e | |
except Exception as e: | |
verbose_proxy_logger.exception( | |
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( | |
str(e) | |
) | |
) | |
async def async_moderation_hook( # type: ignore | |
self, | |
data: dict, | |
user_api_key_dict: UserAPIKeyAuth, | |
call_type: Literal[ | |
"completion", | |
"embeddings", | |
"image_generation", | |
"moderation", | |
"audio_transcription", | |
], | |
) -> Optional[bool]: | |
self.print_verbose( | |
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}" | |
) | |
if self.prompt_injection_params is None: | |
return None | |
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore | |
is_prompt_attack = False | |
prompt_injection_system_prompt = getattr( | |
self.prompt_injection_params, | |
"llm_api_system_prompt", | |
prompt_injection_detection_default_pt(), | |
) | |
# 3. check if llm api check turned on | |
if ( | |
self.prompt_injection_params.llm_api_check is True | |
and self.prompt_injection_params.llm_api_name is not None | |
and self.llm_router is not None | |
): | |
# make a call to the llm api | |
response = await self.llm_router.acompletion( | |
model=self.prompt_injection_params.llm_api_name, | |
messages=[ | |
{ | |
"role": "system", | |
"content": prompt_injection_system_prompt, | |
}, | |
{"role": "user", "content": formatted_prompt}, | |
], | |
) | |
self.print_verbose(f"Received LLM Moderation response: {response}") | |
self.print_verbose( | |
f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}" | |
) | |
if isinstance(response, litellm.ModelResponse) and isinstance( | |
response.choices[0], litellm.Choices | |
): | |
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore | |
is_prompt_attack = True | |
if is_prompt_attack is True: | |
raise HTTPException( | |
status_code=400, | |
detail={ | |
"error": "Rejected message. This is a prompt injection attack." | |
}, | |
) | |
return is_prompt_attack | |