import argparse import glob import json import logging import multiprocessing as mp import os import time import uuid from datetime import timedelta from functools import lru_cache from typing import List, Union import boto3 import gradio as gr import requests from huggingface_hub import HfApi from optimum.onnxruntime import ORTModelForSequenceClassification from rebuff import Rebuff from transformers import AutoTokenizer, pipeline logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) hf_api = HfApi(token=os.getenv("HF_TOKEN")) num_processes = 2 # mp.cpu_count() lakera_api_key = os.getenv("LAKERA_API_KEY") sydelabs_api_key = os.getenv("SYDELABS_API_KEY") rebuff_api_key = os.getenv("REBUFF_API_KEY") azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT") azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY") aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-east-1") @lru_cache(maxsize=2) def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline: hf_model = ORTModelForSequenceClassification.from_pretrained( prompt_injection_ort_model, export=False, subfolder=subfolder, file_name="model.onnx" ) hf_tokenizer = AutoTokenizer.from_pretrained(prompt_injection_ort_model, subfolder=subfolder) hf_tokenizer.model_input_names = ["input_ids", "attention_mask"] logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU") return pipeline( "text-classification", model=hf_model, tokenizer=hf_tokenizer, device="cpu", batch_size=1, truncation=True, max_length=512, ) def convert_elapsed_time(diff_time) -> float: return round(timedelta(seconds=diff_time).total_seconds(), 2) deepset_classifier = init_prompt_injection_model( "ProtectAI/deberta-v3-base-injection-onnx" ) # ONNX version of deepset/deberta-v3-base-injection protectai_v2_classifier = init_prompt_injection_model( "ProtectAI/deberta-v3-base-prompt-injection-v2", "onnx" ) fmops_classifier = init_prompt_injection_model( "ProtectAI/fmops-distilbert-prompt-injection-onnx" ) # ONNX version of fmops/distilbert-prompt-injection def detect_hf( prompt: str, threshold: float = 0.5, classifier=protectai_v2_classifier, label: str = "INJECTION", ) -> (bool, bool): try: pi_result = classifier(prompt) injection_score = round( pi_result[0]["score"] if pi_result[0]["label"] == label else 1 - pi_result[0]["score"], 2, ) logger.info(f"Prompt injection result from the HF model: {pi_result}") return True, injection_score > threshold except Exception as err: logger.error(f"Failed to call HF model: {err}") return False, False def detect_hf_protectai_v2(prompt: str) -> (bool, bool): return detect_hf(prompt, classifier=protectai_v2_classifier) def detect_hf_deepset(prompt: str) -> (bool, bool): return detect_hf(prompt, classifier=deepset_classifier) def detect_hf_fmops(prompt: str) -> (bool, bool): return detect_hf(prompt, classifier=fmops_classifier, label="LABEL_1") def detect_lakera(prompt: str) -> (bool, bool): try: response = requests.post( "https://api.lakera.ai/v1/prompt_injection", json={"input": prompt}, headers={"Authorization": f"Bearer {lakera_api_key}"}, ) response_json = response.json() logger.info(f"Prompt injection result from Lakera: {response.json()}") return True, response_json["results"][0]["flagged"] except requests.RequestException as err: logger.error(f"Failed to call Lakera API: {err}") return False, False def detect_rebuff(prompt: str) -> (bool, bool): try: rb = Rebuff(api_token=rebuff_api_key, api_url="https://www.rebuff.ai") result = rb.detect_injection(prompt) logger.info(f"Prompt injection result from Rebuff: {result}") return True, result.injectionDetected except Exception as err: logger.error(f"Failed to call Rebuff API: {err}") return False, False def detect_azure(prompt: str) -> (bool, bool): try: response = requests.post( f"{azure_content_safety_endpoint}contentsafety/text:detectJailbreak?api-version=2023-10-15-preview", json={"text": prompt}, headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key}, ) response_json = response.json() logger.info(f"Prompt injection result from Azure: {response.json()}") if "jailbreakAnalysis" not in response_json: return False, False return True, response_json["jailbreakAnalysis"]["detected"] except requests.RequestException as err: logger.error(f"Failed to call Azure API: {err}") return False, False def detect_aws_comprehend(prompt: str) -> (bool, bool): response = aws_comprehend_client.classify_document( EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety", Text=prompt, ) response = { "Classes": [ {"Name": "SAFE_PROMPT", "Score": 0.9010000228881836}, {"Name": "UNSAFE_PROMPT", "Score": 0.0989999994635582}, ], "ResponseMetadata": { "RequestId": "e8900fe1-3346-45c0-bad3-007b2840865a", "HTTPStatusCode": 200, "HTTPHeaders": { "x-amzn-requestid": "e8900fe1-3346-45c0-bad3-007b2840865a", "content-type": "application/x-amz-json-1.1", "content-length": "115", "date": "Mon, 19 Feb 2024 08:34:43 GMT", }, "RetryAttempts": 0, }, } logger.info(f"Prompt injection result from AWS Comprehend: {response}") if response["ResponseMetadata"]["HTTPStatusCode"] != 200: logger.error(f"Failed to call AWS Comprehend API: {response}") return False, False return True, response["Classes"][0] == "UNSAFE_PROMPT" def detect_sydelabs(prompt: str) -> (bool, bool): try: response = requests.post( "https://guard.sydelabs.ai/api/v1/guard/generate-score", json={"prompt": prompt}, headers={ "Authorization": f"Bearer {lakera_api_key}", "X-Api-Key": sydelabs_api_key, }, ) response_json = response.json() logger.info(f"Prompt injection result from SydeLabs: {response.json()}") prompt_injection_risk = next( ( category["risk"] for category in response_json["category_scores"] if category["category"] == "PROMPT_INJECT" ), False, ) return True, prompt_injection_risk except requests.RequestException as err: logger.error(f"Failed to call SydeLabs API: {err}") return False, False detection_providers = { "ProtectAI v2 (HF model)": detect_hf_protectai_v2, "Deepset (HF model)": detect_hf_deepset, "FMOps (HF model)": detect_hf_fmops, "Lakera Guard": detect_lakera, # "Rebuff": detect_rebuff, "Azure Content Safety": detect_azure, "SydeLabs": detect_sydelabs, # "AWS Comprehend": detect_aws_comprehend, } def is_detected(provider: str, prompt: str) -> (str, bool, bool, float): if provider not in detection_providers: logger.warning(f"Provider {provider} is not supported") return False, 0.0 start_time = time.monotonic() request_result, is_injection = detection_providers[provider](prompt) end_time = time.monotonic() return provider, request_result, is_injection, convert_elapsed_time(end_time - start_time) def execute(prompt: str) -> List[Union[str, bool, float]]: results = [] with mp.Pool(processes=num_processes) as pool: for result in pool.starmap( is_detected, [(provider, prompt) for provider in detection_providers.keys()] ): results.append(result) # Save image and result fileobj = json.dumps( {"prompt": prompt, "results": results}, indent=2, ensure_ascii=False ).encode("utf-8") result_path = f"/prompts/train/{str(uuid.uuid4())}.json" hf_api.upload_file( path_or_fileobj=fileobj, path_in_repo=result_path, repo_id="ProtectAI/prompt-injection-benchmark", repo_type="dataset", ) logger.info(f"Stored prompt: {prompt}") return results if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--port", type=int, default=7860) parser.add_argument("--url", type=str, default="0.0.0.0") args, left_argv = parser.parse_known_args() example_files = glob.glob(os.path.join(os.path.dirname(__file__), "examples", "*.txt")) examples = [open(file).read() for file in example_files] gr.Interface( fn=execute, inputs=[ gr.Textbox(label="Prompt"), ], outputs=[ gr.Dataframe( headers=[ "Provider", "Is processed successfully?", "Is prompt injection?", "Latency (seconds)", ], datatype=["str", "bool", "bool", "number"], label="Results", ), ], title="Prompt Injection Solutions Benchmark", description="This interface aims to benchmark the known prompt injection detection providers. " "The results are stored in the private dataset for further analysis and improvements. This interface is for research purposes only." "

" "HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.

" 'Join our Slack community to discuss LLM Security
' 'Secure your LLM interactions with LLM Guard', examples=[ [ example, False, ] for example in examples ], cache_examples=True, allow_flagging="never", concurrency_limit=1, ).launch(server_name=args.url, server_port=args.port)