Spaces:
Runtime error
Runtime error
* small model support
Browse files- README.md +1 -1
- app.py +22 -24
- requirements.txt +7 -7
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: π
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 4.
|
8 |
pinned: true
|
9 |
license: apache-2.0
|
10 |
---
|
|
|
4 |
colorFrom: yellow
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 4.39.0
|
8 |
pinned: true
|
9 |
license: apache-2.0
|
10 |
---
|
app.py
CHANGED
@@ -21,7 +21,8 @@ from transformers import AutoTokenizer, pipeline
|
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
-
|
|
|
25 |
num_processes = 2 # mp.cpu_count()
|
26 |
|
27 |
lakera_api_key = os.getenv("LAKERA_API_KEY")
|
@@ -35,9 +36,15 @@ aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-
|
|
35 |
@lru_cache(maxsize=2)
|
36 |
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline:
|
37 |
hf_model = ORTModelForSequenceClassification.from_pretrained(
|
38 |
-
prompt_injection_ort_model,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
)
|
40 |
-
hf_tokenizer = AutoTokenizer.from_pretrained(prompt_injection_ort_model, subfolder=subfolder)
|
41 |
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
42 |
|
43 |
logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU")
|
@@ -58,14 +65,17 @@ def convert_elapsed_time(diff_time) -> float:
|
|
58 |
|
59 |
|
60 |
deepset_classifier = init_prompt_injection_model(
|
61 |
-
"
|
62 |
) # ONNX version of deepset/deberta-v3-base-injection
|
63 |
protectai_v2_classifier = init_prompt_injection_model(
|
64 |
-
"
|
65 |
)
|
66 |
fmops_classifier = init_prompt_injection_model(
|
67 |
-
"
|
68 |
) # ONNX version of fmops/distilbert-prompt-injection
|
|
|
|
|
|
|
69 |
|
70 |
|
71 |
def detect_hf(
|
@@ -93,6 +103,10 @@ def detect_hf_protectai_v2(prompt: str) -> (bool, bool):
|
|
93 |
return detect_hf(prompt, classifier=protectai_v2_classifier)
|
94 |
|
95 |
|
|
|
|
|
|
|
|
|
96 |
def detect_hf_deepset(prompt: str) -> (bool, bool):
|
97 |
return detect_hf(prompt, classifier=deepset_classifier)
|
98 |
|
@@ -153,23 +167,6 @@ def detect_aws_comprehend(prompt: str) -> (bool, bool):
|
|
153 |
EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
|
154 |
Text=prompt,
|
155 |
)
|
156 |
-
response = {
|
157 |
-
"Classes": [
|
158 |
-
{"Name": "SAFE_PROMPT", "Score": 0.9010000228881836},
|
159 |
-
{"Name": "UNSAFE_PROMPT", "Score": 0.0989999994635582},
|
160 |
-
],
|
161 |
-
"ResponseMetadata": {
|
162 |
-
"RequestId": "e8900fe1-3346-45c0-bad3-007b2840865a",
|
163 |
-
"HTTPStatusCode": 200,
|
164 |
-
"HTTPHeaders": {
|
165 |
-
"x-amzn-requestid": "e8900fe1-3346-45c0-bad3-007b2840865a",
|
166 |
-
"content-type": "application/x-amz-json-1.1",
|
167 |
-
"content-length": "115",
|
168 |
-
"date": "Mon, 19 Feb 2024 08:34:43 GMT",
|
169 |
-
},
|
170 |
-
"RetryAttempts": 0,
|
171 |
-
},
|
172 |
-
}
|
173 |
logger.info(f"Prompt injection result from AWS Comprehend: {response}")
|
174 |
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
175 |
logger.error(f"Failed to call AWS Comprehend API: {response}")
|
@@ -209,13 +206,14 @@ def detect_sydelabs(prompt: str) -> (bool, bool):
|
|
209 |
|
210 |
detection_providers = {
|
211 |
"ProtectAI v2 (HF model)": detect_hf_protectai_v2,
|
|
|
212 |
"Deepset (HF model)": detect_hf_deepset,
|
213 |
"FMOps (HF model)": detect_hf_fmops,
|
214 |
"Lakera Guard": detect_lakera,
|
215 |
# "Rebuff": detect_rebuff,
|
216 |
"Azure Content Safety": detect_azure,
|
217 |
"SydeLabs": detect_sydelabs,
|
218 |
-
|
219 |
}
|
220 |
|
221 |
|
|
|
21 |
logging.basicConfig(level=logging.INFO)
|
22 |
logger = logging.getLogger(__name__)
|
23 |
|
24 |
+
hf_token = os.getenv("HF_TOKEN")
|
25 |
+
hf_api = HfApi(token=hf_token)
|
26 |
num_processes = 2 # mp.cpu_count()
|
27 |
|
28 |
lakera_api_key = os.getenv("LAKERA_API_KEY")
|
|
|
36 |
@lru_cache(maxsize=2)
|
37 |
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline:
|
38 |
hf_model = ORTModelForSequenceClassification.from_pretrained(
|
39 |
+
prompt_injection_ort_model,
|
40 |
+
export=False,
|
41 |
+
subfolder=subfolder,
|
42 |
+
file_name="model.onnx",
|
43 |
+
token=hf_token,
|
44 |
+
)
|
45 |
+
hf_tokenizer = AutoTokenizer.from_pretrained(
|
46 |
+
prompt_injection_ort_model, subfolder=subfolder, token=hf_token
|
47 |
)
|
|
|
48 |
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]
|
49 |
|
50 |
logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU")
|
|
|
65 |
|
66 |
|
67 |
deepset_classifier = init_prompt_injection_model(
|
68 |
+
"protectai/deberta-v3-base-injection-onnx"
|
69 |
) # ONNX version of deepset/deberta-v3-base-injection
|
70 |
protectai_v2_classifier = init_prompt_injection_model(
|
71 |
+
"protectai/deberta-v3-base-prompt-injection-v2", "onnx"
|
72 |
)
|
73 |
fmops_classifier = init_prompt_injection_model(
|
74 |
+
"protectai/fmops-distilbert-prompt-injection-onnx"
|
75 |
) # ONNX version of fmops/distilbert-prompt-injection
|
76 |
+
protectai_v2_small_classifier = init_prompt_injection_model(
|
77 |
+
"protectai/deberta-v3-small-prompt-injection-v2", "onnx"
|
78 |
+
) # ONNX version of protectai/deberta-v3-small-prompt-injection-v2
|
79 |
|
80 |
|
81 |
def detect_hf(
|
|
|
103 |
return detect_hf(prompt, classifier=protectai_v2_classifier)
|
104 |
|
105 |
|
106 |
+
def detect_hf_protectai_v2_small(prompt: str) -> (bool, bool):
|
107 |
+
return detect_hf(prompt, classifier=protectai_v2_small_classifier)
|
108 |
+
|
109 |
+
|
110 |
def detect_hf_deepset(prompt: str) -> (bool, bool):
|
111 |
return detect_hf(prompt, classifier=deepset_classifier)
|
112 |
|
|
|
167 |
EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
|
168 |
Text=prompt,
|
169 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
logger.info(f"Prompt injection result from AWS Comprehend: {response}")
|
171 |
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
|
172 |
logger.error(f"Failed to call AWS Comprehend API: {response}")
|
|
|
206 |
|
207 |
detection_providers = {
|
208 |
"ProtectAI v2 (HF model)": detect_hf_protectai_v2,
|
209 |
+
"ProtectAI v2 Small (HF model)": detect_hf_protectai_v2_small,
|
210 |
"Deepset (HF model)": detect_hf_deepset,
|
211 |
"FMOps (HF model)": detect_hf_fmops,
|
212 |
"Lakera Guard": detect_lakera,
|
213 |
# "Rebuff": detect_rebuff,
|
214 |
"Azure Content Safety": detect_azure,
|
215 |
"SydeLabs": detect_sydelabs,
|
216 |
+
"AWS Comprehend": detect_aws_comprehend,
|
217 |
}
|
218 |
|
219 |
|
requirements.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
-
boto3==1.34.
|
2 |
-
gradio==4.
|
3 |
-
huggingface_hub==0.
|
4 |
-
onnxruntime==1.
|
5 |
-
optimum[onnxruntime]==1.
|
6 |
rebuff==0.1.1
|
7 |
-
requests==2.
|
8 |
-
transformers==4.
|
|
|
1 |
+
boto3==1.34.146
|
2 |
+
gradio==4.39.0
|
3 |
+
huggingface_hub==0.24.0
|
4 |
+
onnxruntime==1.18.1
|
5 |
+
optimum[onnxruntime]==1.21.2
|
6 |
rebuff==0.1.1
|
7 |
+
requests==2.32.3
|
8 |
+
transformers==4.42.4
|